In [7]:
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader


## Import data

In [8]:
class CustomEEGDataset(Dataset):
    def __init__(self, annotations_file, eeg_file, transform=None, target_transform=None):
        self.eeg_labels = torch.from_numpy(np.load(annotations_file).reshape(-1,1))
        self.eeg_data = torch.from_numpy(np.load(eeg_file))
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.eeg_labels)

    def __getitem__(self, idx):
        label = self.eeg_labels[idx]
        eeg = self.eeg_data[idx]
        if self.transform:
            eeg = self.transform(eeg)
        if self.target_transform:
            label = self.target_transform(label)
        return eeg, label

In [41]:
eeg_DE_dataset = CustomEEGDataset('label_valence_no_neutral_PSD_gamma.npy','eeg_data_no_neutral_PSD_gamma.npy' )

In [64]:
sample = eeg_DE_dataset[100]
print(len(eeg_DE_dataset))
sample[1]

2880


tensor([1.])

In [237]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

# Create a fully connected adjacency matrix
adjacency = torch.ones((32, 32)) - torch.eye(32)

edge_index = adjacency.nonzero().t()

data_list = []
for entry in eeg_DE_dataset:
    x = entry[0]
    y = entry[1]
    data = Data(x=x, edge_index=edge_index, y=y)
    data_list.append(data)


In [323]:
trainloader = DataLoader(data_list, batch_size=2, shuffle=True)

In [245]:
# Print each subgraph
for i, subgraph in enumerate(trainloader):
    print(f'Subgraph {i}: {subgraph}')
    print(subgraph.ptr)

Subgraph 0: DataBatch(x=[64, 30], edge_index=[2, 1984], y=[2], batch=[64], ptr=[3])
tensor([ 0, 32, 64])
Subgraph 1: DataBatch(x=[64, 30], edge_index=[2, 1984], y=[2], batch=[64], ptr=[3])
tensor([ 0, 32, 64])
Subgraph 2: DataBatch(x=[64, 30], edge_index=[2, 1984], y=[2], batch=[64], ptr=[3])
tensor([ 0, 32, 64])
Subgraph 3: DataBatch(x=[64, 30], edge_index=[2, 1984], y=[2], batch=[64], ptr=[3])
tensor([ 0, 32, 64])
Subgraph 4: DataBatch(x=[64, 30], edge_index=[2, 1984], y=[2], batch=[64], ptr=[3])
tensor([ 0, 32, 64])
Subgraph 5: DataBatch(x=[64, 30], edge_index=[2, 1984], y=[2], batch=[64], ptr=[3])
tensor([ 0, 32, 64])
Subgraph 6: DataBatch(x=[64, 30], edge_index=[2, 1984], y=[2], batch=[64], ptr=[3])
tensor([ 0, 32, 64])
Subgraph 7: DataBatch(x=[64, 30], edge_index=[2, 1984], y=[2], batch=[64], ptr=[3])
tensor([ 0, 32, 64])
Subgraph 8: DataBatch(x=[64, 30], edge_index=[2, 1984], y=[2], batch=[64], ptr=[3])
tensor([ 0, 32, 64])
Subgraph 9: DataBatch(x=[64, 30], edge_index=[2, 1984],

In [324]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv = SAGEConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        return torch.sigmoid(x)


In [326]:
# import graph sage
from torch_geometric.nn import SAGEConv
model = GraphSAGE(30, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

total_correct = 0
total_samples = 0

for epoch in range(400):
    for data in trainloader:
        transformed_data = []
        for ptr in data.ptr:
            transformed_data.append(data[ptr:ptr+data.num_nodes])
        optimizer.zero_grad()

        total_out = []
        # get out for each batch.x and batch.edge_index
        for batch in transformed_data[0]:
            out = model(batch.x, batch.edge_index)
            out_mean = out.mean(dim=0)
            # print(out_mean)
            # add only number without grad_fn=<SelectBackward0>
            total_out.append(out_mean)


        total_out_tensor = torch.cat(total_out)
        # print(total_out_tensor)

        loss = criterion(total_out_tensor, data.y.float())
        loss.backward()
        optimizer.step()

        # Compute accuracy
        pred = torch.sigmoid(out[0])  # Apply sigmoid to get probabilities
        pred = (pred > 0.5).float()  # Convert probabilities to binary predictions
        correct = (pred == data.y.float()).float().sum()  # Count number of correct predictions

        total_correct += correct.item()
        total_samples += len(data.y)

    accuracy = total_correct / total_samples  # Compute overall accuracy

    print(f'Epoch: {epoch}, Loss: {loss.item()}, Accuracy: {accuracy}')

Epoch: 0, Loss: 0.7195472717285156, Accuracy: 0.43125
Epoch: 1, Loss: 0.6931912899017334, Accuracy: 0.43125
Epoch: 2, Loss: 0.6971733570098877, Accuracy: 0.43125
Epoch: 3, Loss: 0.6912310719490051, Accuracy: 0.43125
Epoch: 4, Loss: 0.6931475400924683, Accuracy: 0.43125
Epoch: 5, Loss: 0.6931473016738892, Accuracy: 0.43125
Epoch: 6, Loss: 0.6933958530426025, Accuracy: 0.43125
Epoch: 7, Loss: 0.6932718753814697, Accuracy: 0.43125
Epoch: 8, Loss: 0.6930842399597168, Accuracy: 0.43125
Epoch: 9, Loss: 0.6931471824645996, Accuracy: 0.43125
Epoch: 10, Loss: 0.6931471824645996, Accuracy: 0.43125
Epoch: 11, Loss: 0.6931554079055786, Accuracy: 0.43125
Epoch: 12, Loss: 0.6931429505348206, Accuracy: 0.43125
Epoch: 13, Loss: 0.6931492686271667, Accuracy: 0.43125
Epoch: 14, Loss: 0.6931482553482056, Accuracy: 0.43125
Epoch: 15, Loss: 0.6931471824645996, Accuracy: 0.43125
Epoch: 16, Loss: 0.6931471824645996, Accuracy: 0.43125
Epoch: 17, Loss: 0.6931470632553101, Accuracy: 0.43125
Epoch: 18, Loss: 0.6

KeyboardInterrupt: 