# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.utils import get_maprf_prior_01, setup_sim, setup_sampler, \
get_data_o, quick_plot, contour_draws
from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats

seed = 42

In [None]:
## training data and true parameters, data, statistics

idx_cell = 6 # load toy cell number i 
filename = './results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'

g, prior, d = setup_sim(seed, path='.')
obs_stats, pars_true = get_data_o(filename, g, seed)
rf = g.model.params_to_rf(pars_true)[0]

plt.imshow(rf, interpolation='None')
plt.show()
obs_stats, obs_stats[0,-1]

In [None]:
contour_draws(g.prior, g, obs_stats, d=d)

In [None]:
algo = 'CDELFI'


# network architecture: 8 layer network [4x conv, 3x fully conn., 1x MoG], 20k parameters in total 

filter_sizes=[3,3,3,3,2]   # 5 conv ReLU layers
n_filters=(16,16,32,32,32) # 16 to 64 filters
pool_sizes=[1,2,2,2,2]     # 
n_hiddens=[100,100,100]     # 3 fully connected layers

# N = 100k per round

n_train=10000

# single component (posterior at most STAs is well-approximated by single Gaussian - we also want to run more SNPE-A)

n_components=4

# single rounds (first round is always'amortized' and can be used with any other STA covered by the prior)

n_rounds=1

# new feature for CNN architectures: passing a value directly to the hidden layers (bypassing the conv layers).
# In this case, we pass the number of spikes (single number) directly, which allows to normalize the STAs 
# and hence help out the conv layers. Without that extra input, we couldn't recover the RF gain anymore. 
n_inputs_hidden = 1

# some learning-schedule parameters
lr_decay = 0.999
epochs=200
minibatch=50

svi=False          # large N should make this do nothing anyways
reg_lambda=0.0   # just to make doubly sure SVI is switched off...

pilot_samples=1000 # z-scoring only applies to extra inputs (here: firing rate) directly fed to fully connected layers

prior_norm = True  # doesn't hurt. 
init_norm = False  # didn't yet figure how to best normalize initialization through conv- and ReLU- layers

rank = None   # fitting only DIAGONAL covariances


# First round

In [None]:
if algo == 'CDELFI':

    inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                     pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                     n_components=1, n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                     filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,
                     rank=rank, verbose=True)
    
elif algo == 'SNPE':

    init_norm = True
    inf = infer.SNPE(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                     pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                     n_components=n_components, n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                     filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,
                     rank=rank, verbose=True)

In [None]:
# print parameter numbers per layer (just weights, not biases)
def get_shape(i):
    return inf.network.aps[i].get_value().shape
print([get_shape(i) for i in range(1,17,2)])
print([np.prod(get_shape(i)) for i in range(1,17,2)])

In [None]:
if algo == 'CDELFI':
    
    #run SNPE-A for one round
    log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   lr_decay=lr_decay,n_components=1)

elif algo == 'SNPE':

    # run SNPE-B for one round
    log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   lr_decay=lr_decay)


In [None]:
posterior = inf.predict(obs_stats)
posterior.ndim = posterior.xs[0].ndim

quick_plot(g, obs_stats, d, pars_true, posterior, log)

# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, pdf2=g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo'])
fig.savefig('res.pdf')

# bunch of example posterior draw contours
contour_draws(posterior, g, obs_stats, d)

In [None]:
n_draws = 10
plt.figure(figsize=(6,6))
plt.imshow(np.hstack((
    obs_stats[0,:-1].reshape(d,d),
    g.model.params_to_rf(pars_true.reshape(-1))[0])),
    interpolation='None', cmap='gray')
lvls=[0.5, 0.5]
for i in range(n_draws):
    rfm = g.model.params_to_rf(posterior.gen(1).reshape(-1))[0]
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()])
    #print(rfm.min(), rfm.max())
    plt.hold(True)
plt.title('RF posterior draws')
#plt.savefig('posterior_location_round1_example.pdf')
plt.show()

In [None]:
round_ = 1
filename1 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI.pkl'
filename2 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI_res.pkl'
filename4 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI_net_only.pkl'


io.save_pkl((log, trn_data, posterior),filename1)
net = inf.network
data = {'network.spec_dict' : net.spec_dict, 
        'network.params_dict' : net.params_dict }
io.save_pkl(data, filename4)

# second round

In [None]:
# load round #1 results and continue
round_ = 1
filename1 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_base.pkl'
filename2 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_base_res.pkl'
filename4 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_base_net_only.pkl'

if algo == 'CDELFI' :

    tmp = io.load_pkl(filename4)
    
    inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                     pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                     n_components=tmp['network.spec_dict']['n_components'], rank=rank,
                     n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                     filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)    
    inf.network.params_dict = tmp['network.params_dict']    
    inf.round = round_
    
elif algo == 'SNPE' : 
    raise NotImplementedError
    

posterior = inf.predict(obs_stats)
quick_plot(g, obs_stats, d, pars_true, posterior)

# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, pdf2=g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo'])
#fig.savefig('quadro_posterior_2rounds_CDELFI_200k_total_symmetry_breaking.pdf')

# bunch of example posterior draw contours
contour_draws(posterior, g, obs_stats, d)

In [None]:
if algo == 'CDELFI':
    
    #run SNPE-A for one round
    n_components2 = 1

    log2, trn_data2, posteriors2 = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   lr_decay=lr_decay,n_components=n_components2, stndrd_comps=True)
    
elif algo == 'SNPE':

    # run SNPE-B for one round
    lr = 0.0001
    log2, trn_data2, posteriors2 = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   lr=lr, lr_decay=lr_decay)

    iws = trn_data2[-1][2]
    iws = iws/iws.sum()
    ESS = 1./ np.sum( iws ** 2)
    print('ESS', ESS)

    
posterior = posteriors2[-1] 
posterior.ndim = posterior.xs[0].ndim

quick_plot(g, obs_stats, d, pars_true, posterior, log2)

# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, pdf2=g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo'])
#fig.savefig('quadro_posterior_2rounds_CDELFI_200k_total_symmetry_breaking.pdf')

# bunch of example posterior draw contours
contour_draws(posterior, g, obs_stats, d)

In [None]:
round_ = 2
filename1 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI.pkl'
filename2 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI_res.pkl'
filename4 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI_net_only.pkl'

io.save_pkl((log2, trn_data2, posteriors2),filename1)
net = inf.network
data = {'network.spec_dict' : net.spec_dict, 
        'network.params_dict' : net.params_dict }
io.save_pkl(data, filename4)

# third round

In [None]:
# load round #2 results and continue
round_ = 2

# careful here: rounds start with computing proposal from current MDN state and *previous* proposal
filename1 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_-1) + '_param9_nosvi_CDELFI.pkl'
_, _, proposal = io.load_pkl(filename1)
proposal = proposal[-1] if isinstance(proposal, list) else proposal

filename4 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI_net_only.pkl'
if algo == 'CDELFI' :

    tmp = io.load_pkl(filename4)
    
    inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                     pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                     n_components=tmp['network.spec_dict']['n_components'], rank=rank,
                     n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                     filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)    
    inf.network.params_dict = tmp['network.params_dict']    
    print('# proposal components:', proposal.n_components)
    inf.generator.proposal = proposal.project_to_gaussian() if proposal.n_components == 1 else proposal    
    inf.round = round_

    
elif algo == 'SNPE' : 
    raise NotImplementedError
    

In [None]:
if algo == 'CDELFI':
    
    #run SNPE-A for one round
    n_components3 = 8
    project_proposal=False
    stndrd_comps=True
    n_train=100000
    epochs=50

    log3, trn_data3, posteriors3 = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   lr_decay=lr_decay,n_components=n_components3, stndrd_comps=stndrd_comps, project_proposal=project_proposal)
    
elif algo == 'SNPE':

    # run SNPE-B for one round
    lr = 0.0001
    log3, trn_data3, posteriors3 = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   lr_decay=lr_decay, lr =lr)

    iws = trn_data3[-1][2]
    iws = iws/iws.sum()
    ESS = 1./ np.sum( iws ** 2)
    print('ESS', ESS)

posterior = posteriors3[-1] 
posterior.ndim = posterior.xs[0].ndim

quick_plot(g, obs_stats, d, pars_true, posterior, log3)

# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, pdf2=g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo'])

# bunch of example posterior draw contours
contour_draws(posterior, g, obs_stats, d)


In [None]:
plot_prior = dd.TransformedNormal(m=g.prior.m, S = g.prior.S,
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

plot_post = dd.mixture.TransformedGaussianMixture.MoTG(
                            ms= [posterior.xs[i].m for i in range(posterior.n_components)], 
                            Ss =[posterior.xs[i].S for i in range(posterior.n_components)],
                            a = posterior.a, 
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

lims = np.array([[-2, -2, .001, 0,       .001, 0, 0, -.999, -.999], 
                 [ 2,  2, .999*np.pi, 3, 1.999*np.pi, 3, 3, .999,   .999]]).T

fig, _ = plot_pdf(plot_post, pdf2=plot_prior, lims=lims, gt=plot_post._f(pars_true.reshape(1,-1)).reshape(-1), 
                  figsize=(16,16), resolution=100,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo'])

#fig.savefig('quadro_posterior_run8_round3_comps4_0_5Hz_SNR12.pdf')

In [None]:
lvls=[0.5, 0.5]
p = posterior
n_draws = 10 
plt.figure(figsize=(6,6))
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
for i in range(n_draws):
    rfm = g.model.params_to_rf(p.gen().reshape(-1))[0]
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()])
    #print(rfm.min(), rfm.max())
    plt.hold(True)
plt.title('RF posterior draws')

rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r')

#plt.savefig('quadro_posterior_run8_round3_comps4_0_5Hz_SNR12_draws')
plt.show()


In [None]:
lvls=[0.5, 0.5]
p = posterior
n_draws = 10 
plt.figure(figsize=(6,6))
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
plt.title('STA + GT')
rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r')

#plt.savefig('quadro_posterior_run8_round3_comps4_0_5Hz_SNR12_STA')

plt.show()

In [None]:
round_ = 3
filename1 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI.pkl'
filename2 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI_res.pkl'
filename4 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI_net_only.pkl'

io.save_pkl((log3, trn_data3, posteriors3),filename1)
net = inf.network
data = {'network.spec_dict' : net.spec_dict, 
        'network.params_dict' : net.params_dict }
io.save_pkl(data, filename4)

In [None]:

lims = [
    [-0.5, 0],
    [-1.5, 1.5],
    [-4,4],
    [0, 1],
    [-3,3],
    [-1.5,1.5],
    [-1,2],
    [0,0.7],
    [0,1]
]

for i in range(len(posterior_unc.xs)):
    post_comp = posterior_unc.xs[i]
    fig, _ = plot_pdf(post_comp, pdf2=g.proposal, lims=lims, gt=pars_true.reshape(-1), figsize=(16,16), resolution=100, ticks=True,
                      labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo'])


# ... and fourth?

In [None]:
#params_dict = inf.network.params_dict
#params_dict.pop('means.mW2')
#params_dict.pop('means.mb2')
#params_dict.pop('means.mW3')
#params_dict.pop('means.mb3')
#
#params_dict.pop('precisions.mW2')
#params_dict.pop('precisions.mb2')
#params_dict.pop('precisions.mW3')
#params_dict.pop('precisions.mb3')
#
#params_dict['weights.mW'] = params_dict['weights.mW'][:,:2]
#params_dict['weights.mb'] = params_dict['weights.mb'][  :2]
#
#inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
#                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
#                 n_components=2, rank=rank,
#                 n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
#                 filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)
#
#inf.network.params_dict = params_dict
#inf.generator.proposal = inf.predict(obs_stats)

In [None]:

if algo == 'CDELFI':
    
    #run SNPE-A for one round
    n_components4 = 4
    project_proposal=False
    stndrd_comps=False
    n_train=100000
    epochs=20
    
    log4, trn_data4, posteriors4 = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   lr_decay=lr_decay,n_components=n_components4, stndrd_comps=stndrd_comps, project_proposal=project_proposal)
    

elif algo == 'SNPE':

    # run SNPE-B for one round
    lr = 0.0001
    log4, trn_data4, posteriors4 = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   lr_decay=lr_decay, lr =lr)

    iws = trn_data4[-1][2]
    iws = iws/iws.sum()
    ESS = 1./ np.sum( iws ** 2)
    print('ESS', ESS)

posterior = posteriors4[-1] 
posterior.ndim = posterior.xs[0].ndim

quick_plot(g, obs_stats, d, pars_true, posterior, log4)

# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, pdf2=g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo'])

# bunch of example posterior draw contours
contour_draws(posterior, g, obs_stats, d)


In [None]:
round_ = 4
filename1 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI.pkl'
filename2 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI_res.pkl'
filename4 = './results/SNPE/maprf_100k_elife_prior01_run_9_round' + str(round_) + '_param9_nosvi_CDELFI_net_only.pkl'

io.save_pkl((log4, trn_data4, posteriors4),filename1)
net = inf.network
data = {'network.spec_dict' : net.spec_dict, 
        'network.params_dict' : net.params_dict }
io.save_pkl(data, filename4)