In [1]:
import torch
from matplotlib import pyplot as plt
import numpy as np
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
import networkx as nx
import os
from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset, download_url

from IPython.display import Javascript
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv,SAGEConv
from torch_geometric.nn import global_mean_pool

import wandb

In [2]:
test_patients = set(['F117', 'F130', 'F128', 'F133', 'F135', 'F138', 'F139', 'F136',
                 'F131', 'F344', 'F504', 'F510', 'M512', 'F341', 'F337', 'F324',
                 'F308', 'F317'])
elecs = ['Fp1', 'Fp2', 'Fpz', 'F3', 'F4', 'Fz', 'C3', 'C4',
         'Cz', 'P3', 'P4', 'Pz', '01', '02', '0z', 'F7', 
         'F8', 'T3', 'T4','T5', 'T6']

In [3]:
wandb.init(project = "neuroimaging_gnn_eeg_final_project", entity = "dmasny")

[34m[1mwandb[0m: Currently logged in as: [33mdmasny[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
class GNNDataset(InMemoryDataset):
    
    def __init__(self, 
                 root, 
                 data_dict, 
                 idx, 
                 feature_names, 
                 allow_loops, 
                 weighted, 
                 threshold = 0.65,
                 transform = None, 
                 pre_transform = None, 
                 pre_filter = None):
        
        self.data = np.load(data_dict, allow_pickle = True).item()
        self.stage = idx # idx 0 - train, idx 1 - test
        self.feature_names = feature_names
        self.allow_loops = allow_loops
        self.weighted = weighted
        self.threshold = threshold
        
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[self.stage])
        
    def vectorize_adj_mat_coo(self, matrix):
        source_nodes = []
        target_nodes = []
        if self.allow_loops:
            for i in range(matrix.shape[0]):
                for j in range(i, matrix.shape[1]):
                    source_nodes.append(i)
                    target_nodes.append(j)
        else:
            for i in range(matrix.shape[0]):
                for j in range(i + 1, matrix.shape[1]):
                    source_nodes.append(i)
                    target_nodes.append(j)
        return source_nodes, target_nodes

    def vectorize_adj_mat_weights(self, matrix):
        edge_weights = []
        if self.weighted == 'weighted':
            if self.allow_loops:
                for i in range(matrix.shape[0]):
                    for j in range(i, matrix.shape[1]):
                        edge_weights.append(matrix[i][j])
            else:
                for i in range(matrix.shape[0]):
                    for j in range(i + 1, matrix.shape[1]):
                        edge_weights.append(matrix[i][j])
        else:
            threshold = np.min(np.max(matrix, axis = 1)) if self.weighted == 'unweighted_dynamic_threshold' else self.threshold
            mask = np.array((matrix > threshold), dtype = np.uint8)
            if allow_loops:
                for i in range(mask.shape[0]):
                    for j in range(i, mask.shape[1]):
                        edge_weights.append(mask[i][j])
            else:
                for i in range(mask.shape[0]):
                    for j in range(i + 1, mask.shape[1]):
                        edge_weights.append(mask[i][j])
        return edge_weights

    def upload_data(self, patch_name):
        '''
        input:
            path_to_data: path to precomputed node representations
            path_to_adj_matr: path to precomputed adj matrices
            
        returns:
            Pygeometric Data object (see PyG docs)
        '''
        X, target, adj_matrix = self.data[patch_name] # triplet in format [X, target, A]
        edge_index = np.array(self.vectorize_adj_mat_coo(adj_matrix))
        edge_features = self.vectorize_adj_mat_weights(adj_matrix)
        return Data(x = torch.tensor(X), 
                    edge_index = torch.tensor(edge_index),
                    edge_attrs = edge_features, 
                    y = torch.tensor([target]))  
    
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return [f'gnn_dataset_train_{self.feature_names}.pt',
                f'gnn_dataset_test_{self.feature_names}.pt']

    def download(self):
        pass
        
    def process(self):
        data_list = []
        for elem in self.data.keys():
            data_list.append(self.upload_data(elem))
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[self.stage])
    

In [5]:
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, output_dim = 2):
        super(GCN, self).__init__()
        torch.manual_seed(2022)
        self.conv1 = SAGEConv(num_node_features, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, hidden_channels)
        self.conv4 = SAGEConv(hidden_channels, hidden_channels)
        self.linear = Linear(hidden_channels, output_dim)

    def forward(self, x, edge_index, batch):
        
        x = self.conv1(x, edge_index)
        x = x.relu()
        
        x = self.conv2(x, edge_index)
        x = x.relu()
        
        x = self.conv3(x, edge_index)
        x = x.relu()
        
        x = self.conv4(x ,edge_index)
        x = x.relu()
        

        # Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
        # Classifier
        x = self.linear(x)
        
        return x


In [6]:
def train(model, 
          epoches, 
          criterion, 
          train_dataloader, 
          test_dataloader, 
          expirement_name, 
          path_to_save_weights = 'model_weights'):
    
    best_accuracy = 0
    for i in range(epoches):
        model.train()
        train_correct = 0
        train_loss = 0
        for data in train_dataloader:
            out = model(data.x.type(dtype = torch.float), 
                        data.edge_index, 
                        data.batch)  
            loss = criterion(out, data.y) 
            train_loss += loss
            pred = out.argmax(dim = 1)
            train_correct += int((pred == data.y).sum())
            loss.backward() 
            optimizer.step()  
            optimizer.zero_grad() 
        train_loss = train_loss / len(train_dataloader.dataset)
        train_accuracy = train_correct / len(train_dataloader.dataset)
        
        wandb.log({'train_loss': train_loss})
        wandb.log({'train_accuracy': train_accuracy})
        
        model.eval()
        with torch.no_grad():  
            test_correct = 0
            for data in test_dataloader:  
                out = model(data.x.type(dtype = torch.float), data.edge_index, data.batch)  
                pred = out.argmax(dim = 1)  
                test_correct += int((pred == data.y).sum())
            test_accuracy = test_correct / len(test_dataloader.dataset)
            
            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                torch.save(model.state_dict(), f'{path_to_save_weights}/{expirement_name}.pth')
                
            print(f'Epoch:{i} Train acc:{train_accuracy} Test acc:{test_accuracy} Train loss:{train_loss}')
            
            wandb.log({'test_accuracy': test_accuracy})
            wandb.log({'best_test_accuracy': best_accuracy})
            
    
    wandb.finish()

In [None]:
loops = [True, False]
weighted = ['weighted', 'unweighted_dynamic_threshold', 'unweighted_static_threshold']
data = ['power_and_entropy', 'only_entropy', 'only_powerbands']

lr = 0.001
wd = 0.001
batch_size = 50
hidden_channels = 128 
epoches = 1000

for loops_config in loops:
    for weight_config in weighted:
        for dataset in data:
            
            experiment_name = f'self_loops={loops_config}_weighted={weight_config}_data={dataset}'
            
            train_dataset = GNNDataset(root = './', 
                                       data_dict = f'gnn_prepared_data/{dataset}_gnn_train.npy',
                                       idx = 0, 
                                       feature_names = experiment_name, 
                                       allow_loops = loops_config, 
                                       weighted = weight_config) # idx = 0 - train
            test_dataset = GNNDataset(root = './', 
                                      data_dict = f'gnn_prepared_data/{dataset}_gnn_test.npy', 
                                      idx = 1, 
                                      feature_names = experiment_name,
                                      allow_loops = loops_config, 
                                      weighted = weight_config) # idx = 1 - train

            train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
            test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

            model = GCN(num_node_features = train_dataset.num_node_features, 
                        hidden_channels = hidden_channels)
            optimizer = torch.optim.AdamW(model.parameters(), lr = lr, weight_decay = wd)
            criterion = torch.nn.CrossEntropyLoss()


            config = {
                        'learning_rate': lr,
                        'weight_decay': wd,
                        'epochs': epoches,
                        'training_batch_size' : batch_size,
                        'validation_batch_size' : batch_size,
                        'loops_config': loops_config,
                        'weight_config': weight_config,
                        'dataset': dataset,
                        'criterion': criterion,    
                        'node_representation_size': train_dataset.num_node_features, 
                        'model': {
                                    'num_graph_conv_blocks': 4,
                                    'hidden_channels' : hidden_channels,
                                    'activation' : 'ReLU',
                                    'readout': 'global_mean_pool'}}

            wandb.init(project = 'neuroimaging_gnn_eeg_final_project', 
                       entity = 'dmasny',
                       name = experiment_name, 
                       config = config)

            train(model = model,
                  epoches = epoches, 
                  criterion = criterion, 
                  train_dataloader = train_loader, 
                  test_dataloader = test_loader,
                  expirement_name = experiment_name)
            print(experiment_name)
            del model


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.144451…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668827567870418, max=1.0…

Epoch:0 Train acc:0.49415204678362573 Test acc:0.4988479262672811 Train loss:0.014838864095509052
Epoch:1 Train acc:0.5112085769980507 Test acc:0.5011520737327189 Train loss:0.014280869625508785
Epoch:2 Train acc:0.5082846003898636 Test acc:0.4988479262672811 Train loss:0.014163868501782417
Epoch:3 Train acc:0.49171539961013644 Test acc:0.4988479262672811 Train loss:0.014291990548372269
Epoch:4 Train acc:0.5160818713450293 Test acc:0.5011520737327189 Train loss:0.01417999155819416
Epoch:5 Train acc:0.5384990253411306 Test acc:0.5011520737327189 Train loss:0.014096962288022041
Epoch:6 Train acc:0.5014619883040936 Test acc:0.5011520737327189 Train loss:0.014203084632754326
Epoch:7 Train acc:0.5307017543859649 Test acc:0.5011520737327189 Train loss:0.014203093945980072
Epoch:8 Train acc:0.5389863547758285 Test acc:0.5023041474654378 Train loss:0.014072886668145657
Epoch:9 Train acc:0.5477582846003899 Test acc:0.5702764976958525 Train loss:0.013991250656545162
Epoch:10 Train acc:0.54775828

In [None]:
!tree processed/

In [None]:
# train_dataset = GNNDataset('./', 'gnn_prepared_data/power_and_entropy_gnn_train.npy', 
#                            0, 
#                            'all_features',
#                            allow_loops = True,
#                            weighted = 'weighted')
# test_dataset = GNNDataset('./', 'gnn_prepared_data/power_and_entropy_gnn_test.npy', 
#                           1, 
#                           'all_features',
#                           allow_loops = True,
#                           weighted = 'weighted') 

# model = GCN(num_node_features = train_dataset.num_node_features, 
#                             hidden_channels = hidden_channels)
# optimizer = torch.optim.AdamW(model.parameters(), lr = lr, weight_decay = wd)
# criterion = torch.nn.CrossEntropyLoss()


# train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
# test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)
    
# train(model = model,
#       epoches = epoches, 
#       criterion = criterion, 
#       train_dataloader = train_loader, 
#       test_dataloader = test_loader,
#       expirement_name = 'debug')

# buff = []