# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.utils import get_maprf_prior_01, setup_sim, setup_sampler, get_data_o, quick_plot, contour_draws

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats

In [None]:
seed = 42

duration = 20

idx_cell = 3 # load toy cell number i 

In [None]:
## training data and true parameters, data, statistics

sim_info = np.load('./results/sim_info.npy')[()]
d, params_ls = sim_info['d'], sim_info['params_ls']

m = model(filter_shape= np.array((d,d,2)),
          parametrization=sim_info['parametrization'],
          params_ls=params_ls,
          seed=seed,
          dt=sim_info['dt'],
          duration=duration)

p, prior = get_maprf_prior_01(params_ls, seed)

s = maprfStats(n_summary=d*d+1) # summary stats (d x d RF + spike_count)


def rej(x):
    # rejects summary statistic if number of spikes == 0
    return x[:,-1] > 0

# generator object that auto-rejects some data-pairs (theta_i, x_i) right at sampling
g = dg.RejKernel(model=m, prior=p, summary=s, rej=rej, seed=seed)

# load cell, generate xo
filename = './results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
params_dict_true = np.load(filename)[()]
m.rng = np.random.RandomState(seed=seed)
m.params_dict = params_dict_true.copy()
pars_true = m.read_params_buffer()
obs_stats = s.calc([m.gen_single()])

In [None]:
contour_draws(g.prior, g, obs_stats, d=d)

# SNPE-A version

In [None]:
# network architecture: 8 layer network [4x conv, 3x fully conn., 1x MoG], 20k parameters in total 

filter_sizes=[3,3,3,2]   # 4 conv ReLU layers
n_filters=(16,16,32,32)  # 16 to 32 filters
pool_sizes=[1,2,2,1]     # 
n_hiddens=[50,50,50]     # 3 fully connected layers

# N = 50k per round

n_train=50000

# single component (posterior at most STAs is well-approximated by single Gaussian - we also want to run more SNPE-A)

n_components=1

# single rounds (first round is always'amortized' and can be used with any other STA covered by the prior)

n_rounds=1

# new feature for CNN architectures: passing a value directly to the hidden layers (bypassing the conv layers).
# In this case, we pass the number of spikes (single number) directly, which allows to normalize the STAs 
# and hence help out the conv layers. Without that extra input, we couldn't recover the RF gain anymore. 
n_inputs_hidden = 1

# some learning-schedule parameters
lr_decay = 0.99
epochs=50
minibatch=50

svi=False          # large N should make this do nothing anyways
reg_lambda=0.      # just to make doubly sure SVI is switched off...

pilot_samples=1000 # z-scoring only applies to extra inputs (here: firing rate) directly fed to fully connected layers

prior_norm = True  # doesn't hurt. 
init_norm = False  # didn't yet figure how to best normalize initialization through conv- and ReLU- layers

inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                 n_components=1, n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                 filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)

# print parameter numbers per layer (just weights, not biases)
def get_shape(i):
    return inf.network.aps[i].get_value().shape
print([get_shape(i) for i in range(1,17,2)])
print([np.prod(get_shape(i)) for i in range(1,17,2)])

# run SNPE-A for one round
log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
               lr_decay=lr_decay,n_components=n_components)


# SNPE-B version

In [None]:
# network architecture: 8 layer network [4x conv, 3x fully conn., 1x MoG], 20k parameters in total 

filter_sizes=[3,3,3,2]   # 4 conv ReLU layers
n_filters=(16,16,32,32)  # 16 to 32 filters
pool_sizes=[1,2,2,1]     # 
n_hiddens=[50,50,50]     # 3 fully connected layers

# N = 10k per round

n_train=10000

# MoG with n_component mixture components

n_components=4

# single rounds (first round is always'amortized' and can be used with any other STA covered by the prior)

n_rounds=1

# new feature for CNN architectures: passing a value directly to the hidden layers (bypassing the conv layers).
# In this case, we pass the number of spikes (single number) directly, which allows to normalize the STAs 
# and hence help out the conv layers. Without that extra input, we couldn't recover the RF gain anymore. 
n_inputs_hidden = 1

# some learning-schedule parameters
lr_decay = 0.99
epochs=100
minibatch=50

svi=False          # large N should make this do nothing anyways
reg_lambda=0.      # just to make doubly sure SVI is switched off...

pilot_samples=1000 # z-scoring only applies to extra inputs (here: firing rate) directly fed to fully connected layers

prior_norm = True  # doesn't hurt. 
init_norm = False  # didn't yet figure how to best normalize initialization through conv- and ReLU- layers

inf = infer.SNPE(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                 n_components=n_components, n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                 filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)

# print parameter numbers per layer (just weights, not biases)
def get_shape(i):
    return inf.network.aps[i].get_value().shape
print([get_shape(i) for i in range(1,17,2)])
print([np.prod(get_shape(i)) for i in range(1,17,2)])

# run SNPE-B for one round
log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
               lr_decay=lr_decay)


# round #1

In [None]:
n_samples = 20000

savefile = './results/MCMC/elife/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param7'
tmp = np.load(savefile + '.npz')['arr_0'][()]
T = tmp['T']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias','A','logit_φ','log_f','logit_θ','log_γ','log_b']])

fig, _ = plot_pdf(posteriors[-1], pdf2=g.prior, lims=[-3,3], samples=samples.T, 
                  gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])


# round #2

In [None]:
n_components=4
log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
               lr_decay=lr_decay, n_components=n_components)

In [None]:
plt.subplot(1,2,1)
plt.semilogx(log[-1]['loss'])
plt.subplot(1,2,2)
plt.plot(log[-1]['loss'])
plt.show()

In [None]:
fig, _ = plot_pdf(posteriors[-1], pdf2=g.prior, lims=[-3,3], samples=samples.T, 
                  gt=pars_true.reshape(-1), figsize=(16,16), resolution=200,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])
#fig.savefig('quadro_posterior_2rounds_CDELFI_110k_total.pdf')

In [None]:
fig, _ = plot_pdf(posteriors[-1], lims=[-3,3], samples=samples.T, 
                  gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])
#fig.savefig('quadro_posterior_2rounds_CDELFI_110k_total_noPrior.pdf')

In [None]:
fig, _ = plot_pdf(posteriors[-1], lims=[-3,3],
                  gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])
#fig.savefig('quadro_posterior_2rounds_CDELFI_110k_total_noSamples.pdf')

In [None]:
contour_draws(posteriors[-1], g, obs_stats, d=21)

In [None]:

hasattr(inf.network, 'extra_stats')

# save & load

In [None]:
import delfi.utils.io as io

try: 
    inf.observables
except:
    inf.observables = []

filename1 = './results/SNPE/elife/maprf_10k_amortized_prior01_run_1_round2_param7_nosvi_CDELFI.pkl'
filename2 = './results/SNPE/elife/maprf_10k_amortized_prior01_run_1_round2_param7_nosvi_CDELFI_res.pkl'
filename3 = './results/SNPE/elife/maprf_10k_amortized_prior01_run_1_round2_param7_nosvi_CDELFI_conf.pkl'
filename4 = './results/SNPE/elife/maprf_10k_amortized_prior01_run_1_round2_param7_nosvi_CDELFI_net_only.pkl'

io.save_pkl((log, trn_data, posteriors),filename1)
np.save(filename3, params_dict_true)
io.save_pkl(inf.network, filename4)


In [None]:
inf.generator.model = None
io.save(inf, filename2)

# compare with maprf sampling

In [None]:
# MCMC chain length (including burnin)
n_samples = 20000

g, prior, d = setup_sim(seed, path='.')

filename = './results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
params_dict_true = np.load(filename)[()]

m = model(filter_shape= np.array((d,d,2)),
          parametrization=sim_info['parametrization'],
          params_ls=params_ls,
          seed=seed,
          dt=sim_info['dt'],
          duration=duration)
m.params_dict = params_dict_true.copy()
m.rng = np.random.RandomState(seed=seed)
obs = m.gen_single()    

inference, data = setup_sampler(prior, obs, d, g, params_dict=params_dict_true, 
                      fix_position=True, parametrization='logit_φ')


In [None]:
print('- sampling RF params')
T, L = inference.sample(n_samples)
T = {k.name: t for k, t in T.items()}
print('- sampling Poisson params')
inference.sample_biases(data, T, g.model.dt);


In [None]:
savefile = './results/MCMC/elife/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param7'
np.savez(savefile, {'T' : T, 'params_dict_true' : params_dict_true})    

# inspect results

In [None]:
savefile = './results/MCMC/elife/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param7'
tmp = np.load(savefile + '.npz')['arr_0'][()]
T = tmp['T']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias','A','logit_φ','log_f','logit_θ','log_γ','log_b']])


In [None]:

burnin = 50

for key in ['bias', 'λo', 
            'gain', 'log_A', 'phase', 'logit_φ',
            'angle', 'logit_θ', 'freq', 'log_f',
            'ratio', 'width', 'log_γ', 'log_b', 
            ]:
    
    if key in T.keys():
        x = T[key][burnin:]
        plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
        plt.title(key)
        plt.show()
        print('mean:', x.mean())
        print('var:', x.var())
        

In [None]:
fig, _ = plot_pdf(posteriors[-1], pdf2=g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  samples=samples.T,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])
