In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt

import numpy as np

import pickle 

import sys
sys.path.append("../..")

from neuroprob import stats, tools, neural_utils, animal
import neuroprob.models as mdl

dev = tools.PyTorch()


plt.style.use(['paper.mplstyle'])

PyTorch version: 1.6.0+cu101
Using device: cuda:0


### Data

In [None]:
# Load data

In [None]:
behav_cov = hd_t, w_t, s_t, x_t, y_t

# binning
def binning(bin_size, spktrain):
    tbin, resamples, rc_t, (rhd_t, rx_t, ry_t) = neural_utils.BinTrain(bin_size, sample_bin, spktrain, 
                                                        spktrain.shape[1], (np.unwrap(hd_t), x_t, y_t), 
                                                        average_behav=True, binned=True)


    rw_t = (rhd_t[1:]-rhd_t[:-1])/tbin
    rw_t = np.concatenate((rw_t, rw_t[-1:]))

    rvx_t = (rx_t[1:]-rx_t[:-1])/tbin
    rvy_t = (ry_t[1:]-ry_t[:-1])/tbin
    rs_t = np.sqrt(rvx_t**2 + rvy_t**2)
    rs_t = np.concatenate((rs_t, rs_t[-1:]))
    rtime_t = np.arange(resamples)*tbin

    units_used = rc_t.shape[0]
    rcov = (tools.WrapPi(rhd_t, True), rw_t, rs_t, rx_t, ry_t, rtime_t)
    return rcov, units_used, tbin, resamples, rc_t

### Models

In [None]:
def GP_params(units, tbin, behav_tuple, num_induc):
    """
    Setup GP parameters with input (x, y, speed, theta, hd)
    """
    ind_list = [np.linspace(left_x, right_x, num_induc), \
                bottom_y + arena_height*np.random.rand(num_induc), \
                np.random.rand(num_induc)*100., \
                np.random.rand(num_induc)*2*np.pi, \
                np.random.rand(num_induc)*2*np.pi]

    l = 10.*np.ones(units)
    l_s = 10.*np.ones(units)
    l_ang = 2.*np.ones(units)
    kt = [('variance', v), 
          ('RBF', 'euclid', np.array([l, l, l_s])), 
          ('RBF', 'torus', np.array([l_ang, l_ang]))]
    
    covariates = (behav_tuple[0], behav_tuple[1], behav_tuple[2], behav_tuple[3], behav_tuple[4])
    VI_tuples = [(None, None, None, 1)]*len(covariates)


    inducing_points = np.array(ind_list).T[None, ...].repeat(units_, axis=0)
    gp_rate = mdl.nonparametrics.Gaussian_process(units, inducing_points, kt, VI_tuples, 
                                                  inv_link='exp', shared_kernel_params=False,
                                                  cov_type='factorized', mean=np.zeros((units)), 
                                                  whiten=True)
    gp_rate.set_params(tbin, jitter=1e-5)
    
    return gp_rate

In [None]:
# GP with variable regressors model fit and nonconvexity
nonconvex_trials = 3

for trial in range(nonconvex_trials):
    while True:
        try:
            glm_rate = GP_params(mode, behav_tuple, num_induc, maxspeed)

            likelihood = mdl.likelihoods.Poisson(units, 'exp')
            likelihood.set_params(tbin)

            glm = mdl.inference.nll_optimized([glm_rate], renewal_dist)
            glm.preprocess(list(covariates), covariates[0].shape[0], rc_t, batch_size=100000)
            glm.to(dev)

            # fit
            sch = lambda o: optim.lr_scheduler.MultiplicativeLR(o, lambda e: 0.9)
            opt_tuple = (optim.Adam, 100, sch)
            opt_lr_dict = {'default': 5*1e-2}
            glm.set_optimizers(opt_tuple, opt_lr_dict)

            annealing = lambda x: 1.0#min(1.0, 0.002*x)
            losses = glm.fit(3000, loss_margin=-1e1, stop_iters=100, anneal_func=annealing, 
                             cov_samples=1, ll_samples=10, ll_mode='MC', bound='ELBO')

            plt.figure()
            plt.plot(losses)
            plt.xlabel('epoch')
            plt.ylabel('NLL')
            plt.show()

            break
        except (RuntimeError, AssertionError):
            print('Retrying...')
            if retries > 1: # max retries
                print('Stopped.')
                raise ValueError
            retries += 1


In [None]:
# Place field
grid_size = [int(arena_width/10), int(arena_height/10)]
grid_shape = [[left_x, right_x], [bottom_y, top_y]]

def func(pos):
    prevshape = pos.shape[1:]
    x = pos[0].flatten()
    y = pos[1].flatten()
    covariates = [x, y, SP*np.ones_like(x), \
                  TH*np.ones_like(x), HD*np.ones_like(x)]
    return glm_rate.eval_rate(covariates, [0])[0].reshape(*prevshape)

(xx, yy), place_field_ = tools.compute_mesh(grid_size, grid_shape, func)
place_field.append(place_field_)
    
    
# GP tuning
steps= 100

covariates = (X[n]*np.ones(steps), Y[n]*np.ones(steps), np.linspace(0, maxspeed, steps), \
              TH*np.ones(steps), HD*np.ones(steps))
lower, mean, upper = glm_rate.eval_rate(covariates, [0], False)
lower_s.append(lower[0])
mean_s.append(mean[0])
upper_s.append(upper[0])
