# GNN4DM: A Graph Neural Network-based method to identify overlapping functional disease modules

This is the original implementation of GNN4DM, a graph neural network-based structured model that automates the discovery of overlapping functional disease modules. GNN4DM effectively integrates network topology with genomic data to learn the representations of the genes corresponding to functional modules and align these with known biological pathways for enhanced interpretability.

First, we import the necessary libraries.

In [None]:
import torch
import torch_geometric
import networkx as nx
import pandas as pd
import numpy as np
import scipy as sp

from sklearn.model_selection import ParameterSampler

from models import *
from data import *
from utils import *

import statistics

from cdlib import evaluation

from torch.utils.tensorboard import SummaryWriter

print( f"PyTorch version: {torch.__version__}" )
print( f"CUDA version: {torch.version.cuda}" )
print( f"PyG version: {torch_geometric.__version__}" )
print( f"NetworkX version: {nx.__version__}" )

Second, we import graph from STRING, read input features from GTEx, GWAS Atlas, and compute centrality measures (as additional input features).
Besides, we import the output features (used as auxiliary prediction tasks) from MSigDB.

In [None]:
inputFeatures = ['GTEx','GWASAtlas_nonukb_512','centrality_measures']
outputFeatures = ['kegg','reactome','wikipathways','biocarta','pid']

dataset = get_STRING_dataset( inputFeatures = inputFeatures, outputFeatures = outputFeatures )

show_dataset( dataset )

data = dataset['data']
G = dataset['graph']

Send data to cuda, if available.

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
print(device)

data = data.to(device)

Define training and validation functions.

In [None]:
def train( model, data, source_edge_index, pos_edge_index, neg_edge_index, lambda_bce_loss, lambda_l1_positives_loss, lambda_l2_positives_loss, optimizer, writer, iter = None):
    model.train()
    optimizer.zero_grad()

    # compute node embeddings with the encoder       
    F, y_preds = model.encode( data.x, data.edge_index )
    # compute link strength with the decoder for the positive and the negative edges, respectively
    H_positive = model.decode( F, source_edge_index, pos_edge_index )
    H_negative = model.decode( F, source_edge_index, neg_edge_index )
    # compute loss
    ## Bernoulli-Poisson loss
    bp_loss = model.nll_BernoulliPoisson_loss( H_positive, H_negative )
    ## Binary cross-entropy losses for all datasets
    bce_losses = dict()
    for key in y_preds:
        if data.train_indices[key].dim() == 1:
            bce_losses[key] = model.bce_loss( y_preds[key][data.train_indices[key],:], data.ys[key][data.train_indices[key],:] )
        else:
            bce_losses[key] = model.bce_loss( y_preds[key][data.train_indices[key]], data.ys[key][data.train_indices[key]] )
    # bce_losses = { key : model.bce_lprint( "\n", results )oss( y_preds[key][data.train_indices[key]], data.ys[key][data.train_indices[key]] ) for key in y_preds }
    ## GCN L1 and L2 losses
    l1_loss_gcn, l2_loss_gcn = model.gcn_l1_l2_losses()
    ## 
    l1_positives_loss, l2_positives_loss = model.output_models_l1_l2_losses()
    ## compute final loss
    loss = bp_loss
    for bce_loss in bce_losses.values():
        loss += lambda_bce_loss * bce_loss
    loss += 0.0 * l1_loss_gcn
    loss += 0.0 * l2_loss_gcn
    loss += lambda_l1_positives_loss * l1_positives_loss
    loss += lambda_l2_positives_loss * l2_positives_loss
 
    if iter != None:
        writer.add_scalar('train/loss', loss.item(), iter)
        writer.add_scalar('train/bernoulli_poisson_loss', bp_loss, iter)
        for key, bce_loss in bce_losses.items():
            writer.add_scalar(f"train/bce_loss_{key}", bce_loss, iter)
        writer.add_scalar('train/l1_loss_gcn', l1_loss_gcn, iter)
        writer.add_scalar('train/l2_loss_gcn', l2_loss_gcn, iter)
        writer.add_scalar('train/l1_positives_loss', l1_positives_loss, iter)
        writer.add_scalar('train/l2_positives_loss', l2_positives_loss, iter)

    loss.backward(retain_graph=True)
    optimizer.step()

    return

@torch.no_grad()
def test(model, data, G : nx.Graph, writer=None, iter=None):
    model.eval()

    # compute node embeddings with the encoder       
    F, y_preds = model.encode( data.x, data.edge_index )

    internal_evaluation_scores = {}

    # compute BCE loss in validation sets of pathway DBs
    # bce_losses = { key : model.bce_loss( y_preds[key][data.valid_indices[key]], data.ys[key][data.valid_indices[key]] ) for key in y_preds }
    bce_losses = dict()
    for key in y_preds:
        if data.valid_indices[key].dim() == 1:
            bce_losses[key] = model.bce_loss( y_preds[key][data.valid_indices[key],:], data.ys[key][data.valid_indices[key],:] )
        else:
            bce_losses[key] = model.bce_loss( y_preds[key][data.valid_indices[key]], data.ys[key][data.valid_indices[key]] )
    # sum losses, and write them into the output
    valid_bce_loss = 0.0
    for key, bce_loss in bce_losses.items():
        internal_evaluation_scores[f"bce_loss_{key}"] = bce_loss.item()
        valid_bce_loss += bce_loss.item()
    internal_evaluation_scores["sum_bce_loss"] = valid_bce_loss

    # compute accuracy and F1 score in validation sets of pathway DBs
    accuracies, sensitivities, specificities, precisions, f1_scores, auprcs = calculateMetrics( y_preds, data.ys, data.valid_indices )
    for key in accuracies:
        internal_evaluation_scores[f"accuracy_{key}"] = accuracies[key]
        internal_evaluation_scores[f"sensitivity_{key}"] = sensitivities[key]
        internal_evaluation_scores[f"specificity_{key}"] = specificities[key]
        internal_evaluation_scores[f"precision_{key}"] = precisions[key]
        internal_evaluation_scores[f"f1_score_{key}"] = f1_scores[key]
        internal_evaluation_scores[f"auprc_{key}"] = auprcs[key]
    internal_evaluation_scores["mean_accuracy"] = statistics.mean(accuracies.values())
    internal_evaluation_scores["mean_sensitivity"] = statistics.mean(sensitivities.values())
    internal_evaluation_scores["mean_specificity"] = statistics.mean(specificities.values())
    internal_evaluation_scores["mean_precision"] = statistics.mean(precisions.values())
    internal_evaluation_scores["mean_f1_score"] = statistics.mean(f1_scores.values())
    internal_evaluation_scores["mean_auprc"] = statistics.mean(auprcs.values())

    cosine_self_similarity_loss = model.cosine_similarity_loss( F, type = 'l2', min_threshold = 0.1 )
    internal_evaluation_scores["cosine_self_similarity_loss"] = cosine_self_similarity_loss.item()

    rmse_of_module_size_loss = model.rmse_of_module_size_loss( F = F, threshold = model.getThresholdOfCommunities(), expected_mean = 50.0 )
    internal_evaluation_scores["rmse_of_module_size_loss"] = rmse_of_module_size_loss.item()

    internal_evaluation_scores["thresholdOfCommunities"] = model.getThresholdOfCommunities().item()

    # get predicted communities
    pred_communities, original_communities = getPredictecCommunities3( F, G, threshold = model.getThresholdOfCommunities(), normalize=False )
    # and evaluate them
    if pred_communities.communities != []:
        internal_evaluation_scores['count'] = len(pred_communities.communities)
        internal_evaluation_scores['average_size'] = pred_communities.size().score
        internal_evaluation_scores['max_size'] = np.max(pred_communities.size(summary=False))
        internal_evaluation_scores['num_smaller_5'] = np.sum(np.array(pred_communities.size(summary=False)) <= 5)
        internal_evaluation_scores['num_smaller_10'] = np.sum(np.array(pred_communities.size(summary=False)) <= 10)
        internal_evaluation_scores['num_smaller_50'] = np.sum(np.array(pred_communities.size(summary=False)) <= 50)
        internal_evaluation_scores['num_smaller_100'] = np.sum(np.array(pred_communities.size(summary=False)) <= 100)
        internal_evaluation_scores['num_smaller_200'] = np.sum(np.array(pred_communities.size(summary=False)) <= 200)
        internal_evaluation_scores['node_coverage'] = pred_communities.node_coverage
        internal_evaluation_scores['mean_module_per_node'] = (F >= model.getThresholdOfCommunities()).sum(dim=1).float().mean().item()
        internal_evaluation_scores['average_internal_degree'] = pred_communities.average_internal_degree().score
        internal_evaluation_scores['conductance'] = pred_communities.conductance().score
        internal_evaluation_scores['internal_edge_density'] = pred_communities.internal_edge_density().score
        internal_evaluation_scores['fraction_over_median_degree'] = pred_communities.fraction_over_median_degree().score
        # internal_evaluation_scores['avg_embeddedness'] = pred_communities.avg_embeddedness().score # it takes ages to compute
        # internal_evaluation_scores['modularity_overlap'] = pred_communities.modularity_overlap().score # it takes ages to compute

        if 'groundtruth_communities' in dataset:
            internal_evaluation_scores['onmi'] = evaluation.overlapping_normalized_mutual_information_MGH( dataset['groundtruth_communities'], pred_communities ).score

        if iter != None:
            for key,value in internal_evaluation_scores.items():
                writer.add_scalar(f"test/{key}", value, iter)
                
    if 'groundtruth_communities' in dataset:
        pred_communities_max = getPredictecCommunities2( F, G, threshold = None )

        if pred_communities_max.communities != []:
            internal_evaluation_scores['nmi_max'] = evaluation.overlapping_normalized_mutual_information_MGH( dataset['groundtruth_communities'], pred_communities_max ).score
            
            if iter != None:
                writer.add_scalar('test/nmi_max', internal_evaluation_scores['nmi_max'], iter)
            
    return internal_evaluation_scores, pred_communities, original_communities

Define hyper-parameters.

In [None]:
rng = np.random.RandomState(65)

param_grid = { 'learning_rate': [0.001], 
                 'learning_rate_decay_step_size' : [250],
                 'weight_decay': [0.0],
                 'hidden_channels_before_module_representation': [128],                   
                 'module_representation_length': [500,600,700,800,900,1000], 
                 'threshold': ['auto'], # 0.25,0.5,0.75,1.0,'auto'
                 'batchnorm' : [True],  
                 'dropout': [0.0],
                 'lambda_bce_loss': [10.0],
                 'lambda_l1_positives_loss': [0.0],
                 'lambda_l2_positives_loss': [0.0],
                 'model_type' : ["GCN"],
                 'model_output' : ['kegg|reactome|wikipathways|biocarta'], #
                 'transform_probability_method' : ['tanh'] }

 
param_list = list(ParameterSampler( param_grid, n_iter=4, random_state=rng))

Finally, train and validate the models.

In [7]:
epochs = 5000
eval_steps = 50

batch_index = 0
run_index = 0


final_results = list()

for params in param_list:

    output_models = {}
    for db in params['model_output'].split('|'):
        if db in dataset['msigdb_term_ids']:
            output_models[db] = PositiveLinear( in_features = params['module_representation_length'], 
                                                out_features = len(dataset['msigdb_term_ids'][db]), 
                                                ids = dataset['msigdb_term_ids'][db] ).to(device)

    model = GAEL( encoder = GNN( in_channels = data.num_features, 
                                 hidden_channels_before_module_representation = [params['hidden_channels_before_module_representation']], 
                                 module_representation_channels = params['module_representation_length'], 
                                 out_models = output_models, 
                                 dropout = params['dropout'], 
                                 batchnorm = params['batchnorm'], 
                                 transform_probability_method = params['transform_probability_method'],
                                 threshold = params['threshold'],
                                 type = params['model_type'] ),
                  decoder = InnerProductDecoder() ).to(device)

    print(model)

    total_params = sum(p.numel() for p in model.parameters())
    print("Total number of parameters:", total_params)

    try:      
        writer = SummaryWriter(f"runs/{G.name}/batch_{batch_index}/{model.name}_bce{params['lambda_bce_loss']}_{run_index:03d}")

        optimizer, scheduler = model.configure_optimizers( params )
        model.reset_parameters()

        iter = 1
        # start training
        for epoch in range(1, 1 + epochs):

            source_edge_index, pos_edge_index, neg_edge_index = torch_geometric.utils.structured_negative_sampling( edge_index = data.edge_index )

            # Generate a random permutation index
            num_samples = len(source_edge_index)
            perm = torch.randperm(num_samples)

            # Apply the permutation index to all three lists
            source_edge_index = source_edge_index[perm]
            pos_edge_index = pos_edge_index[perm]
            neg_edge_index = neg_edge_index[perm]

            # Divide each shuffled list into ten (approximately) equal parts
            num_parts = 10
            part_size = num_samples // num_parts

            for i in range(num_parts):
                start_idx = i * part_size
                end_idx = (i + 1) * part_size

                # Adjust the end index for the last part
                if i == num_parts - 1:
                    end_idx = num_samples

                source_part = source_edge_index[start_idx:end_idx]
                pos_part = pos_edge_index[start_idx:end_idx]
                neg_part = neg_edge_index[start_idx:end_idx]

                # training step
                train( model, data, source_part, pos_part, neg_part, 
                       lambda_bce_loss = params['lambda_bce_loss'], 
                       lambda_l1_positives_loss = params['lambda_l1_positives_loss'], lambda_l2_positives_loss = params['lambda_l2_positives_loss'], 
                       optimizer = optimizer, writer = writer, iter = iter )
                iter += 1

            # validation
            if epoch % eval_steps == 0:
                valid_results, pred_communities, original_communities = test( model, data, G, writer = writer, iter = epoch )

                results = {'epoch': epoch, **valid_results}

                print( "\n", results )

                # export results
                filename = f"./runs/{G.name}/batch_{batch_index}/{model.name}_bce{params['lambda_bce_loss']}_{run_index:03d}/modules_{epoch:05d}.json"
                saveResultsToJSON( valid_results, original_communities, filename, dataset['node_index_dict'] )

                # save model
                modelfilename = f"./runs/{G.name}/batch_{batch_index}/{model.name}_bce{params['lambda_bce_loss']}_{run_index:03d}/model_{epoch:05d}.pt"
                torch.save(model.state_dict(), modelfilename)

                # export weights of final model layers
                ws, ids = [], []
                for name in model.encoder.output_models:
                    ws.append( model.encoder.output_models[name].weight.cpu().detach().numpy().T )
                    ids.extend( model.encoder.output_models[name].id_list )

                weightfilename = f"./runs/{G.name}/batch_{batch_index}/{model.name}_bce{params['lambda_bce_loss']}_{run_index:03d}/weights_{epoch:05d}.csv"
                df_w = pd.DataFrame( data = np.hstack(ws), columns=ids)
                df_w.to_csv( weightfilename )
                
                final_results.append( { 'model': str(model), 'model-name': model.name, **params, 'epoch': epoch, 'run': run_index, **valid_results } )

                df = pd.DataFrame( final_results )
                df.to_csv( f'randomsearch_results{batch_index}_{G.name}.csv' )

            # apply learning rate scheduler
            scheduler.step()

        df = pd.DataFrame( final_results )
        df.to_csv( f'randomsearch_results{batch_index}_{G.name}.csv' )

        del model
        torch.cuda.empty_cache()

        run_index += 1

    except BaseException as error:
        print( model )
        print('An exception occurred: {}'.format(error))

        del model
        torch.cuda.empty_cache()



GAEL( Encoder: GCNConv(571, 128); ReLU; BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True); GCNConv(128, 800); SoftPlus; {PositiveLinear(in_features=800, out_features=186, bias=True)|PositiveLinear(in_features=800, out_features=1288, bias=True)|PositiveLinear(in_features=800, out_features=621, bias=True)|PositiveLinear(in_features=800, out_features=221, bias=True)}, Decoder: InnerProductDecoder )
An exception occurred: 
GAEL( Encoder: GCNConv(571, 128); ReLU; BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True); GCNConv(128, 900); SoftPlus; {PositiveLinear(in_features=900, out_features=186, bias=True)|PositiveLinear(in_features=900, out_features=1288, bias=True)|PositiveLinear(in_features=900, out_features=621, bias=True)|PositiveLinear(in_features=900, out_features=221, bias=True)}, Decoder: InnerProductDecoder )
Total number of parameters: 2276289
decay: {'encoder.conv_last.lin.weight', 'encoder.output_models.kegg.weight', 'en