## Imports and Setup

In [1]:
# Imports for Tensor
import csv
import itertools
import math
import numpy as np
import os
import pandas as pd
import shutil
import sys
from collections import OrderedDict
from datetime import datetime
from tempfile import TemporaryDirectory
from typing import Tuple

from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.utils.tensorboard import SummaryWriter
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

from torchvision import transforms

from diffusers import StableDiffusionPipeline
from datasets import load_dataset

sys.path.append("../")

%load_ext autoreload
%autoreload 2

In [2]:
from common.dog import DoG, LDoG, PDoG
from cnn_models import model_size
from cnn_models import CNNClassifier, CNNClassifierLight
from cnn_models import EfficientNet, ShuffleNet, ResNet
from data_processing.math import MathDataset
from data_processing.parkinsons import ParkinsonsDataset
from data_processing.seed import SEEDDataset
from data_processing.general_dataset import GeneralPreprocessor, GeneralDataset, GeneralSampler
from data_processing.general_dataset import general_class_labels, general_dataset_map
from training import train_class, evaluate_class, TrainingConfig, LabelSmoothingCrossEntropy
from visualization import *

In [3]:
random_seed = 205 #205 Gave a good split for training
np.random.seed(random_seed)

In [4]:
# Datapaths
datadirs = {}
# datahome = '/data/shared/signal-diffusion'
datahome = '/mnt/d/data/signal-diffusion'


# Math dataset
datadirs['math'] = f'{datahome}/eeg_math'
datadirs['math-stft'] = os.path.join(datadirs['math'], 'stfts')

# Parkinsons dataset
datadirs['parkinsons'] = f'{datahome}/parkinsons/'
datadirs['parkinsons-stft'] = os.path.join(datadirs['parkinsons'], 'stfts')

#SEED dataset
datadirs['seed'] = f'{datahome}/seed/'
datadirs['seed-stft'] = os.path.join(datadirs['seed'], "stfts")

# Data Preprocessing (run once)

In [5]:
nsamps = 2000

preprocessor = GeneralPreprocessor(datadirs, nsamps, ovr_perc=0.5, fs=125) 
#preprocessor.preprocess(resolution=256)

In [6]:
# math_subs = list(math_train_subs.loc[:,"subject"])
# math_m, math_f = 0, 0
# for sub in math_subs:
#     ind = int(sub[-2:])
#     gen = math_train_info.iloc[ind][2]
#     if gen == "F":
#         math_f += 1
#     else:
#         math_m += 1
        
# park_m = park_df.gender[park_df["gender"] == "M"].count()
# park_f = park_df.gender[park_df["gender"] == "F"].count()

In [7]:
# total_math = math_m + math_f
# total_park = park_m + park_f

# print(total_math, total_park)
# print(total_math / (total_math + total_park))

In [8]:
# total_m = math_m + park_m
# total_f = math_f + park_f

# # Male female breakdown
# print(total_m, total_f)
# print(total_m / (total_m + total_f))

# Models and DataLoaders

In [9]:
# # Load Individual Datasets
# math_dataset = MathSpectrumDataset(datadirs['math'] + "/train", resolution=256)
# parkinsons_dataset = ParkinsonsDataset(datadirs['parkinsons-stft'], split="train")

# datasets = [math_dataset, math_dataset]

In [10]:
# d = GeneralDataset(datasets, split="train")
# print(d[0][0].shape)
# print(d[0][1])
# print(len(d))

In [11]:
# dl = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True,)
# S, y = next(iter(dl))
# print(S.shape, y.shape)
# print(S.max(), S.min(), S.mean(), S.std())

# Train on Real Data

In [12]:
# Parameters
BATCH_SIZE = 32
SHUFFLE = True
NUM_WORKERS = 16
N_TOKENS = 128
RESOLUTION = 256
HOP_LENGTH = 80
persistent = NUM_WORKERS > 0

# Data augmentation
randtxfm = transforms.Compose([
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# Datasets
# train_set = GeneratedSpectrumDataset(gendir)
math_val_dataset = MathDataset(datadirs['math-stft'], split="val")
parkinsons_val_dataset = ParkinsonsDataset(datadirs['parkinsons-stft'], split="val")
seed_val_dataset = SEEDDataset(datadirs['seed-stft'], split="val")
# val_datasets = [math_val_dataset, parkinsons_val_dataset, seed_val_dataset]

val_datasets = [parkinsons_val_dataset, seed_val_dataset]

math_test_dataset = MathDataset(datadirs['math-stft'], split="test")
parkinsons_test_dataset = ParkinsonsDataset(datadirs['parkinsons-stft'], split="test")
seed_test_dataset = SEEDDataset(datadirs['seed-stft'], split="test")
# test_datasets = [math_test_dataset, parkinsons_test_dataset, seed_test_dataset]
test_datasets = [parkinsons_test_dataset, seed_test_dataset]

math_real_train_dataset = MathDataset(datadirs['math-stft'], split="train")
parkinsons_real_train_dataset = ParkinsonsDataset(datadirs['parkinsons-stft'], split="train", transform=None)
seed_real_train_dataset = SEEDDataset(datadirs['seed-stft'], split="train")

# real_train_datasets = [math_real_train_dataset, parkinsons_real_train_dataset, seed_real_train_dataset]
real_train_datasets = [parkinsons_real_train_dataset, seed_real_train_dataset]


val_set = GeneralDataset(val_datasets, split='val')
test_set = GeneralDataset(test_datasets, split='test')
real_train_set = GeneralDataset(real_train_datasets, split='train')

train_samp = GeneralSampler(real_train_datasets, BATCH_SIZE, split='train')
# val_samp = ParkinsonsSampler(stft_path, BATCH_SIZE, split='val')
# test_samp = ParkinsonsSampler(stft_path, BATCH_SIZE, split='test')

val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
                                         num_workers=NUM_WORKERS, pin_memory=True, 
                                         persistent_workers=persistent)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
                                          num_workers=NUM_WORKERS, pin_memory=True, 
                                          persistent_workers=persistent)
# real_train_loader = torch.utils.data.DataLoader(real_train_set, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
#                                                 num_workers=NUM_WORKERS, pin_memory=True, 
#                                                 persistent_workers=persistent)

real_train_loader = torch.utils.data.DataLoader(real_train_set, batch_size=BATCH_SIZE, 
                                                num_workers=NUM_WORKERS, pin_memory=True, 
                                                persistent_workers=persistent, sampler=train_samp)

# val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, 
#                                          num_workers=NUM_WORKERS, pin_memory=True, 
#                                          persistent_workers=persistent, sampler=val_samp)
# test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
#                                           num_workers=NUM_WORKERS, pin_memory=True, 
#                                           persistent_workers=persistent, sampler=test_samp)

In [13]:
# define hyperparameters
OUTPUT_DIM = 2
DROPOUT = 0.5
BATCH_FIRST = True # True: (batch, seq, feature). False: (seq, batch, feature)

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

# Loss function
criterion = LabelSmoothingCrossEntropy(epsilon=0.1)

In [14]:
model = CNNClassifierLight(in_channels=1, out_dim=OUTPUT_DIM, dropout=DROPOUT, pooling="max")
y = model(torch.randn(1, 1, 256, 256))
model_size(model)

In [15]:
# Runtime training parameters
opt, decay, restart, max_eta, decouple = (torch.optim.AdamW, 0.05, 0, None, True)

for decay in [0.001,]:
    # Create model instance
    model = CNNClassifierLight(1, OUTPUT_DIM, dropout=DROPOUT,pooling="max")
    model = model.to(device)    
    if opt == torch.optim.AdamW:
        optimizer = opt(model.parameters(), lr=1e-3, weight_decay=decay)
    else:
        optimizer = opt(model.parameters(), weight_decay=decay, max_eta=max_eta, 
                        decouple_weight_decay=decouple)

    # Create training configuration
    ARGS = TrainingConfig(epochs=1, val_every_epochs=8, opt_restart_every=restart)

    # Log statistics
    postfix = ""
    if isinstance(optimizer, DoG):
        postfix = f"_restart{restart}_etamax{max_eta}_decouple{str(int(decouple))}"
    comment = f"cnnclass_{str(type(optimizer)).split('.')[-1][:-2]}_decay{decay}{postfix}"
    tbsw = SummaryWriter(log_dir="./tensorboard_logs/cnn/" + comment + "-" + 
                         datetime.now().isoformat(sep='_'), 
                         comment=comment)
    print("#" * 80)
    print("Training", comment)

    # Training loop
    losses, accs, val_accs = train_class(
        ARGS, model, 
        real_train_loader, val_loader,
        optimizer, criterion,
        device, tbsw
    )

    # load best model and evaluate on test set
    model.load_state_dict(torch.load(f'best_model.pt'))
    test_loss, test_acc = evaluate_class(model, test_loader, criterion, device, 
                                         tbsw, ARGS.epochs * len(real_train_loader) + 1)
    print(f'Test loss={test_loss:.3f}; test accuracy={test_acc:.3f}')

    # Copy model to unique filename
    os.makedirs("models", exist_ok=True)
    shutil.copyfile("best_model.pt", f"models/best_model-{comment}.pt")
    shutil.copyfile("last_model.pt", f"models/last_model-{comment}.pt")
    print(f"Copied best model to models/best_model-{comment}.pt")

rmodel = model

In [16]:
class_prevalence(real_train_loader, general_class_labels);

In [17]:
train_vs_epoch(losses, accs, 'train')

In [18]:
cf, p = class_confusion(rmodel, real_train_loader, general_class_labels, device)
plt.title("MAIN Training")

In [19]:
cf, p = class_confusion(rmodel, val_loader, general_class_labels, device)
plt.title("Validation")

In [20]:
cf, p = class_confusion(rmodel, test_loader, general_class_labels, device)
plt.title("Test")

In [21]:
for dataset in real_train_datasets:
    real_train_set = GeneralDataset([dataset], split='train')

    real_train_loader = torch.utils.data.DataLoader(real_train_set, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
                                                num_workers=NUM_WORKERS, pin_memory=True, 
                                                persistent_workers=persistent)
    
    
    train_loss, train_acc = evaluate_class(model, real_train_loader, criterion, device, 
                                     tbsw, ARGS.epochs * len(real_train_loader) + 1)
    cf, p = class_confusion(rmodel, real_train_loader, general_class_labels, device)
    plt.title("Train")

In [22]:
for dataset in val_datasets:
    test_set = GeneralDataset([dataset], split='val')

    test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
                                                num_workers=NUM_WORKERS, pin_memory=True, 
                                                persistent_workers=persistent)
    test_loss, test_acc = evaluate_class(model, test_loader, criterion, device, 
                                     tbsw, ARGS.epochs * len(real_train_loader) + 1)
    print(f'Val loss={test_loss:.3f}; Val accuracy={test_acc:.3f}')
    cf, p = class_confusion(rmodel, test_loader, general_class_labels, device)
    plt.title("Validation")

In [23]:
for dataset in test_datasets:
    test_set = GeneralDataset([dataset], split='test')

    test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=SHUFFLE,
                                                num_workers=NUM_WORKERS, pin_memory=True, 
                                                persistent_workers=persistent)
    
    test_loss, test_acc = evaluate_class(model, test_loader, criterion, device, 
                                     tbsw, ARGS.epochs * len(real_train_loader) + 1)
    print(f'Test loss={test_loss:.3f}; test accuracy={test_acc:.3f}')
    cf, p = class_confusion(rmodel, test_loader, general_class_labels, device)
    plt.title("Test")

In [24]:
subject_prevalence(real_train_datasets)

In [25]:
subject_prevalence(val_datasets)

In [26]:
subject_prevalence(test_datasets)

In [27]:
dataset_prevalence(real_train_datasets, general_dataset_map)

In [28]:
class_prevalence(real_train_loader, general_class_labels);

In [29]:
class_prevalence(val_loader, general_class_labels);

In [30]:
subject_prevalence(val_datasets)

In [31]:
class_prevalence(test_loader, general_class_labels);

In [32]:
stop

# Generate a fake dataset

# Generated Spectrograms

In [33]:
gendir = "/data/shared/signal-diffusion/parkinsons/gen-stft.1/"


In [34]:
gen_set = GeneratedSpectrumDataset(gendir)
print(len(gen_set))
print(gen_set[0])

## Generative-Trained Classifier

In [35]:
# Parameters

# Data augmentation
randtxfm = transforms.Compose([
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# Datasets
train_set = GeneratedSpectrumDataset(gendir, transform=randtxfm)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, 
                                           shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                           pin_memory=True, persistent_workers=persistent)

noaug_train_set = GeneratedSpectrumDataset(gendir, transform=None)
noaug_train_loader = torch.utils.data.DataLoader(noaug_train_set, batch_size=BATCH_SIZE, 
                                           shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                           pin_memory=True, persistent_workers=persistent)

val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, 
                                         shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                         pin_memory=True, persistent_workers=persistent)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, 
                                          shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                          pin_memory=True, persistent_workers=persistent)

In [36]:
# define hyperparameters
OUTPUT_DIM = 4
DROPOUT = 0.5
BATCH_FIRST = True # True: (batch, seq, feature). False: (seq, batch, feature)
WEIGHT_DECAY = 0.0001

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

# Loss function
criterion = nn.CrossEntropyLoss()

In [37]:
# Runtime training parameters
opt, decay, restart, max_eta, decouple = (torch.optim.AdamW, 0.01, 0, None, True)


# Create model instance
model = CNNClassifier(1, OUTPUT_DIM, dropout=DROPOUT)
model = model.to(device)    
if opt == torch.optim.AdamW:
    optimizer = opt(model.parameters(), lr=1e-3, weight_decay=decay)
else:
    optimizer = opt(model.parameters(), weight_decay=decay, max_eta=max_eta, 
                    decouple_weight_decay=decouple)

# Create training configuration
ARGS = TrainingConfig(epochs=30, val_every_epochs=3, opt_restart_every=restart)

# Log statistics
postfix = ""
if isinstance(optimizer, DoG):
    postfix = f"_restart{restart}_etamax{max_eta}_decouple{str(int(decouple))}"
comment = f"cnnclass_gen_{str(type(optimizer)).split('.')[-1][:-2]}_decay{decay}{postfix}"
tbsw = SummaryWriter(log_dir="./tensorboard_logs/cnn/" + comment + "-" + 
                     datetime.now().isoformat(sep='_'), 
                     comment=comment)
print("#" * 80)
print("Training", comment)

# Training loop
losses, accs, val_accs = train_class(
    ARGS, model, 
    train_loader, val_loader,
    optimizer, criterion,
    device, tbsw
)

# load best model and evaluate on test set
model.load_state_dict(torch.load(f'best_model.pt'))
test_loss, test_acc = evaluate_class(model, test_loader, criterion, device, 
                                     tbsw, ARGS.epochs * len(train_loader) + 1)
print(f'Test loss={test_loss:.3f}; test accuracy={test_acc:.3f}')

# Copy model to unique filename
os.makedirs("models", exist_ok=True)
shutil.copyfile("best_model.pt", f"models/best_model-{comment}.pt")
shutil.copyfile("last_model.pt", f"models/last_model-{comment}.pt")
print(f"Copied best model to models/best_model-{comment}.pt")

gmodel = model

### Evaluate on generative data

In [38]:
cf, p = class_confusion(gmodel, train_loader, parkinsons_class_labels, device)
plt.title("Gen-train, Gen-test")

### Evaluate on real data

In [39]:
cf, p = class_confusion(gmodel, real_train_loader, parkinsons_class_labels, device)
plt.title("Gen-train, Real-test")

In [40]:
cf, p = class_confusion(gmodel, test_loader, parkinsons_class_labels, device)
plt.title("Gen-train, Real-test")

In [41]:
class_prevalence(train_loader, parkinsons_class_labels);

## Real-trained classifier on generative data

In [42]:
rmodel = CNNClassifier(1, OUTPUT_DIM,)
rmodel.load_state_dict(torch.load("models/best_model-cnnclass_AdamW_decay0.001.pt"))
rmodel.to("cuda")

### Evaluate on generative data

In [43]:
cf, p = class_confusion(rmodel, noaug_train_loader, parkinsons_class_labels, device)
plt.title("Real-train, Gen-test")

### Evaluate on real data

In [44]:
cf, p = class_confusion(rmodel, test_loader, parkinsons_class_labels, device)
plt.title("Real-train, Real-test")

In [45]:
fig, ax = plt.subplots(figsize=(10, 10))
plt.subplot(2, 2, 1)
cf, _ = class_confusion(rmodel, test_loader, parkinsons_class_labels, device, fig=fig)
plt.title("Real-train, Real-test")
plt.yticks(rotation=60)
plt.xticks([])
plt.xlabel("Accuracy: {:.2%}".format(np.trace(cf) / np.sum(cf)))
plt.ylabel("True Class")

plt.subplot(2, 2, 2)
cf, _ = class_confusion(rmodel, noaug_train_loader, parkinsons_class_labels, device, fig=fig)
plt.xticks([])
plt.xlabel("Accuracy: {:.2%}".format(np.trace(cf) / np.sum(cf)))
plt.yticks([])
plt.title("Real-train, Gen-test")

plt.subplot(2, 2, 3)
cf, _ = class_confusion(gmodel, test_loader, parkinsons_class_labels, device, fig=fig)
plt.title("Gen-train, Real-test")
plt.xlabel("Predicted Class\nAccuracy: {:.2%}".format(np.trace(cf) / np.sum(cf)))
plt.yticks(rotation=60)
plt.xticks(rotation=60)

plt.subplot(2, 2, 4)
cf, _ = class_confusion(gmodel, noaug_train_loader, parkinsons_class_labels, device, fig=fig)
plt.yticks([])
plt.xticks(rotation=60)
plt.xlabel("Predicted Class\nAccuracy: {:.2%}".format(np.trace(cf) / np.sum(cf)))
plt.title("Gen-train, Gen-test")

plt.subplots_adjust(bottom=0.2)

## Train on both

In [46]:
# Datasets
both_set = torch.utils.data.ConcatDataset([real_train_set, train_set])
both_loader = torch.utils.data.DataLoader(both_set, batch_size=BATCH_SIZE, 
                                          shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                          pin_memory=True, persistent_workers=persistent)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, 
                                         shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                         pin_memory=True, persistent_workers=persistent)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, 
                                          shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                          pin_memory=True, persistent_workers=persistent)

In [47]:
# Runtime training parameters
opt, decay, restart, max_eta, decouple = (torch.optim.AdamW, 0.01, 0, None, True)


# Create model instance
model = CNNClassifier(1, OUTPUT_DIM, dropout=DROPOUT)
model = model.to(device)    
if opt == torch.optim.AdamW:
    optimizer = opt(model.parameters(), lr=1e-3, weight_decay=decay)
else:
    optimizer = opt(model.parameters(), weight_decay=decay, max_eta=max_eta, 
                    decouple_weight_decay=decouple)

# Create training configuration
ARGS = TrainingConfig(epochs=15, val_every_epochs=1, opt_restart_every=restart)

# Log statistics
postfix = ""
if isinstance(optimizer, DoG):
    postfix = f"_restart{restart}_etamax{max_eta}_decouple{str(int(decouple))}"
comment = f"cnnclass_both_{str(type(optimizer)).split('.')[-1][:-2]}_decay{decay}{postfix}"
tbsw = SummaryWriter(log_dir="./tensorboard_logs/cnn/" + comment + "-" + 
                     datetime.now().isoformat(sep='_'), 
                     comment=comment)
print("#" * 80)
print("Training", comment)

# Training loop
losses, accs, val_accs = train_class(
    ARGS, model, 
    both_loader, val_loader,
    optimizer, criterion,
    device, tbsw
)

# load best model and evaluate on test set
model.load_state_dict(torch.load(f'best_model.pt'))
test_loss, test_acc = evaluate_class(model, test_loader, criterion, device, 
                                     tbsw, ARGS.epochs * len(train_loader) + 1)
print(f'Test loss={test_loss:.3f}; test accuracy={test_acc:.3f}')

# Copy model to unique filename
os.makedirs("models", exist_ok=True)
shutil.copyfile("best_model.pt", f"models/best_model-{comment}.pt")
shutil.copyfile("last_model.pt", f"models/last_model-{comment}.pt")
print(f"Copied best model to models/best_model-{comment}.pt")

bmodel = model

In [48]:
cf, p = class_confusion(bmodel, test_loader, parkinsons_class_labels, device)
plt.title("Both-train, Real-test")

In [49]:
cf, p = class_confusion(bmodel, noaug_train_loader, parkinsons_class_labels, device)
plt.title("Both-train, Gen-test")