# <b><span style='color: #196f3d '>| HMS:</span>  EEGNET [Train]</b> 
This notebook is primarly based on [Nisschay Dhankars EEGNET Starter](https://www.kaggle.com/code/nischaydnk/lightning-1d-eegnet-training-pipeline-hbs). The layout from [Moths WaveNet Starter](https://www.kaggle.com/code/alejopaullier/hms-wavenet-pytorch-train) and certain ideas from [Chris Deottes WaveNet Starter](https://www.kaggle.com/code/cdeotte/wavenet-starter-lb-0-52) were also implemented in this notebook. This notebook is used for training the model. The inference notebook can be found [here](https://www.kaggle.com/code/jasperpieterse/hms-1d-eegnet-inference).

### <b><span style='color: #196f3d '>Table of Contents</span></b> <a class='anchor' id='top'></a>
<div style=" background-color: #ecf0f1 ; padding: 13px 13px; border-radius: 8px; color: white">
<li><a href="#import_libraries">Import Libraries</a></li>
<li><a href="#configuration">Configuration</a></li>
<li><a href="#utils">Utils</a></li>
<li><a href="#load_metadata">Load Train Metadata</a></li>
<li><a href="#crossvalidation">Cross Validation</a></li>    
<li><a href="#load_data_eeg">Load Raw Data</a></li>
<li><a href="#preprocess">Preprocessing Functions</a></li>
<li><a href="#augmentations">Augmentation Functions</a></li>
<li><a href="#dataset">Dataset</a></li>
<li><a href="#dataloader">DataLoader</a></li>
<li><a href="#model">Model</a></li>
<li><a href="#optimizer">Optimizer & Scheduler</a></li>
<li><a href="#loss">Loss Function</a></li>
<li><a href="#lightning">Pytorch Lightning Model</a></li>
<li><a href="#prediction">Prediction Function</a></li>
<li><a href="#train_loop">Train Loop</a></li>
<li><a href="#train">Train Model</a></li>
<li><a href="#train">Score</a></li>
<li><a href="#explanation"> Further Explanations</a></li>
    
</div>

# <b><span style='color: #196f3d '>|</span> Import Libraries</b><a class='anchor' id='import_libraries'></a> [↑](#top) 

***

Import and install all the required libraries for this notebook.

In [None]:
!pip install -q git+https://github.com/asteroid-team/torch-audiomentations
!pip install -q git+https://github.com/deeplearningforfun/torch-toolbox.git@master

In [None]:
import sys
import os
import gc
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations as A
import librosa
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.io 
import torch_audiomentations as tA
import torch.multiprocessing as mp
import torch.optim.lr_scheduler as scheduler 
import torchmetrics
import pytorch_lightning as pl
import sklearn.metrics
import datetime

sys.path.append('/kaggle/input/kaggle-kl-div')
from kaggle_kl_div import score
from glob import glob
from tqdm import tqdm
from scipy.signal import butter, lfilter
from torch.utils.data import Dataset, DataLoader
from sklearn import model_selection
from sklearn.model_selection import KFold, GroupKFold
from PIL import Image
from torch.nn.functional import cross_entropy
from torchtoolbox.tools import mixup_data, mixup_criterion
from pytorch_lightning.callbacks import ModelCheckpoint, BackboneFinetuning, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

# <b><span style='color: #196f3d '>|</span> Configuration</b><a class='anchor' id='configuration'></a> [↑](#top) 

***

In [None]:
class Config:
    #Data Creation
    EXPERIMENT_NAME = '8CH_PARALLEL_FEATURE_ENGINEERING_TRUEFLT32' #Name of model to store in TensorBoard logger
    CREATE_EEGS = False # Whether to create EEG data in kernel instead of loading it from dataset
    FEATURE_ENGINEERING = True # Whether to use feature engineering (i.e. the residuals between features with chain formalism) or use the raw features data
    CUSTOM_FEATURES =['Fp1','T3','C3','O1','Fp2','C4','T4','O2'] # List containing what EEG features to use to create data. Leave empty to use all
    
    #Training Parameters
    NUM_WORKERS = 2         
    PRECISION   = 32        
    BATCH_SIZE  = 32
    EPOCHS      = 30 
    PATIENCE    = 20   
    P_DROPOUT   = 0.0
    SEED        = 2024
    LR          = 8e-3
    FOLDS       = 5 


    #Model parameters
    KERNELS = [3,5,7,9]
    FIXED_KERNEL_SIZE = 5 #Size of the kernel used in the fixed blocks
    NUM_FEATURE_MAPS = 24 #Number of feature maps (channels) outputed for each parallel convolution (num_channels = num_kernels * num_feature_maps)
    RESNET_BLOCKS = 9     #Amount of ResNet blocks to be used
    if CUSTOM_FEATURES:
        NUM_CHANNELS = len(CUSTOM_FEATURES) # Amount of channels used for CNN
    else:  
        NUM_CHANNELS = 20   

    #Augmentation
    PRETRAINED   = False    #Whether to use a pretrained EEG net        
    WEIGHT_DECAY = 0.01
    USE_MIXUP    = False   # Whether to use mixup augmentation
    MIXUP_ALPHA  = 0.1     # Alpha parameter for mixup

    
class paths:
    OUTPUT_DIR  = "/kaggle/working"
    TRAIN_CSV   = "/kaggle/input/hms-harmful-brain-activity-classification/train.csv"
    RAW_EEGS    = "/kaggle/input/hms-harmful-brain-activity-classification/train_eegs/"
    EEG_DATASET = "/kaggle/input/brain-eegs/eegs.npy" 
    
pl.seed_everything(Config.SEED, workers=True)

warnings.filterwarnings('ignore')

# <b><span style='color: #196f3d '>|</span> Utility Functions</b><a class='anchor' id='utils'></a> [↑](#top) 

***

In [None]:
def eeg_from_parquet(parquet_path: str, display: bool = False) -> np.ndarray:
    """
    This function reads a parquet file and extracts the middle 50 seconds of the recording. Then it fills NaN values
    with the mean value (ignoring NaNs).
    :param parquet_path: path to parquet file.
    :param display: whether to display EEG plots or not.
    :return data: np.array of shape  (time_steps, eeg_features) -> (10_000, 8)
    """
    #CHOICE: Extract middle 50 seconds of parquet
    eeg = pd.read_parquet(parquet_path, columns=eeg_features)
    rows = len(eeg)                       # total amount of samples of eeg
    offset = (rows - 10_000) // 2         # 50s * 200 Hz= 10_000 samples
    eeg = eeg.iloc[offset:offset+10_000]  # middle 50 seconds, has the same amount of readings to left and right
    if display: 
        plt.figure(figsize=(10,5))
        offset = 0
        
    # === Convert to numpy ===
    data = np.zeros((10_000, len(eeg_features))) # create placeholder of same shape with zeros
    for index, feature in enumerate(eeg_features):
        x = eeg[feature].values.astype('float32') # CHOICE: convert to float32, can also try float 16 [gives overflow though]
        mean = np.nanmean(x) # arithmetic mean along the specified axis, ignoring NaNs
        nan_percentage = np.isnan(x).mean() # percentage of NaN values in feature
        # === Fill nan values ===
        if nan_percentage < 1: # if some values are nan, but not all
            x = np.nan_to_num(x, nan=mean)
        else: # if all values are nan
            x[:] = 0
        data[:, index] = x
        
    # === Plot EEGs ===
        if display: 
            if index != 0:
                offset += x.max()
            plt.plot(range(10_000), x-offset, label=feature)
            offset -= x.min()
    if display:
        plt.legend()
        name = parquet_path.split('/')[-1].split('.')[0]
        plt.yticks([])
        plt.title(f'EEG {name}',size=16)
        plt.show()    

    return data

def sep():
    'Prints a line!!!!'
    print("-"*100)
    
def show_batch(batch, EEG_IDS, predict_arr=None):
    """Displays 6 random EEG signals from a batch in a 2x3 plot grid.

    Args:
        batch (array-like): Batch containing EEG signals and potentially labels.
            Expects (batch_size, num_channels, signal_length) or similar.
        EEG_IDS (list): List of EEG file IDs corresponding to data in the batch.
        predict_arr (array-like, optional): Array of predicted labels (for display). Defaults to None
    """

    num_items = 8  # Number of signals to display
    num_rows = 2
    num_cols = 4

    # Create a figure for plotting
    fig = plt.figure(figsize=(14, 8))

    # Randomly select 6 indices from the batch
    img_indices = np.random.randint(0, batch.shape[0], num_items)

    # Iterate over the selected indices
    for index, img_index in enumerate(img_indices):
        img = batch[img_index]
        lb = None  # Assume no labels

        # Create subplot within the grid
        ax = fig.add_subplot(num_rows, num_cols, index + 1, xticks=[], yticks=[]) 

        # If input is a PyTorch tensor, convert to NumPy array 
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu().numpy()
            img = img.transpose(1, 0)  # Adjust dimensions if needed

         # Plot each channel of the EEG signal with an offset for clarity
        offset = 0
        for j in range(img.shape[-1]):
            if j != 0: offset -= img[:, j].min() # Calculate offset 
            ax.plot(img[:, j] + offset, label=f'feature {j+1}')
            offset += img[:, j].max() + 1 # Update offset for next channel

        # Add labels and titles
        ax.set_title(f'EEG_Id = {EEG_IDS[img_index]}', size=14)

    return fig 

# <b><span style='color: #196f3d '>|</span> Load Train Metadata <b><a class='anchor' id='load_metadata'></a> [↑](#top) 

***
We load the metadata and filter it to only use unique EEG ids.

In [None]:
df = pd.read_csv(paths.TRAIN_CSV)
print(f"Dataframe shape is: {df.shape}")
df.head()

In [None]:
#Create target dictionaries
TARGETS = df.columns[-6:]
TARS = {'Seizure':0, 'LPD':1, 'GPD':2, 'LRDA':3, 'GRDA':4, 'Other':5} #labels to num
TARS_INV = {x:y for y,x in TARS.items()}                              #num to labels
Config.NUM_CLASSES = len(TARS.keys())

#Create training dataframe with only relevant features and unique IDs
EEG_IDS = df.eeg_id.unique()
train = df.groupby('eeg_id')[['patient_id']].agg('first')

tmp = df.groupby('eeg_id')[TARGETS].agg('sum')
for t in TARGETS:
    train[t] = tmp[t].values
    
y_data = train[TARGETS].values
y_data = y_data / y_data.sum(axis=1,keepdims=True)
train[TARGETS] = y_data

tmp = df.groupby('eeg_id')[['expert_consensus']].agg('first')
train['target'] = tmp

train = train.reset_index()
train = train.loc[train.eeg_id.isin(EEG_IDS)]
print('Train Dataframe with unique eeg_ids shape:', train.shape )
train.head()

# <b><span style='color: #196f3d '>|</span> Cross Validation</b><a class='anchor' id='crossvalidation'></a> [↑](#top) 

***

Pre-processes the training dataframe further using Group K-fold Validation on the `patient_id`.

In [None]:
gkf = GroupKFold(n_splits=Config.FOLDS)
for fold, (train_idx, valid_idx) in enumerate(gkf.split(train, train.target, train.patient_id)):   
    train.loc[valid_idx, 'fold'] = fold

train.to_csv(paths.OUTPUT_DIR + '5gkf_fold.csv',index = False) #save train metadata with folds
train.head()

print(f"Using {Config.FOLDS} group folds")
display(train.groupby('fold').size()), sep()
display(train.head())

# <b><span style='color: #196f3d '>|</span> Load Raw Data</b><a class='anchor' id='load_data_eeg'></a> [↑](#top) 

***

Load the EEG features we will be using and the corresponding raw EEG data from the parquet files.

In [None]:
df = pd.read_parquet(f'{paths.RAW_EEGS}1000913311.parquet')
eeg_features = df.columns
print(f'There are {len(eeg_features)} raw EEG features:')
print(list(eeg_features) )

#CHOICE: Select features to use. 
if Config.CUSTOM_FEATURES:
    eeg_features = Config.CUSTOM_FEATURES
    feature_to_index = {x:y for x,y in zip(eeg_features, range(len(eeg_features)))}
    print(f'We will use {len(eeg_features)} of these features:')
    print(list(eeg_features))

else:
    feature_to_index = {x:y for x,y in zip(eeg_features, range(len(eeg_features)))}
    print('We will use all of them')

In [None]:
%%time

# Create EEGs from parquet files in batches to avoid memory issues
if Config.CREATE_EEGS:
    all_eegs = {}
    DISPLAY = 4
    EEG_IDS = train.eeg_id.unique()           # List of unique EEG IDs from the training dataframe
    PATH = glob(paths.RAW_EEGS + "*.parquet") # List of paths to EEG parquet files
    BATCH_SIZE = 5000                         # Adjust the batch size based on your available memory

    for batch_start in range(0, len(EEG_IDS), BATCH_SIZE):
        batch_end = min(batch_start + BATCH_SIZE, len(EEG_IDS))
        print(f'Processing {batch_end} / {len(EEG_IDS)}  eeg parquets... ', end='')

        for i, eeg_id in enumerate(EEG_IDS[batch_start:batch_end]):
            if ((batch_start + i) % 100 == 0) and (i != 0):
                print(batch_start + i, ', ', end='')

            # Path to the current EEG file
            eeg_path = paths.RAW_EEGS + str(eeg_id) + ".parquet"

            # CHOICE: Load middle 50 seconds EEG data from parquet file
            data = eeg_from_parquet(eeg_path, display=(batch_start + i) < DISPLAY)

            # Add EEG data to the dictionary with the ID as the key
            all_eegs[eeg_id] = data

        # Save EEGs to python the dictionary to NumPy arrays 
        np.save(f'{paths.OUTPUT_DIR}/eegs_batch_{batch_start}', all_eegs)

        # Clear the dictionary for the next batch to save memory
        all_eegs = {} 
        gc.collect()

    # Load all batches and combine them into one dictionary
    for np_file in glob(f'{paths.OUTPUT_DIR}/eegs_batch_*'):
        batch_data = np.load(np_file, allow_pickle=True).item()
        all_eegs.update(batch_data)
        os.remove(np_file) # Clear individual batch files once combined
        gc.collect()       # Clear memory once combined

    # Save the combined dictionary 
    np.save(f'{paths.OUTPUT_DIR}/all_eegs_{Config.NUM_CHANNELS}ch', all_eegs)
    print('Saved created EEGs to disk.')
    
    
# Load the dictionary from a NumPy file if reading existing data
else:
    print(f'Reading {len(EEG_IDS)} EEG NumPys from disk.')
    # Load all batches and combine them into one dictionary
    all_eegs = np.load(paths.EEG_DATASET, allow_pickle=True).item()

In [None]:
from IPython.display import FileLink 
FileLink(r'kaggle/working/.zip')

# <b><span style='color: #196f3d '>|</span>Preprocessing Functions</b><a class='anchor' id='preprocessing'></a> [↑](#top)

In [None]:
def butter_lowpass_filter(data, cutoff_freq=20, sampling_rate=200, order=4):
    #CHOICE: CUTOFF FREQUENCY
    """Applies a Butterworth low-pass filter to the data."""
    nyquist = 0.5 * sampling_rate
    normal_cutoff = cutoff_freq / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    filtered_data = lfilter(b, a, data, axis=0)
    return filtered_data

def quantize_data(data, classes):
    """Quantizes data using Mu-law encoding.
    Args:
        classes (int): The number of quantization levels.
    """
    mu_x = mu_law_encoding(data, classes)
#     bins = np.linspace(-1, 1, classes)  # Create equally spaced bins
#     quantized = np.digitize(mu_x, bins) - 1  # Assign data to bins
    return mu_x

def mu_law_encoding(data, mu):
    """Performs Mu-law encoding on a NumPy array."""
    mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1)
    return mu_x

def mu_law_expansion(data, mu):
    """Performs Mu-law expansion (decoding) on a NumPy array."""
    s = np.sign(data) * (np.exp(np.abs(data) * np.log(mu + 1)) - 1) / mu
    return s

# <b><span style='color: #196f3d '>|</span>Augmentation Functions</b><a class='anchor' id='augmentations'></a> [↑](#top) 

***
Functions of possible augmentations to apply to our data. Will be used in creating the dataset.

In [None]:
#NOT IMPLEMENTED YET
def get_train_transform():
    """Applies data augmentation during training to prevent overfitting.

    Returns:
        A.Compose: An Albumentations composition of transformations.
    """
    return A.Compose([
           A.HorizontalFlip(p=0.5), #Randomly flips the signal horizontally.
           A.OneOf([
                    A.Cutout(max_h_size=5, max_w_size=16), #Randomly removes rectangular sections from the signal.
                    A.CoarseDropout(max_holes=4),          #Randomly sets multiple rectangular regions to zero.
                   ], p=0.5),
           ])


def get_transforms(*, data):
    """Augments EEG signals with colored noise for training data.
    Returns:
        tA.Compose: A TorchAudio composition of transformations.
    """
    if data == 'train':
        return tA.Compose(
                transforms=[
                     # tA.ShuffleChannels(p=0.25,mode="per_channel",p_mode="per_channel",),
                     tA.AddColoredNoise(p=0.15,mode="per_channel",p_mode="per_channel", max_snr_in_db = 15, sample_rate=200),
                ])

    elif data == 'valid':
        return tA.Compose([
        ])

# <b><span style='color: #196f3d '>|</span> Dataset</b><a class='anchor' id='dataset'></a> [↑](#top) 

***

Create a custom `Dataset` to load data. Here we can easily implement custom loading mechanisms, data pre-processing, and augmentation techniques. Also allows us to use the Pytorch Dataloader, streamlining data loading and batching.

In [None]:
class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, data, eegs=None, augmentations=None, test=False): 
        self.data = data  
        self.eegs = eegs 
        self.augmentations = augmentations  
        self.test = test  # Flag to indicate test mode

    def __len__(self):
        return len(self.data)  # Return the number of samples in the dataset

    def __getitem__(self, index):
        # Get a single data sample and its label (if not in test mode)
        row = self.data.iloc[index]       
        data = self.eegs[row.eeg_id]  # Load EEG data based on an ID

        if Config.FEATURE_ENGINEERING:
            if len(Config.CUSTOM_FEATURES) == 8:
                X = np.zeros((10_000, Config.NUM_CHANNELS), dtype='float32')
                    
                # === Feature engineering ===
                #LL Chain
                X[:,0] = data[:,feature_to_index['Fp1']] - data[:,feature_to_index['T3']]
                X[:,1] = data[:,feature_to_index['T3']]  - data[:,feature_to_index['O1']]
                
                #LP Chain
                X[:,2] = data[:,feature_to_index['Fp1']] - data[:,feature_to_index['C3']]
                X[:,3] = data[:,feature_to_index['C3']]  - data[:,feature_to_index['O1']]
                
                #RP Chain
                X[:,4] = data[:,feature_to_index['Fp2']] - data[:,feature_to_index['C4']]
                X[:,5] = data[:,feature_to_index['C4']]  - data[:,feature_to_index['O2']]
                
                #RR Chain
                X[:,6] = data[:,feature_to_index['Fp2']] - data[:,feature_to_index['T4']]
                X[:,7] = data[:,feature_to_index['T4']]  - data[:,feature_to_index['O2']]
                
            elif not Config.CUSTOM_FEATURES:
                Config.NUM_CHANNELS = 16 #update number of channels
                X = np.zeros((10_000, Config.NUM_CHANNELS), dtype='float16') #Full data cannot be represented with float32 anyway

                #LL Chain
                X[:,0] = data[:,feature_to_index['Fp1']] - data[:,feature_to_index['F7']]
                X[:,1] = data[:,feature_to_index['F7']]  - data[:,feature_to_index['T3']]
                X[:,2] = data[:,feature_to_index['T3']] - data[:,feature_to_index['T5']]
                X[:,3] = data[:,feature_to_index['T5']]  - data[:,feature_to_index['O1']]

                #LP Chain
                X[:,4] = data[:,feature_to_index['Fp1']] - data[:,feature_to_index['F3']]
                X[:,5] = data[:,feature_to_index['C4']]  - data[:,feature_to_index['C3']]
                X[:,6] = data[:,feature_to_index['Fp2']] - data[:,feature_to_index['P3']]
                X[:,7] = data[:,feature_to_index['T4']]  - data[:,feature_to_index['O1']]

                #RP Chain
                X[:,8] = data[:,feature_to_index['Fp2']] - data[:,feature_to_index['F4']]
                X[:,9] = data[:,feature_to_index['F4']]  - data[:,feature_to_index['C4']]
                X[:,10] = data[:,feature_to_index['C4']] - data[:,feature_to_index['P4']]
                X[:,11] = data[:,feature_to_index['P4']]  - data[:,feature_to_index['O2']]

                #RR Chain
                X[:,12] = data[:,feature_to_index['Fp2']] - data[:,feature_to_index['F8']]
                X[:,13] = data[:,feature_to_index['F8']]  - data[:,feature_to_index['T4']]
                X[:,14] = data[:,feature_to_index['T4']] - data[:,feature_to_index['T6']]
                X[:,15] = data[:,feature_to_index['T6']]  - data[:,feature_to_index['O2']]
                
        else:
            X = data

        # === Standarize ===
        X = np.clip(X,-1024,1024)    
        X = np.nan_to_num(X, nan=0) / 32.0  
        
        # === Preprocess ===
        X = butter_lowpass_filter(X)  # Apply a low-pass filter 
        X = quantize_data(X, 1)       # Apply quantization
        samples = torch.from_numpy(X).float()  # Convert to a PyTorch tensor

        # Apply data augmentations
        samples = self.augmentations(samples.unsqueeze(0), None)
        samples = samples.squeeze()
        samples = samples.permute(1, 0)  # Adjust tensor shape to get (time_steps, features)

        if not self.test:
            label = row[TARGETS]         # Get label from metadata
            label = torch.tensor(label).float()  #converts to float32
            return samples, label        # Return preprocessed data and label
        else:
            return samples               # Return preprocessed data only (for testing)

# <b><span style='color: #196f3d '>|</span> DataLoader</b><a class='anchor' id='dataloader'></a> [↑](#top) 

***

In [None]:
def get_fold_dls(df_train, df_valid):
    """Creates DataLoaders(dls) for training and validation, assuming data is already split.

    Args:
        df_train (pd.DataFrame): DataFrame containing training data metadata.
        df_valid (pd.DataFrame): DataFrame containing validation data metadata.

    Returns:
        tuple: 
            * dl_train (DataLoader): DataLoader for training.
            * dl_val (DataLoader)  : DataLoader for validation.
            * ds_train (EEGDataset): Training dataset object.
            * ds_val (EEGDataset)  : Validation dataset object.
    """

    # Create EEGDataset objects for training and validation
    ds_train = EEGDataset(
        df_train, 
        eegs=all_eegs,
        augmentations=get_transforms(data='valid'),  # Apply valid set transforms
        test = False  # Indicate training mode 
    )

    ds_val = EEGDataset(
        df_valid, 
        eegs=all_eegs,
        augmentations=get_transforms(data='valid'),  # Apply valid set transforms
        test = False  # Indicate training mode
    )

    # Create DataLoaders 
    dl_train = DataLoader(ds_train, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=Config.NUM_WORKERS, pin_memory = True)
    dl_val = DataLoader(ds_val, batch_size=Config.BATCH_SIZE, num_workers=Config.NUM_WORKERS, pin_memory = True)

    return dl_train, dl_val, ds_train, ds_val

In [None]:
#Lets plot a dummy batch:
dummy_train = train[train['fold']!=0].copy()
dummy_valid = train[train['fold']==0].copy()

dl_train, dl_val, ds_train, ds_val = get_fold_dls(dummy_train, dummy_valid)

batch_images, _ = next(iter(dl_train))

print(batch_images.size())

fig = show_batch(batch_images, EEG_IDS)

plt.tight_layout()
plt.show()

# <b><span style='color: #196f3d '>|</span> Model</b><a class='anchor' id='model'></a> [↑](#top) 

***

In [None]:
# class ResNet_1D_Block(nn.Module):
#     """
#     Implements a ResNet-inspired 1D convolutional residual block 

#     Args:
#         in_channels (int): Number of input channels.
#         out_channels (int): Number of output channels.
#         kernel_size (int): Kernel size for the convolutions.
#         stride (int): Stride for the convolutions.
#         padding (int): Padding for the convolutions.
#         downsampling (bool): Whether to apply downsampling.
#     """

#     def __init__(self, in_channels, out_channels, kernel_size, stride, padding, downsampling):
#         super(ResNet_1D_Block, self).__init__()
#         self.bn1 = nn.BatchNorm1d(num_features=in_channels)
#         self.relu = nn.ReLU(inplace=False)                               
#         self.dropout = nn.Dropout(p=Config.P_DROPOUT, inplace=False)    
#         self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
#                                stride=stride, padding=padding, bias=False) 
        
#         self.bn2 = nn.BatchNorm1d(num_features=out_channels)                
#         self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
#                                stride=stride, padding=padding, bias=False)  
        
#         self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)   
#         self.downsampling = downsampling
#         self.downsample_layer = nn.MaxPool1d(kernel_size=2, stride=2, padding=0) 

#     def forward(self, x):
#         """Basically applies a BN, Relu, Dropout and 1D Convolution twice and then maxpools
#            x: Tensor(batch_size, features, time_steps)"""
#         identity = x  # Store input for residual connection

#         out = self.bn1(x)
#         out = self.relu(out)
#         out = self.dropout(out)
#         out = self.conv1(out)
        
#         out = self.bn2(out)
#         out = self.relu(out)
#         out = self.dropout(out)
#         out = self.conv2(out)

#         out = self.maxpool(out)  
#         if self.downsampling:
#             identity = self.downsample_layer(x)  # Apply downsampling if needed

#         out += identity  # Add the orginal input (residual connection)
#         return out

    
# class EEGNet(nn.Module):
#     """
#     EEGNet architecture: A convolutional neural network with ResNet-style blocks designed for EEG signal classification.

#     Args:
#         kernels (list): A list of kernel sizes for the initial parallel convolutions.
#         in_channels (int, optional): Number of input channels (e.g., number of electrodes). 
#                                     Defaults to 20.
#         fixed_kernel_size (int, optional): Kernel size for subsequent shared convolutions. 
#                                            Defaults to 17.
#         NUM_CLASSES (int, optional): Number of output classes for classification. 
#                                      Defaults to 6.
#     """

#     def __init__(self, kernels = Config.KERNELS, num_feature_maps = Config.NUM_FEATURE_MAPS, in_channels=1, 
#                        fixed_kernel_size=Config.FIXED_KERNEL_SIZE, num_classes=6):
#         super(EEGNet, self).__init__()
#         self.kernels = kernels
#         self.planes = num_feature_maps  # Initial number of feature maps (=planes) outputed for each kernel (and used in 1D conv and ResNet)
#         self.in_channels = 1
        
#         #----DEFINE LAYERS TO BE USED----
#         self.parallel_conv = nn.ModuleList()  # Parallel convolutions with varying kernel sizes
#         for kernel_size in kernels:
#             conv = nn.Conv1d(in_channels, self.planes, kernel_size, stride=1, padding=0, bias=False)
#             self.parallel_conv.append(conv)

#         # Batch normalizations, ReLU activation
#         self.bn1 = nn.BatchNorm1d(self.planes)
#         self.relu = nn.ReLU(inplace=False)

#         # Shared convolution and ResNet blocks
#         self.conv1 = nn.Conv1d(self.planes, self.planes, fixed_kernel_size, stride=2, padding=2, bias=False)
#         self.block = self._make_resnet_layer(fixed_kernel_size, stride=1, padding=fixed_kernel_size//2)

#         # Downsampling Layers 
#         self.bn2 = nn.BatchNorm1d(self.planes)
#         self.avgpool = nn.AvgPool1d(6, stride=6, padding=2)
        
#         # Gated Recurrent Unit (GRU) for processing sequential patterns in the EEG data
#         self.rnn = nn.GRU(self.in_channels, hidden_size=128, num_layers=1, bidirectional=True)
 

#     def _make_resnet_layer(self, kernel_size, stride, blocks=Config.RESNET_BLOCKS, padding=0):
#         """Creates a sequence of ResNet blocks."""
#         layers = []
#         for _ in range(blocks):
#             downsampling = nn.MaxPool1d(2, stride=2, padding=0)  # Downsampling if needed
#             layers.append(ResNet_1D_Block(self.planes, self.planes, kernel_size, stride, padding, downsampling))
#         return nn.Sequential(*layers)

#     def forward(self, x):
#         """Forward pass through the network. 
#         x(Tensor): (batch_size, features, time_steps)"""
#         # ===PARALLEL CONVOLUTION FEATURE EXTRACTOR====
#         conv_out_sep = [conv(x) for conv in self.parallel_conv]  #list containing outputs (batch_size,planes,time_steps)

#         # Concatenate outputs
#         conv_out = torch.cat(conv_out_sep, dim=2) #
        
#         # 1D Convolution
#         conv_out = self.bn1(conv_out)
#         conv_out = self.relu(conv_out)
#         conv_out = self.conv1(conv_out)
        
#         # ResNet Block Stack
#         conv_out = self.block(conv_out)
        
#         # Average Pooling
#         conv_out = self.bn2(conv_out)
#         conv_out = self.relu(conv_out)
#         conv_out = self.avgpool(conv_out)

#         # Flatten for dense layers
#         conv_out = conv_out.reshape(conv_out.shape[0], -1)  

#         # ===GRU FEATURE EXTRACTOR====
#         rnn_out, _ = self.rnn(x.permute(0, 2, 1))  # permute to shape (batch_size, time_steps, features)
#         rnn_out = rnn_out[:, -1, :]                # take the final hidden state of the GRU, shape (batch_size, 2 * hidden_size)
        
#         # Combine both parallel convolution features and GRU Features
#         out = torch.cat([conv_out, rnn_out], dim=1)  # concatenate over feature representation 
#         return out  #shape (batch_size, 2 * hidden_size + len(conv_out))
     
# class EEGNet_Parallel(nn.Module):
#     """Modified EEGNet with parallel chain processing inspired by the WaveNet starter."""

#     def __init__(self):
#         super(EEGNet_Parallel, self).__init__()
#         self.model = EEGNet()
#         self.global_avg_pooling = nn.AdaptiveAvgPool1d(1)
#         self.dropout = Config.P_DROPOUT
        
#         self.head = nn.Sequential(
#             nn.Linear(2272, 64),
#             nn.BatchNorm1d(64),
#             nn.ReLU(),
#             nn.Dropout(self.dropout),
#             nn.Linear(64, 6)
#         )

        
#     def forward(self, x: torch.Tensor):
#         """
#         Forward pass. Input shape is (batch_size, hidden_features)
#         """
#         x1 = self.model(x[:,0:1,:])
#         x2 = self.model(x[:,1:2,:])
#         z1 = torch.mean(torch.stack([x1, x2], dim = 1), dim=1) # Take the mean over both signals

#         x1 = self.model(x[:,2:3,:])
#         x2 = self.model(x[:,3:4,:])
#         z2 = torch.mean(torch.stack([x1, x2], dim = 1), dim=1)
    
#         x1 = self.model(x[:,4:5,:])
#         x2 = self.model(x[:,5:6,:])
#         z3 = torch.mean(torch.stack([x1, x2], dim = 1), dim=1)
        
#         x1 = self.model(x[:,6:7,:])
#         x2 = self.model(x[:,7:8,:])
#         z4 = torch.mean(torch.stack([x1, x2], dim = 1), dim=1)
        
#         y = torch.cat([z1, z2, z3, z4], dim=1)
#         y = self.head(y)
        
#         return y

In [None]:
# model = EEGNet_Parallel()
# total_params = sum(p.numel() for p in model.parameters())
# print(f"Total number of parameters: {total_params}")

In [None]:
class ResNet_1D_Block(nn.Module):
    """
    Implements a ResNet-inspired residual block for 1D convolutional networks.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int): Kernel size for the convolutions.
        stride (int): Stride for the convolutions.
        padding (int): Padding for the convolutions.
        downsampling (bool): Whether to apply downsampling.
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, downsampling):
        super(ResNet_1D_Block, self).__init__()
        self.bn1 = nn.BatchNorm1d(num_features=in_channels)
        self.relu = nn.ReLU(inplace=False)                               
        self.dropout = nn.Dropout(p=Config.P_DROPOUT, inplace=False)    

        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False) 
        self.bn2 = nn.BatchNorm1d(num_features=out_channels)                

        self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)  
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)   
        self.downsampling = downsampling
        self.downsample_layer = nn.MaxPool1d(kernel_size=2, stride=2, padding=0) 

    def forward(self, x):
        "Basically applies a BN, Relu, Dropout and 1D Convolution twice and then maxpools"
        identity = x  # Store input for residual connection

        out = self.bn1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)

        out = self.maxpool(out)  
        if self.downsampling:
            identity = self.downsample_layer(x)  # Apply downsampling if needed

        out += identity  # Add the orginal input (residual connection)
        return out

    
class EEGNet(nn.Module):
    """
    EEGNet architecture: A convolutional neural network with ResNet-style blocks designed for EEG signal classification.

    Args:
        kernels (list): A list of kernel sizes for the initial parallel convolutions.
        in_channels (int, optional): Number of input channels (e.g., number of electrodes). 
                                    Defaults to 20.
        fixed_kernel_size (int, optional): Kernel size for subsequent shared convolutions. 
                                           Defaults to 17.
        NUM_CLASSES (int, optional): Number of output classes for classification. 
                                     Defaults to 6.
    """

    def __init__(self, kernels = Config.KERNELS, num_feature_maps = Config.NUM_FEATURE_MAPS, in_channels=Config.NUM_CHANNELS, 
                       fixed_kernel_size=Config.FIXED_KERNEL_SIZE, num_classes=6):
        super(EEGNet, self).__init__()
        self.kernels = kernels
        self.planes = num_feature_maps  # Initial number of feature maps (=planes) outputed for each kernel (and used in 1D conv and ResNet)
        self.in_channels = in_channels
        
        #----DEFINE LAYERS TO BE USED----
        self.parallel_conv = nn.ModuleList()  # Parallel convolutions with varying kernel sizes
        for kernel_size in kernels:
            conv = nn.Conv1d(in_channels, self.planes, kernel_size, stride=1, padding=0, bias=False)
            self.parallel_conv.append(conv)

        # Batch normalizations, ReLU activation
        self.bn1 = nn.BatchNorm1d(self.planes)
        self.relu = nn.ReLU(inplace=False)

        # Shared convolution and ResNet blocks
        self.conv1 = nn.Conv1d(self.planes, self.planes, fixed_kernel_size, stride=2, padding=2, bias=False)
        self.block = self._make_resnet_layer(fixed_kernel_size, stride=1, padding=fixed_kernel_size//2)

        # Downsampling Layers 
        self.bn2 = nn.BatchNorm1d(self.planes)
        self.avgpool = nn.AvgPool1d(6, stride=6, padding=2)
        
        # Gated Recurrent Unit (GRU) for processing sequential patterns in the EEG data
        self.rnn = nn.GRU(self.in_channels, hidden_size=128, num_layers=1, bidirectional=True)
        
        #Final FC layer
        self.fc = nn.Linear(424, num_classes) 

    def _make_resnet_layer(self, kernel_size, stride, blocks=Config.RESNET_BLOCKS, padding=0):
        """Creates a sequence of ResNet blocks."""
        layers = []
        for _ in range(blocks):
            downsampling = nn.MaxPool1d(2, stride=2, padding=0)  # Downsampling if needed
            layers.append(ResNet_1D_Block(self.planes, self.planes, kernel_size, stride, padding, downsampling))
        return nn.Sequential(*layers)

    def forward(self, x):
        """Forward pass through the network. First separate parallel convolutions 
           which are then concetaned to go through 1D conv and ResNet block"""
        # Parallel convolutions, list containing output of different kernels applied to input
        out_sep = [conv(x) for conv in self.parallel_conv]  

        # Concatenate outputs
        out = torch.cat(out_sep, dim=2) 
        
        # 1D Convolution
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv1(out)
        
        # ResNet Block [Has BN and ReLu build in]
        out = self.block(out)
        
        # Avgerage Pooling
        out = self.bn2(out)
        out = self.relu(out)
        out = self.avgpool(out)

        # Flatten for dense layers
        out = out.reshape(out.shape[0], -1)  

        # Pass through GRU for sequence representations
        rnn_out, _ = self.rnn(x.permute(0, 2, 1))   
        new_rnn_h = rnn_out[:, -1, :]  
        new_out = torch.cat([out, new_rnn_h], dim=1)  # Combine features
        result = self.fc(new_out)  # Final classification layer

        return result

In [None]:
model = EEGNet()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

# <b><span style='color: #196f3d '>|</span> Optimizer \& Scheduler</b><a class='anchor' id='optimizer'></a> [↑](#top) 

***

In [None]:
def get_optimizer(lr, params):
    """
    Creates an Adam optimizer and a CosineAnnealingWarmRestarts learning rate scheduler.

    Args:
        lr (float): The initial learning rate.
        params (iterable): An iterable of model parameters to be optimized.

    Returns:
        dict: A dictionary containing the optimizer and scheduler configuration.
    """

    # Create the Adam optimizer, filtering for trainable parameters only
    model_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, params),  # Only include trainable parameters
        lr=lr,
        weight_decay=Config.WEIGHT_DECAY            # Apply weight decay 
    )

    interval = "epoch"  # Update learning rate per epoch

    # Create Cosine Annealing scheduler with warm restarts
    lr_scheduler = scheduler.CosineAnnealingWarmRestarts(
        model_optimizer,
        T_0=Config.EPOCHS,              # Initial cycle length
        T_mult=1,                       # Cycle length increases by a factor of 1 after restart
        eta_min=1e-6,                   # Minimum learning rate
        last_epoch=-1                   # Starts at the beginning 
    )

    # Package the optimizer and scheduler configuration 
    return {
        "optimizer": model_optimizer, 
        "lr_scheduler": {
            "scheduler": lr_scheduler,
            "interval": interval,
            "monitor": "val_loss",      # Monitor validation loss for potential adjustments
            "frequency": 1              # Update the LR every epoch
        }
    }

# <b><span style='color: #196f3d '>|</span> Loss Function</b><a class='anchor' id='loss'></a> [↑](#top) 

***

In [None]:
class KLDivLossWithLogits(nn.KLDivLoss):

    def __init__(self):
        super().__init__(reduction="batchmean")

    def forward(self, y, t):
        y = nn.functional.log_softmax(y,  dim=1)
        loss = super().forward(y, t)

        return loss

# <b><span style='color: #196f3d '>|</span> Pytorch Lightning Model</b><a class='anchor' id='lightning'></a> [↑](#top) 

***

To use PyTorch Lightning effectively, we define our model as classes that inherit from the `pl.LightningModule` base class:  
- `_init__()`: Where we define the architecture (layers, etc.).  
- `forward()`: The computation performed during inference.  
- `training_step()`: The logic for a single training batch.   
- `validation_step()`: The logic for a single validation batch.    
- `configure_optimizers()`: Where we set up optimizers.  

PyTorch Lightning takes care of the training loop, validation, logging, checkpointing, and much more under the hood based on the methods you define in your class.

In [None]:
class EEGModel(pl.LightningModule):
    """
    Pytorch Lightning module for EEG classification.
    """

    def __init__(self, num_classes = Config.NUM_CLASSES, pretrained = Config.PRETRAINED, fold = fold):
        super().__init__()
        self.num_classes = num_classes
        self.fold = fold

        # Create the EEGNet backbone with specified kernel sizes
        self.backbone = EEGNet()

        # Loss function for multi-class classification
        self.loss_function = KLDivLossWithLogits() 

        self.validation_step_outputs = []  # Storage for validation results
        self.lin = nn.Softmax(dim=1)       # Softmax for probability outputs 
        self.best_score = 1000.0           # Track best validation score

    def forward(self, images):
        """Forward pass through the EEGNet model."""
        logits = self.backbone(images)  # Extract features and get predictions
        return logits

    def configure_optimizers(self):
        """Set up the optimizer"""
        return get_optimizer(lr=Config.LR, params=self.parameters())

    def train_with_mixup(self, X, y):
        X, y_a, y_b, lam = mixup_data(X, y, alpha=Config.MIXUP_ALPHA)
        y_pred = self(X)
        loss_mixup = mixup_criterion(KLDivLossWithLogits(), y_pred, y_a, y_b, lam)
        return loss_mixup

    def training_step(self, batch, batch_idx):
        images, targets = batch        
        if Config.USE_MIXUP:
            loss = self.train_with_mixup(images, targets)
        else:
            y_pred = self(images)
            loss = self.loss_function(y_pred,targets)

        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger = True)
        
        if batch_idx % 1000 == 0:
            fig = show_batch(images, EEG_IDS) 
            self.logger.experiment.add_figure('EEG_data', fig, self.global_step)
        return loss        

    def validation_step(self, batch, batch_idx):
        image, target = batch 
        y_pred = self(image)
        val_loss = self.loss_function(y_pred, target)
        self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.validation_step_outputs.append({"val_loss": val_loss, "logits": y_pred, "targets": target})

        return {"val_loss": val_loss, "logits": y_pred, "targets": target}

    def train_dataloader(self):
        return self._train_dataloader 
    
    def validation_dataloader(self):
        return self._validation_dataloader
    
    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        # print(len(outputs))
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        output_val = nn.Softmax(dim=1)(torch.cat([x['logits'] for x in outputs],dim=0)).cpu().detach().numpy()
        target_val = torch.cat([x['targets'] for x in outputs],dim=0).cpu().detach().numpy()
        self.validation_step_outputs = []

        val_df = pd.DataFrame(target_val, columns = list(TARGETS))
        pred_df = pd.DataFrame(output_val, columns = list(TARGETS))

        val_df['id'] = [f'id_{i}' for i in range(len(val_df))] 
        pred_df['id'] = [f'id_{i}' for i in range(len(pred_df))] 

        avg_score = score(val_df, pred_df, row_id_column_name = 'id')

        if avg_score < self.best_score:
            print(f'Fold {self.fold}: Epoch {self.current_epoch} validation loss {avg_loss}')
            print(f'Fold {self.fold}: Epoch {self.current_epoch} validation KDL score {avg_score}')
            self.best_score = avg_score
            # val_df.to_csv(f'{paths.OUTPUT_DIR}/val_df_f{self.fold}.csv',index=False)
            # pred_df.to_csv(f'{paths.OUTPUT_DIR}/pred_df_f{self.fold}.csv',index=False)
        
        return {'val_loss': avg_loss,'val_cmap':avg_score}

# <b><span style='color: #196f3d '>|</span> Prediction Function </b><a class='anchor' id='prediction'></a> [↑](#top) 

***
We will use a seperate prediction function to generate predictions of the trained model, outside the Pytorch Lightning environment

In [None]:
def predict(data_loader, model):
    """
    Generates predictions from a PyTorch model using a dataloader.
    """

    model.to('cuda')  # Can't do this stuff with 2x GPU or PyT will throw a massive hissy fit, but if done in PyL it will also do so fuck it
    model.eval()      # Set the model to evaluation mode (disables dropout, etc.)

    predictions = []
    for batch in tqdm(data_loader):  # Iterate over batches in the dataloader
        with torch.no_grad():        # Disable gradient calculation for efficiency
            x, y = batch  
            x = x.cuda()           

            outputs = model(x)       # Get predictions from the model
            outputs = nn.Softmax(dim=1)(outputs)  # Apply softmax for probabilities

            # Collect and convert predictions to NumPy array
            predictions.extend(outputs.detach().cpu().numpy())  

    predictions = np.vstack(predictions)  # Stack predictions into a single array
    return predictions

# <b><span style='color: #196f3d '>|</span> Train Loop</b><a class='anchor' id='train_loop'></a> [↑](#top) 

***

In [None]:
torch.set_float32_matmul_precision('high')
def run_training_fold(fold_id):
    """
    Trains an EEGNet model on a specified data fold using PyTorch Lightning.
    """
    print(f"Running training for fold {fold_id}...")

    # Configure logger
    logger = TensorBoardLogger("logs", name = Config.EXPERIMENT_NAME) 
    # Column names for predictions
    pred_cols = [f'pred_{t}' for t in TARGETS]  

    # ======== SPLIT ==========
    df_train = train[train['fold']!=fold_id].copy()
    df_valid = train[train['fold']==fold_id].copy()
    print(len(df_train), 'train length')
    print(len(df_valid), 'valid length')
    
    # ======== DATASETS & DATALOADERS==========
    dl_train, dl_val, ds_train, ds_val = get_fold_dls(df_train, df_valid)

    # ======== MODEL ==========
    eeg_model = EEGModel(num_classes=Config.NUM_CLASSES, pretrained=Config.PRETRAINED, fold=fold_id)

    # Set up PyTorch Lightning callbacks
    early_stop_callback = EarlyStopping(monitor="val_loss",min_delta=0.00, patience=Config.PATIENCE, verbose=True, mode="min")
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=f"{paths.OUTPUT_DIR}/",
        save_top_k=1,
        save_last=True,
        save_weights_only=False,
        filename=f'eegnet_best_loss_fold{fold_id}',
        verbose=True,
        mode='min')
    
    callbacks_to_use = [checkpoint_callback, early_stop_callback]
    
    #Configure Trainer
    trainer = pl.Trainer(
        devices = "auto",  
        val_check_interval = 1.0,
        deterministic = True,
        max_epochs = Config.EPOCHS,
        logger = logger,
        callbacks = callbacks_to_use,
        precision = Config.PRECISION,  
        accelerator = "gpu",
        overfit_batches = 0.0)


    # ======= TRAIN ==========
    print("Running trainer.fit")
    trainer.fit(eeg_model, train_dataloaders=dl_train, val_dataloaders=dl_val)

    # ======= EVALUATION ==========
    model = EEGModel.load_from_checkpoint(f'{paths.OUTPUT_DIR}/eegnet_best_loss_fold{fold_id}.ckpt')
    predictions = predict(dl_val, eeg_model)

    # Save predictions to CSV
    df_valid[pred_cols] = predictions
    df_valid.to_csv(f'{paths.OUTPUT_DIR}/pred_df_f{fold_id}.csv', index=False)

    # Release some memory 
    gc.collect()

    return predictions

# <b><span style='color: #196f3d '>|</span> Train Model</b><a class='anchor' id='train'></a> [↑](#top) 

***

In [None]:
# Create a copy of the training dataset
oof_df = train.copy() 

# Initialize prediction columns with zeros
pred_cols = [f'pred_{t}' for t in TARGETS] 
oof_df[pred_cols] = 0.0

# K-Fold Cross-Validation Loop
for fold_nr in range(Config.FOLDS):
    # Get indices for the validation set in this fold  
    val_idx = list(train[train['fold']== fold_nr].index) 

    # Train the model on the current fold, obtain predictions for the validation set
    val_preds = run_training_fold(fold_nr)

    # Update the 'oof_df' dataframe with predictions for the current fold [each val_idx will only appear once over all folds]
    oof_df.loc[val_idx, pred_cols] = val_preds

# <b><span style='color: #196f3d '>|</span> Score </b><a class='anchor' id='score'></a> [↑](#top) 

***

In [None]:
# ======= SCORING ==========
oof_pred_df = oof_df[['eeg_id'] + list(['pred_'+i for i in TARGETS])]

# Rename columns
oof_pred_df.columns = ['eeg_id'] + list(TARGETS) 

# Create a copy for true values from 'oof_df' 
oof_true_df = oof_df[oof_pred_df.columns].copy()

# Calculate evaluation score using OOF predictions and true values
oof_score = score(solution=oof_true_df, submission=oof_pred_df, row_id_column_name='eeg_id')

# Save out-of-fold predictions to a CSV file
oof_df.to_csv(f'{paths.OUTPUT_DIR}/oof.csv', index=False)

# Display the calculated OOF score
print('OOF Score for solution =', oof_score)

# <b><span style='color: #196f3d '>|</span> Further Explanations </b><a class='anchor' id='explanation'></a> [↑](#top)

**PARALLEL CONVOLUTIONS**  
<center><img src=https://www.googleapis.com/download/storage/v1/b/kaggle-forum-message-attachments/o/inbox%2F4712534%2F9317ffc90ca7cf9cd990cc87250b66ab%2FScreenshot%202024-01-29%20at%209.49.18%20AM.png?generation=1706503620778537&alt=media width=600></center>


**FEATURE ENGINEERING**  
Feature 1 is the beginning of the montage chains `LL` and `LP` minus the ending of montage `LL` and `LP`. And feature 2 is the beginning of the montage chains `RL` and `RP` minus the ending of montage `RL` and `RP`.
Use 8 features grouped as 4 chains.

<center><img src=https://raw.githubusercontent.com/cdeotte/Kaggle_Images/main/Jan-2024/montage.png width=600></center>
