# Imports and Setup

In [3]:
# Imports for Tensor
import csv
import itertools
import math
import numpy as np
import os
import shutil
import sys
from collections import OrderedDict
from datetime import datetime
from tempfile import TemporaryDirectory
from typing import Tuple

from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.utils.tensorboard import SummaryWriter
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

from torchvision import transforms

sys.path.append("../")

%load_ext autoreload
%autoreload 2

In [4]:
from common.dog import DoG, LDoG, PDoG
from models import CNNClassifier
from common.eeg_datasets import MathSpectrumDataset, MathPreprocessor
from training import train_class, evaluate_class, TrainingConfig
from visualization import *

In [5]:
#!ls 'gdrive/My Drive/Muller Group Drive/Ear EEG/Drowsiness_Detection/classifier_TBME'
# !ls C:\Users\arya_bastani\Documents\ear_eeg\data\ear_eeg_data
ear_eeg_base_path = '/data/shared/signal-diffusion/'
# ear_eeg_base_path = '/mnt/d/data/signal-diffusion/'
ear_eeg_data_path = ear_eeg_base_path + 'eeg_classification_data/ear_eeg_data/ear_eeg_clean'
math_data_path = '/data/shared/signal-diffusion/eeg_math/raw_eeg'
# math_data_path = '/mnt/d/data/signal-diffusion/eeg_math/raw_eeg'
%ls {ear_eeg_data_path}

# Data Preprocessing (run once)

In [6]:
nsamps = 2000
preprocessor = MathPreprocessor(math_data_path, nsamps)

samps_per_file = 100
#%time preprocessor.preprocess(samps_per_file)

# Models and DataLoaders

In [None]:
d = MathSpectrumDataset(math_data_path + "/train", 256, 192)
print(d[0][0].shape)
print(d[0][1].shape)

print(len(d))

In [None]:
dl = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True,)
S, y = next(iter(dl))
print(S.shape, y.shape)
print(S.max(), S.min(), S.mean(), S.std())

# Math EEG Training

In [None]:
# Parameters
BATCH_SIZE = 64
SHUFFLE = True
NUM_WORKERS = 16
N_TOKENS = 128
RESOLUTION = 256
HOP_LENGTH = 192
persistent = NUM_WORKERS > 0

# Datasets
train_set = MathSpectrumDataset(math_data_path + "/train", RESOLUTION, HOP_LENGTH)
val_set = MathSpectrumDataset(math_data_path + "/val", RESOLUTION, HOP_LENGTH)
test_set = MathSpectrumDataset(math_data_path + "/test", RESOLUTION, HOP_LENGTH)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, 
                                           shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                           pin_memory=True, persistent_workers=persistent)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, 
                                         shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                         pin_memory=True, persistent_workers=persistent)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, 
                                          shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                          pin_memory=True, persistent_workers=persistent)


In [None]:
# define hyperparameters
OUTPUT_DIM = 4
DROPOUT = 0.25
BATCH_FIRST = True # True: (batch, seq, feature). False: (seq, batch, feature)
WEIGHT_DECAY = 0.0001

# CUDA for PyTorch
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
torch.backends.cudnn.benchmark = device.type == "cuda"

# Loss function
criterion = nn.CrossEntropyLoss()

In [None]:
# Runtime training parameters
opt_options = [
    torch.optim.AdamW,
    LDoG,
    # PDoG,
    DoG,
]
restart_options = [0, 100]
# restart_options = [0]
max_eta_options = [None, 100.]
# max_eta_options = [None]
# decay_options = [0.0, 0.0001, 0.001]
decay_options = [0.0001, 0.001]
# decouple_options = [True, False]
decouple_options = [True]

# Run for each option combo
seen_decays = set()
combos = list(itertools.product(
    opt_options, decay_options, restart_options, max_eta_options, decouple_options))
print(f"Running {len(combos)} hyperparameter tests")
for (opt, decay, restart, max_eta, decouple) in combos:
    # Only run once for AdamW, decay combos
    if opt == torch.optim.AdamW:
        continue
        if decay in seen_decays:
            continue
        seen_decays.add(decay)

    # Create model instance
    model = CNNClassifier(1, OUTPUT_DIM,)
    model = model.to(device)    
    if opt == torch.optim.AdamW:
        optimizer = opt(model.parameters(), lr=1e-3, weight_decay=decay)
    else:
        optimizer = opt(model.parameters(), weight_decay=decay, max_eta=max_eta, decouple_weight_decay=decouple)

    # Create training configuration
    ARGS = TrainingConfig(epochs=300, val_every_epochs=10, opt_restart_every=restart)

    # Log statistics
    postfix = ""
    if isinstance(optimizer, DoG):
        postfix = f"_restart{restart}_etamax{max_eta}_decouple{str(int(decouple))}"
    comment = f"cnnclass_{str(type(optimizer)).split('.')[-1][:-2]}_decay{decay}{postfix}"
    tbsw = SummaryWriter(log_dir="./tensorboard_logs/cnn/" + comment + "-" + datetime.now().isoformat(sep='_'), 
                         comment=comment)
    print("#" * 80)
    print("Training", comment)

    # Training loop
    losses, accs, val_accs = train_class(
        ARGS, model, 
        train_loader, val_loader,
        optimizer, criterion,
        device, tbsw
    )
    
    # load best model and evaluate on test set
    model.load_state_dict(torch.load(f'best_model.pt'))
    test_loss, test_acc = evaluate_class(model, test_loader, criterion, device, tbsw, ARGS.epochs * len(train_loader) + 1)
    print(f'Test loss={test_loss:.3f}; test accuracy={test_acc:.3f}')

    # Copy model to unique filename
    os.makedirs("models", exist_ok=True)
    shutil.copyfile("best_model.pt", f"models/best_model-{comment}.pt")
    shutil.copyfile("last_model.pt", f"models/last_model-{comment}.pt")


In [None]:
cf, p = class_confusion(model, train_loader, math_class_map, device)
plt.title("Training")

In [None]:
cf, p = class_confusion(model, val_loader, math_class_map, device)
plt.title("Validation")

In [None]:
cf, p = class_confusion(model, test_loader, math_class_map, device)
plt.title("Test")

In [None]:
class_prevalence(train_loader, math_class_map);