##### This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    print(dirname)

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install mne
!pip install optuna pytorch_lightning


In [None]:
import mne
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
file_paths = [f"/kaggle/input/eeg-signal/BCICIV_2a_gdf/A0{i}T.gdf" for i in range(1, 10)]  

all_data = []
all_labels = []

event_ids = {
    'left_hand': 7,   
    'right_hand': 8,   
    'feet': 9,         
    'tongue': 10       
}

for file_path in file_paths:
    raw = mne.io.read_raw_gdf(file_path, preload=True)
    raw.notch_filter(freqs=50)# Notch filter
    raw.filter(8., 30., fir_design='firwin')#Band-Pass Filter to Retain Frequencies of Interest
    raw.set_eeg_reference('average')# Re-reference the EEG Signals
    
    events, event_dict = mne.events_from_annotations(raw)
    print(f"Available events in {file_path}: {event_dict}")
    
    available_event_ids = {key: event_ids[key] for key in event_ids if event_ids[key] in event_dict.values()}
    
    if not available_event_ids:
        print(f"No motor imagery tasks found in {file_path}. Skipping this file.")
        continue  
    
    epochs = mne.Epochs(raw, events, event_id=available_event_ids, tmin=0, tmax=4, baseline=None, preload=True)
    
    X = epochs.get_data() 
    X_normalized = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
    y = epochs.events[:, -1] - min(available_event_ids.values())  
    
    all_data.append(X_normalized)
    all_labels.append(y)

X_combined = np.concatenate(all_data, axis=0) 
y_combined = np.concatenate(all_labels, axis=0)  

X_tensor = torch.tensor(X_combined, dtype=torch.float32).unsqueeze(1) 
y_tensor = torch.tensor(y_combined, dtype=torch.long)

In [None]:
train_ratio = 0.8
N = len(X_tensor)
idx = np.arange(N)
np.random.shuffle(idx)

train_size = int(N * train_ratio)
train_idx = idx[:train_size]
val_idx = idx[train_size:]

X_train = X_tensor[train_idx]
y_train = y_tensor[train_idx]

X_val = X_tensor[val_idx]
y_val = y_tensor[val_idx]

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)



class EEGAutoencoder(nn.Module):
    def __init__(self, n_channels=22, n_times=1001, latent_dim=64):
        super(EEGAutoencoder, self).__init__()
        self.n_channels = n_channels
        self.n_times = n_times
        input_dim = n_channels * n_times
        
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
        )
        
    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        x_recon = x_recon.view(-1, 1, self.n_channels, self.n_times)
        return x_recon, z

n_channels = X_tensor.shape[2]
n_times = X_tensor.shape[3]
latent_dim = 64
ae = EEGAutoencoder(n_channels=n_channels, n_times=n_times, latent_dim=latent_dim)
ae_optimizer = optim.Adam(ae.parameters(), lr=1e-3)
ae_criterion = nn.MSELoss()


num_epochs_ae = 50
for epoch in range(num_epochs_ae):
    ae.train()
    total_loss = 0.0
    for X_batch, _ in train_loader:
        ae_optimizer.zero_grad()
        X_recon, _ = ae(X_batch)
        loss = ae_criterion(X_recon, X_batch)
        loss.backward()
        ae_optimizer.step()
        total_loss += loss.item()
    print(f"AE Epoch {epoch+1}/{num_epochs_ae}, Loss: {total_loss/len(train_loader):.4f}")

ae.eval()
with torch.no_grad():
    train_features = []
    train_labels = []
    for X_batch, y_batch in train_loader:
        _, z = ae(X_batch)
        train_features.append(z)
        train_labels.append(y_batch)
    train_features = torch.cat(train_features, dim=0)
    train_labels = torch.cat(train_labels, dim=0)
    
    val_features = []
    val_labels = []
    for X_batch, y_batch in val_loader:
        _, z = ae(X_batch)
        val_features.append(z)
        val_labels.append(y_batch)
    val_features = torch.cat(val_features, dim=0)
    val_labels = torch.cat(val_labels, dim=0)




In [None]:

class CombinedDataset(torch.utils.data.Dataset):
    def __init__(self, raw_data, raw_labels, ae_features):
        self.raw_data = raw_data
        self.ae_features = ae_features
        self.raw_labels = raw_labels

    def __len__(self):
        return len(self.raw_labels)

    def __getitem__(self, idx):
        return self.raw_data[idx], self.ae_features[idx], self.raw_labels[idx]

train_combined_dataset = CombinedDataset(X_train, y_train, train_features)
val_combined_dataset = CombinedDataset(X_val, y_val, val_features)

train_loader = DataLoader(train_combined_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_combined_dataset, batch_size=batch_size, shuffle=False)

In [None]:
print(y_train.unique(return_counts=True))
print(y_val.unique(return_counts=True))

In [None]:
# train_feat_dataset = TensorDataset(train_features, train_labels)
# val_feat_dataset = TensorDataset(val_features, val_labels)
# batch_size = 64
# train_feat_loader = DataLoader(train_feat_dataset, batch_size=batch_size, shuffle=True)
# val_feat_loader = DataLoader(val_feat_dataset, batch_size=batch_size, shuffle=False)
import pytorch_lightning as pl


class MultiBranchLightningModule(pl.LightningModule):
    def __init__(self, 
                 latent_dim=64, 
                 n_channels=22, 
                 n_times=1001, 
                 num_classes=4, 
                 dropout_rate=0.3, 
                 hidden_dim=64, 
                 lr=1e-3, 
                 criterion='crossentropy'):
        super().__init__()
        self.save_hyperparameters()

        # Choose criterion
        if criterion == 'crossentropy':
            self.criterion = nn.CrossEntropyLoss()
        elif criterion == 'mse':
            self.criterion = nn.MSELoss()
        else:
            raise ValueError("Unknown criterion")

        # CNN branch for raw EEG
        self.cnn_branch = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(3,3), padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Dropout(dropout_rate),

            nn.Conv2d(16, 32, kernel_size=(3,3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Dropout(dropout_rate),
            nn.Flatten()
        )

        n_channels_out = n_channels // 4
        n_times_out = n_times // 4
        cnn_output_dim = 48000

        self.ae_branch = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.classifier = nn.Sequential(
            nn.Linear(cnn_output_dim + hidden_dim, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(32, num_classes)
        )

    def forward(self, x_raw, x_ae):
        out_cnn = self.cnn_branch(x_raw)   # [B, cnn_output_dim]
        out_ae = self.ae_branch(x_ae)      # [B, hidden_dim]
        # print(out_cnn.shape, out_ae.shape)
        out = torch.cat((out_cnn, out_ae), dim=1) 
        logits = self.classifier(out)
        return logits

    def training_step(self, batch, batch_idx):
        X_batch, Z_batch, y_batch = batch
        logits = self(X_batch, Z_batch)

        if self.hparams.criterion == 'crossentropy':
            loss = self.criterion(logits, y_batch)
        elif self.hparams.criterion == 'mse':
            num_classes = logits.shape[1]
            y_one_hot = torch.zeros(y_batch.size(0), num_classes, device=y_batch.device)
            y_one_hot.scatter_(1, y_batch.unsqueeze(1), 1.0)
            loss = self.criterion(logits, y_one_hot)
        else:
            raise ValueError("Unknown criterion")

        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        X_batch, Z_batch, y_batch = batch
        logits = self(X_batch, Z_batch)

        if self.hparams.criterion == 'crossentropy':
            loss = self.criterion(logits, y_batch)
        elif self.hparams.criterion == 'mse':
            num_classes = logits.shape[1]
            y_one_hot = torch.zeros(y_batch.size(0), num_classes, device=y_batch.device)
            y_one_hot.scatter_(1, y_batch.unsqueeze(1), 1.0)
            loss = self.criterion(logits, y_one_hot)
        else:
            raise ValueError("Unknown criterion")

        preds = torch.argmax(logits, dim=1)
        acc = (preds == y_batch).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True, on_epoch=True) 
        return {'val_loss': loss, 'val_acc': acc, 'preds': preds, 'targets': y_batch}


    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.hparams.lr)



In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import optuna


In [None]:
model = MultiBranchLightningModule(
            latent_dim=latent_dim, 
            n_channels=22, 
            n_times=1001, 
            num_classes=4, 
            hidden_dim=best_params['hidden_dim'],
            dropout_rate=best_params['dropout_rate'],
            lr=best_params['lr'],
            criterion=best_params['criterion']
            )

checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    mode='max',
    save_top_k=1,
    filename='best-checkpoint-{epoch:02d}-{val_acc:.2f}'
)

trainer = pl.Trainer(
    max_epochs=40,
    callbacks=[checkpoint_callback],
    enable_progress_bar=False,
    log_every_n_steps=1
)
trainer.fit(model, train_loader, val_loader)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

all_preds = []
all_targets = []

with torch.no_grad():
    for X_batch, Z_batch, y_batch in val_loader:
        logits = model(X_batch, Z_batch)  
        preds = torch.argmax(logits, dim=1)
        
        all_preds.append(preds.cpu())
        all_targets.append(y_batch.cpu())

all_preds = torch.cat(all_preds)
all_targets = torch.cat(all_targets)

cm = confusion_matrix(all_targets.numpy(), all_preds.numpy())

fig, ax = plt.subplots(figsize=(5, 5))
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap=plt.cm.Blues, ax=ax)
plt.title('Confusion Matrix')
plt.show()

In [None]:


def objective(trial):
    hidden_dim = trial.suggest_int('hidden_dim', 32, 256, step=32)
    dropout_rate = trial.suggest_float('dropout_rate', 0.0, 0.5, step=0.1)
    lr = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
    criterion_choice = trial.suggest_categorical('criterion', ['crossentropy', 'mse'])

    model = MultiBranchLightningModule(
                 latent_dim=latent_dim, 
                 n_channels=22, 
                 n_times=1001, 
                 num_classes=4, 
                 dropout_rate=dropout_rate, 
                 hidden_dim=hidden_dim, 
                 lr=lr, 
                 criterion=criterion_choice)

    checkpoint_callback = ModelCheckpoint(
        monitor='val_acc',
        mode='max',
        save_top_k=1,
        filename='best-checkpoint-{epoch:02d}-{val_acc:.2f}'
    )

    trainer = pl.Trainer(
        max_epochs=50,
        callbacks=[checkpoint_callback],
        enable_progress_bar=False,
        log_every_n_steps=1
    )

    trainer.fit(model, train_loader, val_loader)
    return trainer.callback_metrics['val_acc'].item()

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

best_params = study.best_params

best_model = MultiBranchLightningModule(
            latent_dim=latent_dim, 
            n_channels=22, 
            n_times=1001, 
            num_classes=4, 
            hidden_dim=best_params['hidden_dim'],
            dropout_rate=best_params['dropout_rate'],
            lr=best_params['lr'],
            criterion=best_params['criterion']
            )

checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    mode='max',
    save_top_k=1,
    filename='best-checkpoint-{epoch:02d}-{val_acc:.2f}'
)

trainer = pl.Trainer(
    max_epochs=50,
    callbacks=[checkpoint_callback],
    enable_progress_bar=False,
    log_every_n_steps=1
)

best_model.eval()

[I 2024-12-15 08:51:45,616] A new study created in memory with name: no-name-c4b7c1c2-6d17-460d-921e-6f449b584158
[I 2024-12-15 09:18:17,459] Trial 0 finished with value: 0.37142857909202576 and parameters: {'hidden_dim': 192, 'dropout_rate': 0.5, 'lr': 6.286522302899363e-05, 'criterion': 'crossentropy'}. Best is trial 0 with value: 0.37142857909202576.
[I 2024-12-15 09:45:20,352] Trial 1 finished with value: 0.45102041959762573 and parameters: {'hidden_dim': 160, 'dropout_rate': 0.1, 'lr': 2.1321485898878236e-05, 'criterion': 'crossentropy'}. Best is trial 1 with value: 0.45102041959762573.


In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

all_preds = []
all_targets = []

with torch.no_grad():
    for X_batch, Z_batch, y_batch in val_loader:
        logits = best_model(X_batch, Z_batch)  
        preds = torch.argmax(logits, dim=1)
        
        all_preds.append(preds.cpu())
        all_targets.append(y_batch.cpu())

all_preds = torch.cat(all_preds)
all_targets = torch.cat(all_targets)

cm = confusion_matrix(all_targets.numpy(), all_preds.numpy())

fig, ax = plt.subplots(figsize=(5, 5))
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap=plt.cm.Blues, ax=ax)
plt.title('Confusion Matrix - Best Hyperparameters')
plt.show()