# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats
from lfimodels.maprf.utils import get_maprf_prior_01


# set up simulator

In [None]:
seed = 42

sim_info = np.load('../results/sim_info.npy')[()]

d, params_ls = sim_info['d'], sim_info['params_ls']
m = model(filter_shape= np.array((d,d,2)), 
          parametrization=sim_info['parametrization'],
          params_ls=params_ls,
          seed=seed, 
          dt=sim_info['dt'], 
          duration=sim_info['duration'] )

p = get_maprf_prior_01(params_ls, seed)

s = maprfStats(n_summary=d*d+1) # summary stats (d x d RF + spike_count)
 
def rej(x):
    # rejects summary statistic if number of spikes == 0
    return x[:,-1] > 0

# generator object that auto-rejects some data-pairs (theta_i, x_i) right at sampling
g = dg.RejKernel(model=m, prior=p, summary=s, rej=rej, seed=seed)


# load cell, generate xo

In [None]:
## training data and true parameters, data, statistics

idx_cell = 1 # load toy cell number i 

filename = '../results/toy_cells/toy_cell_' + str(idx_cell+1) + '.npy'
obs_stats, pars_true = get_data_o(filename, g, seed)


# compare with maprf sampling

In [None]:

from maprf.utils import *
import theano.tensor as tt
from tqdm import tqdm
from maprf.inference import *
from maprf.rfs.v1 import SimpleLinear_full_kt
from maprf.glm import Poisson

fix_position=True # fixues RF position during sampling to (0,0)
parametrization='logit_φ'

# generative model
rf = SimpleLinear_full_kt()
emt = Poisson()

# inputs and outputs
data = [theano.shared(empty(3), 'frames'),
        theano.shared(empty(1, dtype='int64'))]
frames, spikes = data

# fill the grids
rf.grids['s'][0].set_value(m._gen.grid_x)
rf.grids['s'][1].set_value(m._gen.grid_y)
rf.grids['t'][0].set_value(m._gen.axis_t)

# inference model
inference = Inference(rf, emt, bias=params_dict_true['glm']['bias'])
inference.priors = {
    'glm': {         'bias':  {'name':  'gamma',
                               'varname': 'λo',
                               'alpha': 1.0, #prior['λo']['alpha'][0],
                               'beta':  1.0}}, #prior['λo']['beta'][0]}},
    'kernel': {'s': {'ratio':  {'name': 'normal',
                                'varname': 'log_γ',
                                'sigma': prior['log_γ']['sigma'][0],
                                'mu':    prior['log_γ']['mu'][0]}, 
                     'width':  {'name': 'normal',
                                'varname': 'log_b',
                                'sigma': prior['log_b']['sigma'][0],
                                'mu':    prior['log_b']['mu'][0]}}}}

if 'log_A' in prior.keys() and 'logit_φ' in prior.keys():
    inference.priors['kernel']['s']['gain'] =  {'name': 'lognormal',
                                                  'varname': 'log_A',
                                                  'mu': prior['log_A']['mu'][0],
                                                  'sigma': prior['log_A']['sigma'][0]}            
    inference.priors['kernel']['s']['phase'] =  {'name': 'logitnormal',
                                                  'varname': 'logit_φ',
                                                  'mu': prior['logit_φ']['mu'][0],
                                                  'sigma': prior['logit_φ']['sigma'][0]}            
else:
    raise NotImplemented()
    
if 'log_f' in prior.keys() and 'logit_θ' in prior.keys():
    inference.priors['kernel']['s']['freq'] =  {'name': 'lognormal',
                                                  'varname': 'log_f',
                                                  'mu': prior['log_f']['mu'][0],
                                                  'sigma': prior['log_f']['sigma'][0]}            
    inference.priors['kernel']['s']['angle'] =  {'name': 'logitnormal',
                                                  'varname': 'logit_θ',
                                                  'mu': prior['logit_θ']['mu'][0],
                                                  'sigma': prior['logit_θ']['sigma'][0]}            
else:
    raise NotImplemented()

    
if 'kt' in prior.keys():
    inference.priors['kernel']['t'] = prior['kt']

    
inference.add_sampler(GaborSampler(fix_position=fix_position, parametrization=parametrization))
print(inference.samplers[0].params)

# temporal kernel (here [1,0])
kt = tt.vector('kt')
inference.rf.filter.kernel['t'] = kt / tt.sqrt(tt.dot(kt, kt)) # ensure normalization (to firing rate)
inference.add_inputs(kt)

print('inputs: ', inference.inputs)
print('priors: ', inference.priors)

inference.build(data)
inference.compile()


# set MCMC chain initializer
inference.loglik['xo'] = 0.
inference.loglik['yo'] = 0.
ks = params_dict_true['kernel']['s']
inference.loglik['log_γ'] = np.log(ks['ratio'])  
inference.loglik['log_b'] = np.log(ks['width']) 
inference.loglik['kt'] =    params_dict_true['kernel']['t']['value'].copy()  # np.array([0.5, 0.0])
if 'log_A' in prior.keys() and 'logit_φ' in prior.keys():
    inference.loglik['log_A']   = np.log(ks['gain'])
    inference.loglik['logit_φ'] = np.log(ks['phase'] / (  np.pi - ks['angle']))        
if 'log_f' in prior.keys() and 'logit_θ' in prior.keys():
    inference.loglik['log_f'] = np.log(ks['freq'])
    inference.loglik['logit_θ'] = np.log(ks['angle'] / (2*np.pi - ks['angle']))
    
# hand over data
frames.set_value(obs['I'][:,:].reshape(-1,d,d))
spikes.set_value(obs['data'][:])

# use this instead for sampling from the prior: 
#frames.set_value(0*obs['I'][:1,:].reshape(-1,d,d))
#spikes.set_value(0*obs['data'][:1])

plt.plot(spikes.get_value())

print(np.sum(obs['data']))

# sample RF parameters (with Poisson bias marginalized out)

In [None]:

T, L = inference.sample(10000)
T = {k.name: t for k, t in T.items()}

x = T['xo']
y = T['yo']

plt.figure(figsize=(15, 4))
plt.subplot(121)
plt.plot(x[500:])

plt.subplot(122)
plt.hist(x[500:], alpha=0.5, normed=True)
plt.show()

plt.figure(figsize=(15, 4))
plt.subplot(121)
plt.plot(x[500:], y[500:], '.k', alpha=0.1)
plt.show()


# sample Poisson bias (conditioned on the others)

In [None]:
inference.sample_biases(data, T, m.dt)

plt.figure(figsize=(12,5))
plt.subplot(2,1,1)
plt.plot(T['bias'])
print('mean: ' + str(T['bias'].mean()) + ', var: ' + str(T['bias'].var()))
plt.subplot(2,1,2)
plt.plot(T['λo'])
print('mean: ' + str(T['λo'].mean()) + ', var: ' + str(T['λo'].var()))
plt.show()

# (roughly) check for mixing of the chain

In [None]:
plt.figure(figsize = (16,5) )
plt.plot(samples)
plt.show()

# example posterior draws (in direct comparison with xo)

In [None]:

for t in np.sort(np.random.choice(T['gain'].shape[0], 10, replace=False)):
    params_dict = {'kernel' : {'s' : {}}, 'glm': {}}
    params_dict['glm']['bias'] = T['bias'][t]
    params_dict['kernel']['s']['phase'] = T['phase'][t]
    params_dict['kernel']['s']['angle'] = T['angle'][t] 
    params_dict['kernel']['s']['freq']  = T['freq'][t]
    params_dict['kernel']['s']['ratio'] = T['ratio'][t]
    params_dict['kernel']['s']['width'] = T['width'][t]
    params_dict['kernel']['s']['gain'] = T['gain'][t]

    ks = m._eval_ks(bias=-0.5, 
                             angle=params_dict['kernel']['s']['angle'],
                             freq=params_dict['kernel']['s']['freq'],
                             gain=params_dict['kernel']['s']['gain'],
                             phase=params_dict['kernel']['s']['phase'],
                             ratio=params_dict['kernel']['s']['ratio'],
                             width=params_dict['kernel']['s']['width'])

    plt.imshow(np.hstack((ks.reshape(d,d), m.params_to_rf(pars_true)[0])), interpolation='None')
    plt.title('t =' + str(t))
    plt.show()

    print('loc:' , [T['xo'][t], T['yo'][t]])

# marginal histograms for each (transformed) parameter

In [None]:
x=T['ratio'][50:]
plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
plt.title('ratio')
plt.show()
print('mean:', T['ratio'].mean())
print('var:', T['ratio'].var())

x=T['log_γ'][50:]
plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
plt.title('log ratio')
plt.show()
print('mean:', T['log_γ'].mean())
print('var:', T['log_γ'].var())

x=T['width'][50:]
plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
plt.title('width')
plt.show()
print('mean:', T['width'].mean())
print('var:', T['width'].var())

x=T['log_b'][50:]
plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
plt.title('log width')
plt.show()
print('mean:', T['log_b'].mean())
print('var:', T['log_b'].var())


x=T['angle'][50:]
plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
plt.title('angle')
plt.show()
print('mean:', T['angle'].mean())
print('var:', T['angle'].var())

x=T['freq'][50:]
plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
plt.title('freq')
plt.show()
print('mean:', T['freq'].mean())
print('var:', T['freq'].var())

if 'vec_f' in T.keys():
    x=T['vec_f'][50:,0]
    plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
    plt.title('vec_f[0]')
    plt.show()
    print('mean:', T['vec_f'][:,0].mean())
    print('var:',  T['vec_f'][:,0].var())

    x=T['vec_f'][50:,1]
    plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
    plt.title('vec_f[1]')
    plt.show()
    print('mean:', T['vec_f'][:,1].mean())
    print('var:',  T['vec_f'][:,1].var())

x=T['gain'][500:]
plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
plt.title('gain')
plt.show()
print('mean:', T['gain'].mean())
print('var:', T['gain'].var())

x=T['phase'][50:]
plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
plt.title('phase')
plt.show()
print('mean:', T['phase'].mean())
print('var:', T['phase'].var())

if 'logit_φ' in T.keys():
    x=T['logit_φ'][50:]
    plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
    plt.title('logit phase')
    plt.show()
    print('mean:', T['logit_φ'].mean())
    print('var:', T['logit_φ'].var())

if 'log_A' in T.keys():
    x=T['log_A'][50:]
    plt.hist(x, bins=np.linspace(x.min(), x.max(), 50), alpha=0.5, normed=True)
    plt.title('log_A')
    plt.show()
    print('mean:', T['log_A'][:].mean())
    print('var:',  T['log_A'][:].var())
    

if 'vec_A' in T.keys():
    x=T['vec_A'][50:,0]
    plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
    plt.title('vec_A[0]')
    plt.show()
    print('mean:', T['vec_A'][:,0].mean())
    print('var:',  T['vec_A'][:,0].var())    
    x=T['vec_A'][50:,1]
    plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
    plt.title('vec_A[1]')
    plt.show()
    print('mean:', T['vec_A'][:,1].mean())
    print('var:',  T['vec_A'][:,1].var())
    
if 'bias' in T.keys():
    x=T['bias'][50:]
    plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
    plt.title('bias')
    plt.show()
    print('mean:', T['bias'].mean())
    print('var:', T['bias'].var())

if 'λo' in T.keys():
    x=T['λo'][50:]
    plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
    plt.title('exp bias')
    plt.show()
    print('mean:', T['λo'].mean())
    print('var:', T['λo'].var())


# posterior samples versus prior


## actual parameters

In [None]:
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias', 'gain', 'phase', 'freq','angle','ratio','width']])

plot_pdf(p, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         ticks=True, labels_params=['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width']);


## parameters in log/logit space

In [None]:
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias', 'log_A', 'logit_φ', 'log_f','logit_θ','log_γ','log_b']])

plot_pdf(p, lims=[-3,3], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         ticks=True, labels_params=['bias', 'log gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width']);


In [None]:
# tbd

In [None]:
try: 
    #np.savez('posterior_samples_20', {'T' : T})
except:
    pass

In [None]:
#try: 
#    np.savez('prior_samples', {'T' : T})
#except:
#    pass