# Imports and Setup

In [1]:
# Imports for Tensor
import csv
import itertools
import math
import numpy as np
import os
import pickle
import shutil
import sys
from collections import OrderedDict
from datetime import datetime
from tempfile import TemporaryDirectory
from typing import Tuple

from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.utils.tensorboard import SummaryWriter
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

from torchvision import transforms

from diffusers import StableDiffusionPipeline

sys.path.append("../")

%load_ext autoreload
%autoreload 2

In [2]:
from common.dog import DoG, LDoG, PDoG
from models import CNNClassifier
from common.eeg_datasets import MathSpectrumDataset, MathPreprocessor
from training import train_class, evaluate_class, TrainingConfig
from visualization import *

In [3]:
#!ls 'gdrive/My Drive/Muller Group Drive/Ear EEG/Drowsiness_Detection/classifier_TBME'
# !ls C:\Users\arya_bastani\Documents\ear_eeg\data\ear_eeg_data
ear_eeg_base_path = '/data/shared/signal-diffusion/'
# ear_eeg_base_path = '/mnt/d/data/signal-diffusion/'
ear_eeg_data_path = ear_eeg_base_path + 'eeg_classification_data/ear_eeg_data/ear_eeg_clean'
math_data_path = '/data/shared/signal-diffusion/eeg_math/raw_eeg'
# math_data_path = '/mnt/d/data/signal-diffusion/eeg_math/raw_eeg'
%ls {ear_eeg_data_path}

# Data Preprocessing (run once)

In [4]:
nsamps = 2000
preprocessor = MathPreprocessor(math_data_path, nsamps)

samps_per_file = 20
%time preprocessor.preprocess(samps_per_file, 0.9, 0.05, 0.05)

# Models and DataLoaders

In [5]:
d = MathSpectrumDataset(math_data_path + "/train", 512, 80,)
print(d[0][0].shape)
print(d[0][1].shape)
print(len(d))

In [6]:
dl = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True,)
S, y = next(iter(dl))
print(S.shape, y.shape)
print(S.max(), S.min(), S.mean(), S.std())

# Math EEG Training

In [11]:
# Parameters
BATCH_SIZE = 64
SHUFFLE = True
NUM_WORKERS = 16
N_TOKENS = 128
RESOLUTION = 512
HOP_LENGTH = 80
persistent = NUM_WORKERS > 0

# Datasets
train_set = MathSpectrumDataset(math_data_path + "/train", RESOLUTION, HOP_LENGTH)
val_set = MathSpectrumDataset(math_data_path + "/val", RESOLUTION, HOP_LENGTH)
test_set = MathSpectrumDataset(math_data_path + "/test", RESOLUTION, HOP_LENGTH)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, 
                                           shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                           pin_memory=True, persistent_workers=persistent)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, 
                                         shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                         pin_memory=True, persistent_workers=persistent)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, 
                                          shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                          pin_memory=True, persistent_workers=persistent)


In [12]:
# define hyperparameters
OUTPUT_DIM = 4
DROPOUT = 0.25
BATCH_FIRST = True # True: (batch, seq, feature). False: (seq, batch, feature)
WEIGHT_DECAY = 0.0001

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

# Loss function
criterion = nn.CrossEntropyLoss()

In [13]:
# Runtime training parameters
opt_options = [
    torch.optim.AdamW,
#     LDoG,
    # PDoG,
#     DoG,
]
# restart_options = [0, 100]
restart_options = [0]
# max_eta_options = [None, 100.]
max_eta_options = [None]
# decay_options = [0.0, 0.0001, 0.001]
decay_options = [0.0001]
# decouple_options = [True, False]
decouple_options = [True]

# Run for each option combo
seen_decays = set()
combos = list(itertools.product(
    opt_options, decay_options, restart_options, max_eta_options, decouple_options))
print(f"Running {len(combos)} hyperparameter tests")
for (opt, decay, restart, max_eta, decouple) in combos:
    # Only run once for AdamW, decay combos
    if opt == torch.optim.AdamW:
        if decay in seen_decays:
            continue
        seen_decays.add(decay)

    # Create model instance
    model = CNNClassifier(1, OUTPUT_DIM,)
    model = model.to(device)    
    if opt == torch.optim.AdamW:
        optimizer = opt(model.parameters(), lr=1e-3, weight_decay=decay)
    else:
        optimizer = opt(model.parameters(), weight_decay=decay, max_eta=max_eta, decouple_weight_decay=decouple)

    # Create training configuration
    ARGS = TrainingConfig(epochs=300, val_every_epochs=10, opt_restart_every=restart)

    # Log statistics
    postfix = ""
    if isinstance(optimizer, DoG):
        postfix = f"_restart{restart}_etamax{max_eta}_decouple{str(int(decouple))}"
    comment = f"cnnclass_{str(type(optimizer)).split('.')[-1][:-2]}_decay{decay}{postfix}"
    tbsw = SummaryWriter(log_dir="./tensorboard_logs/cnn/" + comment + "-" + datetime.now().isoformat(sep='_'), 
                         comment=comment)
    print("#" * 80)
    print("Training", comment)

    # Training loop
    losses, accs, val_accs = train_class(
        ARGS, model, 
        train_loader, val_loader,
        optimizer, criterion,
        device, tbsw
    )
    
    # load best model and evaluate on test set
    model.load_state_dict(torch.load(f'best_model.pt'))
    test_loss, test_acc = evaluate_class(model, test_loader, criterion, device, tbsw, ARGS.epochs * len(train_loader) + 1)
    print(f'Test loss={test_loss:.3f}; test accuracy={test_acc:.3f}')

    # Copy model to unique filename
    os.makedirs("models", exist_ok=True)
    shutil.copyfile("best_model.pt", f"models/best_model-{comment}.pt")
    shutil.copyfile("last_model.pt", f"models/last_model-{comment}.pt")


In [18]:
cf, p = class_confusion(model, train_loader, math_class_map, device)
plt.title("Training")

In [19]:
cf, p = class_confusion(model, val_loader, math_class_map, device)
plt.title("Validation")

In [20]:
cf, p = class_confusion(model, test_loader, math_class_map, device)
plt.title("Test")

In [21]:
class_prevalence(train_loader, math_class_map);

# Generated Spectrograms

In [4]:
genders = ["female", "male"]
ages = ["17", "18", "19", "20", "21"]
activities = ["resting", "doing math"]
def make_prompt():
    """
    | y | math? | gender |
    |---|-------|--------|
    | 0 |   0   |   0    |
    | 1 |   0   |   1    |
    | 2 |   1   |   0    |
    | 3 |   1   |   1    |
    """
    gender = np.random.choice(genders)
    age = np.random.choice(ages)
    activity = np.random.choice(activities)
    prompt = f"an EEG spectrogram of a {gender} {age} year old subject {activity}"
    y = (activity == "doing math") * 2 + (gender == "female")
    return prompt, y

def make_n_prompts(n):
    for _ in range(n):
        yield make_prompt()

print(make_prompt())

def im2tensor(image):
    txfm = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5])
            ])
    return txfm(image.convert("L"))


## Real-Trained Classifier on Generated Data

In [45]:
model_path = "../data/stft-full.eeg_math.1/"
# model_path = "../data/stft-post-ft.eeg_math.0/"
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")

model = CNNClassifier(1, 4)
model.load_state_dict(torch.load(f'best_model.pt'))
model.to("cuda")
model.eval()
model.requires_grad_(False)

In [46]:
gen_data = [(pipe(prompt, num_inference_steps=50, guidance_scale=7.5,).images[0], y)
            for (prompt, y) in make_n_prompts(200)]
print(gen_data)

In [50]:
with open("gen_data.post-ft.pkl", "wb") as f:
    pickle.dump(gen_data, f)

In [35]:
def test_gen_data(classifier, gen_data,):
    classifier.eval()
    y_preds = []
    ys = []
    for (image, y) in gen_data:
        input = im2tensor(image).to("cuda")
        yhat = classifier(input)
        ys.append(y)
        y_preds.append(torch.argmax(yhat).item())
    return np.array(ys), np.array(y_preds)

In [51]:
ys, yhats = test_gen_data(model, gen_data)

In [52]:
accuracy = sum(ys == yhats) / len(ys)
accuracy

In [53]:
_, fig = raw_class_confusion(ys, yhats)
fig.gca().set_ylabel("Generator Prompted Class")

In [49]:
raw_class_prevalence(ys);

## Generative-Trained Classifier on Real Data

### First, make a large dataset

In [5]:
model_path = "../data/stft-full.eeg_math.1/"
# model_path = "../data/stft-post-ft.eeg_math.0/"
pipe = StableDiffusionPipeline.from_pretrained(
    model_path, 
    safety_checker=lambda images, **kwargs: (images, False),  # Disable safety checker - spectrograms won't be NSFW
    torch_dtype=torch.float16
)
pipe.to("cuda")

In [None]:
N = 1000
batch = 8

gendir = "/data/shared/signal-diffusion/eeg_math/gen-stft/"
os.makedirs(gendir, exist_ok=True)

files = []
ys = []

nn = (N + batch - 1) // batch
for i in range(nn):
    print(f"Batch {i+1}/{nn}")
    prompts, yy = zip(*list(make_n_prompts(batch)))
    ys.extend(yy)
    images = pipe(list(prompts), num_inference_steps=50, guidance_scale=7.5,).images
    for j in range(batch):
        imname = f"gen-{i * batch + j}.png"
        files.append(imname)
        fname = os.path.join(gendir, imname)
        images[j].convert("L").save(fname)

with open("metadata.csv", "w") as f:
    writer = csv.writer(f)
    writer.write(["file", "y"])
    for row in zip(files, ys):
        writer.write(row)