In [59]:
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
import numpy as np
import itertools
import time
import wandb
import random

In [60]:
def load_graph_data():
    # Load labels and features
    y = np.load('label_valence_no_neutral_PSD_gamma.npy')
    x = np.load('eeg_data_no_neutral_PSD_gamma.npy')
    print(x.shape)
    # Convert to PyTorch tensors
    y = torch.tensor(y, dtype=torch.float32)
    x = torch.tensor(x, dtype=torch.float32)

    # fully connected graph for each graph
    edge_index = list(itertools.combinations(range(32), 2))
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Create a list of Data objects
    data_list = [Data(x=x[i], edge_index=edge_index, y=y[i]) for i in range(x.shape[0])]
    print(len(data_list))

    print(data_list[0].x.shape)
    return data_list

# Load your graph data
data_list = load_graph_data()
# Create a DataLoader
loader = DataLoader(data_list, batch_size=5, shuffle=True)
print("loader", loader.dataset[0].x)

(2880, 32, 30)
2880
torch.Size([32, 30])
loader tensor([[3.8687e-07, 4.6613e-07, 5.3260e-07, 1.4376e-07, 5.5562e-07, 6.8038e-07,
         3.7298e-07, 3.8315e-07, 9.0654e-07, 9.3240e-07, 6.1481e-07, 3.4832e-07,
         4.6013e-07, 3.6838e-07, 6.7272e-07, 4.7111e-07, 5.5869e-07, 8.8953e-07,
         2.3090e-07, 8.8521e-07, 2.1437e-07, 4.8598e-07, 4.7125e-07, 2.9518e-07,
         5.3914e-07, 6.3746e-07, 7.3941e-07, 9.1913e-07, 4.0791e-07, 2.3668e-07],
        [1.0997e-06, 1.1365e-06, 1.0279e-06, 4.1245e-07, 9.6109e-07, 1.9613e-06,
         1.9759e-06, 1.7818e-06, 3.5109e-06, 2.1051e-06, 2.2427e-06, 7.2981e-07,
         5.6689e-07, 5.1236e-07, 2.8042e-06, 8.6886e-07, 1.2810e-06, 1.8242e-06,
         1.0040e-06, 1.0020e-06, 8.0691e-07, 1.9977e-06, 1.4499e-06, 6.6204e-07,
         8.7982e-07, 1.4272e-06, 1.2729e-06, 8.9667e-07, 5.4645e-07, 6.8404e-07],
        [6.6198e-08, 1.0031e-07, 4.6065e-08, 4.6944e-08, 6.8400e-08, 1.1024e-07,
         8.6374e-08, 5.8051e-08, 1.0431e-07, 9.8463e-08, 1.



In [61]:
# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="aml_mini_project",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.02,
    "architecture": "GRAPHSage",
    "dataset": "EEG",
    "epochs": 100,
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mniclas-classen[0m ([33mniclasclassen[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [63]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index)
        return torch.sigmoid(x)

# Create the model
model = GraphSAGE(30, 32, 1)

# Define a loss function and an optimizer
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

model.train()
for epoch in range(100):
    acc_epoch = []
    for data in loader:
        data, target = data, data.y
        optimizer.zero_grad()
        out = model(data).squeeze()
        
        out = out.view(5,32)
        mean_out = torch.mean(out,dim = 1)
        
        loss = criterion(mean_out, target)
        loss.backward()
        optimizer.step()
        
        acc = (mean_out.round()==target).float().mean()
        acc_epoch.append(acc)

        # Print the model's output and the target for each batch
        # print(f"Output: {mean_out}, Target: {target}")
    wandb.log({"loss": loss.item(), "accuracy": np.mean(acc_epoch)})

    print(epoch, np.mean(acc_epoch), acc_epoch[-1])

0 0.56111115 tensor(0.8000)
1 0.56875 tensor(0.8000)
2 0.56874996 tensor(0.4000)
3 0.5687501 tensor(0.4000)
4 0.56875 tensor(0.6000)
5 0.56874996 tensor(0.6000)
6 0.5687501 tensor(0.4000)
7 0.5687501 tensor(0.6000)
8 0.56875 tensor(0.4000)
9 0.56875 tensor(0.4000)
10 0.56875 tensor(0.6000)
11 0.56875 tensor(0.4000)
12 0.56875 tensor(0.4000)
13 0.56875 tensor(0.6000)
14 0.56875 tensor(0.2000)
15 0.5687501 tensor(0.8000)
16 0.5687501 tensor(0.4000)
17 0.5687501 tensor(0.6000)
18 0.56875 tensor(0.8000)
19 0.56875 tensor(0.4000)
20 0.56875 tensor(0.4000)
21 0.5687501 tensor(0.8000)
22 0.56875 tensor(0.8000)
23 0.56875 tensor(0.8000)
24 0.56875 tensor(0.6000)
25 0.5687501 tensor(0.8000)
26 0.56875 tensor(0.6000)
27 0.5687501 tensor(0.6000)
28 0.56875 tensor(0.6000)
29 0.56875 tensor(0.2000)
30 0.56875 tensor(0.4000)
31 0.56875 tensor(0.6000)
32 0.5687501 tensor(0.4000)
33 0.56875 tensor(0.2000)
34 0.56875 tensor(1.)
35 0.56875 tensor(0.6000)
36 0.56875 tensor(0.6000)
37 0.56875 tensor(0.800