# Imports and Setup

In [1]:
# Imports for Tensor
import csv
import math
import numpy as np
import os
import sys
from collections import OrderedDict
from datetime import datetime
from tempfile import TemporaryDirectory
from typing import Tuple

from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.utils.tensorboard import SummaryWriter
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

%load_ext autoreload
%autoreload 2

In [25]:
from dog import DoG, LDoG
from models import TransformerClassifier
from datasets import EarDataset, EarEEGPreprocessor

In [26]:
#!ls 'gdrive/My Drive/Muller Group Drive/Ear EEG/Drowsiness_Detection/classifier_TBME'
# !ls C:\Users\arya_bastani\Documents\ear_eeg\data\ear_eeg_data
# ear_eeg_base_path = '/data/shared/signal-diffusion/'
ear_eeg_base_path = '/mnt/d/data/signal-diffusion/'
ear_eeg_data_path = ear_eeg_base_path + 'eeg_classification_data/ear_eeg_data/ear_eeg_clean'

%ls {ear_eeg_data_path}

# Data Preprocessing (run once)

In [27]:
preprocessor = EarEEGPreprocessor(ear_eeg_base_path,)

seq_len = 2000
# %time preprocessor.preprocess(seq_len)

# Models and DataLoaders

In [28]:
d = EarDataset(ear_eeg_base_path + "ear_eeg_train", 40)
print(d[0][0].shape)
print(d[0][1].shape)

# Training

In [17]:
def log_etas(logger, opt, step):
    state = opt.state_dict()
    scalars = {}
    for i, p in enumerate(state['param_groups']):
        etas = torch.stack(p['eta']).detach().cpu()
        logger.add_histogram(f"Eta.{i}", etas, global_step=step)
    return scalars

In [18]:
# define training and evaluation functions
def train_epoch(model, iterator, optimizer, criterion, device, logger, progress):
    global global_step
    model.train()
    losses = []
    accuracies = []
    for i, (src, trg) in enumerate(iterator):
        # Send to device
        src = src.to(device)
        trg = trg.to(device)
        # Run classifier & take step
        output = model(src)
        loss = criterion(output.permute(0, 2, 1), trg)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # Log loss, accuracy
        losses.append(loss.item())
        y_hat = torch.argmax(output, dim=-1, keepdim=False)
        accuracy = torch.sum(y_hat == trg) /  y_hat.nelement()
        accuracies.append(accuracy.item())
        logger.add_scalar("Loss/train", loss.item(), global_step=global_step)
        logger.add_scalar("Accuracy/train", accuracy.item(), global_step=global_step)
        if isinstance(optimizer, DoG) and global_step % 100 == 0:
            log_etas(logger, optimizer, global_step)
        # Update progress bar
        progress.set_postfix({"loss": round(loss.item(), 5), "acc": round(accuracy.item(), 3)})
        progress.update(1)
        global_step += 1
    return losses, accuracies

def evaluate(model, iterator, criterion, device, logger):
    global global_step
    model.eval()
    epoch_loss = 0
    epoch_accuracy = 0
    N = len(iterator)
    with torch.no_grad():
        for batch in iterator:
            src, trg = batch
            src = src.to(device)
            trg = trg.to(device)
            output = model(src)
            loss = criterion(output.permute(0, 2, 1), trg)
            epoch_loss += loss.item()
            y_hat = torch.argmax(output, dim=-1, keepdim=False)
            accuracy = torch.sum(y_hat == trg) /  y_hat.nelement()
            epoch_accuracy += accuracy.item()
    logger.add_scalar("Loss/validate", epoch_loss / N, global_step=global_step)
    logger.add_scalar("Accuracy/validate", epoch_accuracy / N, global_step=global_step)
    return epoch_loss / N, epoch_accuracy / N

In [19]:
# Parameters
BATCH_SIZE = 256
SHUFFLE = True
NUM_WORKERS = 8
CONTEXT_SAMPS = 40

# Datasets
training_set = Dataset(ear_eeg_base_path + "ear_eeg_train", CONTEXT_SAMPS)
validation_set = Dataset(ear_eeg_base_path + "ear_eeg_val", CONTEXT_SAMPS)
test_set = Dataset(ear_eeg_base_path + "ear_eeg_test", CONTEXT_SAMPS)
training_generator = torch.utils.data.DataLoader(training_set, batch_size=BATCH_SIZE, 
                                                 shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                                 pin_memory=True, persistent_workers=True)
validation_generator = torch.utils.data.DataLoader(validation_set, batch_size=BATCH_SIZE, 
                                                   shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                                   pin_memory=True, persistent_workers=True)
test_generator = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, 
                                             shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                             pin_memory=True, persistent_workers=True)

In [22]:
# define hyperparameters
INPUT_DIM = 10 * CONTEXT_SAMPS
OUTPUT_DIM = 2
HID_DIM = INPUT_DIM // 4
N_LAYERS = 4
N_HEADS = 4
FF_DIM = 256
DROPOUT = 0.1
BATCH_FIRST = True # True: (batch, seq, feature). False: (seq, batch, feature)

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

# create model instance
model = TransformerClassifier(INPUT_DIM, OUTPUT_DIM, HID_DIM, N_LAYERS, N_HEADS, FF_DIM, DROPOUT, BATCH_FIRST)
model = model.to(device)

# define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = LDoG(model.parameters())

In [23]:
# Log statistics
tbsw = SummaryWriter(log_dir="./tensorboard_logs/" + "txfmr_classifier-" + datetime.now().isoformat(sep='_'), 
                     comment="txfmr_classifier")

# Training loop
N_EPOCHS = 500
VAL_EVERY = 10
best_valid_acc = float('inf')
epoch_losses = []
best_epoch = 0
global_step = 0
prog = tqdm(total=len(training_generator) * N_EPOCHS)
for epoch in range(N_EPOCHS):
#     print("Epoch:", epoch, file=sys.stderr)
    train_losses, train_accs = train_epoch(model, training_generator, optimizer, criterion, device, tbsw, prog)
    if epoch % VAL_EVERY == 0 or epoch == N_EPOCHS - 1:
        valid_loss, valid_acc = evaluate(model, validation_generator, criterion, device, tbsw)
        if valid_acc < best_valid_acc:
            best_valid_acc = valid_acc
            best_epoch = epoch
            torch.save(model.state_dict(), f'model.pt')
        prog.set_postfix({"Epoch": epoch+ 1, "TAcc": round(train_accs[-1], 3), "VAcc": round(valid_acc, 3)})
    else:
        valid_loss = None
    epoch_losses.append((train_losses, valid_loss))
# Finish progress bar
prog.close()

# load best model and evaluate on test set

model.load_state_dict(torch.load(f'model.pt'))
test_loss, test_acc = evaluate(model, test_generator, criterion, device, tbsw)
print(f'Test loss={test_loss:.3f}; test accuracy={test_acc:.3f}')

In [24]:
model