# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats
from lfimodels.maprf.utils import setup_sim, setup_sampler, get_data_o, quick_plot, contour_draws


# parameters for this experiment

seed = 42

n_cells = 3 # number of toy cells

n_samples = 500000

# load posteriors

## round #1

In [None]:
for idx_cell in range(1,n_cells+1):

    print('\n')
    print('cell #' + str(idx_cell))
    print('\n')
    
    path = '../results/SNPE/toycell_' + str(idx_cell)

    filename1 = path + '/maprf_2x10k_refined_prior01_run_1_round2_param7_nosvi_CDELFI.pkl'
    filename2 = path + '/maprf_100k_amortized_prior01_run_1_round2_param7_nosvi_base_res.pkl'
    filename3 = path + '/maprf_100k_amortized_prior01_run_1_round2_param7_nosvi_CDELFI_inference.pkl'

    savefile = '../results/MCMC/toycell_' + str(idx_cell) + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param7'
    
    # MCMC    
    print('MCMC')
    g, prior, d = setup_sim(seed, path='..')
    filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    obs_stats, pars_true = get_data_o(filename, g, seed)
    
    filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    params_dict_true = np.load(filename)[()]        
    tmp = np.load(savefile + '.npz')['arr_0'][()]
    T = tmp['T']

    samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias','log_A','logit_φ','log_f','logit_θ','log_γ','log_b']])
    
    # SNPE
    print('SNPE')
    filename = '../results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base.pkl'
    #filename = '../results/SNPE/toycell_'+str(idx_cell)+'/maprf_2x10k_refined_prior01_run_1_round2_param7_nosvi_CDELFI.pkl'
    log, trn_data, posteriors = io.load_pkl(filename)

    g, prior, d = setup_sim(seed, path='..') # reset model
    filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    obs_stats, pars_true = get_data_o(filename, g, seed)
    
    print('plotting')    
    plot_pdf(posteriors[-1], pdf2=g.prior, gt=pars_true.flatten(), lims=[-1,1], figsize=(12,12), resolution=100);
    plot_pdf(posteriors[-1], pdf2=g.prior, gt=pars_true.flatten(), samples=samples.T, lims=[-1,1], figsize=(12,12), resolution=100);
    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1)
    plt.semilogx(log[-1]['loss'])
    plt.subplot(1,2,2)
    plt.plot(log[-1]['loss'])
    plt.show() 

# round #2

In [None]:
for idx_cell in range(1,n_cells+1):

    print('\n')
    print('cell #' + str(idx_cell))
    print('\n')
    
    path = '../results/SNPE/toycell_' + str(idx_cell)

    filename1 = path + '/maprf_2x10k_refined_prior01_run_1_round2_param7_nosvi_CDELFI.pkl'
    filename2 = path + '/maprf_100k_amortized_prior01_run_1_round2_param7_nosvi_base_res.pkl'
    filename3 = path + '/maprf_100k_amortized_prior01_run_1_round2_param7_nosvi_CDELFI_inference.pkl'

    savefile = '../results/MCMC/toycell_' + str(idx_cell) + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param7'
    
    # MCMC    
    print('MCMC')
    g, prior, d = setup_sim(seed, path='..')
    filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    obs_stats, pars_true = get_data_o(filename, g, seed)
    
    filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    params_dict_true = np.load(filename)[()]        
    tmp = np.load(savefile + '.npz')['arr_0'][()]
    T = tmp['T']

    samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias','log_A','logit_φ','log_f','logit_θ','log_γ','log_b']])
    
    # SNPE
    print('SNPE')
    #filename = '../results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param7_nosvi_base.pkl'
    filename = '../results/SNPE/toycell_'+str(idx_cell)+'/maprf_2x10k_refined_prior01_run_1_round2_param7_nosvi_CDELFI.pkl'
    log, trn_data, posteriors = io.load_pkl(filename)

    g, prior, d = setup_sim(seed, path='..') # reset model
    filename = '../results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
    obs_stats, pars_true = get_data_o(filename, g, seed)
    
    print('plotting')    
    plot_pdf(posteriors[-1], pdf2=g.prior, gt=pars_true.flatten(), lims=[-1,1], figsize=(12,12), resolution=100);
    plot_pdf(posteriors[-1], pdf2=g.prior, gt=pars_true.flatten(), samples=samples.T, lims=[-1,1], figsize=(12,12), resolution=100);
    plt.figure(figsize=(12,6))    
    plt.subplot(1,2,1)
    plt.semilogx(log[-1]['loss'])
    plt.subplot(1,2,2)
    plt.plot(log[-1]['loss'])
    plt.show() 