# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats
from lfimodels.maprf.utils import setup_sim, setup_sampler, get_data_o, quick_plot, contour_draws


# parameters for this experiment

seed = 42

n_cells = 3 # number of toy cells

# N = 10k per round
n_train=10000

# two components
n_components=2

# extra rounds (rough idea: 1st to refine the proposal locally, 2nd to figure if posterior is non-Gaussian)
n_rounds=1

# number of features passing directly to the hidden layers (number of spikes)
n_inputs_hidden = 1

# some learning-schedule parameters
lr_decay = 1.0
epochs=100
minibatch=50

# MCMC chain length (including burnin)
n_samples = 500000


# learn posteriors

In [None]:
for idx_cell in range(1,n_cells+1):

    print('\n')
    print('cell #' + str(idx_cell))
    print('\n')
    
    path = '../results/SNPE/toycell_' + str(idx_cell)

    filename1 = path + '/maprf_2x10k_refined_prior01_run_1_round2_param7_nosvi_CDELFI.pkl'
    filename2 = path + '/maprf_100k_amortized_prior01_run_1_round2_param7_nosvi_base_res.pkl'
    filename3 = path + '/maprf_100k_amortized_prior01_run_1_round2_param7_nosvi_CDELFI_inference.pkl'

    savefile = '../results/MCMC/toycell_' + str(idx_cell) + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param7'
    

    #"""
    # MCMC    
    print('MCMC')
    g, prior, d = setup_sim(seed, path='..')
    filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    obs_stats, pars_true = get_data_o(filename, g, seed)
    
    filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    params_dict_true = np.load(filename)[()]
    m = g.model
    m.params_dict = params_dict_true.copy()
    m.rng = np.random.RandomState(seed=seed)
    obs = m.gen_single()    
    
    inference, data = setup_sampler(prior, obs, d, g, params_dict=params_dict_true, 
                          fix_position=True, parametrization='logit_φ')
    
    print('- sampling RF params')
    T, L = inference.sample(n_samples)
    T = {k.name: t for k, t in T.items()} 
    print('- sampling Poisson params')
    inference.sample_biases(data, T, g.model.dt)
    
    np.savez(savefile, {'T' : T, 'params_dict_true' : params_dict_true})    
    #"""
    
    #"""
    # SNPE
    print('SNPE')
    filename = '../results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base_res.pkl'
    inf = io.load(filename)
    g, prior, d = setup_sim(seed, path='..') # reset model
    filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    obs_stats, pars_true = get_data_o(filename, g, seed)
    
    inf.generator = g
    inf.obs = obs_stats

    #proposal = inf.predict(obs_stats)
    #quick_plot(g, obs_stats, d, pars_true, proposal)

    log, trn_data, posteriors = inf.run(n_train=n_train, 
                                        n_components=n_components,
                                        epochs=epochs, minibatch=minibatch, 
                                        n_rounds=n_rounds, lr_decay=lr_decay)

    #quick_plot(g, obs_stats, d, pars_true, posterior, log)
    #contour_draws(posterior, g, obs_stats, d)

    io.save_pkl((log, trn_data, posteriors),filename1)
    inf.generator.model = None # model cannot be pickled atm
    io.save_pkl((inf.generator,inf.stats_mean,inf.stats_std,
                 inf.network,inf.kwargs,
                 inf.loss,inf.reg_lambda),filename3)    
    #"""