# Import Modules

In [113]:
%load_ext autoreload
%autoreload 2

import ray
import gc
import cv2
import time
import warnings 
import argparse
import yaml
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import xarray as xr

from ray import tune
from tqdm.auto import tqdm
from pathlib import Path
from itertools import chain
from typing import Tuple
from asyncio import Event
from test_tube import Experiment
from matplotlib.backends.backend_pdf import PdfPages
from scipy.signal import medfilt
from scipy.stats import binned_statistic
from scipy.interpolate import interp1d
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import GroupShuffleSplit
from sklearn.utils import shuffle


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from kornia.geometry.transform import Affine
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import Utils.io_dict_to_hdf5 as ioh5
from Utils.utils import *
from Utils.params import *
from Utils.format_raw_data import *
from Utils.format_model_data import *


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Key Parameters:
- model_dt:     (float) size of time bins in seconds
- date_ani:     (str) date and animal ID
- base_dir:     (str) base directory
- save_dir:     (str) directory where processed data is going to be saved
- data_dir:     (str) directory where raw data is held
- downsamp_vid: (int) factor videos are downsampled by
- lag_list:     (list) which timesteps to include in fits

# Format Data

## Testing Loading Raw Data

In [2]:
# Input arguments
args = arg_parser(jupyter=True)

dates_all = ['070921/J553RT' ,'101521/J559NC','102821/J570LT','110421/J569LT'] #,'122021/J581RT','020422/J577RT'] # '102621/J558NC' '062921/G6HCK1ALTRN',
args['date_ani']        = dates_all[0]
args['free_move']       = True
args['train_shifter']   = True
args['NoL1']            = False
args['NoL2']            = False
args['do_shuffle']      = False
args['Nepochs']         = 10000


ModelID = 1
params, exp = load_params(args,ModelID,exp_dir_name=None,nKfold=0,debug=False)

file_dict = {'cell': 0,
            'drop_slow_frames': False,
            'ephys': list(params['data_dir'].glob('*ephys_merge.json'))[0].as_posix(),
            'ephys_bin': list(params['data_dir'].glob('*Ephys.bin'))[0].as_posix(),
            'eye': list(params['data_dir'].glob('*REYE.nc'))[0].as_posix(),
            'imu': list(params['data_dir'].glob('*imu.nc'))[0].as_posix() if params['stim_cond'] == params['fm_dir'] else None,
            'mapping_json': Path('~/Research/Github/FreelyMovingEphys/config/channel_maps.json').expanduser(),
            'mp4': True,
            'name': params['date_ani2'] + '_control_Rig2_' + params['stim_cond'],  # 070921_J553RT
            'probe_name': 'DB_P128-6',
            'save': params['data_dir'].as_posix(),
            'speed': list(params['data_dir'].glob('*speed.nc'))[0].as_posix() if params['stim_cond'] == 'hf1_wn' else None,
            'stim_cond': 'light',
            'top': list(params['data_dir'].glob('*TOP1.nc'))[0].as_posix() if params['stim_cond'] == params['fm_dir'] else None,
            'world': list(params['data_dir'].glob('*world.nc'))[0].as_posix(), 
            'ephys_csv': list(params['data_dir'].glob('*Ephys_BonsaiBoardTS.csv'))[0].as_posix()}


In [3]:
# def format_raw_data(file_dict, params, medfiltbins=11, **kwargs):
#     """ Formatting raw data for Niell Lab freely moving ephys data 

#     Args:
#         file_dict (dict): file dictionary containing raw data paths.
#         params (dict): parameter dictionary holding key parameters for formatting.
#         medfiltbins (int, optional): filter bin size for smoothing. Defaults to 11.

#     Returns:
#         raw_data (dict): returns formatted dictionary of raw data
#         goodcells (pd.DataFrame): returns a DataFrame with ephys unit information
#     """
    
#     ##### Set up condition shorthand #####
#     if file_dict['imu'] is not None:
#         has_imu = True
#         has_mouse = False
#     else:
#         has_imu = False
#         has_mouse = True
        
#     # open worldcam
#     print('opening worldcam data')
#     world_data = xr.open_dataset(file_dict['world'])
#     world_vid_raw = np.uint8(world_data['WORLD_video'])
#     # resize worldcam
#     sz = world_vid_raw.shape # raw video size
#     # if size is larger than the target 60x80, resize by 0.5
#     if sz[1]>160:
#         downsamp = 0.5
#         world_vid = np.zeros((sz[0],int(sz[1]*downsamp),int(sz[2]*downsamp)), dtype = 'uint8')
#         for f in range(sz[0]):
#             world_vid[f,:,:] = cv2.resize(world_vid_raw[f,:,:],(int(sz[2]*downsamp),int(sz[1]*downsamp)))
#     else:
#         # if the worldcam has already been resized when the nc file was written in preprocessing, don't resize
#         world_vid = world_vid_raw.copy()
#     del world_vid_raw
#     gc.collect()

#     # world timestamps
#     worldT = world_data.timestamps.copy()
#     if params['free_move'] == True:
#         print('opening top data')
#     #     # open the topdown camera nc file
#         top_data = xr.open_dataset(file_dict['top'])
#         top_speed = top_data.TOP1_props[:,0].data
#         topT = top_data.timestamps.data.copy() # read in time timestamps
#     #     top_vid = np.uint8(top_data['TOP1_video']) # read in top video
#         # clear from memory
#         del top_data
#         gc.collect()
#     else: 
#         topT = []
#         top_speed = []
        
#     # load IMU data
#     if has_imu:
#         print('opening imu data')
#         imu_data = xr.open_dataset(file_dict['imu'])
#         try:
#             accT = imu_data.IMU_data.sample # imu timestamps
#             acc_chans = imu_data.IMU_data # imu dample data
#         except AttributeError:
#             accT = imu_data.__xarray_dataarray_variable__.sample
#             acc_chans = imu_data.__xarray_dataarray_variable__
        
#         # raw gyro values
#         gx = np.array(acc_chans.sel(channel='gyro_x_raw'))
#         gy = np.array(acc_chans.sel(channel='gyro_y_raw'))
#         gz = np.array(acc_chans.sel(channel='gyro_z_raw'))
#         gz = (gz-np.mean(gz))*7.5 # Rescale gz
#         # gyro values in degrees
#         gx_deg = np.array(acc_chans.sel(channel='gyro_x'))
#         gy_deg = np.array(acc_chans.sel(channel='gyro_y'))
#         gz_deg = np.array(acc_chans.sel(channel='gyro_z'))
#         # pitch and roll in deg
#         groll = medfilt(np.array(acc_chans.sel(channel='roll')),medfiltbins)
#         gpitch = medfilt(np.array(acc_chans.sel(channel='pitch')),medfiltbins)
#     else: 
#         accT = []
#         gz = []
#         groll = []
#         gpitch = []
#     # load optical mouse nc file from running ball
#     if file_dict['speed'] is not None:
#         print('opening speed data')
#         speed_data = xr.open_dataset(file_dict['speed'])
#         try:
#             spdVals = speed_data.BALL_data
#         except AttributeError:
#             spdVals = speed_data.__xarray_dataarray_variable__
#         try:
#             spd = spdVals.sel(move_params = 'speed_cmpersec')
#             spd_tstamps = spdVals.sel(move_params = 'timestamps')
#         except:
#             spd = spdVals.sel(frame = 'speed_cmpersec')
#             spd_tstamps = spdVals.sel(frame = 'timestamps')
#     print('opening ephys data')
#     # ephys data for this individual recording
#     ephys_data = pd.read_json(file_dict['ephys'])
#     # sort units by shank and site order
#     ephys_data = ephys_data.sort_values(by='ch', axis=0, ascending=True)
#     ephys_data = ephys_data.reset_index()
#     ephys_data = ephys_data.drop('index', axis=1)
#     # spike times
#     ephys_data['spikeTraw'] = ephys_data['spikeT']
#     print('getting good cells')
#     # select good cells from phy2
#     goodcells = ephys_data.loc[ephys_data['group']=='good']
#     units = goodcells.index.values
#     # get number of good units
#     n_units = len(goodcells)
#     # plot spike raster
#     plt.close()
#     print('opening eyecam data')
#     # load eye data
#     eye_data = xr.open_dataset(file_dict['eye'])
#     eye_vid = np.uint8(eye_data['REYE_video'])
#     eyeT = eye_data.timestamps.copy()
#     # plot eye postion across recording
#     eye_params = eye_data['REYE_ellipse_params']
#     # define theta, phi and zero-center
#     th = np.array((eye_params.sel(ellipse_params = 'theta'))*180/np.pi)
#     phi = np.array((eye_params.sel(ellipse_params = 'phi'))*180/np.pi)
#     eyerad = eye_data.REYE_ellipse_params.sel(ellipse_params = 'longaxis').data

#     print('adjusting camera times to match ephys')
#     # adjust eye/world/top times relative to ephys
#     ephysT0 = ephys_data.iloc[0,12]
#     eyeT = eye_data.timestamps  - ephysT0
#     if eyeT[0]<-600:
#         eyeT = eyeT + 8*60*60 # 8hr offset for some data
#     worldT = world_data.timestamps - ephysT0
#     if worldT[0]<-600:
#         worldT = worldT + 8*60*60
#     if params['free_move'] is True and has_imu is True:
#         accTraw = accT - ephysT0
#     if params['free_move'] is False and has_mouse is True:
#         speedT = spd_tstamps - ephysT0
#     if params['free_move'] is True:
#         topT = topT - ephysT0

#     ##### Clear some memory #####
#     del eye_data 
#     gc.collect()

#     if file_dict['drop_slow_frames'] is True:
#         # in the case that the recording has long time lags, drop data in a window +/- 3 frames around these slow frames
#         isfast = np.diff(eyeT)<=0.05
#         isslow = sorted(list(set(chain.from_iterable([list(range(int(i)-3,int(i)+4)) for i in np.where(isfast==False)[0]]))))
#         th[isslow] = np.nan
#         phi[isslow] = np.nan


#     print(world_vid.shape)
#     # calculate eye veloctiy
#     dEye = np.diff(th)
#     accT_correction_file = params['save_dir']/'acct_correction_{}.h5'.format(params['data_name'])
#     # check accelerometer / eye temporal alignment
#     if (accT_correction_file.exists()):# & (reprocess==False):
#         accT_correction = ioh5.load(accT_correction_file)
#         offset0    = accT_correction['offset0']
#         drift_rate = accT_correction['drift_rate']
#         accT = accTraw - (offset0 + accTraw*drift_rate)
#         found_good_offset = True
#     else:
#         if (has_imu):
#             print('checking accelerometer / eye temporal alignment')
#             lag_range = np.arange(-0.2,0.2,0.002)
#             cc = np.zeros(np.shape(lag_range))
#             t1 = np.arange(5,len(dEye)/60-120,20).astype(int) # was np.arange(5,1600,20), changed for shorter videos
#             t2 = t1 + 60
#             offset = np.zeros(np.shape(t1))
#             ccmax = np.zeros(np.shape(t1))
#             acc_interp = interp1d(accTraw, (gz-3)*7.5)
#             for tstart in tqdm(range(len(t1))):
#                 for l in range(len(lag_range)):
#                     try:
#                         c, lag= nanxcorr(-dEye[t1[tstart]*60 : t2[tstart]*60] , acc_interp(eyeT[t1[tstart]*60:t2[tstart]*60]+lag_range[l]),1)
#                         cc[l] = c[1]
#                     except: # occasional problem with operands that cannot be broadcast togther because of different shapes
#                         cc[l] = np.nan
#                 offset[tstart] = lag_range[np.argmax(cc)]    
#                 ccmax[tstart] = np.max(cc)
#             offset[ccmax<0.1] = np.nan
#             del ccmax, dEye
#             gc.collect()
#             if np.isnan(offset).all():
#                 found_good_offset = False
#             else:
#                 found_good_offset = True

#         if has_imu and found_good_offset is True:
#             print('fitting regression to timing drift')
#             # fit regression to timing drift
#             model = LinearRegression()
#             dataT = np.array(eyeT[t1*60 + 30*60])
#             model.fit(dataT[~np.isnan(offset)].reshape(-1,1),offset[~np.isnan(offset)]) 
#             offset0 = model.intercept_
#             drift_rate = model.coef_
#             del dataT
#             gc.collect()
#         elif file_dict['speed'] is not None or found_good_offset is False:
#             offset0 = 0.1
#             drift_rate = -0.000114
#         if has_imu:
#             accT_correction = {'offset0': offset0, 'drift_rate': drift_rate}
#             ioh5.save(accT_correction_file,accT_correction)
#             accT = accTraw - (offset0 + accTraw*drift_rate)
#             del accTraw
#             gc.collect()


#     print('correcting ephys spike times for offset and timing drift')
#     for i in ephys_data.index:
#         ephys_data.at[i,'spikeT'] = np.array(ephys_data.at[i,'spikeTraw']) - (offset0 + np.array(ephys_data.at[i,'spikeTraw']) *drift_rate)
#     goodcells = ephys_data.loc[ephys_data['group']=='good']


#     ##### Calculating image norm #####
#     print('Calculating Image Norm')
#     start = time.time()
#     sz = np.shape(world_vid)
#     world_vid_sm = np.zeros((sz[0],int(sz[1]/params['downsamp_vid']),int(sz[2]/params['downsamp_vid'])),dtype=np.uint8)
#     for f in range(sz[0]):
#         world_vid_sm[f,:,:] = cv2.resize(world_vid[f,:,:],(int(sz[2]/params['downsamp_vid']),int(sz[1]/params['downsamp_vid'])))

#     del world_vid
#     gc.collect()

#     raw_data = {
#                 'eye':{
#                     'th':     th,
#                     'phi':    phi,
#                     'eyerad': eyerad,
#                     'eyeTS':  eyeT,
#                 },
#                 'acc': {
#                     'gz':     gz,
#                     'roll':  groll,
#                     'pitch': gpitch,    
#                     'accTS':  accT,
#                 },
#                 'top':{
#                     'speed': top_speed,
#                     'topTS': topT,
#                 },
#                 'vid':{
#                     'vidTS':    worldT,
#                     'vid_sm':   world_vid_sm,
#                 },
#                 }

#     return raw_data, goodcells


# def interp_raw_data(raw_data, align_t, model_dt=0.05, goodcells=None):
#     """Interpolates raw data based on nested dictionary. 

#     Args:
#         raw_data (dict): nested dictionary where first level (raw_data[key0]) represents data type
#                          second level contains data and timestamps assuming the following format respectively: 
#                          raw_data[key0]['datatype'] or raw_data[key0]['datatypeTS'] 
#                          where datatype is the variabel name of the data.
#         align_t (np.array): timestamps which to align data to.
#         model_dt (float, optional): model bin size to align data. Defaults to 0.05.
#         goodcells (pd.DataFrame, optional): If processing ephys data input DataFrame with size (units x features) with a column named 
#                                             spikeT. spikeT containsa list spike times (row). Defaults to None.

#     Returns:
#         model_data (dict): dictionary containing interpolated time-aligned model data with naming convention 'model_[datatype]'
#     """


#     ##### Set up model interpolated time #####
#     model_t = np.arange(0,np.max(align_t), model_dt)

#     ##### Interpolate raw data #####
#     model_data = {}
#     for key0 in raw_data.keys():
#             for key1 in raw_data[key0].keys():
#                 if 'TS' not in key1:
#                     if 'vid' in key0: # Z score video then interpolate
#                         std_im = np.std(raw_data[key0][key1], axis=0, dtype=float)
#                         img_norm = ((raw_data[key0][key1]-np.mean(raw_data[key0][key1],axis=0,dtype=float))/std_im).astype(float)
#                         std_im[std_im<20] = 0 # zero out extreme values
#                         img_norm = (img_norm * (std_im>0)).astype(float)
#                         interp = interp1d(raw_data[key0][key0+'TS'], img_norm,'nearest', axis=0,bounds_error = False) 
#                         testimg = interp(model_t[0])
#                         model_vid_sm = np.zeros((len(model_t),int(np.shape(testimg)[0]),int(np.shape(testimg)[1])),dtype=float)
#                         for i in tqdm(range(len(model_t))):
#                             model_vid = interp(model_t[i] + model_dt/2)
#                             model_vid_sm[i,:] = model_vid
#                         model_vid_sm[np.isnan(model_vid_sm)]=0
#                         model_data['model_'+ key1] = model_vid_sm
#                     else:
#                         interp = interp1d(raw_data[key0][key0+'TS'],raw_data[key0][key1],axis=0, bounds_error=False)
#                         model_data['model_'+ key1] = interp(model_t+model_dt/2)
#     model_data['model_t'] = model_t
#     if 'acc' in raw_data.keys():
#         model_data['model_active'] = np.convolve(np.abs(model_data['model_gz']), np.ones(int(1/model_dt)), 'same') / len(np.ones(int(1/model_dt)))

#     # get spikes / rate
#     if goodcells is not None:
#         n_units = len(goodcells)
#         model_nsp = np.zeros((len(model_t),n_units))
#         bins = np.append(model_t,model_t[-1]+model_dt)
#         for i,ind in enumerate(goodcells.index):
#             model_nsp[:,i],bins = np.histogram(goodcells.at[ind,'spikeT'],bins)
#         model_data['model_nsp'] = model_nsp
#         model_data['unit_nums'] = goodcells.index.values

#     return model_data


# def load_aligned_data(file_dict, params, reprocess=False):
#     """ Load time aligned data from file or process raw data and return formatted data

#     Args:
#         file_dict (dict): file dictionary containing raw data paths.
#         params (dict): parameter dictionary holding key parameters for formatting.
#         reprocess (bool, optional): reprocess raw data. Defaults to False.

#     Returns:
#         model_data (dict): returns dictionary with time aligned model data
#     """

#     model_file = params['save_dir'] / 'ModelData_{}_dt{:03d}_rawWorldCam_{:d}ds.h5'.format(params['data_name'],int(params['model_dt']*1000),int(params['downsamp_vid']))
#     if (model_file.exists()) & (reprocess==False):
#         model_data = ioh5.load(model_file)
#     else:
#         raw_data, goodcells = format_raw_data(file_dict,params)
#         model_data = interp_raw_data(raw_data,raw_data['vid']['vidTS'],model_dt=params['model_dt'],goodcells=goodcells)
#         if params['free_move']:
#             ##### Saving average and std of parameters for centering and scoring across conditions #####
#             FM_move_avg = np.zeros((2,6))
#             FM_move_avg[:,0] = np.array([np.nanmean(model_data['model_th']),np.nanstd(model_data['model_th'])])
#             FM_move_avg[:,1] = np.array([np.nanmean(model_data['model_phi']),np.nanstd(model_data['model_phi'])])
#             FM_move_avg[:,2] = np.array([np.nanmean(model_data['model_roll']),np.nanstd(model_data['model_roll'])])
#             FM_move_avg[:,3] = np.array([np.nanmean(model_data['model_pitch']),np.nanstd(model_data['model_pitch'])])
#             FM_move_avg[:,4] = np.array([np.nanmean(model_data['model_speed']),np.nanmax(model_data['model_speed'])])
#             FM_move_avg[:,5] = np.array([np.nanmean(model_data['model_eyerad']),np.nanmax(model_data['model_eyerad'])])
#             np.save(params['save_dir_fm']/'FM_MovAvg_{}_dt{:03d}.npy'.format(params['data_name'],int(params['model_dt']*1000)),FM_move_avg)
#         ephys_file = params['save_dir'] / 'RawEphysData_{}.h5'.format(params['data_name'])
#         goodcells.to_hdf(ephys_file,key='goodcells', mode='w')
#         ioh5.save(model_file, model_data)
#     return model_data



# def load_aligned_data(file_dict, params, reprocess=False):
#     """ Load time aligned data from file or process raw data and return formatted data

#     Args:
#         file_dict (dict): file dictionary containing raw data paths.
#         params (dict): parameter dictionary holding key parameters for formatting.
#         reprocess (bool, optional): reprocess raw data. Defaults to False.

#     Returns:
#         model_data (dict): returns dictionary with time aligned model data
#     """

#     model_file = params['save_dir'] / 'ModelData_{}_dt{:03d}_rawWorldCam_{:d}ds.h5'.format(params['data_name'],int(params['model_dt']*1000),int(params['downsamp_vid']))
#     if (model_file.exists()) & (reprocess==False):
#         model_data = ioh5.load(model_file)
#     else:
#         raw_data, goodcells = format_raw_data(file_dict,params)
#         model_data = interp_raw_data(raw_data,raw_data['vid']['vidTS'],model_dt=params['model_dt'],goodcells=goodcells)
#         if params['free_move']:
#             ##### Saving average and std of parameters for centering and scoring across conditions #####
#             FM_move_avg = np.zeros((2,6))
#             FM_move_avg[:,0] = np.array([np.nanmean(model_data['model_th']),np.nanstd(model_data['model_th'])])
#             FM_move_avg[:,1] = np.array([np.nanmean(model_data['model_phi']),np.nanstd(model_data['model_phi'])])
#             FM_move_avg[:,2] = np.array([np.nanmean(model_data['model_roll']),np.nanstd(model_data['model_roll'])])
#             FM_move_avg[:,3] = np.array([np.nanmean(model_data['model_pitch']),np.nanstd(model_data['model_pitch'])])
#             FM_move_avg[:,4] = np.array([np.nanmean(model_data['model_speed']),np.nanmax(model_data['model_speed'])])
#             FM_move_avg[:,5] = np.array([np.nanmean(model_data['model_eyerad']),np.nanmax(model_data['model_eyerad'])])
#             np.save(params['save_dir_fm']/'FM_MovAvg_{}_dt{:03d}.npy'.format(params['data_name'],int(params['model_dt']*1000)),FM_move_avg)
#         ephys_file = params['save_dir'] / 'RawEphysData_{}.h5'.format(params['data_name'])
#         goodcells.to_hdf(ephys_file,key='goodcells', mode='w')
#         ioh5.save(model_file, model_data)
#     return model_data



# def format_data(data, params, frac=.1, shifter_train_size=.5, test_train_size=.75, do_shuffle=False, do_norm=False, NKfold=1, thresh_cells=False, cut_inactive=True, move_medwin=11,**kwargs):
#     """ Fully format data for model training

#     Args:
#         data (dict): data dictionary containing time aligned data. get as ouput from load_aligned_data
#         params (dict): parameter dictionary holding key parameters for formatting.
#         frac (float, optional): from of total length, size of groups for train/test split. Defaults to .1.
#         shifter_train_size (float, optional): shifter train/test split fraction. Defaults to .5.
#         test_train_size (float, optional): train/test size for random group shuffle. Defaults to .75.
#         do_shuffle (bool, optional): shuffle spikes. Defaults to False.
#         do_norm (bool, optional): normalize data. Defaults to False.
#         NKfold (int, optional): How many Kfolds to make. Defaults to 1.
#         thresh_cells (bool, optional): threshold out bad cells. Defaults to False.
#         cut_inactive (bool, optional): cut inactive timepoints. Defaults to True.
#         move_medwin (int, optional): window to smooth roll/pitch. Defaults to 11.

#     Returns:
#         _type_: _description_
#     """
#     ##### Load in preprocessed data #####
#     if params['free_move']:
#         ##### Find 'good' timepoints when mouse is active #####
#         nan_idxs = []
#         for key in data.keys():
#             nan_idxs.append(np.where(np.isnan(data[key]))[0])
#         good_idxs = np.ones(len(data['model_active']),dtype=bool)
#         good_idxs[data['model_active']<0.5] = False # .5 based on histogram, determined emperically 
#         good_idxs[np.unique(np.hstack(nan_idxs))] = False
#     else:
#         good_idxs = np.where((np.abs(data['model_th'])<50) & (np.abs(data['model_phi'])<50))[0].astype(int)


#     data['raw_nsp'] = data['model_nsp'].copy()
#     if cut_inactive:
#         ##### return only active data #####
#         for key in data.keys():
#             if (key != 'model_nsp') & (key != 'model_active') & (key != 'unit_nums') & (key != 'model_vis_sm_shift'):
#                 if len(data[key])>0:
#                     data[key] = data[key][good_idxs] # interp_nans(data[key]).astype(float)
#             elif (key == 'model_nsp'):
#                 data[key] = data[key][good_idxs]
#             elif (key == 'unit_nums') | (key == 'model_vis_sm_shift'):
#                 pass
    
    
#     ##### Splitting data for shifter, then split for model training #####
#     if params['shifter_5050']==False:
#         gss = GroupShuffleSplit(n_splits=NKfold, train_size=test_train_size, random_state=42)
#         nT = data['model_nsp'].shape[0]
#         groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((frac*(i-1))*nT)) for i in range(1,int(1/frac)+1)])

#         train_idx_list=[]
#         test_idx_list = []
#         for train_idx, test_idx in gss.split(np.arange(data['model_nsp'].shape[0]), groups=groups):
#             train_idx_list.append(train_idx)
#             test_idx_list.append(test_idx)
#     else:
#         ##### Need to fix 50/50 splitting ###################
#         # gss = GroupShuffleSplit(n_splits=NKfold, train_size=shifter_train_size, random_state=42)
#         np.random.seed(42)
#         nT = data['model_nsp'].shape[0]
        
#         shifter_train_size = .5
#         groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((frac*(i-1))*nT)) for i in range(1,int(1/frac)+1)])
#         train_idx_list_shifter=[]
#         test_idx_list_shifter=[]
#         train_idx_list=[]
#         test_idx_list = []
#         for Kfold in np.arange(NKfold):
#             glist = np.arange(1,1/frac+1)
#             shifter_groups = np.random.choice(glist,size=int((1/frac)*shifter_train_size),replace=False)
#             idx = np.arange(nT)
#             sampled_inds = np.any(np.array([groups==shifter_groups[n] for n in np.arange(len(shifter_groups))]),axis=0)
#             train_idx_list_shifter.append(idx[sampled_inds])
#             test_idx_list_shifter.append(idx[~sampled_inds])
#             glist=glist[~np.any(np.array([glist==shifter_groups[n] for n in np.arange(len(shifter_groups))]),axis=0)]
            
#             if params['train_shifter']==False:
#                 if params['shifter_5050_run']:
#                     train_idx_list = test_idx_list_shifter
#                     test_idx_list  = train_idx_list_shifter
#                 else:
#                     train_idx_list = train_idx_list_shifter
#                     test_idx_list  = test_idx_list_shifter
#             else:
#                 if params['shifter_5050_run']:
#                     train_idx_list = test_idx_list_shifter
#                     test_idx_list  = train_idx_list_shifter
#                 else:
#                     train_idx_list = train_idx_list_shifter
#                     test_idx_list  = test_idx_list_shifter

#     if thresh_cells:
#         print('Tot_units: {}'.format(data['unit_nums'].shape))
#         if params['free_move']:
#             if (params['save_dir_fm']/'bad_cells.npy').exists():
#                 bad_cells = np.load(params['save_dir_fm']/'bad_cells.npy')
#             else:
#                 mean_thresh = np.nanmean(data['model_nsp']/params['model_dt'],axis=0) < 1 # Thresholding out units under 1 Hz
#                 f25,l75=int((data['model_nsp'].shape[0])*.5),int((data['model_nsp'].shape[0])*.5) # Checking first 25% and last 25% firing rate for drift
#                 scaled_fr = (np.nanmean(data['model_nsp'][:f25], axis=0)/np.nanstd(data['model_nsp'][:f25], axis=0) - np.nanmean(data['model_nsp'][l75:], axis=0)/np.nanstd(data['model_nsp'][l75:], axis=0))/model_dt
#                 bad_cells = np.where((mean_thresh | (np.abs(scaled_fr)>4)))[0] # Locating bad units      
#                 np.save(params['save_dir_fm']/'bad_cells.npy',bad_cells)
#         else:
#             bad_cells = np.load(params['save_dir_fm']/'bad_cells.npy')

#         data['model_nsp'] = np.delete(data['model_nsp'],bad_cells,axis=1) # removing bad units
#         data['unit_nums'] = np.delete(data['unit_nums'],bad_cells,axis=0) # removing bad units
        
#     data['model_dth'] = np.diff(data['model_th'],append=0)
#     data['model_dphi'] = np.diff(data['model_phi'],append=0)
#     FM_move_avg = np.load(params['save_dir_fm']/'FM_MovAvg_{}_dt{:03d}.npy'.format(params['data_name'],int(params['model_dt']*1000)))
#     data['model_th'] = data['model_th'] - FM_move_avg[0,0]
#     data['model_phi'] = (data['model_phi'] - FM_move_avg[0,1])
#     if do_norm:
#         data['model_vid_sm'] = (data['model_vid_sm'] - np.mean(data['model_vid_sm'],axis=0))/np.nanstd(data['model_vid_sm'],axis=0)
#         data['model_vid_sm'][np.isnan(data['model_vid_sm'])]=0
#         data['model_th'] = (data['model_th'])/FM_move_avg[1,0] # np.nanstd(data['model_th'],axis=0) 
#         data['model_phi'] = (data['model_phi'])/FM_move_avg[1,1] # np.nanstd(data['model_phi'],axis=0) 
#         if params['free_move']:
#             data['model_roll'] = (data['model_roll'] - FM_move_avg[0,2])/FM_move_avg[1,2]
#             data['model_pitch'] = (data['model_pitch'] - FM_move_avg[0,3])/FM_move_avg[1,3]
#             data['model_roll'] = medfilt(data['model_roll'],move_medwin)
#             data['model_pitch'] = medfilt(data['model_pitch'],move_medwin)
#             if params['use_spdpup']:
#                 data['model_speed'] = (data['model_speed']-FM_move_avg[0,4])/FM_move_avg[1,4]
#                 data['model_eyerad'] = (data['model_eyerad']-FM_move_avg[0,5])/FM_move_avg[1,5]
#         else:
#             # data['model_roll'] = (0 - FM_move_avg[0,2])/FM_move_avg[1,2])
#             data['model_pitch'] = (np.zeros(data['model_phi'].shape) - FM_move_avg[0,3])/FM_move_avg[1,3]
#     else:
#         if params['free_move']:
#             data['model_roll']   = (data['model_roll'] - FM_move_avg[0,2])
#             data['model_pitch']  = (data['model_pitch'] - FM_move_avg[0,3])
#             data['model_roll']   = medfilt(data['model_roll'],move_medwin)
#             data['model_pitch']  = medfilt(data['model_pitch'],move_medwin)
#             if params['use_spdpup']:
#                 data['model_speed']  = (data['model_speed'])
#                 data['model_eyerad'] = (data['model_eyerad'])
#         else:
#             data['model_pitch']  = (np.zeros(data['model_phi'].shape) - FM_move_avg[0,3])
#     return data,train_idx_list,test_idx_list


# ##### Load in Kfold Data #####
# def load_Kfold_data(data,params,train_idx,test_idx):
#     """ Create Train/Test splits 

#     Args:
#         data (dict): dictionary of formatted data
#         params (dict): parameter dictionary holding key parameters for formatting.
#         train_idx (array): training indecies
#         test_idx (array): testing indecies

#     Returns:
#         data (dict): data dictionary with train/test splits
#     """
#     data['train_vid'] = data['model_vid_sm'][train_idx]
#     data['test_vid'] = data['model_vid_sm'][test_idx]
#     data['train_nsp'] = shuffle(data['model_nsp'][train_idx],random_state=42) if params['do_shuffle'] else data['model_nsp'][train_idx]
#     data['test_nsp'] = shuffle(data['model_nsp'][test_idx],random_state=42) if params['do_shuffle'] else data['model_nsp'][test_idx]
#     data['train_th'] = data['model_th'][train_idx]
#     data['test_th'] = data['model_th'][test_idx]
#     data['train_phi'] = data['model_phi'][train_idx]
#     data['test_phi'] = data['model_phi'][test_idx]
#     data['train_roll'] = data['model_roll'][train_idx] if params['free_move'] else []
#     data['test_roll'] = data['model_roll'][test_idx] if params['free_move'] else []
#     data['train_pitch'] = data['model_pitch'][train_idx]
#     data['test_pitch'] = data['model_pitch'][test_idx]
#     data['train_t'] = data['model_t'][train_idx]
#     data['test_t'] = data['model_t'][test_idx]
#     data['train_dth'] = data['model_dth'][train_idx]
#     data['test_dth'] = data['model_dth'][test_idx]
#     data['train_dphi'] = data['model_dphi'][train_idx]
#     data['test_dphi'] = data['model_dphi'][test_idx]
#     data['train_gz'] = data['model_gz'][train_idx] if params['free_move'] else []
#     data['test_gz'] = data['model_gz'][test_idx] if params['free_move'] else []
#     data['train_speed'] = data['model_speed'][train_idx] if ((params['use_spdpup'])&params['free_move']) else []
#     data['test_speed'] = data['model_speed'][test_idx] if ((params['use_spdpup'])&params['free_move']) else []
#     data['train_eyerad'] = data['model_eyerad'][train_idx] if ((params['use_spdpup'])&params['free_move']) else []
#     data['test_eyerad'] = data['model_eyerad'][test_idx] if ((params['use_spdpup'])&params['free_move']) else []
#     return data


In [None]:
raw_data, goodcells = format_raw_data(file_dict,params)
# for key0 in raw_data.keys():
#     for key1 in raw_data[key0].keys():
#         print(key0, key1, raw_data[key0][key1].shape,raw_data[key0][key1].dtype)
# model_data = interp_raw_data(raw_data,raw_data['vid']['vidTS'],model_dt=0.05,goodcells=goodcells)
# for key0 in model_data.keys():
#     print(key0,model_data[key0].shape)

In [4]:
data = load_aligned_data(file_dict, params, reprocess=False)
data,train_idx_list,test_idx_list = format_data(data, params,do_norm=True,thresh_cells=True,cut_inactive=True)
data = load_Kfold_data(data,params,train_idx_list[0],test_idx_list[0])

Tot_units: (128,)


## Prep data for pytorch

In [6]:

def get_modeltype(params,load_for_training=False):
    """Creates model name based on params configuation

    Args:
        params (dict): parameter dictionary holding key parameters
        load_for_training (bool, optional): If loading for shifter. Defaults to False.

    Returns:
        params or model_type: params with model_type key or model_type string
    """
    if load_for_training==False:
        if params['ModelID'] == 0:
            model_type = 'Pytorch_Mot'
        elif params['ModelID'] == 1:
            model_type = 'Pytorch_Vis'
        elif params['ModelID'] == 2:
            model_type = 'Pytorch_Add'
        elif params['ModelID'] == 3:
            model_type = 'Pytorch_Mul'
    else:
        model_type = 'Pytorch_Vis'

    if params['train_shifter']:
        params['save_model_shift'] = params['save_model'].parent.parent / 'Shifter'
        params['save_model_shift'].mkdir(parents=True, exist_ok=True)
        params['NoL1'] = True
        params['do_norm']=True
        model_type = model_type + 'Shifter'
        if params['shifter_5050']:
            if params['shifter_5050_run']:
                model_type = model_type + 'Train_1'
            else: 
                model_type = model_type + 'Train_0'
    else:
        if params['shifter_5050']:
            if params['shifter_5050_run']:
                model_type = model_type + '1'
            else: 
                model_type = model_type + '0'

    if params['EyeHead_only']:
        if params['EyeHead_only_run']==True:
            model_type = model_type + '_EyeOnly'
        else:
            model_type = model_type + '_HeadOnly'
    if params['NoShifter']:
        model_type = model_type + 'NoShifter'

    if params['only_spdpup']:
        model_type = model_type + '_onlySpdPup'
    elif params['use_spdpup']:
        model_type = model_type + '_SpdPup'
    if params['NoL1']:
        model_type = model_type + '_NoL1'
    if params['NoL2']:
        model_type = model_type + '_NoL2'
    if params['SimRF']:
        model_type = model_type + '_SimRF'
    if load_for_training==False:
        params['model_type'] = model_type
        return params
    else: 
        return model_type

In [51]:

def format_pytorch_data(data,params,train_idx,test_idx, device='cuda'):
    if params['free_move']:
        if params['train_shifter']:
            pos_train = np.hstack((data['train_th'][:, np.newaxis], data['train_phi'][:, np.newaxis],data['train_pitch'][:, np.newaxis]))
            pos_test  = np.hstack((data['test_th'][:, np.newaxis], data['test_phi'][:, np.newaxis],data['test_pitch'][:, np.newaxis]))
            model_pos = np.hstack((data['model_th'][:, np.newaxis], data['model_phi'][:, np.newaxis],data['model_pitch'][:, np.newaxis]))
            params['shift_in'] = model_pos.shape[-1]
            params['shift_out'] = model_pos.shape[-1]
        else:
            pos_train,pos_test,model_pos = [],[],[]
            for key in params['position_vars']:
                pos_train.append(data['train_'+key][:,np.newaxis])
                pos_test.append(data['test_'+key][:,np.newaxis])
                model_pos.append(data['model_'+key][:,np.newaxis])
            pos_train = np.hstack(pos_train)
            pos_test  = np.hstack(pos_test)
            model_pos = np.hstack(model_pos)
    else: 
        pos_train = np.hstack((data['train_th'][:, np.newaxis], data['train_phi'][:, np.newaxis], data['train_pitch'][:, np.newaxis], np.zeros(data['train_phi'].shape)[:, np.newaxis]))
        pos_test  = np.hstack((data['test_th'][:, np.newaxis], data['test_phi'][:, np.newaxis], data['test_pitch'][:, np.newaxis], np.zeros(data['test_phi'].shape)[:, np.newaxis]))
        model_pos = np.hstack((data['model_th'][:, np.newaxis], data['model_phi'][:, np.newaxis], data['model_pitch'][:, np.newaxis], np.zeros(data['model_phi'].shape)[:, np.newaxis]))

    ##### Save dimensions #####
    params['nks'] = np.shape(data['train_vid'])[1:]
    params['nk'] = params['nks'][0]*params['nks'][1]*params['nt_glm_lag']
    if params['train_shifter']:
        rolled_vid = np.hstack([np.roll(data['model_vid_sm'], nframes, axis=0) for nframes in params['lag_list']])
        move_quantiles = np.quantile(model_pos,params['quantiles'],axis=0)
        train_range = np.all(((pos_train>move_quantiles[0]) & (pos_train<move_quantiles[1])),axis=1)
        test_range = np.all(((pos_test>move_quantiles[0]) & (pos_test<move_quantiles[1])),axis=1)
        x_train = rolled_vid[train_idx].reshape((len(train_idx), params['nt_glm_lag'])+params['nks']).astype(np.float32)[train_range]
        x_test = rolled_vid[test_idx].reshape((len(test_idx), params['nt_glm_lag'])+params['nks']).astype(np.float32)[test_range]
        pos_train = pos_train[train_range]
        pos_test = pos_test[test_range]
        ytr = torch.from_numpy(data['train_nsp'][train_range].astype(np.float32))
        yte = torch.from_numpy(data['test_nsp'][test_range].astype(np.float32))
    elif params['NoShifter']:
        if params['crop_input'] != 0:
            model_vid_sm = data['model_vid_sm'][:,params['crop_input']:-params['crop_input'],params['crop_input']:-params['crop_input']]
        rolled_vid = np.hstack([np.roll(model_vid_sm, nframes, axis=0) for nframes in params['lag_list']])
        x_train = rolled_vid[train_idx].reshape(len(train_idx), -1).astype(np.float32)
        x_test = rolled_vid[test_idx].reshape(len(test_idx), -1).astype(np.float32)
        ytr = torch.from_numpy(data['train_nsp'].astype(np.float32))
        yte = torch.from_numpy(data['test_nsp'].astype(np.float32))
        params['nks'] = np.shape(model_vid_sm)[1:]
        params['nk'] = params['nks'][0]*params['nks'][1]*params['nt_glm_lag']
    else: 
        ##### Rework for new model format ######
        model_vid_sm_shift = ioh5.load(params['save_dir']/params['exp_name']/'ModelWC_shifted_dt{:03d}_ModelID{:d}.h5'.format(int(params['model_dt']*1000), 1))['model_vid_sm_shift']  # [:,5:-5,5:-5]
        if params['crop_input'] != 0:
            model_vid_sm_shift = model_vid_sm_shift[:,params['crop_input']:-params['crop_input'],params['crop_input']:-params['crop_input']]
        params['nks'] = np.shape(model_vid_sm_shift)[1:]
        params['nk'] = params['nks'][0]*params['nks'][1]*params['nt_glm_lag']
        rolled_vid = np.hstack([np.roll(model_vid_sm_shift, nframes, axis=0) for nframes in params['lag_list']]) 
        x_train = rolled_vid[train_idx].reshape(len(train_idx), -1).astype(np.float32)
        x_test = rolled_vid[test_idx].reshape(len(test_idx), -1).astype(np.float32)

        ytr = torch.from_numpy(data['train_nsp'].astype(np.float32))
        yte = torch.from_numpy(data['test_nsp'].astype(np.float32))

    ##### Convert to Tensors #####
    if params['ModelID']==0:
        xtr = torch.from_numpy(pos_train.astype(np.float32))
        xte = torch.from_numpy(pos_test.astype(np.float32))
    else:
        xtr = torch.from_numpy(x_train.astype(np.float32))
        xte = torch.from_numpy(x_test.astype(np.float32))
    xtr_pos = torch.from_numpy(pos_train.astype(np.float32))
    xte_pos = torch.from_numpy(pos_test.astype(np.float32))
    params['pos_features'] = xtr_pos.shape[-1]
    params['Ncells'] = ytr.shape[-1]
    
    if params['SimRF']:
        SimRF_file = params['save_dir'].parent.parent.parent/'021522/SimRF/fm1/SimRF_withL1_dt050_T01_Model1_NB10000_Kfold00_best.h5'
        SimRF_data = ioh5.load(SimRF_file)
        ytr = torch.from_numpy(SimRF_data['ytr'].astype(np.float32)).to(device)
        yte = torch.from_numpy(SimRF_data['yte'].astype(np.float32)).to(device)
        params['save_model'] = params['save_model'] / 'SimRF'
        params['save_model'].mkdir(parents=True, exist_ok=True)
        meanbias = torch.from_numpy(SimRF_data['bias_sim'].astype(np.float32)).to(device)
    else:
        meanbias = torch.mean(torch.tensor(data['model_nsp'], dtype=torch.float32), axis=0)
    return xtr, xte, xtr_pos, xte_pos, ytr, yte, meanbias


In [52]:

class FreeMovingEphysDataset(Dataset):
    def __init__(self, input0, input1, target):
        """Pytorch dataset class for 2 inputs and 1 output

        Args:
            input0 (Tensor): tensor for model input
            input1 (Tensor): tensor for model input
            target (Tensor): Target data
        """
        self.input0 = input0
        self.input1 = input1
        self.target = target
        
    def __len__(self):
        return len(self.input0)

    def __getitem__(self, idx):
        X = self.input0[idx]
        Y = self.target[idx]
        X2 = self.input1[idx]
        return X, X2, Y

In [180]:
params['ModelID']=1
params['position_vars'] = ['th','phi','pitch','roll']#,'speed','eyerad']
params['train_shifter']=True
train_idx = train_idx_list[0]
test_idx = test_idx_list[0]

xtr, xte, xtr_pos, xte_pos, ytr, yte, meanbias = format_pytorch_data(data,params,train_idx,test_idx)
train_dataset = FreeMovingEphysDataset(xtr,xtr_pos,ytr)
test_dataset  = FreeMovingEphysDataset(xte,xte_pos,yte)
train_dataloader = DataLoader(train_dataset, batch_size=xtr.shape[0],num_workers=2,pin_memory=True,)
test_dataloader = DataLoader(test_dataset, batch_size=xte.shape[0],num_workers=2,pin_memory=True,)


In [36]:
vid,pos,Y = next(iter(train_dataloader))

In [25]:
xtr_pos.shape

torch.Size([18972, 3])

# Models

In [37]:

class BaseModel(nn.Module):
    def __init__(self, 
                    in_features, 
                    N_cells,
                    config,
                    device='cuda'):
        super(BaseModel, self).__init__()
        r''' Base GLM Network
        Args: 
            in_feature: size of the input dimension
            N_cells: the number of cells to fit
            config: network configuration file with hyperparameters
            device: which device to run network on
        '''
        self.config = config
        self.in_features = in_features
        self.N_cells = N_cells
        
        self.Cell_NN = nn.Sequential(nn.Linear(self.in_features, self.N_cells,bias=True))
        self.activations = nn.ModuleDict({'SoftPlus':nn.Softplus(),
                                          'ReLU': nn.ReLU(),})
        torch.nn.init.uniform_(self.Cell_NN[0].weight, a=-1e-6, b=1e-6)
        
        # Initialize Regularization parameters
        self.L1_alpha = config['L1_alpha']
        if self.L1_alpha != None:
            self.alpha = config['L1_alpha']*torch.ones(1).to(device)

      
    def init_weights(self,m):
        if isinstance(m, nn.Linear):
            torch.nn.init.uniform_(m.weight,a=-1e-6,b=1e-6)
            m.bias.data.fill_(1e-6)
        
    def forward(self, inputs, **kwargs):
        x, y = inputs.shape
        output = self.Cell_NN(inputs)
        ret = self.activations['ReLU'](output)

    def loss(self,Yhat, Y): 
        loss_vec = torch.mean((Yhat-Y)**2,axis=0)
        if self.L1_alpha != None:
            l1_reg0 = torch.stack([torch.linalg.vector_norm(NN_params,ord=1) for name, NN_params in self.Cell_NN.named_parameters() if '0.weight' in name])
            loss_vec = loss_vec + self.alpha*(l1_reg0)
        return loss_vec


class ShifterNetwork(BaseModel):
    def __init__(self, 
                    in_features, 
                    N_cells, 
                    config,
                    device='cuda'):
        super(ShifterNetwork, self).__init__(in_features, N_cells, config)
        r''' Shifter GLM Network
        Args: 
            in_feature: size of the input dimension
            N_cells: the number of cells to fit
            shift_in: size of the input to shifter network. 
            shift_hidden: size of the hidden layer in the shifter network
            shift_out: output dimension of shifter network
            L1_alpha: L1 regularization value for visual network
            train_shifter: Bool for whether training shifter network
            meanbias: can set bias to mean firing rate of each neurons
            device: which device to run network on
        
        '''
        self.config = config
        ##### shifter network initialization #####
        self.shift_in = config['shift_in']
        self.shift_hidden = config['shift_hidden']
        self.shift_out = config['shift_out']
        self.shifter_nn = nn.Sequential(
                                        nn.Linear(config['shift_in'],config['shift_hidden']),
                                        nn.Softplus(),
                                        nn.Linear(config['shift_hidden'], config['shift_out'])
                                    )


    def forward(self, inputs, shifter_input=None, **kwargs):
        ##### Forward Pass of Shifter #####
        batchsize, timesize, x, y = inputs.shape
        dxy = self.shifter_nn(shifter_input)
        shift = Affine(angle=torch.clamp(dxy[:,-1],min=-45,max=45),translation=torch.clamp(dxy[:,:2],min=-15,max=15))
        inputs = shift(inputs)
        inputs = inputs.reshape(batchsize,-1).contiguous()
        ##### fowrad pass of GLM #####
        x, y = inputs.shape
        if y != self.in_features:
            print(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
            return 0
        output = self.Cell_NN(inputs)
        ret = self.activations['ReLU'](output)
        return ret



class MixedNetwork(BaseModel):
    def __init__(self, 
                in_features, 
                N_cells, 
                config,
                device='cuda'):
        super(MixedNetwork, self).__init__(in_features, N_cells, config)
        r''' Mixed GLM Network
        Args: 
            in_feature: size of the input dimension
            N_cells: the number of cells to fit
            L1_alpha: L1 regularization value for visual network
            L1_alpham: L1 regularization value for position network 
            move_feature: the number of position features 
            LinMix: Additive or Multiplicative mixing. LinMix=True for additive, LinMix=False for multiplicative
            device: which device to run network on
        
        '''
        self.config = config
        self.LinMix = config['LinMix']
        ##### Position Network Initialization #####
        if config['L1_alpham'] != None:
            self.alpha_m = config['L1_alpham']*torch.ones(1).to(device)
        self.posNN = nn.Sequential(nn.Linear(config['pos_features'], N_cells))
        torch.nn.init.uniform_(self.posNN[0].weight,a=-1e-6,b=1e-6) 
        if self.LinMix==False:
            torch.nn.init.ones_(self.posNN[0].bias) # Bias = 1 for mult bias=0 for add
        else:    
            torch.nn.init.zeros_(self.posNN[0].bias) # Bias = 1 for mult bias=0 for add

    def forward(self, inputs, pos_inputs, **kwargs):
        x, y = inputs.shape
        if y != self.in_features:
            print(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
            return 0
        output = self.Cell_NN(inputs)
        if self.LinMix==True:
            output = output + self.posNN(pos_inputs)
        else:
            move_out = torch.abs(self.posNN(pos_inputs))
            output = output*move_out
        ret = self.activations['ReLU'](output)

    def loss(self,Yhat, Y): 
        loss_vec = torch.mean((Yhat-Y)**2,axis=0)
        if self.L1_alpha != None:
            l1_reg0 = self.alpha*(torch.linalg.norm(self.Cell_NN[0].weight,axis=1,ord=1))
        else: 
            l1_reg0 = 0
            l1_reg1 = 0
        if self.L1_alpham != None:
            l1_regm = self.alpha_m*(torch.linalg.norm(self.posNN[0].weight,axis=1,ord=1))
        else: 
            l1_regm = 0
        loss_vec = loss_vec + l1_reg0 + l1_reg1 + l1_regm
        return loss_vec


In [161]:
def make_network_config(params,single_trial=False):
    """ Create Network Config dictionary for hyperparameter search

    Args:
        params (dict): key parameters dictionary
        device (string): cuda

    Returns:
        network_config (dict): dictionary with hyperparameters
    """
    network_config = {}
    network_config['in_features']   = params['nk']
    network_config['Ncells']        = params['Ncells']
    network_config['shift_in']      = params['shift_in']
    network_config['shift_hidden']  = params['shift_hidden']
    network_config['shift_out']     = params['shift_out']
    network_config['LinMix']        = params['LinMix']
    network_config['pos_features']  = params['pos_features']
    network_config['lr_shift']      = [1e-2]
    network_config['lr_w']          = [1e-3]
    network_config['lr_b']          = [1e-3]
    network_config['lr_m']          = [1e-3]
    if params['NoL1']:
        network_config['L1_alpha']  = None
        network_config['L1_alpham'] = None
    else:
        network_config['L1_alpha']  = .0001
        network_config['L1_alpham'] = None

    if params['NoL2']:
        network_config['L2_lambda']   = 0
        network_config['L2_lambda_m'] = 0
    else:
        if single_trial:
            network_config['L2_lambda_m'] = np.logspace(-2, 3, 20)[1]
            network_config['L2_lambda']   = np.logspace(-2, 3, 20)[1]
        else:
            network_config['L2_lambda']   = tune.grid_search(np.logspace(-2, 3, 20))
            network_config['L2_lambda_m'] = tune.grid_search(np.logspace(-2, 3, 20))
    return network_config

def model_wrapper(ARGS,**kwargs):
    """Model Wrapper

    Args:
        ARGS (tuple): tuple containing the config dictionary and model class

    Returns:
        model : returns instantiated model of input class. 
    """
    config = ARGS[0]
    Model = ARGS[1]
    model = Model(config['in_features'],config['Ncells'],config,device=config['device'],**kwargs)
    return model

In [181]:
network_config = make_network_config(params,single_trial=True,device=device)
model = model_wrapper((network_config,ShifterNetwork))

In [182]:
model.to('cuda:0')

ShifterNetwork(
  (Cell_NN): Sequential(
    (0): Linear(in_features=1200, out_features=108, bias=True)
  )
  (activations): ModuleDict(
    (SoftPlus): Softplus(beta=1, threshold=20)
    (ReLU): ReLU()
  )
  (shifter_nn): Sequential(
    (0): Linear(in_features=3, out_features=20, bias=True)
    (1): Softplus(beta=1, threshold=20)
    (2): Linear(in_features=20, out_features=3, bias=True)
  )
)

In [183]:
minibatch = next(iter(train_dataloader))
vid,pos,y = minibatch
vid,pos,y = vid.to(device),pos.to(device),y.to(device)

In [185]:
outputs = model(vid,pos)

In [100]:
# checkpoint = torch.load(list(params['save_dir'].glob('GLM_Pytorch_BestShift*'))[0])
filename = list(params['save_dir'].glob('GLM_Pytorch_BestShift*'))[0]

In [167]:

def load_model(model,params,filename,meanbias=None):
    """ Load model parameters

    Args:
        model (nn.Module): _description_
        params (dict): _description_
        filename (str): _description_
        meanbias (Tensor, optional): _description_. Defaults to None.

    Returns:
        _type_: _description_
    """
    checkpoint = torch.load(filename)
    state_dict = model.state_dict()
    for key in state_dict.keys():
        if ('posNN' not in key) & ('shifter_nn' not in key):
            if 'weight' in key:
                state_dict[key] = checkpoint['model_state_dict'][key].repeat(1,params['nt_glm_lag'])
            else:
                state_dict[key] = checkpoint['model_state_dict'][key]
    if (params['SimRF']==True):
        SimRF_file = params['save_dir'].parent.parent.parent/'121521/SimRF/fm1/SimRF_withmodel_dt050_T01_Modemodel_NB10000_Kfold00_best.h5'
        SimRF_data = ioh5.load(SimRF_file)
        model.Cell_NN[0].weight.data = torch.from_numpy(SimRF_data['sta'].astype(np.float32).T)
        model.Cell_NN[0].bias.data = torch.from_numpy(SimRF_data['bias_sim'].astype(np.float32))
    if meanbias is not None:
        state_dict = model.state_dict()
        state_dict['Cell_NN.0.bias']=meanbias
    model.load_state_dict(state_dict)
    return model

def setup_model_training(model,params,network_config):
    """Set up optimizer and scheduler for training

    Args:
        model (nn.Module): Network model to train
        params (dict): key parameters 
        network_config (diot): dictionary of hyperparameters

    Returns:
        optimizer: pytorch optimizer
        scheduler: learning rate scheduler
    """
    param_list = []
    if params['train_shifter']:
        param_list.append({'params': list(model.shifter_nn.parameters()),'lr': network_config['lr_shift'],'weight_decay':.0001})
    for name,p in model.named_parameters():
        if params['ModelID']<2:
            if ('Cell_NN' in name):
                if ('weight' in name):
                    param_list.append({'params':[p],'lr':network_config['lr_w'],'weight_decay':network_config['L2_lambda']})
                elif ('bias' in name):
                    param_list.append({'params':[p],'lr':network_config['lr_b']})
        elif params['ModelID']>1:
            if ('posNN' in name):
                if ('weight' in name):
                    param_list.append({'params':[p],'lr':network_config['lr_w'],'weight_decay':network_config['L2_lambda_m']})
                elif ('bias' in name):
                    param_list.append({'params':[p],'lr':network_config['lr_b']})

    optimizer = optim.Adam(params=param_list)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(params['Nepochs']/5))
    return optimizer, scheduler

In [169]:
network_config = make_network_config(params,single_trial=True)
model = model_wrapper((network_config,ShifterNetwork))
model = load_model(model,params,filename,meanbias=meanbias)
optimizer, scheduler = setup_model_training(model,params,network_config)

Cell_NN.0.weight torch.Size([108, 1200])
Cell_NN.0.bias torch.Size([108])
shifter_nn.0.weight torch.Size([20, 3])
shifter_nn.0.bias torch.Size([20])
shifter_nn.2.weight torch.Size([3, 20])
shifter_nn.2.bias torch.Size([3])


In [None]:
def train_network(params, checkpoint_dir=None, data_dir=None):
    network_config = make_network_config(params,single_trial=True)
    model = model_wrapper((network_config,ShifterNetwork))

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
            

    model.to(device)

    optimizer, scheduler = setup_model_training(model,params,network_config)

    if checkpoint_dir:
        model = load_model(model,params,filename,meanbias=meanbias)

    xtr, xte, xtr_pos, xte_pos, ytr, yte, meanbias = format_pytorch_data(data,params,train_idx,test_idx)
    train_dataset = FreeMovingEphysDataset(xtr,xtr_pos,ytr)
    test_dataset  = FreeMovingEphysDataset(xte,xte_pos,yte)
    train_dataloader = DataLoader(train_dataset, batch_size=xtr.shape[0],num_workers=2,pin_memory=True,)
    test_dataloader = DataLoader(test_dataset, batch_size=xte.shape[0],num_workers=2,pin_memory=True,)

    for epoch in range(10):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(train_dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            vid,pos,y = data
            vid,pos,y = vid.to(device),pos.to(device),y.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = model.loss(outputs, y)
            loss.backward(torch.ones_like(y))
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
                                                running_loss / epoch_steps))
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(test_dataloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = model.loss(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((model.state_dict(), optimizer.state_dict()), path)

        tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
    print("Finished Training")

In [165]:
vloss_trace = np.zeros((Nepochs, ytr.shape[-1]), dtype=np.float32)
tloss_trace = np.zeros((Nepochs, ytr.shape[-1]), dtype=np.float32)
Epoch_GLM = {}
if track_all:
    for name, p in l1.named_parameters():
        Epoch_GLM[name] = np.zeros((Nepochs,) + p.shape, dtype=np.float32)

if pbar is None:
    pbar = pbar2 = tqdm(np.arange(Nepochs))
else:
    pbar2 = np.arange(Nepochs)
for batchn in pbar2:
    out = l1(xtr, xtrm, shift_in_tr)
    loss = l1.loss(out, ytr)
    pred = l1(xte, xtem, shift_in_te)
    val_loss = l1.loss(pred, yte)
    vloss_trace[batchn] = val_loss.clone().cpu().detach().numpy()
    tloss_trace[batchn] = loss.clone().cpu().detach().numpy()
    pbar.set_description('Loss: {:.03f}'.format(np.nanmean(val_loss.clone().cpu().detach().numpy())))
    pbar.refresh()
    optimizer.zero_grad()
    loss.backward(torch.ones_like(loss))
    optimizer.step()
    if scheduler is not None:
        scheduler.step()
    if track_all:
        for name, p in l1.named_parameters():
            Epoch_GLM[name][batchn] = p.clone().cpu().detach().numpy()

In [50]:
out.shape

torch.Size([18972, 108])

In [None]:
vid,pos,Y = next(iter(train_dataloader))