In [None]:
!mkdir models
!mkdir plots
%reload_ext autoreload
%autoreload 2

In [11]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchmetrics.classification import MulticlassAccuracy, MulticlassMatthewsCorrCoef
from torch.nn.utils.rnn import pad_sequence

from classifier import NappaSleepNet

from utils.dataset_classes import NappaDataset
from utils.dataset_preprocess import HybridScaler

from utils.plots import plot_learning_curves

In [None]:
np.set_printoptions(precision=4, suppress=True)
torch.set_printoptions(precision=4,threshold=1000, sci_mode=False)
plt.rcParams.update({'font.size': 12})

sns.set_theme(context='notebook', style='white',
               palette='deep', font='arial',
               font_scale=1, color_codes=True, rc=None)

sns.set_style("ticks")

In [4]:
def select_device(eval=False):
  if torch.cuda.is_available() and not eval:
    device = torch.device('cuda')
  else:
    device = torch.device('cpu')
  return device

In [5]:
device = select_device()

torch.manual_seed(0)

scaler = HybridScaler(method='global')

mapping = {
    'N3'  :0,
    'N2'  :0,
    'N1'  :1,
    'REM' :1,
    'Wake':2,        
    }

nappa_dataset = NappaDataset('nappa_dataset.pkl').labelsToNumeric(mapping).sortById()

sleep_classes = ['N2/N3', 'N1/REM', 'Wake']

NUM_FEATURES = nappa_dataset.features.shape[1]
NUM_CLASSES = len(sleep_classes)
PADDING_VALUE = -1

learning_rate = 1e-3
batch_size = 4
num_epochs = 100

model = NappaSleepNet(n_features=NUM_FEATURES, n_classes=NUM_CLASSES).to(device)

multiclass_matthewscorrcoef = MulticlassMatthewsCorrCoef(num_classes = NUM_CLASSES, ignore_index=PADDING_VALUE).to(device)
multiclass_accuracy = MulticlassAccuracy(num_classes = NUM_CLASSES, average='micro', ignore_index=PADDING_VALUE).to(device)

In [6]:
def compute_metrics(model, criterion, dataloader):
    """
    Computes the loss, accuracy, and Matthews correlation coefficient (MCC) for the given model and data.

    This function iterates over the provided data loader to accumulate the model's loss and compute
    the accuracy and MCC for the predictions. The metrics are computed for the entire dataset
    aggregated across all batches in the data loader.

    Args:
        model: The neural network model to evaluate.
        criterion: The loss function used to evaluate the model's performance.
        dataloader: The data loader containing the dataset to evaluate.

    Returns:
        tuple: A tuple containing the average loss, overall accuracy, and overall MCC for the dataset.
    """
    total_loss = 0.0
    multiclass_accuracy.reset()
    multiclass_matthewscorrcoef.reset()

    model.eval()
    with torch.no_grad():
        for features, labels, rec_lengths in dataloader:

          features = features.to(device).type(torch.float)
          labels = labels.to(device).type(torch.long)

          output = model(features, rec_lengths)
          loss = criterion(output.reshape(-1, NUM_CLASSES), labels.flatten())
          total_loss += loss.item()

          predicted_classes = torch.argmax(output, dim=-1)

          mcc = multiclass_matthewscorrcoef(predicted_classes, labels)
          accuracy = multiclass_accuracy(predicted_classes, labels)


    # Aggregate the results from all batches to get metrics on the whole input dataset
    accuracy = multiclass_accuracy.compute()
    mcc = multiclass_matthewscorrcoef.compute()

    return total_loss / len(dataloader), accuracy.cpu(), mcc.cpu()

In [7]:
def collate(batch):
  """
  Prepares a batch of sleep recordings for training by padding them to equal lengths.

  This function takes a list of SleepRecording instances and pads their features and labels 
  to ensure that each sequence in the batch has the same length.
  Args:
    batch (list): A list of SleepRecording instances.

  Returns:
    tuple: A tuple containing the following elements:
      - features_padded: A tensor of padded features with shape 
        (batch_size, max_rec_length, n_features), where max_rec_length is the length of 
        the longest recording in the batch.
      - labels_padded: A tensor of padded labels with shape 
        (batch_size, max_rec_length), corresponding to the padded features.
      - lengths (list): A list of the original lengths of each recording in the batch.
  """
  # Extract the lengths of each recording and sort the batch based on these lengths
  rec_lengths = [len(recording.features) for recording in batch]
  sorted_indices = sorted(range(len(rec_lengths)), key=rec_lengths.__getitem__, reverse=True)
  batch = [batch[i] for i in sorted_indices]

  # Convert features and labels of each recording to tensors
  features = [torch.tensor(recording.features) for recording in batch]
  labels = [torch.tensor(recording.labels) for recording in batch]

  # Pad features and labels to the length of the longest recording in the batch
  features_padded = pad_sequence(features, batch_first=True, padding_value= PADDING_VALUE)
  labels_padded = pad_sequence(labels, batch_first=True, padding_value= PADDING_VALUE)

  # Return the padded features and labels, along with the sorted lengths
  return features_padded, labels_padded, sorted(rec_lengths, reverse=True)


In [8]:
def LOSOCV(model, dataset, lr, batch_size, num_epochs, verbose=False, plot_curves=False):

  """
  Performs Leave-One-Subject-Out Cross-Validation (LOSOCV) on a given dataset using given model.

  Args:
      model: The neural network model to be trained and evaluated.
      dataset: The dataset containing all subjects' data, each subject will be used once as a test set.
      lr: Learning rate for the optimizer.
      batch_size: Number of sleep recordings per batch.
      num_epochs: Number of epochs to train each model.
      verbose: If True, prints detailed logs for each fold.
      plot_curves: If True, generates and shows learning curves
  """
  all_train_metrics = np.zeros((3, len(dataset), num_epochs))
  all_test_metrics = np.zeros((3, len(dataset), num_epochs))

  for fold, test_subject in enumerate(dataset):

      # Exclude the current test subject to create the training dataset
      train_dataset = NappaDataset([subject for subject in dataset if subject.id != test_subject.id])

      # Calculate global mean and standard deviation for feature normalization
      train_global_mean = train_dataset.features.mean(axis=0)
      train_global_std = train_dataset.features.std(axis=0)

      # Normalize the features of sleep recordings in the training set
      train_dataset = scaler(train_dataset)

      # Construct the test set. For global normalization,
      # normalize the data using training set mean and std.
      test_dataset = scaler(NappaDataset([test_subject]), is_testset=True,
                                                    trainset_mean=train_global_mean,
                                                    trainset_std=train_global_std)

      # Create dataloaders for both datasets. Shuffle the training data for each epoch randomly.
      train_loader = DataLoader(train_dataset, collate_fn=collate, batch_size=batch_size, shuffle=True)
      test_loader = DataLoader(test_dataset, collate_fn=collate, batch_size=1)

      criterion = nn.CrossEntropyLoss(ignore_index=PADDING_VALUE)

      model.reset() # reset model parameters for each test subject/fold
      optimizer = optim.Adam(model.parameters(), lr=lr)

      if verbose:
        print(f'----------------------fold number {fold+1}----------------------')
        print(f'Test subject id: {test_subject.id}, age: {test_subject.age} months')
        print(f'Train set ids: {train_dataset.ids}')
        print(f'Test set size: {test_dataset.labels.shape[0]}')
        print(f'Train set size: {train_dataset.labels.shape[0]}\n')

      for epoch in range(num_epochs):

        model.train()
        for features, labels, rec_lengths in train_loader:

          features = features.to(device).type(torch.float)
          labels = labels.to(device).type(torch.long)

          optimizer.zero_grad()

          output = model(features, rec_lengths).reshape(-1, NUM_CLASSES)
          loss = criterion(output, labels.flatten())

          loss.backward()
          optimizer.step()

        train_metrics = compute_metrics(model, criterion, train_loader)
        test_metrics = compute_metrics(model, criterion, test_loader)

        for i, (train_metric, test_metric) in enumerate(zip(train_metrics, test_metrics)):
            all_train_metrics[i, fold, epoch] = train_metric
            all_test_metrics[i, fold, epoch] = test_metric

        if verbose and ((epoch + 1) % 20 == 0 or (epoch + 1 == num_epochs or epoch==0)):
          print(f'Epoch {(epoch + 1):<5} {"Test set":10} {"Training set":5}')
          print(f'{"MCC:":<12} {test_metrics[2]:<7.2f} {train_metrics[2]:8.2f}')
          print(f'{"Accuracy:":<12} {test_metrics[1]:<7.1%} {train_metrics[1]:9.1%}')
          print(f'{"Loss:":<12} {test_metrics[0]:<7.2f} {train_metrics[0]:8.2f}')
          print()

      if plot_curves:
        fig = plot_learning_curves(all_train_metrics, all_test_metrics, num_epochs, fold, final_plot=False)
        plt.show(fig)
        plt.close(fig)

      # Save the model weights for each test subject
      torch.save(model.state_dict(), f'models/model_subject_{test_subject.id}.pth')
      if verbose:
        print(f'Model saved to: models/model_subject_{test_subject.id}.pth \n')

  if plot_curves:
    fig = plot_learning_curves(all_train_metrics, all_test_metrics, num_epochs, fold, final_plot=True)
    fig.savefig(f'plots/learning curves.png')
    plt.show(fig)
    plt.close(fig)

In [None]:
LOSOCV(model, nappa_dataset, lr=learning_rate, batch_size=batch_size,
       num_epochs = num_epochs, verbose=True, plot_curves=True)