In [None]:
import h5py
import wandb

import numpy as np 
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import MeanSquaredError, MeanAbsoluteError

from pytorch_lightning.loggers import WandbLogger

# Load dataset

In [2]:
wandb.login(key='0a4257fb4fa0625c3750f4f755c7ec9bdfd0c3c8')
y_train = pd.read_csv('/kaggle/input/eeg-misc-52-250hz/y_train.csv')
y_test = pd.read_csv('/kaggle/input/eeg-misc-52-250hz/y_test.csv')

train_data = []
test_data = []
train_label = []
test_label = []

with h5py.File('/kaggle/input/eeg-misc-52-250hz/data_250hz.h5', 'r') as f:
    min_len = 13500
    cur = 67750*2
    for (patient_id, age) in y_train.values:
        data = []
        for i in range(52):
            parcel = f[patient_id + '/parcel_' + str(i)][:]
            parcel = np.squeeze(parcel)
            cur = min(cur, len(parcel))
            data.append(parcel)
        data=np.array(data)
        if len(data[0]) == min_len:
            train_data.append(data)
            train_label.append(age)
        else:
            train_data.append(data[:, :min_len])
            train_data.append(data[:, -min_len:])
            train_label.append(age)
            train_label.append(age)
            
    for (patient_id, age) in y_test.values:
        data = []
        for i in range(52):
            parcel = f[patient_id + '/parcel_' + str(i)][:]
            parcel = np.squeeze(parcel)
            data.append(parcel)
            cur = min(cur, len(parcel))
        data=np.array(data)
        if len(data[0]) == min_len:
            test_data.append(data)
            test_label.append(age)
        else:
            test_data.append(data[:, :min_len])
            test_data.append(data[:, -min_len:])
            test_label.append(age)
            test_label.append(age)
print(cur)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33m22022500[0m ([33m22022500-university-of-engineering-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


135499


In [3]:
train_data = np.array(train_data)
test_data = np.array(test_data)
train_label = np.array(train_label)
test_label = np.array(test_label)
print(train_data.shape, test_data.shape)

(240, 52, 13500) (80, 52, 13500)


In [4]:
segment_length = 250
num_segments = train_data.shape[2] // segment_length

train_data_reshaped = train_data.reshape(-1, train_data.shape[1], segment_length) 
test_data_reshaped = test_data.reshape(-1, test_data.shape[1], segment_length)
train_label_reshaped = np.repeat(train_label, num_segments)  
test_label_reshaped = np.repeat(test_label, num_segments)   
print(train_data_reshaped.shape)

(12960, 52, 250)


In [5]:
labels = pd.read_csv('/kaggle/input/eeg-misc-52-250hz/y_train.csv').values.squeeze()

age_labels = labels[:, 1].astype(int)

num_classes = len(np.unique(age_labels))
print(f"Number of classes: {num_classes}")


Number of classes: 62


# Dataloader definition

In [6]:
class EEGDatasetRegression(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        x = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
        y = torch.tensor(y, dtype=torch.float32)
        return x, y

class EEGDataModule(pl.LightningDataModule):
    def __init__(self, train_data, train_labels, test_data, test_labels, batch_size=8):
        super().__init__()
        self.train_data = train_data
        self.train_labels = train_labels
        self.test_data = test_data
        self.test_labels = test_labels
        self.batch_size = batch_size

    def setup(self, stage=None):
        # Prepare datasets
        self.train_dataset = EEGDatasetRegression(self.train_data, self.train_labels)
        self.test_dataset = EEGDatasetRegression(self.test_data, self.test_labels)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)
batch_size = 128
eeg_dm = EEGDataModule(train_data_reshaped, train_label_reshaped, test_data_reshaped, test_label_reshaped, batch_size=batch_size)

# Model definition

In [7]:
class EEGNetAgeRegressor(pl.LightningModule):
    def __init__(self,
                 Chans=52,          
                 Samples=250,        
                 F1=8,
                 D=2,
                 F2=16,
                 kernelLength=64,
                 dropoutRate=0.5,
                 learning_rate=1e-3,
                 weight_decay=1e-4):
        super().__init__()
        
        self.save_hyperparameters()  
        self.conv1 = nn.Conv2d(
            in_channels=1, 
            out_channels=F1,
            kernel_size=(1, kernelLength),
            padding=(0, kernelLength // 2),
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(F1)

        self.depthwise_conv = nn.Conv2d(
            in_channels=F1,
            out_channels=F1*D,
            kernel_size=(Chans, 1),
            groups=F1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(F1*D)
      
        
        self.elu = nn.ELU()
        
        self.pool1 = nn.AvgPool2d(kernel_size=(1, 4))
        self.dropout1 = nn.Dropout(dropoutRate)
        
        self.sep_conv1 = nn.Conv2d(
            in_channels=F1*D,
            out_channels=F1*D,
            kernel_size=(1, 16), 
            padding=(0, 8),
            groups=F1*D,
            bias=False
        )
        self.bn3 = nn.BatchNorm2d(F1*D)
        
        self.sep_conv2 = nn.Conv2d(
            in_channels=F1*D,
            out_channels=F2,
            kernel_size=(1, 1),
            bias=False
        )
        self.bn4 = nn.BatchNorm2d(F2)
        
        self.pool2 = nn.AvgPool2d(kernel_size=(1, 8))
        self.dropout2 = nn.Dropout(dropoutRate)
        
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(F2, 1)  
        
        self.train_mse = MeanSquaredError()
        self.val_mse = MeanSquaredError()
        self.test_mse = MeanSquaredError()
        self.test_mae = MeanAbsoluteError()
        
    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        
        x = self.depthwise_conv(x)
        x = self.bn2(x)
        x = self.elu(x)
        x = self.pool1(x)
        x = self.dropout1(x)
        
        x = self.sep_conv1(x)
        x = self.bn3(x)
        x = self.elu(x)
        
        x = self.sep_conv2(x)
        x = self.bn4(x)
        x = self.elu(x)
        
        x = self.pool2(x)
        x = self.dropout2(x)
        
        x = self.gap(x)
        
        x = x.view(x.size(0), -1)
        
        out = self.fc(x)
        return out.squeeze(-1)  
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        return optimizer
    
    def training_step(self, batch, batch_idx):
        x, y = batch  
        preds = self.forward(x)
        loss = F.mse_loss(preds, y)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        
        mse_val = self.train_mse(preds, y)
        self.log("train_mse", mse_val, on_epoch=True, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self.forward(x)
        loss = F.mse_loss(preds, y)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        
        mse_val = self.val_mse(preds, y)
        self.log("val_mse", mse_val, on_epoch=True, prog_bar=True)
        
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        preds = self.forward(x)
        loss = F.mse_loss(preds, y)
        self.log("test_loss", loss, on_epoch=True, prog_bar=True)
        
        mse_val = self.test_mse(preds, y)
        self.log("test_mse", mse_val, on_epoch=True, prog_bar=True)

        mae_val = self.test_mae(preds, y)
        self.log("test_mae", mae_val, on_epoch=True, prog_bar=True)
        
        return loss


# Training and testing

In [8]:
wandb_logger = WandbLogger(project="EEG Competition", log_model=True)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    filename="best-model-{epoch}-{val_loss:.2f}"
)
model = EEGNetAgeRegressor(
    Chans=52,
    Samples=250,
    F1=8,
    D=2,
    F2=16,
    kernelLength=64,
    dropoutRate=0.5,
    learning_rate=1e-3,
    weight_decay=1e-4
)

trainer = pl.Trainer(
    max_epochs=60, 
    accelerator="auto",  
    devices=2,
     logger=wandb_logger,
     callbacks=[checkpoint_callback]
)

trainer.fit(model, datamodule=eeg_dm)

trainer.test(model, datamodule=eeg_dm)


[34m[1mwandb[0m: Tracking run with wandb version 0.19.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m./wandb/run-20241222_081247-92jrryjo[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33msplendid-wind-29[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/22022500-university-of-engineering-and-technology/EEG%20competition[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/22022500-university-of-engineering-and-technology/EEG%20competition/runs/92jrryjo[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_mse', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('train_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('train_mse', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[34m[1mwandb[0m: Tracking run with wandb version 0.19.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m./wandb/run-20241222_085621-vy60jkn0[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mlunar-terrain-30[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/22022500-university-of-engineering-and-technology/EEG%20competition[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/22022500-university-of-engineering-and-technology/EEG%20competition/runs/vy60jkn0[0m
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:215: Using `DistributedSampler` with the dataloaders. During `trainer.test()`, it is recommended to use `Trainer(devices=1, num_nodes=1)` to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates some samples to make sure all devices have same batch size in case o

Testing: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('test_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('test_mse', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('test_mae', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


[{'test_loss': 553.847900390625,
  'test_mse': 553.847900390625,
  'test_mae': 19.745569229125977}]