# inference for Gabor-GLM with ABC methods

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

- we run a classical likelihood-free inference algorithm (SMC-ABC) on the Gabor-GLM simulator
- like SNPE, SMC-ABC iteratively refines a posterior estimate across multiple rounds
- within each rounds, SMC-ABC runs a rejection-sampling scheme that rejects parameters based on a distance measure $d(x,x_o)$.


- the design of this distance measure $d$ can be tricky and requires good summary statistics.
- here we use a standard approach: squared error on (normalized) summary statistics, $d(x,x_o) = || (x-x_o) \ / \ std ||_2$.
- this standard approach does not work well here, since the summary statistics in this case include a $1641$-dimensional spike-triggered average

In [None]:
%%capture
%matplotlib inline

import theano
theano.config.floatX='float64'

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
#from delfi.utils.viz import plot_pdf

from model.gabor_rf import maprf as model
from model.gabor_stats import maprfStats
from utils import setup_sim, get_data_o, setup_sampler, quick_plot, contour_draws

import sys; sys.path.append('../')
from common import col, svg, plot_pdf

from support_files.run_abc import run_smc

# parameters for this experiment

In [None]:
seed = 42    # seed for generation of xo for selected cell. MCMC currently not seeded !

idx_cell = 6 # load toy cell number i = idx_cell

maxsim = int(1e6)
n_particles= int(1e3)

savefile = 'results/SMC/toycell_' + str(idx_cell) + '/maprf_SMC_prior01_run_1_'+ str(n_particles) + 'particles_param9'

# load cell, generate xo

In [None]:
g, prior, d = setup_sim(seed, path='')

filename = 'results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
params_dict_true = np.load(filename, allow_pickle=True)[()]
params_dict_true['kernel']['t'] = {'value' : 1.}

m = g.model
m.params_dict = params_dict_true.copy()
m.rng = np.random.RandomState(seed=seed)

pars_true, obs = m.read_params_buffer(), m.gen_single()
obs_stats = g.summary.calc([obs])

# plot ground-truth receptive field
rf = g.model.params_to_rf(pars_true)[0]
plt.imshow(np.hstack((obs_stats[0,:-1].reshape(d,d), rf)), interpolation='None')
plt.show()

print('spike count', obs_stats[0,-1])

# compute likelihoods

In [None]:
from utils import setup_sampler
from maprf.utils import empty
from theano import In
import theano.tensor as tt
import scipy.stats as st

g, prior_dict, d = setup_sim(seed, path='')
params_dict_true['kernel']['t'] = {'value' : 1. }
prior = g.prior

inference, data = setup_sampler(prior_dict, obs, d, g, params_dict=params_dict_true, 
                          fix_position=False, parametrization='logit_φ')

# Gamma prior parameters
alpha = inference.priors['glm']['bias']['alpha']
beta = inference.priors['glm']['bias']['beta']

# could try to grad the following also from the existing graph?
x = tt.as_tensor_variable(data[0], 'x')
y = tt.as_tensor_variable(data[1], 'y')
η = inference.filter(x, inference.updates)

α = theano.function([], tt.sum(y) + alpha,
                    on_unused_input='warn',
                    allow_input_downcast=True)

# get binsize without adding it to self.inputs
Δ = theano.shared(empty(inference.emt.binsize.ndim), inference.emt.binsize.name)
in_Δ = In(inference.emt.binsize, value=Δ.container, implicit=False)
Δ.set_value(m.dt)

i = list(inference.inputs.values()) + [in_Δ]
β = theano.function(i, inference.emt.binsize * tt.sum(tt.exp(η)) + beta,
                    on_unused_input='warn', allow_input_downcast=True)

def loglikelihood(params):

    # use g.model to translate between SNPE/SMC parametrization and mapRF parametrization
    g.model._set_pars_dict(params)
    params_dict = g.model.params_dict

    # update inference object with translated parameters
    loglik = inference.loglik
    fix_position = False
    if fix_position:
        loglik['logit_xo'] = 0.
        loglik['logit_yo'] = 0.
    else:
        kl =  params_dict['kernel']['l']
        loglik['logit_xo'] = np.log( (1+kl['xo']) / (  1. - kl['xo']))
        loglik['logit_yo'] = np.log( (1+kl['yo']) / (  1. - kl['yo']))
    loglik['kt'] = params_dict['kernel']['t']['value']
    ks = params_dict['kernel']['s']
    loglik['log_γ'] = np.log(ks['ratio'])
    loglik['log_b'] = np.log(ks['width'])
    loglik['log_A']   = np.log(ks['gain'])
    loglik['logit_φ'] =  np.log(ks['phase'] / (  np.pi - ks['phase']))
    loglik['log_f'] = np.log(ks['freq'])
    loglik['logit_θ'] = np.log(ks['angle'] / (2*np.pi - ks['angle']))
    
    # inference.loglik is integrated over biases! 
    # If all 9 parameters are divided into 8 params for the spatial kernel params_ks and 
    # one bias parameter, then the probabilities work out as 
    # p(x|params_ks, bias) = p(bias|x, params_ks) * p(x | params_ks) / p(bias|params_ks)
    ll = inference.loglik() # = p(x | params_ks)
    ll += np.log(st.gamma.pdf(np.exp(params_dict['glm']['bias']), 
                              a=α(), scale=1./β(), loc=0)) # = p(bias|x, params_ks) 
    ll -= prior.eval(np.exp(params_dict['glm']['bias']), ii=0, log=True) # = p(bias|params_ks)
                     
    return ll

def logjointdensities(params):
    
    return loglikelihood(params) + prior.eval(params, log=True)

# load SMC results

In [None]:
class normed_summary(): # definition just necessary for loading SMC results below
    def calc(self, y):
        x = g.summary.calc(y)
        return (x-stats_mean)/stats_std

res = np.load('results/SMC/toycell_6/maprf_SMC_prior01_run_1_1000particles_param9.npy',
    allow_pickle=True)[()]

for k,v in res.items():
    globals()[k] = v
    
params_SMC = res['all_ps'][-1]
lls_SMC = np.zeros(params_SMC.shape[0])
ljs_SMC = np.zeros(params_SMC.shape[0])
for i in range(params_SMC.shape[0]):
    lls_SMC[i] = loglikelihood(params_SMC[i,:])
    ljs_SMC[i] = logjointdensities(params_SMC[i,:])

        
plt.hist(lls_SMC, normed=True)
plt.plot(loglikelihood(pars_true)*np.ones(2), [0,0.008], 'r')
plt.show()

In [None]:
g, _, d = setup_sim(seed, path='')
corrs_SMC = np.zeros(params_SMC.shape[0])
for i in range(params_SMC.shape[0]):
    out = g.model.gen_single(params_SMC[i])
    corrs_SMC[i] = np.corrcoef(obs['data'], out['data'])[0,1]

plt.hist(corrs_SMC, normed=True)
plt.show()

# load SNPE results

In [None]:
tmp = np.load('results/SNPE/toycell_6/maprf_100k_prior01_run_1_round2_param9_nosvi_CDELFI_posterior.npy',
              allow_pickle=True)[()]
posterior, proposal, prior = tmp['posterior'], tmp['proposal'], tmp['prior']

params_SNPE = posterior.gen(1000)

lls_SNPE = np.zeros(params_SNPE.shape[0])
ljs_SNPE = np.zeros(params_SNPE.shape[0])
for i in range(params_SNPE.shape[0]):
    lls_SNPE[i] = loglikelihood(params_SNPE[i,:])
    ljs_SNPE[i] = logjointdensities(params_SNPE[i,:])
plt.hist(lls_SNPE, normed=True)
plt.plot(loglikelihood(pars_true)*np.ones(2), [0,0.008], 'r')
plt.show()

In [None]:
g, _, d = setup_sim(seed, path='')
corrs_SNPE = np.zeros(params_SNPE.shape[0])
for i in range(params_SNPE.shape[0]):
    out = g.model.gen_single(params_SNPE[i])
    corrs_SNPE[i] = np.corrcoef(obs['data'], out['data'])[0,1]

plt.hist(corrs_SNPE, normed=True)
plt.show()

In [None]:
plot_pdf(posterior, lims=[-4,4], figsize=(16,16));

# load MCMC resutls

In [None]:
idx_cell = 6 # load toy cell number 6 (cosine-shaped RF with 1Hz firing rate)

fix_position=True         # fixues RF position during sampling to (0,0)
parametrization='logit_φ' # chosen parameterization of Gabor (affects priors !) 

n_samples = 1000000  # number of MCMC samples

savefile = 'results/MCMC/toycell_' + str(idx_cell) + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param9_5min'

T = np.load(savefile+'.npy',allow_pickle=True)[()]['T']
samples_MCMC = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias', 'gain', 'logit_φ', 'log_f','logit_θ','log_γ','log_b', 'logit_xo', 'logit_yo']])
samples_MCMC = samples_MCMC[-500000:-1:500, :]


In [None]:
for k,v in res.items():
    globals()[k] = v
    
lls_MCMC = np.zeros(samples_MCMC.shape[0])
ljs_MCMC = np.zeros(samples_MCMC.shape[0])
for i in range(samples_MCMC.shape[0]):
    lls_MCMC[i] = loglikelihood(samples_MCMC[i,:])
    ljs_MCMC[i] = logjointdensities(samples_MCMC[i,:])


plt.hist(lls_MCMC, normed=True)
plt.plot(loglikelihood(pars_true)*np.ones(2), [0,0.008], 'r')
plt.show()

In [None]:
g, _, d = setup_sim(seed, path='')
corrs_MCMC = np.zeros(samples_MCMC.shape[0])
for i in range(samples_MCMC.shape[0]):
    out = g.model.gen_single(samples_MCMC[i])
    corrs_MCMC[i] = np.corrcoef(obs['data'], out['data'])[0,1]

plt.hist(corrs_MCMC, normed=True)
plt.show()

# compare with prior

In [None]:
params_prior = prior.gen(1000)

lls_prior = np.zeros(params_prior.shape[0])
ljs_prior = np.zeros(params_prior.shape[0])
for i in range(params_SNPE.shape[0]):
    lls_prior[i] = loglikelihood(params_prior[i,:])
    ljs_prior[i] = logjointdensities(params_prior[i,:])
    
lls_prior[lls_prior==-np.inf] = -10000 # lazy numerics
ljs_prior[lls_prior==-np.inf] = -10000 # make sure to focus evaluation on area where numerics are stable

plt.hist(lls_prior, normed=True)
plt.plot(loglikelihood(pars_true)*np.ones(2), [0,0.008], 'r')
plt.show()

In [None]:
g, _, d = setup_sim(seed, path='')
corrs_prior = np.zeros(params_prior.shape[0])
for i in range(params_prior.shape[0]):
    out = g.model.gen_single(params_prior[i])
    corrs_prior[i] = np.corrcoef(obs['data'], out['data'])[0,1]

plt.hist(corrs_prior, normed=True)
plt.show()

# compare all

In [None]:
# sanity check: distribution of r(x, xo) for x ~ p(x|theta_true) - same order of magnitude?
corrs_px = np.zeros(10000)
for i in range(10000):
    obs1=g.model.gen_single(pars_true)

    corrs_px[i] = np.corrcoef(obs1['data'], obs['data'])[0,1]
    
plt.hist(corrs_px)
plt.show()

In [None]:
np.mean(corrs_px), np.std(corrs_px)

In [None]:
plt.figure(figsize=(8,6))
plt.hist(corrs_prior, normed=True)#, bins=np.linspace(-3300, -2750, 30), label='prior')
plt.hist(corrs_SMC,   normed=True)#, bins=np.linspace(-3300, -2750, 30), label='SMC')
plt.hist(corrs_SNPE,  normed=True)#, bins=np.linspace(-3300, -2750, 30), label='SNPE')
plt.xlabel('correlations')
plt.ylabel('density')
plt.yticks([])
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(8,6))
plt.hist(lls_prior, normed=True, bins=np.linspace(-3300, -2750, 30), label='prior')
plt.hist(lls_SMC,   normed=True, bins=np.linspace(-3300, -2750, 30), label='SMC')
plt.hist(lls_SNPE,  normed=True, bins=np.linspace(-3300, -2750, 30), label='SNPE')
plt.plot(loglikelihood(pars_true)*np.ones(2), [0,0.07], 'k', linewidth=2, label='gt')
plt.xlabel('log-likelihoods')
plt.ylabel('density')
plt.yticks([])
plt.xticks([-3100, -3000, -2900, -2800])
plt.legend()
plt.show()

In [None]:
fontsize = 11

plt.figure(figsize=(12,5))

ax = plt.subplot(1,2,1)

plt.hist(np.vstack((corrs_prior, corrs_SMC, corrs_SNPE, corrs_MCMC)).T, normed=True, 
         bins=np.linspace(-0.035, 0.05, 20), 
         label=['prior', 'SMC-ABC', 'SNPE', 'MCMC'],
         color=[(0.55,0.,0.), col['SMC'], col['SNPE'], col['MCMC']],
         histtype='bar',
         rwidth=1.0)
plt.xlabel('correlations', fontsize=fontsize)
plt.ylabel('density', fontsize=fontsize)
plt.yticks([])
plt.xticks([-0.02, 0.0, 0.02, 0.04])
plt.legend(fontsize=fontsize, frameon=False)


ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

ax = plt.subplot(1,2,2)

plt.plot(logjointdensities(pars_true)*np.ones(2), [0,950], 'g', linewidth=1, label='ground-truth param.')

plt.hist(np.vstack((ljs_prior, ljs_SMC, ljs_SNPE, ljs_MCMC)).T, #normed=True, 
         bins=np.linspace(-3150, -2750, 40), 
         label=['prior', 'SMC-ABC', 'SNPE', 'MCMC'],
         color=[(0.55,0.,0.), col['SMC'], col['SNPE'], col['MCMC']],
         histtype='bar',
         rwidth=1.0)

plt.xlabel('log joint densities ' + r'$p(xo, \theta)$', fontsize=fontsize)
plt.ylabel('density', fontsize=fontsize)
plt.yticks([])
plt.xticks([-3100, -3000, -2900, -2800])
plt.legend(fontsize=fontsize, frameon=False)

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

plt.savefig("fig/fig3_gabor_supp_smc_abc_comps.svg")
plt.savefig("fig/fig3_gabor_supp_smc_abc_comps.pdf")

plt.show()



In [None]:
plt.figure(figsize=(16,6))

plt.boxplot(x=(ljs_prior, ljs_SMC, ljs_SNPE), labels=['prior', 'SMC', 'SNPE'])
plt.ylabel('log joint densities')

plt.show()

In [None]:
params_prior.shape

In [None]:
rf_true = g.model.params_to_rf(pars_true)[0]

rfcs_MCMC = np.zeros(samples_MCMC.shape[0])
for i in range(samples_MCMC.shape[0]):
    rf = g.model.params_to_rf(samples_MCMC[i])[0]
    rfcs_MCMC[i] = np.corrcoef(rf_true.flatten(), rf.flatten())[0,1]
    
plt.hist(rfcs_MCMC, normed=True)
plt.show()

rfcs_SMC = np.zeros(params_SMC.shape[0])
for i in range(params_SMC.shape[0]):
    rf = g.model.params_to_rf(params_SMC[i])[0]
    rfcs_SMC[i] = np.corrcoef(rf_true.flatten(), rf.flatten())[0,1]
    
plt.hist(rfcs_SMC, normed=True)
plt.show()

rfcs_SNPE = np.zeros(params_SNPE.shape[0])
for i in range(params_SNPE.shape[0]):
    rf = g.model.params_to_rf(params_SNPE[i])[0]
    rfcs_SNPE[i] = np.corrcoef(rf_true.flatten(), rf.flatten())[0,1]
    
plt.hist(rfcs_SNPE, normed=True)
plt.show()

rfcs_prior = np.zeros(params_prior.shape[0])
for i in range(params_prior.shape[0]):
    rf = g.model.params_to_rf(params_prior[i])[0]
    rfcs_prior[i] = np.corrcoef(rf_true.flatten(), rf.flatten())[0,1]
    
plt.hist(rfcs_prior, normed=True)
plt.show()

In [None]:
fontsize = 9
fig_inches = (4,4)

with mpl.rc_context(fname='../.matplotlibrc'):
    fig = plt.figure(figsize=fig_inches)

    ax = plt.subplot(1,1,1)

    plt.hist(rfcs_SMC, normed=True, 
             bins=np.linspace(-0.5, 1, 30), 
             label=['SMC-ABC'],
             color=[col['SMC']],
             histtype='bar',
             rwidth=1.0)
    plt.hist(rfcs_SNPE, normed=True, 
             bins=np.linspace(-0.5, 1, 30), 
             label=['SNPE'],
             color=[col['SNPE']],
             histtype='bar',
             rwidth=1.0)
    plt.hist(np.vstack((rfcs_prior, rfcs_MCMC)).T, normed=True, 
             bins=np.linspace(-0.5, 1, 30), 
             label=['prior', 'MCMC'],
             color=[(0.55,0.,0.), col['MCMC']],
             histtype='step',
             rwidth=1.0,
             lw=2)
    plt.xlabel('correlations', fontsize=fontsize)
    plt.ylabel('density', fontsize=fontsize)
    plt.yticks([])
    plt.xticks([-0.5, 0.0, 0.5, 1.])
    plt.legend(fontsize=fontsize, frameon=False, loc=2)


    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    fig.savefig('fig/fig3_gabor_supp_smc_abc_comps.svg', transparent=True)

In [None]:
fontsize=11

plt.figure(figsize=(3,2))
ax = plt.subplot(1,1,1)

plt.bar(x=np.arange(4),
    height=[rfcs_prior.mean(), rfcs_SMC.mean(), rfcs_SNPE.mean(), rfcs_MCMC.mean()],
         tick_label=['prior', 'SMC-ABC', 'SNPE', 'MCMC'],
         color=[(0.55,0.,0.), col['SMC'], col['SNPE'], col['MCMC']])
for i,rfcs in enumerate([rfcs_prior, rfcs_SMC, rfcs_SNPE, rfcs_MCMC]):
    plt.plot(
        (i)*np.ones(2), 
        rfcs.std()*np.array([-1,1]) + rfcs.mean(),
        color=(0.4,0.4,0.4)
    )
plt.legend(fontsize=fontsize, frameon=False, loc=2)
plt.plot([-0.5, 3.5], [1., 1.], 'k--')
#plt.ylabel(r'$\left< \rho\left( \ RF, \ \hat{RF}\ \right) \right>$', fontsize=fontsize)
plt.ylabel('correlation', fontsize=fontsize)
plt.yticks([0, 0.5, 1.], fontsize=fontsize)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

plt.savefig("fig/fig3_gabor_supp_smc_abc_comps_inset.svg", bbox_inches='tight', frameon=False, transparent=True)
plt.show()