# Deep Learning Development - Braindecode

## Import Library

In [1]:
import sys
import os
import importlib
import numpy as np
from pathlib import Path
import torch
from torch import nn
import pandas as pd
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE
from collections import Counter
from einops import rearrange

# Add the parent directory to the path so Python can find 'scripts'
sys.path.append(os.path.abspath('..'))

# Reload the modules to get the latest changes
import script.myModules.models.pytorch_lightning as pm  # noqa: E402
import script.myModules.utils.data_utils as du  # noqa: E402
import script.myModules.utils.utility as ut  # noqa: E402
importlib.reload(du)
importlib.reload(pm)

# ut.set_seed(42)
pl.seed_everything(42)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

Seed set to 42


## Utility Functions

In [2]:
# Check if CUDA is available
gpu_count = ut.cuda_device_count()
cuda_available = ut.cuda_is_available()
print(f'GPU count: {gpu_count}, CUDA available: {cuda_available}')

GPU count: 1, CUDA available: True


## Step 1: Load Truncated Data

### 1. Create Dataset and DataLoader

In [3]:
class fNIRSDataModule(pl.LightningDataModule):
    def __init__(self, data_path, task_type, batch_size=32, oversampling=False):
        super().__init__()
        self.data_path = data_path
        self.task_type = task_type
        self.batch_size = batch_size
        self.oversampling = oversampling

    def prepare_data(self):
        # Load fnirs data and create labels
        fNIRS_data_dict = du.load_concatenated_fNIRS_data(self.data_path, self.task_type)
        self.data, self.labels = du.create_data_and_labels(fNIRS_data_dict)

    def setup(self, stage=None):
        # Split the data into train and test sets
        data_HbO = self.data[:, :23]
        self.train_data, self.test_data, self.train_labels, self.test_labels = train_test_split(
            data_HbO, self.labels, test_size=0.2, stratify=self.labels, random_state=42
        )

        # Check the shape of the training data
        num_subjects, num_channels, num_timepoints = self.train_data.shape
        print(f"Original train data shape: {self.train_data.shape}")

        # Use einops to reshape the 3D data to 2D: (samples, channels * time_points)
        reshaped_train_data = rearrange(self.train_data, 's c t -> s (c t)')

        if self.oversampling:
            # Apply SMOTE to the reshaped data
            smote = SMOTE(sampling_strategy='auto', random_state=42)
            reshaped_train_data, self.train_labels = smote.fit_resample(reshaped_train_data, self.train_labels)

            # Use einops to reshape the data back to 3D: (samples, channels, time_points)
            self.train_data = rearrange(reshaped_train_data, 's (c t) -> s c t', c=num_channels, t=num_timepoints)

            # Print the new shape and class distribution after SMOTE
            print(f"Reshaped train data after SMOTE: {self.train_data.shape}")
            print(f"Class distribution after SMOTE: {Counter(self.train_labels)}")


    def train_dataloader(self):
        # Create dataset and dataloader for training data
        train_dataset = self.fNIRSDataset(self.train_data, self.train_labels)
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)

    def test_dataloader(self):
        # Create dataset and dataloader for test data
        test_dataset = self.fNIRSDataset(self.test_data, self.test_labels)
        return DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)

    class fNIRSDataset(Dataset):
        def __init__(self, data, labels):
            self.data = torch.tensor(data, dtype=torch.float32)
            self.labels = torch.tensor(labels, dtype=torch.int64)

        def __len__(self):
            return len(self.labels)

        def __getitem__(self, idx):
            # Return a single sample (data, label)
            return self.data[idx], self.labels[idx]

In [6]:
# Example usage
df_10_1_Hz_truncated = Path('../../data/truncated/10_1_Hz')
task_type = 'GNG'

data_module = fNIRSDataModule(df_10_1_Hz_truncated, task_type, batch_size=8, oversampling=True)
data_module.prepare_data()
data_module.setup()

train_loader = data_module.train_dataloader()
test_loader = data_module.test_dataloader()

Original train data shape: (40, 23, 2424)
Reshaped train data after SMOTE: (52, 23, 2424)
Class distribution after SMOTE: Counter({0: 26, 1: 26})


### 2. Prepare Dataset for Braindecode

In [5]:
import mne
import pandas as pd

df_10_1_Hz_truncated = Path('../../data/truncated/10_1_Hz')
task_type = 'GNG'

fNIRS_data_dict = du.load_concatenated_fNIRS_data(df_10_1_Hz_truncated, task_type)
data, labels = du.create_data_and_labels(fNIRS_data_dict)

sample_idx = np.random.randint(len(data))
sample_data = data[sample_idx]
sample_label = labels[sample_idx]

# Check the shape
print(f"Sample data shape: {sample_data.shape}")  # Should print (channels, time

# Extract the HbO only Data which is the first index to 23rd index
sample_data_hbo = sample_data[:23]
print(f"Sample data HbO shape: {sample_data_hbo.shape}")

Sample data shape: (69, 2424)
Sample data HbO shape: (23, 2424)


In [12]:
# Assuming your data has 23 channels (HbO data) and 2424 time points
data = np.random.rand(23, 2424)  # Replace this with your actual data

# Define channel names for the 23 HbO channels (e.g., S1_D1, S2_D1, etc.)
ch_names = [
    "S1_D1 hbo", "S1_D2 hbo", "S2_D1 hbo", "S2_D2 hbo", "S3_D1 hbo", 
    "S3_D2 hbo", "S4_D1 hbo", "S4_D2 hbo", "S5_D1 hbo", "S5_D2 hbo",
    "S6_D1 hbo", "S6_D2 hbo", "S7_D1 hbo", "S7_D2 hbo", "S8_D1 hbo",
    "S8_D2 hbo", "S9_D1 hbo", "S9_D2 hbo", "S10_D1 hbo", "S10_D2 hbo",
    "S11_D1 hbo", "S11_D2 hbo", "S12_D1 hbo"
]

# Define the channel types (all are 'hbo' for HbO)
ch_types = ["hbo"] * 23

# Sampling frequency of the data (adjust based on your data)
sfreq = 10.1  # Adjust as needed

# Create MNE Info object
info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)

# Create RawArray object for your data (optional, just to check the data structure)
raw = mne.io.RawArray(data, info)

# Extract the `info["chs"]` to pass as `chs_info`
chs_info = info["chs"]

print(chs_info)

Creating RawArray with float64 data, n_channels=23, n_times=2424
    Range : 0 ... 2423 =      0.000 ...   239.901 secs
Ready.
[{'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 1100 (FIFFV_FNIRS_CH), 'coil_type': 300 (FIFFV_COIL_FNIRS_HBO), 'unit': 6 (FIFF_UNIT_MOL), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'S1_D1 hbo', 'scanno': 1, 'logno': 1}, {'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 1100 (FIFFV_FNIRS_CH), 'coil_type': 300 (FIFFV_COIL_FNIRS_HBO), 'unit': 6 (FIFF_UNIT_MOL), 'coord_frame': 4 (FIFFV_COORD_HEAD), 'ch_name': 'S1_D2 hbo', 'scanno': 2, 'logno': 2}, {'loc': array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]), 'unit_mul': 0 (FIFF_UNITM_NONE), 'range': 1.0, 'cal': 1.0, 'kind': 1100 (FIFFV_FNIRS_CH), 'coil_type': 300 (FIFFV_COIL_FNIRS_HBO), 'unit': 6 (FIFF_UNIT

## Step 2: Create LightningModule

In [20]:
from torchmetrics import MeanMetric
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score, BinaryAUROC, BinaryConfusionMatrix
import seaborn as sns

class fNIRSModule(pl.LightningModule):

    def __init__(self, model_name, optimizer_name, optimizer_hparams):
        """
        Inputs:
            model_name - Name of the model/CNN to run. Used for creating the model (see function below)
            model_hparams - Hyperparameters for the model, as dictionary.
            optimizer_name - Name of the optimizer to use. Currently supported: Adam, SGD
            optimizer_hparams - Hyperparameters for the optimizer, as dictionary. This includes learning rate, weight decay, etc.
        """

        super().__init__()
        self.save_hyperparameters()
        self.model = create_model(model_name)  # type: ignore # noqa: F821
        self.loss_module = nn.CrossEntropyLoss()
        self.example_input_array = torch.zeros(1, 23, 2424, dtype=torch.float32)

        # Metrics for training
        self.mean_train_loss = MeanMetric()
        self.mean_train_acc = BinaryAccuracy()

        # Metrics for validation
        self.mean_valid_loss = MeanMetric()
        self.mean_valid_acc = BinaryAccuracy()

        # Metrics for testing
        self.test_preds = []
        self.test_targets = []
        self.mean_test_acc = BinaryAccuracy()
        self.test_f1 = BinaryF1Score()
        self.test_auroc = BinaryAUROC()
        self.test_conf_matrix = BinaryConfusionMatrix()

    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        if self.hparams.optimizer_name == 'Adam':
            optimizer = torch.optim.AdamW(self.model.parameters(), **self.hparams.optimizer_hparams)
        elif self.hparams.optimizer_name == 'SGD':
            optimizer = torch.optim.SGD(self.model.parameters(), **self.hparams.optimizer_hparams)
        else:
            assert False, f"Unknown optimizer: {self.hparams.optimizer_name}"

        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], gamma=0.1
        )
        return [optimizer], [scheduler]

    
    def training_step(self, batch, *args, **kwargs):
        data, target = batch
        output = self.model(data)
        loss = self.loss_module(output, target)
        pred_batch = output.detach().argmax(dim=1)

        self.mean_train_loss(loss, weight=data.shape[0])
        self.mean_train_acc(pred_batch, target)

        self.log("train/batch_loss", self.mean_train_loss, prog_bar=True, logger=True, on_epoch=True)
        self.log("train/batch_acc", self.mean_train_acc, prog_bar=True, logger=True, on_epoch=True)
        
        return loss

    def validation_step(self, batch, *args, **kwargs):
        data, target = batch
        output = self.model(data)
        loss = self.loss_module(output, target)
        pred_batch = output.argmax(dim=1)

        # Update validation metrics
        self.mean_valid_loss(loss, weight=data.shape[0])    
        self.mean_valid_acc(pred_batch, target)

    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self.model(data)
        pred_batch = output.argmax(dim=1)

        # Update test metrics (accuracy, F1 score, AUROC, and confusion matrix)
        self.mean_test_acc(pred_batch, target)
        
        # Store predictions and labels for confusion matrix
        self.test_preds.append(pred_batch)
        self.test_targets.append(target)

    def on_validation_epoch_end(self):
        """Calculate epoch level metrics for the validation set"""

        self.log("valid/loss", self.mean_valid_loss, prog_bar=True, logger=True, on_epoch=True)
        self.log("valid/acc", self.mean_valid_acc, prog_bar=True, logger=True, on_epoch=True)


    def on_test_epoch_end(self):
        """Calculate final metrics for the test set after all batches have been processed."""
        
        # Concatenate all predictions and targets across batches
        final_preds = torch.cat(self.test_preds)
        final_targets = torch.cat(self.test_targets)

        # Compute final test metrics
        test_acc = self.mean_test_acc.compute()
        test_f1 = self.test_f1(final_preds, final_targets)
        test_auroc = self.test_auroc(final_preds, final_targets)
        test_conf_matrix = self.test_conf_matrix(final_preds, final_targets)

        # Move confusion matrix to CPU and convert to NumPy
        test_conf_matrix_cpu = test_conf_matrix.cpu().numpy()

        # Plot confusion matrix
        self.plot_confusion_matrix(test_conf_matrix_cpu)

        # Log the final metrics
        self.log("test/acc", test_acc)
        self.log("test/f1", test_f1)
        self.log("test/auroc", test_auroc)

        # Optionally print the confusion matrix
        print("Test Confusion Matrix:\n", test_conf_matrix)

        # Reset metrics for the next test run (if applicable)
        self.test_f1.reset()
        self.test_auroc.reset()
        self.test_conf_matrix.reset()

        # Clear stored predictions and targets
        self.test_preds.clear()
        self.test_targets.clear()

    def plot_confusion_matrix(self, cm):
        fig, ax = plt.subplots(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=ax, 
                    xticklabels=['Predicted Healthy', 'Predicted Non-Healthy'], 
                    yticklabels=['Healthy', 'Non-Healthy'])
        ax.set_xlabel("Predicted labels")
        ax.set_ylabel("True labels")
        ax.set_title("Confusion Matrix")

        # Log confusion matrix to TensorBoard
        self.logger.experiment.add_figure("Confusion Matrix", fig, self.current_epoch)
        plt.close(fig)

## Step 3. Model Training

In [8]:
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, RichModelSummary
CHECKPOINT_PATH = Path('./saved_models/tensorboard')

In [24]:
def train_model(model_name, save_name=None, **kwargs):
    """
    Inputs:
        model_name - Name of the model you want to run. Is used to look up the class in "model_dict"
        save_name (optional) - If specified, this name will be used for creating the checkpoint and logging directory.
    """
    if save_name is None:
        save_name = model_name

    # Crreate a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, save_name),                            # Where to save models
                         accelerator='gpu' if str(device).startswith("cuda") else "cpu",                       # We run on a GPU (if possible)
                         devices=1,                                                                            # How many GPUs/CPUs we want to use (1 is enough for the notebooks)
                         max_epochs=100,                                                                      # How many epochs to train for if no patience is set
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="valid/acc"),    # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
                                    LearningRateMonitor(logging_interval='epoch'),
                                    RichModelSummary(max_depth=3)],                            # Log learning rate every epoch
                        enable_progress_bar=True,
                        log_every_n_steps=1)                                                               
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    pl.seed_everything(42)
    model = fNIRSModule(model_name=model_name, **kwargs)
    trainer.fit(model, train_loader, test_loader)
    model = fNIRSModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation set and test set (if available)
    val_result = trainer.test(model, test_loader, verbose=False)

    return model, val_result

In [25]:
model_dict = {}

def create_model(model_name):
    if model_name in model_dict:
        return model_dict[model_name]
    else:
        assert False, f"Unknown model name \"{model_name}\". Available models are: {str(model_dict.keys())}"


#### 1. TCN (Temporal Convolutional Network) - Braindecode 

In [27]:
from braindecode.models.util import models_dict

print(f'All the Braindecode models:\n{list(models_dict.keys())}')

from braindecode.models import TCN

model = TCN(
    n_chans=23, 
    n_outputs=2,
    n_times=2424,
    n_blocks=4, 
    n_filters=64,
    kernel_size=5,
    drop_prob=0.5,
    chs_info=chs_info,
    sfreq=10.1,
)
print(model)


All the Braindecode models:
['ATCNet', 'Deep4Net', 'DeepSleepNet', 'EEGConformer', 'EEGITNet', 'EEGInception', 'EEGInceptionERP', 'EEGInceptionMI', 'EEGNetv1', 'EEGNetv4', 'EEGResNet', 'HybridNet', 'ShallowFBCSPNet', 'SleepStagerBlanco2020', 'SleepStagerChambon2018', 'SleepStagerEldele2021', 'TCN', 'TIDNet', 'USleep']
Layer (type (var_name):depth-idx)                  Input Shape               Output Shape              Param #                   Kernel Shape
TCN (TCN)                                          [1, 23, 2424]             [1, 2, 2304]              --                        --
├─Ensure4d (ensuredims): 1-1                       [1, 23, 2424]             [1, 23, 2424, 1]          --                        --
├─Sequential (temporal_blocks): 1-2                [1, 23, 2424]             [1, 64, 2424]             --                        --
│    └─TemporalBlock (temporal_block_0): 2-1       [1, 23, 2424]             [1, 64, 2424]             --                        --
│    │    

In [29]:
model_dict["TCN"] = model

tcn_model, tcn_results = train_model(model_name="TCN",
                                     optimizer_name="Adam",
                                     optimizer_hparams={"lr": 1e-3, 
                                                        "weight_decay": 1e-4})

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


  assert time_size >= self.min_len
  out_size = 1 + max(0, time_size - min_len)
  assert x.size()[3] == 1
  if x.size()[2] == 1:


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

C:\Users\LENOVO X1E\AppData\Roaming\Python\Python310\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


RuntimeError: Expected target size [8, 2304], got [8]