**------------------------------------------------------------------------------------------------------------------------------------------------------**

**Input: Drug Repurposing Knowledge Graph (DRKG)**

**This notebook returns a trained GraphSAGE, GCN and GAT**

**------------------------------------------------------------------------------------------------------------------------------------------------------**

# Librairies

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import json

import os
import re
import random
import itertools

import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F

import dgl
import dgl.nn as dglnn
import dgl.function as fn
from dgl import AddReverse
from dgl.nn import HeteroGraphConv, SAGEConv, GraphConv, GATConv

import warnings
warnings.simplefilter("ignore")

from src.utils import *
from src.gnn import *

In [None]:
etype2pred = ('Compound', 'DRUGBANK::treats::Compound:Disease', 'Disease')

# 1) Get Data

**Get DRKG**

In [None]:
df = pd.read_csv('Input/DRKG/drkg.tsv', sep='\t', header=None)
df.dropna(inplace=True)

**Create HeteroGraph**

In [None]:
node_dict = get_node_dict(df)
edge_dict = get_edge_dict(df, node_dict)
g = dgl.heterograph(edge_dict)

**Add reverse edges in order to to let a GNN be able to pass messages in both directions**

In [None]:
transform = AddReverse()

In [None]:
g = transform(g)

**Add random node features**

In [None]:
g, node_features = add_node_features(g)

**Split into train/test graphs**

In [None]:
dgl.save_graphs('Input/DRKG/drkg', g)

In [None]:
g_train, g_test = split_train_test(g, etype2pred)

**Create folders**

In [None]:
if not os.path.exists('Output'):
    os.makedirs('Output')
if not os.path.exists('Output/GNNEmbeddings'):
    os.makedirs('Output/GNNEmbeddings')
if not os.path.exists('Output/GNNModels'):
    os.makedirs('Output/GNNModels')

# 2) Train Graph Neural Networks

**Graph Convolutional Network**

In [None]:
gcn_model = Model(gnn_variant = 'GCN', 
                  etypes = g.etypes, 
                  etype2pred = etype2pred,
                  g_train = g_train, 
                  g_test = g_test, 
                  node_features = node_features)

gcn_model._train(epochs=150)
gcn_model_embedding = gcn_model.get_embeddings()

torch.save(gcn_model_embedding, 'Output/GNNModels/GCN_embedding')
torch.save(gcn_model, 'Output/GNNModels/GCN')

**GraphSAGE**

In [None]:
graphsage_model = Model(gnn_variant = 'GraphSAGE', 
                  etypes = g.etypes, 
                  etype2pred = etype2pred,
                  g_train = g_train, 
                  g_test = g_test, 
                  node_features = node_features)
# graphsage_model = torch.load('Output/GNNModels/GraphSAGE')
graphsage_model._train(epochs=150)
gcn_model_embedding = gcn_model.get_embeddings()
torch.save(graphsage_model, 'Output/GNNModels/GraphSAGE')
graphsage_model_embeddings = graphsage_model.get_embeddings()
torch.save(graphsage_model_embeddings, 'Output/GNNModels/GraphSAGE_embedding')


**Graph Attention Network**

In [None]:
gat_model = Model(gnn_variant = 'GAT', 
                  etypes = g.etypes, 
                  etype2pred = etype2pred,
                  g_train = g_train, 
                  g_test = g_test, 
                  node_features = node_features)
# gat_model = torch.load('Output/GNNModels/GAT')

gat_model._train(epochs=150)
torch.save(gat_model, 'Output/GNNModels/GAT')
gat_model_embeddings = gat_model.get_embeddings()
torch.save(gat_model_embeddings, 'Output/GNNModels/GAT_embedding')

# 3) Evaluate Graph Neural Networks

In [None]:
gcn_model = torch.load('Output/GNNModels/GCN')
hits5, hits10, precision, recall, f1 = gcn_model._eval()
print(f'hits@5: {hits5:.3f}, hits@10: {hits10:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, f1-score: {f1:.3f}')

In [None]:
graphsage_model = torch.load('Output/GNNModels/GraphSAGE')
hits5, hits10, precision, recall, f1 = graphsage_model._eval()
print(f'hits@5: {hits5:.3f}, hits@10: {hits10:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, f1-score: {f1:.3f}')

In [None]:
gat_model = torch.load('Output/GNNModels/GAT')
hits5, hits10, precision, recall, f1 = gat_model._eval()
print(f'hits@5: {hits5:.3f}, hits@10: {hits10:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, f1-score: {f1:.3f}')