# Day 1 ECoG Seq2seq

Michael Nolan

2020.09.16

In [None]:
import os.path as path
from os import makedirs, chmod
import glob
import functools

import time
import datetime
import tqdm

import aopy
import ecog_is2s

import numpy as np
import torch
import torch.optim as optim
from torchvision.transforms import Compose

import matplotlib.pyplot as plt

## Getting Data squared away
Define dataset from list of files, then get train/valid/test loaders

In [None]:
# file list to dataset
data_path_root = 'C:\\Users\\mickey\\aoLab\\Data\\WirelessData\\Goose_Multiscale_M1'
data_path_day = path.join(data_path_root,'180326')
data_file_list = glob.glob(path.join(data_path_day,'*\\*ECOG*clfp_ds250_fl0u10.dat'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('mounting to device: {}'.format(device))
print(f'files found:\t{len(data_file_list)}')
print(f'files: {data_file_list}')
datafile_list = [aopy.data.DataFile(df) for df in data_file_list]

In [None]:
src_t = 0.5
trg_t = 0.2
step_t = src_t+trg_t
diff_transform = ecog_is2s.Util.add_signal_diff() # no need for the srate parameter, dx est. is z-scored as well
zscore_transform = ecog_is2s.Util.local_zscore()
transform = lambda sample : diff_transform(zscore_transform(sample))
dfds_list = [aopy.data.DatafileDataset(df,src_t,trg_t,step_t,device=device) for df in datafile_list]
datafile_concatdataset = aopy.data.DatafileConcatDataset(dfds_list,transform=transform)

In [None]:
partition = (4,1,1)
batch_size = 500
train_loader, valid_loader, test_loader = datafile_concatdataset.get_data_loaders(partition=partition,batch_size=batch_size,rand_part=True)

## Create model
Construct seq2seq network and training apparatus

In [None]:
n_ch = datafile_concatdataset.n_ch
n_unit = 2**10
n_layers = 1
dropout = 0.0
use_diff = True
bidirectional = False
model = ecog_is2s.Seq2Seq.Seq2Seq_GRU(input_dim=n_ch,hid_dim=n_unit,n_layers=n_layers,enc_len=0,dec_len=0,device=device,dropout=dropout,use_diff=use_diff,bidirectional=bidirectional).to(device)
print(f'The model has {ecog_is2s.Util.count_parameters(model):,} trainable parameters')
LOSS_OBJ = 'MSE'
LEARN_RATE = 1e-4
LR_SCHEDULE_FACTOR = 0.2
TFR = 0.0
CLIP = 2.0
criterion = ecog_is2s.Training.ECOGLoss(objective=LOSS_OBJ)
# criterion = torch.nn.modules.loss.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(),lr=LEARN_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=LR_SCHEDULE_FACTOR)

In [None]:
# prep constants
N_EPOCHS = 250
epoch_train_loss = np.zeros(N_EPOCHS)
epoch_valid_loss = np.zeros(N_EPOCHS)
best_train_loss = np.inf
best_valid_loss = np.inf
epoch_time = np.zeros(N_EPOCHS)
# create session directory
model_save_dir = "D:\\Users\\mickey\\Data\\models\\pyt\\seq2seq"
# model_dir = "enc1.0_dec0.5_srate250_20200924221626"
model_dir = f"enc{src_t}_dec{trg_t}_srate{datafile_concatdataset.srate}_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
checkpoint_save_path = path.join(model_save_dir,model_dir)
if path.exists(checkpoint_save_path): # continue training an existing model
    # load previous checkpoint, if you can, and initialize the model
    checkpoint_state_dict = torch.load(path.join(checkpoint_save_path,'checkpoint.pt'))
    if 'model_state_dict' in checkpoint_state_dict: # modern save format
        model.load_state_dict(checkpoint_state_dict['model_state_dict'])
        optimizer.load_state_dict(checkpoint_state_dict['optimizer_state_dict'])
        criterion.load_state_dict(checkpoint_state_dict['criterion_state_dict'])
        epoch_start = checkpoint_state_dict['epoch']
    else: # old save format, model state dict only
        model.load_state_dict(checkpoint_state_dict)
        epoch_start = 0
    model.to(device) # device load may not be necessary, JIC
    # future implementations: save the optimizer, scheduler and criterion states as well.
else: # train a new model
    makedirs(checkpoint_save_path,mode=0o777)
    epoch_start = 0
for epoch_idx, epoch in tqdm.tqdm(enumerate(range(epoch_start,epoch_start+N_EPOCHS))):
    _t = time.time()
    # get data loaders
    train_loader, valid_loader, test_loader = datafile_concatdataset.get_data_loaders(partition=partition,batch_size=batch_size,rand_part=True)
    # forward pass
    _, trbl_ = model.train_iter(train_loader,optimizer,criterion,clip=CLIP,teacher_forcing_ratio=TFR)
    epoch_train_loss[epoch_idx] = trbl_.mean()
    _, vabl_ = model.eval_iter(valid_loader,criterion)
    epoch_valid_loss[epoch_idx] = vabl_.mean()
    epoch_time[epoch_idx] = time.time() - _t
    # if validation loss has decreased, save the checkpoint
    if epoch_valid_loss[epoch_idx] < best_valid_loss:
        best_valid_loss = epoch_valid_loss[epoch_idx]
        checkpoint_file_path = path.join(checkpoint_save_path,'checkpoint.pt')
        checkpoint_state_dict = {
            'model_state_dict' : model.state_dict(),
            'optimizer_state_dict' : optimizer.state_dict(),
            'scheduler_state_dict' : scheduler.state_dict(),
            'criterion_state_dict' : criterion.state_dict(),
            'epoch' : epoch,
            'valid_loss' : best_valid_loss,
            'input_dim' : model.input_dim,
            'hid_dim' : model.hid_dim,
            'bidir' : model.bidirectional,
            'use_diff' : model.use_diff
        }
        with open(checkpoint_file_path,'wb') as f:
            torch.save(checkpoint_state_dict,f)
    # update the scheduler
    scheduler.step(epoch_valid_loss[epoch_idx],epoch_idx)
    # update the loss figure
    loss_fig,loss_ax = plt.subplots(1,1,dpi=100)
    loss_ax.plot(epoch_train_loss[:epoch_idx+1],'b.',label='train')
    loss_ax.plot(epoch_valid_loss[:epoch_idx+1],'r.',label='validation')
    loss_ax.set_xlabel('epochs')
    loss_ax.set_ylabel('loss (MSE)')
    loss_ax.set_title('Learning Plots, Seq2seq model')
    loss_ax.legend(loc=0)
    loss_fig.savefig(path.join(checkpoint_save_path,'training_loss.png'))
    plt.close(loss_fig)
