# Import Modules

In [None]:
%load_ext autoreload
%autoreload 2

import os
import argparse
import glob
import sys 
import yaml 
import glob
import h5py 
import ray
import logging 
import json
import gc
import cv2
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
# import io_dict_to_hdf5 as ioh5
import xarray as xr

from tqdm.notebook import tqdm, trange
from matplotlib.backends.backend_pdf import PdfPages
from scipy import interpolate 
from scipy import signal
from pathlib import Path
from scipy.interpolate import interp1d
from scipy.ndimage import shift as imshift
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sys.path.append(str(Path('.').absolute().parent))
from utils import *
import io_dict_to_hdf5 as ioh5
from format_data import load_ephys_data_aligned

pd.set_option('display.max_rows', None)
FigPath = check_path(Path('~/Research/SensoryMotorPred_Data').expanduser(),'Figures/Decoding')

ray.init(
    ignore_reinit_error=True,
    logging_level=logging.ERROR,
)
print(f'Dashboard URL: http://{ray.get_dashboard_url()}')
print('Dashboard URL: http://localhost:{}'.format(ray.get_dashboard_url().split(':')[-1]))

# Gather Data

In [None]:
save_dir = Path('~/Research/SensoryMotorPred_Data/data/070921/J553RT/fm1').expanduser()
with open(save_dir / 'file_dict.json','r') as fp:
    file_dict = json.load(fp)

In [None]:
file_dict = {'cell': 0,
 'drop_slow_frames': True,
 'ephys': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_ephys_merge.json',
 'ephys_bin': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_Ephys.bin',
 'eye': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_REYE.nc',
 'imu': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_imu.nc',
 'mapping_json': '/home/seuss/Research/Github/FreelyMovingEphys/probes/channel_maps.json',
 'mp4': True,
 'name': '070921_J553RT_control_Rig2_fm1',
 'probe_name': 'DB_P128-6',
 'save': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1',
 'speed': None,
 'stim_type': 'light',
 'top': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_TOP1.nc',
 'world': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_world.nc'}

In [None]:
model_dt = .1
data = load_ephys_data_aligned(file_dict, save_dir, model_dt=model_dt)
nan_idxs = []
for key in data.keys():
    nan_idxs.append(np.where(np.isnan(data[key]))[0])
good_idxs = np.ones(len(data['model_active']),dtype=bool)
good_idxs[data['model_active']<.5] = False
good_idxs[np.unique(np.hstack(nan_idxs))] = False

raw_nsp = data['model_nsp'].copy()

In [None]:
for key in data.keys():
    if (key != 'model_nsp') & (key != 'model_active'):
#         movement_times = (data['model_active']>.5) & (~np.isnan(data[key]))
        data[key] = data[key][good_idxs] # interp_nans(data[key]).astype(float)
    elif (key == 'model_nsp'):
        data[key] = data[key][good_idxs]
locals().update(data)

In [None]:
plt.hist(data['model_active'], bins=100)
plt.axvline(x=.5)
# movement_times = (data['model_active']>.5) & (~np.isnan(data['model_th']))

In [None]:
##### Print memory of local variables #####
for name, size in sorted(((name, sys.getsizeof(value)) for name, value in locals().items()), key= lambda x: -x[1])[:10]:
    print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))

In [None]:
model_dth = np.diff(model_th,append=0)
model_dphi = np.diff(model_phi,append=0)

In [None]:
##### Group shuffle #####
from sklearn.model_selection import GroupShuffleSplit
gss = GroupShuffleSplit(n_splits=1, train_size=.7, random_state=42)
nT = model_nsp.shape[0]
frac = .2
groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((.2*(i-1))*nT)) for i in range(1,int(1/frac)+1)])
for train_idx, test_idx in gss.split(np.arange(len(model_nsp)), groups=groups):
    print("TRAIN:", train_idx, "TEST:", test_idx)
    
model_vid_sm = (model_vid_sm - np.mean(model_vid_sm,axis=0))/np.std(model_vid_sm,axis=0) 
model_th = (model_th - np.mean(model_th,axis=0))/np.std(model_th,axis=0) 
model_phi = (model_phi - np.mean(model_phi,axis=0))/np.std(model_phi,axis=0) 
model_roll = (model_roll - np.mean(model_roll,axis=0))/np.std(model_roll,axis=0) 
model_pitch = (model_pitch - np.mean(model_pitch,axis=0))/np.std(model_pitch,axis=0) 


train_vid = model_vid_sm[train_idx]
test_vid = model_vid_sm[test_idx]
train_nsp = model_nsp[train_idx]
test_nsp = model_nsp[test_idx]
train_th = model_th[train_idx]
test_th = model_th[test_idx]
train_phi = model_phi[train_idx]
test_phi = model_phi[test_idx]
train_roll = model_roll[train_idx]
test_roll = model_roll[test_idx]
train_pitch = model_pitch[train_idx]
test_pitch = model_pitch[test_idx]
train_t = model_t[train_idx]
test_t = model_t[test_idx]
train_dth = model_dth[train_idx]
test_dth = model_dth[test_idx]
train_dphi = model_dphi[train_idx]
test_dphi = model_dphi[test_idx]
train_gz = model_gz[train_idx]
test_gz = model_gz[test_idx]

In [None]:
plt.plot(train_vid[:1000,5,10])
plt.plot(train_vid[:1000,11,10])
plt.plot(train_vid[:1000,10,11])

In [None]:
plt.scatter(train_phi,train_pitch, alpha=.1)
# plt.scatter(train_dphi,train_roll, alpha=.1)
# plt.scatter(train_dth,train_roll, alpha=.1)
# plt.scatter(train_dth,train_pitch, alpha=.1)


In [None]:
# Create Tuning curve for theta
def tuning_curve(model_nsp, var, model_dt = .025, N_bins=10):
    var_range = np.linspace(np.nanmean(var)-2*np.nanstd(var), np.nanmean(var)+2*np.nanstd(var),N_bins)
    tuning = np.zeros((model_nsp.shape[-1],len(var_range)-1))
    tuning_std = np.zeros((model_nsp.shape[-1],len(var_range)-1))
    for n in range(model_nsp.shape[-1]):
        for j in range(len(var_range)-1):
            usePts = (var>=var_range[j]) & (var<var_range[j+1])
            tuning[n,j] = np.nanmean(model_nsp[usePts,n])/model_dt
            tuning_std[n,j] = (np.nanstd(model_nsp[usePts,n])/model_dt)/ np.sqrt(np.count_nonzero(usePts))
    return tuning, tuning_std, var_range[:-1]

In [None]:
tuning, tuning_std, var_range = tuning_curve(train_nsp, train_th, N_bins=10, model_dt=.1)

In [None]:
n = 22
fig, axs = plt.subplots(1,figsize=(7,5))
axs.errorbar(var_range,tuning[n], yerr=tuning_std[n])
axs.set_ylim(bottom=0)
axs.set_xlabel('Eye Theta')
axs.set_ylabel('Spikes/s')
axs.set_title('Neuron: {}'.format(n))
plt.tight_layout()
fig.savefig(FigPath/'ExampleTuningCurve.png',bbox_inches='tight',transparent=False, facecolor='w')

## GLM Check

In [None]:
import scipy.linalg as linalg
import scipy.sparse as sparse


In [None]:
model_vid = model_vid_sm
model_dt = .1
nks = np.shape(model_vid)[1:]; nk = nks[0]*nks[1]
nT = np.shape(model_nsp)[0]
x = model_vid.reshape(model_nsp.shape[0], -1).copy()
# image dimensions
n_units = np.shape(model_nsp)[1]
# subtract mean and renormalize -- necessary? 
mn_img = np.mean(x,axis=0)
x = x-mn_img
x = x/np.std(x,axis =0)
x = np.append(x,np.ones((nT,1)), axis = 1) # append column of ones

# set up prior matrix (regularizer)
# L2 prior
Imat = np.eye(nk)
Imat = linalg.block_diag(Imat,np.zeros((1,1)))
# smoothness prior
consecutive = np.ones((nk, 1))
consecutive[nks[1]-1::nks[1]] = 0
diff = np.zeros((1,2))
diff[0,0] = -1
diff[0,1]= 1
Dxx = sparse.diags((consecutive @ diff).T, np.array([0, 1]), (nk-1,nk))
Dxy = sparse.diags((np.ones((nk,1))@ diff).T, np.array([0, nks[1]]), (nk-nks[1], nk))
Dx = Dxx.T @ Dxx + Dxy.T @ Dxy
D  = linalg.block_diag(Dx.toarray(),np.zeros((1,1)))      
# summed prior matrix
# Cinv = D + Imat
Cinv = Imat
lag_list = [ -4, -2, 0 , 2, 4]
lambdas = 1024 * (2**np.arange(0,16))
nlam = len(lambdas)
# set up empty arrays for receptive field and cross correlation
sta_all = np.zeros((n_units, len(lag_list), nks[0], nks[1]))
cc_all = np.zeros((n_units,len(lag_list)))

## PCA on Vid

In [None]:
pca = PCA()
pcs = pca.fit_transform(model_vid_sm.reshape(-1,model_vid.shape[1]*model_vid.shape[2]))
plt.plot(np.cumsum(pca.explained_variance_ratio_))
comp_to_keep = np.where(np.cumsum(pca.explained_variance_ratio_)>.9)[0][0]
plt.axvline(x=comp_to_keep)
pca = PCA(n_components=comp_to_keep)
pcs = pca.fit_transform(model_vid_sm.reshape(-1,model_vid.shape[1]*model_vid.shape[2]))
print('keep {} PCs'.format(comp_to_keep))
# recon = pca.inverse_transform(pcs)

In [None]:
model_vid_sm.shape,model_th.shape,model_phi.shape,model_roll.shape,model_pitch.shape,

In [None]:
model_vid_sm = (model_vid_sm - np.mean(model_vid_sm,axis=0))/np.std(model_vid_sm,axis=0) 
model_th = (model_th - np.mean(model_th,axis=0))/np.std(model_th,axis=0) 
model_phi = (model_phi - np.mean(model_phi,axis=0))/np.std(model_phi,axis=0) 
model_roll = (model_roll - np.mean(model_roll,axis=0))/np.std(model_roll,axis=0) 
model_pitch = (model_pitch - np.mean(model_pitch,axis=0))/np.std(model_pitch,axis=0) 



model_dth = np.diff(model_th,append=0)
model_dphi = np.diff(model_phi,append=0)
train_vid, test_vid, train_nsp, test_nsp, train_th, test_th, train_phi, test_phi, train_roll, test_roll, train_pitch, test_pitch, train_t, test_t, train_dth, test_dth, train_dphi, test_dphi, train_pcs,test_pcs = \
train_test_split(model_vid_sm, model_nsp, model_th, model_phi, model_roll, model_pitch, model_t, model_dth, model_dphi, pcs, train_size=.7, shuffle=False, random_state=0)

# Ridge/Elastic Net

In [None]:
from sklearn import linear_model as lm # MultiTaskLassoCV, RidgeCV, MultiTaskElasticNetCV, LinearRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import itertools
model_type = 'ridgecv'
if model_type == 'elasticnetcv':
    model = make_pipeline(StandardScaler(), lm.ElasticNetCV()) # lm.RidgeCV(alphas=np.arange(100,10000,1000))) #  #MultiOutputRegressor(lm.Ridge(),n_jobs=-1)) 
elif model_type == 'ridgecv':
    model = make_pipeline(StandardScaler(), lm.RidgeCV(alphas=lambdas))


In [None]:
@ray.remote
def run_model(train_nsp, test_nsp, train_data, test_data, move_train, move_test, celln, lag, bin_length=80, model_dt=.1):
    sps_train = np.roll(train_nsp[:,celln],-lag)
    sps_test = np.roll(test_nsp[:,celln],-lag)

    #calculate a few terms
    x_train = train_data.reshape(train_data.shape[0],-1) #train_pcs
    x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis = 1) # append column of ones
    x_train = np.concatenate((x_train, move_train2),axis=1)

    x_test = test_data.reshape(test_data.shape[0],-1) #test_pcs
    x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis = 1) # append column of ones
    x_test = np.concatenate((x_test, move_test2),axis=1)

    model.fit(x_train,sps_train)

    sp_pred = model.predict(x_test)
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    cc = np.corrcoef(sp_smooth, pred_smooth)[0,1]
    sta = pca.inverse_transform(model[model_type].coef_[:-move_train.shape[-1]]).reshape(20,30)
    return sp_smooth, pred_smooth, cc, sta

In [None]:
model[model_type].coef_[:-move_train2.shape[-1]].shape

In [None]:
celln =22
bin_length = 80

titles = np.array(['th','phi','roll','pitch'])
cc_all = np.zeros((15,len(lag_list)))
sta_all = np.zeros((15,len(lag_list),20,30))
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis]))#,train_dth[:,np.newaxis], train_dphi[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis]))#,test_dth[:,np.newaxis], test_dphi[:,np.newaxis]))
sp_smooth_all = np.zeros((15,len(lag_list),move_test.shape[0]))
pred_smooth_all = np.zeros((15,len(lag_list),move_test.shape[0]))
titles_all = []

model_ind = 0
for n in range(1,5):
    perms = np.array(list(itertools.combinations([0,1,2,3], n)))
    for ind in range(perms.shape[0]):
        move_train2 = move_train[:,perms[ind]]
        move_test2 = move_test[:,perms[ind]]
        for lag_ind, lag in enumerate(lag_list):
            
            sps_train = np.roll(train_nsp[:,celln],-lag)
            sps_test = np.roll(test_nsp[:,celln],-lag)

            #calculate a few terms
            x_train = train_vid.reshape(train_vid.shape[0],-1) #train_pcs
#             x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis = 1) # append column of ones
            x_train = np.concatenate((x_train, move_train2),axis=1)

            x_test = test_vid.reshape(test_vid.shape[0],-1) #test_pcs
#             x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis = 1) # append column of ones
            x_test = np.concatenate((x_test, move_test2),axis=1)

            model.fit(x_train,sps_train)

            sp_pred = model.predict(x_test)
            sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
            pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
            sp_smooth_all[model_ind,lag_ind] = sp_smooth
            pred_smooth_all[model_ind,lag_ind] = pred_smooth
            cc_all[model_ind,lag_ind] = np.corrcoef(sp_smooth, pred_smooth)[0,1]
            sta_all[model_ind,lag_ind] = model[model_type].coef_[:-move_train2.shape[-1]].reshape(20,30)# pca.inverse_transform(model[model_type].coef_[:-move_train2.shape[-1]]).reshape(20,30)
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))
        model_ind+=1



In [None]:
with PdfPages(FigPath/ 'ModelSelection_{}_rawvid.pdf'.format(model_type)) as pdf:
    for model_ind,title in enumerate(titles_all):
        fig, axs = plt.subplots(2,len(lag_list), figsize=(np.floor(7.5*len(lag_list)).astype(int),10))
        for lag_ind in range(axs.shape[-1]):
            axs[0,lag_ind].plot(sp_smooth_all[model_ind,lag_ind],'k',label='smoothed FR')
            axs[0,lag_ind].plot(pred_smooth_all[model_ind,lag_ind],'r', label='pred FR')
            axs[0,lag_ind].set_title('cc={:.2f}'.format(cc_all[model_ind,lag_ind]))
            axs[1,lag_ind].imshow(sta_all[model_ind,lag_ind])
            axs[1,lag_ind].set_title('lag={:d}'.format(lag_list[lag_ind]))
            axs[1,lag_ind].axis('off')
            plt.suptitle(title)
            plt.tight_layout()
        pdf.savefig()
        plt.close()
print('Done Plotting!')

In [None]:
np.unravel_index(np.argmax(cc_all),shape=cc_all.shape)

In [None]:
cc_all.shape, sta_all.shape, pred_smooth_all.shape, sp_smooth_all.shape

# GLM with eye/head

In [None]:
##### Group shuffle #####
from sklearn.model_selection import GroupShuffleSplit
gss = GroupShuffleSplit(n_splits=1, train_size=.7, random_state=42)
nT = model_nsp.shape[0]
frac = .2
groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((.2*(i-1))*nT)) for i in range(1,int(1/frac)+1)])
for train_idx, test_idx in gss.split(np.arange(len(model_nsp)), groups=groups):
    print("TRAIN:", train_idx, "TEST:", test_idx)
    
train_vid = model_vid_sm[train_idx]
test_vid = model_vid_sm[test_idx]
train_nsp = model_nsp[train_idx]
test_nsp = model_nsp[test_idx]
train_th = model_th[train_idx]
test_th = model_th[test_idx]
train_phi = model_phi[train_idx]
test_phi = model_phi[test_idx]
train_roll = model_roll[train_idx]
test_roll = model_roll[test_idx]
train_pitch = model_pitch[train_idx]
test_pitch = model_pitch[test_idx]
train_t = model_t[train_idx]
test_t = model_t[test_idx]
train_dth = model_dth[train_idx]
test_dth = model_dth[test_idx]
train_dphi = model_dphi[train_idx]
test_dphi = model_dphi[test_idx]


In [None]:
model_vid_sm = (model_vid_sm - np.mean(model_vid_sm,axis=0))/np.std(model_vid_sm,axis=0) 
model_th = (model_th - np.mean(model_th,axis=0))/np.std(model_th,axis=0) 
model_phi = (model_phi - np.mean(model_phi,axis=0))/np.std(model_phi,axis=0) 
model_roll = (model_roll - np.mean(model_roll,axis=0))/np.std(model_roll,axis=0) 
model_pitch = (model_pitch - np.mean(model_pitch,axis=0))/np.std(model_pitch,axis=0) 



model_dth = np.diff(model_th,append=0)
model_dphi = np.diff(model_phi,append=0)
train_vid, test_vid, train_nsp, test_nsp, train_th, test_th, train_phi, test_phi, train_roll, test_roll, train_pitch, test_pitch, train_t, test_t, train_dth, test_dth, train_dphi, test_dphi, train_pcs,test_pcs = \
train_test_split(model_vid_sm, model_nsp, model_th, model_phi, model_roll, model_pitch, model_t, model_dth, model_dphi, pcs, test_size=.3, shuffle=False, random_state=0)

In [None]:
##### Train_Test Split with sklearn #####
model_vid = model_vid_sm
model_dt = .1
nks = np.shape(model_vid)[1:]; nk = nks[0]*nks[1]
nT = np.shape(model_nsp)[0]
x = model_vid.reshape(model_nsp.shape[0], -1).copy()
# image dimensions
n_units = np.shape(model_nsp)[1]

titles = np.array(['th','phi','roll','pitch'])
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis],train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis],test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))

perms = np.array([2,3]) #np.array(list(itertools.combinations([0,1,2,3], n)))
move_train = move_train[:,perms]
move_test = move_test[:,perms]

# set up prior matrix (regularizer)
# L2 prior
Imat = np.eye(nk)
Imat = linalg.block_diag(Imat,np.zeros((1+move_test.shape[-1],1+move_test.shape[-1])))
# smoothness prior
consecutive = np.ones((nk, 1))
consecutive[nks[1]-1::nks[1]] = 0
diff = np.zeros((1,2))
diff[0,0] = -1
diff[0,1]= 1
Dxx = sparse.diags((consecutive @ diff).T, np.array([0, 1]), (nk-1,nk))
Dxy = sparse.diags((np.ones((nk,1))@ diff).T, np.array([0, nks[1]]), (nk-nks[1], nk))
Dx = Dxx.T @ Dxx + Dxy.T @ Dxy
D  = linalg.block_diag(Dx.toarray(),np.zeros((1+move_test.shape[-1],1+move_test.shape[-1])))   
# summed prior matrix
# Cinv = D + Imat
Cinv = Imat

lag_list = [ -4, -2, 0 , 2, 4]
lambdas = 1024 * (2**np.arange(0,16))
nlam = len(lambdas)
# set up empty arrays for receptive field and cross correlation
sta_all = np.zeros((n_units, len(lag_list), nks[0], nks[1]))
cc_all = np.zeros((n_units,len(lag_list)))

celln = 51
fig, axs = plt.subplots(2,len(lag_list), figsize=(np.floor(7.5*len(lag_list)).astype(int),10))
for lag_ind, lag in enumerate(lag_list):
    
    sps_train = np.roll(train_nsp[:,celln],-lag)
    sps_test = np.roll(test_nsp[:,celln],-lag)
    
   
    #calculate a few terms
    x_train = train_vid.reshape(train_vid.shape[0],-1)
    x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis = 1) # append column of ones
    x_train = np.concatenate((x_train,move_train),axis=1)

    x_test = test_vid.reshape(test_vid.shape[0],-1)
    x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis = 1) # append column of ones
    x_test = np.concatenate((x_test,move_test),axis=1)
    
    XXtr = x_train.T @ x_train
    XYtr = x_train.T @ sps_train
    
    msetrain = np.zeros((nlam,1))
    msetest = np.zeros((nlam,1))
    w_ridge = np.zeros((nk+1+move_test.shape[1],nlam))
    # initial guess
    # loop over regularization strength
    for l in range(len(lambdas)):  
        # calculate MAP estimate               
        w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
        w_ridge[:,l] = w
        # calculate test and training rms error
        msetrain[l] = np.mean((sps_train - x_train@w)**2)
        msetest[l] = np.mean((sps_test - x_test@w)**2)
    # select best cross-validated lambda for RF
    best_lambda = np.argmin(msetest)
    w = w_ridge[:,best_lambda]
    ridge_rf = w_ridge[:,best_lambda]
    sta_all[celln,lag_ind,:,:] = np.reshape(w[:-(1+move_test.shape[-1])],nks)
    # plot predicted vs actual firing rate
    # predicted firing rate
    sp_pred = x_test@ridge_rf
    # bin the firing rate to get smooth rate vs time
    bin_length = 80
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all[celln,lag_ind] = cc[0,1]

    axs[0,lag_ind].plot(sp_smooth,'k',label='smoothed FR')
    axs[0,lag_ind].plot(pred_smooth,'r', label='pred FR')
    axs[0,lag_ind].set_title('cc={:.2f}'.format(cc_all[celln,lag_ind]))
    axs[1,lag_ind].imshow(sta_all[celln,lag_ind])
    axs[1,lag_ind].set_title('lag={:d}'.format(lag_list[lag_ind]))
    axs[1,lag_ind].axis('off')
    plt.suptitle('No Smoothness w/ movements')
    plt.tight_layout()

In [None]:
plt.plot(msetrain)
plt.plot(msetest)

In [None]:
# ##### Test_Train split with test=first, train=later #####
# model_vid = model_vid_sm
# model_dt = .1
# nks = np.shape(model_vid)[1:]; nk = nks[0]*nks[1]
# nT = np.shape(model_nsp)[0]
# x = model_vid.reshape(model_nsp.shape[0], -1).copy()
# # image dimensions
# n_units = np.shape(model_nsp)[1]
# # subtract mean and renormalize -- necessary? 
# mn_img = np.mean(x,axis=0)
# x = x-mn_img
# x = x/np.std(x,axis =0)
# x = np.append(x,np.ones((nT,1)), axis = 1) # append column of ones
# test_frac = 0.7
# ntest = int(nT*test_frac)
# titles = np.array(['th','phi','roll','pitch'])
# move_train = np.hstack((model_th[:ntest,np.newaxis],model_phi[:ntest,np.newaxis],model_roll[:ntest,np.newaxis],model_pitch[:ntest,np.newaxis],model_dth[:ntest,np.newaxis],model_dphi[:ntest,np.newaxis]))
# move_test = np.hstack((model_th[ntest:,np.newaxis],model_phi[ntest:,np.newaxis],model_roll[ntest:,np.newaxis],model_pitch[ntest:,np.newaxis],model_dth[ntest:,np.newaxis],model_dphi[ntest:,np.newaxis]))

# perms = np.array([2,3]) #np.array(list(itertools.combinations([0,1,2,3], n)))
# move_train = move_train[:,perms]
# move_test = move_test[:,perms]


# # set up prior matrix (regularizer)
# # L2 prior
# Imat = np.eye(nk)
# Imat = linalg.block_diag(Imat,np.zeros((1+move_test.shape[-1],1+move_test.shape[-1])))
# # smoothness prior
# consecutive = np.ones((nk, 1))
# consecutive[nks[1]-1::nks[1]] = 0
# diff = np.zeros((1,2))
# diff[0,0] = -1
# diff[0,1]= 1
# Dxx = sparse.diags((consecutive @ diff).T, np.array([0, 1]), (nk-1,nk))
# Dxy = sparse.diags((np.ones((nk,1))@ diff).T, np.array([0, nks[1]]), (nk-nks[1], nk))
# Dx = Dxx.T @ Dxx + Dxy.T @ Dxy
# D  = linalg.block_diag(Dx.toarray(),np.zeros((1+move_test.shape[-1],1+move_test.shape[-1])))   
# # summed prior matrix
# # Cinv = D + Imat
# Cinv = Imat

# lag_list = [ -4, -2, 0 , 2, 4]
# lambdas = 1024 * (2**np.arange(0,16))
# nlam = len(lambdas)
# # set up empty arrays for receptive field and cross correlation
# sta_all = np.zeros((n_units, len(lag_list), nks[0], nks[1]))
# cc_all = np.zeros((n_units,len(lag_list)))

# celln = 51
# fig, axs = plt.subplots(2,len(lag_list), figsize=(np.floor(7.5*len(lag_list)).astype(int),10))
# # iterate through timing lags
# for lag_ind, lag in enumerate(lag_list):
#     sps = np.roll(model_nsp.T[celln,:],-lag)
#     nT = len(sps)
#     #split training and test data
#     test_frac = 0.7
#     ntest = int(nT*test_frac)
#     x_train = x[:ntest,:] ; sps_train = sps[:ntest]
#     x_test = x[ntest:,:]; sps_test = sps[ntest:]

#     x_train = np.concatenate((x_train,move_train),axis=1)
#     x_test = np.concatenate((x_test,move_test),axis=1)

#     #calculate a few terms
#     sta = x_train.T@sps_train/np.sum(sps_train)
#     XXtr = x_train.T @ x_train
#     XYtr = x_train.T @sps_train
#     msetrain = np.zeros((nlam,1))
#     msetest = np.zeros((nlam,1))
#     w_ridge = np.zeros((nk+1+move_test.shape[-1],nlam))
#     # initial guess
#     w = sta
#     # loop over regularization strength
#     for l in range(len(lambdas)):  
#         # calculate MAP estimate               
#         w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
#         w_ridge[:,l] = w
#         # calculate test and training rms error
#         msetrain[l] = np.mean((sps_train - x_train@w)**2)
#         msetest[l] = np.mean((sps_test - x_test@w)**2)
#     # select best cross-validated lambda for RF
#     best_lambda = np.argmin(msetest)
#     w = w_ridge[:,best_lambda]
#     ridge_rf = w_ridge[:,best_lambda]
#     sta_all[celln,lag_ind,:,:] = np.reshape(w[:nk],nks)
#     # plot predicted vs actual firing rate
#     # predicted firing rate
#     sp_pred = x_test@ridge_rf
#     # bin the firing rate to get smooth rate vs time
#     bin_length = 80
#     sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
#     pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
#     # a few diagnostics
#     err = np.mean((sp_smooth-pred_smooth)**2)
#     cc = np.corrcoef(sp_smooth, pred_smooth)
#     cc_all[celln,lag_ind] = cc[0,1]
    

#     axs[0,lag_ind].plot(sp_smooth,'k',label='smoothed FR')
#     axs[0,lag_ind].plot(pred_smooth,'r', label='pred FR')
#     axs[0,lag_ind].set_title('cc={:.2f}'.format(cc_all[celln,lag_ind]))
#     axs[1,lag_ind].imshow(sta_all[celln,lag_ind])
#     axs[1,lag_ind].set_title('lag={:d}'.format(lag_list[lag_ind]))
#     axs[1,lag_ind].axis('off')
#     plt.suptitle('No Smoothness splitdata pipeline')
#     plt.tight_layout()

In [None]:
plt.plot(msetrain)
plt.plot(msetest)

In [None]:

model_vid = model_vid_sm
model_dt = .1
nks = np.shape(model_vid)[1:]; nk = nks[0]*nks[1]
nT = np.shape(model_nsp)[0]
x = model_vid.reshape(model_nsp.shape[0], -1).copy()
# image dimensions
n_units = np.shape(model_nsp)[1]
# subtract mean and renormalize -- necessary? 
mn_img = np.mean(x,axis=0)
x = x-mn_img
x = x/np.std(x,axis =0)
x = np.append(x,np.ones((nT,1)), axis = 1) # append column of ones

test_frac = 0.3
ntest = int(nT*test_frac)
titles = np.array(['th','phi','roll','pitch'])
move_train = np.hstack((model_th[ntest:,np.newaxis],model_phi[ntest:,np.newaxis],model_roll[ntest:,np.newaxis],model_pitch[ntest:,np.newaxis],model_dth[ntest:,np.newaxis],model_dphi[ntest:,np.newaxis]))
move_test = np.hstack((model_th[:ntest,np.newaxis],model_phi[:ntest,np.newaxis],model_roll[:ntest,np.newaxis],model_pitch[:ntest,np.newaxis],model_dth[:ntest,np.newaxis],model_dphi[:ntest,np.newaxis]))

perms = np.array([2,3]) #np.array(list(itertools.combinations([0,1,2,3], n)))
move_train = move_train[:,perms]
move_test = move_test[:,perms]

# set up prior matrix (regularizer)
# L2 prior
Imat = np.eye(nk)
Imat = linalg.block_diag(Imat,np.zeros((1+move_test.shape[-1],1+move_test.shape[-1])))
# smoothness prior
consecutive = np.ones((nk, 1))
consecutive[nks[1]-1::nks[1]] = 0
diff = np.zeros((1,2))
diff[0,0] = -1
diff[0,1]= 1
Dxx = sparse.diags((consecutive @ diff).T, np.array([0, 1]), (nk-1,nk))
Dxy = sparse.diags((np.ones((nk,1))@ diff).T, np.array([0, nks[1]]), (nk-nks[1], nk))
Dx = Dxx.T @ Dxx + Dxy.T @ Dxy
D  = linalg.block_diag(Dx.toarray(),np.zeros((1+move_test.shape[-1],1+move_test.shape[-1])))   
# summed prior matrix
# Cinv = D + Imat
Cinv = Imat

lag_list = [ -4, -2, 0 , 2, 4]
lambdas = 1024 * (2**np.arange(0,16))
nlam = len(lambdas)
# set up empty arrays for receptive field and cross correlation
sta_all = np.zeros((n_units, len(lag_list), nks[0], nks[1]))
cc_all = np.zeros((n_units,len(lag_list)))

celln = 51
fig, axs = plt.subplots(2,len(lag_list), figsize=(np.floor(7.5*len(lag_list)).astype(int),10))
# iterate through timing lags
for lag_ind, lag in enumerate(lag_list):
    sps = np.roll(model_nsp.T[celln,:],-lag)
    nT = len(sps)
    #split training and test data

    x_train = x[ntest:,:] ; sps_train = sps[ntest:]
    x_test = x[:ntest,:]; sps_test = sps[:ntest]
    
    
    x_train = np.concatenate((x_train,move_train),axis=1) # x_train*(1+alpha*model_th)
    x_test = np.concatenate((x_test,move_test),axis=1)

    #calculate a few terms
    sta = x_train.T@sps_train/np.sum(sps_train)
    XXtr = x_train.T @ x_train 
    XYtr = x_train.T @sps_train
    msetrain = np.zeros((nlam,1))
    msetest = np.zeros((nlam,1))
    w_ridge = np.zeros((nk+1+move_test.shape[-1],nlam))
    # initial guess
    w = sta
    # loop over regularization strength
    for l in range(len(lambdas)):  
        # calculate MAP estimate               
        w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
        w_ridge[:,l] = w
        # calculate test and training rms error
        msetrain[l] = np.mean((sps_train - x_train@w)**2)
        msetest[l] = np.mean((sps_test - x_test@w)**2)
    # select best cross-validated lambda for RF
    best_lambda = np.argmin(msetest)
    w = w_ridge[:,best_lambda]
    ridge_rf = w_ridge[:,best_lambda]
    sta_all[celln,lag_ind,:,:] = np.reshape(w[:nk],nks)
    # plot predicted vs actual firing rate
    # predicted firing rate
    sp_pred = x_test@ridge_rf
    # bin the firing rate to get smooth rate vs time
    bin_length = 40
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all[celln,lag_ind] = cc[0,1]
    

    axs[0,lag_ind].plot(sp_smooth,'k',label='smoothed FR')
    axs[0,lag_ind].plot(pred_smooth,'r', label='pred FR')
    axs[0,lag_ind].set_title('cc={:.2f}'.format(cc_all[celln,lag_ind]))
    axs[1,lag_ind].imshow(sta_all[celln,lag_ind])
    axs[1,lag_ind].set_title('lag={:d}'.format(lag_list[lag_ind]))
    axs[1,lag_ind].axis('off')
    plt.suptitle('With/out Smoothness splitdata pipeline')
    plt.tight_layout()

In [None]:
w[-move_test.shape[-1]:]

In [None]:
plt.plot(msetrain)
plt.plot(msetest)

## GLM Movement Only

In [None]:
model_type = 'ridgecv'
if model_type == 'elasticnetcv':
    model = make_pipeline(StandardScaler(), lm.ElasticNetCV()) # lm.RidgeCV(alphas=np.arange(100,10000,1000))) #  #MultiOutputRegressor(lm.Ridge(),n_jobs=-1)) 
elif model_type == 'ridgecv':
    model = make_pipeline(StandardScaler(), lm.RidgeCV(alphas=lambdas))


test_frac = 0.3
ntest = int(nT*test_frac)
titles = np.array(['th','phi','roll','pitch','dth','dphi'])
move_train = np.hstack((model_th[ntest:,np.newaxis],model_phi[ntest:,np.newaxis],model_roll[ntest:,np.newaxis],model_pitch[ntest:,np.newaxis],model_dth[ntest:,np.newaxis],model_dphi[ntest:,np.newaxis]))
move_test = np.hstack((model_th[:ntest,np.newaxis],model_phi[:ntest,np.newaxis],model_roll[:ntest,np.newaxis],model_pitch[:ntest,np.newaxis],model_dth[:ntest,np.newaxis],model_dphi[:ntest,np.newaxis]))
sps_smooth_all = np.zeros((15,len(lag_list),move_test.shape[0]))
pred_smooth_all = np.zeros((15,len(lag_list),move_test.shape[0]))
cc_all = np.zeros((n_units,15,len(lag_list)))
model_coef_all= [] # = np.zeros((15,len(lag_list)))
titles_all = []

celln = 51
bin_length = 80
model_ind = 0
with PdfPages(FigPath/ 'ModelSelection_{}_MoveOnly.pdf'.format(model_type)) as pdf:

    for n in range(1,5):
        perms = np.array(list(itertools.combinations([0,1,2,3], n)))
        for ind in range(perms.shape[0]):
            fig, axs = plt.subplots(1,len(lag_list), figsize=(np.floor(7.5*len(lag_list)).astype(int),5))
            move_train2 = move_train[:,perms[ind]]
            move_test2 = move_test[:,perms[ind]]
            # iterate through timing lags
            for lag_ind, lag in enumerate(lag_list):
                sps = np.roll(model_nsp.T[celln,:],-lag)
                nT = len(sps)
                sps_train = sps[ntest:]
                sps_test = sps[:ntest]

                model.fit(move_train2,sps_train)
                sps_pred = model.predict(move_test2)

                sps_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
                pred_smooth = (np.convolve(sps_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
                cc_all[celln,model_ind,lag_ind] = np.corrcoef(sps_smooth, pred_smooth)[0,1]
                sps_smooth_all[model_ind,lag_ind] = sps_smooth
                pred_smooth_all[model_ind,lag_ind] = pred_smooth
                model_coef_all.append(model[model_type].coef_)
                
                axs[lag_ind].plot(sps_smooth,'k',label='smoothed FR')
                axs[lag_ind].plot(pred_smooth,'r', label='pred FR')
                axs[lag_ind].set_title('cc={:.3f}'.format(cc_all[celln,model_ind,lag_ind]))
            #     axs[1,lag_ind].imshow(sta_all[celln,lag_ind])
            #     axs[1,lag_ind].set_title('lag={:d}'.format(lag_list[lag_ind]))
            #     axs[1,lag_ind].axis('off')
            #     plt.suptitle('No Smoothness splitdata pipeline')
                
            titles_all.append('_'.join([t for t in titles[perms[ind]]]))
            plt.suptitle(titles_all[-1])
            plt.tight_layout()
            pdf.savefig()
            
            model_ind+=1


In [None]:
np.unravel_index(np.argmax(cc_all),shape=cc_all.shape)

In [None]:
model_vid = model_vid_sm
model_dt = .1
nks = np.shape(model_vid)[1:]; nk = nks[0]*nks[1]
nT = np.shape(model_nsp)[0]
x = model_vid.reshape(model_nsp.shape[0], -1).copy()
# image dimensions
n_units = np.shape(model_nsp)[1]
# subtract mean and renormalize -- necessary? 
mn_img = np.mean(x,axis=0)
x = x-mn_img
x = x/np.std(x,axis =0)
x = np.append(x,np.ones((nT,1)), axis = 1) # append column of ones

# set up prior matrix (regularizer)
# L2 prior
Imat = np.eye(nk)
Imat = linalg.block_diag(Imat,np.zeros((1,1)))
# smoothness prior
consecutive = np.ones((nk, 1))
consecutive[nks[1]-1::nks[1]] = 0
diff = np.zeros((1,2))
diff[0,0] = -1
diff[0,1]= 1
Dxx = sparse.diags((consecutive @ diff).T, np.array([0, 1]), (nk-1,nk))
Dxy = sparse.diags((np.ones((nk,1))@ diff).T, np.array([0, nks[1]]), (nk-nks[1], nk))
Dx = Dxx.T @ Dxx + Dxy.T @ Dxy
D  = linalg.block_diag(Dx.toarray(),np.zeros((1,1)))   
# summed prior matrix
Cinv = D + Imat
# Cinv = Imat

lag_list = [ -4, -2, 0 , 2, 4]
lambdas = 1024 * (2**np.arange(0,16))
nlam = len(lambdas)
# set up empty arrays for receptive field and cross correlation
sta_all = np.zeros((n_units, len(lag_list), nks[0], nks[1]))
cc_all = np.zeros((n_units,len(lag_list)))

celln = 51
# iterate through timing lags
fig, axs = plt.subplots(2,len(lag_list), figsize=(np.floor(7.5*len(lag_list)).astype(int),10))
for lag_ind, lag in enumerate(lag_list):
    sps = np.roll(model_nsp.T[celln,:],-lag)
    nT = len(sps)
    #split training and test data
    test_frac = 0.3
    ntest = int(nT*test_frac)
    x_train = x[ntest:,:] ; sps_train = sps[ntest:]
    x_test = x[:ntest,:]; sps_test = sps[:ntest]
    #calculate a few terms
    sta = x_train.T@sps_train/np.sum(sps_train)
    XXtr = x_train.T @ x_train
    XYtr = x_train.T @sps_train
    msetrain = np.zeros((nlam,1))
    msetest = np.zeros((nlam,1))
    w_ridge = np.zeros((nk+1,nlam))
    # initial guess
    w = sta
    # loop over regularization strength
    for l in range(len(lambdas)):  
        # calculate MAP estimate               
        w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
        w_ridge[:,l] = w
        # calculate test and training rms error
        msetrain[l] = np.mean((sps_train - x_train@w)**2)
        msetest[l] = np.mean((sps_test - x_test@w)**2)
    # select best cross-validated lambda for RF
    best_lambda = np.argmin(msetest)
    w = w_ridge[:,best_lambda]
    ridge_rf = w_ridge[:,best_lambda]
    sta_all[celln,lag_ind,:,:] = np.reshape(w[:-1],nks)
    # plot predicted vs actual firing rate
    # predicted firing rate
    sp_pred = x_test@ridge_rf
    # bin the firing rate to get smooth rate vs time
    bin_length = 80
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all[celln,lag_ind] = cc[0,1]
    
    axs[0,lag_ind].plot(sp_smooth,'k',label='smoothed FR')
    axs[0,lag_ind].plot(pred_smooth,'r', label='pred FR')
    axs[0,lag_ind].set_title('cc={:.2f}'.format(cc_all[celln,lag_ind]))
    axs[1,lag_ind].imshow(sta_all[celln,lag_ind])
    axs[1,lag_ind].set_title('lag={:d}'.format(lag_list[lag_ind]))
    axs[1,lag_ind].axis('off')
    plt.suptitle('Current_pipeline')
    plt.tight_layout()

In [None]:
plt.plot(msetrain)
plt.plot(msetest)

# Parallel Processing GLM

In [None]:
@ray.remote
def do_glm_fit_mult(train_nsp, test_nsp, train_data, test_data, move_train, move_test, celln, perms, lag, Cinv, lambdas, bin_length=80, model_dt=.1):
    sps_train = np.roll(train_nsp[:,celln],-lag)
    sps_test = np.roll(test_nsp[:,celln],-lag)
    move_train = move_train[:,perms]
    move_test = move_test[:,perms]
    x_train = train_data.reshape(train_data.shape[0],-1) #train_pcs
    x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis = 1) # append column of ones
    x_train = np.concatenate((x_train, move_train),axis=1)

    x_test = test_data.reshape(test_data.shape[0],-1) #test_pcs
    x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis = 1) # append column of ones
    x_test = np.concatenate((x_test, move_test),axis=1)
    
    nlam = len(lambdas)
    XXtr = x_train.T @ x_train
    XYtr = x_train.T @ sps_train
    msetrain = np.zeros((nlam,1))
    msetest = np.zeros((nlam,1))
    w_ridge = np.zeros((x_train.shape[-1],nlam))
    
    # loop over regularization strength
    for l in range(len(lambdas)):  
        # calculate MAP estimate               
        w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
        w_ridge[:,l] = w
        # calculate test and training rms error
        msetrain[l] = np.mean((sps_train - x_train@w)**2)
        msetest[l] = np.mean((sps_test - x_test@w)**2)
        
    # select best cross-validated lambda for RF
    best_lambda = np.argmin(msetest)
    w = w_ridge[:,best_lambda]
    ridge_rf = w_ridge[:,best_lambda]
    sta_all = np.reshape(w[:-(1+move_test.shape[-1])],nks)
    # plot predicted vs actual firing rate
    # predicted firing rate
    sp_pred = x_test@ridge_rf
    # bin the firing rate to get smooth rate vs time
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all = cc[0,1]
    return cc_all, sta_all, sp_smooth, pred_smooth

In [None]:

@ray.remote
def do_glm_fit(train_nsp, test_nsp, train_data, test_data, move_train, move_test, celln, perms, lag, Cinv, lambdas, bin_length=40, model_dt=.1):

    sps_train = np.roll(train_nsp[:,celln],-lag)
    sps_test = np.roll(test_nsp[:,celln],-lag)
    move_train = move_train[:,perms]
    move_test = move_test[:,perms]
    x_train = train_data.reshape(train_data.shape[0],-1) #train_pcs
    x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis = 1) # append column of ones
    x_train = np.concatenate((x_train, move_train),axis=1)

    x_test = test_data.reshape(test_data.shape[0],-1) #test_pcs
    x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis = 1) # append column of ones
    x_test = np.concatenate((x_test, move_test),axis=1)
    
    nlam = len(lambdas)
#     sta = x_train.T@ sps_train/np.sum(sps_train)
    XXtr = x_train.T @ x_train
    XYtr = x_train.T @ sps_train
    msetrain = np.zeros((nlam,1))
    msetest = np.zeros((nlam,1))
    w_ridge = np.zeros((x_train.shape[-1],nlam))
    # initial guess
#     w = sta
    # loop over regularization strength
    for l in range(len(lambdas)):  
        # calculate MAP estimate               
        w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
        w_ridge[:,l] = w
        # calculate test and training rms error
        msetrain[l] = np.mean((sps_train - x_train@w)**2)
        msetest[l] = np.mean((sps_test - x_test@w)**2)
    # select best cross-validated lambda for RF
    best_lambda = np.argmin(msetest)
    w = w_ridge[:,best_lambda]
    ridge_rf = w_ridge[:,best_lambda]
    sta_all = np.reshape(w[:-(1+move_test.shape[-1])],nks)
    # plot predicted vs actual firing rate
    # predicted firing rate
    sp_pred = x_test@ridge_rf
    # bin the firing rate to get smooth rate vs time
#     bin_length = 80
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all = cc[0,1]
    return cc_all, sta_all, sp_smooth, pred_smooth

In [None]:
start = time.time()

titles = np.array(['th','phi','roll','pitch'])
titles_all = []
for n in range(1,5):
    perms = np.array(list(itertools.combinations([0,1,2,3], n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))
        
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis],train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis],test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))

lag_list = [ -4, -2, 0 , 2, 4]
lambdas = 1024 * (2**np.arange(0,16))

train_nsp_r = ray.put(train_nsp)
test_nsp_r = ray.put(test_nsp)
train_data_r = ray.put(train_vid)
test_data_r = ray.put(test_vid)
move_train_r = ray.put(move_train)
move_test_r = ray.put(move_test)
result_ids = []
celln = 51
for celln in range(train_nsp.shape[1]):
    for n in range(1,5):
        perms = np.array(list(itertools.combinations([0,1,2,3], n)))
        for ind in range(perms.shape[0]):
        
            move_train2 = move_train[:,perms[ind]]
            move_test2 = move_test[:,perms[ind]]

            # set up prior matrix (regularizer)
            # L2 prior
            Imat = np.eye(nk)
            Imat = linalg.block_diag(Imat,np.zeros((1+move_test2.shape[-1],1+move_test2.shape[-1])))
            # smoothness prior
            consecutive = np.ones((nk, 1))
            consecutive[nks[1]-1::nks[1]] = 0
            diff = np.zeros((1,2))
            diff[0,0] = -1
            diff[0,1]= 1
            Dxx = sparse.diags((consecutive @ diff).T, np.array([0, 1]), (nk-1,nk))
            Dxy = sparse.diags((np.ones((nk,1))@ diff).T, np.array([0, nks[1]]), (nk-nks[1], nk))
            Dx = Dxx.T @ Dxx + Dxy.T @ Dxy
            D  = linalg.block_diag(Dx.toarray(),np.zeros((1+move_test2.shape[-1],1+move_test2.shape[-1])))   
            # summed prior matrix
            # Cinv = D + Imat
            Cinv = Imat

            for lag_ind, lag in enumerate(lag_list):    
                result_ids.append(do_glm_fit.remote(train_nsp_r, test_nsp_r, train_data_r, test_data_r, move_train_r, move_test_r, celln, perms[ind], lag, Cinv, lambdas))
                      
results_p = ray.get(result_ids)
print('GLM: ', time.time()-start)

In [None]:
cc_all = np.stack([results_p[i][0] for i in range(len(results_p))])
sta_all = np.stack([results_p[i][1] for i in range(len(results_p))])
sp_smooth = np.stack([results_p[i][2] for i in range(len(results_p))])
pred_smooth = np.stack([results_p[i][3] for i in range(len(results_p))])

In [None]:
cc_all = cc_all.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),) + cc_all.shape[1:])
sta_all = sta_all.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),) + sta_all.shape[1:])
sp_smooth = sp_smooth.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),) + sp_smooth.shape[1:])
pred_smooth = pred_smooth.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),) + pred_smooth.shape[1:])

In [None]:
m_cells, m_models, m_lags = np.where(cc_all==np.max(cc_all,axis=(-2,-1), keepdims=True))

In [None]:
mcc = cc_all[m_cells,m_models,m_lags]
msta = sta_all[m_cells,m_models,m_lags]
msp = sp_smooth[m_cells,m_models,m_lags]
mpred = pred_smooth[m_cells,m_models,m_lags]

In [None]:
import plotly.express as px

fig = px.imshow(msta, animation_frame=0, binary_string=False,color_continuous_scale='RdBu_r')
fig.update_layout(width=500,
                  height=500,
                 )
fig.show()

In [None]:
##### Explore Neurons #####
ind = 25
fig, axs = plt.subplots(1,2, figsize=((15,5))) #np.floor(7.5*len(model_nsp)).astype(int)
axs[0].plot(msp[ind],'k',label='test FR')
axs[0].plot(mpred[ind],'r', label='pred FR')
axs[0].legend()
axs[0].set_title('cc={:.2f}, {}, \n lag={:d}'.format(mcc[ind],titles_all[m_models[ind]],lag_list[m_lags[ind]]))
axs[1].imshow(msta[ind],cmap='seismic')
axs[1].axis('off')
plt.tight_layout()

In [None]:
with PdfPages(FigPath/ 'MaxCC_{}.pdf'.format(model_type)) as pdf:
    for ind in range(0,model_nsp.shape[0]):
        
        fig, axs = plt.subplots(10,2, figsize=((15,3*10))) #np.floor(7.5*len(model_nsp)).astype(int)
        axs[ind,0].plot(msp[ind],'k',label='smoothed FR')
        axs[ind,0].plot(mpred[ind],'r', label='pred FR')
        axs[ind,0].set_title('cc={:.2f}, {}, \n lag={:d}'.format(mcc[ind],titles_all[m_models[ind]],lag_list[m_lags[ind]]))
        axs[ind,1].imshow(msta[ind])
        axs[ind,1].axis('off')
        plt.tight_layout()
        pdf.savefig()

In [None]:
celln=33
model_ind=13
fig, axs = plt.subplots(2,len(lag_list), figsize=(np.floor(7.5*len(lag_list)).astype(int),10))
for lag_ind, lag in enumerate(lag_list):
    axs[0,lag_ind].plot(sp_smooth[celln,model_ind,lag_ind],'k',label='smoothed FR')
    axs[0,lag_ind].plot(pred_smooth[celln,model_ind,lag_ind],'r', label='pred FR')
    axs[0,lag_ind].set_title('cc={:.2f}'.format(cc_all[celln,model_ind,lag_ind]))
    axs[1,lag_ind].imshow(sta_all[celln,model_ind,lag_ind])
    axs[1,lag_ind].set_title('lag={:d}'.format(lag_list[lag_ind]))
    axs[1,lag_ind].axis('off')
    plt.suptitle(titles_all[model_ind])
    plt.tight_layout()

In [None]:
cc_all.reshape(n_units,len(lag_list))

In [None]:
@ray.remote
def do_glm_fit(model_nsp, x, celln, lag, model_dt, lambdas, lag_list, test_frac=.3, bin_length=80):
    sps = np.roll(model_nsp[celln,:],-lag)
    nT = len(sps)
    #split training and test data
#     test_frac = 0.3
    ntest = int(nT*test_frac)
    x_train = x[ntest:,:] ; sps_train = sps[ntest:]
    x_test = x[:ntest,:]; sps_test = sps[:ntest]
    #calculate a few terms
    sta = x_train.T@ sps_train/np.sum(sps_train)
    XXtr = x_train.T @ x_train
    XYtr = x_train.T @ sps_train
    msetrain = np.zeros((nlam,1))
    msetest = np.zeros((nlam,1))
    w_ridge = np.zeros((nk+1,nlam))
    # initial guess
    w = sta
    # loop over regularization strength
    for l in range(len(lambdas)):  
        # calculate MAP estimate               
        w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
        w_ridge[:,l] = w
        # calculate test and training rms error
        msetrain[l] = np.mean((sps_train - x_train@w)**2)
        msetest[l] = np.mean((sps_test - x_test@w)**2)
    # select best cross-validated lambda for RF
    best_lambda = np.argmin(msetest)
    w = w_ridge[:,best_lambda]
    ridge_rf = w_ridge[:,best_lambda]
    sta_all = np.reshape(w[:-1],nks)
    # plot predicted vs actual firing rate
    # predicted firing rate
    sp_pred = x_test@ridge_rf
    # bin the firing rate to get smooth rate vs time
#     bin_length = 80
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all = cc[0,1]
    return cc_all, sta_all, sp_smooth, pred_smooth

In [None]:
start = time.time()
model_nsp_r = ray.put(model_nsp.T)
x_r = ray.put(x)
model_dt_r = ray.put(model_dt)
result_ids = []
[result_ids.append(do_glm_fit.remote(model_nsp_r, x_r, celln, lag, model_dt_r, lambdas, lag_list, test_frac=.3, bin_length=80)) for celln in range(model_nsp.shape[1]) for lag in lag_list]
results_p = ray.get(result_ids)
print('GLM: ', time.time()-start)

In [None]:
cc_all = np.stack([results_p[i][0] for i in range(len(results_p))])
sta_all = np.stack([results_p[i][1] for i in range(len(results_p))])
sp_smooth = np.stack([results_p[i][2] for i in range(len(results_p))])
pred_smooth = np.stack([results_p[i][3] for i in range(len(results_p))])

In [None]:
cc_all = cc_all.reshape((model_nsp.shape[1],len(lag_list),) + cc_all.shape[1:])
sta_all = sta_all.reshape((model_nsp.shape[1],len(lag_list),) + sta_all.shape[1:])
sp_smooth = sp_smooth.reshape((model_nsp.shape[1],len(lag_list),) + sp_smooth.shape[1:])
pred_smooth = pred_smooth.reshape((model_nsp.shape[1],len(lag_list),) + pred_smooth.shape[1:])

In [None]:
cc_all.shape,sta_all.shape,sp_smooth.shape,pred_smooth.shape,

In [None]:
plt.imshow(sta_all[51,2])
plt.colorbar()

In [None]:
model_vid = model_vid_sm
model_dt = .1
nks = np.shape(model_vid)[1:]; nk = nks[0]*nks[1]
nT = np.shape(model_nsp)[0]
x = model_vid.reshape(model_nsp.shape[0], -1).copy()
# image dimensions
n_units = np.shape(model_nsp)[1]
# subtract mean and renormalize -- necessary? 
mn_img = np.mean(x,axis=0)
x = x-mn_img
x = x/np.std(x,axis =0)
x = np.append(x,np.ones((nT,1)), axis = 1) # append column of ones

# set up prior matrix (regularizer)
# L2 prior
Imat = np.eye(nk)
Imat = linalg.block_diag(Imat,np.zeros((1,1)))
# smoothness prior
consecutive = np.ones((nk, 1))
consecutive[nks[1]-1::nks[1]] = 0
diff = np.zeros((1,2))
diff[0,0] = -1
diff[0,1]= 1
Dxx = sparse.diags((consecutive @ diff).T, np.array([0, 1]), (nk-1,nk))
Dxy = sparse.diags((np.ones((nk,1))@ diff).T, np.array([0, nks[1]]), (nk-nks[1], nk))
Dx = Dxx.T @ Dxx + Dxy.T @ Dxy
D  = linalg.block_diag(Dx.toarray(),np.zeros((1,1)))   
# summed prior matrix
Cinv = D + Imat
lag_list = [ -4, -2, 0 , 2, 4]
lambdas = 1024 * (2**np.arange(0,16))
nlam = len(lambdas)
# set up empty arrays for receptive field and cross correlation
sta_all = np.zeros((n_units, len(lag_list), nks[0], nks[1]))
cc_all = np.zeros((n_units,len(lag_list)))
# iterate through units
for celln in tqdm(range(n_units)):
    
    # iterate through timing lags
    for lag_ind, lag in enumerate(lag_list):
        sps = np.roll(model_nsp.T[celln,:],-lag)
        nT = len(sps)
        #split training and test data
        test_frac = 0.3
        ntest = int(nT*test_frac)
        x_train = x[ntest:,:] ; sps_train = sps[ntest:]
        x_test = x[:ntest,:]; sps_test = sps[:ntest]
        #calculate a few terms
        sta = x_train.T@sps_train/np.sum(sps_train)
        XXtr = x_train.T @ x_train
        XYtr = x_train.T @sps_train
        msetrain = np.zeros((nlam,1))
        msetest = np.zeros((nlam,1))
        w_ridge = np.zeros((nk+1,nlam))
        # initial guess
        w = sta
        # loop over regularization strength
        for l in range(len(lambdas)):  
            # calculate MAP estimate               
            w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
            w_ridge[:,l] = w
            # calculate test and training rms error
            msetrain[l] = np.mean((sps_train - x_train@w)**2)
            msetest[l] = np.mean((sps_test - x_test@w)**2)
        # select best cross-validated lambda for RF
        best_lambda = np.argmin(msetest)
        w = w_ridge[:,best_lambda]
        ridge_rf = w_ridge[:,best_lambda]
        sta_all[celln,lag_ind,:,:] = np.reshape(w[:-1],nks)
        # plot predicted vs actual firing rate
        # predicted firing rate
        sp_pred = x_test@ridge_rf
        # bin the firing rate to get smooth rate vs time
        bin_length = 80
        sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
        pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
        # a few diagnostics
        err = np.mean((sp_smooth-pred_smooth)**2)
        cc = np.corrcoef(sp_smooth, pred_smooth)
        cc_all[celln,lag_ind] = cc[0,1]
        
# # figure of receptive fields
# fig = plt.figure(figsize=(25,np.int(np.ceil(n_units/3))),dpi=50)
# for celln in tqdm(range(n_units)):
#     for lag_ind, lag in enumerate(lag_list):
#         crange = np.max(np.abs(sta_all[celln,:,:,:]))
#         plt.subplot(n_units,6,(celln*6)+lag_ind + 1)  
#         plt.imshow(sta_all[celln, lag_ind, :, :], vmin=-crange, vmax=crange, cmap='jet')
#         plt.title('cc={:.2f}'.format (cc_all[celln,lag_ind]),fontsize=5)

In [None]:
# figure of receptive fields
fig = plt.figure(figsize=(25,256),dpi=50)
for celln in tqdm(range(n_units)):
    for lag_ind, lag in enumerate(lag_list):
        crange = np.max(np.abs(sta_all[celln,:,:,:]))
        plt.subplot(n_units,6,(celln*6)+lag_ind + 1)  
        plt.imshow(sta_all[celln, lag_ind, :, :], vmin=-crange, vmax=crange, cmap='jet')
        plt.title('cc={:.2f}'.format (cc_all[celln,lag_ind]),fontsize=5)
        plt.axis('off')
plt.tight_layout()

In [None]:
fig.savefig(save_dir/'STA1_5.pdf')

# Linear Models

In [None]:
import joblib
from ray.util.joblib import register_ray
from sklearn import linear_model as lm # MultiTaskLassoCV, RidgeCV, MultiTaskElasticNetCV, LinearRegression
from sklearn import svm
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.multioutput import MultiOutputRegressor
from sklearn.ensemble import GradientBoostingRegressor

## Regression on Movement

In [None]:
Y_train = train_th #np.stack((train_th, train_phi),axis=1) # StandardScaler().fit_transform() # train_vid.reshape(train_vid.shape[0],-1)#[:,10:11] # np.stack((train_roll, train_pitch),axis=1) # 
Y_test = test_th #np.stack((test_th, test_phi),axis=1) # StandardScaler().fit_transform() # test_vid.reshape(test_vid.shape[0],-1)#[:,10:11] # np.stack((test_roll, test_pitch),axis=1) # 

In [None]:
model_type = 'ridgecv'
if model_type == 'elasticnetcv':
    model = make_pipeline(StandardScaler(), lm.ElasticNetCV()) # lm.RidgeCV(alphas=np.arange(100,10000,1000))) #  #MultiOutputRegressor(lm.Ridge(),n_jobs=-1)) 
elif model_type == 'ridgecv':
    model = make_pipeline(StandardScaler(), lm.RidgeCV(alphas=np.arange(100,10000,1000)))

model.fit(train_nsp, Y_train)
pred_train = model.predict(train_nsp)
pred_test = model.predict(test_nsp)
train_score = model.score(train_nsp,Y_train)
test_score = model.score(test_nsp, Y_test)
print('Train Score:', train_score, 'Test Score:', test_score)
# print(model['model_type'].coef_[22])

##### Flip test and train #####
# model2 = make_pipeline(StandardScaler(), lm.RidgeCV())
# model2.fit(test_nsp, Y_test)
# pred_train = model2.predict(test_nsp)
# pred_test = model2.predict(train_nsp)
# train_score = model2.score(test_nsp,Y_test)
# test_score = model2.score(train_nsp, Y_train)
# print('Train Score:', train_score, 'Test Score:', test_score)
# print(model2['ridgecv'].coef_[22])


In [None]:
model[model_type].alpha_

In [None]:
Y_train.shape

In [None]:
t = 0
dt = 10000
plt.plot(np.arange(t,t+dt),Y_train[t:t+dt])
plt.plot(np.arange(t,t+dt), pred_train[t:t+dt])

In [None]:
t = 100
dt = 100
fig, axs = plt.subplots(1,figsize=(7,5))
cc = np.corrcoef(Y_test,pred_test)[0,1]
axs.plot(np.arange(t,t+dt)*model_dt,Y_test[t:t+dt], 'k', label='Ground Truth')
axs.plot(np.arange(t,t+dt)*model_dt,pred_test[t:t+dt], 'r', label='Prediction')
axs.set_title('CorrCoeff: {:.02f}'.format(cc))
axs.set_xlabel('Time (s)')
# axs.set_ylabel('Eye Phi Angle')
axs.legend()
plt.tight_layout()
# fig.savefig(FigPath/'LinearRegressionExample_phi.png',bbox_inches='tight',transparent=False, facecolor='w')

In [None]:
np.corrcoef(Y_train,pred_train)

In [None]:
np.corrcoef(Y_test,pred_test)

In [None]:
plt.scatter(Y_test,pred_test, alpha=.1)

In [None]:
# plt.plot(model['elasticnetcv'].coef_)
plt.plot(model['ridgecv'].coef_)


In [None]:
pred_test.shape,Y_test.shape

# Autocorrelation of the th, vid 

In [None]:
from scipy.ndimage import uniform_filter1d 

In [None]:
np.corrcoef(test_th, test_nsp[:,22])

In [None]:
xcorr_data = plt.xcorr(test_th, test_nsp[:,22], maxlags=100)
lags, xscore = xcorr_data[0], xcorr_data[1]

In [None]:
lags[np.argmax(xscore)], xscore[np.argmax(xscore)],lags[np.argmin(xscore)], xscore[np.argmin(xscore)]

In [None]:
plt.acorr(train_th, maxlags=100)

## Regression on Video

In [None]:
train_vid, test_vid, train_nsp, test_nsp, train_th, test_th, train_phi, test_phi, train_roll, test_roll, train_pitch, test_pitch, train_t, test_t, train_dth, test_dth, train_dphi, test_dphi = \
train_test_split(model_vid_sm, model_nsp, model_th, model_phi, model_roll, model_pitch, model_t, model_dth, model_dphi, train_size=.6, shuffle=False, random_state=0)

In [None]:
Y_train = train_vid.reshape(train_vid.shape[0],-1)#[:,10:11] # np.stack((train_roll, train_pitch),axis=1) # 
Y_test = test_vid.reshape(test_vid.shape[0],-1)#[:,10:11] # np.stack((test_roll, test_pitch),axis=1) # 

In [None]:
@ray.remote
def multi_regression(train_nsp,Y_train,test_nsp,Y_test,idx,model_type):
    if model_type == 'elasticnetcv':
        model = make_pipeline(StandardScaler(), lm.ElasticNetCV()) # lm.RidgeCV(alphas=np.arange(100,10000,1000))) #  #MultiOutputRegressor(lm.Ridge(),n_jobs=-1)) 
    elif model_type == 'ridgecv':
        model = make_pipeline(StandardScaler(), lm.RidgeCV(alphas=np.arange(100,10000,1000)))
        
    # MultiTaskElasticNetCV(n_jobs=-1)) # RidgeCV()# MultiTaskLassoCV(n_jobs=-1) # RidgeCV() # LinearRegression(n_jobs=-1) #
    # register_ray()
    # with joblib.parallel_backend('ray'):
    model.fit(train_nsp, Y_train[:,idx])
    pred_train = model.predict(train_nsp)
    pred_test = model.predict(test_nsp)
    model_coeff = model[model_type].coef_
#     print('Train Score:', model.score(train_nsp,Y_train), 'Test Score:', model.score(test_nsp, Y_test))
    train_score = np.corrcoef(pred_train,Y_train[:,idx])[0,1]
    test_score = np.corrcoef(pred_test, Y_test[:,idx])[0,1]
    alphas = model[model_type].alpha_
    return pred_train, pred_test, train_score, test_score, model_coeff, alphas


In [None]:
model_type = 'elasticnetcv'

start = time.time()
train_nsp_r = ray.put(train_nsp)
Y_train_r = ray.put(Y_train)
test_nsp_r = ray.put(test_nsp)
Y_test_r = ray.put(Y_test)
result_ids = []
[result_ids.append(multi_regression.remote(train_nsp_r,Y_train_r,test_nsp_r,Y_test_r,idx,model_type)) for idx in range(0, train_vid.shape[-1]*train_vid.shape[-2])]
results_p = ray.get(result_ids)
print('MultiReg Time: ', time.time() - start)

In [None]:
pred_train = np.stack([results_p[i][0] for i in range(len(results_p))])
pred_test = np.stack([results_p[i][1] for i in range(len(results_p))])
train_scores = np.array([results_p[i][2] for i in range(len(results_p))])
test_scores = np.array([results_p[i][3] for i in range(len(results_p))])
model_coeff = np.array([results_p[i][4] for i in range(len(results_p))])
alphas = np.array([results_p[i][5] for i in range(len(results_p))])


In [None]:
pred_train = pred_train.T.reshape(pred_train.shape[-1],train_vid.shape[1],train_vid.shape[2])
pred_test = pred_test.T.reshape(pred_test.shape[-1],test_vid.shape[1],test_vid.shape[2])
model_coeff = model_coeff.T.reshape(train_nsp.shape[-1],train_vid.shape[1],train_vid.shape[2])

In [None]:
# ElasticNet_data = {
#                 'train_scores': train_scores,
#                 'test_scores': test_scores,
#                 'pred_train': pred_train,
#                 'pred_test': pred_test, }
# ioh5.save(save_dir/'ElasticNet_data.h5',ElasticNet_data)

# Ridge_data = ioh5.load(save_dir/'RidgeData.h5')
# locals().update(ElasticNet_data)

In [None]:
# import plotly.express as px
# t = 200
# dt = 500
# comb = np.concatenate((pred_train[t:t+dt,np.newaxis,:,:], train_vid[t:t+dt,np.newaxis,:,:]),axis=1)

# fig = px.imshow(comb, animation_frame=0, facet_col=1, binary_string=False)
# fig.update_layout(width=1000,
#                   height=500,
#                  )
# fig.show()

Need to look at decoding weights and see if they resempble receptive fields?

In [None]:
pred_train.shape

## Plotting Decoded Video

In [None]:
import cv2
import torchvision
from scipy.ndimage import uniform_filter1d 

In [None]:
t = 2000 #2000
dt = 50
im_grid = torchvision.utils.make_grid(torch.from_numpy(pred_test[t:t+dt,np.newaxis,:,:]),nrow=10,normalize=False)[0]
im_grid2 = torchvision.utils.make_grid(torch.from_numpy(test_vid[t:t+dt,np.newaxis,:,:]),nrow=10,normalize=False)[0]
fig, axs = plt.subplots(2,1,figsize=(20,10))
axs[0].imshow(im_grid, cmap='gray')#.permute(1,2,0))
axs[0].set_title('Decoding Prediction')
axs[1].imshow(im_grid2, cmap='gray')#.permute(1,2,0))
axs[1].set_title('Actual Frame')
plt.tight_layout()
fig.savefig(FigPath/'DecodedMontage_{}.png'.format(model_type),bbox_inches='tight',transparent=False, facecolor='w')

In [None]:
im_grid = torchvision.utils.make_grid(torch.from_numpy(model_coeff[:,np.newaxis]),nrow=10,normalize=False)[0]
fig, axs = plt.subplots(1,1,figsize=(10,10))
axs.imshow(im_grid, cmap='gray')#.permute(1,2,0))
axs.set_title('Decoding Coeff')
plt.tight_layout()
fig.savefig(FigPath/'DecodingWeights_{}.png'.format(model_type),bbox_inches='tight',transparent=False, facecolor='w')

In [None]:
sf = 2
pred_test_norm = normimgs(pred_test)
pred_test_up = np.zeros((pred_test.shape[0],sf*pred_test.shape[1],sf*pred_test.shape[2]))
test_vid_norm = normimgs(test_vid)
test_vid_up = np.zeros((test_vid.shape[0],sf*test_vid.shape[1],sf*test_vid.shape[2]))
pred_train_norm = normimgs(pred_train)
pred_train_up = np.zeros((pred_train.shape[0],sf*pred_train.shape[1],sf*pred_train.shape[2]))
train_vid_norm = normimgs(train_vid)
train_vid_up = np.zeros((train_vid.shape[0],sf*train_vid.shape[1],sf*train_vid.shape[2]))
for n in range(pred_test.shape[0]):
    pred_test_up[n] = cv2.resize(pred_test_norm[n],(sf*pred_test.shape[2],sf*pred_test.shape[1]))
    test_vid_up[n] = cv2.resize(test_vid_norm[n],(sf*test_vid.shape[2],sf*test_vid.shape[1]))
    pred_train_up[n] = cv2.resize(pred_train_norm[n],(sf*pred_train.shape[2],sf*pred_train.shape[1]))
    train_vid_up[n] = cv2.resize(train_vid_norm[n],(sf*train_vid.shape[2],sf*train_vid.shape[1]))

cond = 'test'
if cond == 'train':
    tot_samps = np.stack((pred_train_up, train_vid_up))
else:
    tot_samps = np.stack((pred_test_up, test_vid_up))
tot_samps.shape

In [None]:
# # Example Frames Video
# t = 0
# dt = pred_test.shape[0]
# # comb = np.concatenate((normimgs(pred_test),normimgs(test_vid)),axis=2)
# comb = np.concatenate((pred_test_up,test_vid_up),axis=2).astype(np.uint8)
# # comb = (comb - np.min(comb,axis=(-1,-2))[:,np.newaxis,np.newaxis])/(np.max(comb,axis=(-1,-2))-np.min(comb,axis=(-1,-2)))[:,np.newaxis,np.newaxis]
# # comb = (comb*255).astype(np.uint8)

# FPS = 10
# out = cv2.VideoWriter(os.path.join(FigPath,'Frames_ExVid.avi'), cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), FPS, (comb.shape[-1], comb.shape[-2]),0)
            
# for fm in tqdm(range(comb.shape[0])):
#     out.write(comb[fm])
# out.release()

In [None]:
# ###### Grab data of longest continuous sequence ######
# def func1(a,b):
#     # "Enclose" mask with sentients to catch shifts later on
#     mask = np.r_[False,a,False]

#     # Get the shifting indices
#     idx = np.flatnonzero(mask[1:] != mask[:-1])

#     s0,s1 = idx[::2], idx[1::2]
#     idx_b = np.r_[0,(s1-s0).cumsum()]
#     out = []
#     for (i,j,k,l) in zip(s0,s1-1,idx_b[:-1],idx_b[1:]):
#         out.append(((i, j), b[k:l]))
#     return out

# train_idxs,test_idxs = train_test_split(good_idxs,train_size=.6,random_state=0)

# out = func1(test_idxs,np.arange(test_idxs.shape[0]))

# max_seqn = 0
# for n in range(len(out)):
#     if len(out[n][1]) > max_seqn:
#         max_seq = np.arange(out[n][0][0],out[n][0][1])
#         max_seqn = len(out[n][1])

In [None]:
win_size = 3
tot_samps2 = uniform_filter1d(tot_samps,win_size,axis=1)

In [None]:
t = 500
dt = 100
plt.plot(tot_samps2[0,t:t+dt,5,10])
plt.plot(tot_samps[1,t:t+dt,5,10])
plt.plot(tot_samps2[1,t:t+dt,5,10])
plt.legend(['Pred','Actual','Actual_smoothed'])

In [None]:
from matplotlib.animation import FuncAnimation, PillowWriter, FFMpegWriter
from matplotlib import colors
def init():
    for n in range(2):
        axs[n].axis('off')
    plt.tight_layout()

def update(t):
    for n in range(2):
        ims[n].set_data(tot_samps2[n,t])
    plt.draw()

In [None]:
t = 0# max_seq[0]
lat_dims = 2
x,y = [],[]
fig, axs = plt.subplots(1,2,figsize=(8,4))   #8,16,figsize=(50,30)  
axs = axs.flatten()
ims = []
titles = ['Pred','Actual']
for n in range(2):
    ims.append(axs[n].imshow(tot_samps2[n,t],cmap='gray',norm=colors.Normalize()))
    axs[n].axis('off')
    axs[n].set_title('{}'.format(titles[n]))
plt.tight_layout()
# fig.savefig(os.path.join(FigurePath,'testimg.png'))

In [None]:
# writervideo = PillowWriter(fps=60)  
ani = FuncAnimation(fig, update, tqdm(range(tot_samps2.shape[1])), init_func=init)  #range(tot_samps.shape[1])
plt.show()
vpath = check_path(FigPath,'version_{:d}'.format(0))
vname =  'DecodedVideo_{}_upsampled{:d}_smoothed{:d}_{}.mp4'.format(model_type,sf, win_size,cond)
writervideo = FFMpegWriter(fps=10) 
ani.save(os.path.join(vpath,vname), writer=writervideo)
print('DONE!!!')

In [None]:
train_scores = train_scores.reshape((train_vid.shape[-2],train_vid.shape[-1]))
test_scores = test_scores.reshape((test_vid.shape[-2],test_vid.shape[-1]))
fig, axs = plt.subplots(2,1,figsize=(10,10))
im1 = axs[0].imshow(train_scores, vmin=0, vmax=.55)
axs[0].set_title('Train Correlation Map')
add_colorbar(im1)
im2 = axs[1].imshow(test_scores, vmin=0, vmax=.55)
axs[1].set_title('Test Correlation Map')
add_colorbar(im2)
plt.tight_layout()
fig.savefig(FigPath/'DecodingScores_{}.png'.format(model_type),bbox_inches='tight',transparent=False, facecolor='w')

In [None]:
t = 2000
dt = 20
comb = np.concatenate((np.concatenate((pred_test[t:t+dt,:,:], test_vid[t:t+dt,:,:]),axis=1)),axis=1)
fig, ax = plt.subplots(1,figsize=(25,20))
ax.imshow(comb)

# Pytorch

In [None]:
# train_roll = train_roll/np.max(train_roll)
# train_roll -=train_roll[0]
# train_pitch = train_pitch/np.max(train_pitch)
# train_pitch -=train_pitch[0]
# test_roll = test_roll/np.max(test_roll)
# test_roll -=test_roll[0]
# test_pitch = test_pitch/np.max(test_pitch)
# test_pitch -=test_pitch[0]

# Y_train = torch.from_numpy(np.stack((train_roll, train_pitch),axis=1)).float()
# Y_test  = torch.from_numpy(np.stack((test_roll, test_pitch),axis=1)).float()

Y_train = torch.from_numpy(train_roll[:,np.newaxis]).float() #train_vid.reshape(train_vid.shape[0],-1)).float()#[:,10:11] # np.stack((train_roll, train_pitch),axis=1) # 
Y_test = torch.from_numpy(test_roll[:,np.newaxis]).float() #test_vid.reshape(test_vid.shape[0],-1)).float()#[:,10:11] # np.stack((test_roll, test_pitch),axis=1) # 

In [None]:
class DecodingDataset(Dataset):
    def __init__(self, data, output, N_fm, transform=None):
        
        self.data = data
        self.output = output
        self.transform = transform
        self.N_fm = N_fm

    def __len__(self):
        return(self.data.shape[0])
    
    def __getitem__(self,idx):
        if idx < self.N_fm:
            idx = self.N_fm
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        sample = torch.from_numpy(self.data[idx-self.N_fm:idx]).float()
        gt = torch.from_numpy(self.output[idx-self.N_fm:idx,:]).float()
        return sample.view(-1), gt.view(-1)
    
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
    
    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
N_fm=1
batch_size = 1024
in_neurons = train_nsp.shape[1]*N_fm
out_neurons = 2048
out_dims = Y_train.shape[-1]*N_fm
NEpochs = 500
# train_dataset = DecodingDataset(train_nsp, np.stack((train_roll, train_pitch),axis=1), N_fm=N_fm)
# test_dataset = DecodingDataset(test_nsp, np.stack((test_roll, test_pitch),axis=1), N_fm=N_fm)
train_dataset = TensorDataset(torch.from_numpy(train_nsp).float(),Y_train)
test_dataset  = TensorDataset(torch.from_numpy(test_nsp).float(),Y_test)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=False)


In [None]:
model = nn.Sequential(nn.Linear(in_neurons,out_neurons),
                      nn.ReLU(),
                      nn.Linear(out_neurons,out_neurons),
                      nn.ReLU(),
                      nn.Linear(out_neurons,out_dims)).to(device)
optimizer = optim.AdamW(params=model.parameters(), lr=.0001)
criteria = nn.MSELoss()
early_stopping = EarlyStopping(path=save_dir/'checkpoint.pt')

In [None]:
tot_loss = []
test_tot_loss = []
for epoch in tqdm(range(NEpochs)):
    epoch_loss = []
    for batch, y in train_dataloader:
        pred = model(batch.to(device))
        loss = criteria(pred.to(device),y.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    tot_loss.append(np.mean(epoch_loss))
    
    with torch.no_grad():
        test_epoch_loss = []
        for batch, y in test_dataloader:
            pred = model(batch.to(device))
            loss = criteria(pred.to(device),y.to(device))
            test_epoch_loss.append(loss.item())
        test_tot_loss.append(np.mean(test_epoch_loss))
    early_stopping(np.mean(test_epoch_loss), model)
    if early_stopping.early_stop == True:
        print('Stopped Early!')
        break
    print('Epoch:', epoch, 'Epoch_Loss_Avg: ', np.mean(epoch_loss), 'Test_Epoch_Loss_Avg: ', np.mean(test_epoch_loss))

In [None]:
wind = np.arange(0,1000)
fig, ax = plt.subplots(2,1,figsize=(20,10))
ax[0].plot(Y_train[wind,0],'b-', label='roll')
ax[0].plot(pred[wind,0].cpu().detach(),'r-', label='pred_roll')
ax[1].plot(Y_train[wind,1],'b-', label='pitch')
ax[1].plot(pred[wind,1].cpu().detach(),'r-', label='pred_pitch')
ax[0].legend()
ax[1].legend()

In [None]:
with torch.no_grad():
    predt = []
    for batch, y in test_dataloader:
        pred = model(batch.to(device))
        predt.append(pred.cpu().numpy())
    predt = np.concatenate(predt,axis=0)

In [None]:
wind = np.arange(0,1000)
fig, ax = plt.subplots(2,1,figsize=(20,10))
ax[0].plot(Y_test[wind,0],'b-', label='roll')
ax[0].plot(predt[wind,0],'r-', label='pred_roll')
ax[1].plot(Y_test[wind,1],'b-', label='pitch')
ax[1].plot(predt[wind,1],'r-', label='pred_pitch')