# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.utils import get_maprf_prior_01, setup_sim, setup_sampler, get_data_o, quick_plot, contour_draws

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats

In [None]:
seed = 42
g, prior, d = setup_sim(seed, path='.')

In [None]:
## training data and true parameters, data, statistics

idx_cell = 3# load toy cell number i 

filename = './results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
obs_stats, pars_true = get_data_o(filename, g, seed)

rf = g.model.params_to_rf(pars_true)[0]#[10:31, 10:31]

plt.imshow(rf, interpolation='None')
plt.show()

np.sum(rf**2) / pars_true[1]**2

In [None]:
contour_draws(g.prior, g, obs_stats, d=d)

In [None]:
# network architecture: 8 layer network [4x conv, 3x fully conn., 1x MoG], 20k parameters in total 

filter_sizes=[3,3,3,3,2]   # 5 conv ReLU layers
n_filters=(16,16,32,32,32) # 16 to 32 filters
pool_sizes=[1,2,2,2,2]     # 
n_hiddens=[50,50,50]     # 3 fully connected layers

# N = 100k per round

n_train=100000

# single component (posterior at most STAs is well-approximated by single Gaussian - we also want to run more SNPE-A)

n_components=1

# single rounds (first round is always'amortized' and can be used with any other STA covered by the prior)

n_rounds=1

# new feature for CNN architectures: passing a value directly to the hidden layers (bypassing the conv layers).
# In this case, we pass the number of spikes (single number) directly, which allows to normalize the STAs 
# and hence help out the conv layers. Without that extra input, we couldn't recover the RF gain anymore. 
n_inputs_hidden = 1

# some learning-schedule parameters
lr_decay = 0.99
epochs=10
minibatch=50

svi=False          # large N should make this do nothing anyways
reg_lambda=0.      # just to make doubly sure SVI is switched off...

pilot_samples=1000 # z-scoring only applies to extra inputs (here: firing rate) directly fed to fully connected layers

prior_norm = True  # doesn't hurt. 
init_norm = False  # didn't yet figure how to best normalize initialization through conv- and ReLU- layers

inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                 n_components=1, n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                 filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)

# print parameter numbers per layer (just weights, not biases)
def get_shape(i):
    return inf.network.aps[i].get_value().shape
print([get_shape(i) for i in range(1,17,2)])
print([np.prod(get_shape(i)) for i in range(1,17,2)])

# run SNPE-A for one round
log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
               lr_decay=lr_decay,n_components=n_components)


In [None]:
posterior = inf.predict(obs_stats)
posterior.ndim = posterior.xs[0].ndim

quick_plot(g, obs_stats, d, pars_true, posterior, log)

# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, pdf2=g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo'])
fig.savefig('res.pdf')

# bunch of example posterior draw contours
contour_draws(posterior, g, obs_stats, d)

In [None]:
n_draws = 10
plt.figure(figsize=(6,6))
plt.imshow(np.hstack((
    obs_stats[0,:-1].reshape(d,d),
    g.model.params_to_rf(pars_true.reshape(-1))[0])),
    interpolation='None', cmap='gray')
lvls=[0.5, 0.5]
for i in range(n_draws):
    rfm = g.model.params_to_rf(posterior.gen(1).reshape(-1))[0]
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()])
    #print(rfm.min(), rfm.max())
    plt.hold(True)
plt.title('RF posterior draws')
plt.savefig('posterior_location_round1_example.pdf')
plt.show()

In [None]:
log2, trn_data2, posteriors2 = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
               lr_decay=lr_decay,n_components=n_components)

In [None]:
posterior = posteriors2[-1] 
posterior.ndim = posterior.xs[0].ndim

quick_plot(g, obs_stats, d, pars_true, posterior, log2)

# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, pdf2=g.prior, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])
fig.savefig('res.pdf')

# bunch of example posterior draw contours
contour_draws(posterior, g, obs_stats, d)

# test toy cells (check for 'amortization')

# 0.5 Hz, SNR -15

In [None]:
a,b = 0.504180511342426, -0.8202461745687 # gain and phase together define firing rate and SNR

for j in range(3):
    
    print('\n')
    print('cell #' + str(j+1))
    print('\n')
    
    params_dict_test = np.load('./results/toy_cells/toy_cell_' + str(j+1) + '.npy')[()]
    
    m.params_dict = params_dict_test.copy()
    
    params_dict_test['kernel']['gain'] = a
    params_dict_test['glm']['bias'] = b
    
    pars_test = m.read_params_buffer()
    stats = s.calc([m.gen_single(pars_test)])

    post_test = inf.predict_uncorrected(s.calc([m.gen_single(pars_test)]))

    fig, _ = plot_pdf(post_test, pdf2=p, lims=[-3,3], gt=pars_test.reshape(-1), figsize=(12,12), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])


    plt.figure(figsize=(6,6))
    plt.subplot(2,3,1)
    plt.imshow(m.params_to_rf(pars_test)[0], interpolation='None')
    plt.subplot(2,3,2)
    plt.imshow(m.params_to_rf(post_test.xs[0].m.reshape(-1))[0], interpolation='None')
    plt.subplot(2,3,3)
    plt.imshow(stats[0,:-1].reshape(d,d), interpolation='None')
    
    for i in range(8):
        plt.subplot(4,4,9+i)
        plt.imshow(m.params_to_rf(post_test.gen().reshape(-1))[0], interpolation='None')

    plt.show()


# 2 Hz, SNR -15

In [None]:
a,b = 0.422120643630379, 0.604054261670482 # gain and phase together define firing rate and SNR

for j in range(3):
    
    print('\n')
    print('cell #' + str(j+1))
    print('\n')
    
    params_dict_test = np.load('./results/toy_cells/toy_cell_' + str(j+1) + '.npy')[()]
    
    m.params_dict = params_dict_test.copy()
    
    params_dict_test['kernel']['gain'] = a
    params_dict_test['glm']['bias'] = b
    
    pars_test = m.read_params_buffer()
    stats = s.calc([m.gen_single(pars_test)])

    post_test = inf.predict_uncorrected(s.calc([m.gen_single(pars_test)]))

    fig, _ = plot_pdf(post_test, pdf2=p, lims=[-3,3], gt=pars_test.reshape(-1), figsize=(12,12), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])


    plt.figure(figsize=(6,6))
    plt.subplot(2,3,1)
    plt.imshow(m.params_to_rf(pars_test)[0], interpolation='None')
    plt.subplot(2,3,2)
    plt.imshow(m.params_to_rf(post_test.xs[0].m.reshape(-1))[0], interpolation='None')
    plt.subplot(2,3,3)
    plt.imshow(stats[0,:-1].reshape(d,d), interpolation='None')
    
    for i in range(8):
        plt.subplot(4,4,9+i)
        plt.imshow(m.params_to_rf(post_test.gen().reshape(-1))[0], interpolation='None')


    plt.show()


# 5 Hz, SNR -15

In [None]:
a,b = 0.357703095858336, 1.54546216004078 # gain and phase together define firing rate and SNR

for j in range(3):
    
    print('\n')
    print('cell #' + str(j+1))
    print('\n')
    
    params_dict_test = np.load('./results/toy_cells/toy_cell_' + str(j+1) + '.npy')[()]
    
    m.params_dict = params_dict_test.copy()
    
    params_dict_test['kernel']['gain'] = a
    params_dict_test['glm']['bias'] = b
    
    pars_test = m.read_params_buffer()
    stats = s.calc([m.gen_single(pars_test)])

    post_test = inf.predict_uncorrected(s.calc([m.gen_single(pars_test)]))

    fig, _ = plot_pdf(post_test, pdf2=p, lims=[-3,3], gt=pars_test.reshape(-1), figsize=(12,12), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])


    plt.figure(figsize=(6,6))
    plt.subplot(2,3,1)
    plt.imshow(m.params_to_rf(pars_test)[0], interpolation='None')
    plt.subplot(2,3,2)
    plt.imshow(m.params_to_rf(post_test.xs[0].m.reshape(-1))[0], interpolation='None')
    plt.subplot(2,3,3)
    plt.imshow(stats[0,:-1].reshape(d,d), interpolation='None')
    
    for i in range(8):
        plt.subplot(4,4,9+i)
        plt.imshow(m.params_to_rf(post_test.gen().reshape(-1))[0], interpolation='None')


    plt.show()


# 1 Hz , SNR $\sim$ 1/8

In [None]:
a,b = 0.852551611364445, -0.363422125020056 # gain and phase together define firing rate and SNR

for j in range(3):
    
    print('\n')
    print('cell #' + str(j+1))
    print('\n')
    
    params_dict_test = np.load('./results/toy_cells/toy_cell_' + str(j+1) + '.npy')[()]
    
    m.params_dict = params_dict_test.copy()
    
    params_dict_test['kernel']['gain'] = a
    params_dict_test['glm']['bias'] = b
    
    pars_test = m.read_params_buffer()
    stats = s.calc([m.gen_single(pars_test)])

    post_test = inf.predict_uncorrected(s.calc([m.gen_single(pars_test)]))

    fig, _ = plot_pdf(post_test, pdf2=p, lims=[-3,3], gt=pars_test.reshape(-1), figsize=(12,12), resolution=100,
                  labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width'])


    plt.figure(figsize=(6,6))
    plt.subplot(2,3,1)
    plt.imshow(m.params_to_rf(pars_test)[0], interpolation='None')
    plt.subplot(2,3,2)
    plt.imshow(m.params_to_rf(post_test.xs[0].m.reshape(-1))[0], interpolation='None')
    plt.subplot(2,3,3)
    plt.imshow(stats[0,:-1].reshape(d,d), interpolation='None')
    
    for i in range(8):
        plt.subplot(4,4,9+i)
        plt.imshow(m.params_to_rf(post_test.gen().reshape(-1))[0], interpolation='None')


    plt.show()


# save & load

In [None]:
import delfi.utils.io as io

try: 
    inf.observables
except:
    inf.observables = []

filename1 = './results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param9_nosvi_base.pkl'
filename2 = './results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param9_nosvi_base_res.pkl'
filename3 = './results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param9_nosvi_base_conf.pkl'
filename4 = './results/SNPE/maprf_100k_amortized_prior01_run_1_round1_param9_nosvi_base_net_only.pkl'

io.save_pkl((log, trn_data, posterior),filename1)
np.save(filename3, params_dict_true)
io.save_pkl(inf.network, filename4)
inf.generator.model = None
io.save(inf, filename2)


In [None]:
T = np.load('posterior_samples_20.npz')['arr_0'].tolist()['T']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['log_A', 'logit_φ', 'log_f','logit_θ','log_γ','log_b']])

plot_pdf(posterior, pdf2=p, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         ticks=True, labels_params=['log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width']);


# check prior

In [None]:
T = np.load('prior_samples.npz')['arr_0'].tolist()['T']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['log_A', 'logit_φ', 'log_f','logit_θ','log_γ','log_b']])

plot_pdf(p, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         ticks=True, labels_params=['log_A', 'logit_φ', 'log_f', 'logit_θ', 'log ratio', 'log width']);


T = np.load('prior_samples.npz')['arr_0'].tolist()['T']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['gain', 'phase', 'freq','angle','ratio','width']])

ks = params_dict_true['kernel']['s']
pars_raw = np.array([ ks['gain'], ks['phase'], ks['freq'], ks['angle'], ks['ratio'], ks['width'] ])

plot_pdf(p, lims=[-3,3], gt=pars_raw.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         ticks=True, labels_params=['gain', 'phase', 'freq', 'angle', 'ratio', 'width']);


# compare with maprf sampling

In [None]:
filename = './results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
params_dict_true = np.load(filename)[()]

m = g.model
m.params_dict = params_dict_true.copy()
m.rng = np.random.RandomState(seed=seed)

pars_true, obs = m.read_params_buffer(), m.gen_single()
obs_stats = g.summary.calc([obs])

In [None]:
inference, data = setup_sampler(prior, obs, d, g, params_dict=params_dict_true, 
                          fix_position=False, parametrization='logit_φ')

In [None]:
n_samples = 10000

T, L = inference.sample(n_samples)
T = {k.name: t for k, t in T.items()}

x,y = T['xo'],T['yo']

plt.figure(figsize=(15, 4))
plt.subplot(121)
plt.plot(x[500:])

plt.subplot(122)
plt.hist(x[500:], alpha=0.5, normed=True)
plt.show()

plt.figure(figsize=(15, 4))
plt.subplot(121)
plt.plot(x[500:], y[500:], '.k', alpha=0.1)
plt.show()


In [None]:
inference.sample_biases(data, T, m.dt)

plt.figure(figsize=(12,5))
plt.subplot(2,1,1)
plt.plot(T['bias'])
print('mean: ' + str(T['bias'].mean()) + ', var: ' + str(T['bias'].var()))
plt.subplot(2,1,2)
plt.plot(T['λo'])
print('mean: ' + str(T['λo'].mean()) + ', var: ' + str(T['λo'].var()))
plt.show()

In [None]:

plt.figure(figsize=(16,12))
i = 1
for t in np.sort(np.random.choice(T['gain'].shape[0], 12, replace=False)):
    params_dict = {'kernel' : {'s' : {}}, 'glm': {}}
    params_dict['glm']['bias'] = T['bias'][t]
    params_dict['kernel']['s']['phase'] = T['phase'][t]
    params_dict['kernel']['s']['angle'] = T['angle'][t] 
    params_dict['kernel']['s']['freq']  = T['freq'][t]
    params_dict['kernel']['s']['ratio'] = T['ratio'][t]
    params_dict['kernel']['s']['width'] = T['width'][t]
    params_dict['kernel']['s']['gain'] = T['gain'][t]

    ks = m._eval_ks(bias=-0.5, 
                             angle=params_dict['kernel']['s']['angle'],
                             freq=params_dict['kernel']['s']['freq'],
                             gain=params_dict['kernel']['s']['gain'],
                             phase=params_dict['kernel']['s']['phase'],
                             ratio=params_dict['kernel']['s']['ratio'],
                             width=params_dict['kernel']['s']['width'])
    
    plt.subplot(3,4,i)
    plt.imshow(np.hstack((ks.reshape(d,d), m.params_to_rf(pars_true)[0])), interpolation='None')
    plt.title('t =' + str(t))
    
    print('loc:' , [T['xo'][t], T['yo'][t]])    
    i += 1
plt.show()



In [None]:

burnin = 50

for key in ['bias', 'λo', 
            'gain', 'log_A', 'phase', 'logit_φ',
            'angle', 'logit_θ', 'freq', 'log_f',
            'ratio', 'width', 'log_γ', 'log_b', 
            ]:
    
    if key in T.keys():
        x = T[key][burnin:]
        plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
        plt.title(key)
        plt.show()
        print('mean:', x.mean())
        print('var:', x.var())
        

In [None]:
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias', 'gain', 'phase', 'freq','angle','ratio','width']])

pars_raw = np.array([ params_dict_true['glm']['bias'],
                      params_dict_true['kernel']['s']['gain'],
                      params_dict_true['kernel']['s']['phase'],
                      params_dict_true['kernel']['s']['freq'],
                      params_dict_true['kernel']['s']['angle'],
                      params_dict_true['kernel']['s']['ratio'],
                      params_dict_true['kernel']['s']['width'] ])

plot_pdf(g.prior, lims=[-3,3], gt=pars_raw.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         ticks=True, labels_params=['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width']);


In [None]:
#T = np.load('posterior_samples_20.npz')['arr_0'].tolist()['T']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['log_A', 'logit_φ', 'log_f','logit_θ','log_γ','log_b']])

plot_pdf(p, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         ticks=True, labels_params=['log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width']);


In [None]:
plt.plot(samples)
plt.show()