# deep identity mapping

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline
# notebook currently depends on code found only in feature_maprf-branch of lfi_models !

import delfi.neuralnet as dn
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import matplotlib.pyplot as plt
import numpy as np
import lfimodels.maprf.utils as utils

from lfimodels.maprf.maprf import maprf
from lfimodels.maprf.maprfStats import maprfStats
from delfi.utils.viz import plot_pdf

import lasagne.layers as ll
import theano
import theano.tensor as tt
import collections


def pol2cart(v):
    return v[0] * np.stack([np.cos(v[1]), np.sin(v[1])])


def cart2pol(v):
    u1 = np.sqrt(v[0]**2 + v[1]**2)
    u2 = np.arctan2(v[1], v[0])
    return np.stack([u1, u2])

In [None]:

seed = 42

## simulation model

d = 41                    # edge length of (quadratic) receptive field
parametrization = 'gabor' # ['full', 'gaussian', 'gabor']
filter_shape = np.array((d,d,2))
m = maprf(filter_shape=filter_shape, 
          parametrization=parametrization,
          seed=seed, 
          duration=1000 )


## prior over simulation parameters

prior = collections.OrderedDict()
prior['b0'] = {'mu' : np.array([0.]), 'sigma' : np.array([1]) }
prior['vec_f']  = {'mu' : np.zeros(2), 'sigma' : 0.4 * np.ones(2) }
prior['vec_A']  = {'mu' : np.zeros(2), 'sigma' : 1.0 * np.ones(2) }
#prior = priors['kernel']['s']['width'] # fields 'a', 'b', 'f'
#mu, sigma = normal_from_ci(*zip(prior['b'], prior['a']), prior['f'])
prior['log_γ']  = {'mu' : np.array([-0.098]), 'sigma' : np.array([0.256])}
prior['log_b']  = {'mu' : np.array([ 0.955]), 'sigma' : np.array([0.236])}
#prior['xo'] = {'mu' : np.array([0.]), 'sigma' : np.array([5/np.sqrt(.5)])}
#prior['yo'] = {'mu' : np.array([0.]), 'sigma' : np.array([5/np.sqrt(.5)])}
ax_t = m._gen.axis_t
Λ =  np.diag(ax_t / 0.075 * np.exp(1 - ax_t / 0.075))
D = np.eye(ax_t.shape[0]) - np.eye(ax_t.shape[0], k=-1)
F = np.dot(D, D.T)
Σ = np.dot(Λ, np.linalg.inv(F).dot(Λ))
prior['kt'] = {'mu': np.zeros_like(ax_t), 'sigma': np.linalg.inv(D).dot(Λ)}
mu  = np.concatenate([prior[i][ 'mu'  ] for i in prior.keys()])
L = np.diag(np.concatenate([prior[i]['sigma'] for i in list(prior.keys())[:-1]]))
L = np.block([[L, np.zeros((L.shape[0], ax_t.size))], 
              [np.zeros((ax_t.size, L.shape[1])), prior['kt']['sigma']]])
p = dd.Gaussian(m=mu, S=L.T.dot(L), seed=seed)

## data summary staistics

s = maprfStats(n_summary=d*d)

g = dg.Default(model=m, prior=p, summary=s)

## network 

n_hiddens=[30,30]
n_filters=[16,16,16]
n_train_round = 1

network = dn.NeuralNet.NeuralNet(n_inputs = [d,d], 
                     n_outputs = m.n_params, 
                     n_components = 1, 
                     n_filters=n_filters, 
                     n_hiddens=n_hiddens,
                     seed=seed,
                     svi=False)
loss = -tt.mean(network.lprobs)


test_fun = theano.function([network.stats], [ll.get_output(network.layer['conv_'+str(i)]) for i in range(1,len(n_filters)+1)])
print('conv layer shapes:',  [test_fun(np.zeros((10,1,d,d)))[i].shape  for i in range(len(n_filters))])

trn_inputs = [network.params, network.stats]


In [None]:
epochs = 200
minibatch = 50
n_samples = 5000

## training data and true parameters, data, statistics

pars_true = np.array([-0.5, 
                      *pol2cart([0.7, 0.3]), 
                      *pol2cart([2., 1.]), 
                      np.log(1.),np.log(2.5), 
                      1.,0.])
obs = m.gen_single(pars_true)
obs_stats = s.calc([obs])

trn_data = g.gen(n_samples)
trn_data = (trn_data[0], trn_data[1].reshape(-1,1,d,d))

## training

t = dn.Trainer.Trainer(network, loss,
            trn_data=trn_data, trn_inputs=trn_inputs)
logs=[]
logs.append(t.train(epochs=epochs, minibatch=minibatch))

posterior = network.get_mog(obs_stats.reshape(1,1,d,d))


# bunch of example prior draws
plt.figure(figsize=(16,10))
for i in range(15):
    plt.subplot(3,5,i+1)
    plt.imshow(m.params_to_rf(p.gen().reshape(-1))[0], interpolation='None')
plt.title('RF prior draws')
plt.show()

plt.figure(figsize=(16,6))
plt.subplot(1,3,1)
plt.imshow(obs_stats.reshape(d,d), interpolation='None')
plt.title('data STA')
plt.subplot(1,3,2)
plt.imshow(m.params_to_rf(pars_true)[0], interpolation='None')
plt.title('ground-truth RF')
plt.subplot(1,3,3)
plt.imshow(m.params_to_rf(posterior.xs[0].m)[0], interpolation='None')
plt.title('posterior mean RF')
plt.show()

# bunch of example posterior draws
plt.figure(figsize=(16,10))
for i in range(15):
    plt.subplot(3,5,i+1)
    plt.imshow(m.params_to_rf(posterior.gen().reshape(-1))[0], interpolation='None')
plt.title('RF posterior draws')
plt.show()

plot_pdf(posterior.xs[0], lims=[-5,5], gt=pars_true.reshape(-1), figsize=(16,16),
labels_params=['b','A_1','A_2','f_1','f_2','log ratio','log width']+['kt_'+str(i) for i in range(filter_shape[2])]);


In [None]:
plot_pdf(p, lims=[-2,2], gt=pars_true.reshape(-1), figsize=(16,16),
labels_params=['b','A_1','A_2','f_1','f_2','log ratio','log width']+['kt_'+str(i) for i in range(filter_shape[2])]);


# compare with maprf sampling

In [None]:
m._gen.grid_x

In [None]:
import numpy as np
import numpy.random as nr
import maprf.config as config
import maprf.rfs.v1 as V1
import maprf.invlink as invlink
import maprf.glm as glm 
from maprf.utils import *
from maprf.data import SymbolicData
import time
import maprf.filters as filters
import maprf.kernels as kernels
# from maprf.sampling.slice import EllipticalSliceSampler as ESS

import theano.printing as printing
import theano.tensor as tt

import theano
from theano import In

import pickle
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

from os import path
from maprf.inference import *

def pyprint(var, filename):
    printing.pydotprint(var, format='pdf', outfile=filename, high_contrast=False, with_ids=True)

cfg = config.load(path.join('config.yaml'))

# The forward part of the model
rf = V1.SimpleLinear()
emt = glm.Poisson()
# inputs and outputs
data = [theano.shared(empty(3), 'frames'),
        theano.shared(empty(1, dtype='int64'))]
frames, spikes = data

# fill the grids
rf.grids['s'][0].set_value(m._gen.grid_x)
rf.grids['s'][1].set_value(m._gen.grid_y)
rf.grids['t'][0].set_value(m._gen.axis_t)

import numpy.linalg as linalg
# build prior for temporal kernel
ax_t = rf.grids['t'][0].get_value()
s = ax_t / 0.075
n = ax_t.shape[0]
Λ =  np.diag(s * np.exp(1 - s))
D = np.eye(n) - np.eye(n, k=-1)
F = np.dot(D, D.T)
Σ = np.dot(Λ, linalg.inv(F).dot(Λ))

# inference model
inference = Inference(rf, emt)
inference.priors = cfg['priors']
inference.priors['kernel']['t'] = {'mu': np.zeros_like(ax_t), 'sigma': linalg.cholesky(Σ)}


inference.add_sampler(GaborSampler())
inference.add_sampler(KernelSampler())

plt.imshow(Σ, interpolation='None')
plt.show()

print('inputs: ', inference.inputs)
print('priors: ', inference.priors)

inference.build(data)
inference.compile()

In [None]:
m._gen.axis_t

In [None]:

inference.loglik['xo'] = 0
inference.loglik['yo'] = 0
#inference.loglik['kt'] = np.array([ 0.02047043,  0.51640702,  0.61731474,  0.01362172, -0.37586342,
#       -0.3750627 , -0.23319645, -0.1131711 , -0.04670192, -0.01714343,
#       -0.00575457])
inference.loglik['kt'] = np.array([0.5, 0.])
inference.loglik['vec_A'] = np.zeros(2)  # np.array([2.0, 0.0])
inference.loglik['vec_f'] = np.zeros(2)  # 0.3 * np.array([np.cos(0.7), np.sin(0.7)])
inference.loglik['log_γ'] = 0.0
inference.loglik['log_b'] = np.log(2.5)

frames.set_value(obs['I'].reshape(-1,d,d))
spikes.set_value(obs['data'])
plt.plot(spikes.get_value())

print(np.sum(obs['data']))

In [None]:
inference.inputs

In [None]:
for v, b in inference.buffer.items():
    print('{} = {}'.format(v, b.get_value()))

In [None]:
import datetime

T, L = inference.sample(10)
T = {k.name: t for k, t in T.items()}


In [None]:
x = T['xo']
y = T['yo']

plt.figure(figsize=(15, 4))
plt.subplot(121)
plt.plot(x[500:])

plt.subplot(122)
plt.hist(x[500:], alpha=0.5, normed=True)
plt.show()


plt.figure(figsize=(15, 4))
plt.subplot(121)
plt.plot(x[500:], y[500:], '.k', alpha=0.1)
plt.show()

# plt.plot(T['vec_f'][:,0], T['vec_f'][:,1], '.k', alpha=0.1)
# plt.xlim((-.7, .7))
# plt.ylim((-.7, .7))
# plt.show()

# plt.plot(T['width'], T['gain'], '.-k', alpha=0.1)