In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import pandas as pd
from tqdm import tqdm
import os, random




In [2]:
train_df = pd.read_parquet("../data/train/segments.parquet")
test_df = pd.read_parquet("../data/test/segments.parquet")
distances_df = pd.read_csv("../data/distances_3d.csv", index_col=0)

print(train_df.head())
print(distances_df.head())


                            label  start_time  end_time       date  \
patient  session   segment                                           
pqejgcff s001_t000 0            1         0.0      12.0 2003-01-01   
                   1            1        12.0      24.0 2003-01-01   
                   2            1        24.0      36.0 2003-01-01   
                   3            1        36.0      48.0 2003-01-01   
                   4            1        48.0      60.0 2003-01-01   

                            sampling_rate                        signals_path  
patient  session   segment                                                     
pqejgcff s001_t000 0                  250  signals/pqejgcff_s001_t000.parquet  
                   1                  250  signals/pqejgcff_s001_t000.parquet  
                   2                  250  signals/pqejgcff_s001_t000.parquet  
                   3                  250  signals/pqejgcff_s001_t000.parquet  
                   4         

In [3]:
# Pivot to get a square matrix
distances_df_square = distances_df.pivot_table(index='from', columns='to', values='distance')
distances_df_square = distances_df_square.astype(float)

In [4]:
from scipy.spatial.distance import squareform
from scipy.sparse.csgraph import minimum_spanning_tree

# Construction d’un graphe d’adjacence basé sur les distances
def build_adjacency_matrix(distances_df, threshold=0.1):
    mat = distances_df.values
    adj = (mat < threshold).astype(int)
    np.fill_diagonal(adj, 0)
    return adj

adj_matrix = build_adjacency_matrix(distances_df_square, threshold=1.0)

In [6]:
from torch.utils.data import Dataset

class EEGGraphDataset(Dataset):
    def __init__(self, df, root_dir, adj_matrix, transform=None):
        self.df = df.reset_index()
        self.root_dir = root_dir
        self.adj_matrix = torch.tensor(adj_matrix, dtype=torch.long)
        self.edge_index = torch.nonzero(self.adj_matrix).t().contiguous()
        self.transform = transform

        # ⚡ Préchargement
        self.signal_cache = {}
        for path in self.df['signals_path'].unique():
            full_path = os.path.join(root_dir, path)
            self.signal_cache[path] = pd.read_parquet(full_path).T  # (n_channels, time)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        segment = self.signal_cache[row['signals_path']]
        start = int(row['start_time'] * row['sampling_rate'])
        end = int(row['end_time'] * row['sampling_rate'])

        x = torch.tensor(segment.iloc[:, start:end].mean(axis=1).values, dtype=torch.float).unsqueeze(1)
        y = torch.tensor(row['label'], dtype=torch.long)

        data = Data(x=x, edge_index=self.edge_index, y=y)
        return self.transform(data) if self.transform else data


In [7]:
train_dataset = EEGGraphDataset(train_df, "../data/train", adj_matrix)


In [None]:
sample = train_dataset[0]
print(sample)
print(sample.x.shape, sample.edge_index.shape, sample.y)


print("Sample EEG segment:")
print(f"{sample.x.shape[0]} nodes (electrodes) with {sample.x.shape[1]} features")
print(f"{sample.edge_index.shape[1]} edges")
print(f"Label: {sample.y.item()} (0: healthy, 1: unhealthy)")



Data(x=[19, 1], edge_index=[2, 96], y=1)
torch.Size([19, 1]) torch.Size([2, 96]) tensor(1)
Sample EEG segment:
19 nodes (electrodes) with 1 features
96 edges
Label: 1 (0: healthy, 1: unhealthy)


## DataLoader

In [39]:
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader

indices = list(range(len(train_dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)

train_split = torch.utils.data.Subset(train_dataset, train_idx)
val_split = torch.utils.data.Subset(train_dataset, val_idx)

train_loader = DataLoader(train_split, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_split, batch_size=32, shuffle=False, num_workers=4)


## Model

In [None]:
from torch_geometric.nn import GCNConv, global_mean_pool

class NeuroGNN(torch.nn.Module):
    def __init__(self, in_channels=1, hidden_channels=32):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, 1)

    def forward(self, batch):
        x, edge_index, batch_index = batch.x, batch.edge_index, batch.batch
        x = torch.relu(self.conv1(x, edge_index))
        x = torch.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch_index)
        return self.lin(x)


## Training

In [None]:
from sklearn.metrics import accuracy_score, f1_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeuroGNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1, 21):
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Validation
    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            val_preds.append(pred.cpu())
            val_labels.append(data.y.cpu())

    val_preds = torch.cat(val_preds)
    val_labels = torch.cat(val_labels)

    acc = accuracy_score(val_labels, val_preds)
    f1 = f1_score(val_labels, val_preds)

    print(f"Epoch {epoch:02d} | Loss: {total_loss:.4f} | Val Acc: {acc:.4f} | F1: {f1:.4f}")


Epoch 01 | Loss: 165.8006 | Val Acc: 0.8045 | F1: 0.0342
Epoch 02 | Loss: 156.9076 | Val Acc: 0.8045 | F1: 0.0379
Epoch 03 | Loss: 156.9684 | Val Acc: 0.8045 | F1: 0.0342
Epoch 04 | Loss: 156.5931 | Val Acc: 0.8045 | F1: 0.0305
Epoch 05 | Loss: 156.3801 | Val Acc: 0.8045 | F1: 0.0342
Epoch 06 | Loss: 155.8330 | Val Acc: 0.8034 | F1: 0.0377
Epoch 07 | Loss: 155.2158 | Val Acc: 0.8045 | F1: 0.0342
Epoch 08 | Loss: 154.5912 | Val Acc: 0.8045 | F1: 0.0342
Epoch 09 | Loss: 154.1378 | Val Acc: 0.8045 | F1: 0.0342
Epoch 10 | Loss: 153.8465 | Val Acc: 0.8045 | F1: 0.0342
Epoch 11 | Loss: 154.0876 | Val Acc: 0.8045 | F1: 0.0342
Epoch 12 | Loss: 153.9364 | Val Acc: 0.8045 | F1: 0.0342
Epoch 13 | Loss: 153.7105 | Val Acc: 0.8026 | F1: 0.0000
Epoch 14 | Loss: 155.2776 | Val Acc: 0.8026 | F1: 0.0000


KeyboardInterrupt: 