In [1]:
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
import numpy as np
import itertools
import time
import wandb
import random
from sklearn.model_selection import train_test_split
from torch_geometric.nn import global_mean_pool


In [2]:
def load_graph_data(features, labels):
    print(features.shape)
    # Convert to PyTorch tensors
    y = torch.tensor(labels, dtype=torch.float32)
    x = torch.tensor(features, dtype=torch.float32)

    print("y shape: ", y.shape)
    # fully connected graph for each graph
    edge_index = list(itertools.combinations(range(32), 2))
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Create a list of Data objects
    data_list = [Data(x=x[i], edge_index=edge_index, y=y[i]) for i in range(x.shape[0])]
    print(len(data_list))

    print(data_list[0].x.shape)
    return data_list

In [3]:
_batch_size = 32

In [4]:
# Load labels and features
y = np.load('label_based_on_movie_classification_movie.npy')
x = np.load('eeg_data_no_neutral_PSD_gamma.npy')

# Split the data into training and testing sets
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

train_data = load_graph_data(x_train, y_train)
test_data = load_graph_data(x_test, y_test)

train_loader = DataLoader(train_data, batch_size=_batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=_batch_size, shuffle=False)

(2304, 32, 30)
y shape:  torch.Size([2304, 2])
2304
torch.Size([32, 30])
(576, 32, 30)
y shape:  torch.Size([576, 2])
576
torch.Size([32, 30])




In [9]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)
        return torch.sigmoid(x)

# Create the model
model = GraphSAGE(30, 1)

# Define a loss function and an optimizer
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
model.train()
for epoch in range(400):
    acc_epoch = []
    loss_epoch = []
    acc_by_movie_train = {}
    correct = 0
    size = 0
    for data in train_loader:
        data, target = data, data.y

        labels = target[::2]  
        movie_numbers = target[1::2]


        optimizer.zero_grad()
        out = model(data).squeeze()
        out = out.view(_batch_size, 32)
        # print("out test", out)    
        mean_out = torch.mean(out,dim = 1)

        # print("mean_out", mean_out)
        
        loss = criterion(mean_out, labels)
        loss.backward()
        optimizer.step()
        
        acc = (mean_out.round()==labels).float().mean()
        # print("acc", acc)
        acc_epoch.append(acc)
        correct += (mean_out.round()==labels).sum()
        size += len(labels)

        # loss
        loss_epoch.append(loss.item())


        # Save accuracy by movie
        for movie, accuracy in zip(movie_numbers, acc_epoch):
            acc_by_movie_train[movie.item()] = accuracy.item()

    print(f"Epoch {epoch} Accuracy: {correct/size}, Loss: {np.mean(loss_epoch)}")

# Testing loop
# model.eval() 
with torch.no_grad():
    acc_test = []
    acc_by_movie_test = {}
    for data in test_loader:
        data, target = data, data.y

        labels = target[::2]  
        movie_numbers = target[1::2]

        out = model(data).squeeze()
        
        out = out.view(_batch_size, 32)
        mean_out = torch.mean(out,dim = 1)
        
        acc = (mean_out.round()==labels).float().mean()
        acc_test.append(acc)

        for movie, accuracy in zip(movie_numbers, acc_test):
            acc_by_movie_test[movie.item()] = accuracy.item()

    print('Test Accuracy:', np.mean(acc_test))
    print(acc_by_movie_test)

Epoch 0 Accuracy: 0.4366319477558136, Loss: 0.6945497476392322
Epoch 1 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 2 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 3 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 4 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 5 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 6 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 7 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 8 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 9 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 10 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 11 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 12 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 13 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 14 Accuracy: 0.5651041865348816, Loss: 0.6931471824645996
Epoch 15 Accuracy: 0.5651041865348816, Loss: 0.693

KeyboardInterrupt: 

In [8]:
for movie, accuracy in sorted(acc_by_movie_test.items()):
    print(f"Movie: {movie}, Accuracy: {accuracy}")

# mean acc for the first 12 movies
mean_acc = np.mean([acc_by_movie_test[i] for i in range(1, 13)])
print(f"Mean accuracy for the first 12 movies: {mean_acc}")

# mean acc for the last 12 movies
mean_acc = np.mean([acc_by_movie_test[i] for i in range(17, 29)])
print(f"Mean accuracy for the last 12 movies: {mean_acc}")

Movie: 1.0, Accuracy: 0.46875
Movie: 2.0, Accuracy: 0.625
Movie: 3.0, Accuracy: 0.625
Movie: 4.0, Accuracy: 0.46875
Movie: 5.0, Accuracy: 0.6875
Movie: 6.0, Accuracy: 0.53125
Movie: 7.0, Accuracy: 0.65625
Movie: 8.0, Accuracy: 0.59375
Movie: 9.0, Accuracy: 0.75
Movie: 10.0, Accuracy: 0.6875
Movie: 11.0, Accuracy: 0.5625
Movie: 12.0, Accuracy: 0.625
Movie: 17.0, Accuracy: 0.75
Movie: 18.0, Accuracy: 0.59375
Movie: 19.0, Accuracy: 0.625
Movie: 20.0, Accuracy: 0.59375
Movie: 21.0, Accuracy: 0.59375
Movie: 22.0, Accuracy: 0.59375
Movie: 23.0, Accuracy: 0.59375
Movie: 24.0, Accuracy: 0.65625
Movie: 25.0, Accuracy: 0.65625
Movie: 26.0, Accuracy: 0.53125
Movie: 27.0, Accuracy: 0.5
Movie: 28.0, Accuracy: 0.5625
Mean accuracy for the first 12 movies: 0.6067708333333334
Mean accuracy for the last 12 movies: 0.6041666666666666
