In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator # for minor ticks
from matplotlib import gridspec
from matplotlib import patches
from matplotlib.offsetbox import TextArea, VPacker, AnnotationBbox

import numpy as np

import pickle

import sys, time
sys.path.append("../..")

from neuroprob import stats, tools, neural_utils
import neuroprob.models as mdl

dev = tools.PyTorch()

import models

plt.style.use(['paper.mplstyle'])

### Place cells show overdispersion

In [None]:
rcov, units_used, tbin, resamples, rc_t = binning(20, spktrain)

show_neuron = [5, 13, 24]
show_neurons = len(show_neuron)



# computation
ll_mode, r_mode, spk_cpl, num_induc = ('IP', 'hd', None, 8)
glm, cov_used = models.set_glm(r_mode, ll_mode, spk_cpl, rcov, units_used, tbin, rc_t, num_induc, inv_link='exp')
glm.to(dev)

model_name = 'GPR_{}_{}'.format(session_id, r_mode)
checkpoint = torch.load('./checkpoint/' + model_name)
glm.load_state_dict(checkpoint['glm'])

steps= 100
covariates = np.linspace(0, 2*np.pi, steps)[None, :]

lower, mean, upper = glm.rate_model[0].eval_rate(covariates, show_neuron, 'posterior')
cnt_tuples_pl, _ = models.compute_count_stats(glm, 'IP', tbin, rc_t, cov_used, show_neuron, 
                                              traj_len=100, start=0, T=resamples, bs=10000)


# large bin rates
bsize = 1000
bin_time, bin_samples, bc_t, (bhd_t,) = neural_utils.BinTrain(bsize, sample_bin, spktrain, 
                                                    spktrain.shape[1], (hd_t,), 
                                                    average_behav=False, binned=True)

brate = bc_t[show_neuron, :]/bin_time



rcov, units_used, tbin, resamples, rc_t = binning(1, spktrain)

# computation
ll_mode, r_mode, spk_cpl, num_induc = ('IP', 'hd', None, 8)
rr_mode = r_mode + (spk_cpl if spk_cpl is not None else '')

k = 0
trunc = 250000
trc_behav = [rc[k*trunc:(k+1)*trunc] for rc in rcov]
Y = (rc_t[:, k*trunc:(k+1)*trunc] > 0) # correct for duplicate spikes
glm, _ = models.set_glm(r_mode, ll_mode, spk_cpl, trc_behav, units_used, sample_bin, Y, num_induc, inv_link='exp', jitter=1e-4)
glm.to(dev)

model_name = 'GPPP_{}_{}_{}_{}'.format(session_id, ll_mode, rr_mode, k)
checkpoint = torch.load('./checkpoint/' + model_name)
glm.load_state_dict(checkpoint['glm'])

isi_tuples_pl = models.compute_isi_stats(glm, ll_mode, tbin, rc_t, cov_used, show_neuron, start=0, T=trunc, bs=10000)

In [None]:
rcov, units_used, tbin, resamples, rc_t = binning(20, spktrain)

show_neuron = [9, 11, # PoS
               15, 25] # ANT
show_neurons = len(show_neuron)



# computation
ll_mode, r_mode, spk_cpl, num_induc = ('IP', 'hd_w', None, 16)
glm, cov_used = models.set_glm(r_mode, ll_mode, spk_cpl, rcov, units_used, tbin, rc_t, num_induc, inv_link='exp')
glm.to(dev)

model_name = 'GPR_{}_{}'.format(session_id, r_mode)
checkpoint = torch.load('./checkpoint/' + model_name)
glm.load_state_dict(checkpoint['glm'])


# compute preferred HD
pref_hd = []
ATI_N = 51
w_arr = np.linspace(-5., 5., ATI_N)
for w_eval in w_arr:
    steps = 300
    covariates = [np.linspace(0, 2*np.pi, steps+1)[:-1], 
                  w_eval*np.ones(steps)]
    neurons = np.arange(units_used)
    mean = glm.rate_model[0].eval_rate(covariates, neurons, 'mean')
    Z = np.cos(covariates[0]) + np.sin(covariates[0])*1j # CoM angle
    pref_hd.append(np.angle((Z[None, :]*mean).mean(-1)) % (2*np.pi))
    #pref_hd.append(covariates[0][mean.argmax(-1)])
pref_hd = np.array(pref_hd)

# ATI
ATI = []
res_var = []
for k in range(units_used):
    _, a, shift, losses = tools.circ_lin_regression(pref_hd[:, k], w_arr/(2*np.pi), dev='cpu', iters=1000, lr=1e-2)
    ATI.append(-a)
    res_var.append(losses[-1])
ATI = np.array(ATI)
res_var = np.array(res_var)


# tuning
grid_size_hdw = (50, 40)
grid_shape_hdw = [[0, 2*np.pi], [-20., 20.]]
    
def func(pos):
    prevshape = pos.shape[1:]
    x = pos[0].flatten()
    y = pos[1].flatten()
    covariates = np.array([x, y])
    return glm.rate_model[0].eval_rate(covariates, show_neuron).reshape(show_neurons, *prevshape)

_, field_hdw = tools.compute_mesh(grid_size_hdw, grid_shape_hdw, func)

    
    
ll_mode, r_mode, spk_cpl, num_induc = ('IP', 'hd_w_s', None, 20)
glm, cov_used = models.set_glm(r_mode, ll_mode, spk_cpl, rcov, units_used, tbin, rc_t, num_induc, inv_link='exp')
glm.to(dev)

model_name = 'GPR_{}_{}'.format(session_id, r_mode)
checkpoint = torch.load('./checkpoint/' + model_name)
glm.load_state_dict(checkpoint['glm'])
    
mean_s = []
lower_s = []
upper_s = []
steps = 100
covariates_s = np.linspace(0, 30., steps)
for n in show_neuron:
    covariates = [pref_hd[ATI_N//2, n]*np.ones(steps), # pref hd at zero AHV
              np.zeros(steps),
              covariates_s]
    l, m, u = glm.rate_model[0].eval_rate(covariates, [n], 'posterior', n_samp=100000)
    lower_s.append(l[0])
    mean_s.append(m[0])
    upper_s.append(u[0])
    
    
ll_mode, r_mode, spk_cpl, num_induc = ('IP', 'hd_w_s_pos', None, 32)
glm, cov_used = models.set_glm(r_mode, ll_mode, spk_cpl, rcov, units_used, tbin, rc_t, num_induc, inv_link='exp')
glm.to(dev)

model_name = 'GPR_{}_{}'.format(session_id, r_mode)
checkpoint = torch.load('./checkpoint/' + model_name)
glm.load_state_dict(checkpoint['glm'])

grid_size_pos = (50, 40)
grid_shape_pos = [[left_x, right_x], [bottom_y, top_y]]
field_pos = []
for n in show_neuron:
    def func(pos):
        prevshape = pos.shape[1:]
        x = pos[0].flatten()
        y = pos[1].flatten()
        hd = pref_hd[ATI_N//2, n]*np.ones_like(x)
        w = 0.*np.ones_like(x)
        s = 0.*np.ones_like(x)
        covariates = np.array([hd, w, s, x, y])
        return glm.rate_model[0].eval_rate(covariates, n).reshape(*prevshape)

    field_pos.append(tools.compute_mesh(grid_size_pos, grid_shape_pos, func)[1])
field_pos = np.stack(field_pos)

In [None]:
modes = [('IP', 'hd', None, 8), 
         ('IP', 'hd_w', None, 16), 
         ('IP', 'hd_w_s', None, 20), 
         ('IP', 'hd_w_s_pos', None, 32)]



T_DS_arr = []
T_KS_arr = []
p_DS_arr = []
I_arr = []

folds = 10
cv_runs = np.arange(10)
pred_LL_arr = []

for m in modes:
    ll_mode, r_mode, spk_cpl, num_induc = m
    print(m)

    T_DS_ = []
    T_KS_ = []
    p_DS_ = []
    pred_LL_ = []

    glm, cov_used = models.set_glm(r_mode, ll_mode, spk_cpl, rcov, units_used, tbin, rc_t, num_induc, inv_link='exp')
    glm.to(dev)

    model_name = 'GPR_{}_{}'.format(session_id, r_mode)
    checkpoint = torch.load('./checkpoint/' + model_name)
    glm.load_state_dict(checkpoint['glm'])
    
    cv_set = neural_utils.SpikeTrainCV(folds, rc_t, rc_t.shape[1], rcov)
    for kcv in cv_runs:
        ftrain, fcov, vtrain, vcov = cv_set[kcv]
        vcov_used = models.cov_used(r_mode, vcov)
        pred_LL_.append(models.pred_ll(glm, vtrain, vcov_used, vtrain.shape[-1], np.arange(units_used)))

    cnt_tuples, I = models.compute_count_stats(glm, 'IP', tbin, rc_t, cov_used, np.arange(units_used), 100, 
                                               start=0, T=resamples, bs=5000)
    for cnt_tuple in cnt_tuples:
        q_cdf, Z_DS, T_KS, s_DS, s_KS, p_DS, p_KS = cnt_tuple
        T_DS_.append(Z_DS)
        T_KS_.append(T_KS)
        p_DS_.append(p_DS)
        
    T_DS_arr.append(T_DS_)
    T_KS_arr.append(T_KS_)
    p_DS_arr.append(p_DS_)
    I_arr.append(I)
    pred_LL_arr.append(pred_LL_)
    
T_DS_arr = np.array(T_DS_arr)
T_KS_arr = np.array(T_KS_arr)
p_DS_arr = np.array(p_DS_arr)
I_arr = np.array(I_arr)
pred_LL_arr = np.array(pred_LL_arr)

In [None]:
grid_size = [50, 40]
grid_shape = [[left_x, right_x], [bottom_y, top_y]]
N_x = 6
N_y = 6
neurons = units_used#N_x*N_y

def func(pos):
    prevshape = pos.shape[1:]
    x = pos[0].flatten()
    y = pos[1].flatten()
    theta = np.zeros_like(x)
    covariates = [x, y, theta]
    return gp_lvm.eval_rate(covariates, np.arange(neurons)).reshape(neurons, *prevshape)

(xx, yy), place_field = tools.compute_mesh(grid_size, grid_shape, func)




In [None]:
pickle.dump((cnt_tuples_pl, isi_tuples_pl, covariates, lower, mean, upper, brate, bhd_t, \
             T_DS_arrl, T_KS_arrl, traj_arr, T_DS_arrs, T_KS_arrs, I_arrs, pred_LL, \
             show_neuron, tbin, shift_times), open('./output/ca1.p', 'wb'))