# Imports + settings

In [1]:
#Python basic
import os
import math

#Torch
import torch
import torch.nn.functional as F
from torch_geometric import seed_everything
from torch_geometric.loader import DataLoader
from torch_geometric.utils import from_networkx
from torch_geometric.nn import (GCNConv, aggr, 
                                global_mean_pool, 
                                global_max_pool, 
                                MessagePassing)

#Other
import networkx as nx
from scripts.prepare_data import create_data_split
import pandas as pd

from sklearn.metrics import f1_score


In [2]:
file_list = os.listdir('data.nosync/networks_multi_test')
#create_data_split(files = file_list, train_val_size=0.15)
class_dict = {'TD': 0, 'ASD-ADHD':1, 'ASD':2, 'ADHD':3}

In [3]:
def load_dataset(dataset:str):
    file_list = pd.read_csv(f'data.nosync/networks_multi_test/{dataset}_set_files.csv')['file'].to_list()
    data_list = []

    for i in file_list:
        network_class = i.split('_')[3]
        G = nx.read_gml(f'data.nosync/networks_multi_test/{i}')
        for e in G.edges():
            u, v = e
            for key, value in G[u][v].items():
                G[u][v][key]['feature_value'] = max(value['feature_value'], 0) if not math.isnan(value['feature_value']) else 0
                if type(G[u][v][key]['feature_value']) != int and type(G[u][v][key]['feature_value']) != float:
                    print(type(G[u][v][key]['feature_value']))
        graph = from_networkx(G, group_node_attrs = ['var_bin_21_1', 'var_bin_21_2', 
                                                     'var_bin_21_3', 'var_bin_21_4',
                                                     'var_bin_21_5', 'var_bin_21_6',
                                                     'var_bin_21_7', 'var_bin_21_8',
                                                     ], 
                                                     group_edge_attrs=['feature_value'])
        graph.y = class_dict[network_class]
        data_list.append(graph)
    return data_list

In [4]:
train_data = load_dataset(dataset='train')
val_data = load_dataset(dataset='val')

In [5]:
class MyConv(MessagePassing):
    def __init__(self, in_, out_, agg_func, dropout, activation):
        super().__init__(agg_func)
        self.conv = self.conv1 = GCNConv(in_, out_, 
                                         add_self_loops = True,
                                         normalize = True)
        self.p = dropout
        if activation == 'sigmoid':
           self.activation = F.sigmoid
        elif activation == 'softmax':
           self.activation = F.softmax
        elif activation == 'relu':
           self.activation = F.relu


    def forward(self, x, edge_index, edge_weight):
        x = self.conv(x, edge_index, edge_weight)
        x = F.dropout(x, p=self.p, training=self.training)
        x = self.activation(x)
        return x

class GCN(torch.nn.Module):
    def __init__(self, out_, dropout_, activation_):
        super().__init__()
        self.conv1 = MyConv(in_ = 8, #Fixed to 16 features
                            out_ = out_, 
                            agg_func = aggr.MeanAggregation(),
                            dropout = dropout_,
                            activation = activation_)
        self.conv2 = MyConv(in_ = out_, 
                            out_ = 4, #Fixed to four classes
                            agg_func = aggr.MeanAggregation(),
                            dropout = dropout_,
                            activation = activation_)

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr.to(torch.float32)
        batch = data.batch

        x = self.conv1(x, edge_index, edge_weight)
        x = self.conv2(x, edge_index, edge_weight)
        
        #Get global value for network
        x = global_max_pool(x, batch)
        
        return F.log_softmax(x, dim=1)
    
    def predict(self, inputs):
        with torch.no_grad():
            y = self.forward(inputs)
            predictions = torch.argmax(y, dim=-1)  # Take the argmax 
            return predictions

def val_loop(model, val_loader, device):
    val_loss = 0
    with torch.no_grad():  
        for val_batch in val_loader:
            val_batch.to(device)
            val_y_hat = model.forward(val_batch)
            val_loss_batch = F.nll_loss(val_y_hat, val_batch.y)
            val_loss += val_loss_batch
    return val_loss/len(val_loader)

def train_loop(model, train_loader, optimizer, device):
    train_loss = 0
    model.train()
    for batch in train_loader:
        batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = F.nll_loss(out, batch.y)
        loss.backward()
        optimizer.step()
        train_loss += loss
    return train_loss/len(train_loader)

def main(layer_1_out, dropout, activation, batch_size, train_data, val_data, optimizer, learning_rate, num_epochs):
  seed_everything(42)

  print('Device:', 'cuda' if torch.cuda.is_available() else 'cpu', flush = True)

  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = GCN(out_ = layer_1_out, 
              dropout_ = dropout, 
              activation_ = activation).to(device)
  
  #Make dataset loaders
  train_loader = DataLoader(train_data, batch_size= batch_size, shuffle = True)
  val_loader = DataLoader(val_data, batch_size= batch_size)
  
  # Optimizer
  if optimizer == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  else:
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
  min_val = 10**6
  
  
  for epoch in range(num_epochs):
    train_loss = train_loop(model, train_loader, optimizer, device)
    val_loss = val_loop(model, val_loader, device)
    if val_loss < min_val:
        min_val = val_loss
        print(f"Epoch: {epoch}, Train loss: {train_loss.item()}, Val loss: {val_loss.item()}")



In [8]:
main(layer_1_out = 32, 
     dropout = 0.5, 
     activation = 'relu', 
     batch_size = 32, 
     train_data = train_data, 
     val_data = val_data, 
     optimizer = 'sgd', 
     learning_rate = 0.001,
     num_epochs = 500)

Device: cpu
Epoch: 0, Train loss: 1.312562346458435, Val loss: 1.3357399702072144
Epoch: 1, Train loss: 1.2672550678253174, Val loss: 1.2697415351867676
Epoch: 2, Train loss: 1.237072229385376, Val loss: 1.220529556274414
Epoch: 3, Train loss: 1.2152904272079468, Val loss: 1.1509454250335693
Epoch: 5, Train loss: 1.2049084901809692, Val loss: 1.0887706279754639
Epoch: 8, Train loss: 1.1750200986862183, Val loss: 1.086126685142517
Epoch: 9, Train loss: 1.1940720081329346, Val loss: 1.0599455833435059
Epoch: 12, Train loss: 1.1894205808639526, Val loss: 1.055171251296997
Epoch: 20, Train loss: 1.1769447326660156, Val loss: 1.0198699235916138
Epoch: 44, Train loss: 1.1736016273498535, Val loss: 1.0197339057922363
Epoch: 47, Train loss: 1.1683127880096436, Val loss: 1.0098687410354614
Epoch: 159, Train loss: 1.1617472171783447, Val loss: 1.001219391822815
