# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats
from lfimodels.maprf.utils import setup_sim, setup_sampler, quick_plot, contour_draws


# parameters for this experiment

In [None]:
seed = 42    # seed for generation of xo for selected cell. MCMC currently not seeded ! 

idx_cell = 3 # load toy cell number i = idx_cell 

fix_position=True         # fixues RF position during sampling to (0,0)
parametrization='logit_φ' # chosen parameterization of Gabor (affects priors !) 

n_samples = 100000  # number of MCMC samples

savefile = '../results/MCMC/toycell_' + str(idx_cell) + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param7_nosvi_CDELFI.pkl'


# set up simulator

In [None]:
g, prior, d = setup_sim(seed, path='..')

# load cell, generate xo

In [None]:
filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
params_dict_true = np.load(filename)[()]

m = g.model
m.params_dict = params_dict_true.copy()
m.rng = np.random.RandomState(seed=seed)

pars_true, obs = m.read_params_buffer(), m.gen_single()
obs_stats = g.summary.calc([obs])

rf = g.model.params_to_rf(pars_true)[0]


# set up MCMC sampling

In [None]:
inference, data = setup_sampler(prior, obs, d, g, params_dict=params_dict_true, 
                          fix_position=False, parametrization='logit_φ')

inference.samplers[0].mu['xo'] = prior['xo']['mu'][0]
inference.samplers[0].mu['yo'] = prior['yo']['mu'][0]

inference.samplers[0].sd['xo'] = prior['xo']['sigma'][0]
inference.samplers[0].sd['yo'] = prior['yo']['sigma'][0]

In [None]:
inference.loglik['xo'] *= 0 
inference.loglik['yo'] *= 0 

# sample RF parameters (with Poisson bias marginalized out)

In [None]:
n_samples = 1000

#frames, spikes = data
#frames.set_value(0*obs['I'][:1,:].reshape(-1,d,d))
#spikes.set_value(0*obs['data'][:1])

T, L = inference.sample(n_samples)
T = {k.name: t for k, t in T.items()}


In [None]:
x,y = T['xo'],T['yo']

plt.figure(figsize=(15, 8))
plt.subplot(221)
plt.plot(x[0:])
plt.plot(y[0:])

plt.subplot(222)
plt.hist(x[0:], alpha=0.5, normed=True)

plt.subplot(224)
plt.hist(y[0:], alpha=0.5, normed=True)

plt.subplot(223)
plt.plot(x[0:], y[0:], '.k', alpha=0.1)
plt.show()


# sample Poisson bias (conditioned on the others)

In [None]:
inference.sample_biases(data, T, m.dt)

plt.figure(figsize=(12,5))
plt.subplot(2,1,1)
plt.plot(T['bias'])
print('mean: ' + str(T['bias'].mean()) + ', var: ' + str(T['bias'].var()))
plt.subplot(2,1,2)
plt.plot(T['λo'])
print('mean: ' + str(T['λo'].mean()) + ', var: ' + str(T['λo'].var()))
plt.show()

# example posterior draws (in direct comparison with xo)

In [None]:

plt.figure(figsize=(16,12))
i = 1
for t in np.sort(np.random.choice(T['gain'].shape[0], 12, replace=False)):
    params_dict = {'kernel' : {'s' : {}, 'l' : {}}, 'glm': {}}
    params_dict['glm']['bias'] = T['bias'][t]
    params_dict['kernel']['s']['phase'] = T['phase'][t]
    params_dict['kernel']['s']['angle'] = T['angle'][t] 
    params_dict['kernel']['s']['freq']  = T['freq'][t]
    params_dict['kernel']['s']['ratio'] = T['ratio'][t]
    params_dict['kernel']['s']['width'] = T['width'][t]
    params_dict['kernel']['s']['gain'] = T['gain'][t]
    params_dict['kernel']['l']['xo'] = T['xo'][t]
    params_dict['kernel']['l']['yo'] = T['yo'][t]

    axis_x = m.axis_x - params_dict['kernel']['l']['xo']
    axis_y = m.axis_y - params_dict['kernel']['l']['yo']    
    m._gen.grid_x, m._gen.grid_y = np.meshgrid(axis_x, axis_y)    
    
    ks = m._eval_ks(bias=params_dict['glm']['bias'], 
                    angle=params_dict['kernel']['s']['angle'],
                    freq=params_dict['kernel']['s']['freq'],
                    gain=params_dict['kernel']['s']['gain'],
                    phase=params_dict['kernel']['s']['phase'],
                    ratio=params_dict['kernel']['s']['ratio'],
                    width=params_dict['kernel']['s']['width'])
    
    plt.subplot(3,4,i)
    plt.imshow(np.hstack((ks.reshape(d,d), m.params_to_rf(pars_true)[0])), interpolation='None')
    plt.title('t =' + str(t))
    
    print('loc:' , [T['xo'][t], T['yo'][t]])    
    i += 1
plt.show()



# marginal histograms for each (transformed) parameter

In [None]:

burnin = 50

for key in ['bias', 'λo', 
            'gain', 'log_A', 'phase', 'logit_φ',
            'angle', 'logit_θ', 'freq', 'log_f',
            'ratio', 'width', 'log_γ', 'log_b', 
            ]:
    
    if key in T.keys():
        x = T[key][burnin:]
        plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
        plt.title(key)
        plt.show()
        print('mean:', x.mean())
        print('var:', x.var())
        

# posterior samples versus prior


## actual parameters

In [None]:
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias', 'gain', 'phase', 'freq','angle','ratio','width', 'xo', 'yo']])

pars_raw = np.array([ params_dict_true['glm']['bias'],
                      params_dict_true['kernel']['s']['gain'],
                      params_dict_true['kernel']['s']['phase'],
                      params_dict_true['kernel']['s']['freq'],
                      params_dict_true['kernel']['s']['angle'],
                      params_dict_true['kernel']['s']['ratio'],
                      params_dict_true['kernel']['s']['width'],
                      params_dict_true['kernel']['l']['xo'],
                      params_dict_true['kernel']['l']['yo']                    
                    ])

plot_pdf(g.prior, lims=[-3,3], gt=pars_raw.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         ticks=True, labels_params=['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width', 'xo', 'yo']);


## parameters in log/logit space

In [None]:
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias', 'gain', 'logit_φ', 'log_f','logit_θ','log_γ','log_b', 'xo', 'yo']])

plot_pdf(g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         ticks=True, labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo']);


# (roughly) check for mixing of the chain

In [None]:
plt.figure(figsize = (16,5) )
plt.plot(samples)
plt.show()

# save results

In [None]:
np.savez(savefile, {'T' : T, 'params_dict_true' : params_dict_true})
