# Imports

In [2]:
%load_ext autoreload
%autoreload 2

import os
import argparse
import glob
import sys 
import yaml 
import glob
import h5py 
import ray
import logging 
import json
import gc
import cv2
import time
import itertools

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import xarray as xr
import scipy.linalg as linalg
import scipy.sparse as sparse
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches

from tqdm.auto import tqdm, trange
from matplotlib.backends.backend_pdf import PdfPages
from scipy import signal
from pathlib import Path
from scipy.optimize import minimize_scalar,minimize
from scipy.interpolate import interp1d
from scipy.ndimage import shift as imshift
from sklearn.model_selection import train_test_split, GroupShuffleSplit
from sklearn.pipeline import make_pipeline
from sklearn import linear_model as lm 
from scipy.stats import binned_statistic
from sklearn.utils import shuffle
from sklearn.metrics import r2_score, mean_poisson_deviance
from pyglmnet import GLMCV, GLM

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sys.path.append('/home/seuss/Research/MyRepos/NonLinearMixedSel_FreelyMoving/')
sys.path.append(str(Path('.').absolute()))
from utils import *
import io_dict_to_hdf5 as ioh5
from format_data import *

pd.set_option('display.max_rows', None)

ray.init(
    ignore_reinit_error=True,
    logging_level=logging.ERROR,
)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
free_move = True
if free_move:
    stim_type = 'fm1'
else:
    stim_type = 'hf1_wn' # 'fm1' # 
# 012821/EE8P6LT
# 128: 070921/J553RT
date_ani = '070921/J553RT' #'062921/G6HCK1ALTRN'
data_dir  = Path('~/Goeppert/freely_moving_ephys/ephys_recordings/').expanduser() / date_ani / stim_type
save_dir  = check_path(Path('~/Research/SensoryMotorPred_Data/data/').expanduser() / date_ani, stim_type)
FigPath = check_path(Path('~/Research/SensoryMotorPred_Data').expanduser(),'Figures/Encoding')
FigPath = check_path(FigPath/date_ani, stim_type)
FigPath_SFN = check_path(FigPath,'SFN')

print('save_dir:',save_dir)
print('data_dir:',data_dir)
print('FigPath:', FigPath)
# with open(save_dir / 'file_dict.json','r') as fp:
#     file_dict = json.load(fp)

save_dir: /home/seuss/Research/SensoryMotorPred_Data/data/070921/J553RT/fm1
data_dir: /home/seuss/Goeppert/freely_moving_ephys/ephys_recordings/070921/J553RT/fm1
FigPath: /home/seuss/Research/SensoryMotorPred_Data/Figures/Encoding/070921/J553RT/fm1


In [4]:
file_dict = {'cell': 0,
            'drop_slow_frames': True,
            'ephys': list(data_dir.glob('*ephys_merge.json'))[0].as_posix(),
            'ephys_bin': list(data_dir.glob('*Ephys.bin'))[0].as_posix(),
            'eye': list(data_dir.glob('*REYE.nc'))[0].as_posix(),
            'imu': list(data_dir.glob('*imu.nc'))[0].as_posix() if stim_type=='fm1' else None,
            'mapping_json': '/home/seuss/Research/Github/FreelyMovingEphys/probes/channel_maps.json',
            'mp4': True,
            'name': '01221_EE8P6LT_control_Rig2_'+stim_type, #070921_J553RT
            'probe_name': 'DB_P128-6',
            'save': data_dir.as_posix(),
            'speed': list(data_dir.glob('*speed.nc'))[0].as_posix() if stim_type=='hf1_wn' else None,
            'stim_type': 'light',
            'top': list(data_dir.glob('*TOP1.nc'))[0].as_posix() if stim_type=='fm1' else None,
            'world': list(data_dir.glob('*world.nc'))[0].as_posix(),}

In [5]:
model_dt = .05
do_shuffle=False
do_norm = False
data,train_idx,test_idx = load_train_test(file_dict, save_dir, model_dt=model_dt, do_shuffle=do_shuffle, do_norm=do_norm,free_move=free_move, has_imu=free_move, has_mouse=False)
locals().update(data)

lag_list = np.array([-2,-1,0,1,2]) #np.array([-1,0,1,2,3]) #,np.arange(minlag,maxlag,np.floor((maxlag-minlag)/nt_glm_lag).astype(int))
nt_glm_lag = len(lag_list)
print(lag_list,1000*lag_list*model_dt)
do_shuffle=False
model_type = 'Pytorch'
ncells=model_nsp.shape[-1]
bin_length=40


Done Loading Aligned Data
TRAIN: 15628 TEST: 6698
[-2 -1  0  1  2] [-100.  -50.    0.   50.  100.]


In [7]:
lag_list = np.array([0]) #-2,-1,0,1,2]) 
nt_glm_lag = len(lag_list)
print(lag_list,1000*lag_list*model_dt)
do_shuffle = False
model_type = 'Pytorch'

# for do_shuffle in [False,True]:
# Load Data
data, train_idx, test_idx = load_train_test(file_dict, save_dir, model_dt=model_dt, do_shuffle=do_shuffle, do_norm=True,free_move=free_move, has_imu=free_move, has_mouse=False)
locals().update(data)

# Initialize movement combinations
titles = np.array(['Theta','Phi','Roll','Pitch']) # 'dg_p','dg_n' 'roll','pitch'
titles_all = []
for n in range(1,len(titles)+1):
    perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))
    for ind in range(perms.shape[0]):
        titles_all.append('_'.join([t for t in titles[perms[ind]]]))
if free_move:
    move_train = np.hstack((train_th[:,np.newaxis],train_phi[:,np.newaxis],train_roll[:,np.newaxis],train_pitch[:,np.newaxis]))
    move_test = np.hstack((test_th[:,np.newaxis],test_phi[:,np.newaxis],test_roll[:,np.newaxis],test_pitch[:,np.newaxis])) 
    model_move = np.hstack((model_th[:,np.newaxis],model_phi[:,np.newaxis],model_roll[:,np.newaxis],model_pitch[:,np.newaxis]))
    model_move = model_move - np.mean(model_move,axis=0)
    move_test = move_test - np.mean(move_test,axis=0)
    move_train = move_train - np.mean(move_train,axis=0)

##### Start GLM Parallel Processing #####
nks = np.shape(train_vid)[1:]; nk = nks[0]*nks[1]*nt_glm_lag
n=4; ind=0
perms = np.array(list(itertools.combinations(np.arange(len(titles)), n)))

##### Start GLM Parallel Processing #####
# Reshape data (video) into (T*n)xN array
rolled_vid = np.hstack([np.roll(model_vid_sm, nframes, axis=0) for nframes in lag_list]) # nt_glm_lag
rolled_vid_flat = rolled_vid.reshape(rolled_vid.shape[0],-1)
x_train = rolled_vid[train_idx].reshape(len(train_idx),-1)
x_test = rolled_vid[test_idx].reshape(len(test_idx),-1)


[0] [0.]
Done Loading Aligned Data
TRAIN: 15628 TEST: 6698


In [None]:
for epoch in Nepochs:


In [20]:
class Decoding_Network(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, hidden_layers=2, device='cuda'):
        super(Decoding_Network, self).__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.relu = nn.ReLU()
        self.input_layer = nn.Linear(self.in_features,self.out_features)
        self.layers = []
        for n in range(hidden_layers):
            self.layers.append(nn.Linear(self.hidden_features,hidden_features))
        self.out_layer = nn.Linear(self.hidden_features,self.out_features)

    def forward(self, inputs, move_input=None):
        x, y = inputs.shape
        if y != self.in_features:
            print(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
            return 0
        output = inputs.matmul(self.weight.t())
        if move_input != None:
            output = output + move_input.matmul(self.move_weights.t())
        if self.bias is not None:
            output = output + self.bias
        ret = torch.log1p(torch.exp(output))
        return ret
    
    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )
    
    def loss(self,Yhat, Y): 
        if self.move_features != None:
#             l2_reg = self.lam*(torch.linalg.norm(self.weight[:,:-self.move_features],axis=1,ord=2))
#             l1_reg = self.alpha*(torch.linalg.norm(self.weight[:,:-self.move_features],axis=1,ord=1))
#             l2_regm = self.lam_m*(torch.linalg.norm(self.weight[:,-self.move_features:],axis=1,ord=2))
            if self.reg_alph != None:
                l1_regm = self.alpha_m*(torch.linalg.norm(self.weight[:,-self.move_features:],axis=1,ord=1))
                l1_reg = self.alpha*(torch.linalg.norm(self.weight,axis=1,ord=1))
            else: 
                l1_regm = 0
                l1_reg = 0
            loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0) + l1_reg + l1_regm
        else:
            if self.reg_lam != None:
                if self.reg_alph != None:
                    l2_reg = self.lam*(torch.linalg.norm(self.weight,axis=1,ord=2))
                    l1_reg = self.alpha*(torch.linalg.norm(self.weight,axis=1,ord=1))
                    loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0) + l2_reg + l1_reg
                else:
                    l2_reg = self.lam*(torch.linalg.norm(self.weight,axis=1,ord=2)) 
                    loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0) + l2_reg
            else:
                if self.reg_alph != None:
                    l1_reg = self.alpha*(torch.linalg.norm(self.weight,axis=1,ord=1))
                    loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0) + l1_reg
                else:
                    loss_vec = torch.mean(Yhat-Y*torch.log(Yhat),axis=0)
        return loss_vec


(torch.Size([15628, 600]), (15628, 128))

In [25]:
dataset_tr = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(train_nsp).float())
dataloader_tr = DataLoader(dataset_tr,batch_size=x_train.shape[0])
dataset_te = TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(test_nsp).float())
dataloader_te = DataLoader(dataset_te,batch_size=x_test.shape[0])

In [26]:

MovModel = 1
# Reshape data (video) into (T*n)xN array
if MovModel == 0:
    mx_train = move_train[:,perms[ind]]
    mx_test = move_test[:,perms[ind]]
    xtr = torch.from_numpy(mx_train.astype(np.float32)).to(device)
    xte = torch.from_numpy(mx_test.astype(np.float32)).to(device)    
    move_features = None # mx_train.shape[-1]
    nk = 0
    xtrm = None
    xtem = None
elif MovModel == 1:
    x_train_m1 = (rolled_vid[train_idx].reshape(len(train_idx),-1)).astype(np.float32)
    x_test_m1 = (rolled_vid[test_idx].reshape(len(test_idx),-1)).astype(np.float32)
    xtr = torch.from_numpy(x_train_m1).to(device)
    xte = torch.from_numpy(x_test_m1).to(device)
    move_features = None
    xtrm = None
    xtem = None
elif MovModel == 2:
    xtrm = torch.from_numpy(move_train[:,perms[ind]].astype(np.float32)).to(device)
    xtem = torch.from_numpy(move_test[:,perms[ind]].astype(np.float32)).to(device)
    xtr = torch.from_numpy(x_train.astype(np.float32)).to(device)
    xte = torch.from_numpy(x_test.astype(np.float32)).to(device)
    move_features = xtrm.shape[-1]
elif MovModel == 3:
    x_train_m3 = np.hstack((np.hstack([x_train*move_train[:,modeln][:,np.newaxis] for modeln in np.arange(len(titles))]), move_train[:,perms[ind]]))
    x_test_m3 = np.hstack((np.hstack([x_test*move_test[:,modeln][:,np.newaxis] for modeln in np.arange(len(titles))]), move_test[:,perms[ind]]))
    xtr = torch.from_numpy(x_train.astype(np.float32)).to(device)
    xte = torch.from_numpy(x_test.astype(np.float32)).to(device)
    xtrm = torch.from_numpy(x_train_m3.astype(np.float32)).to(device)
    xtem = torch.from_numpy(x_test_m3.astype(np.float32)).to(device)
    move_features = x_train_m3.shape[-1]

    
ytr = torch.from_numpy(train_nsp.astype(np.float32)).to(device)
yte = torch.from_numpy(test_nsp.astype(np.float32)).to(device)
input_size = xtr.shape[1]
output_size = ytr.shape[1]
print('Model: {}, move_features: {}'.format(MovModel, move_features))


Model: 1, move_features: None


In [None]:

# lossfn = torch.nn.PoissonNLLLoss(log_input=True,reduction='mean')

Nbatches = 20000
if move_features != None:
    reg_params = np.zeros((Nbatches,output_size,4))
    reg_titles = ['lambda','lambda_m','alpha','alpha_m']
else:
    reg_params = np.zeros((Nbatches,output_size,2))
    reg_titles = ['lambda','alpha']


if MovModel == 0:
    sta_init = None
    lambdas = [0]#(2**(np.arange(0,10)))
    nlam = len(lambdas)
    alphas = [0]#np.array([.005,.01,.02]) #np.arange(.01,.5,.05)
    nalph = len(alphas)
    w_move_traces_all = np.zeros((nalph, nlam, Nbatches, output_size, input_size))
elif MovModel == 1:
    lambdas = (2**(np.arange(0,10)))/100
    nlam = len(lambdas)
    alphas = np.array([.0075]) #np.arange(.01,.5,.05)
    nalph = len(alphas)
    sta_init = torch.from_numpy(((rolled_vid_flat.T@model_nsp)/(10*np.sum(model_nsp,axis=0))).T.astype(np.float32))
elif MovModel == 2:
    lambdas = (2**(np.arange(0,10)))/100
    lambdas_m = (2**(np.arange(0, 10)))/10
    nlam = len(lambdas)
    alphas = np.array([.0075,]) #np.arange(.01,.5,.05) .005,.01,.02
    nalph = len(alphas)
    sta_init = torch.from_numpy(((rolled_vid_flat.T@model_nsp)/(10*np.sum(model_nsp,axis=0))).T.astype(np.float32))
    w_move_cv = np.zeros((nalph,nlam,output_size,move_features))
    w_move_traces_all = np.zeros((nalph, nlam, Nbatches, output_size, move_features))
else:
    lambdas = (2**(np.arange(0, 10)))
    lambdas_m = (2**(np.arange(0, 10)))
    nlam = len(lambdas)
    alphas = np.array([.01])  # np.arange(.01,.5,.05) .005,.01,.02
    nalph = len(alphas)
    sta_init = torch.from_numpy(((rolled_vid_flat.T@model_nsp)/(10*np.sum(model_nsp, axis=0))).T.astype(np.float32))
    w_move_cv = np.zeros((nalph, nlam, output_size, move_features), dtype=np.float32)
    # w_move_traces_all = np.zeros((nalph, nlam, Nbatches, output_size, move_features),dtype=np.float32)

#     sta_init = torch.from_numpy(np.hstack((((rolled_vid_flat.T@model_nsp)/(10*np.sum(model_nsp,axis=0))).T,np.zeros((output_size,move_features)))).astype(np.float32))
meanbias = torch.log(torch.exp(torch.mean(torch.tensor(model_nsp,dtype=torch.float32),axis=0)) - 1)

msetrain = np.zeros((nalph,nlam,output_size))
msetest = np.zeros((nalph,nlam,output_size))
pred_cv = np.zeros((x_test.shape[0],nalph,nlam,output_size),dtype=np.float32)
w_cv = np.zeros((x_train.shape[-1],nalph,nlam,output_size),dtype=np.float32)
bias_cv = np.zeros((nalph,nlam,output_size),dtype=np.float32)
tloss_trace_all = np.zeros((nalph, nlam, Nbatches, output_size),dtype=np.float32)
vloss_trace_all = np.zeros((nalph, nlam, Nbatches, output_size),dtype=np.float32)
bias_traces_all = np.zeros((nalph, nlam, Nbatches, output_size),dtype=np.float32)

lr_w = [1e-6, 1e-4]
lr_b = [1e-5, 5e-3]
lr_m = [1e-5, 1e-3]
start = time.time()
for a, reg_alph in enumerate(tqdm(alphas)):
    for l, reg_lam in enumerate(tqdm(lambdas)):
#         params = add_weight_decay(l1,lambdas[l])
#         optimizer = optim.ASGD(params=[{'params': [l1.weight],'lr':5e-5,'weight_decay':lambdas[l]}, {'params': [l1.bias],'lr':1e-3}], lr=5e-5) #'weight_decay':lambdas[l]
        if MovModel == 0: 
            l1 = PoissonGLM_VM_staticreg(input_size,output_size,reg_lam=None,reg_alph=None,move_features=move_features,meanfr=meanbias,init_sta=sta_init,device=device).to(device)
            optimizer = optim.ASGD(params=[{'params': [l1.weight],'lr': 1e-3,},
                                           {'params': [l1.bias],'lr':lr_b[1]},], lr=5e-5) #
            scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=[lr_m[0],lr_b[0]], max_lr=[lr_m[1],lr_b[1]], cycle_momentum=False)
        elif MovModel == 1:
            l1 = PoissonGLM_VM_staticreg(input_size,output_size,reg_lam=None,reg_alph=reg_alph,move_features=move_features,meanfr=meanbias,init_sta=sta_init,device=device).to(device)
            optimizer = optim.ASGD(params=[{'params': [l1.weight],'lr':lr_w[1],'weight_decay':lambdas[l]},
                                           {'params': [l1.bias],'lr':lr_b[1]},], lr=5e-5) #
            scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=[lr_w[0],lr_b[0]], max_lr=[lr_w[1],lr_b[1]], cycle_momentum=False)
        else:
            l1 = PoissonGLM_VM_staticreg(input_size,output_size,reg_lam=None,reg_alph=reg_alph,move_features=move_features,meanfr=meanbias,init_sta=sta_init,device=device).to(device)
            optimizer = optim.ASGD(params=[{'params': [l1.weight],'lr':lr_w[1],'weight_decay':lambdas[l]},
                                           {'params': [l1.bias],'lr':lr_b[1]},
                                           {'params': [l1.move_weights],'lr':1e-3, 'weight_decay': lambdas_m[l]}], lr=5e-5) #
            scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=[lr_w[0],lr_b[0],lr_m[0]], max_lr=[lr_w[1],lr_b[1],lr_m[1]], cycle_momentum=False)
#         scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=[1e-6,1e-5], max_lr=[1e-4,5e-3], cycle_momentum=False)
        early_stopping = EarlyStopping(patience=1000,min_delta=.005)

        vloss_trace = np.zeros((Nbatches,output_size),dtype=np.float32)      
        tloss_trace = np.zeros((Nbatches,output_size),dtype=np.float32)
        for batchn in np.arange(Nbatches):
            out = l1(xtr,xtrm)
            loss = l1.loss(out,ytr)
            pred = l1(xte,xtem)
            val_loss = l1.loss(pred,yte)
            vloss_trace[batchn] = val_loss.clone().cpu().detach().numpy()
            tloss_trace[batchn] = loss.clone().cpu().detach().numpy()
            bias_traces_all[a,l,batchn] = l1.bias.clone().cpu().detach().numpy()
            # if MovModel == 0:
            #     w_move_traces_all[a,l,batchn] = l1.weight.clone().cpu().detach().numpy()  # [:,(nk):]
            # elif MovModel != 1:
            #     w_move_traces_all[a,l,batchn] = l1.move_weights.clone().cpu().detach().numpy()  # [:,(nk):]
            optimizer.zero_grad()
            loss.backward(torch.ones_like(loss))
            optimizer.step()
            scheduler.step()
            #     lam_grad[batchn]= l1.lam.grad.detach().cpu().numpy()
            early_stopping(np.mean(val_loss.clone().cpu().detach().numpy()))
#             if early_stopping.early_stop:
#                 break
        tloss_trace_all[a,l] = tloss_trace
        vloss_trace_all[a,l] = vloss_trace
        bias_cv[a,l] = l1.bias.clone().cpu().detach().numpy()
        if MovModel != 0:
            w_cv[:,a,l] = l1.weight.clone().cpu().detach().numpy().T #[:,:(nk)]
        if MovModel == 0: 
            w_move_cv[a,l] = l1.weight.clone().cpu().detach().numpy()#[:,(nk):]
        elif MovModel != 1:
            w_move_cv[a,l] = l1.move_weights.clone().cpu().detach().numpy()#[:,(nk):]
        pred =  l1(xte,xtem)
        msetest[a,l] = torch.mean(pred-yte*torch.log(pred),axis=0).cpu().detach().numpy()
        pred_cv[:,a,l] = pred.detach().cpu().numpy().squeeze()
  
print('GLM: ', time.time()-start)
# pred_all = l1(xte).cpu().detach().numpy()
if MovModel != 0:
    w_cv2 = w_cv.T.reshape((output_size,nlam,nalph,nt_glm_lag,)+nks)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

GLM:  513.058262348175


In [None]:
malph,mlam,cellnum  = np.where(msetest==np.nanmin(msetest,axis=(0,1), keepdims=True))
cellnum, m_cinds = np.unique(cellnum,return_index=True)
malph = malph[m_cinds]
mlam = mlam[m_cinds]
sortinds = cellnum.argsort()
cellnum = cellnum[sortinds]
malph = malph[sortinds]
mlam = mlam[sortinds]
sta_all = w_cv[:,malph,mlam,cellnum].T.reshape((output_size,nt_glm_lag,)+nks)
pred_all = pred_cv[:,malph,mlam,cellnum]
bias_all = bias_cv[malph,mlam,cellnum]
tloss_trace_all2 = tloss_trace_all[malph,mlam,:,cellnum]
vloss_trace_all2 = vloss_trace_all[malph,mlam,:,cellnum]
# w_move_traces = w_move_traces_all[malph, mlam, :, cellnum]
bias_traces = bias_traces_all[malph, mlam, :, cellnum]
if MovModel != 1:
    w_move = w_move_cv[malph,mlam,cellnum]

bin_length=40
r2_all = np.zeros((output_size))
for celln in range(output_size):
    sp_smooth = ((np.convolve(test_nsp[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    pred_smooth = ((np.convolve(pred_all[:,celln], np.ones(bin_length), 'same')) / (bin_length * model_dt))[bin_length:-bin_length]
    r2_all[celln] = (np.corrcoef(sp_smooth,pred_smooth)[0,1])**2


In [None]:
mlam

array([0, 0, 0, 9, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 9, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 9, 0, 0, 0, 0, 0, 5, 0])

In [None]:

if MovModel == 0:
    GLM_Data = {'r2_all': r2_all,
                'test_nsp': test_nsp,
                'pred_all': pred_all,
                'bias_all': bias_all, 
                'tloss_trace_all':tloss_trace_all2,
                'vloss_trace_all':vloss_trace_all2,
                'w_move': w_move}
elif MovModel == 1:
    GLM_Data = {'r2_all': r2_all,
                'sta_all': sta_all,
                'test_nsp': test_nsp,
                'pred_all': pred_all,
                'bias_all': bias_all,
                'tloss_trace_all':tloss_trace_all2,
                'vloss_trace_all':vloss_trace_all2,
               }
else:
    GLM_Data = {'r2_all': r2_all,
                'sta_all': sta_all,
                'test_nsp': test_nsp,
                'pred_all': pred_all,
                'bias_all': bias_all,
                'tloss_trace_all':tloss_trace_all2,
                'vloss_trace_all':vloss_trace_all2,
                'w_move': w_move}

if do_shuffle:
    save_datafile = save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}_shuffled.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel)
else:
    save_datafile = save_dir/'GLM_{}_Data_VisMov_dt{:03d}_T{:02d}_MovModel{:d}.h5'.format(model_type,int(model_dt*1000), nt_glm_lag, MovModel)
ioh5.save(save_datafile, GLM_Data)
print(save_datafile)

/home/seuss/Research/SensoryMotorPred_Data/data/070921/J553RT/fm1/GLM_Pytorch_Data_VisMov_dt050_T05_MovModel1.h5
