In [79]:

import torch
from matplotlib import pyplot as plt
import numpy as np
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
from torch_geometric.data import InMemoryDataset
import networkx as nx
import os
from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset, download_url

from IPython.display import Javascript
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv,SAGEConv
from torch_geometric.nn import global_mean_pool


In [80]:
test_patients = set(['F117', 'F130', 'F128', 'F133', 'F135', 'F138', 'F139', 'F136',
                 'F131', 'F344', 'F504', 'F510', 'M512', 'F341', 'F337', 'F324',
                 'F308', 'F317'])

In [81]:
# unq_set = set()
# for elem in os.listdir('clean_wcoh/kz_clean/health/'):
#     if 'log' in elem:
#         continue
#     name = elem.split('_')[1]
#     if name in test_patients:
#         if name not in unq_set:
#             unq_set.add(name)

In [82]:
elecs = ['Fp1', 'Fp2', 'Fpz', 'F3', 'F4', 'Fz', 'C3', 'C4',
         'Cz', 'P3', 'P4', 'Pz', '01', '02', '0z', 'F7', 
         'F8', 'T3', 'T4','T5', 'T6']

In [85]:
class GNNDataset(InMemoryDataset):
    
    def __init__(self, root, data_dict, idx, feature_names, allow_loops = True, weighted = True, threshold = 0.6,
                 transform = None, pre_transform = None, pre_filter = None):
        self.data = np.load(data_dict, allow_pickle = True).item()
        self.stage = idx # idx 0 - train, idx 1 - test
        self.feature_names = feature_names
        self.allow_loops = allow_loops
        self.weighted = weighted
        self.threshold = threshold
        
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[self.stage])
        
    def vectorize_adj_mat_coo(self, matrix, allow_loops = True):
        source_nodes = []
        target_nodes = []
        if allow_loops:
            for i in range(matrix.shape[0]):
                for j in range(i, matrix.shape[1]):
                    source_nodes.append(i)
                    target_nodes.append(j)
        else:
            for i in range(matrix.shape[0]):
                for j in range(i + 1, matrix.shape[1]):
                    source_nodes.append(i)
                    target_nodes.append(j)
        return source_nodes, target_nodes
    
    def vectorize_adj_mat_weights(self, matrix, allow_loops = True, weighted = True, threshold = 0.5):
        edge_weights = []
        if weighted:
            if allow_loops:
                for i in range(matrix.shape[0]):
                    for j in range(i, matrix.shape[1]):
                        edge_weights.append(matrix[i][j])
            else:
                for i in range(matrix.shape[0]):
                    for j in range(i + 1, matrix.shape[1]):
                        edge_weights.append(matrix[i][j])
        else:
            threshold = np.min(np.max(matrix, axis = 1))
            mask = np.array((matrix > threshold), dtype = np.uint8)
            if allow_loops:
                for i in range(mask.shape[0]):
                    for j in range(i, mask.shape[1]):
                        edge_weights.append(mask[i][j])
            else:
                for i in range(mask.shape[0]):
                    for j in range(i + 1, mask.shape[1]):
                        edge_weights.append(mask[i][j])
        return edge_weights

    def upload_data(self, patch_name):
        '''
        input:
            path_to_data: path to precomputed node representations
            path_to_adj_matr: path to precomputed adj matrices
            
        returns:
            Pygeometric Data object (see PyG docs)
        '''
        X, target, adj_matrix = self.data[patch_name] # triplet in format [X, target, A]
        edge_index = np.array(self.vectorize_adj_mat_coo(adj_matrix, 
                                                         allow_loops = self.allow_loops))
        edge_features = self.vectorize_adj_mat_weights(adj_matrix, 
                                                       allow_loops = self.allow_loops,
                                                       weighted = self.weighted, 
                                                       threshold = self.threshold)
        return Data(x = torch.tensor(X), 
                    edge_index = torch.tensor(edge_index),
                    edge_attrs = edge_features, 
                    y = torch.tensor([target]))  
    
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return [f'gnn_dataset_train_{self.feature_names}.pt',
                f'gnn_dataset_test_{self.feature_names}.pt']

    def download(self):
        pass
        
    def process(self):
        data_list = []
        for elem in self.data.keys():
            data_list.append(self.upload_data(elem))
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[self.stage])
    

In [65]:
train_dataset = GNNDataset('./', 'train_full_data.npy', 0, 'all_features') # idx = 0 - train
test_dataset = GNNDataset('./', 'test_full_data.npy', 1, 'all_features') # idx = 1 - train

In [87]:
train_dataset = GNNDataset('./', 'gnn_prepared_data/power_and_entropy_gnn_train.npy', 0, 'all_features',  weighted = False) # idx = 0 - train
test_dataset = GNNDataset('./', 'gnn_prepared_data/power_and_entropy_gnn_test.npy', 1, 'all_features', weighted = False) # idx = 1 - train


In [88]:
batch_size = 50
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

# for step, data in enumerate(train_loader):
#     print(f'Step {step + 1}:')
#     print('=======')
#     print(f'Number of graphs in the current batch: {data.num_graphs}')
#     print(data)
#     print()

In [89]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = SAGEConv(train_dataset.num_node_features, hidden_channels)
        self.batchnorm = torch.nn.BatchNorm1d(hidden_channels) #Vovan added
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, hidden_channels)
        self.conv4 = SAGEConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, train_dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
#         x = self.batchnorm(x)
        x = x.relu()
        x = self.conv2(x, edge_index)
#         x = self.batchnorm(x)
        x = x.relu()
        x = self.conv3(x, edge_index)
        #added
        x = x.relu()
        x = self.conv4(x ,edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        #x = F.dropout(x, p=0.1, training=self.training)
        x = self.lin(x)
        
        return x


1) adj(2) + loops(2) + weighted(2) + data(3) = 2*2*2*3 = 12 exp
2) batch size, num_layers, num_epoches, lr, scheduler
3)

object : [X, target, A]
X - only use local info about the node
1) acc entropy <
2) acc power <
3) acc power + entropy < gnn(power + entropy + adj matrix)

In [90]:

model = GCN(hidden_channels = 128)
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001, weight_decay = 0.001)
criterion = torch.nn.CrossEntropyLoss()


In [91]:
model

GCN(
  (conv1): SAGEConv(8, 128, aggr=mean)
  (batchnorm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): SAGEConv(128, 128, aggr=mean)
  (conv3): SAGEConv(128, 128, aggr=mean)
  (conv4): SAGEConv(128, 128, aggr=mean)
  (lin): Linear(in_features=128, out_features=2, bias=True)
)

In [92]:
def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        out = model(data.x.type(dtype=torch.float), data.edge_index, data.batch)  # Perform a single forward pass.
        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(loader):
    model.eval()
    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model(data.x.type(dtype=torch.float), data.edge_index, data.batch)  
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.

In [93]:
best_train, cur_epoch, best_val = -1, -1, -1 
for epoch in range(1, 500):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    best_train, cur_epoch, best_val = (train_acc, epoch, test_acc) if test_acc > best_val \
                                    else (best_train, cur_epoch, best_val)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

print(f'Best test accuracy - {best_val} on epoch {cur_epoch} with train accuracy - {best_train}')

Epoch: 001, Train Acc: 0.4990, Test Acc: 0.5012
Epoch: 002, Train Acc: 0.4990, Test Acc: 0.5012
Epoch: 003, Train Acc: 0.4990, Test Acc: 0.5012
Epoch: 004, Train Acc: 0.5317, Test Acc: 0.5588
Epoch: 005, Train Acc: 0.5385, Test Acc: 0.5012
Epoch: 006, Train Acc: 0.5400, Test Acc: 0.5081
Epoch: 007, Train Acc: 0.5346, Test Acc: 0.5012
Epoch: 008, Train Acc: 0.5521, Test Acc: 0.5230
Epoch: 009, Train Acc: 0.5911, Test Acc: 0.5161
Epoch: 010, Train Acc: 0.5941, Test Acc: 0.5726
Epoch: 011, Train Acc: 0.5941, Test Acc: 0.5541
Epoch: 012, Train Acc: 0.5892, Test Acc: 0.5288
Epoch: 013, Train Acc: 0.5234, Test Acc: 0.5334
Epoch: 014, Train Acc: 0.5911, Test Acc: 0.5380
Epoch: 015, Train Acc: 0.5463, Test Acc: 0.5588
Epoch: 016, Train Acc: 0.5833, Test Acc: 0.5841
Epoch: 017, Train Acc: 0.5916, Test Acc: 0.5380
Epoch: 018, Train Acc: 0.5180, Test Acc: 0.4988
Epoch: 019, Train Acc: 0.5770, Test Acc: 0.5726
Epoch: 020, Train Acc: 0.5814, Test Acc: 0.5818
Epoch: 021, Train Acc: 0.5975, Test Acc:

Epoch: 172, Train Acc: 0.6082, Test Acc: 0.6210
Epoch: 173, Train Acc: 0.6209, Test Acc: 0.6267
Epoch: 174, Train Acc: 0.6204, Test Acc: 0.6071
Epoch: 175, Train Acc: 0.6165, Test Acc: 0.6233
Epoch: 176, Train Acc: 0.6160, Test Acc: 0.6129
Epoch: 177, Train Acc: 0.5970, Test Acc: 0.5749
Epoch: 178, Train Acc: 0.6179, Test Acc: 0.6048
Epoch: 179, Train Acc: 0.6131, Test Acc: 0.6244
Epoch: 180, Train Acc: 0.6170, Test Acc: 0.6002
Epoch: 181, Train Acc: 0.6140, Test Acc: 0.6279
Epoch: 182, Train Acc: 0.6121, Test Acc: 0.6221
Epoch: 183, Train Acc: 0.6145, Test Acc: 0.5968
Epoch: 184, Train Acc: 0.6199, Test Acc: 0.6325
Epoch: 185, Train Acc: 0.5692, Test Acc: 0.5553
Epoch: 186, Train Acc: 0.6087, Test Acc: 0.5749
Epoch: 187, Train Acc: 0.6140, Test Acc: 0.6210
Epoch: 188, Train Acc: 0.6194, Test Acc: 0.6002
Epoch: 189, Train Acc: 0.5936, Test Acc: 0.5749
Epoch: 190, Train Acc: 0.6174, Test Acc: 0.6325
Epoch: 191, Train Acc: 0.6189, Test Acc: 0.6106
Epoch: 192, Train Acc: 0.5916, Test Acc:

Epoch: 343, Train Acc: 0.6291, Test Acc: 0.6406
Epoch: 344, Train Acc: 0.6287, Test Acc: 0.6382
Epoch: 345, Train Acc: 0.6291, Test Acc: 0.5829
Epoch: 346, Train Acc: 0.6111, Test Acc: 0.5737
Epoch: 347, Train Acc: 0.6267, Test Acc: 0.6325
Epoch: 348, Train Acc: 0.6053, Test Acc: 0.6233
Epoch: 349, Train Acc: 0.6267, Test Acc: 0.6244
Epoch: 350, Train Acc: 0.5999, Test Acc: 0.5657
Epoch: 351, Train Acc: 0.6287, Test Acc: 0.5841
Epoch: 352, Train Acc: 0.6038, Test Acc: 0.5703
Epoch: 353, Train Acc: 0.6277, Test Acc: 0.5945
Epoch: 354, Train Acc: 0.6243, Test Acc: 0.6198
Epoch: 355, Train Acc: 0.6243, Test Acc: 0.6221
Epoch: 356, Train Acc: 0.6179, Test Acc: 0.6175
Epoch: 357, Train Acc: 0.6204, Test Acc: 0.6336
Epoch: 358, Train Acc: 0.6126, Test Acc: 0.6267
Epoch: 359, Train Acc: 0.6121, Test Acc: 0.6221
Epoch: 360, Train Acc: 0.6199, Test Acc: 0.6348
Epoch: 361, Train Acc: 0.6101, Test Acc: 0.6302
Epoch: 362, Train Acc: 0.5161, Test Acc: 0.5046
Epoch: 363, Train Acc: 0.6014, Test Acc:

In [None]:
0.6509

In [30]:
t = np.load('test_full_data.npy', allow_pickle = True).item()

In [34]:
health = 0
mdd = 0
for elem in t.keys():
    if t[elem][1] == 0:
        health += 1
    elif t[elem][1] == 1:
        mdd += 1

In [35]:
health

433

In [36]:
mdd

435

In [37]:
(433 + 436) * 0.68

590.9200000000001