# **EEG (Electroencephalography) / NIRS (Near Infrared Spectroscopy) Data Modeling**

In [None]:
import torch, os
import torch.nn as nn
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from data_preparation.dataset import SignalDataset
from architecture.net import ClassifierNet
from pipeline.train import TrainingPipeline
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score

In [None]:
transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.GaussianBlur(sigma=(0.001, 2), kernel_size=5),
])

transforms_p = 0.5

# 1. Load Data

In [None]:
task = "nback"
t_size = 300
train_size = 0.9
data_dir = "data/samples"
signal = "EEG"
hemoglobin = None
onehot_labels = False
shuffle = True
use_spectrogram = False
use_rp = True
excluded_classes = [
    "0-back target",
    "0-back session",
    "2-back session", 
    "3-back session",
    "3-back target",
    "3-back non-target"
]

In [None]:
meta_df = pd.read_csv(os.path.join(data_dir, f"{signal.upper()}.csv"))
meta_df.head()

In [None]:
training_dataset = SignalDataset(
    data_dir, 
    meta_df=meta_df,
    signal=signal,
    task=task, 
    hemoglobin=hemoglobin,
    excluded_classes=excluded_classes,
    sample_size=train_size, 
    t_size=t_size,
    use_spectrogram=use_spectrogram,
    use_rp=use_rp,
    onehot_labels=onehot_labels,
    shuffle=shuffle,
    transforms=transforms,
    transforms_p=transforms_p
)

eval_dataset = SignalDataset(
    data_dir,
    meta_df=meta_df,
    signal=signal, 
    task=task, 
    hemoglobin=hemoglobin,
    excluded_paths=training_dataset.meta_df["path"].tolist(), 
    excluded_classes=excluded_classes,
    t_size=t_size,
    use_spectrogram=use_spectrogram,
    use_rp=use_rp,
    onehot_labels=onehot_labels,
    shuffle=shuffle,
)

print(f"Number of train samples: {len(training_dataset)}")
print(f"Number of eval samples: {len(eval_dataset)}")
print(f"Number of classes: {len(training_dataset.get_label_names())}")
print(f"Class names: {training_dataset.get_label_names()}")

# 2. Visualise sample

In [None]:
sample_signal, label = training_dataset[0]
print(sample_signal.shape)
print(label.shape)
print(f"class label: {label.item()}")

plt.figure(figsize=(20, 5))
if training_dataset.use_spectrogram:
    if training_dataset.avg_spectrogram_ch:
        print(f"spectrogram kwargs: {training_dataset.spectrogram_kwargs}")
        plt.imshow(sample_signal.squeeze())
        plt.title(f"Averaged Channel {signal} Spctrogram")
        plt.xlabel("Timesteps")
        plt.ylabel("Channel readings")
        plt.show()
    else:
        pass

elif training_dataset.use_rp:
    plt.title(f"Recurrence plot", cmap="gray")
    plt.imshow(sample_signal.squeeze())

else:
    for ch in sample_signal.squeeze():
        plt.plot(ch)
    plt.title(f"All channels {signal} Time signals")
    plt.xlabel("Timesteps")
    plt.ylabel("Channel readings")
    plt.show()

# 3. Check Class Imbalance

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
sns.countplot(training_dataset.meta_df, x="class_name", ax=axs[0])
axs[0].set_title("class count plot for training samples")

sns.countplot(eval_dataset.meta_df, x="class_name", ax=axs[1])
axs[1].set_title("class count plot for evaluation samples")
plt.show()

# 4. Define DataLoader and Account for class Imbalance with a Random Weighted Sampler

In [None]:
num_workers = 4
batch_size = 32

_, train_sample_weights = training_dataset.get_sample_weights()
train_dataloader = DataLoader(
    training_dataset, 
    num_workers=num_workers, 
    batch_size=batch_size, 
    shuffle=False,
    sampler=WeightedRandomSampler(train_sample_weights, len(training_dataset), replacement=True)
)

_, eval_sample_weights = eval_dataset.get_sample_weights()
eval_dataloader = DataLoader(
    eval_dataset, 
    num_workers=num_workers,
    batch_size=batch_size,
    shuffle=False,
    sampler=WeightedRandomSampler(eval_sample_weights, len(eval_dataset), replacement=True)
)

# 5. Define the Relevant Hyper-parameters and Objects for Data Modeling

In [None]:
in_channels = 1
num_classes = len(training_dataset.get_label_names())
dropout = 0.3
network = "resnet18"
pretrained_weights = None #"DEFAULT"
track_grads = True
lr = 1e-4
min_lr = 1e-6
weight_decay = 7e-6
betas = (0.9, 0.9999)
device = "cuda" if torch.cuda.is_available() else "cpu"
T_0 = 40
T_mult = 2
model_folder = "saved_model"
model_name = f"{signal}_{task}_model.pth.tar"

In [None]:
torch.manual_seed(3407)
torch.cuda.manual_seed(3407)
np.random.seed(3407)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
classifier = ClassifierNet(in_channels, num_classes, dropout, network, pretrained_weights, track_grads)
optimizer = torch.optim.AdamW(classifier.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
lossfunc = nn.BCELoss() if onehot_labels else nn.CrossEntropyLoss()

# define pipeline
pipeline = TrainingPipeline(
    classifier, 
    lossfunc, 
    optimizer, 
    device, 
    weight_init=False, 
    dirname=model_folder, 
    filename=model_name,
    onehot_labels=onehot_labels,
)

# lr scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    pipeline.optimizer, T_0=T_0, T_mult=T_mult, eta_min=min_lr, verbose=True,
)

# 6. Model the Dataset

In [None]:
epochs = 120

import warnings
warnings.filterwarnings("ignore")

for epoch in range(epochs):
    print(f"epoch: {epoch+1} / {epochs}")
    pipeline.train(train_dataloader, verbose=True)
    pipeline.evaluate(eval_dataloader, verbose=True)
    lr_scheduler.step()
    print("-"*130)

pipeline.plot_metrics("train")
pipeline.plot_metrics("eval")