# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats
from lfimodels.maprf.utils import setup_sim, setup_sampler, quick_plot, contour_draws


# parameters for this experiment

In [None]:
seed = 42    # seed for generation of xo for selected cell. MCMC currently not seeded ! 

idx_cell = 3 # load toy cell number i = idx_cell 

maxsim = int(1e5)
n_particles= int(1e3)

savefile = '../results/MCMC/toycell_' + str(idx_cell) + '/maprf_SMC_prior01_run_1_'+ str(maxsim)+'samples_' + str(n_particles) + 'particles_param9.pkl'
savefile

# load cell, generate xo

In [None]:
g, prior, d = setup_sim(seed, path='..')

filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
params_dict_true = np.load(filename)[()]

m = g.model
m.params_dict = params_dict_true.copy()
m.rng = np.random.RandomState(seed=seed)

pars_true, obs = m.read_params_buffer(), m.gen_single()
obs_stats = g.summary.calc([obs])

rf = g.model.params_to_rf(pars_true)[0]

plt.imshow(np.hstack((obs_stats[0,:-1].reshape(d,d), rf)), interpolation='None')
plt.show()

print('spike count', obs_stats[0,-1])

# define distance function (based on pilot runs)

In [None]:
from lfimodels.abc_methods.run_abc import run_smc

gts, pilots,_=g.gen(1000)
stats_mean, stats_std = pilots.mean(axis=0).reshape(1,-1), pilots.std(axis=0).reshape(1,-1)

stats_mean[:,:-1] = 0
stats_std[:,:-1]  = 1

stats_std[:,-1] *= 1/np.sqrt(d) # rescaling the FR summary stat to contribute about 50% of distance on average

class normed_summary():
    
    def calc(self, y):

        x = g.summary.calc(y)

        return (x-stats_mean)/stats_std

obs_statz =  (obs_stats.flatten() - stats_mean) /  stats_std   
    


In [None]:
# reproduce distance function as used internally by SMC-ABC implemenation
def calc_dist(stats_1, stats_2):
    """Euclidian distance between summary statistics"""
    return np.sqrt(np.sum((stats_1 - stats_2) ** 2))


# compute distances over pilot runs
dists = np.empty(pilots.shape[0])
for i in range(pilots.shape[0]):
    dists[i] = calc_dist((pilots[i]-stats_mean)/stats_std, obs_statz)

# show distance histogram (use to pick initial epsilon, e.g. roughly as median distance)
plt.hist(dists, bins=50, normed=True)
plt.show()

In [None]:
# visualize 10 clostest summary stats to xo under chosen distance function
for i in range(10):
    idx = np.argsort(dists)[i]

    x = (pilots[idx,:] - stats_mean) / stats_std
    plt.imshow(x[0,:-1].reshape(d,d), interpolation='None')
    plt.show()
    print(pilots[idx,-1], obs_stats[0,-1], x[0,-1], obs_statz[0,-1])
    print(np.sum( (x[0,:-1]-obs_statz[0,:-1])**2 ), np.sum( (x[0,-1]-obs_statz[0,-1])**2 ), dists[idx]) 

# run SMC

In [None]:
seed = 90 # SMC seed
eps_init = 10.0

all_ps, all_logweights, all_eps, all_nsims = run_smc(model=g.model, prior=g.prior, summary=normed_summary(), 
                                                     obs_stats=obs_statz, 
                                                     seed=seed, fn=savefile, 
                                                     n_particles=n_particles, eps_init=eps_init, maxsim=maxsim)

In [None]:
np.savez(savefile, {'all_ps' : all_ps, 
                    'all_logweights' : all_logweights,
                    'all_eps' : all_eps,
                    'all_nsims' : all_nsims,                    
                    'params_dict_true' : params_dict_true})
