# deep identity mapping

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

## simplistic setup: 
- model parameters $\theta$ are the full spatiotemporal kernel (of size d*d, with d < 10), 
- summary statistic $x_0$ is spike-triggered average
- for sufficiently long simulations, posterior mean is just data summary statistic $x_0$. 
- purely spatial kernel for now (kernel-size in temporal dimentions fixed to 1)  
- *very* basic spiking non-linearity and noise model: threshold crossing, no spiking noise


In [None]:
%%capture
# notebook currently depends on code found only in feature_maprf-branch of lfi_models !

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import matplotlib.pyplot as plt
import numpy as np
import lfimodels.maprf.utils as utils

from lfimodels.maprf.maprf import maprf
from lfimodels.maprf.maprfStats import maprfStats
from delfi.utils.viz import plot_pdf

%matplotlib inline

In [None]:

seed = 42

d = 51 # edge length of (quadratic) receptive field
parametrization = 'gaussian' # RF is Gaussian bump with specific mean and cov 

filter_shape = np.array((d,d,1))
m = maprf(filter_shape=filter_shape, 
          parametrization=parametrization,
          seed=seed, 
          duration=1000 )

if parametrization=='gaussian':
    p = dd.GaussianRF(ab=[-2,-0],
                      dd=[-d//2,d//2,-d//2,d//2],
                      ks = np.array([3,3]), cd = [-.9,0.9])
elif parametrization=='full':
    p = dd.Gaussian(m=np.zeros(d*d+1), P=1/100*np.eye(d*d+1))

s = maprfStats(n_summary=m.n_params)
g = dg.Default(model=m, prior=p, summary=s)

true_params, labels_params = utils.obs_params(filter_shape,parametrization)
true_params = np.array([-d, -d//4,-d//4,10,15,0.7])
obs = m.gen_single(true_params)
obs_stats = s.calc([obs])

h_true = m.params_to_rf(true_params[1:])
plt.figure(figsize=(16,7))
plt.subplot(1,3,1)
plt.plot(np.dot( obs['I'], h_true) + true_params[0])
plt.title('neural activation function over observed data')
plt.subplot(1,3,2)
plt.imshow(h_true.reshape(d,d), interpolation='None')
plt.title('ground-truth filter')
plt.subplot(1,3,3)
plt.imshow(obs_stats.T[1:].reshape(filter_shape[0],filter_shape[1]), 
          interpolation='None')
plt.title('summary statistics')
plt.show()

# bunch of example prior draws
plt.figure(figsize=(16,7))
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.imshow(m.params_to_rf(p.gen()[0,1:]).reshape(d,d), interpolation='None')
plt.subplot(2,5,3)
plt.title('RF prior samples')
plt.show()


obs['data'].mean()

In [None]:

res = infer.SNPE(g, obs=obs_stats, n_hiddens=[20,20])

out = res.run(10000, n_rounds=1, minibatch=50, epochs=1000)
posterior = res.predict(obs_stats)

plt.figure(figsize=(16,6))
plt.subplot(1,4,1)
h_prior = m.params_to_rf(p.mean[1:])
plt.imshow(h_prior.reshape(d,d), interpolation='None')
plt.title('prior mean filter')
plt.subplot(1,4,2)
h_est =  m.params_to_rf(posterior.calc_mean_and_cov()[0][1:])
plt.imshow(h_est.reshape(d,d), interpolation='None')
plt.title('posterior mean filter')
plt.subplot(1,4,3)
plt.imshow(h_true.reshape(d,d), 
          interpolation='None')
plt.title('true spatial filter')
plt.subplot(1,4,4)
plt.imshow(obs_stats.T[1:].reshape(filter_shape[0],filter_shape[1]), 
          interpolation='None')
plt.title('x0')
plt.show()

#print('prior mean\n', p.mean[1:].reshape(-1,1))
#print('post mean\n', posterior.calc_mean_and_cov()[0][1:].reshape(-1,1))
#print('true pars\n', true_params[1:].reshape(-1,1))
#print('x0\n', obs_stats.T[1:]) 

posterior = res.predict(obs_stats)
plot_pdf(posterior.xs[0], lims=[-5,5], gt=true_params, figsize=(12,12));
tmp=posterior.xs[0]
tmp.eval(true_params.reshape(1,-1))

# construction site

In [None]:
%%capture
%matplotlib inline
# notebook currently depends on code found only in feature_maprf-branch of lfi_models !

import delfi.neuralnet as dn
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import matplotlib.pyplot as plt
import numpy as np
import lfimodels.maprf.utils as utils

from lfimodels.maprf.maprf import maprf
from lfimodels.maprf.maprfStats import maprfStats
from delfi.utils.viz import plot_pdf

import lasagne.layers as ll
import theano
import theano.tensor as tt
import collections

In [None]:

seed = 42

## simulation model

d = 41                    # edge length of (quadratic) receptive field
parametrization = 'gabor' # ['full', 'gaussian', 'gabor']
filter_shape = np.array((d,d,2))
m = maprf(filter_shape=filter_shape, 
          parametrization=parametrization,
          seed=seed, 
          duration=500 )


## prior over simulation parameters

prior = collections.OrderedDict()
prior['b0'] = {'mu' : np.array([0.]), 'sigma' : np.array([1]) }
prior['vec_f']  = {'mu' : np.zeros(2), 'sigma' : 0.4 * np.ones(2) }
prior['vec_A']  = {'mu' : np.zeros(2), 'sigma' : 1.0 * np.ones(2) }
#prior = priors['kernel']['s']['width'] # fields 'a', 'b', 'f'
#mu, sigma = normal_from_ci(*zip(prior['b'], prior['a']), prior['f'])
prior['log_γ']  = {'mu' : np.array([-0.098]), 'sigma' : np.array([0.256])}
prior['log_b']  = {'mu' : np.array([ 0.955]), 'sigma' : np.array([0.236])}
#prior['xo'] = {'mu' : np.array([0.]), 'sigma' : np.array([5/np.sqrt(.5)])}
#prior['yo'] = {'mu' : np.array([0.]), 'sigma' : np.array([5/np.sqrt(.5)])}
ax_t = m._gen.axis_t
Λ =  np.diag(ax_t / 0.075 * np.exp(1 - ax_t / 0.075))
D = np.eye(ax_t.shape[0]) - np.eye(ax_t.shape[0], k=-1)
F = np.dot(D, D.T)
Σ = np.dot(Λ, np.linalg.inv(F).dot(Λ))
prior['kt'] = {'mu': np.zeros_like(ax_t), 'sigma': np.linalg.inv(D).dot(Λ)}
mu  = np.concatenate([prior[i][ 'mu'  ] for i in prior.keys()])
L = np.diag(np.concatenate([prior[i]['sigma'] for i in list(prior.keys())[:-1]]))
L = np.block([[L, np.zeros((L.shape[0], ax_t.size))], 
              [np.zeros((ax_t.size, L.shape[1])), prior['kt']['sigma']]])
p = dd.Gaussian(m=mu, S=L.T.dot(L))

## data summary staistics

s = maprfStats(n_summary=d*d)

g = dg.Default(model=m, prior=p, summary=s)

## network 

n_inputs = [d,d]
n_outputs = m.n_params
n_components=1
n_hiddens=[50]
n_filters=[30,30]
n_train_round = 1

network = dn.NeuralNet.NeuralNet(n_inputs, 
                     n_outputs, 
                     n_components, 
                     n_filters=n_filters, 
                     n_hiddens=n_hiddens,
                     seed=seed,
                     svi=False)
loss = -tt.mean(network.lprobs)


test_fun = theano.function([network.stats], [ll.get_output(network.layer['conv_'+str(i)]) for i in range(1,len(n_filters)+1)])
print('conv layer shapes:',  [test_fun(np.zeros((10,1,d,d)))[i].shape  for i in range(len(n_filters))])

trn_inputs = [network.params, network.stats]


In [None]:
epochs = 100
minibatch = 50
n_samples = 1000

## training data and true parameters, data, statistics

pars_true = np.array([-0.5,.7,.3,2.,1.,1.,2.5,1.,0.])
obs = m.gen_single(pars_true)
obs_stats = s.calc([obs])

trn_data = g.gen(n_samples)
trn_data = (trn_data[0], trn_data[1].reshape(-1,1,d,d))

## training

t = dn.Trainer.Trainer(network, loss,
            trn_data=trn_data, trn_inputs=trn_inputs)
logs=[]
logs.append(t.train(epochs=epochs, minibatch=minibatch))

posterior = network.get_mog(obs_stats.reshape(1,1,d,d))


# bunch of example prior draws
plt.figure(figsize=(16,10))
for i in range(15):
    plt.subplot(3,5,i+1)
    plt.imshow(m.params_to_rf(p.gen().reshape(-1))[0], interpolation='None')
plt.title('RF prior STAs')
plt.show()

plt.figure(figsize=(16,6))
plt.subplot(1,3,1)
plt.imshow(obs_stats.reshape(d,d), interpolation='None')
plt.title('data STA')
plt.subplot(1,3,2)
plt.imshow(m.params_to_rf(pars_true)[0], interpolation='None')
plt.title('ground-truth posterior RF')
plt.subplot(1,3,3)
plt.imshow(m.params_to_rf(posterior.xs[0].m)[0], interpolation='None')
plt.title('posterior mean STA')
plt.show()

# bunch of example posterior draws
plt.figure(figsize=(16,10))
for i in range(15):
    plt.subplot(3,5,i+1)
    plt.imshow(m.params_to_rf(posterior.gen().reshape(-1))[0], interpolation='None')
plt.title('RF posterior STAs')
plt.show()

plot_pdf(posterior, lims=[-5,5], gt=np.asarray(pars_true).reshape(-1), figsize=(12,12));
