# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats
from lfimodels.maprf.utils import get_maprf_prior_01, get_data_o, quick_plot, contour_draws


# set up simulator

In [None]:
seed = 42

sim_info = np.load('../results/sim_info.npy')[()]

d, params_ls = sim_info['d'], sim_info['params_ls']
m = model(filter_shape= np.array((d,d,2)), 
          parametrization=sim_info['parametrization'],
          params_ls=params_ls,
          seed=seed, 
          dt=sim_info['dt'], 
          duration=sim_info['duration'] )

p = get_maprf_prior_01(params_ls, seed)

s = maprfStats(n_summary=d*d+1) # summary stats (d x d RF + spike_count)
 
def rej(x):
    # rejects summary statistic if number of spikes == 0
    return x[:,-1] > 0

# generator object that auto-rejects some data-pairs (theta_i, x_i) right at sampling
g = dg.RejKernel(model=m, prior=p, summary=s, rej=rej, seed=seed)


# load cell, generate xo

In [None]:
## training data and true parameters, data, statistics

idx_cell = 1 # load toy cell number i 

filename = '../results/toy_cells/toy_cell_' + str(idx_cell+1) + '.npy'
obs_stats, pars_true = get_data_o(filename, g, seed)


# load pretrained network

In [None]:
filename = '../results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base_res.pkl'

inf = io.load(filename)
inf.generator.model = m

posterior = inf.predict(obs_stats, threshold=0.01)

# check pretrained network evaluated on xo

In [None]:
quick_plot(g, obs_stats, d, pars_true, posterior)

# continue training network ('refine' the fit)

In [None]:
# N = 10k per round
n_train=10000

# two components
n_components=2

# two rounds (rough idea: 1st to refine the proposal locally, 2nd to figure if posterior is non-Gaussian)
n_rounds=2

# number of features passing directly to the hidden layers (number of spikes)
n_inputs_hidden = 1

# some learning-schedule parameters
lr_decay = 1.0
epochs=100
minibatch=50

log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, 
                                    n_rounds=n_rounds, lr_decay=lr_decay)

posterior = posteriors[-1] 
posterior.ndim = posterior.xs[0].ndim

# check retrained network

In [None]:

quick_plot(g, obs_stats, d, pars_true, posterior, log)


# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, pdf2=p, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])
fig.savefig('res.pdf')


# bunch of example posterior draw contours
contour_draws(g, obs_stats, d)


# test other cells (check for 'amortization')

- if retraining did much, 'amortization' should be lost for all but the current evaluation cell


In [None]:
# 5 Hz, SNR 10**(-15 / 10)
a,b = 0.357703095858336, 1.54546216004078 # gain and phase together define firing rate and SNR


for j in range(3):
    
    print('\n')
    print('cell #' + str(j+1))
    print('\n')
    
    if j == idx_cell :
        print(' current cell ! should be good ! ')
        print('\n')
              
    params_dict_test = np.load('./results/toy_cells/toy_cell_' + str(j+1) + '.npy')[()]
    
    m.params_dict = params_dict_test.copy()
    
    params_dict_test['kernel']['gain'] = a
    params_dict_test['glm']['bias'] = b
    
    pars_test = m.read_params_buffer()
    stats = s.calc([m.gen_single(pars_test)])

    post_test = inf.predict_uncorrected(s.calc([m.gen_single(pars_test)]))

    fig, _ = plot_pdf(post_test, pdf2=p, lims=[-3,3], gt=pars_test.reshape(-1), figsize=(12,12), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])


    plt.figure(figsize=(6,6))
    plt.subplot(2,3,1)
    plt.imshow(m.params_to_rf(pars_test)[0], interpolation='None')
    plt.subplot(2,3,2)
    plt.imshow(m.params_to_rf(post_test.xs[0].m.reshape(-1))[0], interpolation='None')
    plt.subplot(2,3,3)
    plt.imshow(stats[0,:-1].reshape(d,d), interpolation='None')
    
    for i in range(8):
        plt.subplot(4,4,9+i)
        plt.imshow(m.params_to_rf(post_test.gen().reshape(-1))[0], interpolation='None')


    plt.show()


# save & load

In [None]:
# tbd

In [None]:
import delfi.utils.io as io

try: 
    inf.observables
except:
    inf.observables = []

#filename1 = './results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base.pkl'
#filename2 = './results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base_res.pkl'
#filename3 = './results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base_conf.pkl'
#filename4 = './results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base_net_only.pkl'

io.save_pkl((log, trn_data, posterior),filename1)
np.save(filename3, params_dict_true)
io.save_pkl(inf.network, filename4)
inf.generator.model = None
io.save(inf, filename2)
