# Import Modules

In [1]:
%load_ext autoreload
%autoreload 2

import ray
import gc
import cv2
import time
import warnings 
import argparse
import yaml
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import xarray as xr

from ray import tune
from ray import air
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.tune.schedulers import ASHAScheduler
from filelock import FileLock
from tqdm.auto import tqdm
from pathlib import Path
from itertools import chain
from typing import Tuple
from asyncio import Event
from test_tube import Experiment
from matplotlib.backends.backend_pdf import PdfPages
from scipy.signal import medfilt
from scipy.stats import binned_statistic
from scipy.interpolate import interp1d
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import GroupShuffleSplit
from sklearn.utils import shuffle


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from kornia.geometry.transform import Affine
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# import pytorchGLM.Utils.io_dict_to_hdf5 as ioh5
# from pytorchGLM.Utils.utils import *
# from pytorchGLM.Utils.params import *
# from pytorchGLM.Utils.format_raw_data import *
# from pytorchGLM.Utils.format_model_data import *
# from pytorchGLM.main.models import *

In [1]:
import pytorchGLM as pglm

In [2]:
dir(pglm)

['DataLoader',
 'Dataset',
 'FreeMovingEphysDataset',
 'GroupShuffleSplit',
 'LinearRegression',
 'Path',
 'Utils',
 '__author__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '__version__',
 'add_colorbar',
 'arg_parser',
 'argparse',
 'chain',
 'check_path',
 'cv2',
 'discrete_cmap',
 'format_data',
 'format_model_data',
 'format_pytorch_data',
 'format_raw_data',
 'gc',
 'get_freer_gpu',
 'get_modeltype',
 'h5load',
 'h5store',
 'interp1d',
 'interp_nans',
 'interp_raw_data',
 'io_dict_to_hdf5',
 'ioh5',
 'load_Kfold_data',
 'load_aligned_data',
 'load_model',
 'load_params',
 'main',
 'make_network_config',
 'medfilt',
 'nan_helper',
 'nanxcorr',
 'nn',
 'normimgs',
 'np',
 'optim',
 'os',
 'params',
 'pd',
 'plt',
 'setup_model_training',
 'shuffle',
 'sizeof_fmt',
 'str_to_bool',
 'time',
 'torch',
 'tqdm',
 'tune',
 'utils',
 'xr',
 'yaml']

In [3]:
pglm.arg_parser(jupyter=True)

{'date_ani': '070921/J553RT',
 'base_dir': '~/Research/SensoryMotorPred_Data/Testing',
 'fig_dir': '~/Research/SensoryMotorPred_Data/FigTesting',
 'data_dir': '~/Goeppert/nlab-nas/Dylan/freely_moving_ephys/ephys_recordings/',
 'model_dt': 0.05,
 'ds_vid': 4,
 'Kfold': 0,
 'ModRun': '1',
 'Nepochs': 10,
 'load_ray': False,
 'do_norm': True,
 'crop_input': True,
 'free_move': True,
 'thresh_cells': True,
 'fm_dark': False,
 'NoL1': False,
 'NoL2': False,
 'NoShifter': False,
 'do_shuffle': False,
 'use_spdpup': False,
 'only_spdpup': False,
 'train_shifter': False,
 'shifter_5050': False,
 'shifter_5050_run': False,
 'EyeHead_only': False,
 'EyeHead_only_run': False,
 'SimRF': False}

Key Parameters:
- model_dt:     (float) size of time bins in seconds
- date_ani:     (str) date and animal ID
- base_dir:     (str) base directory
- save_dir:     (str) directory where processed data is going to be saved
- data_dir:     (str) directory where raw data is held
- downsamp_vid: (int) factor videos are downsampled by
- lag_list:     (list) which timesteps to include in fits

# Format Data

## Testing Loading Raw Data

In [4]:
# Input arguments
args = pglm.arg_parser(jupyter=True)

dates_all = ['070921/J553RT' ,'101521/J559NC','102821/J570LT','110421/J569LT'] #,'122021/J581RT','020422/J577RT'] # '102621/J558NC' '062921/G6HCK1ALTRN',
args['date_ani']        = dates_all[0]
args['free_move']       = True
args['train_shifter']   = True
args['NoL1']            = False
args['NoL2']            = False
args['do_shuffle']      = False
args['Nepochs']         = 10000


ModelID = 1
params, file_dict, exp = pglm.load_params(args,ModelID,file_dict=None,exp_dir_name=None,nKfold=0,debug=False)


## Prep data for pytorch

In [17]:
params['ModelID']=1
params['position_vars'] = ['th','phi','pitch','roll']#,'speed','eyerad']
params['train_shifter']=True


data = load_aligned_data(file_dict, params, reprocess=False)
data,train_idx_list,test_idx_list = format_data(data, params,do_norm=True,thresh_cells=True,cut_inactive=True)
train_idx = train_idx_list[0]
test_idx = test_idx_list[0]
data = load_Kfold_data(data,params,train_idx,test_idx)
xtr, xte, xtr_pos, xte_pos, ytr, yte, meanbias = format_pytorch_data(data,params,train_idx,test_idx)

train_dataset = FreeMovingEphysDataset(xtr,xtr_pos,ytr)
test_dataset  = FreeMovingEphysDataset(xte,xte_pos,yte)
train_dataloader = DataLoader(train_dataset, batch_size=xtr.shape[0],num_workers=2,pin_memory=True,)
test_dataloader = DataLoader(test_dataset, batch_size=xte.shape[0],num_workers=2,pin_memory=True,)


In [36]:
vid,pos,Y = next(iter(train_dataloader))

# Models

In [8]:
network_config = make_network_config(params,single_trial=True)
model = model_wrapper((network_config,ShifterNetwork))

In [9]:
model.to('cuda:0')

ShifterNetwork(
  (Cell_NN): Sequential(
    (0): Linear(in_features=1200, out_features=108, bias=True)
  )
  (activations): ModuleDict(
    (SoftPlus): Softplus(beta=1, threshold=20)
    (ReLU): ReLU()
  )
  (shifter_nn): Sequential(
    (0): Linear(in_features=3, out_features=20, bias=True)
    (1): Softplus(beta=1, threshold=20)
    (2): Linear(in_features=20, out_features=3, bias=True)
  )
)

In [183]:
minibatch = next(iter(train_dataloader))
vid,pos,y = minibatch
vid,pos,y = vid.to(device),pos.to(device),y.to(device)

In [185]:
outputs = model(vid,pos)

In [5]:
# checkpoint = torch.load(list(params['save_dir'].glob('GLM_Pytorch_BestShift*'))[0])
filename = list(params['save_dir'].glob('GLM_Pytorch_BestShift*'))[0]

In [6]:
params = get_modeltype(params)
network_config = make_network_config(params,single_trial=True)
if params['train_shifter']:
    model = model_wrapper((network_config,ShifterNetwork))
elif (params['ModelID']==2) | (params['ModelID']==3):
    model = model_wrapper((network_config,MixedNetwork))
    model = load_model(model,params,filename,meanbias=meanbias)
else:
    model = model_wrapper((network_config,BaseModel))
    model = load_model(model,params,filename,meanbias=meanbias)
optimizer, scheduler = setup_model_training(model,params,network_config)

In [16]:
params['ModelID']=1
params['position_vars'] = ['th','phi','pitch','roll']#,'speed','eyerad']
params['train_shifter']=True


#####

torch.Size([108])

# Test training

In [6]:

def load_datasets(file_dict,params,single_trial=False):

    data = load_aligned_data(file_dict, params, reprocess=False)
    data,train_idx_list,test_idx_list = format_data(data, params,do_norm=True,thresh_cells=True,cut_inactive=True)
    train_idx = train_idx_list[0]
    test_idx = test_idx_list[0]
    data = load_Kfold_data(data,params,train_idx,test_idx)
    xtr, xte, xtr_pos, xte_pos, ytr, yte, meanbias = format_pytorch_data(data,params,train_idx,test_idx)
    network_config = make_network_config(params,single_trial=single_trial)
    with FileLock(params['save_model']/'data.lock'):
        train_dataset = FreeMovingEphysDataset(xtr,xtr_pos,ytr)
        test_dataset  = FreeMovingEphysDataset(xte,xte_pos,yte)
    return train_dataset, test_dataset, network_config

def train_network(network_config={},params={},train_dataset=None,test_dataset=None):
    if params['train_shifter']:
        model = model_wrapper((network_config,ShifterNetwork))
    elif (params['ModelID']==2) | (params['ModelID']==3):
        model = model_wrapper((network_config,MixedNetwork))
        model = load_model(model,params,filename,meanbias=meanbias)
    else:
        model = model_wrapper((network_config,BaseModel))
        model = load_model(model,params,filename,meanbias=meanbias)

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
            
    model.to(device)

    optimizer, scheduler = setup_model_training(model,params,network_config)
    train_dataloader = DataLoader(train_dataset, batch_size=len(train_dataset), num_workers=2, pin_memory=True,)
    test_dataloader  = DataLoader(test_dataset,  batch_size=len(test_dataset),  num_workers=2, pin_memory=True,)

    tloss_trace = torch.zeros((params['Nepochs'], network_config['Ncells']), dtype=torch.float)
    vloss_trace = torch.zeros((params['Nepochs'], network_config['Ncells']), dtype=torch.float)

    for epoch in (range(params['Nepochs'])):  # loop over the dataset multiple times
        for i, minibatch in enumerate(train_dataloader, 0):
            # get the inputs; minibatch is a list of [vid, pos, y]
            vid,pos,y = minibatch
            vid,pos,y = vid.to(device),pos.to(device),y.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(vid,pos)
            loss = model.loss(outputs, y)
            loss.backward(torch.ones_like(loss))
            optimizer.step()

        # print statistics
        tloss_trace[epoch] = loss.detach().cpu()
            
        if scheduler is not None:
            scheduler.step()

        # Validation loss
        for i, minibatch in enumerate(test_dataloader, 0):
            with torch.no_grad():
                # get the inputs; minibatch is a list of [vid, pos, y]
                vid,pos,y = minibatch
                vid,pos,y = vid.to(device),pos.to(device),y.to(device)
                outputs = model(vid,pos)
                loss = model.loss(outputs, y)
                vloss_trace[epoch] = loss.detach().cpu()

    # Here we save a checkpoint. It is automatically registered with
    # Ray Tune and can be accessed through `session.get_checkpoint()`
    # API in future iterations.
    model_name = 'GLM_{}_ModelID{:d}_dt{:03d}_T{:02d}_NB{}_{}.pt'.format(params['model_type'], params['ModelID'],int(params['model_dt']*1000), params['nt_glm_lag'], params['Nepochs'],session.get_trial_name())
    torch.save((model.state_dict(), optimizer.state_dict()), params['save_model']/ model_name)
    checkpoint = Checkpoint.from_dict({'step':epoch})
    # session.report({"avg_loss": float(torch.mean(vloss_trace[-1],dim=-1).numpy())})
    session.report({'avg_loss':float(torch.mean(vloss_trace[-1],dim=-1).numpy())}, checkpoint=checkpoint)

    print("Finished Training")
    # return dict(avg_loss=float(torch.mean(vloss_trace[-1],dim=-1).numpy()))
    

In [5]:
params = pglm.get_modeltype(params)
params['Nepochs'] = 10
train_dataset, test_dataset, network_config = pglm.load_datasets(file_dict,params,single_trial=False)

AttributeError: module 'pytorchGLM' has no attribute 'load_datasets'

In [8]:
sync_config = tune.SyncConfig()  # the default mode is to use use rsync
tuner = tune.Tuner(
    tune.with_resources(
        tune.with_parameters(train_network,params=params,train_dataset=train_dataset,test_dataset=test_dataset),
        resources={"cpu": 2, "gpu": .5}),
    tune_config=tune.TuneConfig(metric="avg_loss",mode="min",),
    param_space=network_config,
    run_config=air.RunConfig(local_dir=params['save_model'], name="test_experiment",sync_config=sync_config,verbose=2)
)
results = tuner.fit()

best_result = results.get_best_result("avg_loss", "min")

print("Best trial config: {}".format(best_result.config))
print("Best trial final validation loss: {}".format(best_result.metrics["avg_loss"]))
df = results.get_dataframe()
df.to_hdf(params['save_model']/'experiment_data.h5',key='df', mode='w')
best_network = list(params['save_model'].glob('*{}.pt'.format(best_result.metrics['trial_id'])))[0]


2023-01-09 02:40:09,396	INFO worker.py:1538 -- Started a local Ray instance.


0,1
Current time:,2023-01-09 02:40:43
Running for:,00:00:32.46
Memory:,17.1/125.8 GiB

Trial name,status,loc,L2_lambda,L2_lambda_m,iter,total time (s),avg_loss
train_network_fdf98_00000,TERMINATED,184.171.84.86:4193152,0.01,0.01,1,16.5794,1.03623
train_network_fdf98_00001,TERMINATED,184.171.84.86:4193256,1000.0,0.01,1,15.3057,1.11026
train_network_fdf98_00002,TERMINATED,184.171.84.86:4193152,0.01,1000.0,1,11.4374,1.03299
train_network_fdf98_00003,TERMINATED,184.171.84.86:4193256,1000.0,1000.0,1,10.5793,1.08691


  0%|          | 0/10 [00:00<?, ?it/s]0m 
 10%|█         | 1/10 [00:03<00:29,  3.30s/it]
  0%|          | 0/10 [00:00<?, ?it/s]0m 
 20%|██        | 2/10 [00:04<00:16,  2.03s/it]
 30%|███       | 3/10 [00:05<00:11,  1.65s/it]
 10%|█         | 1/10 [00:02<00:25,  2.84s/it]
 40%|████      | 4/10 [00:06<00:08,  1.43s/it]
 20%|██        | 2/10 [00:03<00:14,  1.79s/it]
 50%|█████     | 5/10 [00:07<00:06,  1.31s/it]
 30%|███       | 3/10 [00:04<00:09,  1.41s/it]
 60%|██████    | 6/10 [00:08<00:04,  1.23s/it]
 40%|████      | 4/10 [00:05<00:07,  1.24s/it]
 70%|███████   | 7/10 [00:10<00:03,  1.22s/it]
 50%|█████     | 5/10 [00:06<00:05,  1.15s/it]
 60%|██████    | 6/10 [00:07<00:04,  1.12s/it]
 80%|████████  | 8/10 [00:11<00:02,  1.22s/it]
 90%|█████████ | 9/10 [00:12<00:01,  1.17s/it]
 70%|███████   | 7/10 [00:08<00:03,  1.11s/it]
 80%|████████  | 8/10 [00:10<00:02,  1.10s/it]


Trial name,avg_loss,should_checkpoint
train_network_fdf98_00000,1.03623,True
train_network_fdf98_00001,1.11026,True
train_network_fdf98_00002,1.03299,True
train_network_fdf98_00003,1.08691,True


100%|██████████| 10/10 [00:13<00:00,  1.34s/it]


[2m[36m(train_network pid=4193152)[0m Finished Training


  0%|          | 0/10 [00:00<?, ?it/s]0m 
 90%|█████████ | 9/10 [00:11<00:01,  1.07s/it]
 10%|█         | 1/10 [00:01<00:09,  1.11s/it]
100%|██████████| 10/10 [00:12<00:00,  1.21s/it]


[2m[36m(train_network pid=4193256)[0m Finished Training


  0%|          | 0/10 [00:00<?, ?it/s]0m 
 20%|██        | 2/10 [00:02<00:08,  1.12s/it]
 10%|█         | 1/10 [00:01<00:09,  1.08s/it]
 30%|███       | 3/10 [00:03<00:07,  1.12s/it]
 20%|██        | 2/10 [00:02<00:08,  1.04s/it]
 40%|████      | 4/10 [00:04<00:06,  1.12s/it]
 30%|███       | 3/10 [00:03<00:07,  1.02s/it]
 50%|█████     | 5/10 [00:05<00:05,  1.13s/it]
 40%|████      | 4/10 [00:04<00:06,  1.08s/it]
 60%|██████    | 6/10 [00:06<00:04,  1.15s/it]
 50%|█████     | 5/10 [00:05<00:05,  1.06s/it]
 70%|███████   | 7/10 [00:07<00:03,  1.13s/it]
 60%|██████    | 6/10 [00:06<00:04,  1.06s/it]
 80%|████████  | 8/10 [00:09<00:02,  1.13s/it]
 70%|███████   | 7/10 [00:07<00:03,  1.05s/it]
 90%|█████████ | 9/10 [00:10<00:01,  1.11s/it]
 80%|████████  | 8/10 [00:08<00:02,  1.04s/it]
100%|██████████| 10/10 [00:11<00:00,  1.12s/it]


[2m[36m(train_network pid=4193152)[0m Finished Training


 90%|█████████ | 9/10 [00:09<00:01,  1.03s/it]
100%|██████████| 10/10 [00:10<00:00,  1.03s/it]
2023-01-09 02:40:43,435	INFO tune.py:762 -- Total run time: 32.64 seconds (32.45 seconds for the tuning loop).


[2m[36m(train_network pid=4193256)[0m Finished Training
Best trial config: {'in_features': 1200, 'Ncells': 108, 'shift_in': 3, 'shift_hidden': 20, 'shift_out': 3, 'LinMix': False, 'pos_features': 3, 'lr_shift': 0.01, 'lr_w': 0.001, 'lr_b': 0.001, 'lr_m': 0.001, 'L1_alpha': None, 'L1_alpham': None, 'L2_lambda': 0.01, 'L2_lambda_m': 1000.0}
Best trial final validation loss: 1.032989501953125


In [165]:
torch.save((model.state_dict(), optimizer.state_dict()), params['save_model']/ "checkpoint.pt")
checkpoint = Checkpoint.from_directory("my_model")

In [25]:
results.get_dataframe()['config/L2_lambda_m']

0       0.01
1       0.01
2    1000.00
3    1000.00
Name: config/L2_lambda_m, dtype: float64

In [9]:
results.get_best_result(metric='loss',mode='min',filter_nan_and_inf=False)



RuntimeError: No best trial found for the given metric: loss. This means that no trial has reported this metric.

In [16]:
results.get_best_result().metrics_dataframe()



RuntimeError: No best trial found for the given metric: loss. This means that no trial has reported this metric, or all values reported for this metric are NaN. To not ignore NaN values, you can set the `filter_nan_and_inf` arg to False.

In [None]:
params = get_modeltype(params)
train_dataset, test_dataset, network_config = load_datasets(file_dict,params,single_trial=True)

if params['train_shifter']:
    model = model_wrapper((network_config,ShifterNetwork))
elif (params['ModelID']==2) | (params['ModelID']==3):
    model = model_wrapper((network_config,MixedNetwork))
    model = load_model(model,params,filename,meanbias=meanbias)
else:
    model = model_wrapper((network_config,BaseModel))
    model = load_model(model,params,filename,meanbias=meanbias)

device = "cpu"
if torch.cuda.is_available():
    device = "cuda:0"
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        
model.to(device)

optimizer, scheduler = setup_model_training(model,params,network_config)
train_dataloader = DataLoader(train_dataset, batch_size=len(train_dataset), num_workers=2, pin_memory=True,)
test_dataloader  = DataLoader(test_dataset,  batch_size=len(test_dataset),  num_workers=2, pin_memory=True,)
params['Nepochs']=10

tloss_trace = torch.zeros((params['Nepochs'], network_config['Ncells']), dtype=torch.float)
vloss_trace = torch.zeros((params['Nepochs'], network_config['Ncells']), dtype=torch.float)

for epoch in tqdm(range(params['Nepochs'])):  # loop over the dataset multiple times
    for i, minibatch in enumerate(train_dataloader, 0):
        # get the inputs; minibatch is a list of [vid, pos, y]
        vid,pos,y = minibatch
        vid,pos,y = vid.to(device),pos.to(device),y.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(vid,pos)
        loss = model.loss(outputs, y)
        loss.backward(torch.ones_like(loss))
        optimizer.step()

    # print statistics
    tloss_trace[epoch] = loss.detach().cpu()
        
    if scheduler is not None:
        scheduler.step()

Tot_units: (128,)
Good_units: (108,)


  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
vid,pos,Y = next(iter(train_dataloader))