In [1]:
!pip install torch torchvision torchaudio
!pip install torch-geometric numpy scipy mne pandas scikit-learn
!pip install mne
!pip install seiz_eeg
!pip install PyWavelets

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
from google.colab import drive

drive.mount("/content/drive")

Mounted at /content/drive


In [3]:
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from scipy import signal
from scipy.signal import welch
import pywt
from seiz_eeg.dataset import EEGDataset
from torch.utils.data import WeightedRandomSampler
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset as PyGDataset
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.metrics import f1_score


from models import GCN, GNNSage, GAT
from filters import fft_filtering, stft_filtering, psd_filtering, wt_filtering
from train import train_epoch, evaluate
from cross_validation import CrossValidator

top-level pandera module will be **removed in a future version of pandera**.
If you're using pandera to validate pandas objects, we highly recommend updating
your import:

```
# old import
import pandera as pa

# new import
import pandera.pandas as pa
```

If you're using pandera to validate objects from other compatible libraries
like pyspark or polars, see the supported libraries section of the documentation
for more information on how to import pandera:

https://pandera.readthedocs.io/en/stable/supported_libraries.html


```
```



### Data Loading  
Reads the pre‑windowed EEG metadata from parquet files and initializes `EEGDataset` instances, applying any time‑or frequency‑domain transforms.


In [4]:
# EEG parameters used later for the creation for the adjacency matric
CH_NAMES = [
    "Fp1",
    "Fp2",
    "F7",
    "F3",
    "Fz",
    "F4",
    "F8",
    "T3",
    "C3",
    "Cz",
    "C4",
    "T4",
    "T5",
    "P3",
    "Pz",
    "P4",
    "T6",
    "O1",
    "O2",
]

# data Loading (using `EEGDataset`)
# we read the pre-windowed segments from parquet, then wrap them into graphs

DATA_ROOT = Path("/content/drive/MyDrive/EPFL/NML/epfl-network-machine-learning-2025")

# one row = one 12s window
clips_tr = pd.read_parquet(DATA_ROOT / "train" / "train" / "segments.parquet")
clips_te = pd.read_parquet(DATA_ROOT / "test" / "test" / "segments.parquet")

# create the EEGDataset instances
dataset_tr = EEGDataset(
    clips_tr,
    signals_root=DATA_ROOT / "train" / "train",
    signal_transform=wt_filtering,
    prefetch=True,
)

dataset_te = EEGDataset(
    clips_te,
    signals_root=DATA_ROOT / "test" / "test",
    signal_transform=wt_filtering,
    prefetch=True,
    return_id=True,
)

print(f"Loaded {len(dataset_tr):,} training windows, {len(dataset_te):,} test windows.")

Loaded 12,993 training windows, 3,614 test windows.


### Preprocessing & Graph Construction  
Loads from the distances from the given distances_3d.csv file, pivots it into a 19×19 distance matrix, applies an RBF kernel to convert distances into similarities, and thresholds to build the adjacency matrix.


In [5]:
def load_adjacency(dist_csv, ch_names, threshold_pct=75):
    """Read the 3-columns [from,to,distance] of distances_3d.csv and build a symmetric adjacency:"""
    # read and pivot
    df = pd.read_csv(dist_csv)
    dmat = df.pivot(index="from", columns="to", values="distance")
    dmat = dmat.reindex(index=ch_names, columns=ch_names)
    dist = dmat.values.astype(float)

    # zero the diagonal
    np.fill_diagonal(dist, 0.0)

    # mirror known entries to get symmetric matrix
    mask = np.isnan(dist)
    dist[mask] = dist.T[mask]

    # fill any remaining NaNs with the max so that missing pairs become “very far apart”
    max_dist = np.nanmax(dist)
    dist[np.isnan(dist)] = max_dist

    # build RBF weights
    sigma = dist.mean()
    W = np.exp(-(dist**2) / (2 * sigma**2))

    # sparsify by zeroing out the weakest edges
    cutoff = np.percentile(W, threshold_pct)
    W[W < cutoff] = 0.0

    # zero the diagonal again
    np.fill_diagonal(W, 0.0)

    return W


distances_csv = DATA_ROOT / "distances_3d.csv"
A = load_adjacency(
    distances_csv, CH_NAMES, threshold_pct=30
)  # changed the threshold to a lower value otherwise I was getting many 0s (only very strong connections were considered)
print("Adjacency shape:", A.shape, "  density:", (A > 0).mean())
print(A)

Adjacency shape: (19, 19)   density: 0.9473684210526315
[[0.         0.43858603 0.43858603 0.43858603 0.43858603 0.43858603
  0.43858603 0.43858603 0.43858603 0.43858603 0.43858603 0.43858603
  0.43858603 0.43858603 0.43858603 0.43858603 0.43858603 0.43858603
  0.43858603]
 [0.43858603 0.         0.43858603 0.43858603 0.43858603 0.43858603
  0.43858603 0.43858603 0.43858603 0.43858603 0.43858603 0.43858603
  0.43858603 0.43858603 0.43858603 0.43858603 0.43858603 0.43858603
  0.43858603]
 [0.43858603 0.43858603 0.         0.93576703 0.43858603 0.64896094
  0.58308472 0.92430892 0.83832074 0.43858603 0.52317844 0.47450356
  0.75218863 0.67583695 0.43858603 0.46869762 0.4385897  0.58304998
  0.47448915]
 [0.43858603 0.43858603 0.93576703 0.         0.43858603 0.78010715
  0.64896094 0.83036187 0.89836861 0.43858603 0.65241855 0.52819134
  0.67583695 0.68944871 0.43858603 0.53784387 0.46869762 0.5458095
  0.47460245]
 [0.43858603 0.43858603 0.43858603 0.43858603 0.         0.43858603
  0.4

### PyG Dataset Wrapper  
Defines `GraphFromEEG`, which takes each transformed EEG window, computes per‑channel features (mean, variance, peak-to-peak, zero-crossing rate etc.), and uses the fixed graph topology (edges + weights) to produce `torch_geometric.data.Data` objects.


In [6]:
from threading import Condition
from typing import Literal

# Hyperparameters
batch_size = 1
epochs = 20
learning_rate = 1e-3

SFREQ = 250  # Hz, matches the dataset’s sampling rate

class GraphFromEEG(PyGDataset):
    def __init__(self,
                 eeg_ds,
                 adj,
                 is_test=False,
                 condition: Literal["concat_beginning", "concat_ending", "only_signal", "only_9features"] = "concat_beginning",
                 transpose: bool = True
                 ):
        super().__init__()
        self.eeg_ds = eeg_ds
        self.is_test = is_test
        self.condition = condition
        self.T = transpose

        rows, cols = np.nonzero(adj > 0)
        self.edge_index = torch.tensor([rows, cols], dtype=torch.long)
        self.edge_weight = torch.tensor(adj[rows, cols], dtype=torch.float)

    def len(self):
        return len(self.eeg_ds)

    def extract_handcrafted_signal_features(self, signal):
        # 1) mean
        mean_ = signal.mean(axis=0)

        # 2) variance
        var_ = signal.var(axis=0)

        # 3) peak-to-peak
        ptp_ = np.ptp(signal, axis=0)

        # 4) zero-crossing rate
        zcr_ = np.mean(np.diff(np.sign(signal), axis=0) != 0, axis=0)

        # 5) PSD via Welch
        freqs, psd = welch(signal, fs=SFREQ, axis=0)

        def bandpower(pxx, freqs, fmin, fmax):
            mask = (freqs >= fmin) & (freqs <= fmax)
            return pxx[mask].mean(axis=0)

        # 6–10) Bandpower in δ (1–4), θ (4–8), α (8–12), β (12–30), γ (30–45)
        delta = bandpower(psd, freqs, 1, 4)
        theta = bandpower(psd, freqs, 4, 8)
        alpha = bandpower(psd, freqs, 8, 12)
        beta = bandpower(psd, freqs, 12, 30)
        gamma = bandpower(psd, freqs, 30, 45)

        # stack into (n_channels, 9) feature matrix
        features = np.stack(
            [mean_, var_, ptp_, zcr_, delta, theta, alpha, beta, gamma], axis=0
        )
        return features

    def get(self, idx):
        arr, meta = self.eeg_ds[idx]
        if self.is_test:
            signal_, sid = arr, meta
            label = None
        else:
            signal_, label = arr, meta
            sid = None

        # signal: (n_time_bins, n_channels)
        if self.condition == "concat_beginning":
          features = self.extract_handcrafted_signal_features(signal_)
          x_1 = torch.tensor(features, dtype=torch.float)
          x_2 = torch.tensor(signal_,  dtype=torch.float)
          x = torch.concat([x_1, x_2], axis=0)

        elif self.condition == "concat_ending":
          features = self.extract_handcrafted_signal_features(signal_)
          x_1 = torch.tensor(features, dtype=torch.float)
          x_2 = torch.tensor(signal_,  dtype=torch.float)
          x = torch.concat([x_2, x_1], axis=0)

        elif self.condition == "only_signal":
          x = torch.tensor(signal_,  dtype=torch.float)

        elif self.condition == "only_9features":
          x = self.extract_handcrafted_signal_features(signal_)

        if self.T:
          x = x.T

        y = torch.tensor([label], dtype=torch.long) if label is not None else None

        data = Data(x=x, edge_index=self.edge_index, edge_attr=self.edge_weight, y=y)
        # keep index for later id lookup
        data.idx = torch.tensor([idx], dtype=torch.long)
        return data


graph_tr = GraphFromEEG(dataset_tr, A, is_test=False)
graph_te = GraphFromEEG(dataset_te, A, is_test=True)

loader_tr = DataLoader(graph_tr, batch_size=batch_size, shuffle=True)
loader_te = DataLoader(graph_te, batch_size=batch_size, shuffle=False)
for sample in loader_tr:
  print(sample.x.shape)
  break

print(f"Graphified train size: {len(graph_tr)}, test size: {len(graph_te)}")

  self.edge_index = torch.tensor([rows, cols], dtype=torch.long)


torch.Size([19, 3009])
Graphified train size: 12993, test size: 3614


### DataLoader with balancing
Instantiate `DataLoader` over our training graph dataset—using a `WeightedRandomSampler` to balance seizure vs. non‑seizure windows during training


In [7]:
# Balance the labels by giving more weight to the minority class
# Difference in the number of samples per class: [10476, 2517]

train_labels = [data.y.item() for data in graph_tr]
counts = np.bincount(train_labels)
weights = 1.0 / counts  # gives more weight to the minority class

# sample‐wise weight vector
sample_weights = np.array([weights[l] for l in train_labels])
sampler = WeightedRandomSampler(
    weights=sample_weights, num_samples=len(sample_weights), replacement=True
)

# rebuild the train loader
loader_tr = DataLoader(graph_tr, batch_size=batch_size, sampler=sampler, num_workers=1)

In [9]:
n_feat = 0
for sample in loader_tr:
  n_feat = sample.x.shape[1]
  break

### Model Definition & Training
Builds the GCN model (`EEG_GCN`) with two graph‐convolution layers followed by global mean pooling and a linear layer for binary classification also trains the GCN over 20 epochs using cross‑entropy and prints train loss and accuracy.

In [13]:
# Train & Validation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GNNSage(
    num_layers = 4,
    nfeat = n_feat,
    nhid = 32,
    nclass = 2,
    dropout = 0.0,
    adj_weight = False,
    use_bn = False,
    # heads = 3
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(1, epochs + 1):
    loss = train_epoch(model, loader_tr, optimizer, device)
    acc = evaluate(model, loader_tr, device)
    print(f"Epoch {epoch:02d}: Train Loss = {loss:.4f}, Train Acc = {acc:.4f}")

Epoch 01: Train Loss = 0.8077, Train Acc = 0.4960
Epoch 02: Train Loss = 0.8125, Train Acc = 0.4980
Epoch 03: Train Loss = 0.8073, Train Acc = 0.5088
Epoch 04: Train Loss = 0.8138, Train Acc = 0.5043


KeyboardInterrupt: 

### Test Prediction & Submission  
Run inference on the test loader, map each prediction back to the original window IDs and write out a Kaggle‐compatible CSV of `id,label` rows.


In [14]:
## Test & Submission

model.eval()
all_idxs, all_preds = [], []

with torch.no_grad():
    for batch in loader_te:
        # batch.idx is a tensor of shape [batch_size] giving the original idx
        all_idxs.extend(batch.idx.cpu().tolist())
        batch = batch.to(device)
        logits = model(batch)
        preds = logits.argmax(dim=1).cpu().tolist()
        all_preds.extend(preds)

all_ids = clips_te.index.tolist()

# check
assert len(all_ids) == len(all_preds)

# write submission
submission = pd.DataFrame({"id": all_ids, "label": all_preds})
submission.to_csv("submission.csv", index=False)
print(f"Saved submission.csv with {len(submission)} rows")

Saved submission.csv with 3614 rows


## Cross Validation

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_kwargs = {
    "num_layers": 4,
    "nfeat": n_feat,
    "nhid": 32,
    "nclass": 2,
    "dropout": 0.0,
    "adj_weight": False,
    "use_bn": False,
}
cv = CrossValidator(
    dataset = graph_tr,
    model_class = GCN,
    model_kwargs = model_kwargs,
    train_fn = train_epoch,
    eval_fn = evaluate,
    device = device,
    batch_size = 32,
    epochs = 20,
    learning_rate = 1e-3,
    n_splits = 5,
    shuffle = True,
    random_state = 42
)
cv.run()


--- Fold 1/5 ---
Epoch 5: Val = 0.8060792612543286
Epoch 10: Val = 0.8060792612543286
