In [None]:
#%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

#import pygpu
#import theano.gpuarray
#theano.gpuarray.use('cuda0')

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.utils import get_maprf_prior_01, setup_sim, setup_sampler, \
get_data_o, quick_plot, contour_draws
from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats

seed = 42

In [None]:
from scipy import io
file_name = 'data/ActualMoviecat-noise_0.12_1.25_0.50_300_10_120px_4Hz_mag14.mat'

out = io.loadmat(file_name, mdict=None, appendmat=True)
mv = out['moviedata']

In [None]:
path='.'
sim_info = np.load(path +'/results/sim_info.npy')[()]

d = 41
m = model(filter_shape=([d,d,2]),
    duration=300, 
    dt=1./30., 
    params_ls=sim_info['params_ls'],
    seed = seed, 
    parametrization=sim_info['parametrization'])

In [None]:

m.I = 1. * mv[:, 39:-40, :][13:-14,:,:].transpose(2,0,1).reshape(mv.shape[2],-1)
m.I = (m.I - np.mean(m.I)) / np.std(m.I)
m.n_params = 9

i = 200

plt.subplot(2,1,1)
plt.imshow(mv[:, :,i], interpolation='None')

plt.subplot(2,2,3)
plt.imshow(m.I[i, :].reshape(d,d), interpolation='None')

plt.subplot(2,2,4)
plt.imshow(mv[:, 26:-26,i], interpolation='None')
plt.show()


In [None]:
p, prior = get_maprf_prior_01(sim_info['params_ls'], seed)
s = maprfStats(n_summary=d*d+1) # summary stats (d x d RF + spike_count)

def rej(x):
    # rejects summary statistic if number of spikes == 0
    return x[:,-1] > 0

# generator object that auto-rejects some data-pairs (theta_i, x_i) right at sampling
g = dg.RejKernel(model=m, prior=p, summary=s, rej=rej, seed=seed)


params_dict_true = np.load('./results/toy_cells/toy_cell_' + str(6) + '.npy')[()]
params_dict_true['glm']['binsize'] = g.model.dt
g.model.rng = np.random.RandomState(seed=seed)
g.model.params_dict = params_dict_true.copy()


In [None]:
idx_cell = 5
cell_id = str(idx_cell)
sess_id = 'b'

stimTime = io.loadmat('data/cell0'+cell_id+'/StimTime/R150902ct01_cell'+cell_id+sess_id+'_FramebasedSpknum.mat')['framebasedSpk'][:,0]
spkTime = io.loadmat('data/cell0'+cell_id+'/SpkTime/R150902ct01_cell'+cell_id+sess_id+'_SpikeTiming.mat')['resid_feature']

print([stimTime[0], stimTime[-1]])


plt.plot(spkTime / 1000)
plt.plot([0, spkTime.size], stimTime[0]*np.ones(2), 'r')
plt.plot([0, spkTime.size], stimTime[-1]*np.ones(2), 'm')
plt.legend(['spike times', 'stim onset', 'stim offset'])
plt.xlabel('# spike')
plt.ylabel('spike time / 1000 (in seconds?)')
plt.savefig('R150902ct01_cell3b_spikes_check.pdf')
plt.show()

spks = np.histogram(spkTime/1000, stimTime)[0]
spks = np.hstack([np.zeros(1), spks]).astype(np.int)
plt.plot(stimTime, spks)
plt.show()

obs = { 'data' : spks,
      'I' : g.model.I}
obs_stats = g.summary.calc([obs])

plt.subplot(2,1,1)
plt.imshow(mv.dot(spks), interpolation='None', cmap='gray')
plt.subplot(2,2,3)
plt.imshow(g.model.I.reshape(-1,d,d).transpose(1,2,0).dot(spks), interpolation='None', cmap='gray')
plt.subplot(2,2,4)
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
plt.show()

plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
plt.savefig('R150902ct01_cell'+cell_id+sess_id+'_STA.pdf')
plt.show()

In [None]:
print( 'firing rate', obs_stats[0,-1] /(g.model.I.shape[0]*g.model.params_dict['glm']['binsize']) )


In [None]:
path='.'
sim_info = np.load(path +'/results/sim_info.npy')[()]
d = 41

params_dict_init = { 'kernel' :  {'s' : {'ratio' : 1.4,
                                         'width' : 1.5,
                                         'gain'  : 0.5,
                                         'phase' : np.pi/2 - 0.0001,
                                         'freq'  : 1.5,
                                         'angle' : 0.9 * np.pi/2},
                                  't' : {'value' : np.array([1., 0.])}, 
                                  'l' : {'xo' : 0.5, 'yo' : 0.1 }}, 
                     'glm' :     {'bias' : 1.5, 
                                  'binsize' : 1./30}}

# model

m = model(filter_shape=([d,d,2]),
    duration=300, 
    dt=1./30., 
    params_ls=sim_info['params_ls'],
    seed = seed, 
    parametrization=sim_info['parametrization'])

m.I = 1. * mv[:, 39:-40, :][13:-14,:,:].transpose(2,0,1).reshape(mv.shape[2],-1)
m.I = (m.I - np.mean(m.I)) / np.std(m.I)
m.n_params = 9

m.params_dict = params_dict_init.copy()
#m.rng = np.random.RandomState(seed=seed)

# generator

p, prior = get_maprf_prior_01(sim_info['params_ls'], seed)
s = maprfStats(n_summary=d*d+1) # summary stats (d x d RF + spike_count)
def rej(x):
    # rejects summary statistic if number of spikes == 0
    return x[:,-1] > 0
g = dg.RejKernel(model=m, prior=p, summary=s, rej=rej, seed=seed)

# observations

obs_0 = g.model.gen_single()
obs_init = g.summary.calc([obs_0])

idx_cell, sess_id = 5, 'b'
cell_id = str(idx_cell)
stimTime = io.loadmat('data/cell0'+cell_id+'/StimTime/R150902ct01_cell'+cell_id+sess_id+'_FramebasedSpknum.mat')['framebasedSpk'][:,0]
spkTime = io.loadmat('data/cell0'+cell_id+'/SpkTime/R150902ct01_cell'+cell_id+sess_id+'_SpikeTiming.mat')['resid_feature']
spks = np.hstack([np.zeros(1), np.histogram(spkTime/1000, stimTime)[0]]).astype(np.int)

obs = { 'data' : spks, 'I' : g.model.I}
obs_stats = g.summary.calc([obs])
pars_true = g.model.read_params_buffer().copy()
rf  = g.model.params_to_rf(pars_true)[0]

sta = obs_stats[0,:-1].reshape(d,d)

plt.figure(figsize=(12, 4))
plt.imshow( np.hstack( (obs_init[0,:-1].reshape(d,d), rf, sta) ), interpolation='None')
plt.show()

print('#spikes : ', (obs_0['data'].sum(), obs_init[0,-1]))

In [None]:

n_samples = 500000

inference, data = setup_sampler(prior, obs, d, g, params_dict=params_dict_init, 
                          fix_position=False, parametrization='logit_φ')

inference.samplers[0].mu['logit_xo'] = prior['logit_xo']['mu'][0]
inference.samplers[0].mu['logit_yo'] = prior['logit_yo']['mu'][0]

inference.samplers[0].sd['logit_xo'] = prior['logit_xo']['sigma'][0]
inference.samplers[0].sd['logit_yo'] = prior['logit_yo']['sigma'][0]

T, L = inference.sample(n_samples)
T = {k.name: t for k, t in T.items()}

In [None]:
def scaled_expit_i(v):
    return 2. / (1. + np.exp(-v)) - 1

T['xo'], T['yo'] = scaled_expit_i(T['logit_xo']), scaled_expit_i(T['logit_yo'])

x,y = T['xo'], T['yo']

plt.figure(figsize=(15, 8))
plt.subplot(221)
plt.plot(x[0:])
plt.plot(y[0:])

plt.subplot(222)
plt.hist(x[0:], alpha=0.5, normed=True)

plt.subplot(224)
plt.hist(y[0:], alpha=0.5, normed=True)

plt.subplot(223)
plt.plot(x[0:], y[0:], '.k', alpha=0.1)
plt.show()

In [None]:
inference.sample_biases(data, T, m.dt)

plt.figure(figsize=(12,5))
plt.subplot(2,1,1)
plt.plot(T['bias'])
print('mean: ' + str(T['bias'].mean()) + ', var: ' + str(T['bias'].var()))
plt.subplot(2,1,2)
plt.plot(T['λo'])
print('mean: ' + str(T['λo'].mean()) + ', var: ' + str(T['λo'].var()))
plt.show()

In [None]:
sta = obs_stats[0,:-1].reshape(d,d)
sta = (sta - sta.sum()) / np.sqrt( np.sum(sta**2))

plt.figure(figsize=(16,12))
i = 1
for t in np.sort(np.random.choice(T['gain'].shape[0], 12, replace=False)):
    params_dict = {'kernel' : {'s' : {}, 'l' : {}}, 'glm': {}}
    params_dict['glm']['bias'] = T['bias'][t]
    params_dict['kernel']['s']['phase'] = T['phase'][t]
    params_dict['kernel']['s']['angle'] = T['angle'][t] 
    params_dict['kernel']['s']['freq']  = T['freq'][t]
    params_dict['kernel']['s']['ratio'] = T['ratio'][t]
    params_dict['kernel']['s']['width'] = T['width'][t]
    params_dict['kernel']['s']['gain'] = T['gain'][t]
    params_dict['kernel']['l']['xo'] = T['xo'][t]
    params_dict['kernel']['l']['yo'] = T['yo'][t]

    axis_x = m.axis_x - params_dict['kernel']['l']['xo']
    axis_y = m.axis_y - params_dict['kernel']['l']['yo']    
    m._gen.grid_x, m._gen.grid_y = np.meshgrid(axis_x, axis_y)    
    
    ks = m._eval_ks(bias=params_dict['glm']['bias'], 
                    angle=params_dict['kernel']['s']['angle'],
                    freq=params_dict['kernel']['s']['freq'],
                    gain=params_dict['kernel']['s']['gain'],
                    phase=params_dict['kernel']['s']['phase'],
                    ratio=params_dict['kernel']['s']['ratio'],
                    width=params_dict['kernel']['s']['width'])
    
    plt.subplot(3,4,i)
    plt.imshow(np.hstack((ks.reshape(d,d), sta* params_dict['kernel']['s']['gain'] )), #m.params_to_rf(pars_true)[0])), 
               interpolation='None')
    plt.title('t =' + str(t))
    
    print('loc:' , [T['xo'][t], T['yo'][t]])    
    i += 1
    
plt.show()


In [None]:

burnin = 50

for key in ['bias', 'λo', 
            'gain', 'log_A', 'phase', 'logit_φ',
            'angle', 'logit_θ', 'freq', 'log_f',
            'ratio', 'width', 'log_γ', 'log_b', 
            'xo', 'yo', 'logit_xo', 'logit_yo'
            ]:
    
    if key in T.keys():
        x = T[key][burnin:]
        plt.hist(x, bins=np.linspace(x.min(), x.max(), 20), alpha=0.5, normed=True)
        plt.title(key)
        plt.show()
        print('mean:', x.mean())
        print('var:', x.var())
        

In [None]:
samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['bias', 'gain', 'phase', 'freq','angle','ratio','width', 'xo', 'yo']])

pars_raw = np.array([ params_dict_true['glm']['bias'],
                      params_dict_true['kernel']['s']['gain'],
                      params_dict_true['kernel']['s']['phase'],
                      params_dict_true['kernel']['s']['freq'],
                      params_dict_true['kernel']['s']['angle'],
                      params_dict_true['kernel']['s']['ratio'],
                      params_dict_true['kernel']['s']['width'],
                      params_dict_true['kernel']['l']['xo'],
                      params_dict_true['kernel']['l']['yo'] ])

plot_pdf(g.prior, lims=[-3,3], gt=pars_raw.reshape(-1), figsize=(16,16), resolution=100, samples=samples.T,
         ticks=True, labels_params=['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width', 'xo', 'yo']);


In [None]:
rf  = g.model.params_to_rf(pars_raw)[0]

plt.figure(figsize=(12, 4))
plt.imshow( np.hstack( (obs_init[0,:-1].reshape(d,d), rf, obs_stats[0,:-1].reshape(d,d)) ), interpolation='None')
plt.show()


In [None]:
'R150902ct01_cell'+cell_id+sess_id

In [None]:

fldr  = 'R150902ct01_cell'+cell_id+sess_id
savefile = './results/MCMC/' + fldr + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param9_5min'


In [None]:
np.savez(savefile, {'T' : T, 'params_dict_init' : params_dict_init})
