In [None]:
import os
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from pathlib import Path
from seiz_eeg.dataset import EEGDataset
from torch.utils.data import Dataset
from collections import defaultdict
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.nn import TransformerConv, global_mean_pool

In [None]:
def seed_everything(seed: int):
    # Python random module
    random.seed(seed)
    # Numpy random module
    np.random.seed(seed)
    # Torch random seeds
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.

    # Set PYTHONHASHSEED environment variable for hash-based operations
    os.environ["PYTHONHASHSEED"] = str(seed)
    
    # Ensure deterministic behavior in cudnn (may slow down your training)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
seed_everything(42)

In [None]:
device = torch.device('cuda')

In [None]:
DATA_PATH = "/home/ogut/data"
DATA_ROOT = Path(DATA_PATH)

In [None]:
clips_tr = pd.read_parquet(DATA_ROOT / "train/segments.parquet")
clips_te = pd.read_parquet(DATA_ROOT / "test/segments.parquet")

In [None]:
MAX_DISTANCE = 1
adjacency_matrix = pd.read_csv('distances_3d.csv').pivot(index='from', columns='to', values='distance').to_numpy()
adjacency_matrix = (adjacency_matrix <= MAX_DISTANCE).astype(int)

In [None]:
class GraphEEGDataset(Dataset):
    def __init__(
        self,
        clips_df: pd.DataFrame,
        signals_root: Path,
        adjacency_matrix: np.ndarray,
        prefetch: bool = True
    ):
        self.dataset = EEGDataset(
            clips_df=clips_df,
            signals_root=signals_root,
            prefetch=prefetch,
        )
        self.edge_index = torch.tensor(np.array(np.nonzero(adjacency_matrix)), dtype=torch.long)
        self.edge_attr = torch.tensor(adjacency_matrix[np.nonzero(adjacency_matrix)], dtype=torch.float32)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Data:
        x, y = self.dataset[idx]
        return Data(
            x=torch.tensor(x, dtype=torch.float32).transpose(0, 1),
            edge_index=self.edge_index,
            edge_attr=self.edge_attr,
            y=torch.tensor(y, dtype=torch.long),
        )

In [None]:
dataset_tr = GraphEEGDataset(
    clips_tr,
    signals_root=DATA_ROOT / "train",
    adjacency_matrix=adjacency_matrix,
    prefetch=True,
)

dataset_tr, dataset_val = torch.utils.data.random_split(dataset_tr, [0.9, 0.1])

In [None]:
loader_tr  = DataLoader(dataset_tr, batch_size=64, shuffle=True, num_workers=4)
loader_val = DataLoader(dataset_val, batch_size=64, shuffle=False, num_workers=4)

In [None]:
class SlidingWindowBatcher(nn.Module):
    def __init__(self, window_size: int = 50, step_size: int = 25):
        super().__init__()
        self.window_size = window_size
        self.step_size = step_size

    def forward(self, batch: Batch) -> Batch:
        x          = batch.x
        edge_index = batch.edge_index
        edge_attr  = batch.edge_attr
        y          = batch.y
        ptr        = batch.ptr
        
        batch_size      = ptr.numel() - 1
        nodes_per_graph = x.size(0) // batch_size
        features_dim    = x.size(1)
        edges_per_graph = edge_index.size(1) // batch_size

        edge_index_per_graph = []
        edge_attr_per_graph = []

        for i in range(batch_size):
            node_offset = ptr[i].item()
            node_end = ptr[i+1].item()
            mask = (edge_index[0] >= node_offset) & (edge_index[0] < node_end)
            edge_index_per_graph.append(edge_index[:, mask] - node_offset) # local indexing
            edge_attr_per_graph.append(edge_attr[mask]) 

        x = x.view(batch_size, nodes_per_graph, features_dim)
        x = x.unfold(dimension=2, size=self.window_size, step=self.step_size)  # (B, N, W, F)
        x = x.permute(0, 2, 1, 3).contiguous()  # (B, W, N, F)
        x = x.view(-1, nodes_per_graph, self.window_size)  # (B * W, N, F)

        windows_per_graph = x.size(0) // batch_size
        return Batch.from_data_list([
            Data(
                x=x[i],
                edge_index=edge_index_per_graph[i // windows_per_graph],
                edge_attr=edge_attr_per_graph[i // windows_per_graph],
                y=y[i // windows_per_graph]
            )
            for i in range(x.size(0))
        ])

In [None]:
class GraphTransformer(nn.Module):
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, heads: int = 4):
        super().__init__()
        self.layer_1 = TransformerConv(in_channels, hidden_channels, heads=heads)
        self.activation_1 = nn.GELU()
        self.layer_2 = TransformerConv(hidden_channels * heads, hidden_channels, heads=1)
        self.activation_2 = nn.GELU()
        self.linear = nn.Linear(hidden_channels, out_channels)

    def forward(self, data: Batch) -> torch.tensor:
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.layer_1(x, edge_index)
        x = self.activation_1(x)
        x = self.layer_2(x, edge_index)
        x = self.activation_2(x)        
        x = global_mean_pool(x, batch)
        x = self.linear(x)
        return x

In [None]:
class Encoder(nn.Module):
    def __init__(
        self,
        d_model: int = 512,
        nhead: int = 8,
        dim_feedforward: int = 2048,
        num_layers: int = 4,
        max_seq_len: int = 512
    ):
        super().__init__()
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_seq_len, d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward, 
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_layers
        )

    def forward(self, x: torch.tensor) -> torch.tensor:
        x = x + self.pe[:, :x.size(1)]
        x = self.encoder(x)
        x = x.mean(dim=1)
        return x

In [None]:
class SeizureClassifier(nn.Module):
    def __init__(
        self,
        window_size: int = 50,
        step_size: int = 25,
        graph_hidden_channels: int = 64,
        graph_heads: int = 4,
        d_model: int = 512,
        nhead: int = 8,
        dim_feedforward: int = 2048,
        num_layers: int = 4,
        max_seq_len: int = 512,
    ):
        super().__init__()  
        self.d_model = d_model
        self.sliding_window_batcher = SlidingWindowBatcher(
            window_size=window_size, 
            step_size=step_size
        )
        
        self.graph_transformer = GraphTransformer(
            in_channels=window_size,
            hidden_channels=graph_hidden_channels, 
            out_channels=d_model,
            heads=graph_heads
        )
        self.encoder = Encoder(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            num_layers=num_layers,
            max_seq_len=max_seq_len
        )
        self.classifier = nn.Linear(d_model, 1)

    def forward(self, batch: Batch) -> torch.tensor:
        """
        x: (B * N, F)
        edge_index: (2, B * N * (N-1))
        edge_attr: (B * N * (N-1))
        y: (B)
        """
        batch_size = batch.y.size(0)
        
        """
        x: (B * W * N, L) 
        edge_index: (2, B * W * N * (N-1))
        edge_attr: (B * W * N * (N-1))
        y: (B * W)
        """
        batch = self.sliding_window_batcher(batch)
        windows_per_graph = batch.y.size(0) // batch_size

        graph_embeddings = self.graph_transformer(batch) # (B * W, D)
        graph_embeddings = graph_embeddings.reshape(batch_size, windows_per_graph, self.d_model) # (B, W, D)

        time_series_embeddings = self.encoder(graph_embeddings) # (B, D)
        
        logits = self.classifier(time_series_embeddings) # (B, 1)
        return logits

In [None]:
def train_epoch(model, epoch, loader_tr, optimizer, criterion, metrics):
    model.train()
    train_loss = 0
    accuracy = 0
    f1 = 0
    for batch in tqdm(loader_tr, desc=f"Epoch {epoch +1}"):
        batch = batch.to(device)
        optimizer.zero_grad()
        outputs = model(batch)

        with torch.no_grad():
            predictions = torch.round(torch.sigmoid(outputs))
            accuracy += accuracy_score(predictions.cpu(), batch.y.cpu())
            f1 += f1_score(predictions.cpu(), batch.y.cpu())

        loss = criterion(outputs.view(-1), batch.y.float())
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(loader_tr)
    accuracy /= len(loader_tr)
    f1 /= len(loader_tr)

    metrics["train"]["loss"].append(train_loss)
    metrics["train"]["accuracy"].append(accuracy)
    metrics["train"]["f1"].append(f1)
    
    print('Train Loss: {:.4f}, '.format(train_loss),'Train Accuracy: {:.4f}, '.format(accuracy),'Train F1-Score: {:.4f}, '.format(f1))

In [None]:
def eval_epoch(model, loader_val, criterion, metrics):
    model.eval()
    eval_loss = 0
    accuracy = 0
    f1 = 0
    for batch in loader_val:
        with torch.no_grad():
            batch = batch.to(device)
            outputs = model(batch)

            loss = criterion(outputs.view(-1), batch.y.float())
            predictions = torch.round(torch.sigmoid(outputs))
        
            eval_loss += loss.item()
            accuracy += accuracy_score(predictions.cpu(), batch.y.cpu())
            f1 += f1_score(predictions.cpu(), batch.y.cpu())
            
    eval_loss /= len(loader_val)
    accuracy /= len(loader_val)
    f1 /= len(loader_val)

    metrics["eval"]["loss"].append(eval_loss)
    metrics["eval"]["acc"].append(accuracy)
    metrics["eval"]["f1"].append(f1)

    print('Val Loss: {:.4f}, '.format(eval_loss),'Val Accuracy: {:.4f}, '.format(accuracy),'Val F1-Score: {:.4f}, '.format(f1))

In [None]:
def train(model, epochs, loader_tr, loader_val, optimizer, criterion):
    metrics = dict()
    metrics["train"] = defaultdict(list)
    metrics["eval"] = defaultdict(list)
    
    best_f1 = 0.0
    for epoch in range(epochs):
        train_epoch(model, epoch, loader_tr, optimizer, criterion, metrics)
        eval_epoch(model, loader_val, criterion, metrics)
    
        current_f1 = metrics["eval"]["f1"][-1]
        if current_f1 > best_f1:
            best_f1 = current_f1
            torch.save(model.state_dict(), f"./best_model.pt")
            print(f"✅ Best model saved with F1: {best_f1:.4f} as best_model.pt")
    return metrics

In [None]:
model = SeizureClassifier(
    window_size=125,
    step_size=62,
    graph_hidden_channels=128,
    graph_heads=4,
    d_model=512,
    nhead=8,
    dim_feedforward=2048,
    num_layers=8,
    max_seq_len=256
)
model = model.to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
metrics = train(model, 10, loader_tr, loader_val, optimizer, criterion)