In [1]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
from itertools import combinations
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import DataLoader
from torch_geometric.utils import train_test_split_edges
from torch_geometric.transforms import RandomNodeSplit
import re

In [2]:
# Load EEG dataset
data_path = 'data/train.csv'  # Update if necessary
df = pd.read_csv(data_path)

# Identify PSD and Coherence columns
psd_columns = [col for col in df.columns if col.startswith('AB.')]
coh_columns = [col for col in df.columns if col.startswith('COH.')]

# Extract unique electrodes from PSD column names
psd_electrodes = set()
for col in psd_columns:
    match = re.search(r'\.([A-Za-z0-9]+)$', col)
    if match:
        psd_electrodes.add(match.group(1))

# Extract unique electrodes from Coherence column names
coh_electrodes = set()
for col in coh_columns:
    match = re.search(r'COH\.[A-Za-z]+\.[A-Za-z]+\.[a-z]+\.([A-Za-z0-9]+)\.[a-z]+\.([A-Za-z0-9]+)', col)
    if match:
        coh_electrodes.add(match.group(1))
        coh_electrodes.add(match.group(2))

# Combine electrodes from PSD and Coherence
electrodes = sorted(psd_electrodes.union(coh_electrodes))

# Create a mapping for electrodes to indices
electrode_to_idx = {e: i for i, e in enumerate(electrodes)}
print(f" Extracted {len(electrodes)} unique electrodes.")

# Extract edges and weights using coherence values
edge_index = []
edge_weights = []
missing_electrodes = set()

for col in coh_columns:
    match = re.search(r'COH\.[A-Za-z]+\.[A-Za-z]+\.[a-z]+\.([A-Za-z0-9]+)\.[a-z]+\.([A-Za-z0-9]+)', col)
    if match:
        src, tgt = match.groups()
        if src in electrode_to_idx and tgt in electrode_to_idx:
            edge_index.append([electrode_to_idx[src], electrode_to_idx[tgt]])
            edge_weights.append(df[col].mean())
        else:
            missing_electrodes.add((src, tgt))

print(f" Successfully created {len(edge_index)} edges.")
print(f" Missing Electrodes for Edges: {missing_electrodes}")

# Convert edge list to tensor
if len(edge_index) > 0:
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_weights = torch.tensor(edge_weights, dtype=torch.float32)
else:
    print(" Warning: No edges found!")

# Convert patient features (PSD & Coherence values)
num_patients = df.shape[0]
num_features_per_electrode = len(psd_columns) // len(electrodes)
patient_features = df[psd_columns].values.reshape(num_patients, len(electrodes), num_features_per_electrode)
patient_features_tensor = torch.tensor(patient_features, dtype=torch.float32)

# Extract labels
disorder_labels = df['main.disorder'].astype('category').cat.codes.values
target_tensor = torch.tensor(disorder_labels, dtype=torch.long)

✅ Extracted 19 unique electrodes.
✅ Successfully created 1026 edges.
⚠️ Missing Electrodes for Edges: set()


In [3]:
# Define GAT model for multiclass classification
class GATModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super(GATModel, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)
    
    def forward(self, x, edge_index, batch_size):
        out = []
        for i in range(batch_size):
            patient_x = x[i]
            patient_out = F.elu(self.conv1(patient_x, edge_index))
            patient_out = self.conv2(patient_out, edge_index)
            patient_out = patient_out.mean(dim=0)
            out.append(patient_out)
        out = torch.stack(out)
        return F.log_softmax(out, dim=1)

# Define training pipeline
in_channels = patient_features_tensor.shape[2]
hidden_channels = 16
out_channels = 7
model = GATModel(in_channels, hidden_channels, out_channels)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)


In [4]:
# Training function
def train():
    model.train()
    optimizer.zero_grad()
    batch_size = 32
    num_batches = num_patients // batch_size
    total_loss = 0
    
    for i in range(num_batches):
        batch_x = patient_features_tensor[i * batch_size : (i + 1) * batch_size]
        batch_y = target_tensor[i * batch_size : (i + 1) * batch_size]
        output = model(batch_x, edge_index, batch_size)
        loss = F.nll_loss(output, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / num_batches

train_losses = []
val_losses = []

# Training loop
epochs = 100
for epoch in range(epochs):
    loss = train()
    train_losses.append(loss)  # Store training loss
    
    # Validation loss computation
    model.eval()
    with torch.no_grad():
        batch_x = patient_features_tensor[:32]  # Use a small validation set
        batch_y = target_tensor[:32]
        output = model(batch_x, edge_index, batch_size=32)
        val_loss = F.nll_loss(output, batch_y).item()
        val_losses.append(val_loss)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Train Loss {loss:.4f}, Validation Loss {val_loss:.4f}")

Epoch 0: Loss 9.8500
Epoch 10: Loss 2.0543
Epoch 20: Loss 1.8634
Epoch 30: Loss 1.8595
Epoch 40: Loss 1.8463
Epoch 50: Loss 1.9777
Epoch 60: Loss 1.8286
Epoch 70: Loss 1.8606
Epoch 80: Loss 1.8363
Epoch 90: Loss 2.1565
✅ GAT model training complete and saved!


In [None]:
import matplotlib.pyplot as plt

# Plot Training and Validation Loss
plt.figure(figsize=(8, 5))
plt.plot(range(epochs), train_losses, label="Training Loss", marker="o", linestyle="-")
plt.plot(range(epochs), val_losses, label="Validation Loss", marker="s", linestyle="--")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training & Validation Loss Over Epochs")
plt.legend()
plt.grid(True)
plt.show()
