In [None]:
from pathlib import Path
import pandas as pd
from seiz_eeg.dataset import EEGDataset

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torch_geometric.nn import GCNConv

from training import train
from utils import display_metrics, count_parameters, seed_everything
from preprocessing import normalize_z_score

In [None]:
# You might need to change this according to where you store the data folder
# Inside your data folder, you should have the following structure:
# data
# ├── train
# │   ├── signals/
# │   ├── segments.parquet
# │-- test
#     ├── signals/
#     ├── segments.parquet

data_path = "data"

DATA_ROOT = Path(data_path)

In [None]:
seed_everything(1)

In [None]:
clips_tr = pd.read_parquet(DATA_ROOT / "train/segments.parquet")
clips_te = pd.read_parquet(DATA_ROOT / "test/segments.parquet")

In [None]:
"""
Based on https://www.sciencedirect.com/science/article/pii/S1746809422004098#fig4
"""

class SlidingWindowBatch(nn.Module):
    #TODO assert last window shape
    def __init__(self, window_size=125, step_size=62):
        super().__init__()
        self.window_size = window_size
        self.step_size = step_size

    def forward(self, data):
        batch_size, signal_len, channels = data.shape
        num_windows = (signal_len - self.window_size) // self.step_size + 1

        windows = []
        for i in range(num_windows):
            start = i * self.step_size
            end = start + self.window_size
            window = data[:, start:end, :]  # shape: (batch_size, window_size, 19)
            windows.append(window)

        windows = torch.stack(windows, dim=1)  # shape: (batch_size, num_windows, window_size, 19)
        return windows.transpose(-2,-1) # shape: (batch_size, num_windows, 19, window_size)

class GATLayer(nn.Module):
    def __init__(self, in_channels, hidden_channels=32, num_heads_1=8, num_heads_2=1):
        super().__init__()
        
        # # First GAT layer with 8 attention heads
        # self.gat1 = GATConv(in_channels, hidden_channels, heads=num_heads_1, concat=True)
        
        # # Second GAT layer with 1 attention head
        # self.gat2 = GATConv(hidden_channels * num_heads_1, hidden_channels, heads=num_heads_2, concat=False)

        self.gcn1 = GCNConv(in_channels=in_channels, out_channels=hidden_channels)

        self.gcn2 = GCNConv(in_channels=hidden_channels, out_channels=hidden_channels)

        self.relu = nn.ReLU()
        
    def forward(self, x, edge_index):
        # First GAT layer
        x = self.gcn1(x, edge_index)  # shape: (batch_size * num_windows, 19, hidden_channels * num_heads_1)
        
        x = self.relu(x)

        # Second GAT layer with 1 attention head
        x = self.gcn2(x, edge_index)  # shape: (batch_size * num_windows, 19, hidden_channels)
        
        return x

class CombinedModel(nn.Module):
    def __init__(self, gat_in_features=64, gat_out_features=32, output_size=1, window_size=125, step_size=62, num_heads_1=8, num_heads_2=1, lstm_hidden_size=128):
        super().__init__()

        # Sliding window batch layer
        self.sliding_window = SlidingWindowBatch(window_size=window_size, step_size=step_size)

        # Projection layer from 125 to 64 features
        self.projection = nn.Linear(window_size, gat_in_features)

        # GAT layers
        self.gat = GATLayer(in_channels=gat_in_features, hidden_channels=gat_out_features, num_heads_1=num_heads_1, num_heads_2=num_heads_2)

        # Flattening the output and passing through a fully connected layer before BiLSTM
        self.fc = nn.Linear(19 * gat_out_features, 128)  # 19 nodes * 32 features per node

        # BiLSTM
        self.bilstm = nn.LSTM(input_size=128, hidden_size=lstm_hidden_size, num_layers=3, batch_first=True, bidirectional=True, dropout=0.5)

        # Final output layer
        self.output_layer = nn.Linear(lstm_hidden_size * 2, output_size)  # *2 because of bidirectional

        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        # Apply sliding window on input
        windows = self.sliding_window(x)  # shape: (batch_size, num_windows, 19, window_size)

        # Projection layer
        windows_proj = self.projection(windows)  # shape: (batch_size, num_windows, 19, 64)

        # Reshape to (batch_size * num_windows, 19, 64) for GAT
        batch_size, num_windows, nodes, features = windows_proj.shape
        windows_proj = windows_proj.view(batch_size * num_windows, nodes, features)  # shape: (batch_size * num_windows, 19, 64)
        
        # Make the tensor contiguous before passing it to GAT
        windows_proj = windows_proj.contiguous()

        # Apply GAT
        gat_out = self.gat(windows_proj, edge_index)  # shape: (batch_size * num_windows, 19, 32)

        # Reshape back to (batch_size, num_windows, 19, 32)
        gat_out = gat_out.view(batch_size, num_windows, nodes, -1)  # shape: (batch_size, num_windows, 19, 32)

        # Flatten each window and pass through fully connected layer
        windows_flat = gat_out.view(batch_size, num_windows, -1)  # shape: (batch_size, num_windows, 19 * 32)
        
        # Make the tensor contiguous before passing to FC layer
        windows_flat = windows_flat.contiguous()
        windows_flat = self.relu(self.fc(windows_flat))  # shape: (batch_size, num_windows, 128)

        # Apply BiLSTM
        lstm_out, _ = self.bilstm(windows_flat)  # shape: (batch_size, num_windows, lstm_hidden_size * 2)

        # Use output of the last time step from the BiLSTM
        out = lstm_out[:, -1, :]  # shape: (batch_size, lstm_hidden_size * 2)

        # Final output layer
        out = self.output_layer(out)  # shape: (batch_size, output_size)

        return out


In [None]:
device = torch.device('mps')
MAX_DIST = 1
NUM_EPOCHS = 13
model = CombinedModel(gat_in_features=64, gat_out_features=32, output_size=1, window_size=125, step_size=62, num_heads_1=8, num_heads_2=1, lstm_hidden_size=128)
model.to(device)

In [None]:
count_parameters(model)

In [None]:
# You can change the signal_transform, or remove it completely
dataset_tr = EEGDataset(
    clips_tr,
    signals_root=DATA_ROOT / "train",
    #signal_transform=normalize_z_score,
    prefetch=True,  # If your compute does not allow it, you can use `prefetch=False`
)

In [None]:
device = torch.device('mps')

distance_matrix = torch.tensor(pd.read_csv('data/distances_3d.csv').pivot(index='from', columns='to', values='distance').to_numpy(),device=device,dtype=torch.float32)
adjacency = (distance_matrix <= MAX_DIST).int()# - torch.eye(19).to(device)
edge_index = torch.argwhere(adjacency==1).transpose(-1,-2).to(torch.long)

#train_set, test_set, val_set = torch.utils.data.random_split(dataset_tr,[0.7,0.2,0.1])
train_set, val_set = torch.utils.data.random_split(dataset_tr,[0.9,0.1])

#Check worker_init_fn
loader_tr = DataLoader(train_set, batch_size=512, shuffle=True, num_workers=0)
loader_val = DataLoader(val_set, batch_size=512, shuffle=True, num_workers=0)
#loader_ts = DataLoader(test_set, batch_size=512, shuffle=True, num_workers=0)

positives = 0
negatives = 0
for idx, data in enumerate(loader_tr):
    positives += data[1].sum()
    negatives += len(data[1]) - data[1].sum()

pos_weight = negatives / positives
pos_weight_tensor = torch.tensor([pos_weight]).to(torch.float32).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)

criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)

In [None]:
metrics = train(model, NUM_EPOCHS, device, loader_tr, loader_val, optimizer, criterion, verbose=True, edge_index=edge_index)

In [None]:
display_metrics(NUM_EPOCHS, metrics)

# Submission

In [None]:
# Create test dataset
dataset_te = EEGDataset(
    clips_te,  # Your test clips variable
    signals_root=DATA_ROOT
    / "test",  # Update this path if your test signals are stored elsewhere
    #signal_transform=fft_filtering,  # You can change or remove the signal_transform as needed
    prefetch=True,  # Set to False if prefetching causes memory issues on your compute environment
    return_id=True,  # Return the id of each sample instead of the label
)

# Create DataLoader for the test dataset
loader_te = DataLoader(dataset_te, batch_size=512, shuffle=False)

In [None]:
# Generate the submission file for Kaggle

# Set the model to evaluation mode
model.eval()

# Lists to store sample IDs and predictions
all_predictions = []
all_ids = []

# Disable gradient computation for inference
with torch.no_grad():
    for x, ids in loader_te:
        # Assume each batch returns a tuple (x_batch, sample_id)
        # If your dataset does not provide IDs, you can generate them based on the batch index.

        # Move the input data to the device (GPU or CPU)
        x = x.to(torch.float32).to(device)

        # Perform the forward pass to get the model's output logits
        logits = model(x, edge_index)

        # Convert logits to predictions.
        # For binary classification, threshold logits at 0 (adjust this if you use softmax or multi-class).
        predictions = torch.round(torch.sigmoid(logits)).cpu().numpy()

        # Append predictions and corresponding IDs to the lists
        all_predictions.extend(predictions.flatten().tolist())
        all_ids.extend(list(ids))

# Create a DataFrame for Kaggle submission with the required format: "id,label"
submission_df = pd.DataFrame({"id": all_ids, "label": all_predictions})
submission_df["id"] = submission_df["id"].apply(lambda x: "_".join([txt.replace("_","") for txt in x.split("__")]))
submission_df["label"] = submission_df["label"].astype(int)

submission_df.to_csv("submission_seed1.csv", index=False)
print("Kaggle submission file generated: submission.csv")