# Introduction
## Acknowledgements
The original base of this notebook was copied from @andreasbis. We thank them for supplying a useful baseline to expand upon. Please take a look at their work: https://www.kaggle.com/code/andreasbis/hms-train-efficientnetb1.

# Imports

In [None]:
import gc
import os
import random
import warnings
from IPython.display import display

import numpy as np
import pandas as pd

import timm
import torch
import torch.nn as nn  
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from scipy import signal

warnings.filterwarnings('ignore', category=Warning)
gc.collect()

# Setup

In [None]:
labels = ['seizure', 'lpd', 'gpd', 'lrda', 'grda', 'other']

class Config:
    seed = 3131 
    image_transform = transforms.Resize((512,512))  
    batch_size = 16
    num_epochs = 9
    num_folds = 5

def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def kl_loss(p, q):
    epsilon = 10 ** (-15)
    
    p = torch.clamp(p, epsilon, 1 - epsilon)
    log_p = torch.log(p)
    log_q = nn.functional.log_softmax(q, dim=1)
    
    kl_divergence_per_point = p * (log_p - log_q)
    kl_divergence_per_label = torch.sum(kl_divergence_per_point, dim=1)
    
    return torch.mean(kl_divergence_per_label)

set_seed(Config.seed)
gc.collect()

# Data Loading

In [None]:
# Code changed to grouping by eeg-id. 

train_df = pd.read_csv("/kaggle/input/hms-harmful-brain-activity-classification/train.csv")
train_features = pd.DataFrame()

for label in labels:
    train_grouped_by_eeg_id = train_df[f'{label}_vote'].groupby(train_df['eeg_id']).sum()

    label_vote_sum = pd.DataFrame()
    label_vote_sum["eeg_id"] = train_grouped_by_eeg_id.index
    label_vote_sum[f"{label}_vote_sum"] = train_grouped_by_eeg_id.values

    if label == labels[0]:
        train_features = label_vote_sum
    else:
        train_features = train_features.merge(label_vote_sum, on='eeg_id', how='left')

# Add a column to sum all votes
train_features['total_vote'] = 0
for label in labels:
    train_features['total_vote'] += train_features[f'{label}_vote_sum']

# Calculate and store the normalized vote for each label
for label in labels:
    train_features[f'{label}_vote'] = train_features[f'{label}_vote_sum'] / train_features['total_vote']

# Select relevant columns for the training features
choose_cols = ['eeg_id']
for label in labels:
    choose_cols += [f'{label}_vote']
train_features = train_features[choose_cols]

# Add a column with the path to the spectrogram files
train_features['path'] = train_features['eeg_id'].apply(lambda x: "/kaggle/input/hms-harmful-brain-activity-classification/train_eegs/" + str(x) + ".parquet")

# Reclaim memory no longer in use.
gc.collect()

# Creating spectrograms
Spectrograms are created based on the 'Magic Formula to Convert EEG to Spectrograms'. (https://www.kaggle.com/competitions/hms-harmful-brain-activity-classification/discussion/469760)

In [None]:
def create_spectrogram(data):
    """This function will create the spectrograms based on the EEG data with the 'magic formula'."""
    nperseg = 150  # Length of each segment
    noverlap = 128  # Overlap between segments
    NFFT = max(256, 2 ** int(np.ceil(np.log2(nperseg))))

    # LL Spec = ( spec(Fp1 - F7) + spec(F7 - T3) + spec(T3 - T5) + spec(T5 - O1) )/4
    freqs, t,spectrum_LL1 = signal.spectrogram(data['Fp1']-data['F7'],nfft=NFFT,noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_LL2 = signal.spectrogram(data['F7']-data['T3'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_LL3 = signal.spectrogram(data['T3']-data['T5'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_LL4 = signal.spectrogram(data['T5']-data['O1'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)

    LL = (spectrum_LL1+ spectrum_LL2 +spectrum_LL3 + spectrum_LL4)/4

    # LP Spec = ( spec(Fp1 - F3) + spec(F3 - C3) + spec(C3 - P3) + spec(P3 - O1) )/4
    freqs, t,spectrum_LP1 = signal.spectrogram(data['Fp1']-data['F3'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_LP2 = signal.spectrogram(data['F3']-data['C3'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_LP3 = signal.spectrogram(data['C3']-data['P3'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_LP4 = signal.spectrogram(data['P3']-data['O1'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)

    LP = (spectrum_LP1+ spectrum_LP2 +spectrum_LP3 + spectrum_LP4)/4

    # RP Spec = ( spec(Fp2 - F4) + spec(F4 - C4) + spec(C4 - P4) + spec(P4 - O2) )/4
    freqs, t,spectrum_RP1 = signal.spectrogram(data['Fp2']-data['F4'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_RP2 = signal.spectrogram(data['F4']-data['C4'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_RP3 = signal.spectrogram(data['C4']-data['P4'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_RP4 = signal.spectrogram(data['P4']-data['O2'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)

    RP = (spectrum_RP1+ spectrum_RP2 +spectrum_RP3 + spectrum_RP4)/4


    # RL Spec = ( spec(Fp2 - F8) + spec(F8 - T4) + spec(T4 - T6) + spec(T6 - O2) )/4
    freqs, t,spectrum_RL1 = signal.spectrogram(data['Fp2']-data['F8'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_RL2 = signal.spectrogram(data['F8']-data['T4'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_RL3 = signal.spectrogram(data['T4']-data['T6'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    freqs, t,spectrum_RL4 = signal.spectrogram(data['T6']-data['O2'],nfft=NFFT, noverlap = noverlap,nperseg=nperseg)
    RL = (spectrum_RL1+ spectrum_RL2 +spectrum_RL3 + spectrum_RL4)/4
    spectogram = np.concatenate((LL, LP,RP,RL), axis=0)
    return spectogram

# Data Preprocessing

In [None]:
def datasetwide_mean(paths):
    """This function will calculate the mean of the entire dataset."""
    data_means = []
    total_values = 0
    # Iterate over each path in the provided paths
    for path in paths:
        # Read data from parquet file
        data = pd.read_parquet(path[0])
        data = create_spectrogram(data)
        
        # Fill missing values with the specified constant
        mask = np.isnan(data)
        data[mask] = -1
        
        # Clip values and apply logarithmic transformation
        data = np.clip(data, np.exp(-6), np.exp(10))
        data = np.log(data)
        
        # Calculate sum and amount of values of the data
        data_sum = data.sum(axis=(0, 1))
        rows,columns = data.shape
        total_values += rows*columns
        
    return data_sum/total_values

def datasetwide_std(paths, mean):
    """This function will calculate the standard deviation of the entire dataset."""
    data_stds = []
    # Iterate over each path in the provided paths
    sum_std = 0
    total_values = 0
    for path in paths:
        # Read data from parquet file
        data = pd.read_parquet(path[0])
        data = create_spectrogram(data)
        
        # Fill missing values with the specified constant
        mask = np.isnan(data)
        data[mask] = -1

        # Clip values and apply logarithmic transformation
        data = np.clip(data, np.exp(-6), np.exp(10))
        data = np.log(data)
        
        # Calculate values needed for std
        sum_std+= np.sum((data-mean)**2)
        rows,columns = data.shape
        total_values += rows*columns
    
    return np.sqrt(sum_std/(total_values-1))

# Calculate the mean and std of dataset. 
data_mean = datasetwide_mean(train_features[['path']].values)
data_std = datasetwide_std(train_features[['path']].values,mean)

In [None]:
def get_batch_datasetwidenorm(paths,data_mean,data_std, batch_size=Config.batch_size):
    """This function will get the batch and preprocess it."""
    # Set a small epsilon to avoid division by zero
    eps = 1e-6

    # Initialize a list to store batch data
    batch_data = []

    # Iterate over each path in the provided paths
    for path in paths:
        # Read data from parquet file
        data = pd.read_parquet(path[0])
        data = create_spectrogram(data)
        
        # Fill missing values with the specified constant
        mask = np.isnan(data)
        data[mask] = -1
        
        # Clip values and apply logarithmic transformation
        data = np.clip(data, np.exp(-6), np.exp(10))
        data = np.log(data)
        
        # Normalize the data
        data = (data - data_mean) / (data_std + eps)

        # Convert data to a PyTorch tensor and apply transformations
        data_tensor = torch.unsqueeze(torch.Tensor(data), dim=0)
        data = Config.image_transform(data_tensor)

        # Append the processed data to the batch_data list
        batch_data.append(data)

    # Stack all the batch data into a single tensor
    batch_data = torch.stack(batch_data)

    # Return the batch data
    return batch_data

# Model Training

In [None]:
def train(lr = 0.001 ,data_mean=0,data_std=1):
    # Cross-validation loop
    train_losses_folds = []
    test_losses_folds = []
    for fold in range(Config.num_folds):
        # Split data into train and test sets for this fold
        test_idx = total_idx[fold * len(total_idx) // Config.num_folds:(fold + 1) * len(total_idx) // Config.num_folds]
        
        #If folds is put to 1, the function will still work. 
        if Config.num_folds==1:
            start = np.random.choice(len(total_idx)- int(np.round(0.2*len(total_idx)))-1)
            end = start + int(np.round(0.2*len(total_idx)))
            test_idx = total_idx[start:end]
            
        train_idx = np.array([idx for idx in total_idx if idx not in test_idx])

        # Initialize EfficientNet-B1 model with pretrained weights
        model = timm.create_model('efficientnet_b1', pretrained=True, num_classes=6, in_chans=1)
        model.to(device)
        
        optimizer = optim.AdamW(model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.01)
        scheduler = CosineAnnealingLR(optimizer, T_max=Config.num_epochs)

        best_test_loss = float('inf')
        train_losses = []
        test_losses = []

        print(f"Starting training for fold {fold + 1}")

        # Training loop
        for epoch in range(Config.num_epochs):
            model.train()
            train_loss = []
            random_num = np.arange(len(train_idx))
            np.random.shuffle(random_num)
            train_idx = train_idx[random_num]
            # Iterate over batches in the training set
            for idx in range(0, len(train_idx), Config.batch_size):
                optimizer.zero_grad()
                train_idx1 = train_idx[idx:idx + Config.batch_size]
                train_X1_path = train_features[['path']].iloc[train_idx1].values
                
                # Normalize data with given mean and std
                train_X1 = get_batch_datasetwidenorm(train_X1_path,data_mean,data_std, batch_size=Config.batch_size)
                
                train_y1 = train_features[['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']].iloc[train_idx1].values
                train_y1 = torch.Tensor(train_y1)
                train_pred = model(train_X1.to(device))
                loss = kl_loss(train_y1.to(device), train_pred)
                loss.backward()
                optimizer.step()
                train_loss.append(loss.item())

            epoch_train_loss = np.mean(train_loss)
            train_losses.append(epoch_train_loss)
            print(f"Epoch {epoch + 1}: Train Loss = {epoch_train_loss:.2f}")
            
            scheduler.step()

            # Evaluation loop
            model.eval()
            test_loss = []
            with torch.no_grad():
                for idx in range(0, len(test_idx), Config.batch_size):
                    test_idx1 = test_idx[idx:idx + Config.batch_size]
                    test_X1_path = train_features[['path']].iloc[test_idx1].values
                    
                    # Normalize data with given mean and std. 
                    test_X1= get_batch_datasetwidenorm(test_X1_path,data_mean,data_std, batch_size=Config.batch_size)
                    
                    test_y1 = train_features[['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']].iloc[test_idx1].values
                    test_y1 = torch.Tensor(test_y1)

                    test_pred = model(test_X1.to(device))
                    loss = kl_loss(test_y1.to(device), test_pred)
                    test_loss.append(loss.item())

            epoch_test_loss = np.mean(test_loss)
                
            test_losses.append(epoch_test_loss)
            print(f"Epoch {epoch + 1}: Test Loss = {epoch_test_loss:.2f}")

            # Save the model if it has the best test loss so far
            if epoch_test_loss < best_test_loss:
                best_test_loss = epoch_test_loss
                torch.save(model.state_dict(), f"efficientnet_b1_fold{fold}.pth")

            gc.collect()

        print(f"Fold {fold + 1} Best Test Loss: {best_test_loss:.2f}")
        train_losses_folds.append(train_losses)
        test_losses_folds.append(test_losses)
        
    return train_losses_folds, test_losses_folds

In [None]:
def plot_loss_graph(train_losses_folds, test_losses_folds):
    """This function will plot the loss curves of all the different folds."""
    for i in range(len(train_losses_folds)):
        train_losses = train_losses_folds[i]
        test_losses = test_losses_folds[i]
        plt.plot(train_losses, label='Train Loss')
        plt.plot(test_losses, label='Test Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(f'fold_{i}.png')
        plt.close()

In [None]:
# Determine device availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Assuming train_features is defined and contains the training features and labels
total_idx = np.arange(len(train_features))
np.random.shuffle(total_idx)

# The model will be trained
train_losses_folds, test_losses_folds = train(lr = 0.001,data_mean=data_mean,data_std=data_std)
plot_loss_graph(train_losses_folds, test_losses_folds)
gc.collect()