# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats
from lfimodels.maprf.utils import setup_sim, setup_sampler, quick_plot, contour_draws


# parameters for this experiment

In [None]:
seed = 42    # seed for generation of xo for selected cell. MCMC currently not seeded ! 

idx_cell = 6 # load toy cell number i = idx_cell 

maxsim = int(1e5)
n_particles= int(1e3)

savefile = '../results/MCMC/toycell_' + str(idx_cell) + '/maprf_PMC_prior01_run_1_'+ str(n_particles) + 'particles_param9'
savefile


# load cell, generate xo

In [None]:
g, prior, d = setup_sim(seed, path='..')

filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
params_dict_true = np.load(filename)[()]

m = g.model
m.params_dict = params_dict_true.copy()
m.rng = np.random.RandomState(seed=seed)

pars_true, obs = m.read_params_buffer(), m.gen_single()
obs_stats = g.summary.calc([obs])

rf = g.model.params_to_rf(pars_true)[0]

plt.imshow(np.hstack((obs_stats[0,:-1].reshape(d,d), rf)), interpolation='None')
plt.show()

print('spike count', obs_stats[0,-1])


# define distance function (based on pilot runs)

In [None]:
from lfimodels.abc_methods.run_abc import run_smc

gts, pilots,_=g.gen(10000)
stats_mean, stats_std = np.zeros((1,pilots.shape[1])), np.ones((1,pilots.shape[1]))
#stats_mean, stats_std = pilots.mean(axis=0).reshape(1,-1), pilots.std(axis=0).reshape(1,-1)
#stats_mean[:,:-1] = 0
#stats_std[:,:-1]  = 1

# firing rate normalized onto the level of a single pixel
stats_std[:,:-1] = pilots[:, :-1].std()

# 50% of total variance is from firing rate
#stats_std[:,:-1] = pilots[:, :-1].std()*np.sqrt(pilots.shape[1]-1) # normalizes with std across *all* pixels
# 10% of total variance is from firing rate
#stats_std[:,:-1] = 1./3*pilots[:, :-1].std()*np.sqrt(pilots.shape[1]-1) # normalizes with std across *all* pixels



stats_mean[:,-1] = pilots[:,-1].mean()
stats_std[:,-1]  = pilots[:,-1].std()

#stats_std[:,-1] *= 1/np.sqrt(d/2) # rescaling the FR summary stat to contribute about 50% of distance on average

class normed_summary():
    
    def calc(self, y):

        x = g.summary.calc(y)

        return (x-stats_mean)/stats_std

obs_statz =  (obs_stats.flatten() - stats_mean) /  stats_std   


# unnormalized L2 distance

In [None]:
# compute distances over pilot runs

def calc_dist(stats_1, stats_2):
    """Euclidian distance between summary statistics"""
    return np.sqrt(np.sum( ((stats_1 - stats_2)) ** 2))

y_true = g.model.gen([pars_true for i in range(1000)])
x_true = [g.summary.calc(y_true[i]) for i in range(len(y_true))]
stats_true  = np.vstack(x_true)

# compute distances over pilot runs
dists = np.empty(stats_true.shape[0])
for i in range(stats_true.shape[0]):
    dists[i] = calc_dist( stats_true[i], obs_stats)

print(np.min(dists))

# show distance histogram (use to pick initial epsilon, e.g. roughly as median distance)
plt.hist(dists, bins=np.arange(0,np.max(dists),10), normed=True)    
    
dists = np.empty(pilots.shape[0])
for i in range(pilots.shape[0]):
    dists[i] = calc_dist( pilots[i], obs_stats )
# show distance histogram (use to pick initial epsilon, e.g. roughly as median distance)
plt.hist(dists, bins=np.arange(0,np.max(dists),10), normed=True)

print(np.min(dists))

plt.xlim([0,500])    
plt.show()

# normalized L2 distance

In [None]:
# compute distances over pilot runs

def calc_dist(stats_1, stats_2):
    """Euclidian distance between summary statistics"""
    return np.sqrt(np.sum( ((stats_1 - stats_2)/stats_std) ** 2))

y_true = g.model.gen([pars_true for i in range(1000)])
x_true = [g.summary.calc(y_true[i]) for i in range(len(y_true))]
stats_true  = np.vstack(x_true)

# compute distances over pilot runs
dists = np.empty(stats_true.shape[0])
for i in range(stats_true.shape[0]):
    dists[i] = calc_dist( stats_true[i], obs_stats )

print(np.min(dists))
# show distance histogram (use to pick initial epsilon, e.g. roughly as median distance)
plt.hist(dists, bins=np.arange(0,np.max(dists),0.01), normed=True)    
    
dists = np.empty(pilots.shape[0])
for i in range(pilots.shape[0]):
    dists[i] = calc_dist( pilots[i], obs_stats )
# show distance histogram (use to pick initial epsilon, e.g. roughly as median distance)
plt.hist(dists, bins=np.arange(0,np.max(dists),0.01), normed=True)

print(np.min(dists))

#plt.xlim([0,10])    
plt.show()

# ensure variances are comparable between samples and obs_stats

In [None]:
plt.imshow(np.hstack((stats_true[0,:-1].reshape(d,d), obs_stats[0,:-1].reshape(d,d))), interpolation='None')
plt.show()

# visualize 10 clostest summary stats to xo under chosen distance function

In [None]:

lvls = [0.2, 0.2]
plt.figure(figsize=(15, 5))
for i in range(10):
    
    plt.subplot(2,5,i+1)
    idx = np.argsort(dists)[i]
    
    print(dists[idx])

    x = (pilots[idx,:] - stats_mean) / stats_std
    plt.imshow(x[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
    plt.title(pilots[idx,-1])
    rfm = g.model.params_to_rf(gts[idx,:].reshape(-1))[0]
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r')
    plt.axis('off')
    
plt.subplots_adjust( wspace=0.2, hspace=0.1, left=0.1, bottom=0.12)
plt.show()


# run PMC

In [None]:
import abcpmc

eps = abcpmc.LinearEps(21, 5.0, 1.0)
def calc_dist(stats_1, stats_2):
    """Euclidian distance between summary statistics"""
    return np.sqrt(np.sum( ((stats_1 - stats_2)/stats_std) ** 2))

prior = abcpmc.GaussianPrior(mu=g.prior.m, sigma=g.prior.S)
def postfn(theta):
    return g.summary.calc([g.model.gen_single(theta.flatten())])
sampler = abcpmc.Sampler(N=n_particles, Y=obs_stats, postfn=postfn, dist=calc_dist)

for pool in sampler.sample(prior, eps):
    print("T: {0}, eps: {1:>.4f}, ratio: {2:>.4f}".format(pool.t, pool.eps, pool.ratio))
    for i, (mean, std) in enumerate(zip(np.mean(pool.thetas, axis=0), np.std(pool.thetas, axis=0))):
        print(u"    theta[{0}]: {1:>.4f} \u00B1 {2:>.4f}".format(i, mean,std))
    np.save(savefile + 't'+str(pool.t), {'t': pool.t,
                                         'eps' : pool.eps,
                                         'n_samples_iter' : n_particles / pool.ratio, 
                                         'params' : pool.thetas,
                                         'dists' : pool.dists
                                        })


# run SMC

In [None]:
seed = 90 # SMC seed
eps_init = dists[np.argmin( (dists-np.median(dists))**2 )]
print(eps_init)

all_ps, all_xs, all_logweights, all_eps, all_nsims = run_smc(model=g.model, prior=g.prior, summary=normed_summary(), 
                                                     obs_stats=obs_statz, 
                                                     seed=seed, fn=savefile, 
                                                     n_particles=n_particles, eps_init=eps_init, maxsim=maxsim)

In [None]:
posterior = dd.Gaussian(m = all_ps[-1].mean(axis=0), S = np.cov(all_ps[-1].T))
plot_pdf(posterior, pdf2=g.prior, lims=[-2,2], samples=all_ps[-1].T, figsize=(16,16));

In [None]:
np.savez(savefile, {'eps_init' : eps_init,
                    'obs_statz' : obs_statz,
                    'obs_stats' : obs_stats,
                    'n_particles' : n_particles, 
                    'maxsim' : maxsim,
                    'stats_mean' : stats_mean,
                    'stats_std' : stats_std, 
                    'all_ps' : all_ps, 
                    'all_logweights' : all_logweights,
                    'all_eps' : all_eps,
                    'all_nsims' : all_nsims,                    
                    'params_dict_true' : params_dict_true})
