# <b><span style='color: #196f3d '>| HMS:</span>  EEGNET [Inference]</b> 
This notebook is primarly based on [Nisschay Dhankars EEGNET Starter](https://www.kaggle.com/code/nischaydnk/lightning-1d-eegnet-training-pipeline-hbs). The layout from [Moths WaveNet Starter](https://www.kaggle.com/code/alejopaullier/hms-wavenet-pytorch-train) and certain ideas from [Chris Deottes WaveNet Starter](https://www.kaggle.com/code/cdeotte/wavenet-starter-lb-0-52) were also implemented in this notebook. This notebook is used for inference. The training notebook containing more information about the model can be found [here](https://www.kaggle.com/code/jasperpieterse/hms-1d-eegnet-training).

### <b><span style='color: #196f3d '>Table of Contents</span></b> <a class='anchor' id='top'></a>
<div style=" background-color: #ecf0f1 ; padding: 13px 13px; border-radius: 8px; color: white">
<li><a href="#import_libraries">Import Libraries</a></li>
<li><a href="#configuration">Configuration</a></li>
<li><a href="#utils">Utility Functions</a></li>
<li><a href="#load_meta_data">Load Test Metadata</a></li>    
<li><a href="#load_data_eeg">Load Raw Data</a></li>
<li><a href="#preprocess">Preprocessing Functions</a></li>
<li><a href="#dataset">Dataset</a></li>
<li><a href="#dataloader">DataLoader</a></li>
<li><a href="#model">Model</a></li>
<li><a href="#optimizer">Optimizer & Scheduler</a></li>
<li><a href="#loss">Loss Function</a></li>
<li><a href="#lightning">Pytorch Lightning Model</a></li>
<li><a href="#prediction">Prediction Function</a></li>
<li><a href="#inference">Inference</a></li>
<li><a href="#submission">Submission</a></li>
    
</div>

# <b><span style='color:#196f3d'>|</span> Import Libraries</b><a class='anchor' id='import_libraries'></a> [↑](#top) 

***

Import all the required libraries for this notebook.

In [None]:
import sys
import os
import gc
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations as A
import librosa
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.io 
import torch.multiprocessing as mp
import torch.optim.lr_scheduler as scheduler 
import torchmetrics
import pytorch_lightning as pl

sys.path.append('/kaggle/input/kaggle-kl-div')
from kaggle_kl_div import score
from glob import glob
from tqdm import tqdm
from scipy.signal import butter, lfilter
from torch.utils.data import Dataset, DataLoader

# <b><span style='color:#196f3d'>|</span> Configuration</b><a class='anchor' id='configuration'></a> [↑](#top) 

***

In [None]:
class Config:
    #Data
    CUSTOM_FEATURES =['Fp1','T3','C3','O1','Fp2','C4','T4','O2'] # List containing what EEG features to use to create data. Leave empty to use all2
#     CUSTOM_FEATURES = []
    NUM_CLASSES = 6 
    FEATURE_ENGINEERING = True
    
    #Training
    NUM_WORKERS = 2         
    PRECISION   = 32        
    BATCH_SIZE  = 32
    EPOCHS      = 30 
    PATIENCE    = 20   
    P_DROPOUT   = 0.0
    SEED        = 2024
    LR          = 8e-3
    FOLDS       = 5 
    
    #Model parameters
    KERNELS = [3,5,7,9]
    FIXED_KERNEL_SIZE = 5 #Size of the kernel used in the fixed blocks
    NUM_FEATURE_MAPS = 24 #Number of feature maps (channels) outputed for each parallel convolution (num_channels = num_kernels * num_feature_maps)
    RESNET_BLOCKS = 9     #Amount of ResNet blocks to be used
    if CUSTOM_FEATURES:
        NUM_CHANNELS = len(CUSTOM_FEATURES)
    else:  
        NUM_CHANNELS = 20     # Amount of channels used for CNN


    #Augmentation
    PRETRAINED = False    #Whether to use a pretrained EEG net        
    WEIGHT_DECAY = 0.01
    USE_MIXUP = False   # Whether to use mixup augmentation
    MIXUP_ALPHA = 0.1   # Alpha parameter for mixup

    
class paths:
    OUTPUT_DIR   = "/kaggle/working/"
    TRAIN_CSV    = "/kaggle/input/hms-harmful-brain-activity-classification/train.csv" 
    TEST_CSV     = "/kaggle/input/hms-harmful-brain-activity-classification/test.csv"
    RAW_EEGS     = "/kaggle/input/hms-harmful-brain-activity-classification/test_eegs/"
    WEIGHTS      = "/kaggle/input/probably-best" 

pl.seed_everything(Config.SEED, workers=True)

warnings.filterwarnings('ignore')
print("Using float 32 for intermediate calculations")
torch.set_float32_matmul_precision('high')

# <b><span style='color:#196f3d'>|</span> Utility Functions</b><a class='anchor' id='utils'></a> [↑](#top) 


In [None]:
def eeg_from_parquet(parquet_path: str, display: bool = False) -> np.ndarray:
    """
    This function reads a parquet file and extracts the middle 50 seconds of the recording. Then it fills NaN values
    with the mean value (ignoring NaNs).
    :param parquet_path: path to parquet file.
    :param display: whether to display EEG plots or not.
    :return data: np.array of shape  (time_steps, eeg_features) -> (10_000, 8)
    """
    #CHOICE: Extract middle 50 seconds of parquet
    eeg = pd.read_parquet(parquet_path, columns=eeg_features)
    rows = len(eeg)                       # total amount of samples of eeg
    offset = (rows - 10_000) // 2         # 50s * 200 Hz= 10_000 samples
    eeg = eeg.iloc[offset:offset+10_000]  # middle 50 seconds, has the same amount of readings to left and right
    if display: 
        plt.figure(figsize=(10,5))
        offset = 0
        
    # === Convert to numpy ===
    data = np.zeros((10_000, len(eeg_features))) # create placeholder of same shape with zeros
    for index, feature in enumerate(eeg_features):
        x = eeg[feature].values.astype('float32') # CHOICE: convert to float32, can also try float 16 [gives overflow though]
        mean = np.nanmean(x) # arithmetic mean along the specified axis, ignoring NaNs
        nan_percentage = np.isnan(x).mean() # percentage of NaN values in feature
        # === Fill nan values ===
        if nan_percentage < 1: # if some values are nan, but not all
            x = np.nan_to_num(x, nan=mean)
        else: # if all values are nan
            x[:] = 0
        data[:, index] = x
        
    # === Plot EEGs ===
        if display: 
            if index != 0:
                offset += x.max()
            plt.plot(range(10_000), x-offset, label=feature)
            offset -= x.min()
    if display:
        plt.legend()
        name = parquet_path.split('/')[-1].split('.')[0]
        plt.yticks([])
        plt.title(f'EEG {name}',size=16)
        plt.show()    

    return data

def sep():
    'Prints a line!!!!'
    print("-"*100)
    
def show_batch(batch, EEG_IDS, predict_arr=None):
    """Displays 6 random EEG signals from a batch in a 2x3 plot grid.

    Args:
        batch (array-like): Batch containing EEG signals and potentially labels.
            Expects (batch_size, num_channels, signal_length) or similar.
        EEG_IDS (list): List of EEG file IDs corresponding to data in the batch.
        predict_arr (array-like, optional): Array of predicted labels (for display). Defaults to None
    """

    num_items = 8  # Number of signals to display
    num_rows = 2
    num_cols = 4

    # Create a figure for plotting
    fig = plt.figure(figsize=(14, 8))

    # Randomly select 6 indices from the batch
    img_indices = np.random.randint(0, batch.shape[0], num_items)

    # Iterate over the selected indices
    for index, img_index in enumerate(img_indices):
        img = batch[img_index]
        lb = None  # Assume no labels

        # Create subplot within the grid
        ax = fig.add_subplot(num_rows, num_cols, index + 1, xticks=[], yticks=[]) 

        # If input is a PyTorch tensor, convert to NumPy array 
        if isinstance(img, torch.Tensor):
            img = img.detach().cpu().numpy()
            img = img.transpose(1, 0)  # Adjust dimensions if needed

         # Plot each channel of the EEG signal with an offset for clarity
        offset = 0
        for j in range(img.shape[-1]):
            if j != 0: offset -= img[:, j].min() # Calculate offset 
            ax.plot(img[:, j] + offset, label=f'feature {j+1}')
            offset += img[:, j].max() + 1 # Update offset for next channel

        # Add labels and titles
        ax.set_title(f'EEG_Id = {EEG_IDS[img_index]}', size=14)

    return fig 

# <b><span style='color:#196f3d'>|</span> Load Test Metadata <b><a class='anchor' id='load_meta_data'></a> [↑](#top) 

***
We load the metadata and filter it to only use unique EEG ids.

In [None]:
test = pd.read_csv(paths.TEST_CSV)
print('Test shape:',test.shape)
test.head()

In [None]:
#Create target dictionaries
df = pd.read_csv(paths.TRAIN_CSV)
TARGETS = df.columns[-6:]
TARS = {'Seizure':0, 'LPD':1, 'GPD':2, 'LRDA':3, 'GRDA':4, 'Other':5} #labels to num
TARS_INV = {x:y for y,x in TARS.items()}                              #num to labels
Config.NUM_CLASSES = len(TARS.keys())

# <b><span style='color:#196f3d'>|</span> Load Raw Data</b><a class='anchor' id='load_data_eeg'></a> [↑](#top) 

***

Load the EEG features we will be using and the corresponding raw EEG data from the parquet files.

In [None]:
parquet = pd.read_parquet(f'{paths.RAW_EEGS}3911565283.parquet')
eeg_features = parquet.columns
print(f'There are {len(eeg_features)} raw EEG features:')
print(list(eeg_features) )

#CHOICE: Select features to use. 
if Config.CUSTOM_FEATURES:
    eeg_features = Config.CUSTOM_FEATURES
    feature_to_index = {x:y for x,y in zip(eeg_features, range(len(eeg_features)))}
    print(f'We will use {len(eeg_features)} of these features:')
    print(list(eeg_features))

else:
    feature_to_index = {x:y for x,y in zip(eeg_features, range(len(eeg_features)))}
    print('We will use all of them')

In [None]:
all_eegs = {}
DISPLAY = 1
EEG_IDS = test.eeg_id.unique()
print('Processing Test EEG parquets...'); print()
for i,eeg_id in enumerate(EEG_IDS):
        
    # SAVE EEG TO PYTHON DICTIONARY OF NUMPY ARRAYS
    data = eeg_from_parquet(f'{paths.RAW_EEGS}{eeg_id}.parquet', i<DISPLAY)
    all_eegs[eeg_id] = data
    
#Create new test dataframe with only the eeg_id collumns (which are already unique)
test = pd.DataFrame(EEG_IDS, columns=['eeg_id'])

# <b><span style='color:#196f3d'>|</span>Preprocessing Functions</b><a class='anchor' id='preprocess'></a> [↑](#top) 

***
Functions of possible augmentations to apply to our data. Will be used in creating the dataset.

In [None]:
def butter_lowpass_filter(data, cutoff_freq=20, sampling_rate=200, order=4):
    #CHOICE: CUTOFF FREQUENCY
    """Applies a Butterworth low-pass filter to the data."""
    nyquist = 0.5 * sampling_rate
    normal_cutoff = cutoff_freq / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    filtered_data = lfilter(b, a, data, axis=0)
    return filtered_data

def quantize_data(data, classes):
    """Quantizes data using Mu-law encoding.
    Args:
        classes (int): The number of quantization levels.
    """
    mu_x = mu_law_encoding(data, classes)
#     bins = np.linspace(-1, 1, classes)  # Create equally spaced bins
#     quantized = np.digitize(mu_x, bins) - 1  # Assign data to bins
    return mu_x

def mu_law_encoding(data, mu):
    """Performs Mu-law encoding on a NumPy array."""
    mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1)
    return mu_x

def mu_law_expansion(data, mu):
    """Performs Mu-law expansion (decoding) on a NumPy array."""
    s = np.sign(data) * (np.exp(np.abs(data) * np.log(mu + 1)) - 1) / mu
    return s

# <b><span style='color:#196f3d'>|</span> Dataset</b><a class='anchor' id='dataset'></a> [↑](#top) 

***

Create a custom `Dataset` to load data. Here we can easily implement custom loading mechanisms, data pre-processing, and augmentation techniques. Also allows us to use the Pytorch Dataloader, streamlining data loading and batching.

In [None]:
class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, data, eegs=None, augmentations=None, test=False): 
        self.data = data  
        self.eegs = eegs 
        self.augmentations = augmentations  
        self.test = test  # Flag to indicate test mode

    def __len__(self):
        return len(self.data)  # Return the number of samples in the dataset

    def __getitem__(self, index):
        # Get a single data sample and its label (if not in test mode)
        row = self.data.iloc[index]       
        data = self.eegs[row.eeg_id]  # Load EEG data based on an ID

        if Config.FEATURE_ENGINEERING:
            if len(Config.CUSTOM_FEATURES) == 8:

                X = np.zeros((10_000, Config.NUM_CHANNELS), dtype='float32')
                    
                # === Feature engineering ===
                #LL Chain
                X[:,0] = data[:,feature_to_index['Fp1']] - data[:,feature_to_index['T3']]
                X[:,1] = data[:,feature_to_index['T3']]  - data[:,feature_to_index['O1']]
                
                #LP Chain
                X[:,2] = data[:,feature_to_index['Fp1']] - data[:,feature_to_index['C3']]
                X[:,3] = data[:,feature_to_index['C3']]  - data[:,feature_to_index['O1']]
                
                #RP Chain
                X[:,4] = data[:,feature_to_index['Fp2']] - data[:,feature_to_index['C4']]
                X[:,5] = data[:,feature_to_index['C4']]  - data[:,feature_to_index['O2']]
                
                #RR Chain
                X[:,6] = data[:,feature_to_index['Fp2']] - data[:,feature_to_index['T4']]
                X[:,7] = data[:,feature_to_index['T4']]  - data[:,feature_to_index['O2']]
                
            elif not Config.CUSTOM_FEATURES:
                Config.NUM_CHANNELS = 16 #update number of channels
                X = np.zeros((10_000, Config.NUM_CHANNELS), dtype='float16') #Full data cannot be represented with float32 anyway

                #LL Chain
                X[:,0] = data[:,feature_to_index['Fp1']] - data[:,feature_to_index['F7']]
                X[:,1] = data[:,feature_to_index['F7']]  - data[:,feature_to_index['T3']]
                X[:,2] = data[:,feature_to_index['T3']] - data[:,feature_to_index['T5']]
                X[:,3] = data[:,feature_to_index['T5']]  - data[:,feature_to_index['O1']]

                #LP Chain
                X[:,4] = data[:,feature_to_index['Fp1']] - data[:,feature_to_index['F3']]
                X[:,5] = data[:,feature_to_index['C4']]  - data[:,feature_to_index['C3']]
                X[:,6] = data[:,feature_to_index['Fp2']] - data[:,feature_to_index['P3']]
                X[:,7] = data[:,feature_to_index['T4']]  - data[:,feature_to_index['O1']]

                #RP Chain
                X[:,8] = data[:,feature_to_index['Fp2']] - data[:,feature_to_index['F4']]
                X[:,9] = data[:,feature_to_index['F4']]  - data[:,feature_to_index['C4']]
                X[:,10] = data[:,feature_to_index['C4']] - data[:,feature_to_index['P4']]
                X[:,11] = data[:,feature_to_index['P4']]  - data[:,feature_to_index['O2']]

                #RR Chain
                X[:,12] = data[:,feature_to_index['Fp2']] - data[:,feature_to_index['F8']]
                X[:,13] = data[:,feature_to_index['F8']]  - data[:,feature_to_index['T4']]
                X[:,14] = data[:,feature_to_index['T4']] - data[:,feature_to_index['T6']]
                X[:,15] = data[:,feature_to_index['T6']]  - data[:,feature_to_index['O2']]
                
        else:

            X = data

        # === Standarize ===
        X = np.clip(X,-1024,1024)    
        X = np.nan_to_num(X, nan=0) / 32.0  
        
        # === Preprocess ===
        X = butter_lowpass_filter(X)  # Apply a low-pass filter 
        X = quantize_data(X, 1)       # Apply quantization
        samples = torch.from_numpy(X).float()  # Convert to a PyTorch tensor
        samples = samples.permute(1, 0)  # Adjust tensor shape to get (time_steps, features)

        if not self.test:
            label = row[TARGETS]         # Get label from metadata
            label = torch.tensor(label).float()  
            return samples, label        # Return preprocessed data and label
        else:
            return samples               # Return preprocessed data only (for testing)

# <b><span style='color:#196f3d'>|</span> DataLoader</b><a class='anchor' id='dataloader'></a> [↑](#top) 

***

In [None]:
def get_test_dls(df_test):
    ds_test = EEGDataset(
        df_test, 
        eegs=all_eegs,
        augmentations = None,
        test = True
    )
    dl_test = DataLoader(ds_test, batch_size=Config.BATCH_SIZE , shuffle=False, num_workers = Config.NUM_WORKERS)    
    return dl_test, ds_test

# <b><span style='color:#196f3d'>|</span> Model</b><a class='anchor' id='model'></a> [↑](#top) 

***

In [None]:
class ResNet_1D_Block(nn.Module):
    """
    Implements a ResNet-inspired residual block for 1D convolutional networks.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int): Kernel size for the convolutions.
        stride (int): Stride for the convolutions.
        padding (int): Padding for the convolutions.
        downsampling (bool): Whether to apply downsampling.
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, downsampling):
        super(ResNet_1D_Block, self).__init__()
        self.bn1 = nn.BatchNorm1d(num_features=in_channels)
        self.relu = nn.ReLU(inplace=False)                               
        self.dropout = nn.Dropout(p=Config.P_DROPOUT, inplace=False)    

        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False) 
        self.bn2 = nn.BatchNorm1d(num_features=out_channels)                

        self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)  
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)   
        self.downsampling = downsampling
        self.downsample_layer = nn.MaxPool1d(kernel_size=2, stride=2, padding=0) 

    def forward(self, x):
        "Basically applies a BN, Relu, Dropout and 1D Convolution twice and then maxpools"
        identity = x  # Store input for residual connection

        out = self.bn1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)

        out = self.maxpool(out)  
        if self.downsampling:
            identity = self.downsample_layer(x)  # Apply downsampling if needed

        out += identity  # Add the orginal input (residual connection)
        return out

    
class EEGNet(nn.Module):
    """
    EEGNet architecture: A convolutional neural network with ResNet-style blocks designed for EEG signal classification.

    Args:
        kernels (list): A list of kernel sizes for the initial parallel convolutions.
        in_channels (int, optional): Number of input channels (e.g., number of electrodes). 
                                    Defaults to 20.
        fixed_kernel_size (int, optional): Kernel size for subsequent shared convolutions. 
                                           Defaults to 17.
        NUM_CLASSES (int, optional): Number of output classes for classification. 
                                     Defaults to 6.
    """

    def __init__(self, kernels, num_feature_maps = Config.NUM_FEATURE_MAPS, in_channels=Config.NUM_CHANNELS, 
                       fixed_kernel_size=Config.FIXED_KERNEL_SIZE, num_classes=6):
        super(EEGNet, self).__init__()
        self.kernels = kernels
        self.planes = num_feature_maps  # Initial number of feature maps (=planes) outputed for each kernel (and used in 1D conv and ResNet)
        self.in_channels = in_channels
        
        #----DEFINE LAYERS TO BE USED----
        self.parallel_conv = nn.ModuleList()  # Parallel convolutions with varying kernel sizes
        for kernel_size in kernels:
            conv = nn.Conv1d(in_channels, self.planes, kernel_size, stride=1, padding=0, bias=False)
            self.parallel_conv.append(conv)

        # Batch normalizations, ReLU activation
        self.bn1 = nn.BatchNorm1d(self.planes)
        self.relu = nn.ReLU(inplace=False)

        # Shared convolution and ResNet blocks
        self.conv1 = nn.Conv1d(self.planes, self.planes, fixed_kernel_size, stride=2, padding=2, bias=False)
        self.block = self._make_resnet_layer(fixed_kernel_size, stride=1, padding=fixed_kernel_size//2)

        # Downsampling Layers 
        self.bn2 = nn.BatchNorm1d(self.planes)
        self.avgpool = nn.AvgPool1d(6, stride=6, padding=2)
        
        # Gated Recurrent Unit (GRU) for processing sequential patterns in the EEG data
        self.rnn = nn.GRU(self.in_channels, hidden_size=128, num_layers=1, bidirectional=True)
        
        #Final FC layer
        self.fc = nn.Linear(424, num_classes) 

    def _make_resnet_layer(self, kernel_size, stride, blocks=Config.RESNET_BLOCKS, padding=0):
        """Creates a sequence of ResNet blocks."""
        layers = []
        for _ in range(blocks):
            downsampling = nn.MaxPool1d(2, stride=2, padding=0)  # Downsampling if needed
            layers.append(ResNet_1D_Block(self.planes, self.planes, kernel_size, stride, padding, downsampling))
        return nn.Sequential(*layers)

    def forward(self, x):
        """Forward pass through the network. First separate parallel convolutions 
           which are then concetaned to go through 1D conv and ResNet block"""
        # Parallel convolutions, list containing output of different kernels applied to input
        out_sep = [conv(x) for conv in self.parallel_conv]  

        # Concatenate outputs
        out = torch.cat(out_sep, dim=2) 
        
        # 1D Convolution
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv1(out)
        
        # ResNet Block [Has BN and ReLu build in]
        out = self.block(out)
        
        # Avgerage Pooling
        out = self.bn2(out)
        out = self.relu(out)
        out = self.avgpool(out)

        # Flatten for dense layers
        out = out.reshape(out.shape[0], -1)  

        # Pass through GRU for sequence representations
        rnn_out, _ = self.rnn(x.permute(0, 2, 1))   
        new_rnn_h = rnn_out[:, -1, :]  
        new_out = torch.cat([out, new_rnn_h], dim=1)  # Combine features
        result = self.fc(new_out)  # Final classification layer

        return result

# <b><span style='color:#196f3d'>|</span> Optimizer & Scheduler</b><a class='anchor' id='optimizer'></a> [↑](#top) 

***

In [None]:
def get_optimizer(lr, params):
    model_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, params), 
            lr=lr,
            weight_decay=Config.weight_decay
        )
    interval = "epoch"
    
    lr_scheduler = CosineAnnealingWarmRestarts(
                            model_optimizer, 
                            T_0=Config.epochs, 
                            T_mult=1, 
                            eta_min=1e-5, 
                            last_epoch=-1
                        )

    return {
        "optimizer": model_optimizer, 
        "lr_scheduler": {
            "scheduler": lr_scheduler,
            "interval": interval,
            "monitor": "val_loss",
            "frequency": 1
        }
    }

# <b><span style='color:#196f3d'>|</span> Loss Function</b><a class='anchor' id='loss'></a> [↑](#top) 

***

In [None]:
class KLDivLossWithLogits(nn.KLDivLoss):

    def __init__(self):
        super().__init__(reduction="batchmean")

    def forward(self, y, t):
        y = nn.functional.log_softmax(y,  dim=1)
        loss = super().forward(y, t)

        return loss


# <b><span style='color:#196f3d'>|</span> Pytorch Lightning Model</b><a class='anchor' id='lightning'></a> [↑](#top) 

***

In [None]:
class EEGModel(pl.LightningModule):
    """
    Pytorch Lightning module for EEG classification.
    """

    def __init__(self, num_classes = Config.NUM_CLASSES, pretrained = Config.PRETRAINED, fold = -1):
        super().__init__()
        self.num_classes = num_classes
        self.fold = fold

        # Create the EEGNet backbone with specified kernel sizes
        self.backbone = EEGNet(kernels=Config.KERNELS, in_channels=Config.NUM_CHANNELS, 
                               fixed_kernel_size=Config.FIXED_KERNEL_SIZE, num_classes=Config.NUM_CLASSES)

        # Loss function for multi-class classification
        self.loss_function = KLDivLossWithLogits() 

        self.validation_step_outputs = []  # Storage for validation results
        self.lin = nn.Softmax(dim=1)       # Softmax for probability outputs 
        self.best_score = 1000.0           # Track best validation score

    def forward(self, images):
        """Forward pass through the EEGNet model."""
        logits = self.backbone(images)  # Extract features and get predictions
        return logits

    def configure_optimizers(self):
        """Set up the optimizer (using configuration parameters)."""
        return get_optimizer(lr=Config.LR, params=self.parameters())


    def train_with_mixup(self, X, y):
        X, y_a, y_b, lam = mixup_data(X, y, alpha=Config.MIXUP_ALPHA)
        y_pred = self(X)
        loss_mixup = mixup_criterion(KLDivLossWithLogits(), y_pred, y_a, y_b, lam)
        return loss_mixup

    def training_step(self, batch, batch_idx):
        image, target = batch        
        if Config.USE_MIXUP:
            loss = self.train_with_mixup(image, target)
        else:
            y_pred = self(image)
            loss = self.loss_function(y_pred,target)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss        

    def validation_step(self, batch, batch_idx):
        image, target = batch 
        y_pred = self(image)
        val_loss = self.loss_function(y_pred, target)
        self.log("val_loss", val_loss, on_step=True, on_epoch=True, logger=True, prog_bar=True, sync_dist=True)
        self.validation_step_outputs.append({"val_loss": val_loss, "logits": y_pred, "targets": target})

        return {"val_loss": val_loss, "logits": y_pred, "targets": target}

    def train_dataloader(self):
        return self._train_dataloader 
    
    def validation_dataloader(self):
        return self._validation_dataloader
    
    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        # print(len(outputs))
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        output_val = nn.Softmax(dim=1)(torch.cat([x['logits'] for x in outputs],dim=0)).cpu().detach().numpy()
        target_val = torch.cat([x['targets'] for x in outputs],dim=0).cpu().detach().numpy()
        self.validation_step_outputs = []

        val_df = pd.DataFrame(target_val, columns = list(TARGETS))
        pred_df = pd.DataFrame(output_val, columns = list(TARGETS))

        val_df['id'] = [f'id_{i}' for i in range(len(val_df))] 
        pred_df['id'] = [f'id_{i}' for i in range(len(pred_df))] 


        avg_score = score(val_df, pred_df, row_id_column_name = 'id')

        if avg_score < self.best_score:
            print(f'Fold {self.fold}: Epoch {self.current_epoch} validation loss {avg_loss}')
            print(f'Fold {self.fold}: Epoch {self.current_epoch} validation KDL score {avg_score}')
            self.best_score = avg_score
            # val_df.to_csv(f'{paths.OUTPUT_DIR}/val_df_f{self.fold}.csv',index=False)
            # pred_df.to_csv(f'{paths.OUTPUT_DIR}/pred_df_f{self.fold}.csv',index=False)
        
        return {'val_loss': avg_loss,'val_cmap':avg_score}

# <b><span style='color:#196f3d'>|</span> Prediction Function </b><a class='anchor' id='prediction'></a> [↑](#top) 

***
We will use a seperate prediction function to generate predictions of the trained model, outside the Pytorch Lightning environment

In [None]:
def predict(data_loader, model):
        
    model.to('cuda')
    model.eval()    
    predictions = []
    for batch in tqdm(data_loader):

        with torch.no_grad():
            x = batch
            x = x.cuda()
            # inputs = {key:val.reshape(val.shape[0], -1).to(config.device) for key,val in batch.items()}
            outputs = model(x)
            outputs = nn.Softmax(dim=1)(outputs)
        predictions.extend(outputs.detach().cpu().numpy())
    predictions = np.vstack(predictions)
    return predictions

# <b><span style='color:#196f3d'>|</span> Inference</b><a class='anchor' id='inference'></a> [↑](#top) 

***

In [None]:
def run_inference(fold_id, Config):
    logger = None
    pred_cols = [f'pred_{t}' for t in TARGETS]
    df_test = test.copy()
    dl_test, ds_test = get_test_dls(df_test)
    

    print(f"Running inference model '{paths.WEIGHTS}/eegnet_best_loss_fold{fold_id}.ckpt'..")
    
    model = EEGModel.load_from_checkpoint(f'{paths.WEIGHTS}/eegnet_best_loss_fold{fold_id}.ckpt',map_location='cuda:0',
                                          train_dataloader=None,validation_dataloader=None,config=Config)
    
    preds = predict(dl_test, model)  
    gc.collect()
    return preds

# <b><span style='color:#196f3d'>|</span> Submission</b><a class='anchor' id='submission'></a> [↑](#top) 

***

In [None]:
submission_df = test[['eeg_id']].copy()
submission_df[TARGETS] = 0.0

test_preds = []
for en,f in enumerate(range(Config.FOLDS)):
    preds = run_inference(f, Config)
    test_preds.append(preds)
    
test_preds = np.mean(test_preds, 0)
submission_df[TARGETS] = test_preds
submission_df.to_csv('submission.csv',index=False)