## Setup

In [1]:
!pip install einops ema_pytorch mat73 numpy scikit_learn torch tqdm --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m113.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m92.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m53.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m39.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/
%cd /content/drive/MyDrive/IDL_Project

Mounted at /content/drive
/content/drive/MyDrive
/content/drive/MyDrive/IDL_Project


In [3]:
import sys
sys.path.append('/content/drive/MyDrive/IDL_Project/Diff-E')

In [4]:
# !git clone https://github.com/diffe2023/Diff-E.git

## Load and Preprocess Data

In [None]:
import numpy as np
import torch
import pandas as pd
import os
import random
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from utils import zscore_norm, minmax_norm, EEGDataset

In [6]:
def load_eeg_data(root_dir, subject_id, data_type='thinking', eeg_channels=None, window_size=256, step_size=128):
    """
    Load EEG data from CSV files in the specified directory

    Parameters:
    root_dir (str): Path to the data_eeg folder
    subject_id (str): Subject ID (e.g., '01', '02', etc.)
    data_type (str): Type of data to load ('thinking', 'speaking', 'stimuli', etc.)
    eeg_channels (list): List of EEG channel names to include
    window_size (int): Size of the sliding window for epoching the data
    step_size (int): Step size for the sliding window

    Returns:
    X (torch.Tensor): EEG data tensor of shape (n_samples, n_channels, n_timepoints)
    Y (torch.Tensor): Labels tensor
    """
    # Ensure subject_id is properly formatted
    subject_id = str(subject_id).zfill(2)

    # Define path to the subject's data
    subject_dir = os.path.join(root_dir, subject_id)
    file_path = os.path.join(subject_dir, f"{data_type}.csv")

    # Check if file exists
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Data file not found: {file_path}")

    # Read CSV data
    df = pd.read_csv(file_path)

    # If no EEG channels specified, use all available EEG channels
    # Exclude non-EEG columns like 'Time:256Hz', 'Epoch', 'Label', 'Stage', 'Flag'
    if eeg_channels is None:
        non_eeg_cols = ['Time:256Hz', 'Epoch', 'Label', 'Stage', 'Flag']
        eeg_channels = [col for col in df.columns if col not in non_eeg_cols]

    # Extract EEG data and labels
    eeg_data = df[eeg_channels].values  # shape: (n_samples, n_channels)

    # Extract labels from the 'Label' column
    if 'Label' in df.columns:
        labels = df['Label'].values
        unique_labels = np.unique(labels)
        # Map labels to integers (0 to n_classes-1)
        label_map = {label: i for i, label in enumerate(unique_labels)}
        labels = np.array([label_map[label] for label in labels])
    else:
        # If no labels are available, create dummy labels
        labels = np.zeros(len(eeg_data))

    # Reshape data into epochs using sliding windows
    X = []
    Y = []

    for i in range(0, len(eeg_data) - window_size + 1, step_size):
        window_data = eeg_data[i:i+window_size, :].T  # Transpose to get (n_channels, n_timepoints)
        window_label = labels[i + window_size // 2]  # Use the label from the middle of the window
        X.append(window_data)
        Y.append(window_label)

    # Convert to torch tensors
    X = torch.tensor(np.array(X), dtype=torch.float32)
    Y = torch.tensor(np.array(Y), dtype=torch.long)

    # Apply z-score normalization
    X = zscore_norm(X)

    return X, Y

In [7]:
def load_multiple_subjects(root_dir, subject_ids, data_type='thinking'):
    """
    Load data from multiple subjects and combine them

    Parameters:
    root_dir (str): Path to the data_eeg folder
    subject_ids (list): List of subject IDs to load
    data_type (str): Type of data to load

    Returns:
    X (torch.Tensor): Combined EEG data tensor
    Y (torch.Tensor): Combined labels tensor
    """
    all_X = []
    all_Y = []

    for subject_id in subject_ids:
        try:
            X, Y = load_eeg_data(root_dir, subject_id, data_type)
            all_X.append(X)
            all_Y.append(Y)
            print(f"Loaded data from subject {subject_id}")
        except Exception as e:
            print(f"Error loading data from subject {subject_id}: {e}")

    # Combine data from all subjects
    X_combined = torch.cat(all_X, dim=0)
    Y_combined = torch.cat(all_Y, dim=0)

    return X_combined, Y_combined

In [8]:
def get_dataloader(X, Y, batch_size, batch_size2, seed, shuffle=True):
    """
    Split the data into training and testing sets and create DataLoader instances

    Parameters:
    X (torch.Tensor): EEG data tensor
    Y (torch.Tensor): Labels tensor
    batch_size (int): Batch size for training
    batch_size2 (int): Batch size for testing
    seed (int): Random seed for reproducibility
    shuffle (bool): Whether to shuffle the data

    Returns:
    training_loader (DataLoader): DataLoader for training data
    test_loader (DataLoader): DataLoader for test data
    """
    X_train, X_test, Y_train, Y_test = train_test_split(
        X, Y, test_size=0.2, shuffle=shuffle, stratify=Y, random_state=seed
    )

    training_set = EEGDataset(X_train, Y_train)
    training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=shuffle)

    test_set = EEGDataset(X_test, Y_test)
    test_loader = DataLoader(test_set, batch_size=batch_size2, shuffle=False)

    return training_loader, test_loader

In [9]:
def pad_channels(X, target_channels=16):
    """Pad EEG data tensor with zero channels to reach target_channels"""
    batch_size, channels, seq_length = X.shape
    if channels >= target_channels:
        return X

    # Create tensor with target_channels
    padded_X = torch.zeros((batch_size, target_channels, seq_length), dtype=X.dtype)
    # Copy original data to the first channels
    padded_X[:, :channels, :] = X

    return padded_X

In [10]:
root_dir = "data_eeg"
subject_ids = [str(i).zfill(2) for i in range(1, 22)]

# Load thinking data from specified subjects
X, Y = load_multiple_subjects(root_dir, subject_ids, data_type='thinking')
X = pad_channels(X, target_channels=16)

Loaded data from subject 01
Loaded data from subject 02
Loaded data from subject 03
Loaded data from subject 04
Loaded data from subject 05
Loaded data from subject 06
Loaded data from subject 07
Loaded data from subject 08
Loaded data from subject 09
Loaded data from subject 10
Loaded data from subject 11
Loaded data from subject 12
Loaded data from subject 13
Loaded data from subject 14
Loaded data from subject 15
Loaded data from subject 16
Loaded data from subject 17
Loaded data from subject 18
Loaded data from subject 19
Loaded data from subject 20
Loaded data from subject 21


In [11]:
batch_size = 32
batch_size2 = 260
seed = 42
random.seed(seed)
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader, test_loader = get_dataloader(X, Y, batch_size, batch_size2, seed, shuffle=True)

In [12]:
print(f"Number of batches in train loader: {len(train_loader)}")
print(f"Number of batches in test loader: {len(test_loader)}")

for X_train, y_train in train_loader:
    print(f"Train Batch - Input shape: {X_train.shape}, Labels shape: {y_train.shape}")
    break

for X_test, y_test in test_loader:
    print(f"Test Batch - Input shape: {X_test.shape}, Labels shape: {y_test.shape}")
    break

Number of batches in train loader: 828
Number of batches in test loader: 26
Train Batch - Input shape: torch.Size([32, 16, 256]), Labels shape: torch.Size([32])
Test Batch - Input shape: torch.Size([260, 16, 256]), Labels shape: torch.Size([260])


## Model

In [13]:
from models import *

num_classes = 16
channels = X.shape[1]

n_T = 1000
ddpm_dim = 128
encoder_dim = 256
fc_dim = 512

In [14]:
ddpm_model = ConditionalUNet(in_channels=channels, n_feat=ddpm_dim).to(device)
ddpm = DDPM(nn_model=ddpm_model, betas=(1e-6, 1e-2), n_T=n_T, device=device).to(device)
encoder = Encoder(in_channels=channels, dim=encoder_dim).to(device)
decoder = Decoder(in_channels=channels, n_feat=ddpm_dim, encoder_dim=encoder_dim).to(device)
fc = LinearClassifier(encoder_dim, fc_dim, emb_dim=num_classes).to(device)
diffe = DiffE(encoder, decoder, fc).to(device)

In [15]:
print("ddpm size: ", sum(p.numel() for p in ddpm.parameters()))
print("encoder size: ", sum(p.numel() for p in encoder.parameters()))
print("decoder size: ", sum(p.numel() for p in decoder.parameters()))
print("fc size: ", sum(p.numel() for p in fc.parameters()))

ddpm size:  239174
encoder size:  137475
decoder size:  136466
fc size:  404498


## Train

In [16]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from ema_pytorch import EMA
from tqdm import tqdm
from sklearn.metrics import (
    f1_score,
    roc_auc_score,
    precision_score,
    recall_score,
    top_k_accuracy_score,
)

In [17]:
def evaluate(encoder, fc, generator, device):
    labels = np.arange(0, 16)
    Y = []
    Y_hat = []
    for x, y in generator:
        x, y = x.to(device), y.type(torch.LongTensor).to(device)
        encoder_out = encoder(x)
        y_hat = fc(encoder_out[1])
        y_hat = F.softmax(y_hat, dim=1)

        Y.append(y.detach().cpu())
        Y_hat.append(y_hat.detach().cpu())

    # List of tensors to tensor to numpy
    Y = torch.cat(Y, dim=0).numpy()  # (N, )
    Y_hat = torch.cat(Y_hat, dim=0).numpy()  # (N, 13): has to sum to 1 for each row

    # Accuracy and Confusion Matrix
    accuracy = top_k_accuracy_score(Y, Y_hat, k=1, labels=labels)
    f1 = f1_score(Y, Y_hat.argmax(axis=1), average="macro", labels=labels)
    recall = recall_score(Y, Y_hat.argmax(axis=1), average="macro", labels=labels)
    precision = precision_score(Y, Y_hat.argmax(axis=1), average="macro", labels=labels)
    auc = roc_auc_score(Y, Y_hat, average="macro", multi_class="ovo", labels=labels)

    metrics = {
        "accuracy": accuracy,
        "f1": f1,
        "recall": recall,
        "precision": precision,
        "auc": auc,
    }
    # df_cm = pd.DataFrame(confusion_matrix(Y, Y_hat.argmax(axis=1)))
    return metrics

In [18]:
# Criterion
criterion = nn.L1Loss()
criterion_class = nn.MSELoss()

# Define optimizer
base_lr, lr = 9e-5, 1.5e-3
optim1 = optim.RMSprop(ddpm.parameters(), lr=base_lr)
optim2 = optim.RMSprop(diffe.parameters(), lr=base_lr)

In [19]:
fc_ema = EMA(diffe.fc, beta=0.95, update_after_step=100, update_every=10,)

step_size = 150
scheduler1 = optim.lr_scheduler.CyclicLR(
    optimizer=optim1,
    base_lr=base_lr,
    max_lr=lr,
    step_size_up=step_size,
    mode="exp_range",
    cycle_momentum=False,
    gamma=0.9998,
)
scheduler2 = optim.lr_scheduler.CyclicLR(
    optimizer=optim2,
    base_lr=base_lr,
    max_lr=lr,
    step_size_up=step_size,
    mode="exp_range",
    cycle_momentum=False,
    gamma=0.9998,
)

In [20]:
# Train & Evaluate
num_epochs = 500
test_period = 1
start_test = test_period
alpha = 0.1

best_acc = 0
best_f1 = 0
best_recall = 0
best_precision = 0
best_auc = 0

with tqdm(total=num_epochs) as pbar:
    for epoch in range(num_epochs):
        ddpm.train()
        diffe.train()

        ############################## Train ###########################################
        for x, y in train_loader:
            x, y = x.to(device), y.type(torch.LongTensor).to(device)
            y_cat = F.one_hot(y, num_classes=16).type(torch.FloatTensor).to(device)
            # Train DDPM
            optim1.zero_grad()
            x_hat, down, up, noise, t = ddpm(x)

            loss_ddpm = F.l1_loss(x_hat, x, reduction="none")
            loss_ddpm.mean().backward()
            optim1.step()
            ddpm_out = x_hat, down, up, t

            # Train Diff-E
            optim2.zero_grad()
            decoder_out, fc_out = diffe(x, ddpm_out)

            loss_gap = criterion(decoder_out, loss_ddpm.detach())
            loss_c = criterion_class(fc_out, y_cat)
            loss = loss_gap + alpha * loss_c
            loss.backward()
            optim2.step()

            # Optimizer scheduler step
            scheduler1.step()
            scheduler2.step()

            # EMA update
            fc_ema.update()

        ############################## Test ###########################################
        with torch.no_grad():
            if epoch > start_test:
                test_period = 1
            if epoch % test_period == 0:
                ddpm.eval()
                diffe.eval()

                metrics_test = evaluate(diffe.encoder, fc_ema, test_loader, device)

                acc = metrics_test["accuracy"]
                f1 = metrics_test["f1"]
                recall = metrics_test["recall"]
                precision = metrics_test["precision"]
                auc = metrics_test["auc"]

                best_acc_bool = acc > best_acc
                best_f1_bool = f1 > best_f1
                best_recall_bool = recall > best_recall
                best_precision_bool = precision > best_precision
                best_auc_bool = auc > best_auc

                if best_acc_bool:
                    best_acc = acc
                if best_f1_bool:
                    best_f1 = f1
                if best_recall_bool:
                    best_recall = recall
                if best_precision_bool:
                    best_precision = precision
                if best_auc_bool:
                    best_auc = auc

                description = f"Best accuracy: {best_acc*100:.2f}%"
                pbar.set_description(description)
        pbar.update(1)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Best accuracy: 30.62%: 100%|██████████| 500/500 [2:07:16<00:00, 15.27s/it]
