# **NIRS (Spectroscopy) Data Modeling**

In [None]:
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, WeightedRandomSampler
from data_preparation.dataset import SignalDataset
from architecture.net import ClassifierNet
from pipeline.train import TrainingPipeline
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score

# 1. Load Data

In [None]:
task = "nback"
t_size = 300
train_size = 0.9
data_dir = "data/samples"
onehot_labels=True,

training_nirs_dataset = SignalDataset(
    data_dir, 
    signal="NIRS",
    task=task, 
    sample_size=train_size, 
    t_size=t_size,
    hemoglobin="oxy",
    use_spectrogram=False,
    onehot_labels=onehot_labels,
)

eval_nirs_dataset = SignalDataset(
    data_dir, 
    signal="NIRS", 
    task=task, 
    excluded=training_nirs_dataset.segment_files, 
    t_size=t_size,
    hemoglobin="oxy",
    use_spectrogram=False,
    onehot_labels=onehot_labels,
)

print(f"Number of train samples: {len(training_nirs_dataset)}")
print(f"Number of eval samples: {len(eval_nirs_dataset)}")
print(f"Number of classes: {len(training_nirs_dataset.get_label_names())}")
print(f"Class names: {training_nirs_dataset.get_label_names()}")

# 2. Visualise a Sample

In [None]:
sample_signal, label = training_nirs_dataset[0]
print(sample_signal.shape)

plt.figure(figsize=(20, 5))
if training_nirs_dataset.use_spectrogram:
    if training_nirs_dataset.avg_spectrogram_ch:
        print(f"spectrogram kwargs: {training_nirs_dataset.spectrogram_kwargs}")
        plt.imshow(sample_signal.squeeze())
        plt.title(f"Averaged Channel NIRS Spctrogram")
        plt.xlabel("Timesteps")
        plt.ylabel("Channel readings")
        plt.show()
    else:
        pass

else:
    for ch in sample_signal.squeeze():
        plt.plot(ch)
    plt.title(f"All channels NIRS Time signals")
    plt.xlabel("Timesteps")
    plt.ylabel("Channel readings")
    plt.show()

# 3. Check Class Imbalance

In [None]:
train_classes = training_nirs_dataset.get_sample_classes()
eval_classes = eval_nirs_dataset.get_sample_classes()

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
sns.countplot(train_classes, x="class_label", ax=axs[0])
axs[0].set_title("class count plot for training samples")

sns.countplot(eval_classes, x="class_label", ax=axs[1])
axs[1].set_title("class count plot for evaluation samples")
plt.show()

# 4. Define DataLoader and Account for class Imbalance with a Random Weighted Sampler

In [None]:
num_workers = 4
batch_size = 16

_, train_sample_weights = training_nirs_dataset.get_sample_weights()
train_nirs_dataloader = DataLoader(
    training_nirs_dataset, 
    num_workers=num_workers, 
    batch_size=batch_size, 
    shuffle=False,
    sampler=WeightedRandomSampler(train_sample_weights, len(training_nirs_dataset), replacement=True)
)

_, eval_sample_weights = eval_nirs_dataset.get_sample_weights()
eval_nirs_dataloader = DataLoader(
    eval_nirs_dataset, 
    num_workers=num_workers, 
    batch_size=batch_size, 
    shuffle=False, 
    sampler=WeightedRandomSampler(eval_sample_weights, len(eval_nirs_dataset), replacement=True)
)

# 5. Define the Relevant Hyper-parameters and Objects for Data Modeling

In [None]:
in_channels = 1
num_classes = len(training_nirs_dataset.get_label_names())
dropout = 0.3
network = "resnet18"
pretrained_weights = None #"DEFAULT"
track_grads = True
lr = 1e-4
min_lr = 1e-6
weight_decay = 7e-6
betas = (0.9, 0.9999)
device = "cuda" if torch.cuda.is_available() else "cpu"
T_0 = 40
T_mult = 2
model_folder = "saved_model"
model_name = f"eeg_{task}_model.pth.tar"

In [None]:
torch.manual_seed(3407)
torch.cuda.manual_seed(3407)
np.random.seed(3407)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
nirs_classifier = ClassifierNet(in_channels, num_classes, dropout, network, pretrained_weights, track_grads)
optimizer = torch.optim.Adam(nirs_classifier.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
lossfunc = lossfunc = nn.BCELoss() if onehot_labels else nn.CrossEntropyLoss()

# define pipeline
pipeline = TrainingPipeline(
    nirs_classifier, 
    lossfunc, 
    optimizer, 
    device, 
    weight_init=False, 
    dirname=model_folder, 
    filename=model_name,
    onehot_labels=onehot_labels,
)

# lr scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    pipeline.optimizer, T_0=T_0, T_mult=T_mult, eta_min=min_lr, verbose=True
)

# 6. Model the Dataset

In [None]:
epochs = 120

import warnings
warnings.filterwarnings("ignore")

for epoch in range(epochs):
    print(f"epoch: {epoch+1} / {epochs}")
    pipeline.train(train_nirs_dataloader, verbose=True)
    pipeline.evaluate(eval_nirs_dataloader, verbose=True)
    lr_scheduler.step()
    print("-"*130)

pipeline.plot_metrics("train")
pipeline.plot_metrics("eval")