In [None]:
import torch, os, random
import torch.nn as nn
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from data_preparation.dataset import SignalDataset
from architecture.net import ResClassificationNet, SimpleCNClassificationNet, MultiModalClassifier
from pipeline.train import TrainingPipeline
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score

In [None]:
transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.GaussianBlur(sigma=(0.001, 2), kernel_size=5),
])

transforms_p = 0.5

In [None]:
data_dir = "data/samples"

eeg_df = pd.read_csv(os.path.join(data_dir, f"EEG.csv"))
nirs_df = pd.read_csv(os.path.join(data_dir, f"NIRS.csv"))

eeg_df = eeg_df[eeg_df["task"] == "wg"]
nirs_df = nirs_df[nirs_df["task"] == "wg"]

eeg_df = eeg_df.sort_values("path")
nirs_df = nirs_df.sort_values("path")

In [None]:
eeg_df.head()

In [None]:
eeg_df["timestep"].max()

In [None]:
nirs_df.head()

In [None]:
df = pd.DataFrame()
df["eeg_path"] = eeg_df["path"].values
df["nirs_path"] = nirs_df["path"].values
df[["task", "datatype", "class_name"]] = eeg_df[["task", "datatype", "class_name"]].values
df = df.sample(frac=1.0)
df.head()

In [None]:
task = "wg"
t_eeg_size = 1500
t_nirs_size = 300
train_size = 0.9
signal = "multimodal"
onehot_labels = False
use_rp = False

In [None]:
training_dataset = SignalDataset(
    data_dir, 
    meta_df=df,
    signal=signal,
    task=task, 
    sample_size=train_size, 
    t_eeg_size = t_eeg_size,
    t_nirs_size=t_nirs_size,
    use_rp=use_rp,
    onehot_labels=onehot_labels,
    transforms=transforms,
    transforms_p=transforms_p
)

eval_dataset = SignalDataset(
    data_dir,
    meta_df=df,
    signal=signal, 
    task=task, 
    excluded_paths=training_dataset.meta_df["eeg_path"].tolist(),
    t_eeg_size = t_eeg_size, 
    t_nirs_size=t_nirs_size,
    use_rp=use_rp,
    onehot_labels=onehot_labels,
)

In [None]:
eeg_signals, nirs_oxy_signals, nirs_deoxy_signals, labels = training_dataset[0]

print(eeg_signals.shape, nirs_oxy_signals.shape, nirs_oxy_signals.shape)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
sns.countplot(training_dataset.meta_df, x="class_name", ax=axs[0])
axs[0].set_title("class count plot for training samples")

sns.countplot(eval_dataset.meta_df, x="class_name", ax=axs[1])
axs[1].set_title("class count plot for evaluation samples")
plt.show()

In [None]:
num_workers = 4
batch_size = 32

_, train_sample_weights = training_dataset.get_sample_weights()
train_dataloader = DataLoader(
    training_dataset, 
    num_workers=num_workers, 
    batch_size=batch_size, 
    shuffle=False,
    sampler=WeightedRandomSampler(train_sample_weights, len(training_dataset), replacement=True)
)

_, eval_sample_weights = eval_dataset.get_sample_weights()
eval_dataloader = DataLoader(
    eval_dataset, 
    num_workers=num_workers,
    batch_size=batch_size,
    shuffle=False,
    sampler=WeightedRandomSampler(eval_sample_weights, len(eval_dataset), replacement=True)
)

In [None]:
n_eeg_channels = 30
n_nirs_channels = 36
prediction_weights = {
    "eeg_w" : 1/3,
    "nirs_oxy_w" : 1/3,
    "nirs_deoxy_w" : 1/3,
}
num_classes = len(training_dataset.get_label_names())
dropout = 0.1
# network = "resnet18"
# pretrained_weights = None #"DEFAULT"
track_grads = True
lr = 1e-4
min_lr = 1e-6
weight_decay = 0.0
betas = (0.9, 0.9999)
device = "cuda" if torch.cuda.is_available() else "cpu"
T_0 = 20
T_mult = 2
model_folder = "saved_model"
model_name = f"{signal}_{task}_model.pth.tar"

In [None]:
eeg_model = SimpleCNClassificationNet(n_eeg_channels, t_eeg_size, num_classes=num_classes, dropout=dropout)
nirs_oxy_model = SimpleCNClassificationNet(n_nirs_channels, t_nirs_size, num_classes=num_classes, dropout=dropout)
nirs_deoxy_model = SimpleCNClassificationNet(n_nirs_channels, t_nirs_size, num_classes=num_classes, dropout=dropout)

classifier = MultiModalClassifier(
    eeg_model, 
    nirs_oxy_model, 
    nirs_deoxy_model, 
    **prediction_weights)

optimizer = torch.optim.Adam(classifier.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
lossfunc = nn.BCELoss() if onehot_labels else nn.CrossEntropyLoss()

# define pipeline
pipeline = TrainingPipeline(
    classifier, 
    lossfunc, 
    optimizer, 
    device, 
    weight_init=True, 
    dirname=model_folder, 
    filename=model_name,
    onehot_labels=onehot_labels,
)

# lr scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    pipeline.optimizer, T_0=T_0, T_mult=T_mult, eta_min=min_lr, verbose=True,
)

In [None]:
epochs = 140

import warnings
warnings.filterwarnings("ignore")

for epoch in range(epochs):
    print(f"epoch: {epoch+1} / {epochs}")
    pipeline.train(train_dataloader, verbose=True)
    pipeline.evaluate(eval_dataloader, verbose=True)
    lr_scheduler.step()
    print("-"*130)

pipeline.plot_metrics("train")
pipeline.plot_metrics("eval")