In [None]:
import sys
print(sys.executable)

In [30]:
from scipy.io import loadmat
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import multiprocessing
import optuna
from optuna.trial import TrialState
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from optuna.pruners import MedianPruner
import pickle
from sklearn.metrics import f1_score, precision_score, recall_score
import time
from optuna.samplers import RandomSampler
import h5py

# Dataset loading

In [47]:
# Loading dataset 1
import gc
def deref_all_refs(group, ref_container):
    return [group[ref[0]][:] for ref in ref_container]

def load_combined_data():
    """
    Reads 'combinedEpochs_v2.mat' (HDF5) and concatenates trials across subjects.
    
    Returns
    -------
    rotation_dataset_all_subjects : np.ndarray [sequence length, channels, num trials]
    labels_all_subjects           : np.ndarray [1, num trials]
    magnitudes_all_subjects       : np.ndarray [1, num trials]
    """

#with h5py.File('eegEpochs_16subs_chanInterp_control.mat', 'r') as control_file, \
    with h5py.File('combinedEpochs_v2.mat', 'r') as bci_file:
    
        #control_group = control_file['eegEpochs']
        bci_group = bci_file['combinedEpochs']
    
        #control_data_refs = deref_all_refs(control_group, control_group['rotation_data'])
        bci_data_refs = deref_all_refs(bci_group, bci_group['rotation_data'])
        #control_label_refs = deref_all_refs(control_group, control_group['label'])
        bci_label_refs = deref_all_refs(bci_group, bci_group['label'])
        #control_magnitude_refs = deref_all_refs(control_group, control_group['magnitude'])
        bci_magnitude_refs = deref_all_refs(bci_group, bci_group['magnitude'])
    
        # Combine references
        #all_data_refs = control_data_refs + bci_data_refs
        #all_label_refs = control_label_refs + bci_label_refs
        #all_magnitude_refs = control_magnitude_refs + bci_magnitude_refs
        all_data_refs = bci_data_refs
        all_label_refs = bci_label_refs
        all_magnitude_refs = bci_magnitude_refs
    
        print('number of subjects:', len(all_data_refs))
        print('checking shape of one subject:', all_data_refs[0].shape)
    
        # Cast to float32 to reduce memory
        rotation_dataset = [ref[:].astype(np.float32) for ref in all_data_refs]
        labels = [ref[:].astype(np.uint8) for ref in all_label_refs]
        magnitudes = [ref[:].astype(np.uint8) for ref in all_magnitude_refs]
    
        # Clean up refs
        #del control_data_refs, bci_data_refs, control_label_refs, bci_label_refs
        #del control_magnitude_refs, bci_magnitude_refs, all_data_refs, all_label_refs, all_magnitude_refs
        del bci_data_refs, bci_label_refs
        del bci_magnitude_refs, all_data_refs, all_label_refs, all_magnitude_refs
        gc.collect()
    
        # Concatenate
        rotation_dataset_all_subjects = np.concatenate(rotation_dataset, axis=0)
        labels_all_subjects = np.concatenate(labels, axis=1)
        magnitudes_all_subjects = np.concatenate(magnitudes, axis=1)
    
        print('rotation shape:', rotation_dataset_all_subjects.shape)
        print('labels shape:', labels_all_subjects.shape)
        print('magnitudes shape:', magnitudes_all_subjects.shape)
    
        return rotation_dataset_all_subjects, labels_all_subjects, magnitudes_all_subjects

In [48]:
# Loading dataset 2

import gc 
def deref_all_refs(group, ref_container):
    return [group[ref[0]][:] for ref in ref_container]

def load_neuromod_data():
    """
    Reads 'discrete_errp.mat' (HDF5) and concatenates trials across subjects.

    Returns
    -------
    rotation_dataset_all_subjects : np.ndarray [sequence length, channels, num trials]
    labels_all_subjects           : np.ndarray [1, num trials]
    magnitudes_all_subjects       : np.ndarray [1, num trials]
    """
    
    with h5py.File('discrete_errp.mat', 'r') as discrete_file:
        #h5py.File('continuous_errp.mat', 'r') as continous_file, \
        
        #continuous_group = continous_file['continuous_errp']
        discrete_group = discrete_file['discrete_errp']
    
        #continuous_data_list = deref_all_refs(continuous_group, continuous_group['data'])
        discrete_data_list = deref_all_refs(discrete_group, discrete_group['data'])
        #continuous_label_list = deref_all_refs(continuous_group, continuous_group['label'])
        discrete_label_list = deref_all_refs(discrete_group, discrete_group['label'])
    
        # Combine
        #all_data = continuous_data_list + discrete_data_list
        #all_labels = continuous_label_list + discrete_label_list
        all_data = discrete_data_list
        all_labels = discrete_label_list
        neuromod_dataset = [x.astype(np.float32) for x in all_data]
        labels = [x.astype(np.uint8) for x in all_labels]
    
        del all_data, all_labels
        gc.collect()
    
        neuromod_dataset_all_subjects = np.concatenate(neuromod_dataset, axis=0)
        labels_all_subjects = np.concatenate(labels, axis=1)
    
        return neuromod_dataset_all_subjects, labels_all_subjects

In [46]:
# Grand average plotting of error and correct trials

def GA_plotting(data, labels, channel):
    time_stamps = np.arange(-0.5, 1, 1/512)
    error_matrix = data[:,channel,labels.flatten() == 1]
    correct_matrix = data[:,channel,labels.flatten()==0]
    plt.plot(time_stamps, np.mean(error_matrix, axis=1), label='Error Trials')
    plt.plot(time_stamps, np.mean(correct_matrix, axis=1), label='Correct Trials')
    plt.legend()
    plt.show()
    error_mean = np.mean(error_matrix, axis=1)
    correct_mean = np.mean(correct_matrix, axis=1)
    return error_mean, correct_mean


# Preprocessing

In [49]:
#Basic preprocessing: sequence rejection via amplitude thresholding

def sequence_rejection(data, labels, magnitudes, task_labels, threshold):
    """
    data         : np.ndarray [sequence length, channels, num trials]
    labels       : np.ndarray [num trials]
    magnitudes   : np.ndarray [num trials]
    task_labels  : np.ndarray [num trials]
    threshold    : float 

    Returns cleaned versions of all inputs.
    """
    
    trials = data.shape[2]
    keep_mask = np.ones(trials, dtype=bool)
    for trial_id in range(trials):
        if np.any(data[:,:, trial_id] > threshold):
            keep_mask[trial_id] = False
    cleaned_data = data[:,:,keep_mask]
    cleaned_labels = labels[keep_mask]
    cleaned_magnitudes = magnitudes[keep_mask]
    cleaned_task_labels = task_labels[keep_mask]
    return cleaned_data, cleaned_labels, cleaned_magnitudes, cleaned_task_labels

In [50]:
#Basic preprocessing: Spatial filtering to increase spatial resolution
def spatial_filter(data):
    """
    data : np.ndarray [sequence length, channels, num trials]
    Returns same shape, each sample channel demeaned across channels.
    """
    return data - np.mean(data, axis=1, keepdims=True)

# Dataset creation and splitting

In [8]:
#PyTorch Dataset for single trial ErrP sequences

class ErrPDataset(Dataset):
    def __init__(self, sequences, labels, magnitudes, task_labels):
        """
        Args
        ----
        sequences  : np.ndarray [num trials, seq_len, num_features]
        labels     : np.ndarray,  [num_trials]
                     Binary (0/1) sequence-level labels.
        magnitudes : np.ndarray [num_trials]
        task_labels: np.ndarray [num_trials]
        """
        self.sequences = sequences  # [num trials, seq_len, num_features]
        self.labels = labels        # [num_trials], sequence level classification
        self.magnitudes = magnitudes # [num_trials]
        self.task_labels = task_labels  # [num_trials]
    def __len__(self):
        return len(self.sequences)
    def __getitem__(self, idx):
        # sequences[idx] is of size [seq_len, num_features]
        return {'features': torch.tensor(self.sequences[idx], dtype=torch.float), \
                'labels': torch.tensor(self.labels[idx], dtype=torch.long), \
               'magnitudes': torch.tensor(self.magnitudes[idx], dtype=torch.long), \
               'task_labels': torch.tensor(self.task_labels[idx], dtype=torch.long)}

In [9]:
# collate_fn to turn a list of sample dicts into batch tensors

def collate_fn(batch):
    sequences = [x['features'] for x in batch]
    labels = [x['labels'] for x in batch]
    magnitudes = [x['magnitudes'] for x in batch]
    task_labels = [x['task_labels'] for x in batch]
    return torch.stack(sequences), torch.stack(labels), torch.stack(magnitudes), torch.stack(task_labels)

In [10]:
#Function for dataset split (train / val / test)

def Dataset_splitting(tensor_data, train_ratio, val_ratio, test_ratio):
     """
    Args
    ----
    tensor_data : Dataset object (ErrPDataset)
    train_ratio : float in (0,1)
    val_ratio   : float in (0,1)
    test_ratio  : float in (0,1)

    Returns
    -------
    train_set, test_set, val_set : lists of dicts
    """
    
    indices = np.arange(len(tensor_data))
    # deterministic (no shuffle) split; shuffle externally
    train_size = int(len(indices) * train_ratio)
    val_size = int(len(indices) * val_ratio)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size+val_size]
    test_indices = indices[train_size+val_size:]
    train_set = [tensor_data[x] for x in train_indices]
    val_set = [tensor_data[x] for x in val_indices]
    test_set = [tensor_data[x] for x in test_indices]
    return train_set, test_set, val_set

# Feature-wise normalization

In [11]:
#Function for Per-channel, per-time-step z-score normalization

def data_normalization(train_data, test_data, val_data): 
     """
    Args
    train_data / test_data / val_data : list[dict]
    Returns
    The same three lists (modified in place).
    """
    
    temp_train_data = [x['features'] for x in train_data]
    temp_train_data = torch.stack(temp_train_data)
    num_samples, seq_len, num_features = temp_train_data.shape
    #obtain mean and std from the train set, and apply it to val and test set
    for chan_id in range (num_features):
        channel_data = temp_train_data[:,:,chan_id]
        mean = torch.mean(channel_data, dim=0, keepdim=True)
        std = torch.std(channel_data, dim=0, keepdim=True)
        # Apply to train set
        for x in train_data:
            x['features'][:, chan_id] = (x['features'][:, chan_id] - mean.squeeze(0)) / (std.squeeze(0) + 1e-6)
        # Apply to val set
        for x in val_data:
            x['features'][:, chan_id] = (x['features'][:, chan_id] - mean.squeeze(0)) / (std.squeeze(0) + 1e-6)
        # Apply to test set
        for x in test_data:
            x['features'][:, chan_id] = (x['features'][:, chan_id] - mean.squeeze(0)) / (std.squeeze(0) + 1e-6)
    
    return train_data, test_data, val_data

In [12]:
# Function  for global (sequence-wide) z-score normalisation performed per channel

def data_normalization_v2(train_data, test_data, val_data): 
    """
    Args
    ----
    train_data / test_data / val_data : list[dict]

    Returns
    -------
    The same three lists (modified in place).
    """
    
    temp_train_data = [x['features'] for x in train_data]
    temp_train_data = torch.stack(temp_train_data)
    num_samples, seq_len, num_features = temp_train_data.shape
    #obtain mean and std from the train set, and apply it to val and test set
    for chan_id in range (num_features):
        channel_data = temp_train_data[:,:,chan_id]
        mean = torch.mean(channel_data, dim=[0,1], keepdim=True)
        std = torch.std(channel_data, dim=[0,1], keepdim=True)
        # Apply to train set
        for x in train_data:
            x['features'][:, chan_id] = (x['features'][:, chan_id] - mean.squeeze(0)) / (std.squeeze(0) + 1e-6)
        # Apply to val set
        for x in val_data:
            x['features'][:, chan_id] = (x['features'][:, chan_id] - mean.squeeze(0)) / (std.squeeze(0) + 1e-6)
        # Apply to test set
        for x in test_data:
            x['features'][:, chan_id] = (x['features'][:, chan_id] - mean.squeeze(0)) / (std.squeeze(0) + 1e-6)

    
    return train_data, test_data, val_data

# RNN models

In [13]:
#Vanilla RNN

class ErrPDetectionRNNModel(nn.Module):
    """
    Plain GRU → Dropout → Linear classifier.
    """
    def __init__(self, input_size, hidden_size, num_layers, dropout_rate, output_size):
        super().__init__()
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_size, output_size)

    #Input size: batch_size, sequence_length, features
    #Output size: batch_size
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output, _ = self.gru(x)
        output = self.dropout(output)
        last_time_step = output[:, -1, :] #batch size, sequence length, hidden_size
        output = self.fc(last_time_step)
        #logits, need to be passed through sigmoid before thresholding
        return output

In [14]:
#GRU + attention mechanisms

class ErrPDetectionRNNModel_with_att(nn.Module):
    """
    GRU backbone with a simple additive attention pooled across time.
    """
    def __init__(self, input_size, hidden_size, num_layers, dropout_rate, output_size):
        super().__init__()
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.attention_layer = nn.Linear(hidden_size, 1)
        self.fc = nn.Linear(hidden_size, output_size)

    #Input size: batch_size, sequence_length, features
    #Output size: batch_size
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output, _ = self.gru(x)
        output = self.dropout(output)
        attn_scores = self.attention_layer(output)
        attn_weights = F.softmax(attn_scores, dim=1)
        context = torch.sum(attn_weights * output, dim=1)
        output = self.fc(context)
        #logits, need to be passed through sigmoid before thresholding
        return output

In [15]:
#CNN + GRU + attention mechanisms

class ErrPDetection_CNN_GRU_Attn(nn.Module):
    """
    1-D convolutional frontend (spatial-temporal filtering) + GRU +
    self-attention pooling.
    """
    
    def __init__(self, input_size, cnn_out_channels, cnn_kernel_size,
                 hidden_size, num_layers, dropout_rate, output_size):
      
        super().__init__()

        self.cnn = nn.Conv1d(
            in_channels=input_size,
            out_channels=cnn_out_channels,
            kernel_size=cnn_kernel_size,
            padding=cnn_kernel_size // 2  # 'same' padding
        )
        self.bn = nn.BatchNorm1d(cnn_out_channels)
        self.relu = nn.ReLU()

        # GRU input size = cnn_out_channels (features per timepoint)
        self.gru = nn.GRU(
            input_size=cnn_out_channels,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )

        self.dropout = nn.Dropout(dropout_rate)

        # Attention layer
        self.attention_layer = nn.Linear(hidden_size, 1)

        # Final classifier
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # permute needed to match input size: x: [batch, channels, time]
        x = x.permute(0, 2, 1) 
        x = self.cnn(x)                     # [batch, cnn_out_channels, time]
        x = self.bn(x)
        x = self.relu(x)
        x = x.permute(0, 2, 1)             # [batch, time, cnn_out_channels] → GRU input

        gru_out, _ = self.gru(x)           # [batch, time, hidden_size]
        gru_out = self.dropout(gru_out)

        attn_scores = self.attention_layer(gru_out)  # [batch, time, 1]
        attn_weights = F.softmax(attn_scores, dim=1)
        context = torch.sum(attn_weights * gru_out, dim=1)  # [batch, hidden_size]

        output = self.fc(context)          # [batch, output_size]
        return output

In [16]:
#Bidirectional GRU + attention mechanisms

class ErrPDetectionRNNModel_with_att_bidir(nn.Module):
    """
    Bidirectional GRU + temporal attention.
    """
    
    def __init__(self, input_size, hidden_size, num_layers, dropout_rate, output_size):
        super().__init__()
        self.gru = nn.GRU(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True 
        )
        self.dropout = nn.Dropout(dropout_rate)
        
        # Attention layer takes 2 * hidden_size 
        self.attention_layer = nn.Linear(2 * hidden_size, 1)
        
        # Final classifier layer also takes 2 * hidden_size
        self.fc = nn.Linear(2 * hidden_size, output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gru_output, _ = self.gru(x)  
        gru_output = self.dropout(gru_output)

        # Attention weights over time steps
        attn_scores = self.attention_layer(gru_output)  # [batch_size, seq_len, 1]
        attn_weights = F.softmax(attn_scores, dim=1)    

        context = torch.sum(attn_weights * gru_output, dim=1)  # [batch_size, 2 * hidden_size]

        # Final classification
        output = self.fc(context)  # [batch_size, output_size]
        return output  


# Loss function

In [17]:
#Binary-class cross entropy loss
def cross_entropy_loss(predictions, labels, class_weights):
    """
    Args
    ----
    predictions : raw logits [batch_size]
    labels      : 0 / 1 [batch_size]
    class_weights : 1-D FloatTensor for pos_weight
    """
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights) if class_weights is not None else nn.BCEWithLogitsLoss()
    loss = loss_fn(predictions,labels.float())
    return loss

In [28]:
# Returns TPR, TNR, Accuracy given hard predictions

def evaluation(predictions, labels): 
    """
    predictions / labels : 1-D tensors of equal length.
    """
    #TPR
    TP = torch.logical_and(predictions==1, labels==1)
    TPR = TP.sum() / ((labels==1).sum() + 1e-6)
    #TNR
    TN = torch.logical_and(predictions==0, labels==0)
    TNR = TN.sum() / ((labels==0).sum() + 1e-6)
    #Accuracy
    accuracy = (TP.sum() + TN.sum()) / len(labels)
    return TPR, TNR, accuracy

In [19]:
#Grid-search 0 → 1 step 0.01 for the best TPR×TNR threshold

def threshold_tuning(all_probs_val, all_labels_val):
    """
    all_probs_val  : tensor of sigmoid probabilities
    all_labels_val : labels of equal length
    """
    thresholds = [round(x * 0.01, 2) for x in range(101)]
    best_performance = 0.0
    best_threshold = thresholds[0]
    for thres in thresholds:
        predictions = all_probs_val > thres
        TPR, TNR, accuracy = evaluation(predictions, all_labels_val)
        perf = TPR * TNR
        if perf > best_performance:
            best_performance = perf
            best_threshold = thres
    return best_threshold, best_performance

# Optuna hyper-parameter search (objective function)

In [20]:
#Objective function of Optuna

def objective(trial, number_of_features, train_loader, val_loader, device, folder_name):
    """
    Optuna objective:
      - Tune HPs
      - instantiates model / optimiser / scheduler
      - trains with early-stopping
      - saves the best model checkpoint
      - returns best validation score (TPR×TNR)
    """
    
    model_params = {}
    
    # Hyper-parameter space 
    tuning_params = {
                    'hidden_size': trial.suggest_categorical('hidden_size', [200, 300, 400, 500]),
                     'dropout_rate': trial.suggest_float('dropout_rate',0.1, 0.5),
                     #'num_layers': trial.suggest_int('num_layers', 1, 5),
                    'num_layers': trial.suggest_int('num_layers', 1, 5),
                     'learning_rate': trial.suggest_loguniform('lr', 1e-5, 5e-4),
                     'l2_lambda': trial.suggest_loguniform('l2_lambda', 1e-6, 1e-5),
                     'epochs': 100,
                     'optimizer': trial.suggest_categorical('optimizer',['Adam', 'RMSProp']),
                     'learning_rate_scheduler': True,
                     'patience': 5,
                     'batch_size': 32,
                     'class_weight_error': trial.suggest_float('class_weight_error', 0.5, 2.0)}

    print(f"\n[OPTUNA] Starting trial {trial.number}")
    print(f"[OPTUNA] Hyperparameters: {tuning_params}")

    # Model instantiation 
    model_params['num_layers'] = tuning_params['num_layers']
    model_params['dropout_rate'] = tuning_params['dropout_rate']
    model_params['hidden_size'] = tuning_params['hidden_size']
    model_params['output_size'] = 1

    #CNN parameters
    model_params['input_size'] = number_of_features
    #number of filters
    #model_params['cnn_out_channels'] = 32
    #kernel size, fixed for now, set to 100ms
    #model_params['cnn_kernel_size'] = 50

    #pass the parameters to device
    model = ErrPDetectionRNNModel_with_att(**model_params).to(device)
    #model = ErrPDetectionRNNModel(**model_params).to(device)

    # Optimiser & LR scheduler
    if tuning_params['optimizer'] == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=tuning_params['learning_rate'], weight_decay=tuning_params['l2_lambda'])
    elif tuning_params['optimizer'] == 'RMSProp':
        optimizer = optim.RMSprop(model.parameters(), lr=tuning_params['learning_rate'], weight_decay=tuning_params['l2_lambda'])

    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=True, min_lr=1e-6)

    #Storing epoch-wise performance for early stopping
    train_losses = []
    val_losses = []
    no_improve_epochs = 0
    best_perf = 0.0
    best_tpr = 0.0
    best_tnr = 0.0
    best_acc = 0.0
    best_threshold = 0.0
    best_model = None
    best_epoch = None
    
    #Enable multi-GPU
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    print('training')
    #Training loop 
    #iterations through epochs
    for epoch in range (tuning_params['epochs']):
        print('epoch number ' + str(epoch+1))
        epoch_start_time = time.time()
        #train
        model.train()
        train_loss = 0.0
        print('training loop, size of tensor: ' + str(len(train_loader)))
        for i, (sequences, labels, magnitudes, task_labels) in enumerate(train_loader): #batch -> one update in weights
            sequences,labels, magnitudes, task_labels = sequences.to(device), labels.to(device), magnitudes.to(device), task_labels.to(device)
            #reset gradient
            optimizer.zero_grad()
            outputs = model(sequences)
            #output size of 1
            probabilities = torch.sigmoid(outputs)
            #Loss computation
            weight_tensor = torch.tensor([tuning_params['class_weight_error']], device=device)
            loss = cross_entropy_loss(outputs, labels, class_weights=weight_tensor)
            #backpropagation
            loss.backward()
            #gradient clipping
            total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            #update weights
            optimizer.step()
            #sum of loss per epoch
            train_loss += loss.item() * sequences.size(0)  # accumulate total loss

            #print('end of batch ' + str(i))
        #average loss per epoch (divided by number of batches)
        train_loss /= len(train_loader.dataset)  # divide by total samples
        print(f"[EPOCH {epoch + 1}] Training loss: {train_loss:.4f}")

        #Validation
        print('validation')
        print('validation loop, size of tensor: ' + str(len(val_loader)))
        model.eval()
        val_loss = 0.0
        all_probs_val = []
        all_labels_val = []
        #only evaluation, no gradient updates
        with torch.no_grad():
            for i, (sequences, labels, magnitudes, task_labels) in enumerate(val_loader): #batch -> one evaluation per batch
                sequences, labels, magnitudes, task_labels = sequences.to(device), labels.to(device), magnitudes.to(device), task_labels.to(device)
                outputs = model(sequences)
                probabilities = torch.sigmoid(outputs).squeeze()
                loss = cross_entropy_loss(outputs, labels, class_weights=weight_tensor)
                # sum of loss per epoch
                val_loss += loss.item() * sequences.size(0)  # accumulate total loss
                #append to list for storage
                all_probs_val.append(probabilities.detach().cpu())
                all_labels_val.append(labels.detach().cpu())

        all_probs_val = torch.cat(all_probs_val).squeeze()
        all_labels_val = torch.cat(all_labels_val).squeeze()
        #On validation set, return the best threshold and performance to the Optuna
        threshold, performance = threshold_tuning(all_probs_val, all_labels_val)
        #print(all_probs_val)
        print('threshold:' + str(threshold))
        # average loss per epoch (divided by number of batches)
        val_loss /= len(val_loader.dataset)  # divide by total samples
        print(f"[EPOCH {epoch + 1}] validation perf: {performance:.4f}")
        preds = (all_probs_val > threshold).int()
        tpr, tnr, acc = evaluation(preds, all_labels_val)
        #print(preds)
        print(f"[EPOCH {epoch + 1}] validation accuracy: {acc:.4f}")
        print(f"[EPOCH {epoch + 1}] validation TPR:      {tpr:.4f}")
        print(f"[EPOCH {epoch + 1}] validation TNR:  {tnr:.4f}")
        
        epoch_end_time = time.time()
        time_elapsed = epoch_end_time - epoch_start_time
        print(f'Time elapsed for one epoch: {time_elapsed:.2f} sec')

        # Scheduler & Early-stop
        perf = tpr * tnr
        scheduler.step(perf)
        for param_group in optimizer.param_groups:
            print(f"Current learning rate: {param_group['lr']}")
        #early stopping to prevent overfitting
        if performance > best_perf:
            best_perf = performance
            no_improve_epochs = 0
            best_model = model
            best_epoch = epoch
            best_tpr = tpr
            best_tnr = tnr
            best_acc =acc
            best_threshold = threshold
        else:
            no_improve_epochs += 1
        if no_improve_epochs >= tuning_params['patience']:
            print('early stopping')
            print('Best validation tpr x tnr: ' + str(best_perf))
            print('Best validation tpr: ' + str(best_tpr))
            print('Best validation tnr: ' + str(best_tnr))
            print('Best validation acc: ' + str(best_acc))
            break

    #Persist best checkpoint
    torch.save({
       'epoch': best_epoch,
        'model': best_model.state_dict() if best_model is not None else None,
        'best_perf': best_perf,
        'best_threshold': best_threshold,
        'model_params': model_params,
    }, f'{folder_name}/model_{trial.number}.pth')
    return best_perf # value to be maximised by Optuna

# Training and Testing

In [None]:
# Main training / HP tuning / testing script
import random
import torch.nn.functional as F

# Data-set switches
small_sample_size_experiment = False
use_multiple_datesets = False
use_perceptual_dataset = False
use_neuromod_dataset = True

#Apply CAR
spatial_filter_use = True

folder_name = 'Optuna_models_att'

if __name__ == "__main__":
    # Channel subsets (feature selection)
    RNN_features_PL_dataset = [4,8,10,14,18,20,24]
    RNN_features_neuromod_dataset = np.linspace(0,31,32)
    RNN_features_neuromod_dataset = RNN_features_neuromod_dataset.astype(int)
    
    number_of_features = len(RNN_features_neuromod_dataset)
    
    # Load whichever data set(s) are enabled
    if use_perceptual_dataset is True:
        rotation_dataset_all_subjects, labels_all_subjects, magnitudes_all_subjects = load_combined_data()
        labels_all_subjects = np.squeeze(labels_all_subjects)
        magnitudes_all_subjects = np.squeeze(magnitudes_all_subjects)
        task_labels_all_subjects = np.ones(labels_all_subjects.size)
        rotation_dataset_all_subjects = rotation_dataset_all_subjects[:,RNN_features_PL_dataset,:]
        
    if use_neuromod_dataset is True:
        rotation_dataset_all_subjects, labels_all_subjects = load_neuromod_data()
        labels_all_subjects = np.squeeze(labels_all_subjects)
        task_labels_all_subjects = 2 * np.ones(labels_all_subjects.size) 
        task_labels_all_subjects = np.squeeze(task_labels_all_subjects)
        magnitudes_all_subjects = np.ones(labels_all_subjects.size)
        magnitudes_all_subjects = np.squeeze(magnitudes_all_subjects)
        rotation_dataset_all_subjects = rotation_dataset_all_subjects[:,RNN_features_neuromod_dataset,:]
        
    if use_multiple_datesets is True:
        rotation_dataset_all_subjects, labels_all_subjects, magnitudes_all_subjects = load_combined_data()
        labels_all_subjects = np.squeeze(labels_all_subjects)
        magnitudes_all_subjects = np.squeeze(magnitudes_all_subjects)
        task_labels_PL = np.ones(labels_all_subjects.size)
        rotation_dataset_all_subjects = rotation_dataset_all_subjects[:,RNN_features_PL_dataset,:]

        
        neuromod_dataset_all_subjects, neuromod_labels_all_subjects = load_neuromod_data()
        task_labels_neuromod = 2 * np.ones(neuromod_labels_all_subjects.size)
        neuromod_labels_all_subjects = np.squeeze(neuromod_labels_all_subjects)
        labels_all_subjects = np.concatenate((labels_all_subjects, neuromod_labels_all_subjects))
        magnitudes_temp = np.ones(neuromod_labels_all_subjects.size)
        magnitudes_all_subjects = np.concatenate((magnitudes_all_subjects, magnitudes_temp))
        task_labels_all_subjects = np.concatenate((task_labels_PL, task_labels_neuromod))
        neuromod_dataset_all_subjects = neuromod_dataset_all_subjects[:,RNN_features_neuromod_dataset,:]
        rotation_dataset_all_subjects = np.concatenate((rotation_dataset_all_subjects, neuromod_dataset_all_subjects))

    
    rotation_dataset_all_subjects = np.transpose(rotation_dataset_all_subjects, (2, 1, 0))

    #shuffling across subjects and tasks:
    shuffling_indexes = np.arange(labels_all_subjects.size)
    np.random.shuffle(shuffling_indexes) 
    labels_all_subjects = labels_all_subjects[shuffling_indexes]
    magnitudes_all_subjects = magnitudes_all_subjects[shuffling_indexes]
    task_labels_all_subjects = task_labels_all_subjects[shuffling_indexes]
    rotation_dataset_all_subjects = rotation_dataset_all_subjects[:,:,shuffling_indexes]


    if small_sample_size_experiment is True:
        rotation_dataset_all_subjects = rotation_dataset_all_subjects[:,:,:8000]
        labels_all_subjects = labels_all_subjects[:8000]
        magnitudes_all_subjects = magnitudes_all_subjects[:8000]
        task_labels_all_subjects = task_labels_all_subjects[:8000]
                

    #Trial-level artefact rejection
    cleaned_data, cleaned_labels, cleaned_magnitudes, cleaned_task_labels = sequence_rejection(rotation_dataset_all_subjects, labels_all_subjects, magnitudes_all_subjects, task_labels_all_subjects, 150)
    cleaned_labels = cleaned_labels.reshape(-1, 1)
    cleaned_magnitudes = cleaned_magnitudes.reshape(-1, 1)
    cleaned_task_labels = cleaned_task_labels.reshape(-1, 1)
    
    print('number of trials after trial rejection ' + str(cleaned_data.shape))
    print('number of labels after trial rejection ' + str(cleaned_labels.shape))

    # Common Average Referencing
    if spatial_filter_use is True:
        cleaned_data = spatial_filter(cleaned_data)
    
    #Truncate to 1s period after trigger onset (note this depends on the dataset)
    RNN_features = cleaned_data[256:, :, :]
    print('Input space ' + str(RNN_features.shape))  # sequence length, features, number of sequences

    # Dataset creation to integrate with PyTorch's DataLoader
    RNN_features = np.transpose(RNN_features, (2, 0, 1)) # number of sequences, sequence length, features

    tensor_data = ErrPDataset(RNN_features, cleaned_labels, cleaned_magnitudes, cleaned_task_labels)
    # Sanity check on tensors
    data = tensor_data[0]
    print('Each sequence size ' + str(data['features'].shape))

    # Dataset splitting
    val_ratio = 0.1
    test_ratio = 0.11
    train_ratio = 1 - val_ratio - test_ratio
    train_set, test_set, val_set = Dataset_splitting(tensor_data, train_ratio, val_ratio, test_ratio)
    
    #norm_train, norm_test, norm_val = data_normalization_v2(train_set, test_set, val_set)
    norm_train, norm_test, norm_val = train_set, test_set, val_set
    
    print('train set size: ' + str(len(norm_train)))
    print('test set size: ' + str(len(norm_test)))
    print('val set size: ' + str(len(norm_val)))
    
    # Dataloaders
    n_cpus = multiprocessing.cpu_count()
    print('number of cpus ' + str(n_cpus))
    n_workers = max(2, n_cpus // 2)
    model_training_params = {'batch_size': 32}

    train_loader = DataLoader(
        norm_train,
        batch_size=model_training_params['batch_size'],
        # avoid shuffling for now
        shuffle=True,
        num_workers=n_workers,
        pin_memory=True,
        collate_fn=collate_fn,
        drop_last=True
    )

    val_loader = DataLoader(
        norm_val,
        batch_size=model_training_params['batch_size'],
        # avoid shuffling for now
        shuffle=True,
        num_workers=n_workers,
        pin_memory=True,
        collate_fn=collate_fn,
        drop_last=True
    )

    test_loader = DataLoader(
        norm_test,
        batch_size=model_training_params['batch_size'],
        # avoid shuffling for now
        shuffle=True,
        num_workers=n_workers,
        pin_memory=True,
        collate_fn=collate_fn,
        drop_last=True
    )

    # Hyper-parameter search with Optuna 
    # If GPU is available, use a GPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('number of GPUs ' + str(torch.cuda.device_count()))

    sampler = RandomSampler(seed=40)
    pruner = MedianPruner()
    study = optuna.create_study(direction='maximize', sampler=sampler, pruner=pruner)
    study.optimize(lambda trial: objective(trial, number_of_features, train_loader, val_loader, device, folder_name), n_trials=70)

    def save_study(study, filename):
        with open(filename, 'wb') as f:
            pickle.dump(study, f)

    save_study(study, f'{folder_name}.pkl')
    print(f"Best trial: {study.best_trial.number}")

    #Reload best checkpoint & evaluate on test
    checkpoint = torch.load(f'{folder_name}/model_{study.best_trial.number}.pth', map_location=device)

    model_params = checkpoint['model_params']
    model = ErrPDetectionRNNModel_with_att(**model_params).to(device)
    #model = ErrPDetectionRNNModel(**model_params).to(device)
    model.load_state_dict(checkpoint['model'])  
    best_threshold = checkpoint['best_threshold']
    print(f"[OPTUNA] Best threshold: {checkpoint['best_threshold']:.2f}")
    print(f"[OPTUNA] Best validation performance: {checkpoint['best_perf']:.4f}")
    model.eval()

    print("Running final evaluation on test set...")

    #Evaluation on the test set
    all_predicted = []
    all_labels = []
    all_magnitudes = []
    all_task_labels = []
    
    with torch.no_grad():
        for i, (sequences, labels, magnitudes, task_labels) in enumerate(test_loader):  # batch -> one evaluation per batch
            sequences, labels, magnitudes, task_labels = sequences.to(device), labels.to(device), magnitudes.to(device), task_labels.to(device)
            outputs = model(sequences)
            probabilities = torch.sigmoid(outputs)
            all_predicted.append(probabilities)
            all_labels.append(labels)
            all_magnitudes.append(magnitudes)
            all_task_labels.append(task_labels)

    all_predicted = torch.cat(all_predicted).view(-1)
    all_labels = torch.cat(all_labels).view(-1)
    all_magnitudes = torch.cat(all_magnitudes).view(-1)
    all_task_labels = torch.cat(all_task_labels).view(-1)
    
    preds = (all_predicted > best_threshold).int()
    tpr, tnr, acc = evaluation(preds, all_labels)
    print(f"Accuracy:  {acc:.4f}")
    print(f"TPR:       {tpr:.4f}")
    print(f"TNR:       {tnr:.4f}")

# Evaluation

In [None]:
print(study.best_trial.number)
best_threshold = 0.03
preds = (all_predicted > best_threshold).int()
tpr, tnr, acc = evaluation(preds[all_task_labels==1], all_labels[all_task_labels==1])
print(f"Task 1 Accuracy:  {acc:.4f}")
print(f"Task 1 TPR:       {tpr:.4f}")
print(f"Task 1 TNR:       {tnr:.4f}")

tpr, tnr, acc = evaluation(preds[(all_task_labels==1) & (all_magnitudes==6)], all_labels[(all_task_labels==1) & (all_magnitudes==6)])
print(f"Task 1 3 mag Accuracy:  {acc:.4f}")
print(f"Task 1 3 mag TPR:       {tpr:.4f}")
print(f"Task 1 3 mag TNR:       {tnr:.4f}")

tpr, tnr, acc = evaluation(preds[(all_task_labels==1) & (all_magnitudes==0)], all_labels[(all_task_labels==1) & (all_magnitudes==0)])
print(f"Task 1 0 mag Accuracy:  {acc:.4f}")
print(f"Task 1 0 mag TPR:       {tpr:.4f}")
print(f"Task 1 0 mag TNR:       {tnr:.4f}")

tpr, tnr, acc = evaluation(preds[all_task_labels==2], all_labels[all_task_labels==2])
print(f"Task 2 Accuracy:  {acc:.4f}")
print(f"Task 2 TPR:       {tpr:.4f}")
print(f"Task 2 TNR:       {tnr:.4f}")

# Some sanity checking

In [None]:
small_sample_size_experiment = False
#note large input dataset leads to crashing
use_multiple_datesets = False
use_perceptual_dataset = True 
use_neuromod_dataset = False 
shuffling = False


# Features extraction
RNN_features_PL_dataset = [4,8,10,14,18,20,24]
RNN_features_neuromod_dataset = [5,9,10,15,20,21,25]

if use_perceptual_dataset:
    rotation_dataset_all_subjects, labels_all_subjects, magnitudes_all_subjects = load_combined_data()
    labels_all_subjects = np.squeeze(labels_all_subjects)
    magnitudes_all_subjects = np.squeeze(magnitudes_all_subjects)
    task_labels_PL = np.ones(labels_all_subjects.size)

if use_neuromod_dataset: 
    rotation_dataset_all_subjects, labels_all_subjects = load_neuromod_data()
    labels_all_subjects = np.squeeze(labels_all_subjects)

    
if use_multiple_datesets is True:    
    labels_all_subjects = np.concatenate((labels_all_subjects, neuromod_labels_all_subjects))
    magnitudes_all_subjects = np.concatenate((magnitudes_all_subjects, magnitudes_temp))
    
    task_labels_all_subjects = np.concatenate((task_labels_PL, task_labels_neuromod))
    neuromod_dataset_all_subjects = neuromod_dataset_all_subjects[:,RNN_features_neuromod_dataset,:]
    rotation_dataset_all_subjects = np.concatenate((rotation_dataset_all_subjects, neuromod_dataset_all_subjects))
    
    
rotation_dataset_all_subjects = rotation_dataset_all_subjects[:,RNN_features_PL_dataset,:]
number_of_features = len(RNN_features_PL_dataset)


#GA plotting
#GA_plotting(first_subject_data, first_subject_labels, 15)
rotation_dataset_all_subjects = np.transpose(rotation_dataset_all_subjects, (2, 1, 0))

#shuffling across subjects and tasks:
if shuffling is True:
    shuffling_indexes = np.arange(labels_all_subjects.size)
    np.random.shuffle(shuffling_indexes) 
    labels_all_subjects = labels_all_subjects[shuffling_indexes]
    magnitudes_all_subjects = magnitudes_all_subjects[shuffling_indexes]
    task_labels_all_subjects = task_labels_all_subjects[shuffling_indexes]
    rotation_dataset_all_subjects = rotation_dataset_all_subjects[:,:,shuffling_indexes]


if small_sample_size_experiment is True:
    rotation_dataset_all_subjects = rotation_dataset_all_subjects[:,:,:1000]
    labels_all_subjects = labels_all_subjects[:1000]
    magnitudes_all_subjects = magnitudes_all_subjects[:1000]
    task_labels_all_subjects = task_labels_all_subjects[:1000]

print(rotation_dataset_all_subjects.shape)

logical_index = ~np.any(np.abs(rotation_dataset_all_subjects) > 100, axis=(0, 1))
logical_index = logical_index & (magnitudes_all_subjects == 12)
rotation_dataset_all_subjects = rotation_dataset_all_subjects[:,:,logical_index]
labels_all_subjects = labels_all_subjects[logical_index]

error_mean, correct_mean = GA_plotting(rotation_dataset_all_subjects, labels_all_subjects, 0)