# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats
from lfimodels.maprf.utils import get_maprf_prior_01, get_data_o, quick_plot, contour_draws


# parameters for this experiment

In [None]:
seed = 42

idx_cell = 2 # load toy cell number i 
# in this 'amortized inference' experiment the cell only plays a role for validating the results

# network architecture: 8 layer network [4x conv, 3x fully conn., 1x MoG], 20k parameters in total 

filter_sizes=[3,3,3,2]   # 4 conv ReLU layers
n_filters=(16,16,32,32)  # 16 to 32 filters
pool_sizes=[1,2,2,1]     # 
n_hiddens=[50,50,50]     # 3 fully connected layers

# N = 100k per round

n_train=100000

# single component (posterior at most STAs is well-approximated by single Gaussian - we also want to run more SNPE-A)

n_components=1

# single rounds (first round is always'amortized' and can be used with any other STA covered by the prior)

n_rounds=1

# new feature for CNN architectures: passing a value directly to the hidden layers (bypassing the conv layers).
# In this case, we pass the number of spikes (single number) directly, which allows to normalize the STAs 
# and hence help out the conv layers. Without that extra input, we couldn't recover the RF gain anymore. 
n_inputs_hidden = 1

# some learning-schedule parameters
lr_decay = 0.99
epochs=50
minibatch=50

svi=False          # large N should make this do nothing anyways
reg_lambda=0.      # just to make doubly sure SVI is switched off...

pilot_samples=1000 # z-scoring only applies to extra inputs (here: firing rate) directly fed to fully connected layers

prior_norm = True  # doesn't hurt. 
init_norm = False  # didn't yet figure how to best normalize initialization through conv- and ReLU- layers

path = '../results/SNPE'
filename1 = path + '/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base.pkl'
filename2 = path + '/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base_res.pkl'


# set up simulator

In [None]:
g, _, d = setup_sim(seed, path='..')

# load cell, generate xo
- obs_stats & pars_true will not actually be used during training. They are used afterwards to anectdotally check results 

In [None]:
filename = '../results/toy_cells/toy_cell_' + str(idx_cell+1) + '.npy'
obs_stats, pars_true = get_data_o(filename, g, seed)

# train network over full prior

In [None]:
inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                 n_components=n_components, n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                 filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden)

# print parameter numbers per layer (just weights, not biases)
def get_shape(i):
    return inf.network.aps[i].get_value().shape
print([get_shape(i) for i in range(1,17,2)])
print([np.prod(get_shape(i)) for i in range(1,17,2)])

# run SNPE-A for one round
log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
               lr_decay=lr_decay)

# check results
- quickly check results for a single toy cell to ensure the network got a sensible rough idea

In [None]:
posterior = posteriors[-1] 
posterior.ndim = posterior.xs[0].ndim

quick_plot(g, obs_stats, d, pars_true, posterior, log)

# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, pdf2=g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])
fig.savefig('res.pdf')

# bunch of example posterior draw contours
contour_draws(posterior, g, obs_stats, d)

# test other cells (check for 'amortization')

In [None]:
# gain and phase together define firing rate and SNR
a,b = 0.357703095858336, 1.54546216004078  # 5 Hz, SNR 10**(-15 / 10)

for j in range(3):
    
    print('\n')
    print('cell #' + str(j+1))
    print('\n')
    
    params_dict_test = np.load('../results/toy_cells/toy_cell_' + str(j+1) + '.npy')[()]
    
    m.params_dict = params_dict_test.copy()
    params_dict_test['kernel']['gain'] = a
    params_dict_test['glm']['bias'] = b
    
    pars_test = m.read_params_buffer()
    stats = s.calc([m.gen_single(pars_test)])

    post_test = inf.predict_uncorrected(s.calc([m.gen_single(pars_test)]))

    fig, _ = plot_pdf(post_test, pdf2=p, lims=[-3,3], gt=pars_test.reshape(-1), figsize=(12,12), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])

    plt.figure(figsize=(6,6))
    plt.subplot(2,3,1)
    plt.imshow(m.params_to_rf(pars_test)[0], interpolation='None')
    plt.subplot(2,3,2)
    plt.imshow(m.params_to_rf(post_test.xs[0].m.reshape(-1))[0], interpolation='None')
    plt.subplot(2,3,3)
    plt.imshow(stats[0,:-1].reshape(d,d), interpolation='None')
    
    for i in range(8):
        plt.subplot(4,4,9+i)
        plt.imshow(m.params_to_rf(post_test.gen().reshape(-1))[0], interpolation='None')

    plt.show()

# save results

In [None]:
# store output
io.save_pkl((log, trn_data, posteriors),filename1)
inf.generator.model = None # model cannot be pickled atm
io.save(inf, filename2)
