In [1]:
#!pip install torch torch-geometric


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import numpy as np


In [3]:
#!pip install tqdm
#!pip install gputil

from tqdm import tqdm


import GPUtil
import time

def monitor_gpu(seconds=5):
    """Monitor GPU usage every n seconds"""
    try:
        while True:
            GPUtil.showUtilization()
            print('-' * 40)
            time.sleep(seconds)
    except KeyboardInterrupt:
        print("Monitoring stopped.")

# Check if GPU is available and print info
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory Usage:")
    print(f"Allocated: {round(torch.cuda.memory_allocated(0)/1024**3, 1)} GB")
    print(f"Cached: {round(torch.cuda.memory_reserved(0)/1024**3, 1)} GB")

# To use, run in a separate cell and press Ctrl+C to stop when done
# monitor_gpu()

Using device: cpu


In [4]:
import pandas as pd
import pyarrow.parquet as pq
import torch
from torch_geometric.data import Data

In [5]:
data_path = '../data'

distances = pd.read_csv(f'{data_path}/distances_3d.csv')

train_segments = pq.read_table(f'{data_path}/train/segments.parquet').to_pandas()

train_segments

Unnamed: 0_level_0,label,start_time,end_time,date,sampling_rate,signals_path
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
pqejgcff_s001_t000_0,1,0.0,12.0,2003-01-01,250,signals/pqejgcff_s001_t000.parquet
pqejgcff_s001_t000_1,1,12.0,24.0,2003-01-01,250,signals/pqejgcff_s001_t000.parquet
pqejgcff_s001_t000_2,1,24.0,36.0,2003-01-01,250,signals/pqejgcff_s001_t000.parquet
pqejgcff_s001_t000_3,1,36.0,48.0,2003-01-01,250,signals/pqejgcff_s001_t000.parquet
pqejgcff_s001_t000_4,1,48.0,60.0,2003-01-01,250,signals/pqejgcff_s001_t000.parquet
...,...,...,...,...,...,...
pqejgvqb_s001_t013_8,1,96.0,108.0,2015-01-01,250,signals/pqejgvqb_s001_t013.parquet
pqejgvqb_s001_t013_9,1,108.0,120.0,2015-01-01,250,signals/pqejgvqb_s001_t013.parquet
pqejgvqb_s001_t013_10,1,120.0,132.0,2015-01-01,250,signals/pqejgvqb_s001_t013.parquet
pqejgvqb_s001_t013_11,1,132.0,144.0,2015-01-01,250,signals/pqejgvqb_s001_t013.parquet


We take a look at what the data looks like

In [6]:
display(
    pd.read_parquet(f"{data_path}/train/{train_segments.iloc[0]["signals_path"]}").iloc[
        12000:12015
    ]
)

Unnamed: 0,FP1,FP2,F3,F4,C3,C4,P3,P4,O1,O2,F7,F8,T3,T4,T5,T6,FZ,CZ,PZ
12000,59.718113,42.933436,14.85725,7.533027,-39.158892,-11.998233,-9.25165,-40.684771,42.933436,-21.458687,33.778158,10.889962,12.415842,11.500314,-28.78291,-25.425974,19.434889,-44.041707,-35.191604
12001,59.942979,44.684182,10.199301,6.53719,-39.544377,-12.99407,-6.585376,-40.154729,43.768654,-20.623469,33.087497,12.640709,14.471764,14.471764,-26.726988,-24.28558,15.082116,-48.089303,-35.882266
12002,63.075048,45.374844,7.533027,7.533027,-45.872762,-11.387881,-9.556825,-39.158892,44.459316,-15.355168,35.304038,15.162426,7.533027,22.791824,-32.750197,-19.932807,14.552074,-50.450401,-38.853716
12003,63.026862,45.326658,11.146953,9.315897,-48.667531,-10.825715,-13.267122,-39.512253,45.021482,-14.487826,34.6455,17.555647,8.095193,25.795398,-34.629438,-19.370641,13.283184,-52.939995,-39.512253
12004,61.227931,45.969134,7.211789,6.906613,-48.025056,-12.014295,-14.150526,-40.395657,47.495013,-15.37123,34.372448,17.282595,13.92566,27.658577,-26.357564,-20.864397,11.179077,-54.738926,-41.311185
12005,58.352852,45.840638,2.505655,5.252239,-48.153551,-13.363494,-12.447966,-42.050032,47.061342,-17.330781,35.769832,18.374804,18.985156,27.530082,-22.823948,-20.077364,11.355757,-53.646718,-41.134504
12006,59.830546,45.792453,5.204053,7.64546,-48.812088,-10.970272,-14.022031,-42.098218,46.707981,-16.768615,36.942351,18.93697,13.138627,32.059536,-28.670476,-20.430726,11.612748,-53.389727,-42.70857
12007,60.97094,45.712143,10.922086,10.311734,-48.892398,-9.829878,-16.848924,-40.957823,46.322495,-15.323045,38.082745,18.551485,9.701382,33.505106,-35.464657,-21.426563,13.66867,-54.385565,-44.619935
12008,58.706214,43.447417,15.981582,8.65736,-45.969134,-10.873901,-16.367067,-42.917374,45.583648,-19.418827,38.259426,14.760878,15.676406,28.798972,-28.879281,-24.911994,15.676406,-53.598532,-42.612198
12009,57.421262,43.993521,11.644871,5.846529,-42.676446,-14.600259,-13.684732,-43.89715,45.214225,-21.619306,39.721058,14.086279,20.189798,25.682965,-23.145186,-24.671065,14.391455,-51.221372,-42.676446


In [7]:
test_segments = pq.read_table(f'{data_path}/test/segments.parquet').to_pandas()
test_segments

Unnamed: 0_level_0,start_time,end_time,date,sampling_rate,signals_path
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
pqejgcvm_s001_t000_0,0.0,12.0,2002-01-01,250,signals/pqejgcvm_s001_t000.parquet
pqejgcvm_s001_t000_1,12.0,24.0,2002-01-01,250,signals/pqejgcvm_s001_t000.parquet
pqejgcvm_s001_t000_2,24.0,36.0,2002-01-01,250,signals/pqejgcvm_s001_t000.parquet
pqejgcvm_s001_t000_3,36.0,48.0,2002-01-01,250,signals/pqejgcvm_s001_t000.parquet
pqejgcvm_s001_t000_4,48.0,60.0,2002-01-01,250,signals/pqejgcvm_s001_t000.parquet
...,...,...,...,...,...
pqejgvej_s001_t000_153,1836.0,1848.0,2015-01-01,250,signals/pqejgvej_s001_t000.parquet
pqejgvej_s001_t000_154,1848.0,1860.0,2015-01-01,250,signals/pqejgvej_s001_t000.parquet
pqejgvej_s001_t000_155,1860.0,1872.0,2015-01-01,250,signals/pqejgvej_s001_t000.parquet
pqejgvej_s001_t000_156,1872.0,1884.0,2015-01-01,250,signals/pqejgvej_s001_t000.parquet


In [8]:
display(
    pd.read_parquet(f"{data_path}/test/{test_segments.iloc[0]['signals_path']}")
    .iloc[12000:12015]
)

Unnamed: 0,FP1,FP2,F3,F4,C3,C4,P3,P4,O1,O2,F7,F8,T3,T4,T5,T6,FZ,CZ,PZ
12000,-215.871822,-70.302899,1085.398383,-195.425034,1085.398383,-144.460652,598.947935,-252.798111,-1450.918848,-555.532643,25.522346,-136.526078,1095.774364,-420.339702,998.42324,-791.433644,-368.154616,-459.402222,171.701621
12001,-205.57615,-71.603912,1065.176461,-173.2275,1086.538777,-132.944276,589.407172,-244.943846,-1439.402472,-546.76285,8.65736,-150.339305,1095.388879,-414.926844,1004.446449,-787.24149,-370.981509,-485.422486,173.757543
12002,-231.275176,-30.774584,1073.351964,-154.981191,1093.798752,-125.989477,573.1686,-237.683871,-1449.537526,-582.837858,-6.970861,-123.54807,1095.019456,-394.239128,1016.894415,-812.63534,-360.97495,-515.699152,174.913999
12003,-258.034288,-49.599121,1082.298437,-143.59331,1098.777938,-119.179235,580.589193,-234.53574,-1457.070553,-585.48807,-11.452129,-110.023957,1116.783319,-426.796582,1025.535713,-843.666915,-352.944005,-489.052474,177.451778
12004,-214.972356,-61.468859,1074.701163,-156.683752,1084.161617,-109.686657,583.978253,-234.503616,-1464.362651,-572.638557,0.787033,-116.095352,1107.660164,-447.211246,1005.7314,-848.21243,-355.96364,-445.38019,170.159679
12005,-191.232881,-23.386114,1075.247267,-144.540962,1076.467971,-97.543867,558.584402,-234.262688,-1484.873687,-590.403009,-5.68591,-78.012607,1056.936711,-413.095788,960.80629,-826.30401,-343.82085,-457.3463,162.466033
12006,-198.926527,-37.793631,1086.474529,-77.161327,1089.221113,-82.95967,537.463015,-229.138945,-1507.82613,-630.140129,-10.02262,-75.635447,1063.281158,-403.394406,950.36606,-827.283786,-324.659014,-488.233317,166.369073
12007,-205.013984,-54.25707,1081.3026,-76.534913,1077.945665,-87.521247,531.070382,-234.310874,-1516.049818,-635.92241,4.641887,-79.586673,1087.711295,-377.438389,963.504688,-813.839982,-332.882702,-487.912079,155.093625
12008,-210.539274,-1.18858,1052.583938,-157.133485,1053.499466,-110.746742,536.226249,-245.024155,-1504.790432,-594.14543,4.304587,-59.78236,1081.575652,-343.290808,963.167388,-779.692401,-359.159957,-465.971535,140.10788
12009,-214.683242,-19.065465,1056.984896,-172.568963,1053.017609,-122.214933,540.322031,-258.933753,-1479.332334,-574.180499,-0.144557,-66.06256,1096.352593,-356.284878,993.203125,-776.817322,-358.115934,-476.219022,134.743208


In [9]:

distances

Unnamed: 0,from,to,distance
0,FP1,FP1,0.000000
1,FP1,FP2,0.618000
2,FP1,F3,0.618969
3,FP1,F4,1.030322
4,FP1,C3,1.250226
...,...,...,...
356,PZ,T5,1.081066
357,PZ,T6,1.081066
358,PZ,FZ,1.414200
359,PZ,CZ,0.765363


In [10]:
def create_edge_index(distances_df, threshold=1):
    """Create edge index based on electrode distances."""

    unique_electrodes = set()
    for idx, row in distances_df.iterrows():
        unique_electrodes.add(row['from'])
        unique_electrodes.add(row['to'])
    
    electrode_to_idx = {electrode: idx for idx, electrode in enumerate(unique_electrodes)}
    
    edge_index = []
    
    for idx, row in distances_df.iterrows():
        source_name = row['from']
        target_name = row['to']
        distance = row['distance']
        
        source_idx = electrode_to_idx[source_name]
        target_idx = electrode_to_idx[target_name]
        
        if distance < threshold:
            edge_index.append([source_idx, target_idx])
            edge_index.append([target_idx, source_idx])
    
    if not edge_index:  
        print("Warning: No edges were created with the current threshold. Try increasing the threshold.")
        # Create a fallback edge_index with at least one edge
        if len(electrode_to_idx) > 1:
            edge_index = [[0, 1], [1, 0]]
    
    return torch.tensor(edge_index, dtype=torch.long).t().contiguous()

In [11]:
def create_edge_index_and_attr(distances_df, connection_threshold=1, use_all_connections=True):
    """
    Create edge index and edge attributes based on electrode distances.
    
    Args:
        distances_df: DataFrame with electrode distances
        connection_threshold: Threshold for creating edges
        use_all_connections: If True, use all connections with distance-based weights
                            If False, only use connections below threshold
    """
    # First, create a mapping from electrode names to indices
    unique_electrodes = set()
    for idx, row in distances_df.iterrows():
        unique_electrodes.add(row['from'])
        unique_electrodes.add(row['to'])
    
    electrode_to_idx = {electrode: idx for idx, electrode in enumerate(sorted(unique_electrodes))}
    
    edge_index = []
    edge_attr = []
    

    all_distances = distances_df['distance'].values
    max_distance = np.max(all_distances)
    min_distance = np.min(all_distances)
    

    for idx, row in distances_df.iterrows():
        source_name = row['from']
        target_name = row['to']
        distance = row['distance']
        
        source_idx = electrode_to_idx[source_name]
        target_idx = electrode_to_idx[target_name]
        
        norm_distance = (distance - min_distance) / (max_distance - min_distance)
        
        inverse_dist = 1.0 / (norm_distance + 0.001)  # Avoid division by zero
        gaussian_weight = np.exp(-norm_distance**2)
        
        edge_features = [norm_distance, inverse_dist, gaussian_weight]
        
        if use_all_connections or distance < connection_threshold:
            edge_index.append([source_idx, target_idx])
            edge_attr.append(edge_features)
            
            edge_index.append([target_idx, source_idx])
            edge_attr.append(edge_features)
    
    if not edge_index:  # Check if edge_index is empty
        print("Warning: No edges were created. Using fallback edges.")
        if len(electrode_to_idx) > 1:
            edge_index = [[0, 1], [1, 0]]
            edge_attr = [[0.5, 2.0, 0.6], [0.5, 2.0, 0.6]]
    
    return (
        torch.tensor(edge_index, dtype=torch.long).t().contiguous(),
        torch.tensor(edge_attr, dtype=torch.float)
    )

In [12]:
def extract_eeg_features(eeg_data, sampling_rate=250):
    """Extract meaningful features from EEG data for each electrode."""
    features = []
    
    for column in eeg_data.columns:
        signal = eeg_data[column].values
        
        mean = np.mean(signal)
        std = np.std(signal)
        min_val = np.min(signal)
        max_val = np.max(signal)
        
        signal_fft = np.abs(np.fft.rfft(signal))
        freq = np.fft.rfftfreq(len(signal), d=1/sampling_rate)
        
        # Extract power in standard EEG frequency bands
        # Delta: 0.5-4 Hz (deep sleep)
        delta_idx = np.logical_and(freq >= 0.5, freq < 4)
        delta_power = np.sum(signal_fft[delta_idx]**2)
        
        # Theta: 4-8 Hz (drowsiness)
        theta_idx = np.logical_and(freq >= 4, freq < 8)
        theta_power = np.sum(signal_fft[theta_idx]**2)
        
        # Alpha: 8-13 Hz (relaxed wakefulness)
        alpha_idx = np.logical_and(freq >= 8, freq < 13)
        alpha_power = np.sum(signal_fft[alpha_idx]**2)
        
        # Beta: 13-30 Hz (active thinking)
        beta_idx = np.logical_and(freq >= 13, freq < 30)
        beta_power = np.sum(signal_fft[beta_idx]**2)
        
        # Gamma: 30+ Hz (cognitive processing)
        gamma_idx = freq >= 30
        gamma_power = np.sum(signal_fft[gamma_idx]**2)
        
        # Dominant frequency (frequency with maximum power)
        dom_freq = freq[np.argmax(signal_fft)]
        
        # Spectral edge frequency (95% of power is below this frequency)
        total_power = np.sum(signal_fft**2)
        cumulative_power = np.cumsum(signal_fft**2)
        spectral_edge_idx = np.where(cumulative_power >= 0.95 * total_power)[0]
        spectral_edge = freq[spectral_edge_idx[0]] if len(spectral_edge_idx) > 0 else freq[-1]
        
        # Band power ratios 
        total_band_power = delta_power + theta_power + alpha_power + beta_power + gamma_power
        if total_band_power > 0:
            delta_ratio = delta_power / total_band_power
            theta_ratio = theta_power / total_band_power
            alpha_ratio = alpha_power / total_band_power
            beta_ratio = beta_power / total_band_power
            gamma_ratio = gamma_power / total_band_power
        else:
            delta_ratio = theta_ratio = alpha_ratio = beta_ratio = gamma_ratio = 0
        
        # Combining all features
        node_feature = [
            mean, std, min_val, max_val,
            delta_power, theta_power, alpha_power, beta_power, gamma_power,
            dom_freq, spectral_edge,
            delta_ratio, theta_ratio, alpha_ratio, beta_ratio, gamma_ratio
        ]
        features.append(node_feature)
    
    return torch.tensor(features, dtype=torch.float)

In [13]:
import torch_geometric.nn as pyg_nn
from torch_geometric.nn import GATConv

class EnhancedGATWithEdgeFeatures(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, edge_dim=3, num_heads=8, dropout=0.3):
        super(EnhancedGATWithEdgeFeatures, self).__init__()
        
        # Define attention layers with edge features
        self.gat1 = pyg_nn.GATv2Conv(
            in_channels=input_dim,
            out_channels=hidden_dim,
            heads=num_heads,
            dropout=dropout,
            concat=True,
            edge_dim=edge_dim  # Include edge features
        )
        
        self.bn1 = nn.BatchNorm1d(hidden_dim * num_heads)
        
        self.gat2 = pyg_nn.GATv2Conv(
            in_channels=hidden_dim * num_heads,
            out_channels=hidden_dim, 
            heads=num_heads,
            dropout=dropout,
            concat=True,
            edge_dim=edge_dim  # Include edge features
        )
        
        self.bn2 = nn.BatchNorm1d(hidden_dim * num_heads)
        
        self.gat3 = pyg_nn.GATv2Conv(
            in_channels=hidden_dim * num_heads, 
            out_channels=hidden_dim,
            heads=1, 
            dropout=dropout,
            concat=False,
            edge_dim=edge_dim  # Include edge features
        )
        
        # Graph-level readout
        self.linear = nn.Linear(hidden_dim, output_dim)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, edge_attr):
        # First GAT layer
        x = self.gat1(x, edge_index, edge_attr)
        x = self.bn1(x)
        x = F.elu(x)
        x = self.dropout(x)
        
        # Second GAT layer
        x = self.gat2(x, edge_index, edge_attr)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.dropout(x)
        
        # Final GAT layer
        x = self.gat3(x, edge_index, edge_attr)
        
        # Global pooling (graph-level prediction)
        x = torch.mean(x, dim=0)
        
        # Ensure correct shape for the linear layer
        if x.dim() == 1:
            x = x.unsqueeze(0)
            
        x = self.linear(x)  # Final prediction
        
        # For binary classification (epilepsy detection)
        return torch.sigmoid(x)

In [14]:
def train_gat_model(model, train_data_list, val_data_list, epochs=100, lr=0.001):
    """Train the GAT model for epilepsy detection."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on: {device}")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCELoss()
    
    best_val_loss = float('inf')
    best_model = None
    
    # Training loop
    for epoch in tqdm(range(epochs), desc="Training Progress"):
        # Add this at the beginning of your training loop
        print(f"Sample data dimensions:")
        for i, data in enumerate([train_data_list[0]]):
            print(f"  x shape: {data.x.shape}")
            print(f"  edge_index shape: {data.edge_index.shape}")
            print(f"  y shape: {data.y.shape}")
            print(f"  Model output shape: {model(data.x, data.edge_index).shape}")
            break
        model.train()
        epoch_loss = 0
        
        # Process each graph separately (batch size = 1 for graph data)
        for i, data in enumerate(train_data_list):
            optimizer.zero_grad()
            
            # Move data to device
            x = data.x.to(device)
            edge_index = data.edge_index.to(device)
            y = data.y.to(device)
            
            # Forward pass
            out = model(x, edge_index)
            
            # Handle dimension matching
            if out.shape != y.shape:
                if out.dim() > y.dim():
                    y = y.unsqueeze(-1)
                elif y.dim() > out.dim():
                    out = out.unsqueeze(-1)
            
            loss = loss_fn(out, y)
            epoch_loss += loss.item()
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Print mini-batch progress occasionally
            if (i+1) % 500 == 0:
                tqdm.write(f'  Batch {i+1}/{len(train_data_list)}, Loss: {loss.item():.4f}')
            
        
        # Calculate average epoch loss
        avg_epoch_loss = epoch_loss / len(train_data_list)
        
        # Validation
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for val_data in val_data_list:
                val_x = val_data.x.to(device)
                val_edge_index = val_data.edge_index.to(device)
                val_y = val_data.y.to(device)
                
                val_out = model(val_x, val_edge_index)
                
                # Handle dimension matching
                if val_out.shape != val_y.shape:
                    if val_out.dim() > val_y.dim():
                        val_y = val_y.unsqueeze(-1)
                    elif val_y.dim() > val_out.dim():
                        val_out = val_out.unsqueeze(-1)
                
                val_loss += loss_fn(val_out, val_y).item()
                
                # Calculate validation accuracy
                val_preds = (val_out > 0.5).float()
                correct += (val_preds == val_y).sum().item()
                total += val_y.numel()
        
        avg_val_loss = val_loss / len(val_data_list)
        val_accuracy = correct / total
        
        # Save the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model = model.state_dict()
        
        # Print progress each epoch
        tqdm.write(f'Epoch {epoch+1}/{epochs}, Loss: {avg_epoch_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}')
    
    # Load the best model
    model.load_state_dict(best_model)
    return model

In [15]:


def train_gat_model_with_edge_attr(model, train_data_list, val_data_list, epochs=100, lr=0.001):
    """Train the GAT model with edge attributes."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on: {device}")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCELoss()
    
    best_val_loss = float('inf')
    best_model = None
    
    # Training loop
    for epoch in tqdm(range(epochs), desc="Training Progress"):
        # Debug information
        print(f"Sample data dimensions:")
        for i, data in enumerate([train_data_list[0]]):
            print(f"  x shape: {data.x.shape}")
            print(f"  edge_index shape: {data.edge_index.shape}")
            print(f"  edge_attr shape: {data.edge_attr.shape}")
            print(f"  y shape: {data.y.shape}")
            print(f"  Model output shape: {model(data.x, data.edge_index, data.edge_attr).shape}")
            break
            
        model.train()
        epoch_loss = 0
        
        # Process each graph separately
        for i, data in enumerate(train_data_list):
            optimizer.zero_grad()
            
            # Move data to device
            x = data.x.to(device)
            edge_index = data.edge_index.to(device)
            edge_attr = data.edge_attr.to(device)
            y = data.y.to(device)
            
            # Forward pass
            out = model(x, edge_index, edge_attr)
            
            # Handle dimension matching
            if out.shape != y.shape:
                if out.dim() > y.dim():
                    y = y.unsqueeze(-1)
                elif y.dim() > out.dim():
                    out = out.unsqueeze(-1)
            
            loss = loss_fn(out, y)
            epoch_loss += loss.item()
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Print mini-batch progress occasionally
            if (i+1) % 100 == 0:
                tqdm.write(f'  Batch {i+1}/{len(train_data_list)}, Loss: {loss.item():.4f}')
        
        # Calculate average epoch loss
        avg_epoch_loss = epoch_loss / len(train_data_list)
        
        # Validation
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for val_data in val_data_list:
                val_x = val_data.x.to(device)
                val_edge_index = val_data.edge_index.to(device)
                val_edge_attr = val_data.edge_attr.to(device)
                val_y = val_data.y.to(device)
                
                val_out = model(val_x, val_edge_index, val_edge_attr)
                
                # Handle dimension matching
                if val_out.shape != val_y.shape:
                    if val_out.dim() > val_y.dim():
                        val_y = val_y.unsqueeze(-1)
                    elif val_y.dim() > val_out.dim():
                        val_out = val_out.unsqueeze(-1)
                
                val_loss += loss_fn(val_out, val_y).item()
                
                # Calculate validation accuracy
                val_preds = (val_out > 0.5).float()
                correct += (val_preds == val_y).sum().item()
                total += val_y.numel()
        
        avg_val_loss = val_loss / len(val_data_list)
        val_accuracy = correct / total
        
        # Save the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model = model.state_dict()
        
        # Print progress each epoch
        tqdm.write(f'Epoch {epoch+1}/{epochs}, Loss: {avg_epoch_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}')
    
    # Load the best model
    model.load_state_dict(best_model)
    return model

In [16]:
# Check what's inside a segment to debug the path issue
print("Checking first segment:")
print(train_segments.iloc[0])
print("Full expected path:", f"{data_path}/train/{train_segments.iloc[0]['signals_path']}")

# Check if directory exists
import os
path_to_check = f"{data_path}/train"
print(f"Directory {path_to_check} exists:", os.path.exists(path_to_check))

# List directory contents
if os.path.exists(path_to_check):
    print("Contents:", os.listdir(path_to_check)[:10])  # Show first 10 items

Checking first segment:
label                                             1
start_time                                      0.0
end_time                                       12.0
date                            2003-01-01 00:00:00
sampling_rate                                   250
signals_path     signals/pqejgcff_s001_t000.parquet
Name: pqejgcff_s001_t000_0, dtype: object
Full expected path: ../data/train/signals/pqejgcff_s001_t000.parquet
Directory ../data/train exists: True
Contents: ['segments.parquet', 'signals']


In [17]:
def prepare_eeg_data(segments_df, data_path, distances_df, threshold=0.1):
    """Process EEG segments and prepare them for GAT model."""
    all_data = []
    
    # Create electrode mapping once
    edge_index = create_edge_index(distances_df, threshold)
    print(f"Created edge_index with shape {edge_index.shape}")
    
    for idx, segment in segments_df.iterrows():
        try:
            # Load EEG data for this segment
            eeg_data = pd.read_parquet(f"{data_path}/{segment['signals_path']}")
            
            # Create node features from EEG data
            node_features = extract_eeg_features(eeg_data)
            
            # Create label (1 for seizure, 0 for non-seizure)
            if 'seizure' in segment.index:
                label = torch.tensor([[1.0 if segment['seizure'] else 0.0]]).float()  # Shape: [1, 1]
            else:
                label = torch.tensor([[0.0]]).float()  # Shape: [1, 1] 
            
            # Create PyTorch Geometric Data object
            data = Data(
                x=node_features,
                edge_index=edge_index,
                y=label
            )
            
            all_data.append(data)
            
            # Print shapes for debug
            if idx == 0:
                print(f"Sample data shapes - x: {data.x.shape}, edge_index: {data.edge_index.shape}, y: {data.y.shape}")
                
        except Exception as e:
            print(f"Error processing segment {idx}: {e}")
            continue
    
    return all_data

In [18]:
def prepare_eeg_data_with_edge_features(segments_df, data_path, distances_df, threshold=0.3):
    """Process EEG segments and prepare them for GAT model with edge features."""
    all_data = []
    
    # Create electrode mapping and edge attributes once
    edge_index, edge_attr = create_edge_index_and_attr(distances_df, threshold)
    print(f"Created edge_index with shape {edge_index.shape}")
    print(f"Created edge_attr with shape {edge_attr.shape}")

    counter = 0
    for idx, segment in tqdm(segments_df.iterrows(), total=len(segments_df), desc="Processing EEG segments"):
        try:
            # Load EEG data for this segment
            eeg_data = pd.read_parquet(f"{data_path}/{segment['signals_path']}")
            
            # Create node features from EEG data
            node_features = extract_eeg_features(eeg_data)
            
            # Create label (1 for seizure, 0 for non-seizure)
            if 'seizure' in segment.index:
                label = torch.tensor([[1.0 if segment['seizure'] else 0.0]]).float()
            else:
                label = torch.tensor([[0.0]]).float()
            
            # Create PyTorch Geometric Data object with edge attributes
            data = Data(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=label
            )
            
            all_data.append(data)
            
            # Print shapes for the first few samples
            if counter < 3:  
                print(f"Sample {idx} shapes - x: {data.x.shape}, edge_index: {data.edge_index.shape}, edge_attr: {data.edge_attr.shape}")
                
            counter += 1
        except Exception as e:
            print(f"Error processing segment {idx}: {e}")
            continue
    
    return all_data

In [19]:
def extract_eeg_features_optimized(eeg_data, sampling_rate=256):
    """Extract meaningful features from EEG data for each electrode - optimized version."""
    # Preallocate output array for better performance
    features = np.zeros((len(eeg_data.columns), 16), dtype=np.float32)
    
    # Extract all signals at once as a numpy array
    all_signals = eeg_data.values.T  # Transpose to get electrode x timepoints
    
    # Calculate FFTs in batch (much faster than one-by-one)
    all_ffts = np.abs(np.fft.rfft(all_signals, axis=1))
    freq = np.fft.rfftfreq(all_signals.shape[1], d=1/sampling_rate)
    
    # Prepare frequency band indices once
    delta_idx = np.logical_and(freq >= 0.5, freq < 4)
    theta_idx = np.logical_and(freq >= 4, freq < 8)
    alpha_idx = np.logical_and(freq >= 8, freq < 13)
    beta_idx = np.logical_and(freq >= 13, freq < 30)
    gamma_idx = freq >= 30
    
    # Process each electrode
    for i, column in enumerate(eeg_data.columns):
        signal = all_signals[i]
        signal_fft = all_ffts[i]
        
        # Basic statistical features
        features[i, 0] = np.mean(signal)
        features[i, 1] = np.std(signal)
        features[i, 2] = np.min(signal)
        features[i, 3] = np.max(signal)
        
        # Power in frequency bands
        features[i, 4] = np.sum(signal_fft[delta_idx]**2)  # Delta
        features[i, 5] = np.sum(signal_fft[theta_idx]**2)  # Theta
        features[i, 6] = np.sum(signal_fft[alpha_idx]**2)  # Alpha
        features[i, 7] = np.sum(signal_fft[beta_idx]**2)   # Beta
        features[i, 8] = np.sum(signal_fft[gamma_idx]**2)  # Gamma
        
        # Dominant frequency
        features[i, 9] = freq[np.argmax(signal_fft)]
        
        # Spectral edge
        total_power = np.sum(signal_fft**2)
        if total_power > 0:
            cumulative_power = np.cumsum(signal_fft**2)
            spectral_edge_idx = np.where(cumulative_power >= 0.95 * total_power)[0]
            features[i, 10] = freq[spectral_edge_idx[0]] if len(spectral_edge_idx) > 0 else freq[-1]
        else:
            features[i, 10] = 0
        
        # Band power ratios
        band_total = features[i, 4] + features[i, 5] + features[i, 6] + features[i, 7] + features[i, 8]
        if band_total > 0:
            features[i, 11] = features[i, 4] / band_total  # Delta ratio
            features[i, 12] = features[i, 5] / band_total  # Theta ratio
            features[i, 13] = features[i, 6] / band_total  # Alpha ratio
            features[i, 14] = features[i, 7] / band_total  # Beta ratio
            features[i, 15] = features[i, 8] / band_total  # Gamma ratio
    
    return torch.tensor(features, dtype=torch.float)

In [20]:
# Check if GPU is available and print info
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory Usage:")
    print(f"Allocated: {round(torch.cuda.memory_allocated(0)/1024**3, 1)} GB")
    print(f"Cached: {round(torch.cuda.memory_reserved(0)/1024**3, 1)} GB")



Using device: cpu


In [21]:
# Test with just a few samples first
sample_size = 5  # Start with just 5 segments
sample_segments = train_segments.iloc[:sample_size]
print(f"Testing with {len(sample_segments)} segments")

# Replace the extract_eeg_features function with the optimized version
# This is just for the cell that runs this test
import time
start_time = time.time()
sample_data = prepare_eeg_data_with_edge_features(sample_segments, f'{data_path}/train', distances, threshold=0.3)
elapsed = time.time() - start_time
print(f"Processed {len(sample_data)} segments in {elapsed:.2f} seconds ({elapsed/len(sample_data):.2f} seconds per segment)")

# Examine one processed sample
if len(sample_data) > 0:
    print("\nFirst sample details:")
    print(f"Node features shape: {sample_data[0].x.shape}")
    print(f"Edge index shape: {sample_data[0].edge_index.shape}")
    print(f"Edge attr shape: {sample_data[0].edge_attr.shape}")
    print(f"Label: {sample_data[0].y.item()}")

Testing with 5 segments
Created edge_index with shape torch.Size([2, 722])
Created edge_attr with shape torch.Size([722, 3])


Processing EEG segments:   0%|          | 0/5 [00:00<?, ?it/s]

Processing EEG segments:  20%|██        | 1/5 [00:00<00:02,  1.67it/s]

Sample pqejgcff_s001_t000_0 shapes - x: torch.Size([19, 16]), edge_index: torch.Size([2, 722]), edge_attr: torch.Size([722, 3])


Processing EEG segments:  40%|████      | 2/5 [00:01<00:01,  1.67it/s]

Sample pqejgcff_s001_t000_1 shapes - x: torch.Size([19, 16]), edge_index: torch.Size([2, 722]), edge_attr: torch.Size([722, 3])


Processing EEG segments:  60%|██████    | 3/5 [00:01<00:01,  1.88it/s]

Sample pqejgcff_s001_t000_2 shapes - x: torch.Size([19, 16]), edge_index: torch.Size([2, 722]), edge_attr: torch.Size([722, 3])


Processing EEG segments: 100%|██████████| 5/5 [00:02<00:00,  2.05it/s]

Processed 5 segments in 2.54 seconds (0.51 seconds per segment)

First sample details:
Node features shape: torch.Size([19, 16])
Edge index shape: torch.Size([2, 722])
Edge attr shape: torch.Size([722, 3])
Label: 0.0





In [None]:
# First, import tqdm if not already imported
from tqdm import tqdm

# Parameters
input_dim = 16  # Number of features per node (from extract_eeg_features)
hidden_dim = 32
output_dim = 1  # Binary classification
edge_dim = 3    # We're using 3 edge features now
num_heads = 8
threshold = 0.3  # Threshold for edge feature calculations

# Prepare data with edge features
print("Preparing training data...")
train_data = prepare_eeg_data_with_edge_features(train_segments, f'{data_path}/train', distances, threshold)
print(f"Created {len(train_data)} training samples")



Preparing training data...
Created edge_index with shape torch.Size([2, 722])
Created edge_attr with shape torch.Size([722, 3])


Processing EEG segments:   0%|          | 0/12993 [00:00<?, ?it/s]

Processing EEG segments:   0%|          | 1/12993 [00:00<1:27:59,  2.46it/s]

Sample pqejgcff_s001_t000_0 shapes - x: torch.Size([19, 16]), edge_index: torch.Size([2, 722]), edge_attr: torch.Size([722, 3])


Processing EEG segments:   0%|          | 2/12993 [00:00<1:26:53,  2.49it/s]

Sample pqejgcff_s001_t000_1 shapes - x: torch.Size([19, 16]), edge_index: torch.Size([2, 722]), edge_attr: torch.Size([722, 3])


Processing EEG segments:   0%|          | 3/12993 [00:01<1:26:07,  2.51it/s]

Sample pqejgcff_s001_t000_2 shapes - x: torch.Size([19, 16]), edge_index: torch.Size([2, 722]), edge_attr: torch.Size([722, 3])


Processing EEG segments:  20%|██        | 2633/12993 [26:44<5:14:09,  1.82s/it]

In [None]:

# Split into train and validation
from sklearn.model_selection import train_test_split
train_samples, val_samples = train_test_split(train_data, test_size=0.2, random_state=42)
print(f"Training on {len(train_samples)} samples, validating on {len(val_samples)} samples")

# Initialize model with edge features
model = EnhancedGATWithEdgeFeatures(input_dim, hidden_dim, output_dim, edge_dim, num_heads)


In [None]:

# Train model
print("Starting training...")
trained_model = train_gat_model_with_edge_attr(model, train_samples, val_samples, epochs=10, lr=0.01)


In [None]:
# Process test data
print("Processing test data...")
test_data = prepare_eeg_data_with_edge_features(test_segments, f'{data_path}/test', distances, threshold)

# Make predictions on test data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
predictions = []

for data in test_data:
    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    edge_attr = data.edge_attr.to(device)
    
    with torch.no_grad():
        out = model(x, edge_index, edge_attr)
        pred = (out > 0.5).float().cpu().numpy()
        predictions.append(pred)

# Save predictions
test_segments['prediction'] = predictions
test_segments.to_csv('predictions.csv', index=False)

In [None]:
# Split into train and validation
from sklearn.model_selection import train_test_split
train_samples, val_samples = train_test_split(train_data, test_size=0.2, random_state=42)
print(f"Training on {len(train_samples)} samples, validating on {len(val_samples)} samples")

In [None]:

# Initialize model
model = EnhancedGATModel(input_dim, hidden_dim, output_dim, num_heads)

# Train model
print("Starting training...")
trained_model = train_gat_model(model, train_samples, val_samples, epochs=100, lr=0.001)

# Process test data
print("Processing test data...")
test_data = prepare_eeg_data(test_segments, f'{data_path}/test', distances, threshold)

# Make predictions on test data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
predictions = []

for data in test_data:
    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    
    with torch.no_grad():
        out = model(x, edge_index)
        pred = (out > 0.5).float().cpu().numpy()
        predictions.append(pred)

# Save predictions
test_segments['prediction'] = predictions
test_segments.to_csv('predictions.csv', index=False)