# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats
from lfimodels.maprf.utils import setup_sim, get_data_o, quick_plot, contour_draws


# parameters for this experiment

In [None]:
seed = 42

idx_cell = 2 # load toy cell number i 

# N = 10k per round
n_train=10000

# two components
n_components=2

# extra rounds (rough idea: 1st to refine the proposal locally, 2nd to figure if posterior is non-Gaussian)
n_rounds=1

# number of features passing directly to the hidden layers (number of spikes)
n_inputs_hidden = 1

# some learning-schedule parameters
lr_decay = 1.0
epochs=100
minibatch=50

path = '../results/SNPE/toycell_' + str(idx_cell)

filename1 = path + '/maprf_2x10k_refined_prior01_run_1_round2_param7_nosvi_CDELFI.pkl'
filename2 = path + '/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base_res.pkl'

# set up simulator

In [None]:
g, prior, d = setup_sim(seed, path='..')

# load cell, generate xo

In [None]:
filename = '../results/toy_cells/toy_cell_' + str(idx_cell+1) + '.npy'
obs_stats, pars_true = get_data_o(filename, g, seed)

# load pretrained network

In [None]:
filename = '../results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base_res.pkl'
inf = io.load(filename)
inf.generator.model = g.model

# check proposal
proposal = inf.predict(obs_stats)


# swap out xo 
inf.obs = obs_stats

# check pretrained network evaluated on xo

In [None]:
quick_plot(g, obs_stats, d, pars_true, proposal)

# continue training network ('refine' the fit)

In [None]:
log, trn_data, posteriors = inf.run(n_train=n_train, 
                                    n_components=n_components,
                                    epochs=epochs, minibatch=minibatch, 
                                    n_rounds=n_rounds, lr_decay=lr_decay)

# check retrained network

In [None]:
posterior = posteriors[-1] 
posterior.ndim = posterior.xs[0].ndim

quick_plot(g, obs_stats, d, pars_true, posterior, log)

# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, pdf2=g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])
fig.savefig('res.pdf')


# bunch of example posterior draw contours
contour_draws(posterior, g, obs_stats, d)

# test other cells (check for 'amortization')

- if retraining did much, 'amortization' should be lost for all but the current evaluation cell


In [None]:
# gain and phase together define firing rate and SNR
a,b = 0.357703095858336, 1.54546216004078  # 5 Hz, SNR 10**(-15 / 10)

m = g.model
for j in range(3):
    
    print('\n')
    print('cell #' + str(j+1))
    print('\n')
    
    params_dict_test = np.load('../results/toy_cells/toy_cell_' + str(j+1) + '.npy')[()]
    
    m.params_dict = params_dict_test.copy()
    params_dict_test['kernel']['gain'] = a
    params_dict_test['glm']['bias'] = b
    
    pars_test = m.read_params_buffer()
    stats = g.summary.calc([m.gen_single(pars_test)])

    post_test = inf.predict_uncorrected(stats)

    fig, _ = plot_pdf(post_test, pdf2=g.prior, lims=[-3,3], gt=pars_test.reshape(-1), figsize=(12,12), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])

    plt.figure(figsize=(6,6))
    plt.subplot(2,3,1)
    plt.imshow(m.params_to_rf(pars_test)[0], interpolation='None')
    plt.subplot(2,3,2)
    plt.imshow(m.params_to_rf(post_test.xs[0].m.reshape(-1))[0], interpolation='None')
    plt.subplot(2,3,3)
    plt.imshow(stats[0,:-1].reshape(d,d), interpolation='None')
    
    for i in range(8):
        plt.subplot(4,4,9+i)
        plt.imshow(m.params_to_rf(post_test.gen().reshape(-1))[0], interpolation='None')

    plt.show()

# save results

In [None]:
# store output
io.save_pkl((log, trn_data, posteriors),filename1)
#inf.generator.model = None # model cannot be pickled atm
#io.save(inf, filename2)