# deep identity mapping

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

## simplistic setup: 
- model parameters $\theta$ are the full spatiotemporal kernel (of size d*d, with d < 10), 
- summary statistic $x_0$ is spike-triggered average
- for sufficiently long simulations, posterior mean is just data summary statistic $x_0$. 
- purely spatial kernel for now (kernel-size in temporal dimentions fixed to 1)  
- *very* basic spiking non-linearity and noise model: threshold crossing, no spiking noise


In [None]:
%%capture
# notebook currently depends on code found only in feature_maprf-branch of lfi_models !

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import matplotlib.pyplot as plt
import numpy as np
import lfimodels.maprf.utils as utils

from lfimodels.maprf.maprf import maprf
from lfimodels.maprf.maprfStats import maprfStats
from delfi.utils.viz import plot_pdf

%matplotlib inline

In [None]:

seed = 42

d = 11 # edge length of (quadratic) receptive field
parametrization = 'gaussian' # RF is Gaussian bump with specific mean and cov 

filter_shape = np.array((d,d,1))
m = maprf(filter_shape=filter_shape, 
          parametrization=parametrization,
          seed=seed, 
          duration=10000 )

p = dd.GaussianRF(ab=[-1,1],
                  m=np.zeros(2), B=np.sqrt(d)*np.eye(2),
                  df = 2, scale =np.eye(2)/2)
s = maprfStats(n_summary=m.n_params)
g = dg.Default(model=m, prior=p, summary=s)

true_params, labels_params = utils.obs_params(filter_shape,parametrization)
obs = m.gen_single(true_params)
obs_stats = s.calc([obs])

_, stats = g.gen(1000)
stats_mean = stats.mean(axis=0)
stats_std  = stats.std(axis=0)

h_true = m.params_to_rf(true_params[1:])
plt.subplot(1,2,1)
plt.plot(np.dot( obs['I'], h_true) + true_params[0])
plt.title('neural activation function over observed data')
plt.subplot(1,2,2)
plt.imshow(h_true.reshape(d,d), interpolation='None')
plt.colorbar()

plt.show()

In [None]:

res = infer.SNPE(g, obs=obs_stats, n_hiddens=[50])
res.stats_mean = stats_mean
res.stats_std = stats_std

out = res.run(1000, n_rounds=1)

posterior = res.predict(obs_stats)


plt.figure(figsize=(16,6))
plt.subplot(1,4,1)
h_prior = m.params_to_rf(p.mean[1:])
plt.imshow(h_prior.reshape(d,d), interpolation='None')
plt.title('prior mean filter')
plt.subplot(1,4,2)
h_est =  m.params_to_rf(posterior.calc_mean_and_cov()[0][1:])
plt.imshow(h_est.reshape(d,d), interpolation='None')
plt.title('posterior mean filter')
plt.subplot(1,4,3)
plt.imshow(h_true.reshape(d,d), 
          interpolation='None')
plt.title('true spatial filter')
plt.subplot(1,4,4)
plt.imshow(obs_stats.T[1:].reshape(filter_shape[0],filter_shape[1]), 
          interpolation='None')
plt.title('x0')

#plt.subplot(1,4,4)
#plt.plot(posterior.calc_mean_and_cov()[0][1:], true_params[1:], 'b.')
#plt.axis('square')
#plt.title('est. vs true pars (all entries)')
#plt.legend(['post. mean', 'x0'])
#plt.xlabel('est. filter values')
#plt.ylabel('true filter values')

plt.show()

print('prior mean\n', p.mean[1:].reshape(-1,1))
print('post mean\n', posterior.calc_mean_and_cov()[0][1:].reshape(-1,1))
print('true pars\n', true_params[1:].reshape(-1,1))
#print('x0\n', obs_stats.T[1:]) 
