# **NIRS (Spectroscopy) Data Modeling**

In [None]:
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, WeightedRandomSampler
from data_preparation.dataset import SignalDataset
from architecture.net import ClassifierNet
from pipeline.train import TrainingPipeline
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score

# 1. Load Data

In [None]:
t_size = 100
train_size = 0.8
training_eeg_dataset = SignalDataset(
    "data/samples", 
    signal="NIRS",
    task="nback", 
    sample_size=train_size, 
    t_size=t_size,
    hemoglobin="oxy",
)

eval_eeg_dataset = SignalDataset(
    "data/samples", 
    signal="NIRS", 
    task="nback", 
    excluded=training_eeg_dataset.segment_files, 
    t_size=t_size,
    hemoglobin="oxy",
)

print(f"Number of train samples: {len(training_eeg_dataset)}")
print(f"Number of eval samples: {len(eval_eeg_dataset)}")
print(f"Number of classes: {len(training_eeg_dataset.get_label_names())}")
print(f"Class names: {training_eeg_dataset.get_label_names()}")

# Check Class Imbalance

In [None]:
train_classes = training_eeg_dataset.get_sample_classes()
eval_classes = eval_eeg_dataset.get_sample_classes()

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
sns.countplot(train_classes, x="class_label", ax=axs[0])
axs[0].set_title("class count plot for training samples")

sns.countplot(eval_classes, x="class_label", ax=axs[1])
axs[1].set_title("class count plot for evaluation samples")
plt.show()

# 3. Define DataLoader and Account for class Imbalance with a Random Weighted Sampler

In [None]:
num_workers = 4
batch_size = 16

_, train_sample_weights = training_eeg_dataset.get_sample_weights()
train_eeg_dataloader = DataLoader(
    training_eeg_dataset, 
    num_workers=num_workers, 
    batch_size=batch_size, 
    shuffle=False,
    sampler=WeightedRandomSampler(train_sample_weights, len(training_eeg_dataset), replacement=True)
)

_, eval_sample_weights = eval_eeg_dataset.get_sample_weights()
eval_eeg_dataloader = DataLoader(
    eval_eeg_dataset, 
    num_workers=num_workers, 
    batch_size=batch_size, 
    shuffle=False, 
    sampler=WeightedRandomSampler(eval_sample_weights, len(eval_eeg_dataset), replacement=True)
)

# 4. Visualise a Sample

In [None]:
sample_signal, label = training_eeg_dataset[12]

plt.figure(figsize=(20, 5))
for ch in sample_signal.squeeze():
    plt.plot(ch)
plt.title(f"NIRS signals (shape: {sample_signal.squeeze().shape})")
plt.xlabel("Timesteps")
plt.ylabel("Channel readings")
plt.show()

# 5. Define the Relevant Hyper-parameters and Objects for Data Modeling

In [None]:
in_channels = 1
num_classes = len(training_eeg_dataset.get_label_names())
dropout = 0.0
network = "resnet18"
pretrained_weights = None #"DEFAULT"
track_grads = True
lr = 1e-3
min_lr = 1e-5
weight_decay = 0.0
betas = (0.9, 0.999)
device = "cuda" if torch.cuda.is_available() else "cpu"
T_0 = 10
T_mult = 2

In [None]:
eeg_classifier = ClassifierNet(in_channels, num_classes, dropout, network, pretrained_weights, track_grads)
optimizer = torch.optim.Adam(eeg_classifier.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
lossfunc = nn.CrossEntropyLoss()

# define pipeline
pipeline = TrainingPipeline(eeg_classifier, lossfunc, optimizer, device, weight_init=False)

# lr scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    pipeline.optimizer, T_0=T_0, T_mult=T_mult, eta_min=min_lr, verbose=True
)

# 6. Model the Dataset

In [None]:
epochs = 70

import warnings
warnings.filterwarnings("ignore")

for epoch in range(epochs):
    print(f"epoch: {epoch+1} / {epochs}")
    pipeline.train(train_eeg_dataloader, verbose=True)
    pipeline.evaluate(eval_eeg_dataloader, verbose=True)
    lr_scheduler.step()
    print("-"*130)