In [2]:
import mne
from pathlib import Path

import numpy as np
import pandas as pd

from pathlib import Path

import numpy as np
from scipy import signal
import matplotlib.pyplot as plt

from scipy.stats import skew, kurtosis

import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import MeanSquaredError, MeanAbsoluteError

from tqdm.notebook import tqdm

import wandb

from pytorch_lightning.loggers import WandbLogger

In [14]:
mne.set_log_level('CRITICAL')

# Load dataset

In [16]:
ROOT = Path('/kaggle/input/competition-dataset/Competition') # name of the "folder"
SOURCE = ROOT / 'raw_files'

In [17]:
y_train = pd.read_csv(ROOT/ 'y_train.csv')
y_test = pd.read_csv(ROOT / 'y_test.csv')

In [18]:
all_files = len([x for x in SOURCE.glob('*.fif')])

# Features extraction

In [19]:
def PSD(input_signal):
    f, Pxx_den = signal.welch(x=input_signal, fs=250, window='hamming',)
    return Pxx_den

In [20]:
def hjorth_parameters(eeg_signal):
    """
    Calculates Hjorth parameters (Activity, Mobility, Complexity) for an EEG signal.
    
    Args:
    eeg_signal: A 1D NumPy array representing the EEG signal.
    
    Returns:
    A dictionary containing the Hjorth parameters:
    {'activity': activity, 'mobility': mobility, 'complexity': complexity}
    """
    # Calculate derivatives using central difference method
    first_derivative = np.diff(eeg_signal)
    second_derivative = np.diff(first_derivative)
    
    # Calculate variances
    var_signal = np.var(eeg_signal)
    var_first_derivative = np.var(first_derivative)
    var_second_derivative = np.var(second_derivative)

    if var_signal == 0:
        return np.array((0,0,0))    
    # Calculate Hjorth parameters
    activity = var_signal
    mobility = np.sqrt(var_first_derivative / var_signal)
    complexity = np.sqrt(var_second_derivative / var_first_derivative) / mobility
    
    return np.array((activity, mobility, complexity))



In [21]:
def extract_statistical_measures(eeg_signal):
    """
    Extracts statistical measures from an EEG signal.
    
    Args:
    eeg_signal: A 1D NumPy array representing the EEG signal.
    
    Returns:
    A dictionary containing the statistical measures.
    """
    mean = np.mean(eeg_signal)
    variance = np.var(eeg_signal)
    std_dev = np.std(eeg_signal)
    rms = np.sqrt(np.mean(eeg_signal**2))
    skewness = skew(eeg_signal)
    kurt = kurtosis(eeg_signal)  # This is excess kurtosis
    zero_crossing_rate = ((eeg_signal[:-1] * eeg_signal[1:]) < 0).sum()
    range_voltage = np.max(eeg_signal) - np.min(eeg_signal)
    iqr = np.percentile(eeg_signal, 75) - np.percentile(eeg_signal, 25)
    return np.nan_to_num(np.array((
        mean,
        variance,
        std_dev,
        rms,
        skewness,
        kurt,
        zero_crossing_rate,
        range_voltage,
        iqr
    )), nan=0)


In [22]:
def extract_features(input_signal):
    psd = PSD(input_signal)
    hjorth = hjorth_parameters(input_signal)
    # stats = extract_statistical_measures(input_signal)
    return np.concatenate([psd, hjorth])

In [23]:
def remove_zero_product_elements(arr, label=None):
    products = np.prod(arr, axis=2)  
    products = np.prod(products, axis=1)

    mask = products != 0 

    filtered_arr = arr[mask]
    if label is not None:
        filtered_label = label[mask]
    else:
        filtered_label = None
    
    return filtered_arr, filtered_label

In [None]:
include = range(0, 52)

segment_length =  5 * 250 #  125 Hz
# num_segments = train_data.shape[2] // segment_length

train_data = []
train_label = []
test_data = []
test_label = []

for (id, age) in tqdm(y_train.values):
    file = f'{id}_sflip_parc-raw.fif'
    
    raw = mne.io.read_raw_fif(SOURCE / file, preload=True)
    raw.filter(l_freq=0.5, h_freq=100.0, picks=include)
    
    data = np.array(raw.get_data(picks=include))
    data = mne.filter.resample(data, down=2)
    num_segments = data.shape[1] // segment_length
    data = data[:, :num_segments * segment_length]
    
    data_reshaped = data.reshape(-1, data.shape[0], segment_length)
    labels = np.full(num_segments, age)
    data_reshaped, labels = remove_zero_product_elements(data_reshaped, labels)
    data_transformed = np.apply_along_axis(extract_features, axis=2, arr=data_reshaped)
    
    # print(data_reshaped.shape, num_segments, labels.shape)
    train_data.append(data_reshaped)
    train_label.append(labels)

for (id, age) in tqdm(y_test.values):
    file = f'{id}_sflip_parc-raw.fif'
    
    raw = mne.io.read_raw_fif(SOURCE / file, preload=True)
    raw.filter(l_freq=0.5, h_freq=100.0, picks=include)
    
    data = np.array(raw.get_data(picks=include))
    data = mne.filter.resample(data, down=2)
    num_segments = data.shape[1] // segment_length
    data = data[:, :num_segments * segment_length]
    
    data_reshaped = data.reshape(-1, data.shape[0], segment_length)
    labels = np.full(num_segments, age)
    data_reshaped, labels = remove_zero_product_elements(data_reshaped, labels)
    data_transformed = np.apply_along_axis(extract_features, axis=2, arr=data_reshaped)
    
    # print(data_reshaped.shape, num_segments, labels.shape)
    test_data.append(data_reshaped)
    test_label.append(labels)
    # break

  0%|          | 0/120 [00:00<?, ?it/s]

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


In [None]:
train_data_reshaped = np.concatenate(train_data)
train_label_reshaped = np.concatenate(train_label)
test_data_reshaped = np.concatenate(test_data)
test_label_reshaped = np.concatenate(test_label)

# Dataloader definition

In [None]:
class EEGDatasetRegression(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        x = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
        y = torch.tensor(y, dtype=torch.float32)
        return x, y

class EEGDataModule(pl.LightningDataModule):
    def __init__(self, train_data, train_labels, test_data, test_labels, batch_size=8):
        super().__init__()
        self.train_data = train_data
        self.train_labels = train_labels
        self.test_data = test_data
        self.test_labels = test_labels
        self.batch_size = batch_size

    def setup(self, stage=None):
        # Prepare datasets
        self.train_dataset = EEGDatasetRegression(self.train_data, self.train_labels)
        self.test_dataset = EEGDatasetRegression(self.test_data, self.test_labels)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)


In [None]:
batch_size = 128
eeg_dm = EEGDataModule(train_data_reshaped, train_label_reshaped, test_data_reshaped, test_label_reshaped, batch_size=batch_size)

# Model definition

In [None]:
class EEGNetAgeRegressor(pl.LightningModule):
    def __init__(self,
                 Chans=52,
                 Samples=250,
                 F1=8,
                 D=2,
                 F2=16,
                 kernelLength=64,
                 dropoutRate=0.5,
                 learning_rate=1e-3,
                 weight_decay=1e-4):
        super().__init__()

        self.save_hyperparameters()
        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=F1,
            kernel_size=(1, kernelLength),
            padding=(0, kernelLength // 2),
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(F1)

        self.depthwise_conv = nn.Conv2d(
            in_channels=F1,
            out_channels=F1*D,
            kernel_size=(Chans, 1),
            groups=F1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(F1*D)


        self.elu = nn.ELU()

        self.pool1 = nn.AvgPool2d(kernel_size=(1, 4))
        self.dropout1 = nn.Dropout(dropoutRate)

        self.sep_conv1 = nn.Conv2d(
            in_channels=F1*D,
            out_channels=F1*D,
            kernel_size=(1, 16),
            padding=(0, 8),
            groups=F1*D,
            bias=False
        )
        self.bn3 = nn.BatchNorm2d(F1*D)

        self.sep_conv2 = nn.Conv2d(
            in_channels=F1*D,
            out_channels=F2,
            kernel_size=(1, 1),
            bias=False
        )
        self.bn4 = nn.BatchNorm2d(F2)

        self.pool2 = nn.AvgPool2d(kernel_size=(1, 8))
        self.dropout2 = nn.Dropout(dropoutRate)

        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(F2, 1)

        self.train_mse = MeanSquaredError()
        self.val_mse = MeanSquaredError()
        self.test_mse = MeanSquaredError()
        self.test_mae = MeanAbsoluteError()

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)

        x = self.depthwise_conv(x)
        x = self.bn2(x)
        x = self.elu(x)
        x = self.pool1(x)
        x = self.dropout1(x)

        x = self.sep_conv1(x)
        x = self.bn3(x)
        x = self.elu(x)

        x = self.sep_conv2(x)
        x = self.bn4(x)
        x = self.elu(x)

        x = self.pool2(x)
        x = self.dropout2(x)

        x = self.gap(x)

        x = x.view(x.size(0), -1)

        out = self.fc(x)
        return out.squeeze(-1)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self.forward(x)
        loss = F.mse_loss(preds, y)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)

        mse_val = self.train_mse(preds, y)
        self.log("train_mse", mse_val, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self.forward(x)
        loss = F.mse_loss(preds, y)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)

        mse_val = self.val_mse(preds, y)
        self.log("val_mse", mse_val, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        preds = self.forward(x)
        loss = F.mse_loss(preds, y)
        self.log("test_loss", loss, on_epoch=True, prog_bar=True)

        mse_val = self.test_mse(preds, y)
        self.log("test_mse", mse_val, on_epoch=True, prog_bar=True)

        mae_val = self.test_mae(preds, y)
        self.log("test_mae", mae_val, on_epoch=True, prog_bar=True)

        return loss


# Training and Testing

In [None]:
wandb.login(key='4a14ae04a51206a6d72f68f667cdf5a33d65d259')

In [None]:
wandb_logger = WandbLogger(project="EEG Competition", log_model=True)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    filename="best-model-{epoch}-{val_loss:.2f}"
)
model = EEGNetAgeRegressor(
    Chans=52,
    Samples=train_data_reshaped.shape[2],
    F1=8,
    D=2,
    F2=16,
    kernelLength=64,
    dropoutRate=0.5,
    learning_rate=1e-3,
    weight_decay=1e-4
)

trainer = pl.Trainer(
    max_epochs=30,
    accelerator="auto",
    devices=1,
     logger=wandb_logger,
     callbacks=[checkpoint_callback]
)

trainer.fit(model, datamodule=eeg_dm)

trainer.test(model, datamodule=eeg_dm)
