In [8]:
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
import numpy as np
import itertools
import time
import wandb
import random
from sklearn.model_selection import train_test_split

In [12]:
def load_graph_data(features, labels):
    print(features.shape)
    # Convert to PyTorch tensors
    y = torch.tensor(labels, dtype=torch.float32)
    x = torch.tensor(features, dtype=torch.float32)

    # fully connected graph for each graph
    edge_index = list(itertools.combinations(range(32), 2))
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Create a list of Data objects
    data_list = [Data(x=x[i], edge_index=edge_index, y=y[i]) for i in range(x.shape[0])]
    print(len(data_list))

    print(data_list[0].x.shape)
    return data_list

In [16]:
# Load labels and features
y = np.load('label_valence_no_neutral_PSD_gamma.npy')
x = np.load('eeg_data_no_neutral_PSD_gamma.npy')

# Split the data into training and testing sets
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

train_data = load_graph_data(x_train, y_train)
test_data = load_graph_data(x_test, y_test)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

(2304, 32, 30)
2304
torch.Size([32, 30])
(576, 32, 30)
576
torch.Size([32, 30])




In [61]:
# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="aml_mini_project",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.02,
    "architecture": "GRAPHSage",
    "dataset": "EEG",
    "epochs": 100,
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mniclas-classen[0m ([33mniclasclassen[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [19]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index)
        return torch.sigmoid(x)

# Create the model
model = GraphSAGE(30, 32, 1)

# Define a loss function and an optimizer
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
model.train()
for epoch in range(10):
    acc_epoch = []
    for data in train_loader:
        data, target = data, data.y
        optimizer.zero_grad()
        out = model(data).squeeze()
        
        out = out.view(32,32)
        mean_out = torch.mean(out,dim = 1)
        
        loss = criterion(mean_out, target)
        loss.backward()
        optimizer.step()
        
        acc = (mean_out.round()==target).float().mean()
        acc_epoch.append(acc)

    print(epoch, np.mean(acc_epoch), acc_epoch[-1])

# Testing loop
model.eval()
with torch.no_grad():
    acc_test = []
    for data in test_loader:
        data, target = data, data.y
        out = model(data).squeeze()
        
        out = out.view(32,32)
        mean_out = torch.mean(out,dim = 1)
        
        acc = (mean_out.round()==target).float().mean()
        acc_test.append(acc)

    print('Test Accuracy:', np.mean(acc_test))

0 0.5651042 tensor(0.5312)
1 0.5651042 tensor(0.4062)
2 0.5651042 tensor(0.4375)
3 0.5651042 tensor(0.5312)
4 0.5651042 tensor(0.5625)
5 0.5651042 tensor(0.4688)
6 0.5651042 tensor(0.5938)
7 0.5651042 tensor(0.5625)
8 0.5651042 tensor(0.5625)
9 0.5651042 tensor(0.5625)
Test Accuracy: 0.5833333
