In [None]:
#%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.utils import get_maprf_prior_01, setup_sim, setup_sampler, \
get_data_o, quick_plot, contour_draws
from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats

from scipy import io

seed = 42

In [None]:
# choose experiment, cell id and recording session

exp_id = 'R151013ct01'
idx_cell = 8
cell_id = str(idx_cell)
sess_id = 'g'

In [None]:
path='.'
sim_info = np.load(path +'/results/sim_info.npy')[()] # overall setup info (parametrization, stimulus size etc.)
d = sim_info['d']

# initial parameter estimate (= MCMC initial chain segment)
params_dict_init = { 'kernel' :  {'s' : {'ratio' : 0.9,
                                         'width' : 3.5,
                                         'gain'  : 0.3,
                                         'phase' : np.pi/2 - 0.001,
                                         'freq'  : 1.45,
                                         'angle' : 3*np.pi/2 - 0.1},
                                  't' : {'value' : np.array([1., 0.])}, 
                                  'l' : {'xo' : 0.3, 'yo' : -0.03 }}, 
                     'glm' :     {'bias' : 0.15, 
                                  'binsize' : 1./30}}

# simulator
m = model(filter_shape=([d,d,2]),
    duration=300, 
    dt=1./30., 
    params_ls=sim_info['params_ls'],
    seed = seed, 
    parametrization=sim_info['parametrization'])

# load stimulus, overwrite simulator stimulus
file_name = 'data/ActualMoviecat-noise_0.12_1.25_0.50_300_10_120px_4Hz_mag14.mat'
mv = io.loadmat(file_name, mdict=None, appendmat=True)['moviedata']
m.I = 1. * mv[:, 39:-40, :][13:-14,:,:].transpose(2,0,1).reshape(mv.shape[2],-1)
m.I = (m.I - np.mean(m.I)) / np.std(m.I)
m._gen.x = m.I.reshape(len(m.t), m.filter_shape[0], m.filter_shape[1])
m.n_params = 9
m.params_dict = params_dict_init.copy()
m.rng = np.random.RandomState(seed=seed)

# generator
p, prior = get_maprf_prior_01(sim_info['params_ls'], seed)
s = maprfStats(n_summary=d*d+1) # summary stats (d x d RF + spike_count)
def rej(x):
    # rejects summary statistic if number of spikes == 0
    return x[:,-1] > 0
g = dg.RejKernel(model=m, prior=p, summary=s, rej=rej, seed=seed)

# observations
obs_0 = g.model.gen_single()
obs_init = g.summary.calc([obs_0])
pars_init = g.model.read_params_buffer().copy()

stimTime = io.loadmat('data/'+exp_id+'/cell0'+cell_id+'/StimTime/'+exp_id +'_cell'+cell_id+sess_id+'_FramebasedSpknum.mat')['framebasedSpk'][:,0]
spkTime = io.loadmat('data/'+exp_id+'/cell0'+cell_id+'/SpkTime/'+exp_id +'_cell'+cell_id+sess_id+'_SpikeTiming.mat')['resid_feature']

print([stimTime[0], stimTime[-1]])

spks = np.histogram(spkTime/1000, stimTime)[0]
spks = np.hstack([np.zeros(1), spks]).astype(np.int)

obs = { 'data' : spks,
      'I' : g.model.I}
obs_stats = g.summary.calc([obs])

print(' observed data: ')
print('firing rate', obs_stats[0,-1] /(g.model.I.shape[0]*g.model.params_dict['glm']['binsize']) )
print('spike count', obs_stats[0,-1] )

print(' initial parameter guess: ')
print('spike count ', obs_init[0,-1])

plt.figure(figsize=(12, 4))
plt.subplot(1,3,1)
plt.imshow( obs_stats[0,:-1].reshape(d,d), interpolation='None')
plt.title('observed STA')
plt.subplot(1,3,2)
plt.imshow( obs_init[0,:-1].reshape(d,d), interpolation='None')
plt.title('STA of initial parameter guess')
plt.subplot(1,3,3)
plt.imshow( g.model.params_to_rf(pars_init)[0], interpolation='None')
plt.title('RF of initial parameter guess')
plt.show()


# MCMC

In [None]:

n_samples = 1000000

inference, data = setup_sampler(prior, obs, d, g, params_dict=params_dict_init, 
                          fix_position=False, parametrization='logit_φ')

inference.samplers[0].mu['logit_xo'] = prior['logit_xo']['mu'][0]
inference.samplers[0].mu['logit_yo'] = prior['logit_yo']['mu'][0]

inference.samplers[0].sd['logit_xo'] = prior['logit_xo']['sigma'][0]
inference.samplers[0].sd['logit_yo'] = prior['logit_yo']['sigma'][0]

T, L = inference.sample(n_samples)
T = {k.name: t for k, t in T.items()}

In [None]:
def scaled_expit_i(v):
    return 2. / (1. + np.exp(-v)) - 1

T['xo'], T['yo'] = scaled_expit_i(T['logit_xo']), scaled_expit_i(T['logit_yo'])

x,y = T['xo'], T['yo']

plt.figure(figsize=(15, 8))
plt.subplot(221)
plt.plot(x[0:])
plt.plot(y[0:])
plt.legend(('x-pos', 'y-pos'))

plt.subplot(222)
plt.hist(x[0:], alpha=0.5, normed=True)
plt.title('x-positions')

plt.subplot(224)
plt.hist(y[0:], alpha=0.5, normed=True)
plt.title('y-positions')

plt.subplot(223)
plt.plot(x[0:], y[0:], '.k', alpha=0.1)
plt.title('x/y')
plt.show()

In [None]:
inference.sample_biases(data, T, m.dt)

plt.figure(figsize=(12,5))
plt.subplot(2,1,1)
plt.plot(T['bias'])
print('mean: ' + str(T['bias'].mean()) + ', var: ' + str(T['bias'].var()))
plt.subplot(2,1,2)
plt.plot(T['λo'])
print('mean: ' + str(T['λo'].mean()) + ', var: ' + str(T['λo'].var()))
plt.show()


In [None]:
import theano

i = list(inference.inputs.values())
o = inference.logL

loglik = theano.function(i, o, updates=[], on_unused_input='warn')

loglik()

In [None]:
inference.samplers[0].get_point()

# inspect results

In [None]:
# random selection of MCMC samples

sta = obs_stats[0,:-1].reshape(d,d)
sta = (sta - sta.sum()) / np.sqrt( np.sum(sta**2))

plt.figure(figsize=(16,12))
i = 1
for t in np.sort(np.random.choice(T['gain'].shape[0], 12, replace=False)):
    params_dict = {'kernel' : {'s' : {}, 'l' : {}}, 'glm': {}}
    params_dict['glm']['bias'] = T['bias'][t]
    params_dict['kernel']['s']['phase'] = T['phase'][t]
    params_dict['kernel']['s']['angle'] = T['angle'][t] 
    params_dict['kernel']['s']['freq']  = T['freq'][t]
    params_dict['kernel']['s']['ratio'] = T['ratio'][t]
    params_dict['kernel']['s']['width'] = T['width'][t]
    params_dict['kernel']['s']['gain'] = T['gain'][t]
    params_dict['kernel']['l']['xo'] = T['xo'][t]
    params_dict['kernel']['l']['yo'] = T['yo'][t]

    axis_x = m.axis_x - params_dict['kernel']['l']['xo']
    axis_y = m.axis_y - params_dict['kernel']['l']['yo']    
    m._gen.grid_x, m._gen.grid_y = np.meshgrid(axis_x, axis_y)    
    
    ks = m._eval_ks(bias=params_dict['glm']['bias'], 
                    angle=params_dict['kernel']['s']['angle'],
                    freq=params_dict['kernel']['s']['freq'],
                    gain=params_dict['kernel']['s']['gain'],
                    phase=params_dict['kernel']['s']['phase'],
                    ratio=params_dict['kernel']['s']['ratio'],
                    width=params_dict['kernel']['s']['width'])
    
    plt.subplot(3,4,i)
    plt.imshow(np.hstack((ks.reshape(d,d), sta* params_dict['kernel']['s']['gain'] )), #m.params_to_rf(pars_true)[0])), 
               interpolation='None')
    plt.title('t =' + str(t))
    
    print('loc:' , [T['xo'][t], T['yo'][t]])    
    i += 1
    
plt.show()


In [None]:
# closer look at a single (late) MCMC chain element

m = model(filter_shape=([d,d,2]),
    duration=300, 
    dt=1./30., 
    params_ls=sim_info['params_ls'],
    seed = seed,
    parametrization=sim_info['parametrization'])

m.I = 1. * mv[:, 39:-40, :][13:-14,:,:].transpose(2,0,1).reshape(mv.shape[2],-1)
m.I = (m.I - np.mean(m.I)) / np.std(m.I)
m._gen.x = m.I.reshape(len(m.t), m.filter_shape[0], m.filter_shape[1])
m.n_params = 9

params_dict['glm']['binsize'] = 1./30
params_dict['kernel']['t'] = {'value' : np.array([1., 0.])}

m.params_dict = params_dict
pars_raw = m.read_params_buffer()
obs_test = m.gen_single()
obs_stats_test = g.summary.calc([obs_test])

plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None')
plt.title('observed STA')
plt.subplot(1,3,2)
plt.imshow(obs_stats_test[:,:-1].T.reshape(d,d), interpolation='None')
plt.title('STA of MCMC chain element #' + str(t))
plt.subplot(1,3,3)
plt.imshow(m.params_to_rf(pars_raw)[0], interpolation='None')
plt.title('RF of MCMC chain element #' + str(t))
plt.show()

obs_stats_test[0,-1]

In [None]:
# distribution of MCMC samples

samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias', 'gain', 'phase', 'freq','angle','ratio','width', 'xo', 'yo']])

pars_raw = np.array([ params_dict['glm']['bias'],
                      params_dict['kernel']['s']['gain'],
                      params_dict['kernel']['s']['phase'],
                      params_dict['kernel']['s']['freq'],
                      params_dict['kernel']['s']['angle'],
                      params_dict['kernel']['s']['ratio'],
                      params_dict['kernel']['s']['width'],
                      params_dict['kernel']['l']['xo'],
                      params_dict['kernel']['l']['yo'] ])

plot_pdf(g.prior, lims=[-3,3], gt=pars_raw.reshape(-1), figsize=(16,16), resolution=100, samples=samples,
         ticks=True, labels_params=['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width', 'xo', 'yo']);


# save results

In [None]:

fldr  = exp_id + '_cell'+cell_id+sess_id
savefile = './results/MCMC/' + fldr + '/maprf_MCMC_prior01_run_2_'+ str(n_samples)+'samples_param9_5min'
savefile

In [None]:
np.savez(savefile, {'T' : T, 'params_dict_init' : params_dict_init})
