In [139]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

from matplotlib import pyplot as plt
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
from sklearn.model_selection import train_test_split

In [140]:
all_data = pd.read_csv("../Data/Train_and_Validate_EEG.csv", index_col=0).drop("Unnamed: 122",axis=1)
data = all_data.drop(['sex','eeg.date','education','specific.disorder'], axis=1).dropna(axis=0)

# Encode disorders
encoder = OrdinalEncoder()
data[['main.disorder']] = encoder.fit_transform(data[['main.disorder']])

X = data.drop(['main.disorder'], axis=1).dropna(axis=0)
y = data['main.disorder']
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

In [141]:
bands = ['gamma', 'highbeta', 'beta', 'alpha', 'theta', 'delta']
bands2 = ['F.gamma', 'E.highbeta', 'D.beta', 'C.alpha', 'B.theta', 'A.delta']

# Separate column names based on whether they are PSD or coherence
AB_per_band = {}
AB_cols = []
COH_per_band = {}
COH_cols = []

for band in bands:
    AB_per_band[band] = []
    COH_per_band[band] = []

for col_name in X_train.drop(['age','IQ'], axis=1).columns:
    col_split = col_name.split('.')
    if col_split[0] == 'AB':
        AB_per_band[col_split[2]].append(col_name)
        AB_cols.append(col_name)
    else:
        COH_per_band[col_split[2]].append(col_name)
        COH_cols.append(col_name)

In [142]:
# Get all electrode names
electrodes = set()
for col in COH_cols:
    col_split = col.split('.')
    # elec1 = col_split[4]
    electrodes.add(col_split[3]+'.'+col_split[4])
    # elec2 = col_split[6]
    electrodes.add(col_split[5]+'.'+col_split[6])

electrodes = list(electrodes)
len(electrodes)

19

## data preprocessing
coherence data shape = [num individuals, frequency bands, electrodes, electrodes]
PSD data shape = [num individuals, frequency bands, electrodes]

In [143]:
num_samples = X_train.shape[0]
num_nodes = len(electrodes)
num_bands = len(bands)

In [144]:
def getCoherenceStacked(coherence):
    adj_tensors = []

    # for each individual
    for i in range(coherence.shape[0]):
        indiv_tensors = []

        # for each band
        for band_idx in range(len(bands2)):
            # adjacency matrix electrodes x electrodes
            adj_matrix = np.zeros([len(electrodes),len(electrodes)])

            for elec1 in range(len(electrodes)):
                for elec2 in range(elec1+1, len(electrodes)):
                    col_name = 'COH.'+bands2[band_idx]+'.'+electrodes[elec1]+'.'+electrodes[elec2]
                    if col_name in coherence.columns:
                        adj_matrix[elec1][elec2] = coherence.iloc[i,:][col_name] #individual's row, then get value
                    else: 
                        col_name = 'COH.'+bands2[band_idx]+'.'+electrodes[elec2]+'.'+electrodes[elec1]
                        adj_matrix[elec1][elec2] = coherence.iloc[i,:][col_name] #individual's row, then get value
            indiv_tensors.append(torch.from_numpy(adj_matrix).fill_diagonal_(1.0))

        adj_tensors.append(torch.stack(indiv_tensors))
    
    adj_matrices = torch.stack(adj_tensors)
    # print(adj_matrices.shape)
    adj_matrices = (adj_matrices + adj_matrices.transpose(2, 3)) / 2  # Ensure symmetry
    print(adj_matrices.shape)
    return adj_matrices

In [145]:
def getPSDStacked(psd):
    psd_tensors = []
    # for each individual
    for i in range(psd.shape[0]):
        psd_mat = np.zeros([len(electrodes),len(bands2)])

        # for each band
        for band_idx in range(len(bands2)):
            for elec1 in range(len(electrodes)):
                col_name = 'AB.'+bands2[band_idx]+'.'+electrodes[elec1]
                psd_mat[elec1][band_idx] = psd.iloc[i,:][col_name] #individual's row, then get value
            
        psd_tensors.append(torch.from_numpy(psd_mat))

    node_features = torch.stack(psd_tensors)
    print(node_features.shape)
    return node_features

In [146]:
adj_matrices = getCoherenceStacked(X_train)

torch.Size([672, 6, 19, 19])


In [147]:
node_features = getPSDStacked(X_train)

torch.Size([672, 19, 6])


## modeling

In [148]:
label_strs = list(y_train.unique())
label_strs

[3.0, 0.0, 6.0, 4.0, 1.0, 5.0, 2.0]

In [149]:
# Simulated node features for each sample (replace with actual PSD values for each band)
node_features = torch.rand(num_samples, num_nodes, num_bands)  # Shape: (800, 20, 6)

labels = y_train.reset_index(drop=True)

# Prepare the dataset
graphs = []

for i in range(num_samples):
    for band_idx in range(num_bands):
        # Get the adjacency matrix and node features for the current band and sample
        adj_matrix = adj_matrices[i, band_idx]
        edge_index = adj_matrix.nonzero(as_tuple=False).T
        
        # Node features: Extract features for the current band for this sample
        node_features_sample = node_features[i, :, band_idx].unsqueeze(1)  # Shape: (20, 1)
        
        # Create a graph for this band of the sample
        graph = Data(x=node_features_sample, 
                     edge_index=edge_index, 
                     y=labels[i])  # Same label for all graphs in this sample
        
        graphs.append(graph)

# Create a DataLoader for batching
train_loader = DataLoader(graphs, batch_size=32, shuffle=True)

# Define GNN Model
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        # Apply GCNConv layers
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        # Global mean pooling to get a graph-level representation
        x = global_mean_pool(x, batch)  # Shape: [batch_size, out_channels]
        return x

# Instantiate model
model = GCN(1, 16, len(label_strs))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
model.train()
for epoch in range(100):
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        
        # Batch is a single Batch object containing the batched data
        out = model(batch.x, batch.edge_index, batch.batch)  # Model output shape: [batch_size, num_classes]
        loss = F.cross_entropy(out, batch.y.long())  # Target labels shape: [batch_size]
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {total_loss / len(train_loader)}')

print("Training complete!")




Epoch 0, Loss: 1.8395034689751883
Epoch 10, Loss: 1.8224248479283045
Epoch 20, Loss: 1.8220130763356648
Epoch 30, Loss: 1.8212487527302332
Epoch 40, Loss: 1.8216022328724937
Epoch 50, Loss: 1.8218663249697005
Epoch 60, Loss: 1.8218212629121446
Epoch 70, Loss: 1.821637371229747
Epoch 80, Loss: 1.8216487386870006
Epoch 90, Loss: 1.8218793774408006
Training complete!


In [151]:
# Set model to evaluation mode
model.eval()

# Store predictions and true labels for evaluation
all_preds = []
all_labels = []

with torch.no_grad():  # Disable gradient calculation during evaluation
    for batch in train_loader:  # Use test_loader or validation_loader here instead of train_loader
        # Get the model's output
        out = model(batch.x, batch.edge_index, batch.batch)
        
        # Convert output to predicted labels (e.g., using argmax for classification)
        preds = out.argmax(dim=1)  # Get the class with the highest score
        
        # Store predictions and true labels
        all_preds.append(preds.cpu().numpy())  # Convert to numpy for evaluation
        all_labels.append(batch.y.cpu().numpy())  # Convert to numpy for evaluation

# Flatten the lists
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

# Calculate accuracy
accuracy = accuracy_score(all_labels, all_preds)

print(f'Accuracy: {accuracy * 100:.2f}%')


Accuracy: 28.87%
