# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.utils import get_maprf_prior_01, setup_sim, setup_sampler, \
get_data_o, quick_plot, contour_draws
from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats

seed = 42

root_path = '/home/marcel/Dropbox (mackelab)/team/Write/Working_Manuscripts/InferenceNeuralDynamics_2018/materials/fig2/mapRF'

In [None]:
## training data and true parameters, data, statistics

idx_cell = 6 # load toy cell number i 
filename = root_path + '/results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'

g, prior, d = setup_sim(seed, path=root_path)
obs_stats, pars_true = get_data_o(filename, g, seed)
rf = g.model.params_to_rf(pars_true)[0]

plt.figure(figsize=(6,6))
plt.imshow(rf, interpolation='None')
#plt.savefig('ground_truth_RF.pdf')
plt.title('ground-truth receptive field')
plt.show()
obs_stats, obs_stats[0,-1]

In [None]:
algo = 'CDELFI'


# network architecture: 8 layer network [4x conv, 3x fully conn., 1x MoG], 20k parameters in total 

filter_sizes=[3,3,3,3,2]   # 5 conv ReLU layers
n_filters=(16,16,32,32,32) # 16 to 64 filters
pool_sizes=[1,2,2,2,2]     # 
n_hiddens=[100,100,100]     # 3 fully connected layers

# N = 100k per round

n_train=10000

# single component (posterior at most STAs is well-approximated by single Gaussian - we also want to run more SNPE-A)

n_components=4

# single rounds (first round is always'amortized' and can be used with any other STA covered by the prior)

n_rounds=1

# new feature for CNN architectures: passing a value directly to the hidden layers (bypassing the conv layers).
# In this case, we pass the number of spikes (single number) directly, which allows to normalize the STAs 
# and hence help out the conv layers. Without that extra input, we couldn't recover the RF gain anymore. 
n_inputs_hidden = 1

# some learning-schedule parameters
lr_decay = 0.999
epochs=200
minibatch=50

svi=False          # large N should make this do nothing anyways
reg_lambda=0.0   # just to make doubly sure SVI is switched off...

pilot_samples=1000 # z-scoring only applies to extra inputs (here: firing rate) directly fed to fully connected layers

prior_norm = True  # doesn't hurt. 
init_norm = False  # didn't yet figure how to best normalize initialization through conv- and ReLU- layers

rank = None   # fitting only DIAGONAL covariances


# load

In [None]:
# load round #i results and continue
round_ = 2

# careful here: rounds start with computing proposal from current MDN state and *previous* proposal

path = root_path + '/results/SNPE/toycell_6/'

if round_ > 1:

    #path = '/media/marcel/636f7b46-1fd1-4600-b69e-86d2ed82002c/Biophysicality/lfi_experiments/results/SNPE/'
    filename1 = path + 'maprf_100k_elife_prior01_run_99_round' + str(round_-1) + '_param9_nosvi_CDELFI.pkl'
    _, _, proposal = io.load_pkl(filename1)
    proposal = proposal[-1] if isinstance(proposal, list) else proposal

filename4 = path + 'maprf_100k_elife_prior01_run_99_round' + str(round_) + '_param9_nosvi_CDELFI_net_only.pkl'
if algo == 'CDELFI' :

    tmp = io.load_pkl(filename4)
    
    inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                     pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                     n_components=tmp['network.spec_dict']['n_components'], rank=rank,
                     n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                     filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)    
    inf.network.params_dict = tmp['network.params_dict']    
    if round_ > 1:
        print('# proposal components:', proposal.n_components)
        inf.generator.proposal = proposal.project_to_gaussian() if proposal.n_components == 1 else proposal    
    inf.round = round_

elif algo == 'SNPE' : 
    raise NotImplementedError


In [None]:
#np.save(root_path + '/results/SNPE/toycell_6/posterior_round4', {'posterior' : inf.predict(obs_stats),
#                                                                    'proposal'  : proposal, 
#                                                                    'prior' : g.prior})

In [None]:
tmp = np.load(root_path + '/results/SNPE/toycell_6/maprf_100k_elife_prior01_run_9_round4_param9_nosvi_CDELFI_posterior.npy')[()]
posterior, proposal, prior = tmp['posterior'], tmp['proposal'], tmp['prior']

In [None]:
#posterior = inf.predict(obs_stats)
plot_prior = dd.TransformedNormal(m=g.prior.m, S = g.prior.S,
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

plot_post = dd.mixture.TransformedGaussianMixture.MoTG(
                            ms= [posterior.xs[i].m for i in range(posterior.n_components)], 
                            Ss =[posterior.xs[i].S for i in range(posterior.n_components)],
                            a = posterior.a, 
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

lims = np.array([[-1.5, -1.1, .001,         0,          .001, 0, 0, -.999, -.999], 
                 [ 1.5,  1.1, .999*np.pi, 2.5,   1.999*np.pi, 2, 4., .999,   .999]]).T

fig, _ = plot_pdf(plot_post, pdf2=plot_prior, lims=lims, gt=plot_post._f(pars_true.reshape(1,-1)).reshape(-1), 
                  figsize=(16,16), resolution=100,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo'])

#fig.savefig('quadro_posterior_run9_round1_comps1_1Hz_SNR12_5min.pdf')

# compare against MCMC sampler

In [None]:
n_samples=1000000
path = root_path + '/results/MCMC/'
savefile = path + 'toycell_' + str(idx_cell) + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param9_5min.npy'
tmp = np.load(savefile)[()] # {'T' : T, 'params_dict_true' : params_dict_true}


In [None]:

T = tmp['T']
params_dict_true = tmp['params_dict_true']

params_ls = ['bias', 'gain', 'phase', 'freq','angle','ratio','width', 'xo', 'yo']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in params_ls])

def symmetrize_sample_modes(samples):

    assert samples.ndim==2 and samples.shape[1] == 9 

    # assumes phase in [0, pi]
    assert np.min(samples[:,2]) >= 0. and np.max(samples[:,2] <= np.pi)
    # assumes angle in [0, 2*pi]
    assert np.min(samples[:,4]) >= 0. and np.max(samples[:,4] <= 2*np.pi)
    # assumes freq, ratio and width > 0
    assert np.all(np.min(samples[:,np.array([3,5,6])], axis=0) >= 0.)

    samples1 = samples.copy()
    idx = np.where( samples[:,4] > np.pi )[0]
    samples1[idx,4] = samples1[idx,4] - np.pi
    idx = np.where( samples[:,4] < np.pi )[0]
    samples1[idx,4] = samples1[idx,4] + np.pi
    #samples1[:,2] = np.pi - samples1[:,2]
    samples_all = np.vstack((samples, samples1))[::2, :]

    samples1 = samples_all.copy()
    samples1[:,1] = - samples1[:,1] 
    samples1[:,2] = np.pi - samples1[:,2] 
    samples_all = np.vstack((samples_all, samples1))[::2, :]

    return samples_all

samples = symmetrize_sample_modes(samples)
#samples = np.vstack((lims.T, samples))

pars_raw = np.array([ params_dict_true['glm']['bias'],
                      params_dict_true['kernel']['s']['gain'],
                      params_dict_true['kernel']['s']['phase'],
                      params_dict_true['kernel']['s']['angle'],
                      params_dict_true['kernel']['s']['freq'],
                      params_dict_true['kernel']['s']['ratio'],
                      params_dict_true['kernel']['s']['width'],
                      params_dict_true['kernel']['l']['xo'],
                      params_dict_true['kernel']['l']['yo'] ])


lims = np.array([[-1.5, -1.5, .001, 0,       .001, 0, 0, -0.5, -0.5], 
                 [ 1.5,  1.5, .999*np.pi, 3, 1.999*np.pi, 3, 5, 0.5, 0.5]]).T

fig, _ = plot_pdf(plot_post,  pdf2=plot_prior, lims=lims, gt=pars_raw.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         labels_params=['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width', 'xo', 'yo']);

fig.savefig('quadro_posterior_vs_MCMC_run99_round2_comps1_1Hz_SNR12_5min.pdf')


In [None]:
idx = np.array([0,1,2,4]) # bias, gain, frequency and angle

plot_post_small = dd.mixture.MoTG(ms=[x.m[idx] for x in plot_post.xs],
                                  Ss=[x.S[idx][:,idx] for x in plot_post.xs],
                                  a=plot_post.a,
                                  flags=plot_post.flags[idx],
                                  lower=plot_post.lower[idx],
                                  upper=plot_post.upper[idx]                                  
                                 )

plot_prior_small = dd.TransformedNormal(m=plot_prior.m[idx], S = plot_prior.S[idx][:,idx],
                            flags=plot_prior.flags[idx],
                            lower=plot_prior.lower[idx], 
                            upper=plot_prior.upper[idx]) 

fig, _ = plot_pdf(plot_post_small,  pdf2=plot_prior_small, lims=lims[idx], gt=pars_raw.reshape(-1)[idx], 
                  figsize=(9,9), resolution=100, 
                  samples=samples[:,idx].T,
         labels_params=['bias','gain', 'freq', 'angle']);

fig.savefig('quadro_posterior_vs_MCMC_run99_round2_comps1_1Hz_SNR12_5min_small.pdf')

# visualize RF and RF draws

In [None]:
lvls=[0.5, 0.5]
p = posterior
n_draws = 10 
plt.figure(figsize=(6,6))
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
plt.title('STA')
plt.savefig('quadro_posterior_run99_round2_comps4_1Hz_SNR12_5min_STA_only.pdf')
plt.show()

In [None]:
lvls=[0.5, 0.5]
p = posterior
n_draws = 10 
plt.figure(figsize=(6,6))
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
plt.title('STA + GT')
rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r')

plt.savefig('quadro_posterior_run99_round2_comps4_1Hz_SNR12_5min_STA.pdf')

plt.show()

In [None]:
lvls=[0.5, 0.5]
p = posterior
n_draws = 10 
plt.figure(figsize=(6,6))
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
for i in range(n_draws):
    rfm = g.model.params_to_rf(p.gen().reshape(-1))[0]
    
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()])
    #print(rfm.min(), rfm.max())
    plt.hold(True)
plt.title('RF posterior draws')

rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r')

plt.savefig('quadro_posterior_run99_round2_comps4_1Hz_SNR12_5min_draws.pdf')
plt.show()
