# Imports

In [None]:
%load_ext autoreload
%autoreload 2

import os
import argparse
import glob
import sys 
import yaml 
import glob
import h5py 
import ray
import logging 
import json
import gc
import cv2
import time
import itertools

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
# import io_dict_to_hdf5 as ioh5
import xarray as xr
import scipy.linalg as linalg
import scipy.sparse as sparse
import matplotlib.gridspec as gridspec

from tqdm.notebook import tqdm, trange
from matplotlib.backends.backend_pdf import PdfPages
from scipy import signal
from pathlib import Path
from scipy.optimize import minimize_scalar,minimize
from scipy.interpolate import interp1d
from scipy.ndimage import shift as imshift
from sklearn.model_selection import train_test_split, GroupShuffleSplit
from sklearn.pipeline import make_pipeline
from sklearn import linear_model as lm 
from scipy.stats import binned_statistic
from sklearn.utils import shuffle
from sklearn.metrics import r2_score, mean_poisson_deviance
from pyglmnet import GLMCV, GLM

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sys.path.append(str(Path('.').absolute()))
from utils import *
import io_dict_to_hdf5 as ioh5
from format_data import load_ephys_data_aligned

pd.set_option('display.max_rows', None)

ray.init(
    ignore_reinit_error=True,
    logging_level=logging.ERROR,
)
# print(f'Dashboard URL: http://{ray.get_dashboard_url()}')
# print('Dashboard URL: http://localhost:{}'.format(ray.get_dashboard_url().split(':')[-1]))

# Gather Data

In [None]:
def load_train_test(file_dict, save_dir, model_dt=.1, frac=.1, train_size=.7, do_shuffle=False, do_norm=False, free_move=True, has_imu=True, has_mouse=False,):
    ##### Load in preprocessed data #####
    data = load_ephys_data_aligned(file_dict, save_dir, model_dt=model_dt, free_move=free_move, has_imu=has_imu, has_mouse=has_mouse,)
    if free_move:
        ##### Find 'good' timepoints when mouse is active #####
        nan_idxs = []
        for key in data.keys():
            nan_idxs.append(np.where(np.isnan(data[key]))[0])
        good_idxs = np.ones(len(data['model_active']),dtype=bool)
        good_idxs[data['model_active']<.5] = False
        good_idxs[np.unique(np.hstack(nan_idxs))] = False
    else:
        good_idxs = np.where((np.abs(data['model_th'])<10) & (np.abs(data['model_phi'])<10))[0]
    
    data['raw_nsp'] = data['model_nsp'].copy()
    ##### return only active data #####
    for key in data.keys():
        if (key != 'model_nsp') & (key != 'model_active') & (key != 'unit_nums'):
            data[key] = data[key][good_idxs] # interp_nans(data[key]).astype(float)
        elif (key == 'model_nsp'):
            data[key] = data[key][good_idxs]
        elif (key == 'unit_nums'):
            pass
    gss = GroupShuffleSplit(n_splits=1, train_size=train_size, random_state=42)
    nT = data['model_nsp'].shape[0]
    groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((frac*(i-1))*nT)) for i in range(1,int(1/frac)+1)])

    for train_idx, test_idx in gss.split(np.arange(len(data['model_nsp'])), groups=groups):
        print("TRAIN:", len(train_idx), "TEST:", len(test_idx))


    data['model_dth'] = np.diff(data['model_th'],append=0)
    data['model_dphi'] = np.diff(data['model_phi'],append=0)

    data['model_vid_sm'] = (data['model_vid_sm'] - np.mean(data['model_vid_sm'],axis=0))/np.nanstd(data['model_vid_sm'],axis=0)
    data['model_vid_sm'][np.isnan(data['model_vid_sm'])]=0
    if do_norm:
        data['model_th'] = (data['model_th'] - np.mean(data['model_th'],axis=0))/np.std(data['model_th'],axis=0) 
        data['model_phi'] = (data['model_phi'] - np.mean(data['model_phi'],axis=0))/np.std(data['model_phi'],axis=0) 
        if free_move:
            data['model_roll'] = (data['model_roll'] - np.mean(data['model_roll'],axis=0))/np.std(data['model_roll'],axis=0) 
            data['model_pitch'] = (data['model_pitch'] - np.mean(data['model_pitch'],axis=0))/np.std(data['model_pitch'],axis=0) 

    ##### Split Data by train/test #####
    data_train_test = {
        'train_vid': data['model_vid_sm'][train_idx],
        'test_vid': data['model_vid_sm'][test_idx],
        'train_nsp': shuffle(data['model_nsp'][train_idx],random_state=42) if do_shuffle else data['model_nsp'][train_idx],
        'test_nsp': shuffle(data['model_nsp'][test_idx],random_state=42) if do_shuffle else data['model_nsp'][test_idx],
        'train_th': data['model_th'][train_idx],
        'test_th': data['model_th'][test_idx],
        'train_phi': data['model_phi'][train_idx],
        'test_phi': data['model_phi'][test_idx],
        'train_roll': data['model_roll'][train_idx] if free_move else [],
        'test_roll': data['model_roll'][test_idx] if free_move else [],
        'train_pitch': data['model_pitch'][train_idx] if free_move else [],
        'test_pitch': data['model_pitch'][test_idx] if free_move else [],
        'train_t': data['model_t'][train_idx],
        'test_t': data['model_t'][test_idx],
        'train_dth': data['model_dth'][train_idx],
        'test_dth': data['model_dth'][test_idx],
        'train_dphi': data['model_dphi'][train_idx],
        'test_dphi': data['model_dphi'][test_idx],
        'train_gz': data['model_gz'][train_idx] if free_move else [],
        'test_gz': data['model_gz'][test_idx] if free_move else [],
    }

    d1 = data
    d1.update(data_train_test)
    return d1,train_idx,test_idx


def f_add(alpha,stat_range,stat_all):
    return np.mean((stat_range - (stat_all+alpha))**2)

def f_mult(alpha,stat_range,stat_all):
    return np.mean((stat_range - (stat_all*alpha))**2)

In [None]:
free_move = True
if free_move:
    stim_type = 'fm1'
else:
    stim_type = 'hf1_wn' # 'fm1' # 
# 012821/EE8P6LT
# 128: 070921/J553RT
date_ani = '070921/J553RT' #'062921/G6HCK1ALTRN'
data_dir  = Path('~/Goeppert/freely_moving_ephys/ephys_recordings/').expanduser() / date_ani / stim_type
save_dir  = check_path(Path('~/Research/SensoryMotorPred_Data/data/').expanduser() / date_ani, stim_type)
FigPath = check_path(Path('~/Research/SensoryMotorPred_Data').expanduser(),'Figures/Encoding')
FigPath = check_path(FigPath/date_ani, stim_type)

print('savd_dir:',save_dir)
print('data_dir:',data_dir)
print('FigPath:', FigPath)
# with open(save_dir / 'file_dict.json','r') as fp:
#     file_dict = json.load(fp)

In [None]:
file_dict = {'cell': 0,
 'drop_slow_frames': True,
 'ephys': list(data_dir.glob('*ephys_merge.json'))[0].as_posix(),
 'ephys_bin': list(data_dir.glob('*Ephys.bin'))[0].as_posix(),
 'eye': list(data_dir.glob('*REYE.nc'))[0].as_posix(),
 'imu': list(data_dir.glob('*imu.nc'))[0].as_posix() if stim_type=='fm1' else None,
 'mapping_json': '/home/seuss/Research/Github/FreelyMovingEphys/probes/channel_maps.json',
 'mp4': True,
 'name': '01221_EE8P6LT_control_Rig2_'+stim_type, #070921_J553RT
 'probe_name': 'DB_P128-6',
 'save': data_dir.as_posix(),
 'speed': list(data_dir.glob('*speed.nc'))[0].as_posix() if stim_type=='hf1_wn' else None,
 'stim_type': 'light',
 'top': list(data_dir.glob('*TOP1.nc'))[0].as_posix() if stim_type=='fm1' else None,
 'world': list(data_dir.glob('*world.nc'))[0].as_posix(),}

In [None]:
model_dt = .05
do_shuffle=False
do_norm = False
data,train_idx,test_idx = load_train_test(file_dict, save_dir, model_dt=model_dt, do_shuffle=do_shuffle, do_norm=do_norm,free_move=free_move, has_imu=free_move, has_mouse=False)
locals().update(data)

lag_list = np.array([-2,-1,0,1,2]) #np.array([-1,0,1,2,3]) #,np.arange(minlag,maxlag,np.floor((maxlag-minlag)/nt_glm_lag).astype(int))
nt_glm_lag = len(lag_list)
print(lag_list,1000*lag_list*model_dt)
do_shuffle=False
model_type = 'Pytorch'

# Testing Tuning Curves

In [None]:
# Create Tuning curve for theta
def tuning_curve(model_nsp, var, model_dt = .025, N_bins=10, Nstds=3):
    var_range = np.linspace(np.nanmean(var)-Nstds*np.nanstd(var), np.nanmean(var)+Nstds*np.nanstd(var),N_bins)
    tuning = np.zeros((model_nsp.shape[-1],len(var_range)-1))
    tuning_std = np.zeros((model_nsp.shape[-1],len(var_range)-1))
    for n in range(model_nsp.shape[-1]):
        for j in range(len(var_range)-1):
            usePts = (var>=var_range[j]) & (var<var_range[j+1])
            tuning[n,j] = np.nanmean(model_nsp[usePts,n])/model_dt
            tuning_std[n,j] = (np.nanstd(model_nsp[usePts,n])/model_dt)/ np.sqrt(np.count_nonzero(usePts))
    return tuning, tuning_std, var_range[:-1]


# VisMov Poisson GLM

In [None]:
class PoissonGLM_VM(nn.Module):
    def __init__(self, in_features, out_features, bias=True, reg_lam=None, reg_alph=None, move_features=None):
        super(PoissonGLM_VM, self).__init__()
        self.move_features = move_features
        if self.move_features != None:
            self.lam_m = torch.nn.Parameter(torch.Tensor(out_features))
            self.alpha_m = torch.nn.Parameter(torch.Tensor(out_features))
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features),)
        self.reg_lam = reg_lam
        self.reg_alph = reg_alph
        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        if self.reg_lam != None:
            self.lam = torch.nn.Parameter(torch.Tensor(out_features))
        if self.reg_alph != None:
            self.alpha = torch.nn.Parameter(torch.Tensor(out_features))
            
        self.lossfn = torch.nn.PoissonNLLLoss(log_input=True,reduction='mean')
        self.reset_parameters()
        
    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight) #, a=np.sqrt(5)
        if self.reg_lam != None:
            torch.nn.init.constant_(self.lam,self.reg_lam)
        if self.reg_alph != None:
            torch.nn.init.constant_(self.alpha,self.reg_alph)
        if self.move_features != None:
            torch.nn.init.constant_(self.lam_m,1)#self.reg_lam)
            torch.nn.init.constant_(self.alpha_m,1)#self.reg_alph)
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / np.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias, -bound, bound)
        
    def forward(self, inputs):
        x, y = inputs.shape
        if y != self.in_features:
            print(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
            return 0
        output = inputs.matmul(self.weight.t())
        if self.bias is not None:
            output = output + self.bias
        ret = torch.log1p(torch.exp(output))
        return ret
    
    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )
    
    def loss(self,Yhat, Y): 
        # self.lossfn(Yhat,Y) 
        if self.reg_lam != None:
            if self.move_features != None:
                loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0) + torch.clamp(self.lam,min=0)*(torch.linalg.norm(self.weight[:,:-self.move_features],axis=1,ord=2)) + torch.clamp(self.alpha,0,1)*torch.linalg.norm(self.weight[:,:-self.move_features],axis=1,ord=1) +\
                            torch.clamp(self.lam_m,min=0)*(torch.linalg.norm(self.weight[:,-self.move_features:],axis=1,ord=2)) + torch.clamp(self.alpha_m,0,1)*torch.linalg.norm(self.weight[:,-self.move_features:],axis=1,ord=1) 
            else:
                loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0) + torch.clamp(self.lam,min=0)*(torch.linalg.norm(self.weight,axis=1,ord=2)) + torch.clamp(self.alpha,0,1)*torch.linalg.norm(self.weight,axis=1,ord=1) 
        else: 
            loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0)
            
#         if self.move_features != None:
#             loss_vec = torch.mean(Yhat-Y*torch.log(Yhat)) + torch.clamp(self.lam,min=0)*(torch.linalg.norm(self.weight,axis=1)) + torch.clamp(self.alpha,0,1)*torch.sum(torch.abs(self.weight),axis=1) 
        return loss_vec


In [None]:
lag_list = np.array([-2,-1,0,1,2]) #np.array([-1,0,1,2,3]) #,np.arange(minlag,maxlag,np.floor((maxlag-minlag)/nt_glm_lag).astype(int))
nt_glm_lag = len(lag_list)
print(lag_list,1000*lag_list*model_dt)
do_shuffle = False
model_type = 'Pytorch'


# for do_shuffle in [False,True]:
# Load Data
data, train_idx, test_idx = load_train_test(file_dict, save_dir, model_dt=model_dt, do_shuffle=do_shuffle, do_norm=True,free_move=free_move, has_imu=free_move, has_mouse=False)
locals().update(data)

# Initialize movement combinations
titles = np.array(['Theta','Phi','Roll','Pitch']) # 'dg_p','dg_n' 'roll','pitch'
titles_all = []
for n in range(1,len(titles)+1):
    perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))

# train_dgaze_p = train_dth + np.diff(train_gz,append=0)
# train_dgaze_n = train_dth - np.diff(train_gz,append=0)
# test_dgaze_p = test_dth + np.diff(test_gz,append=0)
# test_dgaze_n = test_dth - np.diff(test_gz,append=0)
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis])) 
model_move = np.hstack((model_th[:,np.newaxis],model_phi[:,np.newaxis],model_roll[:,np.newaxis],model_pitch[:,np.newaxis]))
model_move = model_move - np.mean(model_move,axis=0)
move_test = move_test - np.mean(move_test,axis=0)

##### Start GLM Parallel Processing #####
nks = np.shape(train_vid)[1:]; nk = nks[0]*nks[1]*nt_glm_lag
n=4; ind=0
perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
##### Start GLM Parallel Processing #####
# Reshape data (video) into (T*n)xN array
rolled_vid = np.hstack([np.roll(model_vid_sm, nframes, axis=0) for nframes in lag_list]) # nt_glm_lag
x_train = rolled_vid[train_idx].reshape(len(train_idx),-1)
x_test = rolled_vid[test_idx].reshape(len(test_idx),-1)

MovModel = 1
# Reshape data (video) into (T*n)xN array
rolled_vid = np.hstack([np.roll(model_vid_sm, nframes, axis=0) for nframes in lag_list]) # nt_glm_lag
if MovModel == 0:
    mx_train = move_train[:,perms[ind]]
    mx_test = move_test[:,perms[ind]]
    xtr = torch.from_numpy(mx_train.astype(np.float32)).to(device)
    xte = torch.from_numpy(mx_test.astype(np.float32)).to(device)    
    move_features = mx_train.shape[-1]
    nk = 0
elif MovModel == 1:
    xtr = torch.from_numpy((rolled_vid[train_idx].reshape(len(train_idx),-1)).astype(np.float32)).to(device)
    xte = torch.from_numpy((rolled_vid[test_idx].reshape(len(test_idx),-1)).astype(np.float32)).to(device)
    move_features = None
elif MovModel == 2:
    x_train_m2 = np.concatenate((x_train,move_train[:,perms[ind]]),axis=1)
    x_test_m2 = np.concatenate((x_test,move_test[:,perms[ind]]),axis=1)
    xtr = torch.from_numpy(x_train_m2.astype(np.float32)).to(device)
    xte = torch.from_numpy(x_test_m2.astype(np.float32)).to(device)
    move_features = x_train_m2.shape[-1]-nk
elif MovModel == 3:
    x_train_m3 = np.hstack((x_train,np.hstack([x_train*move_train[:,modeln][:,np.newaxis] for modeln in np.arange(len(titles))]), move_train[:,perms[ind]]))
    x_test_m3 = np.hstack((x_test,np.hstack([x_test*move_test[:,modeln][:,np.newaxis] for modeln in np.arange(len(titles))]), move_test[:,perms[ind]]))
    xtr = torch.from_numpy(x_train_m3.astype(np.float32)).to(device)
    xte = torch.from_numpy(x_test_m3.astype(np.float32)).to(device)    
    move_features = x_train_m3.shape[-1]-nk

    
ytr = torch.from_numpy(train_nsp.astype(np.float32)).to(device)
yte = torch.from_numpy(test_nsp.astype(np.float32)).to(device)
print('move_features: {}'.format(move_features))

In [None]:
input_size = xtr.shape[1]
output_size = ytr.shape[1]
reg_lam = 512
reg_alph = 10
l1 = PoissonGLM_VM(input_size,output_size,reg_lam=reg_lam,reg_alph=reg_alph,move_features=move_features).to(device)
if MovModel == 1:
    optimizer = optim.ASGD([{'params': [l1.weight, l1.bias]}, 
                           {'params': [l1.lam, l1.alpha,] , 'lr': .01},
                           ], lr=1e-3) 
else:
    optimizer = optim.ASGD([{'params': [l1.weight, l1.bias]}, 
                       {'params': [l1.lam, l1.alpha,l1.lam_m, l1.alpha_m] , 'lr': .01},
                       ], lr=1e-3) 
# optimizer = optim.ASGD(params=l1.parameters()) #
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1000)

scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=[1e-4, .1], max_lr=[1e-3, 16], cycle_momentum=False)
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=[0.001,], max_lr=[10], cycle_momentum=False)

lossfn = torch.nn.PoissonNLLLoss(log_input=True,reduction='mean')

Nbatches= 10000
if move_features != None:
    reg_params = np.zeros((Nbatches,output_size,4))
    reg_titles = ['lambda','lambda_m','alpha','alpha_m']
else:
    reg_params = np.zeros((Nbatches,output_size,2))
    reg_titles = ['lambda','alpha']
vloss_trace = np.zeros((Nbatches,output_size))      
tloss_trace = np.zeros((Nbatches,output_size))      
early_stopping = EarlyStopping(patience=500,min_delta=.5)
# lam_grad = np.zeros((Nbatches,output_size))
for batchn in tqdm(np.arange(Nbatches)):
    out = l1(xtr)
    loss = l1.loss(out,ytr)
    pred = l1(xte)
    val_loss = l1.loss(pred,yte)
    vloss_trace[batchn] = val_loss.clone().cpu().detach().numpy()
    tloss_trace[batchn] = loss.clone().cpu().detach().numpy()
    if move_features != None:
        reg_params[batchn] = torch.stack((l1.lam.data.clone().cpu(),l1.lam_m.data.clone().cpu(),l1.alpha.data.clone().cpu(),l1.alpha_m.data.clone().cpu())).numpy().T
    else:
        reg_params[batchn] = torch.stack((l1.lam.data.clone().cpu(),l1.alpha.data.clone().cpu())).numpy().T
    optimizer.zero_grad()
    loss.backward(torch.ones_like(loss))
    optimizer.step()
    scheduler.step()
#     lam_grad[batchn]= l1.lam.grad.detach().cpu().numpy()
#     early_stopping(val_loss.item())
#     if early_stopping.early_stop:
#         break
pred_all = l1(xte).cpu().detach().numpy()
if MovModel != 0:
    sta_all = l1.weight.cpu().detach().numpy()[:,:(nk)].reshape((output_size,nt_glm_lag)+nks)
if MovModel != 1:
    w_move = l1.weight.cpu().detach().numpy()[:,(nk):]

In [None]:
# reg_titles = ['lambda','lambda_m','alpha','alpha_m']
# reg_titles = ['lambda','alpha']
for n in range(reg_params.shape[-1]):
    fig,ax = plt.subplots()
    ax.plot(reg_params[:,:,n])
    ax.set_title(reg_titles[n])

In [None]:
bin_length=40
r2_all = np.zeros(output_size)
for celln in range(output_size):
    sp_smooth = ((np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    pred_smooth = ((np.convolve(pred_all[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    r2_all[celln] = (np.corrcoef(sp_smooth,pred_smooth)[0,1])**2

if MovModel == 0:
    GLM_Data = {'r2_all': r2_all,
                'spks_all': test_nsp,
                'pred_all': pred_all,
                'w_move': w_move}
elif MovModel == 1:
    GLM_Data = {'r2_all': r2_all,
                'sta_all': sta_all,
                'spks_all': test_nsp,
                'pred_all': pred_all,
               }
else:
    GLM_Data = {'r2_all': r2_all,
                'sta_all': sta_all,
                'spks_all': test_nsp,
                'pred_all': pred_all,
                'w_move': w_move}

if do_shuffle:
    save_datafile = save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel)
else:
    save_datafile = save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel)
ioh5.save(save_datafile, GLM_Data)
print(save_datafile)



## Plotting

In [None]:
MovModel=1

In [None]:
bin_length=40
data, train_idx, test_idx = load_train_test(file_dict, save_dir, model_dt=model_dt, do_shuffle=do_shuffle, do_norm=False,free_move=free_move, has_imu=free_move, has_mouse=False)
locals().update(data)
# if do_shuffle:
#     GLM_Vis = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel))
# else:
#     GLM_Vis = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel))
# locals().update(GLM_Vis)
##### Explore Neurons #####
colors = plt.cm.cool(np.linspace(0,1,4))
clrs = ['blue','orange','green','red']
# Initialize movement combinations
titles = np.array(['Theta','Phi','Roll','Pitch']) # 'dg_p','dg_n' 'roll','pitch'
titles_all = []
for n in range(1,len(titles)+1):
    perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))

# train_dgaze_p = train_dth + np.diff(train_gz,append=0)
# train_dgaze_n = train_dth - np.diff(train_gz,append=0)
# test_dgaze_p = test_dth + np.diff(test_gz,append=0)
# test_dgaze_n = test_dth - np.diff(test_gz,append=0)
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis]))#, train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis])) #,test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))#
# move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))# test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))
model_move = np.hstack((model_th[:,np.newaxis],model_phi[:,np.newaxis],model_roll[:,np.newaxis],model_pitch[:,np.newaxis])) #,test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))#
model_move = model_move - np.mean(model_move,axis=0)
move_test = move_test - np.mean(move_test,axis=0)
# Create all tuning curves for plotting
N_bins=10
ncells = model_nsp.shape[-1]
ax_ylims = np.zeros((model_nsp.shape[-1],len(titles)))
tuning_curves = np.zeros((model_nsp.shape[-1],len(titles),N_bins-1))
tuning_stds = np.zeros((model_nsp.shape[-1],len(titles),N_bins-1))
var_ranges = np.zeros((len(titles),N_bins-1))
for modeln in range(len(titles)):
    metric = move_test[:,modeln]
    tuning, tuning_std, var_range = tuning_curve(test_nsp, metric, N_bins=N_bins, model_dt=model_dt, Nstds=2)
    tuning_curves[:,modeln] = tuning
    tuning_stds[:,modeln] = tuning_std
    ax_ylims[:,modeln] = np.nanmax(tuning,axis=1)
    var_ranges[modeln] = var_range

In [None]:
celln = 51# np.argmax(r2_all)
bin_length = 40
ncells=model_nsp.shape[-1]
colors = plt.cm.cool(np.linspace(0,1,4))
clrs = ['blue','orange','green','red']
quartiles = np.arange(0,1.25,.25)

fig, axs = plt.subplots(3,5, figsize=((35,15))) 
gs = axs[0,0].get_gridspec()
gs_sub = gs[0,:].subgridspec(1,nt_glm_lag)
for ax in axs[0,:]:
    ax.remove()
top_grid = np.zeros((nt_glm_lag),dtype=object)
for ind in range(nt_glm_lag):
    top_grid[ind] = fig.add_subplot(gs_sub[0,ind])

predcell = pred_all[:,celln]/model_dt
nspcell = test_nsp[:,celln]/model_dt
test_nsp_smooth=((np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
pred_smooth=((np.convolve(pred_all[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
# Set up predicted spike range between 1-99th percentile
stat_bins = 5
pred_range = np.quantile(predcell,[.1,.9])
test_nsp_range = np.quantile(nspcell,[.01,1])
spike_percentiles = np.arange(0,1.25,.25)
spike_percentiles[-1]=.99
spk_percentile2 = np.arange(.125,1.125,.25)
pred_rangelin = np.quantile(predcell,spike_percentiles)
xbin_pts = np.quantile(predcell,spk_percentile2)
stat_bins = len(pred_rangelin) #5


axs[1,0].plot(np.arange(len(test_nsp_smooth))*model_dt,test_nsp_smooth,'k',label='test FR')
axs[1,0].plot(np.arange(len(pred_smooth))*model_dt,pred_smooth,'r', label='pred FR')
axs[1,0].set_xlabel('Time (s)')
axs[1,0].set_ylabel('Firing Rate (spks/s)')
axs[1,0].legend()
axs[1,0].set_title('Smoothed FRs')

crange = np.max(np.abs(sta_all[celln]))
for n in range(nt_glm_lag):
    img = top_grid[n].imshow(sta_all[celln,n],cmap='RdBu_r',vmin=-crange,vmax=crange)
    top_grid[n].axis('off')
    top_grid[n].set_title('Lag:{:03d} ms'.format(int(1000*lag_list[n]*model_dt)))
    top_grid[n].axis('off')
add_colorbar(img)

# Eye Tuning Curve
top_yaxs = np.max(ax_ylims[celln])+2*np.max(tuning_stds[celln])
for i,modeln in enumerate(range(len(titles)-2)):
    metric = move_test[:,modeln]
    nranges = np.quantile(metric,quartiles)
    stat_range, edges, _ = binned_statistic(metric,test_nsp[:,celln],statistic='mean',bins=nranges)
    edge_mids = np.quantile(metric,spk_percentile2)#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
    cmap = mpl.colors.ListedColormap(colors, N=colors.shape[0])
    norm = mpl.colors.BoundaryNorm(boundaries=np.floor(nranges), ncolors=len(cmap.colors))
    for m in range(len(nranges)-1):
        axs[1,1].axvspan(nranges[m], nranges[m+1],ymin=i*1/2,ymax=(i+1)*1/2,alpha=0.8, color=colors[m],zorder=0)
    #     axs[1,1].errorbar(var_ranges[modeln],tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln],label=titles[modeln],c=clrs[modeln],lw=4,elinewidth=3)
    axs[1,1].plot(edge_mids,stat_range/model_dt,'.-', ms=20, lw=4,c=clrs[modeln])

axs[1,1].set_ylim(bottom=0,top=np.max(ax_ylims,axis=1)[celln]+2*np.nanmax(tuning_stds,axis=(1,2))[celln])
axs[1,1].set_xlim(-30,30)
axs[1,1].set_xlabel('Angle ($ ^{\degree}$)')
axs[1,1].set_ylabel('Spikes/s')
axs[1,1].set_title('Eye Tuning Curves')
lines = axs[1,1].get_lines()
legend1 = axs[1,1].legend([lines[0]],[titles[0]],bbox_to_anchor=(1.01, .2), fontsize=12)
legend2 = axs[1,1].legend([lines[1]],[titles[1]],bbox_to_anchor=(1.01, .9), fontsize=12)
axs[1,1].add_artist(legend1)

# Head Tuning Curves
top_yaxs = np.max(ax_ylims[celln])+2*np.nanmax(tuning_stds[celln])
for i, modeln in enumerate(range(2,len(titles))):
    metric = move_test[:,modeln]
#     nranges = np.round(np.quantile(var_ranges[modeln],quartiles),decimals=1)
    nranges = np.round(np.quantile(metric,quartiles),decimals=1)
    stat_range, edges, _ = binned_statistic(metric,test_nsp[:,celln],statistic='mean',bins=nranges)
    edge_mids = np.quantile(metric,spk_percentile2)#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
    cmap = mpl.colors.ListedColormap(colors, N=colors.shape[0])
    norm = mpl.colors.BoundaryNorm(boundaries=np.floor(nranges), ncolors=len(cmap.colors))
    for m in range(len(nranges)-1):
        axs[1,2].axvspan(nranges[m], nranges[m+1],ymin=i*1/2,ymax=(i+1)*1/2,alpha=0.8, color=colors[m],zorder=0)
#     axs[1,2].errorbar(var_ranges[modeln], tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln], label=titles[modeln], c=clrs[modeln],lw=4,elinewidth=3)
    axs[1,2].plot(edge_mids,stat_range/model_dt,'.-', ms=20, lw=4,c=clrs[modeln])

axs[1,2].set_ylim(bottom=0,top=top_yaxs)
axs[1,2].set_xlim(-30,30)
axs[1,2].set_xlabel('Angle ($ ^{\degree}$)')
axs[1,2].set_ylabel('Spikes/s')
axs[1,2].set_title('Head Tuning Curves')
lines = axs[1,2].get_lines()
legend1 = axs[1,2].legend([lines[0]],[titles[2]],bbox_to_anchor=(1.01, .2), fontsize=12)
legend2 = axs[1,2].legend([lines[1]],[titles[3]],bbox_to_anchor=(1.01, .9), fontsize=12)
axs[1,2].add_artist(legend1)

# axs[1,2].legend(bbox_to_anchor=(1.01, 1), fontsize=12)


# pred_rangelin = np.linspace(pred_range[0],pred_range[1],stat_bins)
axs[1,3].scatter(pred_all[celln]/model_dt,test_nsp[celln]/model_dt,c='k',s=15)
axs[1,3].plot(np.linspace(test_nsp_range[0],test_nsp_range[1]),np.linspace(test_nsp_range[0],test_nsp_range[1]),'k--',zorder=0)
axs[1,3].set_xlabel('Predicted Spike Rate')
axs[1,3].set_ylabel('Actual Spike Rate')
cbar = add_colorbar(img)
# cbar.set_label('count')

if MovModel == 1:
    w_move = np.zeros((model_nsp.shape[-1],len(titles)))
elif MovModel == 3:
    Msta = w_move[:,:-len(titles)].reshape((model_nsp.shape[-1],nt_glm_lag,len(titles))+nks)
    w_move = w_move[:,-len(titles):]
for modeln in range(len(titles)):
    axs[1,4].bar(modeln, w_move[celln,modeln], color=clrs[modeln])
    axs[1,4].set_xticks(np.arange(0,len(titles)))
    axs[1,4].set_xticklabels(titles)
    axs[1,4].set_ylabel('GLM Weight')


mse_add = np.zeros((ncells,len(titles),len(quartiles)-1))
mse_mult = np.zeros((ncells,len(titles),len(quartiles)-1))
alpha_add = np.zeros((ncells,len(titles),len(quartiles)-1))
alpha_mult = np.zeros((ncells,len(titles),len(quartiles)-1))

traces = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
traces_mean = np.zeros((ncells,len(titles),stat_bins-1)) # (model_type,quartile,FR)
edges_all = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
# df_traces = pd.DataFrame([],columns=['modeln','quartile','FR']) 
for modeln in range(len(titles)):
    metric = move_test[:,modeln]
    nranges = np.quantile(metric,quartiles)# np.linspace(np.nanmean(metric)-2*np.nanstd(metric), np.nanmean(metric)+2*np.nanstd(metric),N_bins)
    stat_all, edges, _ = binned_statistic(predcell,nspcell, statistic='mean',bins=pred_rangelin)
    edge_mids = xbin_pts#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
    traces_mean[celln,modeln]=stat_all
    max_fr = np.max(stat_all)
#     axs[1,modeln].set_xlim(0,pred_range[1]+np.std(pred_range))
#     axs[1,modeln].set_ylim(0,np.max(stat)+np.std(stat))

    for n in range(len(nranges)-1):
        ind = np.where(((metric<=nranges[n+1])&(metric>nranges[n])))[0]
        pred = predcell[ind]
        sp = nspcell[ind]

        stat_range, edges, _ = binned_statistic(pred, sp, statistic='mean',bins=pred_rangelin)
        edge_mids = xbin_pts #np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
        traces[celln,modeln,n]=stat_range
        edges_all[celln,modeln,n]=edge_mids
        res_add = minimize_scalar(f_add,args=(stat_range/max_fr, stat_all/max_fr))
        res_mult = minimize_scalar(f_mult,args=(stat_range/max_fr, stat_all/max_fr))
        mse_add[celln, modeln, n] = res_add.fun
        mse_mult[celln, modeln, n] = res_mult.fun
        alpha_add[celln, modeln, n] = res_add.x
        alpha_mult[celln, modeln, n] = res_mult.x

        axs[2,modeln].plot(edge_mids, stat_range,'.-', c=colors[n],label='{:.02f} : {:.02f}'.format(nranges[n],nranges[n+1]),lw=4,ms=20,alpha=.9)
        axs[2,modeln].set_title('Metric: {}'.format(titles[modeln]), color=clrs[modeln])
        axs[2,modeln].set_xlabel('Predicted Spike Rate')
        axs[2,modeln].set_ylabel('Actual Spike Rate')
    
    lim_max = np.max(np.hstack((edge_mids,traces[celln,modeln].flatten())))+.5*np.std(edges)
    lim_min = np.min(np.hstack((edge_mids,traces[celln,modeln].flatten())))-.5*np.std(edges)
    lims = (0, lim_max) if (lim_min)<0 else (lim_min,lim_max) 
    axs[2,modeln].plot(np.linspace(lims[0],lims[1]),np.linspace(lims[0],lims[1]),'k--',zorder=0)
    axs[2,modeln].plot(edge_mids, stat_all,'.-', c='k', lw=5, ms=20, label='All_data', alpha=.8)
    axs[2,modeln].legend(bbox_to_anchor=(1.01, 1), fontsize=12)
    axs[2,modeln].axis('equal')
#     axs[2,modeln].set_xlim(left=0)
    axs[2,modeln].set(xlim=lims, ylim=lims)
#     axs[2,modeln].set_xlim([0,xbin_pts[-1]])
    axs[2,modeln].set_ylim(bottom=0)

dmodel = mse_add[celln]-mse_mult[celln]
crange = np.max(np.abs(dmodel))
im = axs[2,-1].imshow(dmodel,cmap='seismic',vmin=-crange,vmax=crange)
axs[2,-1].set_yticks(np.arange(0,4))
axs[2,-1].set_yticklabels(titles)
axs[2,-1].set_ylabel('Movement Model')
axs[2,-1].set_xticks(np.arange(0,4))
axs[2,-1].set_xticklabels(['.25','.5','.75','1'])
axs[2,-1].set_xlabel('Quantile Range')
axs[2,-1].set_title('$MSE_{add}$ - $MSE_{mult}$')
cbar = add_colorbar(im)

plt.suptitle('Celln:{}, r2={:.03f}'.format(celln,r2_all[celln]),y=1,fontsize=30)
plt.tight_layout()


# fig.savefig(FigPath/'CellSummary_N{}_T{:02d}.png'.format(celln,nt_glm_lag), facecolor='white', transparent=True)

In [None]:
# Msta = w_move[:,:-4].reshape((model_nsp.shape[-1],nt_glm_lag,len(titles))+nks)
# fig, ax = plt.subplots(4,nt_glm_lag,figsize=(20,10))
# for modeln in range(len(titles)):
#     crange = np.max(np.abs(Msta[celln,:,modeln]))
#     for n in range(nt_glm_lag):
#         img = ax[modeln,n].imshow(Msta[celln,n,modeln],cmap='RdBu_r',vmin=-crange,vmax=crange)
#         ax[modeln,n].axis('off')
#         ax[modeln,n].set_title('Lag:{:03d} ms'.format(int(1000*lag_list[n]*model_dt)))
#         ax[modeln,n].axis('off')
#     add_colorbar(img)

In [None]:
bin_length=40

for n, celln in enumerate(tqdm([13,67])):
    fig2 = plt.figure(constrained_layout=False, figsize=(20,7))
    spec2 = gridspec.GridSpec(ncols=nt_glm_lag, nrows=2, figure=fig2)
    axs = np.array([fig2.add_subplot(spec2[0, n]) for n in range(nt_glm_lag)])
    f2_ax6 = fig2.add_subplot(spec2[1, :nt_glm_lag//2])
    f2_ax7 = fig2.add_subplot(spec2[1, nt_glm_lag//2:])
    crange = np.max(np.abs(sta_all[celln]))
    for n,ax in enumerate(axs):
        im = ax.imshow(sta_all[celln,n],'RdBu_r',vmin=-crange,vmax=crange)
        cbar = add_colorbar(im)
        ax.axis('off')

    sp_smooth = ((np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    f2_ax6.plot(np.arange(len(sp_smooth))*model_dt,sp_smooth, 'k', lw=2)
    pred_smooth = ((np.convolve(pred_all[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    f2_ax6.plot(np.arange(len(pred_smooth))*model_dt,pred_smooth,'r', lw=2)
    f2_ax6.set_xlabel('Time (s)')
    f2_ax6.set_ylabel('Spike Rate')
    f2_ax7.plot(tloss_trace[:,celln])
    f2_ax7.plot(vloss_trace[:,celln])
    f2_ax7.set_xlabel('Batch #')
    f2_ax7.set_ylabel('Loss')
    r2 = (np.corrcoef(sp_smooth,pred_smooth)[0,1])**2
    plt.suptitle('celln: {} $r^2$:{:.03f}'.format(celln, r2))
    plt.tight_layout()

## Make PDF of All Cells

In [592]:
##### Make PDF of All Cells #####
bin_length=40
do_shuffle = False
ncells=model_nsp.shape[-1]
colors = plt.cm.cool(np.linspace(0,1,4))
clrs = ['blue','orange','green','red']
for MovModel in [1]: #
    data, train_idx, test_idx = load_train_test(file_dict, save_dir, model_dt=model_dt, do_shuffle=do_shuffle, do_norm=False,free_move=free_move, has_imu=free_move, has_mouse=False)
    locals().update(data)
    if do_shuffle:
        GLM_Vis = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel))
    else:
        GLM_Vis = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag,MovModel))
    locals().update(GLM_Vis)
    ##### Explore Neurons #####
    colors = plt.cm.cool(np.linspace(0,1,4))
    clrs = ['blue','orange','green','red']
    # Initialize movement combinations
    titles = np.array(['Theta','Phi','Roll','Pitch']) # 'dg_p','dg_n' 'roll','pitch'
    titles_all = []
    for n in range(1,len(titles)+1):
        perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
        for ind in range(perms.shape[0]):
            titles_all.append('_'.join([t for t in titles[perms[ind]]]))

    # train_dgaze_p = train_dth + np.diff(train_gz,append=0)
    # train_dgaze_n = train_dth - np.diff(train_gz,append=0)
    # test_dgaze_p = test_dth + np.diff(test_gz,append=0)
    # test_dgaze_n = test_dth - np.diff(test_gz,append=0)
    move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis]))#, train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
    move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis])) #,test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))#
    # move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))# test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))
    model_move = np.hstack((model_th[:,np.newaxis],model_phi[:,np.newaxis],model_roll[:,np.newaxis],model_pitch[:,np.newaxis])) #,test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))#
    model_move = model_move - np.mean(model_move,axis=0)
    move_test = move_test - np.mean(move_test,axis=0)

    # Create all tuning curves for plotting
    N_bins=10
    ncells = model_nsp.shape[-1]
    ax_ylims = np.zeros((model_nsp.shape[-1],len(titles)))
    tuning_curves = np.zeros((model_nsp.shape[-1],len(titles),N_bins-1))
    tuning_stds = np.zeros((model_nsp.shape[-1],len(titles),N_bins-1))
    var_ranges = np.zeros((len(titles),N_bins-1))
    for modeln in range(len(titles)):
        metric = move_test[:,modeln]
        tuning, tuning_std, var_range = tuning_curve(test_nsp, metric, N_bins=N_bins, model_dt=model_dt)
        tuning_curves[:,modeln] = tuning
        tuning_stds[:,modeln] = tuning_std
        ax_ylims[:,modeln] = np.max(tuning,axis=1)
        var_ranges[modeln] = var_range

    quartiles = np.arange(0,1.25,.25)
    if do_shuffle:
        pdf_name = FigPath/ 'VisMov_{}_dt{:03d}_Lags{:02d}_MovModel{:d}_CellSummary_shuff.pdf'.format(model_type,int(model_dt*1000),nt_glm_lag, MovModel)
    else:
        pdf_name = FigPath/ 'VisMov_{}_dt{:03d}_Lags{:02d}_MovModel{:d}_CellSummary.pdf'.format(model_type,int(model_dt*1000),nt_glm_lag, MovModel)
    with PdfPages(pdf_name) as pdf:
        for celln in tqdm(range(model_nsp.shape[1])):            
            fig, axs = plt.subplots(3,5, figsize=((35,15))) 
            gs = axs[0,0].get_gridspec()
            gs_sub = gs[0,:].subgridspec(1,nt_glm_lag)
            for ax in axs[0,:]:
                ax.remove()
            top_grid = np.zeros((nt_glm_lag),dtype=object)
            for ind in range(nt_glm_lag):
                top_grid[ind] = fig.add_subplot(gs_sub[0,ind])

            predcell = pred_all[:,celln]/model_dt
            nspcell = test_nsp[:,celln]/model_dt
            test_nsp_smooth=((np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
            pred_smooth=((np.convolve(pred_all[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
            # Set up predicted spike range between 1-99th percentile
            stat_bins = 5
            pred_range = np.quantile(predcell,[.1,.9])
            test_nsp_range = np.quantile(nspcell,[.01,1])
            spike_percentiles = np.arange(0,1.25,.25)
            spike_percentiles[-1]=.99
            spk_percentile2 = np.arange(.125,1.125,.25)
            pred_rangelin = np.quantile(predcell,spike_percentiles)
            xbin_pts = np.quantile(predcell,spk_percentile2)
            stat_bins = len(pred_rangelin) #5


            axs[1,0].plot(np.arange(len(test_nsp_smooth))*model_dt,test_nsp_smooth,'k',label='test FR')
            axs[1,0].plot(np.arange(len(pred_smooth))*model_dt,pred_smooth,'r', label='pred FR')
            axs[1,0].set_xlabel('Time (s)')
            axs[1,0].set_ylabel('Firing Rate (spks/s)')
            axs[1,0].legend()
            axs[1,0].set_title('Smoothed FRs')

            crange = np.max(np.abs(sta_all[celln]))
            for n in range(nt_glm_lag):
                img = top_grid[n].imshow(sta_all[celln,n],cmap='RdBu_r',vmin=-crange,vmax=crange)
                top_grid[n].axis('off')
                top_grid[n].set_title('Lag:{:03d} ms'.format(int(1000*lag_list[n]*model_dt)))
                top_grid[n].axis('off')
            add_colorbar(img)

            # Eye Tuning Curve
            top_yaxs = np.max(ax_ylims[celln])+2*np.max(tuning_stds[celln])
            for i,modeln in enumerate(range(len(titles)-2)):
                metric = move_test[:,modeln]
                nranges = np.quantile(metric,quartiles)
                stat_range, edges, _ = binned_statistic(metric,test_nsp[:,celln],statistic='mean',bins=nranges)
                edge_mids = np.quantile(metric,spk_percentile2)#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
                cmap = mpl.colors.ListedColormap(colors, N=colors.shape[0])
                norm = mpl.colors.BoundaryNorm(boundaries=np.floor(nranges), ncolors=len(cmap.colors))
                for m in range(len(nranges)-1):
                    axs[1,1].axvspan(nranges[m], nranges[m+1],ymin=i*1/2,ymax=(i+1)*1/2,alpha=0.8, color=colors[m],zorder=0)
                #     axs[1,1].errorbar(var_ranges[modeln],tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln],label=titles[modeln],c=clrs[modeln],lw=4,elinewidth=3)
                axs[1,1].plot(edge_mids,stat_range/model_dt,'.-', ms=20, lw=4,c=clrs[modeln])

            axs[1,1].set_ylim(bottom=0,top=np.max(ax_ylims,axis=1)[celln]+2*np.nanmax(tuning_stds,axis=(1,2))[celln])
            axs[1,1].set_xlim(-30,30)
            axs[1,1].set_xlabel('Angle ($ ^{\degree}$)')
            axs[1,1].set_ylabel('Spikes/s')
            axs[1,1].set_title('Eye Tuning Curves')
            lines = axs[1,1].get_lines()
            legend1 = axs[1,1].legend([lines[0]],[titles[0]],bbox_to_anchor=(1.01, .2), fontsize=12)
            legend2 = axs[1,1].legend([lines[1]],[titles[1]],bbox_to_anchor=(1.01, .9), fontsize=12)
            axs[1,1].add_artist(legend1)

            # Head Tuning Curves
            top_yaxs = np.max(ax_ylims[celln])+2*np.nanmax(tuning_stds[celln])
            for i, modeln in enumerate(range(2,len(titles))):
                metric = move_test[:,modeln]
            #     nranges = np.round(np.quantile(var_ranges[modeln],quartiles),decimals=1)
                nranges = np.round(np.quantile(metric,quartiles),decimals=1)
                stat_range, edges, _ = binned_statistic(metric,test_nsp[:,celln],statistic='mean',bins=nranges)
                edge_mids = np.quantile(metric,spk_percentile2)#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
                cmap = mpl.colors.ListedColormap(colors, N=colors.shape[0])
                norm = mpl.colors.BoundaryNorm(boundaries=np.floor(nranges), ncolors=len(cmap.colors))
                for m in range(len(nranges)-1):
                    axs[1,2].axvspan(nranges[m], nranges[m+1],ymin=i*1/2,ymax=(i+1)*1/2,alpha=0.8, color=colors[m],zorder=0)
            #     axs[1,2].errorbar(var_ranges[modeln], tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln], label=titles[modeln], c=clrs[modeln],lw=4,elinewidth=3)
                axs[1,2].plot(edge_mids,stat_range/model_dt,'.-', ms=20, lw=4,c=clrs[modeln])

            axs[1,2].set_ylim(bottom=0,top=top_yaxs)
            axs[1,2].set_xlim(-30,30)
            axs[1,2].set_xlabel('Angle ($ ^{\degree}$)')
            axs[1,2].set_ylabel('Spikes/s')
            axs[1,2].set_title('Head Tuning Curves')
            lines = axs[1,2].get_lines()
            legend1 = axs[1,2].legend([lines[0]],[titles[2]],bbox_to_anchor=(1.01, .2), fontsize=12)
            legend2 = axs[1,2].legend([lines[1]],[titles[3]],bbox_to_anchor=(1.01, .9), fontsize=12)
            axs[1,2].add_artist(legend1)

            # axs[1,2].legend(bbox_to_anchor=(1.01, 1), fontsize=12)


            # pred_rangelin = np.linspace(pred_range[0],pred_range[1],stat_bins)
            axs[1,3].scatter(pred_all[celln]/model_dt,test_nsp[celln]/model_dt,c='k',s=15)
            axs[1,3].plot(np.linspace(test_nsp_range[0],test_nsp_range[1]),np.linspace(test_nsp_range[0],test_nsp_range[1]),'k--',zorder=0)
            axs[1,3].set_xlabel('Predicted Spike Rate')
            axs[1,3].set_ylabel('Actual Spike Rate')
            cbar = add_colorbar(img)
            # cbar.set_label('count')

            if MovModel == 1:
                w_move = np.zeros((model_nsp.shape[-1],len(titles)))
            elif MovModel == 3:
                Msta = w_move[:,:-len(titles)].reshape((model_nsp.shape[-1],nt_glm_lag,len(titles))+nks)
                w_move = w_move[:,-len(titles):]
            for modeln in range(len(titles)):
                axs[1,4].bar(modeln, w_move[celln,modeln], color=clrs[modeln])
                axs[1,4].set_xticks(np.arange(0,len(titles)))
                axs[1,4].set_xticklabels(titles)
                axs[1,4].set_ylabel('GLM Weight')


            mse_add = np.zeros((ncells,len(titles),len(quartiles)-1))
            mse_mult = np.zeros((ncells,len(titles),len(quartiles)-1))
            alpha_add = np.zeros((ncells,len(titles),len(quartiles)-1))
            alpha_mult = np.zeros((ncells,len(titles),len(quartiles)-1))

            traces = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
            traces_mean = np.zeros((ncells,len(titles),stat_bins-1)) # (model_type,quartile,FR)
            edges_all = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
            # df_traces = pd.DataFrame([],columns=['modeln','quartile','FR']) 
            for modeln in range(len(titles)):
                metric = move_test[:,modeln]
                nranges = np.quantile(metric,quartiles)# np.linspace(np.nanmean(metric)-2*np.nanstd(metric), np.nanmean(metric)+2*np.nanstd(metric),N_bins)
                stat_all, edges, _ = binned_statistic(predcell,nspcell, statistic='mean',bins=pred_rangelin)
                edge_mids = xbin_pts#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
                traces_mean[celln,modeln]=stat_all
                max_fr = np.max(stat_all)
            #     axs[1,modeln].set_xlim(0,pred_range[1]+np.std(pred_range))
            #     axs[1,modeln].set_ylim(0,np.max(stat)+np.std(stat))

                for n in range(len(nranges)-1):
                    ind = np.where(((metric<=nranges[n+1])&(metric>nranges[n])))[0]
                    pred = predcell[ind]
                    sp = nspcell[ind]

                    stat_range, edges, _ = binned_statistic(pred, sp, statistic='mean',bins=pred_rangelin)
                    edge_mids = xbin_pts #np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
                    traces[celln,modeln,n]=stat_range
                    edges_all[celln,modeln,n]=edge_mids
                    res_add = minimize_scalar(f_add,args=(stat_range/max_fr, stat_all/max_fr))
                    res_mult = minimize_scalar(f_mult,args=(stat_range/max_fr, stat_all/max_fr))
                    mse_add[celln, modeln, n] = res_add.fun
                    mse_mult[celln, modeln, n] = res_mult.fun
                    alpha_add[celln, modeln, n] = res_add.x
                    alpha_mult[celln, modeln, n] = res_mult.x

                    axs[2,modeln].plot(edge_mids, stat_range,'.-', c=colors[n],label='{:.02f} : {:.02f}'.format(nranges[n],nranges[n+1]),lw=4,ms=20,alpha=.9)
                    axs[2,modeln].set_title('Metric: {}'.format(titles[modeln]), color=clrs[modeln])
                    axs[2,modeln].set_xlabel('Predicted Spike Rate')
                    axs[2,modeln].set_ylabel('Actual Spike Rate')

                lim_max = np.max(np.hstack((edge_mids,traces[celln,modeln].flatten())))+.5*np.std(edges)
                lim_min = np.min(np.hstack((edge_mids,traces[celln,modeln].flatten())))-.5*np.std(edges)
                lims = (0, lim_max) if (lim_min)<0 else (lim_min,lim_max) 
                axs[2,modeln].plot(np.linspace(lims[0],lims[1]),np.linspace(lims[0],lims[1]),'k--',zorder=0)
                axs[2,modeln].plot(edge_mids, stat_all,'.-', c='k', lw=5, ms=20, label='All_data', alpha=.8)
                axs[2,modeln].legend(bbox_to_anchor=(1.01, 1), fontsize=12)
                axs[2,modeln].axis('equal')
            #     axs[2,modeln].set_xlim(left=0)
                axs[2,modeln].set(xlim=lims, ylim=lims)
            #     axs[2,modeln].set_xlim([0,xbin_pts[-1]])
                axs[2,modeln].set_ylim(bottom=0)

            dmodel = mse_add[celln]-mse_mult[celln]
            crange = np.max(np.abs(dmodel))
            im = axs[2,-1].imshow(dmodel,cmap='seismic',vmin=-crange,vmax=crange)
            axs[2,-1].set_yticks(np.arange(0,4))
            axs[2,-1].set_yticklabels(titles)
            axs[2,-1].set_ylabel('Movement Model')
            axs[2,-1].set_xticks(np.arange(0,4))
            axs[2,-1].set_xticklabels(['.25','.5','.75','1'])
            axs[2,-1].set_xlabel('Quantile Range')
            axs[2,-1].set_title('$MSE_{add}$ - $MSE_{mult}$')
            cbar = add_colorbar(im)

            plt.suptitle('Celln:{}, r2={:.03f}'.format(celln,r2_all[celln]),y=1,fontsize=30)
            plt.tight_layout()

            pdf.savefig()
            plt.close()

# fig.savefig(FigPath/'CellSummary_N{}.png'.format(celln), facecolor='white', transparent=True)

Done Loading Aligned Data
TRAIN: 15628 TEST: 6698


  0%|          | 0/128 [00:00<?, ?it/s]

In [None]:
bin_length=40

for n, celln in enumerate(tqdm([21,25,51,117])):
    fig2 = plt.figure(constrained_layout=False, figsize=(20,7))
    spec2 = gridspec.GridSpec(ncols=nt_glm_lag, nrows=2, figure=fig2)
    axs = np.array([fig2.add_subplot(spec2[0, n]) for n in range(nt_glm_lag)])
    f2_ax6 = fig2.add_subplot(spec2[1, :nt_glm_lag//2])
    f2_ax7 = fig2.add_subplot(spec2[1, nt_glm_lag//2:-1])
    f2_ax8 = fig2.add_subplot(spec2[1, -1])
    crange = np.max(np.abs(sta_all[celln]))
    for n,ax in enumerate(axs):
        im = ax.imshow(sta_all[celln,n],'RdBu_r',vmin=-crange,vmax=crange)
        cbar = add_colorbar(im)
        ax.axis('off')

    sp_smooth = ((np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    f2_ax6.plot(np.arange(len(sp_smooth))*model_dt,sp_smooth, 'k', lw=2)
    pred_smooth = ((np.convolve(pred_all[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    f2_ax6.plot(np.arange(len(pred_smooth))*model_dt,pred_smooth,'r', lw=2)
    f2_ax6.set_xlabel('Time (s)')
    f2_ax6.set_ylabel('Spike Rate')
    f2_ax7.plot(tloss_trace[:,celln])
    f2_ax7.plot(vloss_trace[:,celln])
    f2_ax7.set_xlabel('Batch #')
    f2_ax7.set_ylabel('Loss')
    r2 = (np.corrcoef(sp_smooth,pred_smooth)[0,1])**2
    for modeln in range(len(titles)):
        f2_ax8.bar(modeln, w_move[celln,modeln], color=clrs[modeln])
        f2_ax8.set_xticks(np.arange(0,len(titles)))
        f2_ax8.set_xticklabels(titles)
        f2_ax8.set_ylabel('GLM Weight')    
    plt.suptitle('celln: {} $r^2$:{:.03f}'.format(celln, r2))
    plt.tight_layout()

## Shuffle Comparison

In [None]:
model_dt=.05
# for model_dt in [.025,.05,.1]:
GLM_VisMov_shuff = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, 1))
GLM_VisMov_m0 = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag,0))
GLM_VisMov_m1 = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag,1))
GLM_VisMov_m2 = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag,2))
# GLM_VisMov_m3 = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag,3))


max_shuff = np.max(GLM_VisMov_shuff['r2_all']**2)
fig,ax = plt.subplots(1,1,figsize=(7,5))

sig = GLM_VisMov_m1['r2_all'][GLM_VisMov_m1['r2_all']>max_shuff]
non_sig = GLM_VisMov_m1['r2_all'][GLM_VisMov_m1['r2_all']<max_shuff]
non_sig_vm = GLM_VisMov_m2['r2_all'][GLM_VisMov_m2['r2_all']<max_shuff]

hbins=.02
count,edges = np.histogram(GLM_VisMov_m1['r2_all'],bins=np.arange(0,1,hbins))
count_m2,edges_m2 = np.histogram(GLM_VisMov_m2['r2_all'],bins=np.arange(0,1,hbins))
count_shuff,edges_shuff = np.histogram(non_sig,bins=np.arange(0,1,hbins))
edges_mid = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
edges_mid_m2 = np.array([(edges_m2[i]+edges_m2[i+1])/2 for i in range(len(edges_m2)-1)])
edges_mid_shuff = np.array([(edges_shuff[i]+edges_shuff[i+1])/2 for i in range(len(edges_shuff)-1)])
ax.bar(edges_mid, count/len(GLM_VisMov_m1['r2_all']),color='k',width=hbins,alpha=.5, label='Significant $R^2$')
ax.bar(edges_mid_m2, count_m2/len(GLM_VisMov_m2['r2_all']),color='b',width=hbins,alpha=.5, label='Significant VisMov $R^2$')
ax.bar(edges_mid_shuff, count_shuff/len(GLM_VisMov_shuff['r2_all']),color='r',width=hbins,alpha=1, label='NonSignificant $R^2$')
ax.set_xlabel('$R^2$')
ax.set_yticks(np.arange(0,1,.1))
ax.set_yticklabels(np.round(np.arange(0,1,.1),decimals=3))
ax.set_xlim(-.01,.5)
ax.set_ylim(0,.2)
ax.legend(fontsize=12,loc=(.3,.9))

In [None]:
GLM_VisMov_m0 = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag,0))
GLM_VisMov_m1 = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag,1))
GLM_VisMov_m2 = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag,2))
r_vpm = GLM_VisMov_m1['r2_all']+GLM_VisMov_m0['r2_all']

In [None]:
lim_max = np.max((GLM_VisMov_m0['r2_all'],GLM_VisMov_m1['r2_all'],GLM_VisMov_m2['r2_all']))+1.5*np.std((GLM_VisMov_m0['r2_all'],GLM_VisMov_m1['r2_all'],GLM_VisMov_m2['r2_all']))
lims = (0, lim_max)

fig, ax = plt.subplots(figsize=(7,7))
ax.scatter(GLM_VisMov_m1['r2_all'],GLM_VisMov_m2['r2_all'],c='k',label='Significant $R^2$')
# ax.scatter(non_sig,non_sig_vm,c='r', label='Nonsignificant $R^2$')
ax.plot(np.linspace(lims[0],lims[1]),np.linspace(lims[0],lims[1]),'k--',zorder=0)
ax.legend(fontsize=12,loc=(.3,.9))
ax.set_xlabel('Visual only $R^2$')
ax.set_ylabel('VisMov $R^2$')
plt.tight_layout()
# fig.savefig(FigPath/'M1M2_R2_Comparison.png', facecolor='white', transparent=True)

fig, ax = plt.subplots(figsize=(7,7))
ax.scatter(GLM_VisMov_m0['r2_all'],GLM_VisMov_m2['r2_all'],c='k',label='Significant $R^2$')
# ax.scatter(non_sig,non_sig_vm,c='r', label='Nonsignificant $R^2$')
ax.plot(np.linspace(lims[0],lims[1]),np.linspace(lims[0],lims[1]),'k--',zorder=0)
ax.legend(fontsize=12,loc=(.3,.9))
ax.set_xlabel('Movement Only $R^2$')
ax.set_ylabel('VisMov $R^2$')
plt.tight_layout()
# fig.savefig(FigPath/'M0M2_R2_Comparison.png', facecolor='white', transparent=True)


fig, ax = plt.subplots(figsize=(7,7))
ax.scatter(GLM_VisMov_m0['r2_all'],GLM_VisMov_m1['r2_all'],c='k',label='Significant $R^2$')
# ax.scatter(non_sig,non_sig_vm,c='r', label='Nonsiagnificant $R^2$')
ax.plot(np.linspace(lims[0],lims[1]),np.linspace(lims[0],lims[1]),'k--',zorder=0)
ax.legend(fontsize=12,loc=(.3,.9))
ax.set_xlabel('Movement Only $R^2$')
ax.set_ylabel('Visual $R^2$')
plt.tight_layout()
# fig.savefig(FigPath/'M0M1_R2_Comparison.png', facecolor='white', transparent=True)

fig, ax = plt.subplots(figsize=(7,7))
ax.scatter(r_vpm,GLM_VisMov_m2['r2_all'],c='k',label='Significant $R^2$')
# ax.scatter(non_sig,non_sig_vm,c='r', label='Nonsiagnificant $R^2$')
ax.plot(np.linspace(lims[0],lims[1]),np.linspace(lims[0],lims[1]),'k--',zorder=0)
ax.legend(fontsize=12,loc=(.3,.9))
ax.set_xlabel('V+M $R^2$')
ax.set_ylabel('VisMov $R^2$')
plt.tight_layout()
# fig.savefig(FigPath/'VPM_R2_Comparison.png', facecolor='white', transparent=True)


## Plotting Motor Only

In [None]:
bin_length=40
data, train_idx, test_idx = load_train_test(file_dict, save_dir, model_dt=model_dt, do_shuffle=do_shuffle, do_norm=False,free_move=free_move, has_imu=free_move, has_mouse=False)
locals().update(data)
if do_shuffle:
    save_datafile = save_dir/'GLM_{}_Data_Mov_dt{:03d}_T{:02d}_MovModel{:d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel)
else:
    save_datafile = save_dir/'GLM_{}_Data_Mov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag,MovModel)
GLM_Vis = ioh5.load(save_datafile)
locals().update(GLM_Vis)
##### Explore Neurons #####
colors = plt.cm.cool(np.linspace(0,1,4))
clrs = ['blue','orange','green','red']
# Initialize movement combinations
titles = np.array(['Theta','Phi','Roll','Pitch']) # 'dg_p','dg_n' 'roll','pitch'
titles_all = []
for n in range(1,len(titles)+1):
    perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))

# train_dgaze_p = train_dth + np.diff(train_gz,append=0)
# train_dgaze_n = train_dth - np.diff(train_gz,append=0)
# test_dgaze_p = test_dth + np.diff(test_gz,append=0)
# test_dgaze_n = test_dth - np.diff(test_gz,append=0)
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis]))#, train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis])) #,test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))#
# move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))# test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))
model_move = np.hstack((model_th[:,np.newaxis],model_phi[:,np.newaxis],model_roll[:,np.newaxis],model_pitch[:,np.newaxis])) #,test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))#
model_move = model_move - np.mean(model_move,axis=0)
move_test = move_test - np.mean(move_test,axis=0)
# Create all tuning curves for plotting
N_bins=10
ncells = model_nsp.shape[-1]
ax_ylims = np.zeros((model_nsp.shape[-1],len(titles)))
tuning_curves = np.zeros((model_nsp.shape[-1],len(titles),N_bins-1))
tuning_stds = np.zeros((model_nsp.shape[-1],len(titles),N_bins-1))
var_ranges = np.zeros((len(titles),N_bins-1))
for modeln in range(len(titles)):
    metric = move_test[:,modeln]
    tuning, tuning_std, var_range = tuning_curve(test_nsp, metric, N_bins=N_bins, model_dt=model_dt)
    tuning_curves[:,modeln] = tuning
    tuning_stds[:,modeln] = tuning_std
    ax_ylims[:,modeln] = np.max(tuning,axis=1)
    var_ranges[modeln] = var_range

In [None]:
celln = 6 #np.argmax(mr2)
bin_length = 40
ncells=model_nsp.shape[-1]
colors = plt.cm.cool(np.linspace(0,1,4))
clrs = ['blue','orange','green','red']
quartiles = np.arange(0,1.25,.25)

fig, axs = plt.subplots(2,5, figsize=((35,10))) 

predcell = pred_all[:,celln]/model_dt
nspcell = test_nsp[:,celln]/model_dt
test_nsp_smooth=(np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
pred_smooth=(np.convolve(pred_all[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
# Set up predicted spike range between 1-99th percentile
stat_bins = 5
pred_range = np.quantile(predcell,[.1,.9])
test_nsp_range = np.quantile(nspcell,[.01,1])
spike_percentiles = np.arange(0,1.25,.25)
spike_percentiles[-1]=.99
spk_percentile2 = np.arange(.125,1.125,.25)
pred_rangelin = np.quantile(predcell,spike_percentiles)
xbin_pts = np.quantile(predcell,spk_percentile2)
stat_bins = len(pred_rangelin) #5


axs[0,0].plot(np.arange(len(test_nsp_smooth))*model_dt,test_nsp_smooth,'k',label='test FR')
axs[0,0].plot(np.arange(len(pred_smooth))*model_dt,pred_smooth,'r', label='pred FR')
axs[0,0].set_xlabel('Time (s)')
axs[0,0].set_ylabel('Firing Rate (spks/s)')
axs[0,0].legend()
axs[0,0].set_title('Smoothed FRs')

# Eye Tuning Curve
top_yaxs = np.max(ax_ylims[celln])+2*np.max(tuning_stds[celln])
for i,modeln in enumerate(range(len(titles)-2)):
    metric = move_test[:,modeln]
    nranges = np.quantile(metric,quartiles)
    stat_range, edges, _ = binned_statistic(metric,test_nsp[:,celln],statistic='mean',bins=nranges)
    edge_mids = np.quantile(metric,spk_percentile2)#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
    cmap = mpl.colors.ListedColormap(colors, N=colors.shape[0])
    norm = mpl.colors.BoundaryNorm(boundaries=np.floor(nranges), ncolors=len(cmap.colors))
    for m in range(len(nranges)-1):
        axs[0,1].axvspan(nranges[m], nranges[m+1],ymin=i*1/2,ymax=(i+1)*1/2,alpha=0.8, color=colors[m],zorder=0)
    #     axs[0,1].errorbar(var_ranges[modeln],tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln],label=titles[modeln],c=clrs[modeln],lw=4,elinewidth=3)
    axs[0,1].plot(edge_mids,stat_range/model_dt,'.-', ms=20, lw=4,c=clrs[modeln])

axs[0,1].set_ylim(bottom=0,top=np.max(ax_ylims,axis=1)[celln]+2*np.max(tuning_stds,axis=(1,2))[celln])
axs[0,1].set_xlim(-30,30)
axs[0,1].set_xlabel('Angle ($ ^{\degree}$)')
axs[0,1].set_ylabel('Spikes/s')
axs[0,1].set_title('Eye Tuning Curves')
lines = axs[0,1].get_lines()
legend1 = axs[0,1].legend([lines[0]],[titles[0]],bbox_to_anchor=(1.01, .2), fontsize=12)
legend2 = axs[0,1].legend([lines[1]],[titles[1]],bbox_to_anchor=(1.01, .9), fontsize=12)
axs[0,1].add_artist(legend1)

# Head Tuning Curves
top_yaxs = np.max(ax_ylims[celln])+2*np.max(tuning_stds[celln])
for i, modeln in enumerate(range(2,len(titles))):
    metric = move_test[:,modeln]
#     nranges = np.round(np.quantile(var_ranges[modeln],quartiles),decimals=1)
    nranges = np.round(np.quantile(metric,quartiles),decimals=1)
    stat_range, edges, _ = binned_statistic(metric,test_nsp[:,celln],statistic='mean',bins=nranges)
    edge_mids = np.quantile(metric,spk_percentile2)#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
    cmap = mpl.colors.ListedColormap(colors, N=colors.shape[0])
    norm = mpl.colors.BoundaryNorm(boundaries=np.floor(nranges), ncolors=len(cmap.colors))
    for m in range(len(nranges)-1):
        axs[0,2].axvspan(nranges[m], nranges[m+1],ymin=i*1/2,ymax=(i+1)*1/2,alpha=0.8, color=colors[m],zorder=0)
#     axs[0,2].errorbar(var_ranges[modeln], tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln], label=titles[modeln], c=clrs[modeln],lw=4,elinewidth=3)
    axs[0,2].plot(edge_mids,stat_range/model_dt,'.-', ms=20, lw=4,c=clrs[modeln])

axs[0,2].set_ylim(bottom=0,top=top_yaxs)
axs[0,2].set_xlim(-30,30)
axs[0,2].set_xlabel('Angle ($ ^{\degree}$)')
axs[0,2].set_ylabel('Spikes/s')
axs[0,2].set_title('Head Tuning Curves')
lines = axs[0,2].get_lines()
legend1 = axs[0,2].legend([lines[0]],[titles[2]],bbox_to_anchor=(1.01, .2), fontsize=12)
legend2 = axs[0,2].legend([lines[1]],[titles[3]],bbox_to_anchor=(1.01, .9), fontsize=12)
axs[0,2].add_artist(legend1)

# axs[0,2].legend(bbox_to_anchor=(1.01, 1), fontsize=12)


# pred_rangelin = np.linspace(pred_range[0],pred_range[1],stat_bins)
axs[0,3].scatter(pred_all[celln]/model_dt,test_nsp[celln]/model_dt,c='k',s=15)
axs[0,3].plot(np.linspace(test_nsp_range[0],test_nsp_range[1]),np.linspace(test_nsp_range[0],test_nsp_range[1]),'k--',zorder=0)
axs[0,3].set_xlabel('Predicted Spike Rate')
axs[0,3].set_ylabel('Actual Spike Rate')
cbar = add_colorbar(img)
# cbar.set_label('count')

if MovModel == 1:
    w_move = np.zeros((model_nsp.shape[-1],len(titles)))
elif MovModel == 3:
    Msta = w_move[:,:-len(titles)].reshape((model_nsp.shape[-1],nt_glm_lag,len(titles))+nks)
    w_move = w_move[:,-len(titles):]
for modeln in range(len(titles)):
    axs[0,4].bar(modeln, w_move[celln,modeln], color=clrs[modeln])
    axs[0,4].set_xticks(np.arange(0,len(titles)))
    axs[0,4].set_xticklabels(titles)
    axs[0,4].set_ylabel('GLM Weight')


mse_add = np.zeros((ncells,len(titles),len(quartiles)-1))
mse_mult = np.zeros((ncells,len(titles),len(quartiles)-1))
alpha_add = np.zeros((ncells,len(titles),len(quartiles)-1))
alpha_mult = np.zeros((ncells,len(titles),len(quartiles)-1))

traces = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
traces_mean = np.zeros((ncells,len(titles),stat_bins-1)) # (model_type,quartile,FR)
edges_all = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
# df_traces = pd.DataFrame([],columns=['modeln','quartile','FR']) 
for modeln in range(len(titles)):
    metric = move_test[:,modeln]
    nranges = np.quantile(metric,quartiles)# np.linspace(np.nanmean(metric)-2*np.nanstd(metric), np.nanmean(metric)+2*np.nanstd(metric),N_bins)
    stat_all, edges, _ = binned_statistic(predcell,nspcell, statistic='mean',bins=pred_rangelin)
    edge_mids = xbin_pts#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
    traces_mean[celln,modeln]=stat_all
    max_fr = np.max(stat_all)
#     axs[0,modeln].set_xlim(0,pred_range[1]+np.std(pred_range))
#     axs[0,modeln].set_ylim(0,np.max(stat)+np.std(stat))

    for n in range(len(nranges)-1):
        ind = np.where(((metric<=nranges[n+1])&(metric>nranges[n])))[0]
        pred = predcell[ind]
        sp = nspcell[ind]

        stat_range, edges, _ = binned_statistic(pred, sp, statistic='mean',bins=pred_rangelin)
        edge_mids = xbin_pts #np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
        traces[celln,modeln,n]=stat_range
        edges_all[celln,modeln,n]=edge_mids
        res_add = minimize_scalar(f_add,args=(stat_range/max_fr, stat_all/max_fr))
        res_mult = minimize_scalar(f_mult,args=(stat_range/max_fr, stat_all/max_fr))
        mse_add[celln, modeln, n] = res_add.fun
        mse_mult[celln, modeln, n] = res_mult.fun
        alpha_add[celln, modeln, n] = res_add.x
        alpha_mult[celln, modeln, n] = res_mult.x

        axs[1,modeln].plot(edge_mids, stat_range,'.-', c=colors[n],label='{:.02f} : {:.02f}'.format(nranges[n],nranges[n+1]),lw=4,ms=20,alpha=.9)
        axs[1,modeln].set_title('Metric: {}'.format(titles[modeln]), color=clrs[modeln])
        axs[1,modeln].set_xlabel('Predicted Spike Rate')
        axs[1,modeln].set_ylabel('Actual Spike Rate')
    
    axs[1,modeln].plot([0, 1], [0, 1], 'k--', transform=axs[1,modeln].transAxes, lw=2, zorder=0)
    axs[1,modeln].plot(edge_mids, stat_all,'.-', c='k', lw=5, ms=20, label='All_data', alpha=.8)
    axs[1,modeln].legend(bbox_to_anchor=(1.01, 1), fontsize=12)
#     axs[1,modeln].axis('equal')
    axs[1,modeln].set_xlim(left=0)
#     axs[1,modeln].set(xlim=lims, ylim=lims)
#     axs[1,modeln].set_xlim([0,xbin_pts[-1]])
    axs[1,modeln].set_ylim(bottom=0)

dmodel = mse_add[celln]-mse_mult[celln]
crange = np.max(np.abs(dmodel))
im = axs[1,-1].imshow(dmodel,cmap='seismic',vmin=-crange,vmax=crange)
axs[1,-1].set_yticks(np.arange(0,4))
axs[1,-1].set_yticklabels(titles)
axs[1,-1].set_ylabel('Movement Model')
axs[1,-1].set_xticks(np.arange(0,4))
axs[1,-1].set_xticklabels(['.25','.5','.75','1'])
axs[1,-1].set_xlabel('Quantile Range')
axs[1,-1].set_title('$MSE_{add}$ - $MSE_{mult}$')
cbar = add_colorbar(im)

plt.suptitle('Celln:{}, r2={:.03f}'.format(celln,r2_all[celln]),y=1,fontsize=30)
plt.tight_layout()


# fig.savefig(FigPath/'CellSummary_N{}_T{:02d}.png'.format(celln,nt_glm_lag), facecolor='white', transparent=True)

In [None]:
##### Make PDF of All Cells #####
bin_length=40
MovModel = 1
do_shuffle = False
ncells=model_nsp.shape[-1]
colors = plt.cm.cool(np.linspace(0,1,4))
clrs = ['blue','orange','green','red']
for model_dt in [.05]: #
    data, train_idx, test_idx = load_train_test(file_dict, save_dir, model_dt=model_dt, do_shuffle=do_shuffle, do_norm=False,free_move=free_move, has_imu=free_move, has_mouse=False)
    locals().update(data)
    if do_shuffle:
        GLM_Vis = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel))
    else:
        GLM_Vis = ioh5.load(save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag,MovModel))
    locals().update(GLM_Vis)
    ##### Explore Neurons #####
    colors = plt.cm.cool(np.linspace(0,1,4))
    clrs = ['blue','orange','green','red']
    # Initialize movement combinations
    titles = np.array(['Theta','Phi','Roll','Pitch']) # 'dg_p','dg_n' 'roll','pitch'
    titles_all = []
    for n in range(1,len(titles)+1):
        perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
        for ind in range(perms.shape[0]):
            titles_all.append('_'.join([t for t in titles[perms[ind]]]))

    # train_dgaze_p = train_dth + np.diff(train_gz,append=0)
    # train_dgaze_n = train_dth - np.diff(train_gz,append=0)
    # test_dgaze_p = test_dth + np.diff(test_gz,append=0)
    # test_dgaze_n = test_dth - np.diff(test_gz,append=0)
    move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis]))#, train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
    move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis])) #,test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))#
    # move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))# test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))
    model_move = np.hstack((model_th[:,np.newaxis],model_phi[:,np.newaxis],model_roll[:,np.newaxis],model_pitch[:,np.newaxis])) #,test_dgaze_p[:,np.newaxis],test_dgaze_n[:,np.newaxis]))#
    model_move = model_move - np.mean(model_move,axis=0)
    move_test = move_test - np.mean(move_test,axis=0)
    
    # Create all tuning curves for plotting
    N_bins=10
    ncells = model_nsp.shape[-1]
    ax_ylims = np.zeros((model_nsp.shape[-1],len(titles)))
    tuning_curves = np.zeros((model_nsp.shape[-1],len(titles),N_bins-1))
    tuning_stds = np.zeros((model_nsp.shape[-1],len(titles),N_bins-1))
    var_ranges = np.zeros((len(titles),N_bins-1))
    for modeln in range(len(titles)):
        metric = move_test[:,modeln]
        tuning, tuning_std, var_range = tuning_curve(test_nsp, metric, N_bins=N_bins, model_dt=model_dt)
        tuning_curves[:,modeln] = tuning
        tuning_stds[:,modeln] = tuning_std
        ax_ylims[:,modeln] = np.max(tuning,axis=1)
        var_ranges[modeln] = var_range
        
    quartiles = np.arange(0,1.25,.25)
    if do_shuffle:
        pdf_name = FigPath/ 'VisMov_dt{:03d}_T{:02d}_MovModel{:d}_CellSummary_shuff.pdf'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel)
    else:
        pdf_name = FigPath/ 'VisMov_dt{:03d}_T{:02d}_MovModel{:d}_CellSummary.pdf'.format(model_type,int(model_dt*1000),nt_glm_lag, MovModel)
    with PdfPages(pdf_name) as pdf:
        for celln in tqdm(range(model_nsp.shape[1])):            

            fig, axs = plt.subplots(2,5, figsize=((35,10))) 

            predcell = pred_all[:,celln]/model_dt
            nspcell = test_nsp[:,celln]/model_dt
            test_nsp_smooth=(np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
            pred_smooth=(np.convolve(pred_all[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
            # Set up predicted spike range between 1-99th percentile
            stat_bins = 5
            pred_range = np.quantile(predcell,[.1,.9])
            test_nsp_range = np.quantile(nspcell,[.01,1])
            spike_percentiles = np.arange(0,1.25,.25)
            spike_percentiles[-1]=.99
            spk_percentile2 = np.arange(.125,1.125,.25)
            pred_rangelin = np.quantile(predcell,spike_percentiles)
            xbin_pts = np.quantile(predcell,spk_percentile2)
            stat_bins = len(pred_rangelin) #5


            axs[0,0].plot(np.arange(len(test_nsp_smooth))*model_dt,test_nsp_smooth,'k',label='test FR')
            axs[0,0].plot(np.arange(len(pred_smooth))*model_dt,pred_smooth,'r', label='pred FR')
            axs[0,0].set_xlabel('Time (s)')
            axs[0,0].set_ylabel('Firing Rate (spks/s)')
            axs[0,0].legend()
            axs[0,0].set_title('Smoothed FRs')

            # Eye Tuning Curve
            top_yaxs = np.max(ax_ylims[celln])+2*np.max(tuning_stds[celln])
            for i,modeln in enumerate(range(len(titles)-2)):
                metric = move_test[:,modeln]
                nranges = np.quantile(metric,quartiles)
                stat_range, edges, _ = binned_statistic(metric,test_nsp[:,celln],statistic='mean',bins=nranges)
                edge_mids = np.quantile(metric,spk_percentile2)#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
                cmap = mpl.colors.ListedColormap(colors, N=colors.shape[0])
                norm = mpl.colors.BoundaryNorm(boundaries=np.floor(nranges), ncolors=len(cmap.colors))
                for m in range(len(nranges)-1):
                    axs[0,1].axvspan(nranges[m], nranges[m+1],ymin=i*1/2,ymax=(i+1)*1/2,alpha=0.8, color=colors[m],zorder=0)
                #     axs[0,1].errorbar(var_ranges[modeln],tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln],label=titles[modeln],c=clrs[modeln],lw=4,elinewidth=3)
                axs[0,1].plot(edge_mids,stat_range/model_dt,'.-', ms=20, lw=4,c=clrs[modeln])

            axs[0,1].set_ylim(bottom=0,top=np.max(ax_ylims,axis=1)[celln]+2*np.max(tuning_stds,axis=(1,2))[celln])
            axs[0,1].set_xlim(-30,30)
            axs[0,1].set_xlabel('Angle ($ ^{\degree}$)')
            axs[0,1].set_ylabel('Spikes/s')
            axs[0,1].set_title('Eye Tuning Curves')
            lines = axs[0,1].get_lines()
            legend1 = axs[0,1].legend([lines[0]],[titles[0]],bbox_to_anchor=(1.01, .2), fontsize=12)
            legend2 = axs[0,1].legend([lines[1]],[titles[1]],bbox_to_anchor=(1.01, .9), fontsize=12)
            axs[0,1].add_artist(legend1)

            # Head Tuning Curves
            top_yaxs = np.max(ax_ylims[celln])+2*np.max(tuning_stds[celln])
            for i, modeln in enumerate(range(2,len(titles))):
                metric = move_test[:,modeln]
            #     nranges = np.round(np.quantile(var_ranges[modeln],quartiles),decimals=1)
                nranges = np.round(np.quantile(metric,quartiles),decimals=1)
                stat_range, edges, _ = binned_statistic(metric,test_nsp[:,celln],statistic='mean',bins=nranges)
                edge_mids = np.quantile(metric,spk_percentile2)#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
                cmap = mpl.colors.ListedColormap(colors, N=colors.shape[0])
                norm = mpl.colors.BoundaryNorm(boundaries=np.floor(nranges), ncolors=len(cmap.colors))
                for m in range(len(nranges)-1):
                    axs[0,2].axvspan(nranges[m], nranges[m+1],ymin=i*1/2,ymax=(i+1)*1/2,alpha=0.8, color=colors[m],zorder=0)
            #     axs[0,2].errorbar(var_ranges[modeln], tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln], label=titles[modeln], c=clrs[modeln],lw=4,elinewidth=3)
                axs[0,2].plot(edge_mids,stat_range/model_dt,'.-', ms=20, lw=4,c=clrs[modeln])

            axs[0,2].set_ylim(bottom=0,top=top_yaxs)
            axs[0,2].set_xlim(-30,30)
            axs[0,2].set_xlabel('Angle ($ ^{\degree}$)')
            axs[0,2].set_ylabel('Spikes/s')
            axs[0,2].set_title('Head Tuning Curves')
            lines = axs[0,2].get_lines()
            legend1 = axs[0,2].legend([lines[0]],[titles[2]],bbox_to_anchor=(1.01, .2), fontsize=12)
            legend2 = axs[0,2].legend([lines[1]],[titles[3]],bbox_to_anchor=(1.01, .9), fontsize=12)
            axs[0,2].add_artist(legend1)

            # axs[0,2].legend(bbox_to_anchor=(1.01, 1), fontsize=12)


            # pred_rangelin = np.linspace(pred_range[0],pred_range[1],stat_bins)
            axs[0,3].scatter(pred_all[celln]/model_dt,test_nsp[celln]/model_dt,c='k',s=15)
            axs[0,3].plot(np.linspace(test_nsp_range[0],test_nsp_range[1]),np.linspace(test_nsp_range[0],test_nsp_range[1]),'k--',zorder=0)
            axs[0,3].set_xlabel('Predicted Spike Rate')
            axs[0,3].set_ylabel('Actual Spike Rate')
            cbar = add_colorbar(img)
            # cbar.set_label('count')

            if MovModel == 1:
                w_move = np.zeros((model_nsp.shape[-1],len(titles)))
            elif MovModel == 3:
                Msta = w_move[:,:-len(titles)].reshape((model_nsp.shape[-1],nt_glm_lag,len(titles))+nks)
                w_move = w_move[:,-len(titles):]
            for modeln in range(len(titles)):
                axs[0,4].bar(modeln, w_move[celln,modeln], color=clrs[modeln])
                axs[0,4].set_xticks(np.arange(0,len(titles)))
                axs[0,4].set_xticklabels(titles)
                axs[0,4].set_ylabel('GLM Weight')


            mse_add = np.zeros((ncells,len(titles),len(quartiles)-1))
            mse_mult = np.zeros((ncells,len(titles),len(quartiles)-1))
            alpha_add = np.zeros((ncells,len(titles),len(quartiles)-1))
            alpha_mult = np.zeros((ncells,len(titles),len(quartiles)-1))

            traces = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
            traces_mean = np.zeros((ncells,len(titles),stat_bins-1)) # (model_type,quartile,FR)
            edges_all = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
            # df_traces = pd.DataFrame([],columns=['modeln','quartile','FR']) 
            for modeln in range(len(titles)):
                metric = move_test[:,modeln]
                nranges = np.quantile(metric,quartiles)# np.linspace(np.nanmean(metric)-2*np.nanstd(metric), np.nanmean(metric)+2*np.nanstd(metric),N_bins)
                stat_all, edges, _ = binned_statistic(predcell,nspcell, statistic='mean',bins=pred_rangelin)
                edge_mids = xbin_pts#np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
                traces_mean[celln,modeln]=stat_all
                max_fr = np.max(stat_all)
            #     axs[0,modeln].set_xlim(0,pred_range[1]+np.std(pred_range))
            #     axs[0,modeln].set_ylim(0,np.max(stat)+np.std(stat))

                for n in range(len(nranges)-1):
                    ind = np.where(((metric<=nranges[n+1])&(metric>nranges[n])))[0]
                    pred = predcell[ind]
                    sp = nspcell[ind]

                    stat_range, edges, _ = binned_statistic(pred, sp, statistic='mean',bins=pred_rangelin)
                    edge_mids = xbin_pts #np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
                    traces[celln,modeln,n]=stat_range
                    edges_all[celln,modeln,n]=edge_mids
                    res_add = minimize_scalar(f_add,args=(stat_range/max_fr, stat_all/max_fr))
                    res_mult = minimize_scalar(f_mult,args=(stat_range/max_fr, stat_all/max_fr))
                    mse_add[celln, modeln, n] = res_add.fun
                    mse_mult[celln, modeln, n] = res_mult.fun
                    alpha_add[celln, modeln, n] = res_add.x
                    alpha_mult[celln, modeln, n] = res_mult.x

                    axs[1,modeln].plot(edge_mids, stat_range,'.-', c=colors[n],label='{:.02f} : {:.02f}'.format(nranges[n],nranges[n+1]),lw=4,ms=20,alpha=.9)
                    axs[1,modeln].set_title('Metric: {}'.format(titles[modeln]), color=clrs[modeln])
                    axs[1,modeln].set_xlabel('Predicted Spike Rate')
                    axs[1,modeln].set_ylabel('Actual Spike Rate')

                axs[1,modeln].plot([0, 1], [0, 1], 'k--', transform=axs[1,modeln].transAxes, lw=2, zorder=0)
                axs[1,modeln].plot(edge_mids, stat_all,'.-', c='k', lw=5, ms=20, label='All_data', alpha=.8)
                axs[1,modeln].legend(bbox_to_anchor=(1.01, 1), fontsize=12)
            #     axs[1,modeln].axis('equal')
                axs[1,modeln].set_xlim(left=0)
            #     axs[1,modeln].set(xlim=lims, ylim=lims)
            #     axs[1,modeln].set_xlim([0,xbin_pts[-1]])
                axs[1,modeln].set_ylim(bottom=0)

            dmodel = mse_add[celln]-mse_mult[celln]
            crange = np.max(np.abs(dmodel))
            im = axs[1,-1].imshow(dmodel,cmap='seismic',vmin=-crange,vmax=crange)
            axs[1,-1].set_yticks(np.arange(0,4))
            axs[1,-1].set_yticklabels(titles)
            axs[1,-1].set_ylabel('Movement Model')
            axs[1,-1].set_xticks(np.arange(0,4))
            axs[1,-1].set_xticklabels(['.25','.5','.75','1'])
            axs[1,-1].set_xlabel('Quantile Range')
            axs[1,-1].set_title('$MSE_{add}$ - $MSE_{mult}$')
            cbar = add_colorbar(im)

            plt.suptitle('Celln:{}, r2={:.03f}'.format(celln,r2_all[celln]),y=1,fontsize=30)
            plt.tight_layout()

            pdf.savefig()
            plt.close()

    # fig.savefig(FigPath/'CellSummary_N{}.png'.format(celln), facecolor='white', transparent=True)

# Testing Regularization

In [605]:
class PoissonGLM_VM_staticreg(nn.Module):
    def __init__(self, in_features, out_features, bias=True, reg_lam=None, reg_alph=None, move_features=None, meanfr=None, init_sta=None, device='cuda'):
        super(PoissonGLM_VM_staticreg, self).__init__()
        self.move_features = move_features
        if self.move_features != None:
            self.lam_m = reg_alph*torch.ones(out_features).to(device)
            self.alpha_m = reg_alph*torch.ones(out_features).to(device)
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        if init_sta != None:
            self.weight = torch.nn.Parameter(init_sta,requires_grad=True)
            self.init_sta = True
        else:
            self.init_sta = False
            self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features),)
        self.reg_lam = reg_lam
        self.reg_alph = reg_alph
        if bias:
            if meanfr != None:
                self.bias = torch.nn.Parameter(meanfr,requires_grad=True)
                self.meanfr = True
            else:
                self.meanfr = None
                self.bias = torch.nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        if self.reg_lam != None:
            self.lam = reg_lam*torch.ones(out_features).to(device)
        if self.reg_alph != None:
            self.alpha = reg_alph*torch.ones(out_features).to(device)
            
        self.lossfn = torch.nn.PoissonNLLLoss(log_input=True,reduction='mean')
        self.reset_parameters()
        
    def reset_parameters(self):
        if self.init_sta == False:
            torch.nn.init.kaiming_uniform_(self.weight) #, a=np.sqrt(5)
        if self.bias is not None:
            if self.meanfr == None:
                fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
                bound = 1 / np.sqrt(fan_in)
                torch.nn.init.uniform_(self.bias, -bound, bound)
        
    def forward(self, inputs):
        x, y = inputs.shape
        if y != self.in_features:
            print(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
            return 0
        output = inputs.matmul(self.weight.t())
        if self.bias is not None:
            output = output + self.bias
        ret = torch.log1p(torch.exp(output))
        return ret
    
    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )
    
    def loss(self,Yhat, Y): 
        if self.move_features != None:
#             l2_reg = self.lam*(torch.linalg.norm(self.weight[:,:-self.move_features],axis=1,ord=2))
            l1_reg = self.alpha*(torch.linalg.norm(self.weight[:,:-self.move_features],axis=1,ord=1))
#             l2_regm = self.lam_m*(torch.linalg.norm(self.weight[:,-self.move_features:],axis=1,ord=2))
            l1_regm = self.alpha_m*(torch.linalg.norm(self.weight[:,-self.move_features:],axis=1,ord=1))
            loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0) + l1_reg + l1_regm
        else:
            if self.reg_lam != None:
                if self.reg_alph != None:
                    l2_reg = self.lam*(torch.linalg.norm(self.weight,axis=1,ord=2))
                    l1_reg = self.alpha*(torch.linalg.norm(self.weight,axis=1,ord=1))
                    loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0) + l2_reg + l1_reg
                else:
                    l2_reg = self.lam*(torch.linalg.norm(self.weight,axis=1,ord=2))
                    loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0) + l2_reg
            else:
                if self.reg_alph != None:
                    l1_reg = self.alpha*(torch.linalg.norm(self.weight,axis=1,ord=1))
                    loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0) + l1_reg
                else:
                    loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0)
        return loss_vec


In [596]:
lag_list = np.array([-2,-1,0,1,2]) #np.array([-1,0,1,2,3]) #,np.arange(minlag,maxlag,np.floor((maxlag-minlag)/nt_glm_lag).astype(int))
nt_glm_lag = len(lag_list)
print(lag_list,1000*lag_list*model_dt)
do_shuffle = False
model_type = 'Pytorch'


# for do_shuffle in [False,True]:
# Load Data
data, train_idx, test_idx = load_train_test(file_dict, save_dir, model_dt=model_dt, do_shuffle=do_shuffle, do_norm=True,free_move=free_move, has_imu=free_move, has_mouse=False)
locals().update(data)

# Initialize movement combinations
titles = np.array(['Theta','Phi','Roll','Pitch']) # 'dg_p','dg_n' 'roll','pitch'
titles_all = []
for n in range(1,len(titles)+1):
    perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))

# train_dgaze_p = train_dth + np.diff(train_gz,append=0)
# train_dgaze_n = train_dth - np.diff(train_gz,append=0)
# test_dgaze_p = test_dth + np.diff(test_gz,append=0)
# test_dgaze_n = test_dth - np.diff(test_gz,append=0)
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis])) 
model_move = np.hstack((model_th[:,np.newaxis],model_phi[:,np.newaxis],model_roll[:,np.newaxis],model_pitch[:,np.newaxis]))
model_move = model_move - np.mean(model_move,axis=0)
move_test = move_test - np.mean(move_test,axis=0)

##### Start GLM Parallel Processing #####
nks = np.shape(train_vid)[1:]; nk = nks[0]*nks[1]*nt_glm_lag
n=4; ind=0
perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
##### Start GLM Parallel Processing #####
# Reshape data (video) into (T*n)xN array
rolled_vid = np.hstack([np.roll(model_vid_sm, nframes, axis=0) for nframes in lag_list]) # nt_glm_lag
x_train = rolled_vid[train_idx].reshape(len(train_idx),-1)
x_test = rolled_vid[test_idx].reshape(len(test_idx),-1)

MovModel = 2
# Reshape data (video) into (T*n)xN array
rolled_vid = np.hstack([np.roll(model_vid_sm, nframes, axis=0) for nframes in lag_list]) # nt_glm_lag
if MovModel == 0:
    mx_train = move_train[:,perms[ind]]
    mx_test = move_test[:,perms[ind]]
    xtr = torch.from_numpy(mx_train.astype(np.float32)).to(device)
    xte = torch.from_numpy(mx_test.astype(np.float32)).to(device)    
    move_features = mx_train.shape[-1]
    nk = 0
elif MovModel == 1:
    x_train_m1 = (rolled_vid[train_idx].reshape(len(train_idx),-1)).astype(np.float32)
    x_test_m1 = (rolled_vid[test_idx].reshape(len(test_idx),-1)).astype(np.float32)
    xtr = torch.from_numpy(x_train_m1).to(device)
    xte = torch.from_numpy(x_test_m1).to(device)
    move_features = None
elif MovModel == 2:
    x_train_m2 = np.concatenate((x_train,move_train[:,perms[ind]]),axis=1)
    x_test_m2 = np.concatenate((x_test,move_test[:,perms[ind]]),axis=1)
    xtr = torch.from_numpy(x_train_m2.astype(np.float32)).to(device)
    xte = torch.from_numpy(x_test_m2.astype(np.float32)).to(device)
    move_features = x_train_m2.shape[-1]-nk
elif MovModel == 3:
    x_train_m3 = np.hstack((x_train,np.hstack([x_train*move_train[:,modeln][:,np.newaxis] for modeln in np.arange(len(titles))]), move_train[:,perms[ind]]))
    x_test_m3 = np.hstack((x_test,np.hstack([x_test*move_test[:,modeln][:,np.newaxis] for modeln in np.arange(len(titles))]), move_test[:,perms[ind]]))
    xtr = torch.from_numpy(x_train_m3.astype(np.float32)).to(device)
    xte = torch.from_numpy(x_test_m3.astype(np.float32)).to(device)    
    move_features = x_train_m3.shape[-1]-nk

    
ytr = torch.from_numpy(train_nsp.astype(np.float32)).to(device)
yte = torch.from_numpy(test_nsp.astype(np.float32)).to(device)
print('move_features: {}'.format(move_features))

[-2 -1  0  1  2] [-100.  -50.    0.   50.  100.]
Done Loading Aligned Data
TRAIN: 15628 TEST: 6698
move_features: 4


In [None]:
def add_weight_decay(net, l2_value, skip_list=()):
    decay, no_decay = [], []
    for name, param in net.named_parameters():
        if not param.requires_grad: continue # frozen weights		            
        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: no_decay.append(param)
        else: decay.append(param)
    return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': l2_value}]


In [608]:
np.hstack((((rolled_vid_flat.T@model_nsp)/(np.sum(model_nsp,axis=0))).T.astype(np.float32),np.zeros((output_size,move_features)))).shape

(128, 3004)

In [None]:
input_size = xtr.shape[1]
output_size = ytr.shape[1]
lossfn = torch.nn.PoissonNLLLoss(log_input=True,reduction='mean')

Nbatches= 5000
if move_features != None:
    reg_params = np.zeros((Nbatches,output_size,4))
    reg_titles = ['lambda','lambda_m','alpha','alpha_m']
else:
    reg_params = np.zeros((Nbatches,output_size,2))
    reg_titles = ['lambda','alpha']

lambdas = (2**(np.arange(0,10)))
alphas = np.array([.005,.01,.02]) #np.arange(.01,.5,.05)
nlam = len(lambdas)
nalph = len(alphas)
if MovModel !=1:
    sta_init = torch.from_numpy(np.hstack((((rolled_vid_flat.T@model_nsp)/(np.sum(model_nsp,axis=0))).T,np.zeros((output_size,move_features)))).astype(np.float32))
else:
    sta_init = torch.from_numpy(((rolled_vid_flat.T@model_nsp)/(np.sum(model_nsp,axis=0))).T.astype(np.float32))
meanfr = torch.log1p(torch.mean(torch.tensor(model_nsp,dtype=torch.float32),axis=0))

msetrain = np.zeros((nalph,nlam,output_size))
msetest = np.zeros((nalph,nlam,output_size))
pred_all = np.zeros((x_test.shape[0],nalph,nlam,output_size)) 
w_cv = np.zeros((x_train.shape[-1],nalph,nlam,output_size))
w_intercept = np.zeros((nalph,nlam,output_size))
w_move_cv = np.zeros((nalph,nlam,output_size,move_features))
vloss_trace = np.zeros((Nbatches,output_size))      
tloss_trace = np.zeros((Nbatches,output_size))    
tloss_trace_all = np.zeros((nalph,nlam,Nbatches,output_size))
vloss_trace_all = np.zeros((nalph,nlam,Nbatches,output_size))
lr_w = [1e-6, 1e-4]
lr_b = [1e-5, 5e-3]
start = time.time()
for a, reg_alph in enumerate(tqdm(alphas)):
    for l, reg_lam in enumerate(tqdm(lambdas)):
        l1 = PoissonGLM_VM_staticreg(input_size,output_size,reg_lam=None,reg_alph=reg_alph,move_features=move_features,meanfr=meanfr,init_sta=sta_init,device=device).to(device)
#         params = add_weight_decay(l1,lambdas[l])
#         optimizer = optim.ASGD(params=[{'params': [l1.weight],'lr':5e-5,'weight_decay':lambdas[l]}, {'params': [l1.bias],'lr':1e-3}], lr=5e-5) #'weight_decay':lambdas[l]
        optimizer = optim.ASGD(params=[{'params': [l1.weight],'lr':lr_w[1],'weight_decay':lambdas[l]}, {'params': [l1.bias],'lr':lr_b[1]}], lr=5e-5) #
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=[lr_w[0],lr_b[0]], max_lr=[lr_w[1],lr_b[1]], cycle_momentum=False)
#         scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=[1e-6,1e-5], max_lr=[1e-4,5e-3], cycle_momentum=False)
        early_stopping = EarlyStopping(patience=1000,min_delta=.005)

        vloss_trace = np.zeros((Nbatches,output_size))      
        tloss_trace = np.zeros((Nbatches,output_size))  
        for batchn in np.arange(Nbatches):
            out = l1(xtr)
            loss = l1.loss(out,ytr)
            pred = l1(xte)
            val_loss = l1.loss(pred,yte)
            vloss_trace[batchn] = val_loss.clone().cpu().detach().numpy()
            tloss_trace[batchn] = loss.clone().cpu().detach().numpy()
            optimizer.zero_grad()
            loss.backward(torch.ones_like(loss))
            optimizer.step()
            scheduler.step()
            #     lam_grad[batchn]= l1.lam.grad.detach().cpu().numpy()
            early_stopping(np.mean(val_loss.clone().cpu().detach().numpy()))
#             if early_stopping.early_stop:
#                 break
        tloss_trace_all[a,l] = tloss_trace
        vloss_trace_all[a,l] = vloss_trace
        w_cv[:,a,l] = l1.weight.clone().cpu().detach().numpy()[:,:(nk)].T
        w_intercept[a,l] = l1.bias.clone().cpu().detach().numpy()
        if MovModel != 1:
            w_move_cv[a,l] = l1.weight.cpu().detach().numpy()[:,(nk):]
        pred =  l1(xte)
        msetest[a,l] = torch.mean(pred-yte*torch.log(pred),axis=0).cpu().detach().numpy()
        pred_cv[:,a,l] = pred.detach().cpu().numpy().squeeze()
    
print('GLM: ', time.time()-start)
  
    
# pred_all = l1(xte).cpu().detach().numpy()
if MovModel != 0:
    w_cv2 = w_cv.T.reshape((output_size,nlam,nalph,nt_glm_lag,)+nks)
#     sta_all = l1.weight.cpu().detach().numpy()[:,:(nk)].reshape((output_size,nt_glm_lag)+nks)
# if MovModel != 1:
#     w_move = l1.weight.cpu().detach().numpy()[:,(nk):]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
malph,mlam,cellnum  = np.where(msetest==np.nanmin(msetest,axis=(0,1), keepdims=True))
sortinds = cellnum.argsort()
cellnum = cellnum[sortinds]
malph = malph[sortinds]
mlam = mlam[sortinds]
sta_all = w_cv[:,malph,mlam,cellnum].T.reshape((output_size,nt_glm_lag,)+nks)
pred_all = pred_cv[:,malph,mlam,cellnum]

# m_cells, m_cinds = np.unique(m_cells,return_index=True)
# m_models = m_models[m_cinds]
# mcc = cc_all[m_cells,m_models]
# msp = sp_raw[m_cells,m_models]
# mpred = pred_raw[m_cells,m_models]
# mw_move = w_move_all[m_cells,m_models]
# mr2 = r2_all[m_cells,m_models]

In [None]:
bin_length=40
r2_all = np.zeros((output_size))
for celln in range(output_size):
    sp_smooth = ((np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    pred_smooth = ((np.convolve(pred_all[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    r2_all[celln] = (np.corrcoef(sp_smooth,pred_smooth)[0,1])**2

if MovModel == 0:
    GLM_Data = {'r2_all': r2_all,
                'test_nsp': test_nsp,
                'pred_all': pred_all,
                'w_move': w_move}
elif MovModel == 1:
    GLM_Data = {'r2_all': r2_all,
                'sta_all': sta_all,
                'test_nsp': test_nsp,
                'pred_all': pred_all,
               }
else:
    GLM_Data = {'r2_all': r2_all,
                'sta_all': sta_all,
                'test_nsp': test_nsp,
                'pred_all': pred_all,
                'w_move': w_move}

if do_shuffle:
    save_datafile = save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel)
else:
    save_datafile = save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel)
ioh5.save(save_datafile, GLM_Data)
print(save_datafile)

In [None]:
w_cv2 = w_cv.T.reshape((output_size,nlam,nalph,nt_glm_lag,)+nks)

In [None]:
bin_length=40
lam = np.argmin(msetest[1,:,celln])
print(lam)
alph = 0
for n, celln in enumerate(tqdm([21,51,25,117,126])):
    fig2 = plt.figure(constrained_layout=False, figsize=(20,7))
    spec2 = gridspec.GridSpec(ncols=nt_glm_lag, nrows=2, figure=fig2)
    axs = np.array([fig2.add_subplot(spec2[0, n]) for n in range(nt_glm_lag)])
    f2_ax6 = fig2.add_subplot(spec2[1, :nt_glm_lag//2])
    f2_ax7 = fig2.add_subplot(spec2[1, nt_glm_lag//2:])
    crange = np.max(np.abs(sta_all[celln]))
    for n,ax in enumerate(axs):
        im = ax.imshow(sta_all[celln,n],'RdBu_r',vmin=-crange,vmax=crange)
        cbar = add_colorbar(im)
        ax.axis('off')

    sp_smooth = ((np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    f2_ax6.plot(np.arange(len(sp_smooth))*model_dt,sp_smooth, 'k', lw=2)
    pred_smooth = ((np.convolve(pred_all.T[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    f2_ax6.plot(np.arange(len(pred_smooth))*model_dt,pred_smooth,'r', lw=2)
    f2_ax6.set_xlabel('Time (s)')
    f2_ax6.set_ylabel('Spike Rate')
    f2_ax7.plot(tloss_trace_all[alph,lam,:,celln])
    f2_ax7.plot(vloss_trace_all[alph,lam,:,celln])
    f2_ax7.set_xlabel('Batch #')
    f2_ax7.set_ylabel('Loss')
    r2 = (np.corrcoef(sp_smooth,pred_smooth)[0,1])**2
    plt.suptitle('celln: {} $r^2$:{:.03f}'.format(celln, r2))
    plt.tight_layout()

In [None]:
bin_length=40
lam = np.argmin(msetest[1,:,celln])
print(lam)
alph = 0
for n, celln in enumerate(tqdm([21,51,25,117,126])):
    fig2 = plt.figure(constrained_layout=False, figsize=(20,7))
    spec2 = gridspec.GridSpec(ncols=nt_glm_lag, nrows=2, figure=fig2)
    axs = np.array([fig2.add_subplot(spec2[0, n]) for n in range(nt_glm_lag)])
    f2_ax6 = fig2.add_subplot(spec2[1, :nt_glm_lag//2])
    f2_ax7 = fig2.add_subplot(spec2[1, nt_glm_lag//2:])
    crange = np.max(np.abs(w_cv2[celln,lam,alph]))
    for n,ax in enumerate(axs):
        im = ax.imshow(w_cv2[celln,lam,alph,n],'RdBu_r',vmin=-crange,vmax=crange)
        cbar = add_colorbar(im)
        ax.axis('off')

    sp_smooth = ((np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    f2_ax6.plot(np.arange(len(sp_smooth))*model_dt,sp_smooth, 'k', lw=2)
    pred_smooth = ((np.convolve(pred_all.T[celln,lam,alph], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    f2_ax6.plot(np.arange(len(pred_smooth))*model_dt,pred_smooth,'r', lw=2)
    f2_ax6.set_xlabel('Time (s)')
    f2_ax6.set_ylabel('Spike Rate')
    f2_ax7.plot(tloss_trace_all[alph,lam,:,celln])
    f2_ax7.plot(vloss_trace_all[alph,lam,:,celln])
    f2_ax7.set_xlabel('Batch #')
    f2_ax7.set_ylabel('Loss')
    r2 = (np.corrcoef(sp_smooth,pred_smooth)[0,1])**2
    plt.suptitle('celln: {} $r^2$:{:.03f}'.format(celln, r2))
    plt.tight_layout()

In [None]:
bin_length=40
alph = 0
lam = np.argmin(msetest[0,:,celln])
print(lam)
pdf_name = FigPath/ 'VisMov_{}_dt{:03d}_Lags{:02d}_MovModel{:d}_FitSummary.pdf'.format(model_type,int(model_dt*1000),nt_glm_lag, MovModel)
with PdfPages(pdf_name) as pdf:
    for n, celln in enumerate(tqdm(range(output_size))):
        fig2 = plt.figure(constrained_layout=False, figsize=(20,7))
        spec2 = gridspec.GridSpec(ncols=nt_glm_lag, nrows=2, figure=fig2)
        axs = np.array([fig2.add_subplot(spec2[0, n]) for n in range(nt_glm_lag)])
        f2_ax6 = fig2.add_subplot(spec2[1, :nt_glm_lag//2])
        f2_ax7 = fig2.add_subplot(spec2[1, nt_glm_lag//2:])
        crange = np.max(np.abs(w_cv2[celln,lam,alph]))
        for n,ax in enumerate(axs):
            im = ax.imshow(w_cv2[celln,lam,alph,n],'RdBu_r',vmin=-crange,vmax=crange)
            cbar = add_colorbar(im)
            ax.axis('off')

        sp_smooth = ((np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
        f2_ax6.plot(np.arange(len(sp_smooth))*model_dt,sp_smooth, 'k', lw=2)
        pred_smooth = ((np.convolve(pred_all.T[celln,lam,alph], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
        f2_ax6.plot(np.arange(len(pred_smooth))*model_dt,pred_smooth,'r', lw=2)
        f2_ax6.set_xlabel('Time (s)')
        f2_ax6.set_ylabel('Spike Rate')
        f2_ax7.plot(tloss_trace_all[alph,lam,:,celln])
        f2_ax7.plot(vloss_trace_all[alph,lam,:,celln])
        f2_ax7.set_xlabel('Batch #')
        f2_ax7.set_ylabel('Loss')
        r2 = (np.corrcoef(sp_smooth,pred_smooth)[0,1])**2
        plt.suptitle('celln: {} $r^2$:{:.03f}'.format(celln, r2))
        plt.tight_layout()
        pdf.savefig()
        plt.close()

In [None]:
plt.plot(test_nsp[:,celln])
plt.plot(pred2)

In [None]:
np.mean(pred2-Y*np.log(pred2))

In [None]:
np.mean(pred2-Y*np.log(pred2))

In [None]:
msetest[0,:,celln]

In [None]:
plt.plot(msetest[2,:,[21,51,25,117,126]])

In [None]:
input_size = xtr.shape[1]
output_size = ytr.shape[1]
Nbatches=5000
lossfn = torch.nn.PoissonNLLLoss(log_input=True,reduction='sum')

sps_train = train_nsp[:,celln]
sps_test = test_nsp[:,celln]
lambdas = 1024 * (2**np.arange(0,16))
# lambdas = 2**np.arange(0,18)# np.concatenate((np.arange(.1,1,.1),
alphas = np.array([1]) #np.arange(.01,.5,.05)
nlam = len(lambdas)
nalph = len(alphas)
# Initialze mse traces for regularization cross validation
msetrain = np.zeros((nalph,nlam,1))
msetest = np.zeros((nalph,nlam,1))
pred_all =np.zeros((x_test.shape[0],nalph,nlam)) 
w_cv = np.zeros((x_train.shape[-1],nalph,nlam))
w_intercept = np.zeros((nalph,nlam,1))
l1 = myLinear(input_size,output_size).to(device)
loss_trace_all = np.zeros((nalph,nlam,Nbatches))
# loop over regularization strength
for a,alpha in enumerate(tqdm(alphas)):
    for l in (range(len(lambdas))):
        # l1 = nn.Linear(input_size,output_size).to(device)
        # b.requires_grad=True
        # optimizer = optim.AdamW(lr=.01,params=l1.parameters(),weight_decay=2) #[W,b]
        l1.reset_parameters()
#         early_stopping = EarlyStopping()
        optimizer = optim.SGD(lr=.0001,params=l1.parameters(), weight_decay=0)#lambdas[l]) #
        # lossfn = nn.MSELoss(reduction='sum')

        loss_trace = np.zeros(Nbatches)
        for batchn in (np.arange(Nbatches)):
            out = l1(xtr)
    #         loss = torch.mean(2*(ytr*torch.log(ytr/out) + out + ytr)) + alpha*torch.sum(torch.abs(l1.weight)) + (1-alpha)*lambdas[l]*torch.sum(torch.norm(l1.weight))
    #         print(loss.item())
            loss = lossfn(out,ytr) + (1-alpha)*torch.sum(torch.abs(l1.weight)) + alpha*lambdas[l]*torch.norm(l1.weight)
            loss_trace[batchn] = loss.item()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            early_stopping(loss.item())
            if early_stopping.early_stop:
                break
        loss_trace_all[a,l] = loss_trace
        w_cv[:,a,l] = l1.weight.clone().cpu().detach().numpy()
        w_intercept[a,l] = l1.bias.clone().cpu().detach().numpy()
        pred_all[:,a,l] = l1(xte).clone().cpu().detach().numpy().squeeze()