# Import Modules

In [2]:
%load_ext autoreload
%autoreload 2

import warnings 
warnings.filterwarnings('ignore')
import logging
import numpy as np
import matplotlib.pyplot as plt

import ray
from ray import tune
from ray import air

import torch
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##### Custom package #####
import pytorchGLM as pglm
from pytorchGLM.main.training import train_network

##### Plotting settings ######
import matplotlib as mpl

mpl.rcParams.update({'font.size':         10,
                     'axes.linewidth':    2,
                     'xtick.major.size':  3,
                     'xtick.major.width': 2,
                     'ytick.major.size':  3,
                     'ytick.major.width': 2,
                     'axes.spines.right': False,
                     'axes.spines.top':   False,
                     'font.sans-serif':   "Arial",
                     'font.family':       "sans-serif",
                     'pdf.fonttype':      42,
                     'xtick.labelsize':   10,
                     'ytick.labelsize':   10,
                     'figure.facecolor': 'white'

                    })


# Format Data

## Loading  Niell lab Raw Data

load_aligned_data calls 2 functions when preprocessing the raw data: 
- format_raw_data: formats the raw data based on file_dict and params
- interp_raw_data: interpolates the formamted data from format_raw_data

In [None]:
# Input arguments
args = pglm.arg_parser(jupyter=True)

##### Modify default argments if needed #####
dates_all = ['070921/J553RT' ,'101521/J559NC','102821/J570LT','110421/J569LT'] #,'122021/J581RT','020422/J577RT'] # '102621/J558NC' '062921/G6HCK1ALTRN',
args['date_ani']        = dates_all[0]
args['free_move']       = True
args['train_shifter']   = True
args['Nepochs']         = 10000

ModelID = 1
params, file_dict, exp = pglm.load_params(args,ModelID,file_dict=None,exp_dir_name=None,nKfold=0,debug=False)


In [None]:
data = pglm.load_aligned_data(file_dict, params, reprocess=False)
params = pglm.get_modeltype(params)
datasets, network_config = pglm.load_datasets(file_dict,params,single_trial=True)


In [None]:
x,xpos,y = datasets['xtr'][:10],datasets['xtr_pos'][:10],datasets['ytr'][:10]
print(x.shape,xpos.shape,y.shape)

## Custom Dataset Formatting Base Model

This section is dedicated to formatting any custom datasets. The key components are: 
- Formatting data
    - Train/Test Splits
    - Inputs (time,in_features)
    - Optional Inputs (time,pos_features)
    - Outputs (time,Ncells)
- network_config
    - in_features: input dims
    - Ncells: output dims
    - initW: How to initialize weights, 'zero' or 'normal' 
    - optimizer: optmimizer to use: 'adam' or 'sgd'
    - lr_w: learning rate for weights
    - lr_b: learning rate for bias
    - lr_m: learning rate for additional inputs
    - single_trial: flag for single trial or hyperparam seach
    - L1_alpha: L1 regularization parameter. Single value or hyperparam search
    - L1_alpham: L1 regularization parameter. Single value or hyperparam search
    - L2_lambda: L2 regularization parameter. Single value or hyperparam search
    - L2_lambda_m: L2 regularization parameter. Single value or hyperparam search

In [None]:
args = pglm.arg_parser(jupyter=True)
args

In [5]:

def load_BaseModel_params(args,exp_dir_name='Testing',ModelID=0,nKfold=0,debug=False):
    """ Load parameter dictionary for custom BaseModel network. Minimal implementation 
        adabpting to custom datasets

    Args:
        args (dict): Argument dictionary 
        exp_dir_name (str): name of experiment. 
        ModelID (int, optional): Model Identification number. Defaults to 0.
        exp_dir_name (str, optional): Optional experiment directory name if using own data. Defaults to None.
        nKfold (int, optional): Kfold number for versioning. Defaults to 0.
        debug (bool, optional): debug=True does not create experiment directories. Defaults to False.

    Returns:
        params (dict): dictionary of parameters
        exp (obj): Test_tube object for organizing files and tensorboard
    """
    import yaml
    from pathlib import Path
    from test_tube import Experiment
    
    ##### Create directories and paths #####
    date_ani2 = '_'.join(args['date_ani'].split('/'))
    data_dir = Path(args['data_dir']).expanduser() / args['date_ani'] / args['stim_cond'] 
    base_dir = Path(args['base_dir']).expanduser()
    save_dir = (base_dir / args['date_ani'] / args['stim_cond'])
    save_dir.mkdir(parents=True, exist_ok=True)
    base_dir.mkdir(parents=True, exist_ok=True)

    ##### Set up test_tube versioning #####
    exp = Experiment(name='ModelID{}'.format(ModelID),
                        save_dir=save_dir / exp_dir_name, 
                        debug=debug,
                        version=nKfold)

    save_model = exp.save_dir / exp.name / 'version_{}'.format(nKfold)

    params = {
        ##### Data Parameters #####
        'data_dir':                 data_dir,
        'base_dir':                 base_dir,
        'exp_name_base':            base_dir.name,
        'stim_cond':                args['stim_cond'],
        'save_dir':                 save_dir,
        'exp_name':                 exp.save_dir.name,
        'save_model':               save_model,
        'date_ani2':                date_ani2,
        'model_dt':                 args['model_dt'],
        ##### Model Parameters #####
        'ModelID':                  ModelID,
        'lag_list':                 [0], # List of which timesteps to include in model fit
        'Nepochs':                  args['Nepochs'],
        'Kfold':                    args['Kfold'],
        'NoL1':                     args['NoL1'],
        'NoL2':                     args['NoL2'],
        'initW':                    'zero',
        'train_shifter':            False,
        'model_type':               'pytorchGLM_custom', # For naming files
    }

    params['nt_glm_lag']=len(params['lag_list']) # number of timesteps for model fits
    params['data_name'] = '_'.join([params['date_ani2'],params['stim_cond']])
    
    ##### Saves yaml of parameters #####
    if debug==False:
        params2=params.copy()
        for key in params2.keys():
            if isinstance(params2[key], Path):
                params2[key]=params2[key].as_posix()

        pfile_path = save_model / 'model_params.yaml'
        with open(pfile_path, 'w') as file:
            doc = yaml.dump(params2, file, sort_keys=True)

    return params, exp


In [None]:
# Input arguments
args = pglm.arg_parser(jupyter=True)

##### Modify default argments if needed #####
args['base_dir']        = '~/Research/SensoryMotorPred_Data/Testing'
args['fig_dir']         = '~/Research/SensoryMotorPred_Data/FigTesting'
args['data_dir']        = '~/Goeppert/nlab-nas/Dylan/freely_moving_ephys/ephys_recordings/'
args['date_ani']        = '011523/TestAni'
args['stim_cond']       = 'Control'
args['Nepochs']         = 50
args['NoL1']            = True
args['NoL2']            = False
args['model_dt']        = 0

params, exp = load_BaseModel_params(args=args,exp_dir_name='CustomData',ModelID=0)

In [7]:
from sklearn.gaussian_process.kernels import RBF

seed = 2
np.random.seed(seed)
torch.manual_seed(seed)

def initialize_GP_inputs(Npats,length_scale,batch_size,Nx_low,Nx,Ny_star,Nr,seed=42,multi_input=False,pytorch=True):
    
    ##### Set random seed #####
    np.random.seed(seed+1)
    torch.manual_seed(seed+1)
    ##### Initialize RBF kernels #####
    rbf = RBF(length_scale=length_scale)
    genX = np.arange(Npats)[:,np.newaxis]
    genY = np.arange(Npats)[:,np.newaxis]
    Kx = rbf(genX,genX)
    Ky = rbf(genY,genY)
    if multi_input:
        ##### Initialize inputs #####
        x_low0 = torch.transpose(torch.from_numpy(np.random.multivariate_normal(np.zeros(Npats), Kx,size=(batch_size,Nx_low))),2,1).float()
        x_low1 = torch.transpose(torch.from_numpy(np.random.multivariate_normal(np.zeros(Npats), Kx,size=(batch_size,Nx_low))),2,1).float()
        x_expand = torch.randn(size=(batch_size,Nx_low,Nx)).float()
        x0 = torch.bmm(x_low0,x_expand)
        x1 = torch.bmm(x_low1,x_expand)
        x_all = torch.stack((x0,x1),dim=1).float()
        ##### Initialize target patterns #####
        y_all = torch.transpose(torch.from_numpy(np.random.multivariate_normal(np.zeros(Npats), Ky,size=(1,Ny_star,Nr))),3,2)
        y_all = ((y_all/torch.max(torch.max(torch.abs(y_all),dim=1,keepdim=True)[0],dim=2,keepdim=True)[0]).repeat(batch_size,1,1,1))
    else:
        ##### Initialize inputs #####
        x_low0 = torch.transpose(torch.from_numpy(np.random.multivariate_normal(np.zeros(Npats), Kx,size=(batch_size,Nx_low))),2,1).float()
        x_expand = torch.randn(size=(batch_size,Nx_low,Nx)).float()
        x_all = torch.bmm(x_low0,x_expand)#.numpy()
        # x_all = torch.from_numpy((x_all - np.nanmean(x_all,axis=0))/np.nanstd(x_all,axis=0)).float()
        ##### Initialize target patterns #####
        y_all = torch.from_numpy(np.random.multivariate_normal(np.zeros(Npats), Ky,size=(1,Nr)))
        y_all = torch.transpose((y_all/torch.max(torch.max(torch.abs(y_all),dim=1,keepdim=True)[0],dim=2,keepdim=True)[0]).repeat(batch_size,1,1),-1,-2)

    if pytorch:
        x_all = x_all.float()
        y_all = y_all.float()
    else:
        x_all = x_all.float().numpy()
        y_all = y_all.float().numpy()

    return x_all, y_all


In [None]:
from sklearn.model_selection import GroupShuffleSplit

##### Generating data #####
x_all,y_all = initialize_GP_inputs(Npats=1000,length_scale=5,batch_size=1,Nx_low=2,Nx=100,Ny_star=2,Nr=10,pytorch=True)
x_all, y_all = x_all.squeeze(),y_all.squeeze()
y_all = (y_all+1)/2
x_all = (x_all - np.nanmean(x_all,axis=0))/np.nanstd(x_all,axis=0)

##### Train/Test Splits ####
gss = GroupShuffleSplit(n_splits=1, train_size=.8, random_state=42)
frac = 0.1
nT = x_all.shape[0]
groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((frac*(i-1))*nT)) for i in range(1,int(1/frac)+1)])
train_idx, test_idx = next(iter(gss.split(np.arange(x_all.shape[0]), groups=groups)))
# train_idx, test_idx = torch.from_numpy(train_idx), torch.from_numpy(test_idx)
xtr,xte = x_all[train_idx], x_all[test_idx]
xtr_pos,xte_pos = torch.zeros_like(xtr).float(),torch.zeros_like(xte).float()
ytr,yte = y_all[train_idx], y_all[test_idx]

print('X:',xtr.shape,'Xpos:',xtr_pos.shape,'y:',ytr.shape)
print('X:',xte.shape,'Xpos:',xte_pos.shape,'y:',yte.shape)
params['nk'] = xtr.shape[-1]
params['Ncells'] = ytr.shape[-1]
meanbias = torch.mean(y_all,dim=0)

xtr, xte, xtr_pos, xte_pos, ytr, yte, meanbias=xtr.to(device), xte.to(device), xtr_pos.to(device), xte_pos.to(device), ytr.to(device), yte.to(device), meanbias.to(device)
datasets = {
            'xtr':xtr,
            'xte':xte,
            'xtr_pos':xtr_pos,
            'xte_pos':xte_pos,
            'ytr':ytr,
            'yte':yte,
            'meanbias':meanbias,
        }


In [None]:
params['initW'] = 'normal' #'zero' # 'normal'
params['optimizer'] = 'sgd'
network_config = pglm.make_network_config(params,single_trial=0,custom=True)
network_config['lr_w'] = .001
network_config['lr_b'] = .1

In [None]:
tloss_trace,vloss_trace,model,optimizer = train_network(network_config,**datasets, params=params,filename=None)

In [None]:
##### Make prediction #####
yhat = model(xte.to(device),xte_pos.to(device)).detach().cpu().numpy().squeeze()
yt = yte.cpu().detach().numpy().squeeze()

In [None]:
fig, axs = plt.subplots(1,2,figsize=(8,4))
ax = axs[0]
cmap = pglm.discrete_cmap(vloss_trace.shape[-1],'jet')
for cell in range(vloss_trace.shape[-1]):
    ax.plot(vloss_trace[:,cell],c=cmap(cell))
ax.set_xlabel('iteration')
ax.set_ylabel('loss')
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=vloss_trace.shape[-1]))
cbar = fig.colorbar(sm,ax=ax,format=None,shrink=0.7,pad=0.01)
cbar.outline.set_linewidth(1)
cbar.set_label('output dim')
cbar.ax.tick_params(labelsize=12, width=1,direction='in')

ncell = 0
ax = axs[1]
ax.plot(yt[:,ncell])
ax.plot(yhat[:,ncell])
ax.set_xlabel('time')
ax.set_ylabel('activity')
ax.set_title('cc={:.03}'.format(np.corrcoef(yhat[:,ncell],yt[:,ncell])[1,0]))
plt.show()

In [None]:
from sklearn.linear_model import LinearRegression
x_all2,y_all2 = initialize_GP_inputs(Npats=1000,length_scale=5,batch_size=1,Nx_low=2,Nx=100,Ny_star=2,Nr=50,pytorch=False)
x_all2, y_all2 = x_all2.squeeze(),y_all2.squeeze()
y_all2 = (y_all2+1)/2
x_all2 = x_all2/np.max(np.abs(x_all2))
gss = GroupShuffleSplit(n_splits=1, train_size=.8, random_state=42)
frac = 0.1
nT = x_all2.shape[0]
groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((frac*(i-1))*nT)) for i in range(1,int(1/frac)+1)])
train_idx, test_idx = next(iter(gss.split(np.arange(x_all2.shape[0]), groups=groups)))
# train_idx, test_idx = torch.from_numpy(train_idx), torch.from_numpy(test_idx)
xtr2,xte2 = x_all2[train_idx], x_all2[test_idx]
xtr_pos2,xte_pos2 = np.zeros_like(xtr2),np.zeros_like(xte2)
ytr2,yte2 = y_all2[train_idx], y_all2[test_idx]


l1 = LinearRegression()
l1.fit(xtr2,ytr2)
yhat2 = l1.predict(xte2)
print('cc=',np.corrcoef(yhat2[:,ncell],yte2[:,ncell])[0,1])

In [None]:
plt.plot(yte2[:,ncell])
plt.plot(yhat2[:,ncell])

# Ray Tune Training: Parallel Cross Validation

In [3]:
import ray
from ray import tune
from ray.air import session
from ray.tune.search import ConcurrencyLimiter
from ray.tune.search.hyperopt import HyperOptSearch
from hyperopt import hp

In [8]:
# Input arguments
args = pglm.arg_parser(jupyter=True)

##### Modify default argments if needed #####
args['base_dir']        = '~/Research/SensoryMotorPred_Data/Testing'
args['fig_dir']         = '~/Research/SensoryMotorPred_Data/FigTesting'
args['data_dir']        = '~/Goeppert/nlab-nas/Dylan/freely_moving_ephys/ephys_recordings/'
args['date_ani']        = '011523/TestAni'
args['stim_cond']       = 'Control'
args['Nepochs']         = 50
args['NoL1']            = True
args['NoL2']            = True
args['model_dt']        = 0

params, exp = load_BaseModel_params(args=args,exp_dir_name='CustomData',ModelID=0)

from sklearn.model_selection import GroupShuffleSplit

##### Generating data #####
x_all,y_all = initialize_GP_inputs(Npats=1000,length_scale=5,batch_size=1,Nx_low=2,Nx=100,Ny_star=2,Nr=10,pytorch=True)
x_all, y_all = x_all.squeeze(),y_all.squeeze()
y_all = (y_all+1)/2
x_all = (x_all - np.nanmean(x_all,axis=0))/np.nanstd(x_all,axis=0)

##### Train/Test Splits ####
gss = GroupShuffleSplit(n_splits=1, train_size=.8, random_state=42)
frac = 0.1
nT = x_all.shape[0]
groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((frac*(i-1))*nT)) for i in range(1,int(1/frac)+1)])
train_idx, test_idx = next(iter(gss.split(np.arange(x_all.shape[0]), groups=groups)))
# train_idx, test_idx = torch.from_numpy(train_idx), torch.from_numpy(test_idx)
xtr,xte = x_all[train_idx], x_all[test_idx]
xtr_pos,xte_pos = torch.zeros_like(xtr).float(),torch.zeros_like(xte).float()
ytr,yte = y_all[train_idx], y_all[test_idx]

print('X:',xtr.shape,'Xpos:',xtr_pos.shape,'y:',ytr.shape)
print('X:',xte.shape,'Xpos:',xte_pos.shape,'y:',yte.shape)

params['nk'] = xtr.shape[-1]
params['Ncells'] = ytr.shape[-1]
meanbias = torch.mean(y_all,dim=0)
xtr, xte, xtr_pos, xte_pos, ytr, yte, meanbias=xtr.to(device), xte.to(device), xtr_pos.to(device), xte_pos.to(device), ytr.to(device), yte.to(device), meanbias.to(device)
datasets = {
            'xtr':xtr,
            'xte':xte,
            'xtr_pos':xtr_pos,
            'xte_pos':xte_pos,
            'ytr':ytr,
            'yte':yte,
            'meanbias':meanbias,
        }

params['initW'] = 'normal' #'zero' # 'normal'
params['optimizer'] = 'sgd'
network_config = pglm.make_network_config(params,custom=True)
network_config['lr_w'] = tune.loguniform(1e-4, 1e-2)
network_config['lr_b'] = tune.loguniform(1e-2, 1)

X: torch.Size([800, 100]) Xpos: torch.Size([800, 100]) y: torch.Size([800, 10])
X: torch.Size([200, 100]) Xpos: torch.Size([200, 100]) y: torch.Size([200, 10])


In [9]:
network_config

{'in_features': 100,
 'Ncells': 10,
 'initW': 'normal',
 'optimizer': 'sgd',
 'lr_w': <ray.tune.search.sample.Float at 0x7f31b6b54460>,
 'lr_b': <ray.tune.search.sample.Float at 0x7f31b6b54580>,
 'lr_m': 0.001,
 'single_trial': None,
 'L1_alpha': None,
 'L1_alpham': None,
 'L2_lambda': 0,
 'L2_lambda_m': 0}

In [10]:
initial_params = [
    {"lr_w": 0.001,"lr_b": 0.1, },
]
algo = HyperOptSearch(points_to_evaluate=initial_params)
algo = ConcurrencyLimiter(algo, max_concurrent=4)
num_samples = 10

In [11]:

sync_config = tune.SyncConfig()  # the default mode is to use use rsync
tuner = tune.Tuner(
    tune.with_resources(
        tune.with_parameters(train_network,**datasets, params=params),
        resources={"cpu": 2, "gpu": .5}),
    tune_config=tune.TuneConfig(metric="avg_loss",mode="min",search_alg=algo,num_samples=num_samples),
    param_space=network_config,
    run_config=air.RunConfig(local_dir=params['save_model'], name="NetworkAnalysis",sync_config=sync_config,verbose=2)
)
results = tuner.fit()

best_result = results.get_best_result("avg_loss", "min")

print("Best trial config: {}".format(best_result.config))
print("Best trial final validation loss: {}".format(best_result.metrics["avg_loss"]))
df = results.get_dataframe()
best_network = list(params['save_model'].glob('*{}.pt'.format(best_result.metrics['trial_id'])))[0]
pglm.h5store(params['save_model'] / 'NetworkAnalysis/experiment_data.h5', df, **{'best_network':best_network,'trial_id':best_result.metrics['trial_id']})

2023-01-16 11:10:00,723	INFO worker.py:1538 -- Started a local Ray instance.


0,1
Current time:,2023-01-16 11:10:11
Running for:,00:00:09.59
Memory:,12.4/125.8 GiB

Trial name,status,loc,L1_alpha,L1_alpham,L2_lambda,L2_lambda_m,Ncells,in_features,initW,lr_b,lr_m,lr_w,optimizer,single_trial,iter,total time (s),avg_loss
train_network_cbcf823f,TERMINATED,184.171.84.86:2519966,,,0,0,10,100,normal,0.1,0.001,0.001,sgd,,1,1.49916,0.0422347
train_network_be3f38ea,TERMINATED,184.171.84.86:2520119,,,0,0,10,100,normal,0.899085,0.001,0.00124164,sgd,,1,1.55057,0.0168743
train_network_ecc387d1,TERMINATED,184.171.84.86:2519966,,,0,0,10,100,normal,0.476843,0.001,0.000285234,sgd,,1,0.0834661,0.0219259
train_network_b18b4fc1,TERMINATED,184.171.84.86:2519966,,,0,0,10,100,normal,0.685615,0.001,0.000302728,sgd,,1,0.0822985,0.023392
train_network_cb357bf1,TERMINATED,184.171.84.86:2519966,,,0,0,10,100,normal,0.396142,0.001,0.000591262,sgd,,1,0.0869327,0.0163544
train_network_74f5e6a4,TERMINATED,184.171.84.86:2519966,,,0,0,10,100,normal,0.236181,0.001,0.00636346,sgd,,1,0.0857477,0.0157562
train_network_b5175c4a,TERMINATED,184.171.84.86:2519966,,,0,0,10,100,normal,0.0198365,0.001,0.0060693,sgd,,1,0.0882962,0.186705
train_network_435e2b79,TERMINATED,184.171.84.86:2519966,,,0,0,10,100,normal,0.0175796,0.001,0.00504152,sgd,,1,0.0849552,0.154375
train_network_a4840bad,TERMINATED,184.171.84.86:2519966,,,0,0,10,100,normal,0.0742141,0.001,0.000410058,sgd,,1,0.0852156,0.104674
train_network_75b66696,TERMINATED,184.171.84.86:2519966,,,0,0,10,100,normal,0.0597361,0.001,0.000651634,sgd,,1,0.0873818,0.0953736


Trial name,avg_loss,should_checkpoint
train_network_435e2b79,0.154375,True
train_network_74f5e6a4,0.0157562,True
train_network_75b66696,0.0953736,True
train_network_a4840bad,0.104674,True
train_network_b18b4fc1,0.023392,True
train_network_b5175c4a,0.186705,True
train_network_be3f38ea,0.0168743,True
train_network_cb357bf1,0.0163544,True
train_network_cbcf823f,0.0422347,True
train_network_ecc387d1,0.0219259,True


[2m[36m(train_network pid=2519966)[0m Finished Training
[2m[36m(train_network pid=2519966)[0m Finished Training
[2m[36m(train_network pid=2519966)[0m Finished Training
[2m[36m(train_network pid=2519966)[0m Finished Training
[2m[36m(train_network pid=2519966)[0m Finished Training
[2m[36m(train_network pid=2519966)[0m Finished Training
[2m[36m(train_network pid=2519966)[0m Finished Training
[2m[36m(train_network pid=2519966)[0m Finished Training
[2m[36m(train_network pid=2519966)[0m Finished Training


2023-01-16 11:10:11,723	INFO tune.py:762 -- Total run time: 9.83 seconds (9.56 seconds for the tuning loop).


[2m[36m(train_network pid=2520119)[0m Finished Training
Best trial config: {'in_features': 100, 'Ncells': 10, 'initW': 'normal', 'optimizer': 'sgd', 'lr_w': 0.00636345897537398, 'lr_b': 0.2361814700834408, 'lr_m': 0.001, 'single_trial': None, 'L1_alpha': None, 'L1_alpham': None, 'L2_lambda': 0, 'L2_lambda_m': 0}
Best trial final validation loss: 0.015756191685795784


In [12]:
df

Unnamed: 0,avg_loss,time_this_iter_s,should_checkpoint,done,timesteps_total,episodes_total,training_iteration,trial_id,experiment_id,date,...,config/L2_lambda_m,config/Ncells,config/in_features,config/initW,config/lr_b,config/lr_m,config/lr_w,config/optimizer,config/single_trial,logdir
0,0.042235,1.499163,True,False,,,1,cbcf823f,970c934a09444041a25f3189ee0da078,2023-01-16_11-10-07,...,0,10,100,normal,0.1,0.001,0.001,sgd,,/home/seuss/Research/SensoryMotorPred_Data/Tes...
1,0.016874,1.55057,True,False,,,1,be3f38ea,7722e2b2bca74907b6671a5764174e44,2023-01-16_11-10-11,...,0,10,100,normal,0.899085,0.001,0.001242,sgd,,/home/seuss/Research/SensoryMotorPred_Data/Tes...
2,0.021926,0.083466,True,False,,,1,ecc387d1,970c934a09444041a25f3189ee0da078,2023-01-16_11-10-10,...,0,10,100,normal,0.476843,0.001,0.000285,sgd,,/home/seuss/Research/SensoryMotorPred_Data/Tes...
3,0.023392,0.082299,True,False,,,1,b18b4fc1,970c934a09444041a25f3189ee0da078,2023-01-16_11-10-10,...,0,10,100,normal,0.685615,0.001,0.000303,sgd,,/home/seuss/Research/SensoryMotorPred_Data/Tes...
4,0.016354,0.086933,True,False,,,1,cb357bf1,970c934a09444041a25f3189ee0da078,2023-01-16_11-10-10,...,0,10,100,normal,0.396142,0.001,0.000591,sgd,,/home/seuss/Research/SensoryMotorPred_Data/Tes...
5,0.015756,0.085748,True,False,,,1,74f5e6a4,970c934a09444041a25f3189ee0da078,2023-01-16_11-10-10,...,0,10,100,normal,0.236181,0.001,0.006363,sgd,,/home/seuss/Research/SensoryMotorPred_Data/Tes...
6,0.186705,0.088296,True,False,,,1,b5175c4a,970c934a09444041a25f3189ee0da078,2023-01-16_11-10-10,...,0,10,100,normal,0.019837,0.001,0.006069,sgd,,/home/seuss/Research/SensoryMotorPred_Data/Tes...
7,0.154375,0.084955,True,False,,,1,435e2b79,970c934a09444041a25f3189ee0da078,2023-01-16_11-10-10,...,0,10,100,normal,0.01758,0.001,0.005042,sgd,,/home/seuss/Research/SensoryMotorPred_Data/Tes...
8,0.104674,0.085216,True,False,,,1,a4840bad,970c934a09444041a25f3189ee0da078,2023-01-16_11-10-10,...,0,10,100,normal,0.074214,0.001,0.00041,sgd,,/home/seuss/Research/SensoryMotorPred_Data/Tes...
9,0.095374,0.087382,True,False,,,1,75b66696,970c934a09444041a25f3189ee0da078,2023-01-16_11-10-10,...,0,10,100,normal,0.059736,0.001,0.000652,sgd,,/home/seuss/Research/SensoryMotorPred_Data/Tes...


In [None]:
state_dict,optim2=torch.load(best_network)