# Import Modules

In [None]:
%load_ext autoreload
%autoreload 2

import os
import argparse
import glob
import sys 
import yaml 
import glob
import h5py 
import ray
import logging 
import json
import gc
import cv2
import time
import itertools

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
# import io_dict_to_hdf5 as ioh5
import xarray as xr
import scipy.linalg as linalg
import scipy.sparse as sparse

from tqdm.notebook import tqdm, trange
from matplotlib.backends.backend_pdf import PdfPages
from scipy import signal
from pathlib import Path
from scipy.optimize import minimize_scalar
from scipy.interpolate import interp1d
from scipy.ndimage import shift as imshift
from sklearn.model_selection import train_test_split, GroupShuffleSplit
from sklearn.pipeline import make_pipeline
from sklearn import linear_model as lm 
from scipy.stats import binned_statistic
from sklearn.utils import shuffle


# import torch
# import torch.nn as nn
# import torch.optim as optim
# import torch.nn.functional as F
# from torchvision import transforms
# from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
# torch.backends.cudnn.benchmark = True
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sys.path.append(str(Path('.').absolute().parent))
from utils import *
import io_dict_to_hdf5 as ioh5
from format_data import load_ephys_data_aligned

pd.set_option('display.max_rows', None)
FigPath = check_path(Path('~/Research/SensoryMotorPred_Data').expanduser(),'Figures/Decoding')

ray.init(
    ignore_reinit_error=True,
    logging_level=logging.ERROR,
)
print(f'Dashboard URL: http://{ray.get_dashboard_url()}')
print('Dashboard URL: http://localhost:{}'.format(ray.get_dashboard_url().split(':')[-1]))

To Do:  1. Make Possion GLM with L2 norm, solve equations

# Gather Data

In [None]:
def load_train_test(file_dict, save_dir, model_dt=.1, frac=.2, do_shuffle=False, do_norm=False):
    ##### Load in preprocessed data #####
    data = load_ephys_data_aligned(file_dict, save_dir, model_dt=model_dt)
    ##### Find 'good' timepoints when mouse is active #####
    nan_idxs = []
    for key in data.keys():
        nan_idxs.append(np.where(np.isnan(data[key]))[0])
    good_idxs = np.ones(len(data['model_active']),dtype=bool)
    good_idxs[data['model_active']<.5] = False
    good_idxs[np.unique(np.hstack(nan_idxs))] = False
    
    data['raw_nsp'] = data['model_nsp'].copy()
    ##### return only active data #####
    for key in data.keys():
        if (key != 'model_nsp') & (key != 'model_active'):
            data[key] = data[key][good_idxs] # interp_nans(data[key]).astype(float)
        elif (key == 'model_nsp'):
            data[key] = data[key][good_idxs]

    gss = GroupShuffleSplit(n_splits=1, train_size=.7, random_state=42)
    nT = data['model_nsp'].shape[0]
    groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((.2*(i-1))*nT)) for i in range(1,int(1/frac)+1)])

    for train_idx, test_idx in gss.split(np.arange(len(data['model_nsp'])), groups=groups):
        print("TRAIN:", train_idx, "TEST:", test_idx)
    if do_shuffle:
        train_idx = shuffle(train_idx, random_state=42)
        test_idx = shuffle(test_idx, random_state=42)

    data['model_dth'] = np.diff(data['model_th'],append=0)
    data['model_dphi'] = np.diff(data['model_phi'],append=0)

    data['model_vid_sm'] = (data['model_vid_sm'] - np.mean(data['model_vid_sm'],axis=0))/np.std(data['model_vid_sm'],axis=0) 
    if do_norm:
        data['model_th'] = (data['model_th'] - np.mean(data['model_th'],axis=0))/np.std(model_th,axis=0) 
        data['model_phi'] = (data['model_phi'] - np.mean(data['model_phi'],axis=0))/np.std(model_phi,axis=0) 
        data['model_roll'] = (data['model_roll'] - np.mean(data['model_roll'],axis=0))/np.std(model_roll,axis=0) 
        data['model_pitch'] = (data['model_pitch'] - np.mean(data['model_pitch'],axis=0))/np.std(model_pitch,axis=0) 

    ##### Split Data by train/test #####
    data_train_test = {
    'train_vid': data['model_vid_sm'][train_idx],
    'test_vid': data['model_vid_sm'][test_idx],
    'train_nsp': data['model_nsp'][train_idx],
    'test_nsp': data['model_nsp'][test_idx],
    'train_th': data['model_th'][train_idx],
    'test_th': data['model_th'][test_idx],
    'train_phi': data['model_phi'][train_idx],
    'test_phi': data['model_phi'][test_idx],
    'train_roll': data['model_roll'][train_idx],
    'test_roll': data['model_roll'][test_idx],
    'train_pitch': data['model_pitch'][train_idx],
    'test_pitch': data['model_pitch'][test_idx],
    'train_t': data['model_t'][train_idx],
    'test_t': data['model_t'][test_idx],
    'train_dth': data['model_dth'][train_idx],
    'test_dth': data['model_dth'][test_idx],
    'train_dphi': data['model_dphi'][train_idx],
    'test_dphi': data['model_dphi'][test_idx],
    'train_gz': data['model_gz'][train_idx],
    'test_gz': data['model_gz'][test_idx],
    }
    d1 = data
    d1.update(data_train_test)
    return d1

In [None]:
save_dir = Path('~/Research/SensoryMotorPred_Data/data/070921/J553RT/fm1').expanduser()
with open(save_dir / 'file_dict.json','r') as fp:
    file_dict = json.load(fp)

In [None]:
file_dict = {'cell': 0,
 'drop_slow_frames': True,
 'ephys': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_ephys_merge.json',
 'ephys_bin': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_Ephys.bin',
 'eye': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_REYE.nc',
 'imu': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_imu.nc',
 'mapping_json': '/home/seuss/Research/Github/FreelyMovingEphys/probes/channel_maps.json',
 'mp4': True,
 'name': '070921_J553RT_control_Rig2_fm1',
 'probe_name': 'DB_P128-6',
 'save': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1',
 'speed': None,
 'stim_type': 'light',
 'top': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_TOP1.nc',
 'world': '/home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1/070921_J553RT_control_Rig2_fm1_world.nc'}

In [None]:
data = load_train_test(file_dict, save_dir, do_shuffle=False, do_norm=False)
locals().update(data)

In [None]:
# model_dt = .1
# data = load_ephys_data_aligned(file_dict, save_dir, model_dt=model_dt)
# nan_idxs = []
# for key in data.keys():
#     nan_idxs.append(np.where(np.isnan(data[key]))[0])
# good_idxs = np.ones(len(data['model_active']),dtype=bool)
# good_idxs[data['model_active']<.5] = False
# good_idxs[np.unique(np.hstack(nan_idxs))] = False

# raw_nsp = data['model_nsp'].copy()

# for key in data.keys():
#     if (key != 'model_nsp') & (key != 'model_active'):
#         data[key] = data[key][good_idxs] # interp_nans(data[key]).astype(float)
#     elif (key == 'model_nsp'):
#         data[key] = data[key][good_idxs]
# locals().update(data)

# plt.hist(data['model_active'], bins=100)
# plt.axvline(x=.5)
# # movement_times = (data['model_active']>.5) & (~np.isnan(data['model_th']))

# ##### Print memory of local variables #####
# for name, size in sorted(((name, sys.getsizeof(value)) for name, value in locals().items()), key= lambda x: -x[1])[:10]:
#     print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))

# ##### Group shuffle #####
# do_shuffle = False
# do_norm = False

# gss = GroupShuffleSplit(n_splits=1, train_size=.7, random_state=42)
# nT = model_nsp.shape[0]
# frac = .2
# groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((.2*(i-1))*nT)) for i in range(1,int(1/frac)+1)])

# for train_idx, test_idx in gss.split(np.arange(len(model_nsp)), groups=groups):
#     print("TRAIN:", train_idx, "TEST:", test_idx)
# if do_shuffle:
#     train_idx, test_idx = shuffle(train_idx, test_idx, random_state=42)
# model_dth = np.diff(model_th,append=0)
# model_dphi = np.diff(model_phi,append=0)

# model_vid_sm = (model_vid_sm - np.mean(model_vid_sm,axis=0))/np.std(model_vid_sm,axis=0) 
# if do_norm:
#     model_th = (model_th - np.mean(model_th,axis=0))/np.std(model_th,axis=0) 
#     model_phi = (model_phi - np.mean(model_phi,axis=0))/np.std(model_phi,axis=0) 
#     model_roll = (model_roll - np.mean(model_roll,axis=0))/np.std(model_roll,axis=0) 
#     model_pitch = (model_pitch - np.mean(model_pitch,axis=0))/np.std(model_pitch,axis=0) 

# ##### Split Data by train/test #####
# train_vid = model_vid_sm[train_idx]
# test_vid = model_vid_sm[test_idx]
# train_nsp = model_nsp[train_idx]
# test_nsp = model_nsp[test_idx]
# train_th = model_th[train_idx]
# test_th = model_th[test_idx]
# train_phi = model_phi[train_idx]
# test_phi = model_phi[test_idx]
# train_roll = model_roll[train_idx]
# test_roll = model_roll[test_idx]
# train_pitch = model_pitch[train_idx]
# test_pitch = model_pitch[test_idx]
# train_t = model_t[train_idx]
# test_t = model_t[test_idx]
# train_dth = model_dth[train_idx]
# test_dth = model_dth[test_idx]
# train_dphi = model_dphi[train_idx]
# test_dphi = model_dphi[test_idx]
# train_gz = model_gz[train_idx]
# test_gz = model_gz[test_idx]

# Testing Tuning Curves

In [None]:
# Create Tuning curve for theta
def tuning_curve(model_nsp, var, model_dt = .025, N_bins=10):
    var_range = np.linspace(np.nanmean(var)-2*np.nanstd(var), np.nanmean(var)+2*np.nanstd(var),N_bins)
    tuning = np.zeros((model_nsp.shape[-1],len(var_range)-1))
    tuning_std = np.zeros((model_nsp.shape[-1],len(var_range)-1))
    for n in range(model_nsp.shape[-1]):
        for j in range(len(var_range)-1):
            usePts = (var>=var_range[j]) & (var<var_range[j+1])
            tuning[n,j] = np.nanmean(model_nsp[usePts,n])/model_dt
            tuning_std[n,j] = (np.nanstd(model_nsp[usePts,n])/model_dt)/ np.sqrt(np.count_nonzero(usePts))
    return tuning, tuning_std, var_range[:-1]


In [None]:
tuning, tuning_std, var_range = tuning_curve(train_nsp, train_phi, N_bins=10, model_dt=.1)

In [None]:
n = 51
fig, axs = plt.subplots(1,figsize=(7,5))
axs.errorbar(var_range,tuning[n], yerr=tuning_std[n])
axs.set_ylim(bottom=0)
axs.set_xlabel('Eye Phi')
axs.set_ylabel('Spikes/s')
axs.set_title('Neuron: {}'.format(n))
plt.tight_layout()
fig.savefig(FigPath/'ExampleTuningCurve.png',bbox_inches='tight',transparent=False, facecolor='w')

## PCA on Vid

In [None]:
pca = PCA()
pcs = pca.fit_transform(model_vid_sm.reshape(-1,model_vid_sm.shape[1]*model_vid_sm.shape[2]))
plt.plot(np.cumsum(pca.explained_variance_ratio_))
comp_to_keep = np.where(np.cumsum(pca.explained_variance_ratio_)>.9)[0][0]
plt.axvline(x=comp_to_keep)
pca = PCA(n_components=comp_to_keep)
pcs = pca.fit_transform(model_vid_sm.reshape(-1,model_vid_sm.shape[1]*model_vid_sm.shape[2]))
print('keep {} PCs'.format(comp_to_keep))
# recon = pca.inverse_transform(pcs)

In [None]:
model_vid_sm.shape,model_th.shape,model_phi.shape,model_roll.shape,model_pitch.shape,

# GLM Movement Only

In [None]:
model_type = 'poissonregressor'
if model_type == 'elasticnetcv':
    model = make_pipeline(StandardScaler(), lm.ElasticNetCV()) # lm.RidgeCV(alphas=np.arange(100,10000,1000))) #  #MultiOutputRegressor(lm.Ridge(),n_jobs=-1)) 
elif model_type == 'ridgecv':
    model = make_pipeline(StandardScaler(), lm.RidgeCV(alphas=lambdas))
elif model_type == 'poissonregressor':
    model = make_pipeline(StandardScaler(), lm.PoissonRegressor())

nks = np.shape(model_vid)[1:]; nk = nks[0]*nks[1]
test_frac = 0.3
ntest = int(nT*test_frac)
titles = np.array(['th','phi','roll','pitch','dth','dphi'])
move_train = np.hstack((model_th[ntest:,np.newaxis],model_phi[ntest:,np.newaxis],model_roll[ntest:,np.newaxis],model_pitch[ntest:,np.newaxis],model_dth[ntest:,np.newaxis],model_dphi[ntest:,np.newaxis]))
move_test = np.hstack((model_th[:ntest,np.newaxis],model_phi[:ntest,np.newaxis],model_roll[:ntest,np.newaxis],model_pitch[:ntest,np.newaxis],model_dth[:ntest,np.newaxis],model_dphi[:ntest,np.newaxis]))
sps_smooth_all = np.zeros((15,len(lag_list),move_test.shape[0]))
pred_smooth_all = np.zeros((15,len(lag_list),move_test.shape[0]))
cc_all = np.zeros((n_units,15,len(lag_list)))
model_coef_all= [] # = np.zeros((15,len(lag_list)))
titles_all = []

celln = 51
bin_length = 80
model_ind = 0
with PdfPages(FigPath/ 'ModelSelection_{}_MoveOnly.pdf'.format(model_type)) as pdf:

    for n in range(1,5):
        perms = np.array(list(itertools.combinations([0,1,2,3], n)))
        for ind in range(perms.shape[0]):
            fig, axs = plt.subplots(1,len(lag_list), figsize=(np.floor(7.5*len(lag_list)).astype(int),5))
            move_train2 = move_train[:,perms[ind]]
            move_test2 = move_test[:,perms[ind]]
            # iterate through timing lags
            for lag_ind, lag in enumerate(lag_list):
                sps = np.roll(model_nsp.T[celln,:],-lag)
                nT = len(sps)
                sps_train = sps[ntest:]
                sps_test = sps[:ntest]

                model.fit(move_train2,sps_train)
                sps_pred = model.predict(move_test2)

                sps_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
                pred_smooth = (np.convolve(sps_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
                cc_all[celln,model_ind,lag_ind] = np.corrcoef(sps_smooth, pred_smooth)[0,1]
                sps_smooth_all[model_ind,lag_ind] = sps_smooth
                pred_smooth_all[model_ind,lag_ind] = pred_smooth
                model_coef_all.append(model[model_type].coef_)
                
                axs[lag_ind].plot(sps_smooth,'k',label='smoothed FR')
                axs[lag_ind].plot(pred_smooth,'r', label='pred FR')
                axs[lag_ind].set_title('cc={:.3f}'.format(cc_all[celln,model_ind,lag_ind]))
            #     axs[1,lag_ind].imshow(sta_all[celln,lag_ind])
            #     axs[1,lag_ind].set_title('lag={:d}'.format(lag_list[lag_ind]))
            #     axs[1,lag_ind].axis('off')
            #     plt.suptitle('No Smoothness splitdata pipeline')
                
            titles_all.append('_'.join([t for t in titles[perms[ind]]]))
            plt.suptitle(titles_all[-1])
            plt.tight_layout()
            pdf.savefig()
            
            model_ind+=1


In [None]:
plt.plot(msetrain)
plt.plot(msetest)

# Parallel Processing GLM

## Vis Only sklearn

In [None]:
train_dgaze_p = train_dth + np.diff(train_gz,append=0)
train_dgaze_n = train_dth - np.diff(train_gz,append=0)

In [None]:
# if model_type == 'elasticnetcv':
#     model = make_pipeline(StandardScaler(), lm.ElasticNetCV()) # lm.RidgeCV(alphas=np.arange(100,10000,1000))) #  #MultiOutputRegressor(lm.Ridge(),n_jobs=-1)) 
# elif model_type == 'ridgecv':
#     model = make_pipeline(StandardScaler(), lm.RidgeCV(alphas=lambdas))
# elif model_type == 'poissonregressor':
#     model = make_pipeline(StandardScaler(), lm.PoissonRegressor())


In [None]:
@ray.remote
def do_glm_fit_vis_skl(train_nsp, test_nsp, train_data, test_data, celln, model_type, nt_glm_lag, bin_length=40, model_dt=.1):
    ##### Format data #####
    # save shape of train_data for initialization
    nks = np.shape(train_data)[1:]; nk = nks[0]*nks[1]

    # Shift spikes by -lag for GLM fits
    sps_train = train_nsp[:,celln] # np.roll(train_nsp[:,celln],-lag)
    sps_test = test_nsp[:,celln] # np.roll(test_nsp[:,celln],-lag)

    # Reshape data (video) into (T*n)xN array
    x_train = train_data.reshape(train_data.shape[0],-1)
    x_train = np.hstack([np.roll(x_train, nframes, axis=0) for nframes in reversed(range(nt_glm_lag))])
    x_test = test_data.reshape(test_data.shape[0],-1) 
    x_test = np.hstack([np.roll(x_test,nframes, axis=0) for nframes in reversed(range(nt_glm_lag))])

    if model_type == 'elasticnetcv':
        model = lm.ElasticNetCV(l1_ratio=[.05, .01, .5, .7]) # lm.RidgeCV(alphas=np.arange(100,10000,1000))) #  #MultiOutputRegressor(lm.Ridge(),n_jobs=-1)) 
        model.fit(x_train,sps_train)
        sta_all = np.reshape(model.coef_,(nt_glm_lag,)+nks)
        sp_pred = model.predict(x_test)
    elif model_type == 'ridgecv':
        lambdas = 1024 * (2**np.arange(0,16))
        model = lm.RidgeCV(alphas=lambdas)
        model.fit(x_train,sps_train)
        sta_all = np.reshape(model.coef_,(nt_glm_lag,)+nks)
        sp_pred = model.predict(x_test)
    else:
        lambdas = (2**np.arange(0,16))
        nlam = len(lambdas)
        # Initialze mse traces for regularization cross validation
        msetrain = np.zeros((nlam,1))
        msetest = np.zeros((nlam,1))
        pred_all =np.zeros((x_test.shape[0],nlam)) 
        w_ridge = np.zeros((x_train.shape[-1],nlam))
        w_intercept = np.zeros((nlam,1))
        # loop over regularization strength
        for l in range(len(lambdas)):
            if model_type == 'poissonregressor':
                model = lm.PoissonRegressor(alpha=lambdas[l])
            # calculate MAP estimate               
            model.fit(x_train,sps_train)
            w_ridge[:,l] = model.coef_
            w_intercept[l] = model.intercept_
            # calculate test and training rms error
            msetrain[l] = np.mean((sps_train - model.predict(x_train))**2)
            msetest[l] = np.mean((sps_test - model.predict(x_test))**2)
            pred_all[:,l] = model.predict(x_test)
        # select best cross-validated lambda for RF
        best_lambda = np.argmin(msetest)
        w = w_ridge[:,best_lambda]
        intercept= w_intercept[best_lambda]
        ridge_rf = w_ridge[:,best_lambda]
        sta_all = np.reshape(w,(nt_glm_lag,)+nks)
        sp_pred = pred_all[:,best_lambda]
    #     model = make_pipeline(StandardScaler(), lm.PoissonRegressor(alpha=lambdas[best_lambda]))
    #     model.fit(x_train,sps_train)

    # predicted firing rate
    # bin the firing rate to get smooth rate vs time
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all = cc[0,1]

    return cc_all, sta_all, sps_test, sp_pred

In [None]:
model_dt = .1
model_type = 'elasticnetcv'
nt_glm_lag=5
# Load Data
do_shuffle=False
data = load_train_test(file_dict, save_dir, do_shuffle=do_shuffle, do_norm=True)
locals().update(data)

##### Start GLM Parallel Processing #####
start = time.time()
nks = np.shape(train_vid)[1:]; nk = nks[0]*nks[1]*nt_glm_lag

# Put data into shared memory for parallization 
train_nsp_r = ray.put(train_nsp)
test_nsp_r = ray.put(test_nsp)
train_data_r = ray.put(train_vid)
test_data_r = ray.put(test_vid)
result_ids = []
# Loop over parameters appending process ids
for celln in range(train_nsp.shape[1]):
    result_ids.append(do_glm_fit_vis_skl.remote(train_nsp_r, test_nsp_r, train_data_r, test_data_r, celln, model_type, nt_glm_lag, model_dt=model_dt))

print('N_proc:', len(result_ids))
results_p = ray.get(result_ids)
print('GLM Add: ', time.time()-start)

In [None]:
##### Gather Data and Find Max CC Model #####
mcc = np.stack([results_p[i][0] for i in range(len(results_p))])
msta = np.stack([results_p[i][1] for i in range(len(results_p))])
msp = np.stack([results_p[i][2] for i in range(len(results_p))])
mpred = np.stack([results_p[i][3] for i in range(len(results_p))])

# cc_all = cc_all.reshape((model_nsp.shape[1],nt_glm_lag,) + cc_all.shape[1:])
# sta_all = sta_all.reshape((model_nsp.shape[1],nt_glm_lag,) + sta_all.shape[1:])
# sp_raw = sp_raw.reshape((model_nsp.shape[1],nt_glm_lag,) + sp_raw.shape[1:])
# pred_raw = pred_raw.reshape((model_nsp.shape[1],nt_glm_lag,) + pred_raw.shape[1:])

# m_cells, m_lags = np.where(cc_all==np.max(cc_all,axis=(-1), keepdims=True))

# mcc = cc_all[m_cells,m_lags]
# msta = sta_all[m_cells,m_lags]
# msp = sp_raw[m_cells,m_lags]
# mpred = pred_raw[m_cells,m_lags]


In [None]:
GLM_Data = {'mcc': mcc,
            'msta': msta,
            'msp': msp,
            'mpred': mpred,}
if do_shuffle:
    ioh5.save(save_dir/'GLM_{}_Data_VisOnly_notsmooth_dt{:03d}_T{:02d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag), GLM_Data)
else:
    ioh5.save(save_dir/'GLM_{}_Data_VisOnly_notsmooth_dt{:03d}_T{:02d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag), GLM_Data)

In [None]:
celln=51
fig, axs = plt.subplots(1,nt_glm_lag,figsize=(20,5))
crange = np.max(np.abs(msta[celln]))
for n in range(nt_glm_lag):
    img = axs[n].imshow(msta[celln,n],cmap='seismic',vmin=-crange,vmax=crange)
    add_colorbar(img)
    axs[n].axis('off')
    axs[n].set_title('Lag:{}'.format(n-nt_glm_lag+1))
plt.suptitle('Celln:{}, cc={:.03f}'.format(celln,mcc[celln]),y=.75,fontsize=20)
plt.tight_layout()

# fig.savefig(FigPath/'TemporalRF_N{}.png'.format(celln), facecolor='white', transparent=True)


## VisOnly By Hand

In [None]:
@ray.remote
def do_glm_fit_visonly(train_nsp, test_nsp, train_data, test_data, celln, perms, lag, lambdas, bin_length=40, model_dt=.1):
    ##### Format data #####
    # save shape of train_data for initialization
    nks = np.shape(train_data)[1:]; nk = nks[0]*nks[1]
    
    # Shift spikes by -lag for GLM fits
    sps_train = np.roll(train_nsp[:,celln],-lag)
    sps_test = np.roll(test_nsp[:,celln],-lag)
        
    # Reshape data (video) into (T*n)xN array
    x_train = train_data.reshape(train_data.shape[0],-1)
    x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis = 1) # append column of ones for fitting intercept
#     x_train = np.concatenate((x_train, move_train),axis=1)
    
    x_test = test_data.reshape(test_data.shape[0],-1) 
    x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis = 1) # append column of ones
#     x_test = np.concatenate((x_test, move_test),axis=1)
    
    # Prepare Design Matrix
    nlam = len(lambdas)
    XXtr = x_train.T @ x_train
    XYtr = x_train.T @ sps_train
    
    # Initialze mse traces for regularization cross validation
    msetrain = np.zeros((nlam,1))
    msetest = np.zeros((nlam,1))
    w_ridge = np.zeros((x_train.shape[-1],nlam))
    # Inverse matrix for regularization 
    Cinv = np.eye(nk)
    Cinv = linalg.block_diag(Cinv,np.zeros((1, 1)))
    # loop over regularization strength
    for l in range(len(lambdas)):
        # calculate MAP estimate               
        w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
        w_ridge[:,l] = w
        # calculate test and training rms error
        msetrain[l] = np.mean((sps_train - x_train@w)**2)
        msetest[l] = np.mean((sps_test - x_test@w)**2)
    
    # select best cross-validated lambda for RF
    best_lambda = np.argmin(msetest)
    w = w_ridge[:,best_lambda]
    ridge_rf = w_ridge[:,best_lambda]
    sta_all = np.reshape(w[:-1],nks)
    
    # predicted firing rate
    sp_pred = x_test@ridge_rf
    # bin the firing rate to get smooth rate vs time
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all = cc[0,1]
    
    return cc_all, sta_all, sps_test, sp_pred

In [None]:

start = time.time()

lag_list = [ -4, -2, 0 , 4, 2]
lambdas = 1024 * (2**np.arange(0,16))
nks = np.shape(model_vid_sm)[1:]; nk = nks[0]*nks[1]

# Put data into shared memory for parallization 
train_nsp_r = ray.put(train_nsp)
test_nsp_r = ray.put(test_nsp)
train_data_r = ray.put(train_vid)
test_data_r = ray.put(test_vid)
result_ids = []
# Loop over parameters appending process ids
for celln in range(train_nsp.shape[1]):
#     for n in range(1,len(titles)):
    perms = [] #np.array(list(itertools.combinations(np.arange(len(titles), n)))
#     for ind in range(perms.shape[0]):
    for lag_ind, lag in enumerate(lag_list):    
        result_ids.append(do_glm_fit_visonly.remote(train_nsp_r, test_nsp_r, train_data_r, test_data_r, celln, perms, lag, lambdas, model_dt=model_dt))

print('N_proc:', len(result_ids))
results_p = ray.get(result_ids)
print('GLM Add: ', time.time()-start)

In [None]:
##### Gather Data and Find Max CC Model #####
cc_all = np.stack([results_p[i][0] for i in range(len(results_p))])
sta_all = np.stack([results_p[i][1] for i in range(len(results_p))])
sp_raw = np.stack([results_p[i][2] for i in range(len(results_p))])
pred_raw = np.stack([results_p[i][3] for i in range(len(results_p))])

cc_all = cc_all.reshape((model_nsp.shape[1],len(lag_list),) + cc_all.shape[1:])
sta_all = sta_all.reshape((model_nsp.shape[1],len(lag_list),) + sta_all.shape[1:])
sp_raw = sp_raw.reshape((model_nsp.shape[1],len(lag_list),) + sp_raw.shape[1:])
pred_raw = pred_raw.reshape((model_nsp.shape[1],len(lag_list),) + pred_raw.shape[1:])

m_cells, m_lags = np.where(cc_all==np.max(cc_all,axis=(-1), keepdims=True))

mcc = cc_all[m_cells,m_lags]
msta = sta_all[m_cells,m_lags]
msp = sp_raw[m_cells,m_lags]
mpred = pred_raw[m_cells,m_lags]


In [None]:
GLM_add = {'cc_all': cc_all,
            'sta_all': sta_all,
            'sp_raw': sp_raw,
            'pred_raw': pred_raw,}
ioh5.save(save_dir/'Add_GLM_Data_VisOnly_notsmooth_dt{:03d}_Shuffled.h5'.format(int(model_dt*1000)), GLM_add)

In [None]:
GLM_data = ioh5.load(save_dir/'GLM_{}_Data_VisOnly_notsmooth_dt{:03d}_T{:02d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag))

GLM_shuff = ioh5.load(save_dir/'GLM_{}_Data_VisOnly_notsmooth_dt{:03d}_T{:02d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag))


### Plotting

In [None]:
sf = 2
model_vid = np.zeros((model_vid_sm.shape[0],sf*model_vid_sm.shape[1],sf*model_vid_sm.shape[2]))
for n in range(model_vid_sm.shape[0]):
    model_vid[n] = cv2.resize(model_vid_sm[n],(sf*model_vid_sm.shape[2],sf*model_vid_sm.shape[1]))

In [None]:
from matplotlib.animation import FuncAnimation, PillowWriter, FFMpegWriter
def init():
    axs[0].axis('off')
    axs[1].set_xlabel('Frame #')
    axs[1].set_ylabel('Smoothed FR (spks/s)')
    axs[1].set_title('cc={:.2f}, \n lag={:d}'.format(mcc[celln],lag_list[m_lags[celln]]))
    plt.tight_layout()

def update(t):
    img.set_data(model_vid[t])
    ln.set_data([t, t], [0, 1])
    plt.draw()

In [None]:
t = 3500# np.argmin((msp_smooth-pred_smooth)**2)
dt = 1500
celln = np.argmax(mcc)
msp_smooth=(np.convolve(msp[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
pred_smooth=(np.convolve(mpred[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)

fig, axs = plt.subplots(1,2,figsize=(12,5))
img = axs[0].imshow(model_vid[t],cmap='gray')
axs[0].axis('off')
ln = axs[1].axvline(x=t,c='b')
axs[1].plot(np.arange(t,t+dt),msp_smooth[t:t+dt],'k',label='test FR')
axs[1].plot(np.arange(t,t+dt),pred_smooth[t:t+dt],'r', label='pred FR')
axs[1].set_xlabel('Frame #')
axs[1].set_ylabel('Smoothed FR (spks/s)')
axs[1].set_title('cc={:.2f}, \n lag={:d}'.format(mcc[celln],lag_list[m_lags[celln]]))
plt.tight_layout()
# fig.savefig(os.path.join(FigurePath,'testimg.png'))

In [None]:
ani = FuncAnimation(fig, update, tqdm(np.arange(t,t+dt)), init_func=init)  #range(tot_samps.shape[1])
plt.show()
vname =  'SampleVid.mp4'
writervideo = FFMpegWriter(fps=60)
ani.save(FigPath/vname, writer=writervideo)
print('DONE!!!')

In [None]:
import plotly.express as px

fig = px.imshow(msta, animation_frame=0, binary_string=False,color_continuous_scale='RdBu_r')
fig.update_layout(width=500,
                  height=500,)
fig.show()

In [None]:
GLM_data = ioh5.load(save_dir/'GLM_{}_Data_VisOnly_notsmooth_dt{:03d}_T{:02d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag))
locals().update(GLM_data)
mcc = cc_all
msta = sta_all
msp = sp_raw
mpred = pred_raw
# GLM_shuff = ioh5.load(save_dir/'GLM_{}_Data_VisOnly_notsmooth_dt{:03d}_T{:02d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag))

In [None]:

plt.hist(GLM_data['cc_all'],bins=20,color='k',alpha=.5,label='Test CC')
plt.hist(GLM_shuff['cc_all'],bins=20,color='r', alpha=.5,label='Shuffled CC')
plt.xlabel('Corr. Coeff.')
plt.legend()
plt.savefig(FigPath/'CC_comparison_{}.png'.format(model_type), facecolor='white', transparent=True)

In [None]:
def f_add(alpha,stat_range,stat_all):
    return np.mean((stat_range - stat_all+alpha)**2)

def f_mult(alpha,stat_range,stat_all):
    return np.mean((stat_range - stat_all*alpha)**2)

In [None]:
bin_length=40
data = load_train_test(file_dict, save_dir, do_shuffle=do_shuffle, do_norm=False)
locals().update(data)
##### Explore Neurons #####
colors = plt.cm.winter(np.linspace(0,1,4))
clrs = ['blue','orange','green','red']
# Initialize movement combinations
titles = np.array(['th','phi','roll','pitch']) # 
titles_all = []
for n in range(1,len(titles)):
    perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))
        
# move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis], train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis], test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))
move_test = move_test - np.mean(move_test,axis=0)
# Create all tuning curves for plotting
N_bins=10
ncells = model_nsp.shape[-1]
ax_ylims = np.zeros((model_nsp.shape[-1],len(titles)))
tuning_curves = np.zeros((model_nsp.shape[-1],len(titles),N_bins-1))
tuning_stds = np.zeros((model_nsp.shape[-1],len(titles),N_bins-1))
var_ranges = np.zeros((len(titles),N_bins-1))
for modeln in range(len(titles)):
    metric = move_test[:,modeln]
    tuning, tuning_std, var_range = tuning_curve(test_nsp, metric, N_bins=N_bins, model_dt=.1)
    tuning_curves[:,modeln] = tuning
    tuning_stds[:,modeln] = tuning_std
    ax_ylims[:,modeln] = np.max(tuning,axis=1)
    var_ranges[modeln] = var_range

In [None]:
with PdfPages(FigPath/ 'MaxCC_GLM_dt{:03d}_cellsummary_sig.pdf'.format(int(model_dt*1000))) as pdf:
    for celln in tqdm(range(msp.shape[0])):
        if mcc[celln]>.3:
            fig, axs = plt.subplots(2,5, figsize=((35,10))) #np.floor(7.5*len(model_nsp)).astype(int)
            predcell = mpred[celln]/model_dt
            nspcell = msp[celln]/model_dt
            msp_smooth=(np.convolve(msp[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
            pred_smooth=(np.convolve(mpred[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
            axs[0,0].plot(msp_smooth,'k',label='test FR')
            axs[0,0].plot(pred_smooth,'r', label='pred FR')
            axs[0,0].set_xlabel('Frame #')
            axs[0,0].set_ylabel('Smoothed Firing Rate (spks/s)')
            axs[0,0].legend()
            axs[0,0].set_title('cc={:.2f}, \n lag={:d}'.format(mcc[celln],lag_list[m_lags[celln]]))
            crange = np.max(np.abs(msta[celln]))
            img = axs[0,1].imshow(msta[celln],cmap='seismic',vmin=-crange,vmax=crange)
            axs[0,1].set_title('STA,cell: {:d}'.format(celln))
            axs[0,1].axis('off')
            add_colorbar(img)


            for modeln in range(len(titles)):
                axs[0,2].errorbar(var_ranges[modeln],tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln],label=titles[modeln],c=clrs[modeln],lw=4,elinewidth=3)
            axs[0,2].set_ylim(bottom=0,top=np.max(ax_ylims,axis=1)[celln]+2*np.max(tuning_stds,axis=(1,2))[celln])
            axs[0,2].set_xlim(-50,50)
            axs[0,2].set_xlabel('Angle ($ ^{\degree}$)')
            axs[0,2].set_ylabel('Spikes/s')
            axs[0,2].set_title('Tuning Curves')
            axs[0,2].legend(bbox_to_anchor=(1.01, 1), fontsize=12)


            # Set up predicted spike range between 1-99th percentile
            pred_range = np.quantile(predcell,[.1,.99])
            msp_range = np.quantile(nspcell,[.1,.99])
            pred_rangelin = np.linspace(pred_range[0],pred_range[1],5)

            hist,xedges,yedges,img =axs[0,3].hist2d(mpred[celln]/model_dt,msp[celln]/model_dt,range=np.vstack((pred_range,msp_range)))#pred_smooth,msp_smooth)
            # axs[0,3].scatter(mpred[celln]/model_dt,msp[celln]/model_dt,c='k',s=15)
            axs[0,3].set_xlabel('Predicted Spike Rate')
            axs[0,3].set_ylabel('Actual Spike Rate')
            cbar = add_colorbar(img)
            cbar.set_label('count')

            mse_add = np.zeros((len(titles),len(nranges)-1,1))
            mse_mult = np.zeros((len(titles),len(nranges)-1,1))
            alpha_add = np.zeros((len(titles),len(nranges)-1,1))
            alpha_mult = np.zeros((len(titles),len(nranges)-1,1))

            for modeln in range(len(titles)):
                metric = move_test[:,modeln]
                nranges = np.quantile(metric,[0,.25,.5,.75,1])# np.linspace(np.nanmean(metric)-2*np.nanstd(metric), np.nanmean(metric)+2*np.nanstd(metric),N_bins)
                nranges = np.quantile(metric,[0,.25,.5,.75,1])# np.linspace(np.nanmean(metric)-2*np.nanstd(metric), np.nanmean(metric)+2*np.nanstd(metric),N_bins)
                stat_all, edges, _ = binned_statistic(predcell,nspcell, statistic='mean',bins=pred_rangelin)
                edge_mids = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
            #     axs[1,modeln].set_xlim(0,pred_range[1]+np.std(pred_range))
            #     axs[1,modeln].set_ylim(0,np.max(stat)+np.std(stat))
                for n in range(len(nranges)-1):
                    ind = np.where(((metric<=nranges[n+1])&(metric>nranges[n])))[0]
                    pred = predcell[ind]
                    sp = nspcell[ind]

                    stat_range, edges, _ = binned_statistic(pred, sp, statistic='mean',bins=pred_rangelin)
                    edge_mids = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])

                    res_add = minimize_scalar(f_add,args=(stat_range,stat_all))
                    res_mult = minimize_scalar(f_mult,args=(stat_range,stat_all))
                    mse_add[modeln, n] = res_add.fun
                    mse_mult[modeln, n] = res_mult.fun
                    alpha_add[modeln, n] = res_add.x
                    alpha_mult[modeln, n] = res_mult.x

                    axs[1,modeln].plot(edge_mids, stat_range,'.-', c=colors[n],label='{:.02f} : {:.02f}'.format(nranges[n],nranges[n+1]),lw=4,ms=20)
                    axs[1,modeln].plot(np.linspace(pred_range[0],pred_range[1]),np.linspace(pred_range[0],pred_range[1]),'k--',zorder=0)
                    axs[1,modeln].set_title('Metric: {}'.format(titles[modeln]), color=clrs[modeln])
                    axs[1,modeln].set_xlabel('Predicted Spike Rate')
                    axs[1,modeln].set_ylabel('Actual Spike Rate')
                axs[1,modeln].plot(edge_mids, stat_all,'.-', c='k', lw=5, ms=20, label='All_data')
                axs[1,modeln].legend(loc='upper left', fontsize=12)
                axs[1,modeln].axis('equal')

            min_add = np.min(mse_add,axis=-1)
            min_mult = np.min(mse_mult,axis=-1)

            crange = np.max(np.abs(min_add-min_mult))
            im = axs[1,-1].imshow(min_add-min_mult,cmap='seismic',vmin=-crange,vmax=crange)
            axs[1,-1].set_yticks(np.arange(0,4))
            axs[1,-1].set_yticklabels(titles)
            axs[1,-1].set_ylabel('Movement Model')
            axs[1,-1].set_xticks(np.arange(0,4))
            axs[1,-1].set_xticklabels(['.25','.5','.75','1'])
            axs[1,-1].set_xlabel('Quantile Range')
            axs[1,-1].set_title('$MSE_{add}$ - $MSE_{mult}$')
            cbar = add_colorbar(im)
            plt.tight_layout()

            pdf.savefig()
            plt.close()
        
# fig.savefig(FigPath/'CellSummary_N{}.png'.format(celln), facecolor='white', transparent=True)

In [None]:
celln=50
ncells=model_nsp.shape[-1]
colors = plt.cm.winter(np.linspace(0,1,4))
clrs = ['blue','orange','green','red']
fig, axs = plt.subplots(2,5, figsize=((35,10))) #np.floor(7.5*len(model_nsp)).astype(int)
predcell = mpred[celln]/model_dt
nspcell = msp[celln]/model_dt
msp_smooth=(np.convolve(msp[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
pred_smooth=(np.convolve(mpred[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
axs[0,0].plot(msp_smooth,'k',label='test FR')
axs[0,0].plot(pred_smooth,'r', label='pred FR')
axs[0,0].set_xlabel('Frame #')
axs[0,0].set_ylabel('Smoothed Firing Rate (spks/s)')
axs[0,0].legend()
axs[0,0].set_title('cc={:.2f}, \n lag={:d}'.format(mcc[celln],lag_list[m_lags[celln]]))
crange = np.max(np.abs(msta[celln]))
img = axs[0,1].imshow(msta[celln],cmap='seismic',vmin=-crange,vmax=crange)
axs[0,1].set_title('STA,cell: {:d}'.format(celln))
axs[0,1].axis('off')
add_colorbar(img)


for modeln in range(len(titles)):
    axs[0,2].errorbar(var_ranges[modeln],tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln],label=titles[modeln],c=clrs[modeln],lw=4,elinewidth=3)
axs[0,2].set_ylim(bottom=0,top=np.max(ax_ylims,axis=1)[celln]+2*np.max(tuning_stds,axis=(1,2))[celln])
axs[0,2].set_xlim(-50,50)
axs[0,2].set_xlabel('Angle ($ ^{\degree}$)')
axs[0,2].set_ylabel('Spikes/s')
axs[0,2].set_title('Tuning Curves')
axs[0,2].legend(bbox_to_anchor=(1.01, 1), fontsize=12)


# Set up predicted spike range between 1-99th percentile
stat_bins = 5
pred_range = np.quantile(predcell,[.1,.99])
msp_range = np.quantile(nspcell,[.1,.99])
pred_rangelin = np.linspace(pred_range[0],pred_range[1],stat_bins)

# hist,xedges,yedges,img =axs[0,3].hist2d(mpred[celln]/model_dt,msp[celln]/model_dt,range=np.vstack((pred_range,msp_range)))#pred_smooth,msp_smooth)
axs[0,3].scatter(mpred[celln]/model_dt,msp[celln]/model_dt,c='k',s=15)
axs[0,3].set_xlabel('Predicted Spike Rate')
axs[0,3].set_ylabel('Actual Spike Rate')
cbar = add_colorbar(img)
cbar.set_label('count')

mse_add = np.zeros((ncells,len(titles),len(nranges)-1))
mse_mult = np.zeros((ncells,len(titles),len(nranges)-1))
alpha_add = np.zeros((ncells,len(titles),len(nranges)-1))
alpha_mult = np.zeros((ncells,len(titles),len(nranges)-1))

traces = np.zeros((ncells,len(titles),len(nranges)-1,stat_bins-1)) # (model_type,quartile,FR)
traces_mean = np.zeros((ncells,len(titles),stat_bins-1)) # (model_type,quartile,FR)
edges_all = np.zeros((ncells,len(titles),len(nranges)-1,stat_bins-1)) # (model_type,quartile,FR)
# df_traces = pd.DataFrame([],columns=['modeln','quartile','FR']) 
for modeln in range(len(titles)):
    metric = move_test[:,modeln]
    nranges = np.quantile(metric,[0,.25,.5,.75,1])# np.linspace(np.nanmean(metric)-2*np.nanstd(metric), np.nanmean(metric)+2*np.nanstd(metric),N_bins)
    stat_all, edges, _ = binned_statistic(predcell,nspcell, statistic='mean',bins=pred_rangelin)
    edge_mids = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
    traces_mean[celln,modeln]=stat_all
#     axs[1,modeln].set_xlim(0,pred_range[1]+np.std(pred_range))
#     axs[1,modeln].set_ylim(0,np.max(stat)+np.std(stat))
    for n in range(len(nranges)-1):
        ind = np.where(((metric<=nranges[n+1])&(metric>nranges[n])))[0]
        pred = predcell[ind]
        sp = nspcell[ind]

        stat_range, edges, _ = binned_statistic(pred, sp, statistic='mean',bins=pred_rangelin)
        edge_mids = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
        traces[celln,modeln,n]=stat_range
        edges_all[celln,modeln,n]=edge_mids
        res_add = minimize_scalar(f_add,args=(stat_range,stat_all))
        res_mult = minimize_scalar(f_mult,args=(stat_range,stat_all))
        mse_add[celln, modeln, n] = res_add.fun
        mse_mult[celln, modeln, n] = res_mult.fun
        alpha_add[celln, modeln, n] = res_add.x
        alpha_mult[celln, modeln, n] = res_mult.x
        
        axs[1,modeln].plot(edge_mids, stat_range,'.-', c=colors[n],label='{:.02f} : {:.02f}'.format(nranges[n],nranges[n+1]),lw=4,ms=20)
        axs[1,modeln].plot(np.linspace(pred_range[0],pred_range[1]),np.linspace(pred_range[0],pred_range[1]),'k--',zorder=0)
        axs[1,modeln].set_title('Metric: {}'.format(titles[modeln]), color=clrs[modeln])
        axs[1,modeln].set_xlabel('Predicted Spike Rate')
        axs[1,modeln].set_ylabel('Actual Spike Rate')
#     axs[1,modeln].plot(edge_mids, stat_all,'.-', c='k', lw=5, ms=20, label='All_data')
    axs[1,modeln].legend(loc='upper left', fontsize=12)
    axs[1,modeln].axis('equal')

dmodel = mse_add[celln]-mse_mult[celln]
crange = np.max(np.abs(dmodel))
im = axs[1,-1].imshow(dmodel,cmap='seismic',vmin=-crange,vmax=crange)
axs[1,-1].set_yticks(np.arange(0,4))
axs[1,-1].set_yticklabels(titles)
axs[1,-1].set_ylabel('Movement Model')
axs[1,-1].set_xticks(np.arange(0,4))
axs[1,-1].set_xticklabels(['.25','.5','.75','1'])
axs[1,-1].set_xlabel('Quantile Range')
axs[1,-1].set_title('$MSE_{add}$ - $MSE_{mult}$')
cbar = add_colorbar(im)
plt.tight_layout()

fig.savefig(FigPath/'CellSummary_N{}.png'.format(celln), facecolor='white', transparent=True)

### Plotting Temporal Fits

In [None]:
celln=51
bin_length=40
ncells=model_nsp.shape[-1]
colors = plt.cm.winter(np.linspace(0,1,4))
clrs = ['blue','orange','green','red']
fig, axs = plt.subplots(3,5, figsize=((35,15))) #np.floor(7.5*len(model_nsp)).astype(int)
predcell = mpred[celln]/model_dt
nspcell = msp[celln]/model_dt
msp_smooth=(np.convolve(msp[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
pred_smooth=(np.convolve(mpred[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
axs[1,0].plot(msp_smooth,'k',label='test FR')
axs[1,0].plot(pred_smooth,'r', label='pred FR')
axs[1,0].set_xlabel('Frame #')
axs[1,0].set_ylabel('Firing Rate (spks/s)')
axs[1,0].legend()
axs[1,0].set_title('Smoothed FRs')

crange = np.max(np.abs(msta[celln]))
for n in range(nt_glm_lag):
    img = axs[0,n].imshow(msta[celln,n],cmap='seismic',vmin=-crange,vmax=crange)
    add_colorbar(img)
    axs[0,n].axis('off')
    axs[0,n].set_title('Lag:{:03d} ms'.format(int(1000*(n-nt_glm_lag+1)*model_dt)))
    axs[0,n].axis('off')
add_colorbar(img)

# Eye Tuning Curve
for modeln in range(len(titles)-2):
    axs[1,1].errorbar(var_ranges[modeln],tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln],label=titles[modeln],c=clrs[modeln],lw=4,elinewidth=3)
axs[1,1].set_ylim(bottom=0,top=np.max(ax_ylims,axis=1)[celln]+2*np.max(tuning_stds,axis=(1,2))[celln])
axs[1,1].set_xlim(-50,50)
axs[1,1].set_xlabel('Angle ($ ^{\degree}$)')
axs[1,1].set_ylabel('Spikes/s')
axs[1,1].set_title('Eye Tuning Curves')
axs[1,1].legend(bbox_to_anchor=(1.01, 1), fontsize=12)

# Head Tuning Curves
for modeln in range(2,len(titles)):
    axs[1,2].errorbar(var_ranges[modeln],tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln],label=titles[modeln],c=clrs[modeln],lw=4,elinewidth=3)
axs[1,2].set_ylim(bottom=0,top=np.max(ax_ylims,axis=1)[celln]+2*np.max(tuning_stds,axis=(1,2))[celln])
axs[1,2].set_xlim(-50,50)
axs[1,2].set_xlabel('Angle ($ ^{\degree}$)')
axs[1,2].set_ylabel('Spikes/s')
axs[1,2].set_title('Head Tuning Curves')
axs[1,2].legend(bbox_to_anchor=(1.01, 1), fontsize=12)

# Set up predicted spike range between 1-99th percentile
stat_bins = 5
pred_range = np.quantile(predcell,[.1,.99])
msp_range = np.quantile(nspcell,[.1,.99])
pred_rangelin = np.linspace(pred_range[0],pred_range[1],stat_bins)
quartiles = [0,.25,.5,.75,1]
axs[1,3].scatter(mpred[celln]/model_dt,msp[celln]/model_dt,c='k',s=15)
axs[1,3].set_xlabel('Predicted Spike Rate')
axs[1,3].set_ylabel('Actual Spike Rate')
cbar = add_colorbar(img)
cbar.set_label('count')

hist,xedges,yedges,img =axs[1,4].hist2d(mpred[celln]/model_dt,msp[celln]/model_dt,range=np.vstack((pred_range,msp_range)))#pred_smooth,msp_smooth)
axs[1,4].set_xlabel('Predicted Spike Rate')
axs[1,4].set_ylabel('Actual Spike Rate')
cbar = add_colorbar(img)
cbar.set_label('count')


mse_add = np.zeros((ncells,len(titles),len(quartiles)-1))
mse_mult = np.zeros((ncells,len(titles),len(quartiles)-1))
alpha_add = np.zeros((ncells,len(titles),len(quartiles)-1))
alpha_mult = np.zeros((ncells,len(titles),len(quartiles)-1))

traces = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
traces_mean = np.zeros((ncells,len(titles),stat_bins-1)) # (model_type,quartile,FR)
edges_all = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
# df_traces = pd.DataFrame([],columns=['modeln','quartile','FR']) 
for modeln in range(len(titles)):
    metric = move_test[:,modeln]
    nranges = np.quantile(metric,quartiles)# np.linspace(np.nanmean(metric)-2*np.nanstd(metric), np.nanmean(metric)+2*np.nanstd(metric),N_bins)
    stat_all, edges, _ = binned_statistic(predcell,nspcell, statistic='mean',bins=pred_rangelin)
    edge_mids = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
    traces_mean[celln,modeln]=stat_all
    max_fr = np.max(stat_all)
#     axs[1,modeln].set_xlim(0,pred_range[1]+np.std(pred_range))
#     axs[1,modeln].set_ylim(0,np.max(stat)+np.std(stat))

    for n in range(len(nranges)-1):
        ind = np.where(((metric<=nranges[n+1])&(metric>nranges[n])))[0]
        pred = predcell[ind]
        sp = nspcell[ind]

        stat_range, edges, _ = binned_statistic(pred, sp, statistic='mean',bins=pred_rangelin)
        edge_mids = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
        traces[celln,modeln,n]=stat_range
        edges_all[celln,modeln,n]=edge_mids
        res_add = minimize_scalar(f_add,args=(stat_range/max_fr, stat_all/max_fr))
        res_mult = minimize_scalar(f_mult,args=(stat_range/max_fr, stat_all/max_fr))
        mse_add[celln, modeln, n] = res_add.fun
        mse_mult[celln, modeln, n] = res_mult.fun
        alpha_add[celln, modeln, n] = res_add.x
        alpha_mult[celln, modeln, n] = res_mult.x

        axs[2,modeln].plot(edge_mids, stat_range,'.-', c=colors[n],label='{:.02f} : {:.02f}'.format(nranges[n],nranges[n+1]),lw=4,ms=20)
        axs[2,modeln].plot(np.linspace(pred_range[0],pred_range[1]),np.linspace(pred_range[0],pred_range[1]),'k--',zorder=0)
        axs[2,modeln].set_title('Metric: {}'.format(titles[modeln]), color=clrs[modeln])
        axs[2,modeln].set_xlabel('Predicted Spike Rate')
        axs[2,modeln].set_ylabel('Actual Spike Rate')
#     axs[1,modeln].plot(edge_mids, stat_all,'.-', c='k', lw=5, ms=20, label='All_data')
    axs[2,modeln].legend(loc='upper left', fontsize=12)
    axs[2,modeln].axis('equal')

dmodel = mse_add[celln]-mse_mult[celln]
crange = np.max(np.abs(dmodel))
im = axs[2,-1].imshow(dmodel,cmap='seismic',vmin=-crange,vmax=crange)
axs[2,-1].set_yticks(np.arange(0,4))
axs[2,-1].set_yticklabels(titles)
axs[2,-1].set_ylabel('Movement Model')
axs[2,-1].set_xticks(np.arange(0,4))
axs[2,-1].set_xticklabels(['.25','.5','.75','1'])
axs[2,-1].set_xlabel('Quantile Range')
axs[2,-1].set_title('$MSE_{add}$ - $MSE_{mult}$')
cbar = add_colorbar(im)

plt.suptitle('Celln:{}, cc={:.03f}'.format(celln,mcc[celln]),y=1,fontsize=30)
plt.tight_layout()


# fig.savefig(FigPath/'CellSummary_N{}_T{:02d}.png'.format(celln,nt_glm_lag), facecolor='white', transparent=True)

In [None]:
with PdfPages(FigPath/ 'GLM_{}_dt{:03d}_T{:02d}_cellsummary_sig.pdf'.format(model_type,int(model_dt*1000),nt_glm_lag)) as pdf:
    for celln in tqdm(range(msp.shape[0])):
        if mcc[celln]>.3:
            fig, axs = plt.subplots(3,5, figsize=((35,15))) #np.floor(7.5*len(model_nsp)).astype(int)
            predcell = mpred[celln]/model_dt
            nspcell = msp[celln]/model_dt
            msp_smooth=(np.convolve(msp[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
            pred_smooth=(np.convolve(mpred[celln], np.ones(bin_length), 'same')) / (bin_length * model_dt)
            axs[1,0].plot(msp_smooth,'k',label='test FR')
            axs[1,0].plot(pred_smooth,'r', label='pred FR')
            axs[1,0].set_xlabel('Frame #')
            axs[1,0].set_ylabel('Firing Rate (spks/s)')
            axs[1,0].legend()
            axs[1,0].set_title('Smoothed FRs')

            crange = np.max(np.abs(msta[celln]))
            for n in range(nt_glm_lag):
                img = axs[0,n].imshow(msta[celln,n],cmap='seismic',vmin=-crange,vmax=crange)
                add_colorbar(img)
                axs[0,n].axis('off')
                axs[0,n].set_title('Lag:{:03d} ms'.format(int(1000*(n-nt_glm_lag+1)*model_dt)))
                axs[0,n].axis('off')
            add_colorbar(img)

            # Eye Tuning Curve
            for modeln in range(len(titles)-2):
                axs[1,1].errorbar(var_ranges[modeln],tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln],label=titles[modeln],c=clrs[modeln],lw=4,elinewidth=3)
            axs[1,1].set_ylim(bottom=0,top=np.max(ax_ylims,axis=1)[celln]+2*np.max(tuning_stds,axis=(1,2))[celln])
            axs[1,1].set_xlim(-50,50)
            axs[1,1].set_xlabel('Angle ($ ^{\degree}$)')
            axs[1,1].set_ylabel('Spikes/s')
            axs[1,1].set_title('Eye Tuning Curves')
            axs[1,1].legend(bbox_to_anchor=(1.01, 1), fontsize=12)
            
            # Head Tuning Curves
            for modeln in range(2,len(titles)):
                axs[1,2].errorbar(var_ranges[modeln],tuning_curves[celln,modeln], yerr=tuning_stds[celln,modeln],label=titles[modeln],c=clrs[modeln],lw=4,elinewidth=3)
            axs[1,2].set_ylim(bottom=0,top=np.max(ax_ylims,axis=1)[celln]+2*np.max(tuning_stds,axis=(1,2))[celln])
            axs[1,2].set_xlim(-50,50)
            axs[1,2].set_xlabel('Angle ($ ^{\degree}$)')
            axs[1,2].set_ylabel('Spikes/s')
            axs[1,2].set_title('Head Tuning Curves')
            axs[1,2].legend(bbox_to_anchor=(1.01, 1), fontsize=12)
            
            
            # Set up predicted spike range between 1-99th percentile
            stat_bins = 5
            pred_range = np.quantile(predcell,[.1,.99])
            msp_range = np.quantile(nspcell,[.1,.99])
            pred_rangelin = np.linspace(pred_range[0],pred_range[1],stat_bins)
            quartiles = [0,.25,.5,.75,1]
            axs[1,3].scatter(mpred[celln]/model_dt,msp[celln]/model_dt,c='k',s=15)
            axs[1,3].set_xlabel('Predicted Spike Rate')
            axs[1,3].set_ylabel('Actual Spike Rate')
            cbar = add_colorbar(img)
            cbar.set_label('count')

            hist,xedges,yedges,img =axs[1,4].hist2d(mpred[celln]/model_dt,msp[celln]/model_dt,range=np.vstack((pred_range,msp_range)))#pred_smooth,msp_smooth)
            axs[1,4].set_xlabel('Predicted Spike Rate')
            axs[1,4].set_ylabel('Actual Spike Rate')
            cbar = add_colorbar(img)
            cbar.set_label('count')


            mse_add = np.zeros((ncells,len(titles),len(quartiles)-1))
            mse_mult = np.zeros((ncells,len(titles),len(quartiles)-1))
            alpha_add = np.zeros((ncells,len(titles),len(quartiles)-1))
            alpha_mult = np.zeros((ncells,len(titles),len(quartiles)-1))

            traces = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
            traces_mean = np.zeros((ncells,len(titles),stat_bins-1)) # (model_type,quartile,FR)
            edges_all = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
            # df_traces = pd.DataFrame([],columns=['modeln','quartile','FR']) 
            for modeln in range(len(titles)):
                metric = move_test[:,modeln]
                nranges = np.quantile(metric,quartiles)# np.linspace(np.nanmean(metric)-2*np.nanstd(metric), np.nanmean(metric)+2*np.nanstd(metric),N_bins)
                stat_all, edges, _ = binned_statistic(predcell,nspcell, statistic='mean',bins=pred_rangelin)
                edge_mids = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
                traces_mean[celln,modeln]=stat_all
                max_fr = np.max(stat_all)
            #     axs[1,modeln].set_xlim(0,pred_range[1]+np.std(pred_range))
            #     axs[1,modeln].set_ylim(0,np.max(stat)+np.std(stat))

                for n in range(len(nranges)-1):
                    ind = np.where(((metric<=nranges[n+1])&(metric>nranges[n])))[0]
                    pred = predcell[ind]
                    sp = nspcell[ind]

                    stat_range, edges, _ = binned_statistic(pred, sp, statistic='mean',bins=pred_rangelin)
                    edge_mids = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
                    traces[celln,modeln,n]=stat_range
                    edges_all[celln,modeln,n]=edge_mids
                    res_add = minimize_scalar(f_add,args=(stat_range/max_fr, stat_all/max_fr))
                    res_mult = minimize_scalar(f_mult,args=(stat_range/max_fr, stat_all/max_fr))
                    mse_add[celln, modeln, n] = res_add.fun
                    mse_mult[celln, modeln, n] = res_mult.fun
                    alpha_add[celln, modeln, n] = res_add.x
                    alpha_mult[celln, modeln, n] = res_mult.x

                    axs[2,modeln].plot(edge_mids, stat_range,'.-', c=colors[n],label='{:.02f} : {:.02f}'.format(nranges[n],nranges[n+1]),lw=4,ms=20)
                    axs[2,modeln].plot(np.linspace(pred_range[0],pred_range[1]),np.linspace(pred_range[0],pred_range[1]),'k--',zorder=0)
                    axs[2,modeln].set_title('Metric: {}'.format(titles[modeln]), color=clrs[modeln])
                    axs[2,modeln].set_xlabel('Predicted Spike Rate')
                    axs[2,modeln].set_ylabel('Actual Spike Rate')
            #     axs[1,modeln].plot(edge_mids, stat_all,'.-', c='k', lw=5, ms=20, label='All_data')
                axs[2,modeln].legend(loc='upper left', fontsize=12)
                axs[2,modeln].axis('equal')

            dmodel = mse_add[celln]-mse_mult[celln]
            crange = np.max(np.abs(dmodel))
            im = axs[2,-1].imshow(dmodel,cmap='seismic',vmin=-crange,vmax=crange)
            axs[2,-1].set_yticks(np.arange(0,4))
            axs[2,-1].set_yticklabels(titles)
            axs[2,-1].set_ylabel('Movement Model')
            axs[2,-1].set_xticks(np.arange(0,4))
            axs[2,-1].set_xticklabels(['.25','.5','.75','1'])
            axs[2,-1].set_xlabel('Quantile Range')
            axs[2,-1].set_title('$MSE_{add}$ - $MSE_{mult}$')
            cbar = add_colorbar(im)

            plt.suptitle('Celln:{}, cc={:.03f}'.format(celln,mcc[celln]),y=1,fontsize=30)
            plt.tight_layout()

            pdf.savefig()
            plt.close()
        
# fig.savefig(FigPath/'CellSummary_N{}.png'.format(celln), facecolor='white', transparent=True)

In [None]:
n = 1
model_n=-1
celln=51
# for modeln in range(len(titles)):
traces = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
traces_mean = np.zeros((ncells,len(titles),stat_bins-1)) # (model_type,quartile,FR)
edges_all = np.zeros((ncells,len(titles),len(quartiles)-1,stat_bins-1)) # (model_type,quartile,FR)
metric = move_test[:,modeln]
nranges = np.quantile(metric,[0,.25,.5,.75,1])# np.linspace(np.nanmean(metric)-2*np.nanstd(metric), np.nanmean(metric)+2*np.nanstd(metric),N_bins)
stat_all, edges, _ = binned_statistic(predcell,nspcell, statistic='mean',bins=pred_rangelin)
edge_mids = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
#     axs[1,modeln].set_xlim(0,pred_range[1]+np.std(pred_range))
#     axs[1,modeln].set_ylim(0,np.max(stat)+np.std(stat))
#     for n in range(len(nranges)-1):
ind = np.where(((metric<=nranges[n+1])&(metric>nranges[n])))[0]
pred = predcell[ind]
sp = nspcell[ind]

stat_range, edges, _ = binned_statistic(pred, sp, statistic='mean',bins=pred_rangelin)
edge_mids = np.array([(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)])
traces_mean[celln,modeln]=stat_all

traces[celln,modeln,n]=stat_range
edges_all[celln,modeln,n]=edge_mids
res_add = minimize_scalar(f_add,args=(stat_range,stat_all))
res_mult = minimize_scalar(f_mult,args=(stat_range,stat_all))
mse_add[celln, modeln, n] = res_add.fun
mse_mult[celln, modeln, n] = res_mult.fun
alpha_add[celln, modeln, n] = res_add.x
alpha_mult[celln, modeln, n] = res_mult.x
        
#         axs[1,modeln].plot(edge_mids, stat_range,'.-', c=colors[n],label='{:.02f} : {:.02f}'.format(nranges[n],nranges[n+1]),lw=4,ms=20)
#         axs[1,modeln].plot(np.linspace(pred_range[0],pred_range[1]),np.linspace(pred_range[0],pred_range[1]),'k--',zorder=0)
#         axs[1,modeln].set_title('Metric: {}'.format(titles[modeln]), color=clrs[modeln])
#         axs[1,modeln].set_xlabel('Predicted Spike Rate')
#         axs[1,modeln].set_ylabel('Actual Spike Rate')
#     axs[1,modeln].plot(edge_mids, stat_all,'.-', c='k', lw=5, ms=20, label='All_data')
#     axs[1,modeln].legend(loc='upper left', fontsize=12)
#     axs[1,modeln].axis('equal')


SInce we are optimizing the visual do the elastic net with l1 norm look at RFS

Add temporal compoennt for spatiotemporal RF

In [None]:
res_add = minimize_scalar(f_add,args=(stat_range,stat_all),tol=1e-6)
res_mult = minimize_scalar(f_mult,args=(stat_range,stat_all),tol=1e-6)
mse_add[celln,modeln, n] = res_add.fun
mse_mult[celln,modeln, n] = res_mult.fun
alpha_add[celln,modeln, n] = res_add.x
alpha_mult[celln,modeln, n] = res_mult.x

In [None]:
res_mult.x,alpha_mult[celln,modeln, n] 

In [None]:
res_add,res_mult

In [None]:
alpha_add[celln,modeln,n],alpha_mult[celln,modeln,n]

In [None]:
mse_add[celln,modeln, n], mse_mult[celln,modeln, n]

In [None]:

plt.figure(figsize=(8,8))
plt.plot(edge_mids,traces_mean[celln,modeln],'.-', c='k', lw=5, ms=20, label='All_data')
plt.plot(edge_mids,(traces_mean[celln,modeln]*alpha_mult[celln,modeln,n]).T,'--', label='MultFit',c=colors[n],lw=4,ms=20)
plt.plot(edge_mids,(traces_mean[celln,modeln]+alpha_add[celln,modeln,n]).T,'-.', label='AddFit', c=colors[n],lw=4,ms=20)
plt.plot(edge_mids, traces[celln,modeln,n],'.-', c=colors[n],label='{:.02f} : {:.02f}'.format(nranges[n],nranges[n+1]),lw=4,ms=20)
plt.legend()

In [None]:
alphas = np.arange(-10,10,.001)
mse_add = np.zeros(alphas.shape[0])
mse_mult = np.zeros(alphas.shape[0])
for ind,alpha in enumerate(alphas):
    mse_add[ind] = np.mean((stat_range - stat_all+alpha)**2)
    mse_mult[ind] = np.mean((stat_range - stat_all*alpha)**2)

In [None]:
np.min(mse_add),np.min(mse_mult)

In [None]:
plt.plot(mse_add)
plt.plot(mse_mult)

In [None]:
plt.plot(stat_all+alphas[np.argmin(mse_add)])
plt.plot(stat_all+alphas[np.argmin(mse_mult)])
plt.plot(stat_all, 'k')

In [None]:
fig, axs = plt.subplots(figsize=(5,5))
crange = np.max(np.abs(min_add-min_mult))
im = axs.imshow(min_add-min_mult,cmap='seismic',vmin=-crange,vmax=crange,origin='lower')
axs.set_yticks(np.arange(0,4))
axs.set_yticklabels(titles)
axs.set_ylabel('Movement Model')
axs.set_xticks(np.arange(0,4))
axs.set_xticklabels(['.25','.5','.75','1'])
axs.set_xlabel('Quantile Range')
axs.set_title('$MSE_{add}$ - $MSE_{mult}$')
cbar = add_colorbar(im)


In [None]:
fig, axs = plt.subplots(figsize=(5,5))
crange = np.max(np.abs(alpha_mult))
im = axs.imshow(alpha_mult,cmap='seismic',vmin=-crange,vmax=crange,origin='lower')
axs.set_yticks(np.arange(0,4))
axs.set_yticklabels(titles)
axs.set_ylabel('Movement Model')
axs.set_xticks(np.arange(0,4))
axs.set_xticklabels(['.25','.5','.75','1'])
axs.set_xlabel('Quantile Range')
axs.set_title('$MSE_{add}$ - $MSE_{mult}$')
cbar = add_colorbar(im)


In [None]:
##### Explore Neurons Write to pdf #####
with PdfPages(FigPath/ 'MaxCC_GLM_dt{:03d}.pdf'.format(int(model_dt*1000))) as pdf:
    for ind in tqdm(range(model_nsp.shape[1])):
        fig, axs = plt.subplots(1,3, figsize=((25,6))) #np.floor(7.5*len(model_nsp)).astype(int)
        axs[0].plot((np.convolve(msp[ind], np.ones(bin_length), 'same')) / (bin_length * model_dt),'k',label='test FR')
        axs[0].plot((np.convolve(mpred[ind], np.ones(bin_length), 'same')) / (bin_length * model_dt),'r', label='pred FR')
        axs[0].set_xlabel('Frame #')
        axs[0].set_ylabel('Firing Rate (spks/s)')
        axs[0].legend()
        axs[0].set_title('cc={:.2f}, {}, \n lag={:d}'.format(mcc[ind],titles_all[m_models[ind]],lag_list[m_lags[ind]]))
        img = axs[1].imshow(msta[ind],cmap='seismic')
        axs[1].axis('off')
        axs[1].set_title('STA,cell: {:d}'.format(ind))
        add_colorbar(img)
        axs[2].bar(np.arange(mw_move.shape[-1]),mw_move[ind],color='b')
        axs[2].set_xticks(np.arange(mw_move.shape[-1]))
        axs[2].set_xticklabels(titles)
        axs[2].set_ylabel('Movement Weights')
        axs[2].axhline(0, color='grey', linewidth=0.8)
        plt.tight_layout()
        pdf.savefig()
        plt.close()

## Additive GLM

### Testing Temporal components

In [None]:
x_train = train_data.reshape(train_data.shape[0],-1)
x_train.shape

In [None]:
train_data = train_vid; test_data=test_vid
lag = 0
celln = 51
nt_glm_lag=5

# Initialize movement combinations
titles = np.array(['th','phi','dth','dphi']) # 'roll','pitch'
titles_all = []
for n in range(1,len(titles)):
    perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))
        
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis], train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis], test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))

lag_list = [ -2, -1, 0 , 1, 2]
lambdas = 1024 * (2**np.arange(0,16))
nks = np.shape(model_vid_sm)[1:]; nk = nks[0]*nks[1]

perms = []#np.array(list(itertools.combinations(np.arange(len(titles)), 1)))[1]

##### Format data #####
# save shape of train_data for initialization
nks = np.shape(train_data)[1:]; nk = nks[0]*nks[1]*nt_glm_lag

# Shift spikes by -lag for GLM fits
sps_train = np.roll(train_nsp[:,celln],-lag)
sps_test = np.roll(test_nsp[:,celln],-lag)

# Initialize saving movement weights 
w_move = np.zeros(move_train.shape[1])

# Take combination of movements
move_train = move_train[:,perms]
move_test = move_test[:,perms]

# Reshape data (video) into (T*n)xN array
x_train = train_data.reshape(train_data.shape[0],-1)
x_train = np.hstack([np.roll(x_train, nframes, axis=0) for nframes in reversed(range(nt_glm_lag))])
x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis=1) # append column of ones for fitting intercept
# move_train = np.hstack([np.roll(move_train,nframes, axis=0) for nframes in reversed(range(nt_glm_lag))])
# x_train = np.concatenate((x_train, move_train),axis=1)

x_test = test_data.reshape(test_data.shape[0],-1) 
x_test = np.hstack([np.roll(x_test,nframes, axis=0) for nframes in reversed(range(nt_glm_lag))])
x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis=1) # append column of ones
# move_test = np.hstack([np.roll(move_test,nframes, axis=0) for nframes in reversed(range(nt_glm_lag))])
# x_test = np.concatenate((x_test, move_test),axis=1)

# Prepare Design Matrix
nlam = len(lambdas)
XXtr = x_train.T @ x_train
XYtr = x_train.T @ sps_train

# Initialze mse traces for regularization cross validation
msetrain = np.zeros((nlam,1))
msetest = np.zeros((nlam,1))
w_ridge = np.zeros((x_train.shape[-1],nlam))
# Inverse matrix for regularization 
Cinv = np.eye(nk)
Cinv = linalg.block_diag(Cinv,np.zeros((1+move_test.shape[-1], 1+move_test.shape[-1])))
# loop over regularization strength
for l in range(len(lambdas)):  
    # calculate MAP estimate               
    w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
    w_ridge[:,l] = w
    # calculate test and training rms error
    msetrain[l] = np.mean((sps_train - x_train@w)**2)
    msetest[l] = np.mean((sps_test - x_test@w)**2)

In [None]:
# With temporal filter
plt.plot(msetrain)
plt.plot(msetest)

In [None]:
best_lambda = np.argmin(msetest)
w = w_ridge[:,best_lambda]

In [None]:
plt.imshow(w[:-(1+move_test.shape[-1])].reshape(20,30))
plt.colorbar()

In [None]:
fig, axs = plt.subplots(1,nt_glm_lag,figsize=(20,5))
for n in range(nt_glm_lag):
    img = axs[n].imshow(np.reshape(w[:-(1+move_test.shape[-1])],(nt_glm_lag,)+nks)[n],cmap='seismic')
    add_colorbar(img)
    axs[n].axis('off')
    axs[n].set_title('Lag:{}'.format(n-nt_glm_lag+1))
plt.suptitle('Celln:{}'.format(celln),y=.75,fontsize=20)
plt.tight_layout()

fig.savefig(FigPath/'TemporalRF_N{}.png'.format(celln), facecolor='white', transparent=True)


In [None]:
w[-move_test.shape[-1]:]

In [None]:
plt.plot(w[-move_test.shape[-1]:])


In [None]:
# select best cross-validated lambda for RF
best_lambda = np.argmin(msetest)
w = w_ridge[:,best_lambda]
ridge_rf = w_ridge[:,best_lambda]
sta_all = np.reshape(w[:-(1+move_test.shape[-1])],(nt_glm_lag,)+nks)
# w_move[perms] = w[-move_test.shape[-1]:]

# predicted firing rate
sp_pred = x_test@ridge_rf
# bin the firing rate to get smooth rate vs time
sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
# a few diagnostics
err = np.mean((sp_smooth-pred_smooth)**2)
cc = np.corrcoef(sp_smooth, pred_smooth)
cc_all = cc[0,1]

In [None]:
plt.plot(sp_smooth)
plt.plot(pred_smooth)
cc

In [None]:
@ray.remote
def do_glm_temporal_fit(train_nsp, test_nsp, train_data, test_data, move_train, move_test, celln, perms, lag, lambdas, bin_length=40, model_dt=.1,nt_glm_lag=4):
    ##### Format data #####
    # save shape of train_data for initialization
    nks = np.shape(train_data)[1:]; nk = nks[0]*nks[1]
    
    # Shift spikes by -lag for GLM fits
    sps_train = np.roll(train_nsp[:,celln],-lag)
    sps_test = np.roll(test_nsp[:,celln],-lag)
    
    # Initialize saving movement weights 
    w_move = np.zeros(move_train.shape[1])
    
    # Take combination of movements
    move_train = move_train[:,perms]
    move_test = move_test[:,perms]
    
    # Reshape data (video) into (T*n)xN array
    x_train = train_data.reshape(train_data.shape[0],-1)
    x_train = np.hstack([np.roll(x_train,nframes) for nframes in range(nt_glm_lag)])
    x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis = 1) # append column of ones for fitting intercept
    x_train = np.concatenate((x_train, move_train),axis=1)
    
    x_test = test_data.reshape(test_data.shape[0],-1) 
    x_test = np.hstack([np.roll(x_test,nframes) for nframes in range(nt_glm_lag)])
    x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis = 1) # append column of ones
    x_test = np.concatenate((x_test, move_test),axis=1)
    
    # Prepare Design Matrix
    nlam = len(lambdas)
    XXtr = x_train.T @ x_train
    XYtr = x_train.T @ sps_train
    
    # Initialze mse traces for regularization cross validation
    msetrain = np.zeros((nlam,1))
    msetest = np.zeros((nlam,1))
    w_ridge = np.zeros((x_train.shape[-1],nlam))
    # Inverse matrix for regularization 
    Cinv = np.eye(nk)
    Cinv = linalg.block_diag(Cinv,np.zeros((1+move_test.shape[-1], 1+move_test.shape[-1])))
    # loop over regularization strength
    for l in range(len(lambdas)):  
        # calculate MAP estimate               
        w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
        w_ridge[:,l] = w
        # calculate test and training rms error
        msetrain[l] = np.mean((sps_train - x_train@w)**2)
        msetest[l] = np.mean((sps_test - x_test@w)**2)
    
    # select best cross-validated lambda for RF
    best_lambda = np.argmin(msetest)
    w = w_ridge[:,best_lambda]
    ridge_rf = w_ridge[:,best_lambda]
    sta_all = np.reshape(w[:-(1+move_test.shape[-1])],nks)
    w_move[perms] = w[-move_test.shape[-1]:]
    
    # predicted firing rate
    sp_pred = x_test@ridge_rf
    # bin the firing rate to get smooth rate vs time
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all = cc[0,1]
    
    return cc_all, sta_all, sps_test, sp_pred, w_move

### Parallel Fitting

In [None]:
@ray.remote
def do_glm_fit(train_nsp, test_nsp, train_data, test_data, move_train, move_test, celln, perms, lag, lambdas, bin_length=40, model_dt=.1):
    ##### Format data #####
    # save shape of train_data for initialization
    nks = np.shape(train_data)[1:]; nk = nks[0]*nks[1]
    
    # Shift spikes by -lag for GLM fits
    sps_train = np.roll(train_nsp[:,celln],-lag)
    sps_test = np.roll(test_nsp[:,celln],-lag)
    
    # Initialize saving movement weights 
    w_move = np.zeros(move_train.shape[1])
    
    # Take combination of movements
    move_train = move_train[:,perms]
    move_test = move_test[:,perms]
    
    # Reshape data (video) into (T*n)xN array
    x_train = train_data.reshape(train_data.shape[0],-1)
    x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis = 1) # append column of ones for fitting intercept
    x_train = np.concatenate((x_train, move_train),axis=1)
    
    x_test = test_data.reshape(test_data.shape[0],-1) 
    x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis = 1) # append column of ones
    x_test = np.concatenate((x_test, move_test),axis=1)
    
    # Prepare Design Matrix
    nlam = len(lambdas)
    XXtr = x_train.T @ x_train
    XYtr = x_train.T @ sps_train
    
    # Initialze mse traces for regularization cross validation
    msetrain = np.zeros((nlam,1))
    msetest = np.zeros((nlam,1))
    w_ridge = np.zeros((x_train.shape[-1],nlam))
    # Inverse matrix for regularization 
    Cinv = np.eye(nk)
    Cinv = linalg.block_diag(Cinv,np.zeros((1+move_test.shape[-1], 1+move_test.shape[-1])))
    # loop over regularization strength
    for l in range(len(lambdas)):  
        # calculate MAP estimate               
        w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
        w_ridge[:,l] = w
        # calculate test and training rms error
        msetrain[l] = np.mean((sps_train - x_train@w)**2)
        msetest[l] = np.mean((sps_test - x_test@w)**2)
    
    # select best cross-validated lambda for RF
    best_lambda = np.argmin(msetest)
    w = w_ridge[:,best_lambda]
    ridge_rf = w_ridge[:,best_lambda]
    sta_all = np.reshape(w[:-(1+move_test.shape[-1])],nks)
    w_move[perms] = w[-move_test.shape[-1]:]
    
    # predicted firing rate
    sp_pred = x_test@ridge_rf
    # bin the firing rate to get smooth rate vs time
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all = cc[0,1]
    
    return cc_all, sta_all, sps_test, sp_pred, w_move

In [None]:
start = time.time()

# Initialize movement combinations
titles = np.array(['th','phi','dth','dphi']) # 'roll','pitch'
titles_all = []
for n in range(1,len(titles)):
    perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))
        
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis], train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis], test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))

lag_list = [ -2, -1, 0 , 2, 2]
lambdas = 1024 * (2**np.arange(0,16))
nks = np.shape(model_vid_sm)[1:]; nk = nks[0]*nks[1]

# Put data into shared memory for parallization 
train_nsp_r = ray.put(train_nsp)
test_nsp_r = ray.put(test_nsp)
train_data_r = ray.put(train_vid)
test_data_r = ray.put(test_vid)
move_train_r = ray.put(move_train)
move_test_r = ray.put(move_test)
result_ids = []
# Loop over parameters appending process ids
for celln in range(train_nsp.shape[1]):
    for n in range(1,len(titles)):
        perms = np.array(list(itertools.combinations(np.arange(len(titles), n)))
        for ind in range(perms.shape[0]):
            for lag_ind, lag in enumerate(lag_list):    
                result_ids.append(do_glm_fit.remote(train_nsp_r, test_nsp_r, train_data_r, test_data_r, move_train_r, move_test_r, celln, perms[ind], lag, lambdas, model_dt=model_dt))

print('N_proc:', len(result_ids))
results_p = ray.get(result_ids)
print('GLM Add: ', time.time()-start)

In [None]:
##### Gather Data and Find Max CC Model #####
cc_all = np.stack([results_p[i][0] for i in range(len(results_p))])
sta_all = np.stack([results_p[i][1] for i in range(len(results_p))])
sp_raw = np.stack([results_p[i][2] for i in range(len(results_p))])
pred_raw = np.stack([results_p[i][3] for i in range(len(results_p))])
w_move_all = np.stack([results_p[i][4] for i in range(len(results_p))])

cc_all = cc_all.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),) + cc_all.shape[1:])
sta_all = sta_all.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),) + sta_all.shape[1:])
sp_raw = sp_raw.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),) + sp_raw.shape[1:])
pred_raw = pred_raw.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),) + pred_raw.shape[1:])
w_move_all = w_move_all.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),) + w_move_all.shape[1:])

m_cells, m_models, m_lags = np.where(cc_all==np.max(cc_all,axis=(-2,-1), keepdims=True))

mcc = cc_all[m_cells,m_models,m_lags]
msta = sta_all[m_cells,m_models,m_lags]
msp = sp_raw[m_cells,m_models,m_lags]
mpred = pred_raw[m_cells,m_models,m_lags]
mw_move = w_move_all[m_cells,m_models,m_lags]

In [None]:
GLM_add = {'cc_all': cc_all,
            'sta_all': sta_all,
            'sp_raw': sp_raw,
            'pred_raw': pred_raw,
            'w_move_all': w_move_all,}
ioh5.save(save_dir/'Add_GLM_Data_notsmooth_dt{:03d}.h5'.format(int(model_dt*1000)), GLM_add)

In [None]:
import plotly.express as px

fig = px.imshow(msta, animation_frame=0, binary_string=False,color_continuous_scale='RdBu_r')
fig.update_layout(width=500,
                  height=500,
                 )
fig.show()

In [None]:
bin_length=40
##### Explore Neurons #####
ind = 50
fig, axs = plt.subplots(1,3, figsize=((25,6))) #np.floor(7.5*len(model_nsp)).astype(int)
axs[0].plot((np.convolve(msp[ind], np.ones(bin_length), 'same')) / (bin_length * model_dt),'k',label='test FR')
axs[0].plot((np.convolve(mpred[ind], np.ones(bin_length), 'same')) / (bin_length * model_dt),'r', label='pred FR')
axs[0].set_xlabel('Frame #')
axs[0].set_ylabel('Firing Rate (spks/s)')
axs[0].legend()
axs[0].set_title('cc={:.2f}, {}, \n lag={:d}'.format(mcc[ind],titles_all[m_models[ind]],lag_list[m_lags[ind]]))
img = axs[1].imshow(msta[ind],cmap='seismic')
axs[1].set_title('STA,cell: {:d}'.format(ind))
axs[1].axis('off')
add_colorbar(img)
axs[2].bar(np.arange(mw_move.shape[-1]),mw_move[ind],color='b')
axs[2].set_xticks(np.arange(mw_move.shape[-1]))
axs[2].set_xticklabels(titles) # [m_models[ind]].split('_')
axs[2].axhline(0, color='grey', linewidth=0.8)

plt.tight_layout()


In [None]:
##### Explore Neurons Write to pdf #####
with PdfPages(FigPath/ 'MaxCC_GLM_dt{:03d}.pdf'.format(int(model_dt*1000))) as pdf:
    for ind in tqdm(range(model_nsp.shape[1])):
        fig, axs = plt.subplots(1,3, figsize=((25,6))) #np.floor(7.5*len(model_nsp)).astype(int)
        axs[0].plot((np.convolve(msp[ind], np.ones(bin_length), 'same')) / (bin_length * model_dt),'k',label='test FR')
        axs[0].plot((np.convolve(mpred[ind], np.ones(bin_length), 'same')) / (bin_length * model_dt),'r', label='pred FR')
        axs[0].set_xlabel('Frame #')
        axs[0].set_ylabel('Firing Rate (spks/s)')
        axs[0].legend()
        axs[0].set_title('cc={:.2f}, {}, \n lag={:d}'.format(mcc[ind],titles_all[m_models[ind]],lag_list[m_lags[ind]]))
        img = axs[1].imshow(msta[ind],cmap='seismic')
        axs[1].axis('off')
        axs[1].set_title('STA,cell: {:d}'.format(ind))
        add_colorbar(img)
        axs[2].bar(np.arange(mw_move.shape[-1]),mw_move[ind],color='b')
        axs[2].set_xticks(np.arange(mw_move.shape[-1]))
        axs[2].set_xticklabels(titles)
        axs[2].set_ylabel('Movement Weights')
        axs[2].axhline(0, color='grey', linewidth=0.8)
        plt.tight_layout()
        pdf.savefig()
        plt.close()

In [None]:
##### Plot cell model with lags #####
celln=33
model_ind=0
fig, axs = plt.subplots(2,len(lag_list), figsize=(np.floor(7.5*len(lag_list)).astype(int),10))
for lag_ind, lag in enumerate(lag_list):
    axs[0,lag_ind].plot(sp_smooth[celln,model_ind,lag_ind],'k',label='smoothed FR')
    axs[0,lag_ind].plot(pred_smooth[celln,model_ind,lag_ind],'r', label='pred FR')
    axs[0,lag_ind].set_title('cc={:.2f}'.format(cc_all[celln,model_ind,lag_ind]))
    axs[1,lag_ind].imshow(sta_all[celln,model_ind,lag_ind])
    axs[1,lag_ind].set_title('lag={:d}'.format(lag_list[lag_ind]))
    axs[1,lag_ind].axis('off')
    plt.suptitle(titles_all[model_ind])
    plt.tight_layout()

In [None]:
cc_all.reshape(n_units,len(lag_list))

In [None]:
# figure of receptive fields
fig = plt.figure(figsize=(25,256),dpi=50)
for celln in tqdm(range(n_units)):
    for lag_ind, lag in enumerate(lag_list):
        crange = np.max(np.abs(sta_all[celln,:,:,:]))
        plt.subplot(n_units,6,(celln*6)+lag_ind + 1)  
        plt.imshow(sta_all[celln, lag_ind, :, :], vmin=-crange, vmax=crange, cmap='jet')
        plt.title('cc={:.2f}'.format (cc_all[celln,lag_ind]),fontsize=5)
        plt.axis('off')
plt.tight_layout()

In [None]:
fig.savefig(save_dir/'STA1_5.pdf')

## Multiplicitive GLM

In [None]:
@ray.remote
def do_glm_fit_mult(train_nsp, test_nsp, train_data, test_data, move_train, move_test, celln, perms, lag, lambdas, alpha, bin_length=40, model_dt=.1):
    sps_train = np.roll(train_nsp[:,celln],-lag)
    sps_test = np.roll(test_nsp[:,celln],-lag)
    move_train = move_train[:,perms]
    move_test = move_test[:,perms]
    nks = np.shape(train_data)[1:]; nk = nks[0]*nks[1]

    x_train = train_data.reshape(train_data.shape[0],-1)*(1 + alpha*move_train)
    x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis = 1) # append column of ones
#     x_train = np.concatenate((x_train, move_train),axis=1) # x_train*(1+alpha*model_th)
    
    x_test = test_data.reshape(test_data.shape[0],-1)*(1 + alpha*move_test)
    x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis = 1) # append column of ones
#     x_test = np.concatenate((x_test, move_test),axis=1)

    Cinv = np.eye(nk)
    Cinv = linalg.block_diag(Cinv,np.zeros((1,1))) #move_test.shape[-1],move_test.shape[-1])
    nlam = len(lambdas)
    XXtr = x_train.T @ x_train
    XYtr = x_train.T @ sps_train
    msetrain = np.zeros((nlam,1))
    msetest = np.zeros((nlam,1))
    w_ridge = np.zeros((x_train.shape[-1],nlam))
    
    # loop over regularization strength
    for l in range(len(lambdas)):  
        # calculate MAP estimate               
        w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
        w_ridge[:,l] = w
        # calculate test and training rms error
        msetrain[l] = np.mean((sps_train - x_train@w)**2)
        msetest[l] = np.mean((sps_test - x_test@w)**2)
        
    # select best cross-validated lambda for RF
    best_lambda = np.argmin(msetest)
    w = w_ridge[:,best_lambda]
    ridge_rf = w_ridge[:,best_lambda]
    sta_all = np.reshape(w[:-(move_test.shape[-1])],nks)
    # plot predicted vs actual firing rate
    # predicted firing rate
    sp_pred = x_test@ridge_rf
    # bin the firing rate to get smooth rate vs time
    sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
    # a few diagnostics
    err = np.mean((sp_smooth-pred_smooth)**2)
    cc = np.corrcoef(sp_smooth, pred_smooth)
    cc_all = cc[0,1]
    return cc_all, sta_all, sps_test, sp_pred, msetest[best_lambda]

In [None]:
start = time.time()
titles = np.array(['th','phi','roll','pitch','dth','dphi'])
titles_all = []
for n in range(1,2):
    perms = np.array(list(itertools.combinations([0,1,2,3], n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))
        
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis])) # ,train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis])) # ,test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))

lag_list = [ -4, -2, 0 , 2, 4]
lambdas = 1024 * (2**np.arange(0,16))
# alpha_list = np.linspace(-2,2,10)
alpha_list = np.arange(-2,2+.5,.5)
nks = np.shape(train_vid)[1:]; nk = nks[0]*nks[1]

train_nsp_r = ray.put(train_nsp)
test_nsp_r = ray.put(test_nsp)
train_data_r = ray.put(train_vid)
test_data_r = ray.put(test_vid)
move_train_r = ray.put(move_train)
move_test_r = ray.put(move_test)
result_ids = []
# celln = 51
for celln in range(train_nsp.shape[1]):
    for n in range(1,2):
        perms = np.array(list(itertools.combinations([0,1,2,3,4,5], n)))
        for ind in range(perms.shape[0]):
            for lag_ind, lag in enumerate(lag_list):
                for alpha in alpha_list:
                    result_ids.append(do_glm_fit_mult.remote(train_nsp_r, test_nsp_r, train_data_r, test_data_r, move_train_r, move_test_r, celln, perms[ind], lag, lambdas, alpha))

print('N_proc:', len(result_ids))
results_p = ray.get(result_ids)                
print('GLM Mult: ', time.time()-start)

In [None]:
##### Gather Data and Find Max CC Model #####
cc_all = np.stack([results_p[i][0] for i in range(len(results_p))])
sta_all = np.stack([results_p[i][1] for i in range(len(results_p))])
sp_raw = np.stack([results_p[i][2] for i in range(len(results_p))])
pred_raw = np.stack([results_p[i][3] for i in range(len(results_p))])
mse_test_all = np.stack([results_p[i][4] for i in range(len(results_p))])

cc_all = cc_all.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),len(alpha_list)) + cc_all.shape[1:])
sta_all = sta_all.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),len(alpha_list)) + sta_all.shape[1:])
sp_raw = sp_raw.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),len(alpha_list)) + sp_raw.shape[1:])
pred_raw = pred_raw.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),len(alpha_list)) + pred_raw.shape[1:])
mse_test_all = mse_test_all.reshape((model_nsp.shape[1],len(titles_all),len(lag_list),len(alpha_list)) + mse_test_all.shape[1:])

m_cells, m_models, m_lags, m_alphas = np.where(cc_all==np.max(cc_all,axis=(-3,-2,-1), keepdims=True))

mcc = cc_all[m_cells,m_models,m_lags,m_alphas]
msta = sta_all[m_cells,m_models,m_lags,m_alphas]
msp = sp_raw[m_cells,m_models,m_lags,m_alphas]
mpred = pred_raw[m_cells,m_models,m_lags,m_alphas]
mmsetest = mse_test_all[m_cells,m_models,m_lags,m_alphas]

In [None]:
GLM_mult = {'cc_all': cc_all,
            'sta_all': sta_all,
            'sp_raw': sp_raw,
            'pred_raw': pred_raw,
            'mse_test_all': mse_test_all,}
ioh5.save(save_dir/'Mult_GLM_Data_alpha_{:d}_{:d}_notsmooth_dt{:03d}.h5'.format(int(np.abs(np.min(alpha_list))),int(np.max(alpha_list)),int(model_dt*1000)), GLM_mult)

In [None]:
##### Explore Neurons Write to pdf #####
with PdfPages(FigPath/ 'MaxCC_GLM_mult_dt{:03d}.pdf'.format(int(model_dt*1000))) as pdf:
    for ind in tqdm(range(model_nsp.shape[1])):
        fig, axs = plt.subplots(1,2, figsize=((15,6))) #np.floor(7.5*len(model_nsp)).astype(int)
        axs[0].plot((np.convolve(msp[ind], np.ones(bin_length), 'same')) / (bin_length * model_dt),'k',label='test FR')
        axs[0].plot((np.convolve(mpred[ind], np.ones(bin_length), 'same')) / (bin_length * model_dt),'r', label='pred FR')
        axs[0].legend()
        axs[0].set_xlabel('Frame #')
        axs[0].set_ylabel('Firing Rate (spks/s)')
        axs[0].set_title('cc={:.2f}, {}, \n lag={:d}, alpha={:.2f}'.format(mcc[ind],titles_all[m_models[ind]],lag_list[m_lags[ind]], alpha_list[m_alphas[ind]]))
        img = axs[1].imshow(msta[ind],cmap='seismic')
        axs[1].set_title('STA,cell: {:d}'.format(ind))
        axs[1].axis('off')
        add_colorbar(img)
        plt.tight_layout()
        pdf.savefig()
        plt.close()

Temporal component

dth, dphi, at shorter timescales model_dt = .025

Create movie with world cam and traces animation 

In [None]:
test_vid

In [None]:
alpha_list = np.arange(-2,2+.5,.5)


## Comparison of Models

In [None]:
GLM_add = ioh5.load(save_dir/'Add_GLM_Data_notsmooth.h5')
GLM_mult = ioh5.load(save_dir/'Mult_GLM_Data_alpha_{:d}_{:d}_notsmooth.h5'.format(int(np.abs(np.min(alpha_list))),int(np.max(alpha_list))))


In [None]:
##### GLM Mult #####
m_cells_mult, m_models_mult, m_lags_mult, m_alphas_mult = np.where(GLM_mult['cc_all']==np.max(GLM_mult['cc_all'], axis=(-3,-2,-1), keepdims=True))
l=list(np.where((GLM_mult['cc_all']==np.max(GLM_mult['cc_all'], axis=(-3,-2,-1), keepdims=True)))[0])
indexes = [l.index(x) for x in set(l)]
m_cells_mult = m_cells_mult[indexes]
m_models_mult = m_models_mult[indexes]
m_lags_mult = m_lags_mult[indexes]
m_alphas_mult = m_alphas_mult[indexes]

mcc_mult = GLM_mult['cc_all'][m_cells_mult,m_models_mult,m_lags_mult,m_alphas_mult]
msta_mult = GLM_mult['sta_all'][m_cells_mult,m_models_mult,m_lags_mult,m_alphas_mult]
msp_mult = GLM_mult['sp_smooth'][m_cells_mult,m_models_mult,m_lags_mult,m_alphas_mult]
mpred_mult = GLM_mult['pred_smooth'][m_cells_mult,m_models_mult,m_lags_mult,m_alphas_mult]
mmsetest_mult = GLM_mult['mse_test_all'][m_cells_mult,m_models_mult,m_lags_mult,m_alphas_mult]

In [None]:
##### GLM Add #####
m_cells_add, m_models_add, m_lags_add = np.where(GLM_add['cc_all']==np.max(GLM_add['cc_all'],axis=(-2,-1), keepdims=True))

mcc_add = GLM_add['cc_all'][m_cells_add,m_models_add,m_lags_add]
msta_add = GLM_add['sta_all'][m_cells_add,m_models_add,m_lags_add]
msp_add = GLM_add['sp_smooth'][m_cells_add,m_models_add,m_lags_add]
mpred_add = GLM_add['pred_smooth'][m_cells_add,m_models_add,m_lags_add]
mw_move_add = GLM_add['w_move_all'][m_cells_add,m_models_add,m_lags_add]

In [None]:
plt.hist(m_lags_mult)

model the response of neurons as a product of the velocity and the spatial gradient

gives an idea of what terms in regression model should carry weight 


9/14 potential meeting with cris, james and me

In [None]:
lag_list = [ -4, -2, 0 , 2, 4]
plt.hist(m_lags_add)

In [None]:
mmsetest_mult = GLM_mult['mse_test_all'][m_cells_mult,m_models_mult,m_lags_mult,]

In [None]:
alpha_list

In [None]:
plt.hist(np.argmin(mmsetest_mult,axis=1))

In [None]:
for n in range(mmsetest_mult.shape[0]):
    plt.plot(alpha_list, mmsetest_mult[n] - np.mean(mmsetest_mult[n]))

In [None]:
fig, axs = plt.subplots(1,1, figsize=(5,5))
axs.plot([0,1],[0,1], 'k')
axs.scatter(mcc_add ,mcc_mult)
axs.set_xlabel('CC Add')
axs.set_ylabel('CC Mult.')
axs.set_aspect('equal')

In [None]:
plt.scatter(mpred_add[21],msp_add[21], alpha=.1)

In [None]:
from scipy.stats import binned_statistic

In [None]:
fig, ax = plt.subplots(1)
for n in range(mpred_add.shape[0]):
    stat, edges, _ = binned_statistic(mpred_add[n], msp_add[n], statistic='mean')
    edge_mids = [(edges[i]+edges[i+1])/2 for i in range(len(edges)-1)]
    ax.scatter(edge_mids, stat)
    

restrict data of th and phi, pitch roll./

# GLM on PCA of population

In [None]:
pca = PCA()
pcs = pca.fit_transform(model_nsp)
plt.plot(np.cumsum(pca.explained_variance_ratio_))
comp_to_keep = np.where(np.cumsum(pca.explained_variance_ratio_)>.9)[0][0]
plt.axvline(x=comp_to_keep)
pca = PCA(n_components=comp_to_keep)
pcs = pca.fit_transform(model_nsp)
print('keep {} PCs'.format(comp_to_keep))
# recon = pca.inverse_transform(pcs)

In [None]:
train_pcs = pcs[train_idx]
test_pcs = pcs[test_idx]

In [None]:
start = time.time()

titles = np.array(['th','phi','roll','pitch'])
titles_all = []
for n in range(1,5):
    perms = np.array(list(itertools.combinations([0,1,2,3], n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))
        
move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis])) # ,train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis])) # ,test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))

lag_list = [ -4, -2, 0 , 2, 4]
lambdas = 1024 * (2**np.arange(0,16))

train_nsp_r = ray.put(train_pcs)
test_nsp_r = ray.put(test_pcs)
train_data_r = ray.put(train_vid)
test_data_r = ray.put(test_vid)
move_train_r = ray.put(move_train)
move_test_r = ray.put(move_test)
result_ids = []

for celln in range(train_pcs.shape[1]):
    for n in range(1,5):
        perms = np.array(list(itertools.combinations([0,1,2,3], n)))
        for ind in range(perms.shape[0]):
            for lag_ind, lag in enumerate(lag_list):    
                result_ids.append(do_glm_fit.remote(train_nsp_r, test_nsp_r, train_data_r, test_data_r, move_train_r, move_test_r, celln, perms[ind], lag, lambdas))
                      
results_p = ray.get(result_ids)
print('GLM: ', time.time()-start)

In [None]:
##### Gather Data and Find Max CC Model #####
cc_all = np.stack([results_p[i][0] for i in range(len(results_p))])
sta_all = np.stack([results_p[i][1] for i in range(len(results_p))])
sp_smooth = np.stack([results_p[i][2] for i in range(len(results_p))])
pred_smooth = np.stack([results_p[i][3] for i in range(len(results_p))])
w_move_all = np.stack([results_p[i][4] for i in range(len(results_p))])

cc_all = cc_all.reshape((pcs.shape[1],len(titles_all),len(lag_list),) + cc_all.shape[1:])
sta_all = sta_all.reshape((pcs.shape[1],len(titles_all),len(lag_list),) + sta_all.shape[1:])
sp_smooth = sp_smooth.reshape((pcs.shape[1],len(titles_all),len(lag_list),) + sp_smooth.shape[1:])
pred_smooth = pred_smooth.reshape((pcs.shape[1],len(titles_all),len(lag_list),) + pred_smooth.shape[1:])
w_move_all = w_move_all.reshape((pcs.shape[1],len(titles_all),len(lag_list),) + w_move_all.shape[1:])

m_cells, m_models, m_lags = np.where(cc_all==np.max(cc_all,axis=(-2,-1), keepdims=True))

mcc = cc_all[m_cells,m_models,m_lags]
msta = sta_all[m_cells,m_models,m_lags]
msp = sp_smooth[m_cells,m_models,m_lags]
mpred = pred_smooth[m_cells,m_models,m_lags]
mw_move = w_move_all[m_cells,m_models,m_lags]

In [None]:
GLM_pcs = {'cc_all': cc_all,
            'sta_all': sta_all,
            'sp_smooth': sp_smooth,
            'pred_smooth': pred_smooth,
            'w_move_all': w_move_all,}
ioh5.save(save_dir/'Add_GLM_PCs_Data.h5', GLM_add)

In [None]:
import plotly.express as px

fig = px.imshow(msta, animation_frame=0, binary_string=False,color_continuous_scale='RdBu_r')
fig.update_layout(width=500,
                  height=500,
                 )
fig.show()

In [None]:
##### Explore Neurons Write to pdf #####
with PdfPages(FigPath/ 'MaxCC_pcaSpikes.pdf') as pdf:
    for ind in tqdm(range(pcs.shape[1])):
        fig, axs = plt.subplots(1,3, figsize=((25,6))) #np.floor(7.5*len(model_nsp)).astype(int)
        axs[0].plot(msp[ind],'k',label='test pc{:d}'.format(ind))
        axs[0].plot(mpred[ind],'r', label='pred pc{:d}'.format(ind))
        axs[0].legend()
        axs[0].set_title('cc={:.2f}, {}, \n lag={:d}'.format(mcc[ind],titles_all[m_models[ind]],lag_list[m_lags[ind]]))
        img = axs[1].imshow(msta[ind],cmap='seismic')
        axs[1].set_title('STA, PC: {:d}'.format(ind))
        axs[1].axis('off')
        add_colorbar(img)
        axs[2].bar(np.arange(mw_move.shape[-1]),mw_move[ind],color='b')
        axs[2].set_xticks(np.arange(mw_move.shape[-1]))
        axs[2].set_xticklabels(titles)
        axs[2].axhline(0, color='grey', linewidth=0.8)
        plt.tight_layout()
        pdf.savefig()
        plt.close()

spatial gradient as input for GLM mult. by dth/dphi/dgaze

In [None]:
t = 2004
plt.figure(figsize=(2,2))
plt.imshow(model_vid_sm[t],cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig(FigPath/'ExampleFrame.png',facecolor='white', transparent=True)

# Sequential GLM

In [None]:
# ##### Train_Test Split with sklearn #####
# model_vid = model_vid_sm
# model_dt = .1
# nks = np.shape(model_vid)[1:]; nk = nks[0]*nks[1]
# nT = np.shape(pcs)[0]
# x = model_vid.reshape(pcs.shape[0], -1).copy()
# # image dimensions
# n_units = np.shape(pcs)[1]

# titles = np.array(['th','phi','roll','pitch'])
# move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis],train_dth[:,np.newaxis],train_dphi[:,np.newaxis]))
# move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis],test_dth[:,np.newaxis],test_dphi[:,np.newaxis]))

# perms = np.array([0,1,2,3]) #np.array(list(itertools.combinations([0,1,2,3], n)))
# move_train = move_train[:,perms]
# move_test = move_test[:,perms]

# # set up prior matrix (regularizer)
# # L2 prior
# Imat = np.eye(nk)
# Imat = linalg.block_diag(Imat,np.zeros((1+move_test.shape[-1],1+move_test.shape[-1])))
# # smoothness prior
# consecutive = np.ones((nk, 1))
# consecutive[nks[1]-1::nks[1]] = 0
# diff = np.zeros((1,2))
# diff[0,0] = -1
# diff[0,1]= 1
# Dxx = sparse.diags((consecutive @ diff).T, np.array([0, 1]), (nk-1,nk))
# Dxy = sparse.diags((np.ones((nk,1))@ diff).T, np.array([0, nks[1]]), (nk-nks[1], nk))
# Dx = Dxx.T @ Dxx + Dxy.T @ Dxy
# D  = linalg.block_diag(Dx.toarray(),np.zeros((1+move_test.shape[-1],1+move_test.shape[-1])))   
# # summed prior matrix
# # Cinv = D + Imat
# Cinv = Imat

# lag_list = [ -4, -2, 0 , 2, 4]
# lambdas = 1024 * (2**np.arange(0,16))
# nlam = len(lambdas)
# # set up empty arrays for receptive field and cross correlation
# sta_all = np.zeros((n_units, len(lag_list), nks[0], nks[1]))
# cc_all = np.zeros((n_units,len(lag_list)))

# celln = 1
# fig, axs = plt.subplots(2,len(lag_list), figsize=(np.floor(7.5*len(lag_list)).astype(int),10))
# for lag_ind, lag in enumerate(lag_list):
    
#     sps_train = np.roll(train_pcs[:,celln],-lag)
#     sps_test = np.roll(test_pcs[:,celln],-lag)
    
   
#     #calculate a few terms
#     x_train = train_vid.reshape(train_vid.shape[0],-1)
#     x_train = np.append(x_train, np.ones((x_train.shape[0],1)), axis = 1) # append column of ones
#     x_train = np.concatenate((x_train,move_train),axis=1)

#     x_test = test_vid.reshape(test_vid.shape[0],-1)
#     x_test = np.append(x_test,np.ones((x_test.shape[0],1)), axis = 1) # append column of ones
#     x_test = np.concatenate((x_test,move_test),axis=1)
    
#     XXtr = x_train.T @ x_train
#     XYtr = x_train.T @ sps_train
    
#     msetrain = np.zeros((nlam,1))
#     msetest = np.zeros((nlam,1))
#     w_ridge = np.zeros((nk+1+move_test.shape[1],nlam))
#     # initial guess
#     # loop over regularization strength
#     for l in range(len(lambdas)):  
#         # calculate MAP estimate               
#         w = np.linalg.solve(XXtr + lambdas[l]*Cinv, XYtr) # equivalent of \ (left divide) in matlab
#         w_ridge[:,l] = w
#         # calculate test and training rms error
#         msetrain[l] = np.mean((sps_train - x_train@w)**2)
#         msetest[l] = np.mean((sps_test - x_test@w)**2)
#     # select best cross-validated lambda for RF
#     best_lambda = np.argmin(msetest)
#     w = w_ridge[:,best_lambda]
#     ridge_rf = w_ridge[:,best_lambda]
#     sta_all[celln,lag_ind,:,:] = np.reshape(w[:-(1+move_test.shape[-1])],nks)
#     # plot predicted vs actual firing rate
#     # predicted firing rate
#     sp_pred = x_test@ridge_rf
#     # bin the firing rate to get smooth rate vs time
#     bin_length = 40
#     sp_smooth = (np.convolve(sps_test, np.ones(bin_length), 'same')) / (bin_length * model_dt)
#     pred_smooth = (np.convolve(sp_pred, np.ones(bin_length), 'same')) / (bin_length * model_dt)
#     # a few diagnostics
#     err = np.mean((sp_smooth-pred_smooth)**2)
#     cc = np.corrcoef(sp_smooth, pred_smooth)
#     cc_all[celln,lag_ind] = cc[0,1]

#     axs[0,lag_ind].plot(sp_smooth,'k',label='smoothed FR')
#     axs[0,lag_ind].plot(pred_smooth,'r', label='pred FR')
#     axs[0,lag_ind].set_title('cc={:.2f}'.format(cc_all[celln,lag_ind]))
#     axs[1,lag_ind].imshow(sta_all[celln,lag_ind])
#     axs[1,lag_ind].set_title('lag={:d}'.format(lag_list[lag_ind]))
#     axs[1,lag_ind].axis('off')
#     plt.suptitle('GLM on PCA spikes')
#     plt.tight_layout()