# Imports + settings

In [1]:
#Python basic
import os
import math

#Torch
import torch
import torch.nn.functional as F
from torch_geometric import seed_everything
from torch_geometric.loader import DataLoader
from torch_geometric.utils import from_networkx
from torch_geometric.nn import (GCNConv, aggr, 
                                global_mean_pool, 
                                global_max_pool, 
                                MessagePassing)

#Other
import networkx as nx
from scripts.prepare_data import create_data_split
import pandas as pd

In [2]:
file_list = os.listdir('data.nosync/networks_multi')
create_data_split(files = file_list, train_val_size=0.15)
class_dict = {'TD': 0, 'ASD-ADHD':1, 'ASD':2, 'ADHD':3}

Train size: 164, Test size: 36, Val size: 37
Train size: 126, Test size: 28, Val size: 28
Train size: 17, Test size: 3, Val size: 3
Train size: 16, Test size: 4, Val size: 4
Train size: 22, Test size: 4, Val size: 4
Train size: 40, Test size: 8, Val size: 8
Train size: 10, Test size: 2, Val size: 2
Train size: 66, Test size: 14, Val size: 14


In [3]:
def load_dataset(dataset:str):
    file_list = pd.read_csv(f'data.nosync/networks_multi/{dataset}_set_files.csv')['file'].to_list()
    data_list = []

    for i in file_list:
        network_class = i.split('_')[3]
        G = nx.read_gml(f'data.nosync/networks_multi/{i}')
        for e in G.edges():
            u, v = e
            for key, value in G[u][v].items():
                G[u][v][key]['edge_features'] = max(value['edge_features']+1, 0) if not math.isnan(value['edge_features']) else 0
        graph = from_networkx(G, group_node_attrs = 'node_features', group_edge_attrs='edge_features')
        graph.y = class_dict[network_class]
        data_list.append(graph)
    return data_list

In [5]:
train = load_dataset(dataset='train')
val = load_dataset(dataset='val')

In [6]:
import numpy as np
from math import log
classes = []
for i in train:
    classes.append(i.y)
np.unique(classes, return_counts= True)

(array([0, 1, 2, 3]), array([214,  27,  56, 164]))

In [7]:
#expected loss
class_dict = {0:214, 1:27, 2:56, 3:164}
total = len(classes)
total_loss = 0
for key, value in class_dict.items():
    total_loss += (-(value/total)*log(value/total))
print(total_loss)

1.1461870854236984


In [9]:
seed_everything(42)

class MyConv(MessagePassing):
    def __init__(self, in_, out_, agg_func):
        super().__init__(agg_func)
        self.conv = self.conv1 = GCNConv(in_, out_, 
                                         add_self_loops = True,
                                         normalize = True)

    def forward(self, x, edge_index, edge_weight):
        x = self.conv(x, edge_index, edge_weight)
        x = F.dropout(x, p=0.2, training=self.training)
        x = F.relu(x)
        return x

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = MyConv(16, 16, aggr.MultiAggregation(
                                    ['mean', 'std']))
        self.conv2 = MyConv(16, 16, aggr.MeanAggregation())
        self.conv3 = MyConv(16, 16, aggr.MeanAggregation())
        self.conv4 = MyConv(16, 8, aggr.MeanAggregation())
        self.conv5 = MyConv(8, 4, aggr.MeanAggregation())

    def after_each_layer(self, x):
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        return x

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr.to(torch.float32)
        batch = data.batch

        x = self.conv1(x, edge_index, edge_weight)
        #x = self.after_each_layer(x)
        
        x = self.conv2(x, edge_index, edge_weight)
        #x = self.after_each_layer(x)
        
        x = self.conv3(x, edge_index, edge_weight)
        #x = self.after_each_layer(x)
        
        x = self.conv4(x, edge_index, edge_weight)

        x = self.conv5(x, edge_index, edge_weight)

        #Get global value for network
        x = global_max_pool(x, batch)
        
        return F.log_softmax(x, dim=1)
    
    def predict(self, inputs):
        with torch.no_grad():
            y = self.forward(inputs, train=False)
            predictions = torch.argmax(y, dim=-1)  # Take the argmax 
            return predictions
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()

loader = DataLoader(train, batch_size=64, shuffle = True)
val_loder = DataLoader(val, batch_size=64, shuffle = True)

min_score = 10**6

for epoch in range(100):
    epoch_loss = 0
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = F.nll_loss(out, batch.y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss

    epoch_loss = epoch_loss/len(loader)
    
    if epoch_loss < min_score:
        min_score = epoch_loss
        
        val_loss = 0.0
        with torch.no_grad():  
                for val_batch in val_loder:
                    val_y = model.forward(val_batch)
                    val_loss_batch = F.nll_loss(val_y, val_batch.y)
                    val_loss += val_loss_batch
                        
        val_loss_avg = val_loss / len(val_loder)
        print(f"Epoch: {epoch}, Train loss: {round(min_score.item(), 5)}, Val loss: {round(val_loss_avg.item(), 5)}")

Epoch: 0, Train loss: 1.2887, Val loss: 1.2439
Epoch: 1, Train loss: 1.23631, Val loss: 1.21179
Epoch: 2, Train loss: 1.22288, Val loss: 1.19972
Epoch: 3, Train loss: 1.17042, Val loss: 1.19048
Epoch: 4, Train loss: 1.15512, Val loss: 1.1533
Epoch: 10, Train loss: 1.15259, Val loss: 1.13113
Epoch: 11, Train loss: 1.1391, Val loss: 1.12833
Epoch: 14, Train loss: 1.13307, Val loss: 1.13215
Epoch: 30, Train loss: 1.12015, Val loss: 1.14545
Epoch: 41, Train loss: 1.11517, Val loss: 1.14923
Epoch: 72, Train loss: 1.11184, Val loss: 1.12012
Epoch: 91, Train loss: 1.10896, Val loss: 1.16049
Epoch: 94, Train loss: 1.10355, Val loss: 1.15124
