In [7]:
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader


## Import data

In [8]:
class CustomEEGDataset(Dataset):
    def __init__(self, annotations_file, eeg_file, transform=None, target_transform=None):
        self.eeg_labels = torch.from_numpy(np.load(annotations_file).reshape(-1,1))
        self.eeg_data = torch.from_numpy(np.load(eeg_file))
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.eeg_labels)

    def __getitem__(self, idx):
        label = self.eeg_labels[idx]
        eeg = self.eeg_data[idx]
        if self.transform:
            eeg = self.transform(eeg)
        if self.target_transform:
            label = self.target_transform(label)
        return eeg, label

In [41]:
eeg_DE_dataset = CustomEEGDataset('label_valence_no_neutral_PSD_gamma.npy','eeg_data_no_neutral_PSD_gamma.npy' )

In [64]:
sample = eeg_DE_dataset[100]
print(len(eeg_DE_dataset))
sample[1]

2880


tensor([1.])

In [169]:
import torch
from torch_geometric.data import Data, DataLoader

# Create a fully connected adjacency matrix
adjacency = torch.ones((32, 32)) - torch.eye(32)

edge_index = adjacency.nonzero().t()

data_list = []
for entry in eeg_DE_dataset:
    x = entry[0]
    y = entry[1]
    data = Data(x=x, edge_index=edge_index, y=y)
    data_list.append(data)

trainloader = DataLoader(data_list, batch_size=1, shuffle=True)

In [167]:
for data in trainloader:
    print(data)

DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32], batch=[1024], ptr=[33])
DataBatch(x=[1024, 30], edge_index=[2, 31744], y=[32

In [125]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv = SAGEConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        return torch.sigmoid(x)


In [171]:
model = GraphSAGE(30, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

total_correct = 0
total_samples = 0

for epoch in range(100):
    for data in trainloader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        out = out.mean(dim=0, keepdim=True)

        loss = criterion(out[0], data.y.float())
        loss.backward()
        optimizer.step()

        # Compute accuracy
        pred = torch.sigmoid(out[0])  # Apply sigmoid to get probabilities
        pred = (pred > 0.5).float()  # Convert probabilities to binary predictions
        correct = (pred == data.y.float()).float().sum()  # Count number of correct predictions

        total_correct += correct.item()
        total_samples += len(data.y)

    accuracy = total_correct / total_samples  # Compute overall accuracy

    print(f'Epoch: {epoch}, Loss: {loss.item()}, Accuracy: {accuracy}')

Epoch: 0, Loss: 0.679759681224823, Accuracy: 0.43125


KeyboardInterrupt: 