**------------------------------------------------------------------------------------------------------------------------------------------------------**

**Input: Drug Repurposing Knowledge Graph (DRKG)**

**This notebook returns a trained GraphSAGE, GCN and GAT**

**------------------------------------------------------------------------------------------------------------------------------------------------------**

# Librairies

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import json

import os
import re
import random
import itertools

import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F

import dgl
import dgl.nn as dglnn
import dgl.function as fn
from dgl.nn import HeteroGraphConv, SAGEConv, GraphConv, GATConv
from dgl.data.utils import save_graphs, load_graphs

from torch_geometric.explain import characterization_score

from captum.attr import Saliency, IntegratedGradients

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

from functools import partial

import warnings
warnings.simplefilter("ignore")

from src.utils import *
from src.gnn import *

# 1) Get Data

**Get DRKG**

In [2]:
df = pd.read_csv('Input/DRKG/drkg.tsv', sep='\t', header=None)
df = df.dropna()

**Create HeteroGraph**

In [3]:
node_dict = get_node_dict(df)
edge_dict = get_edge_dict(df, node_dict)
g = dgl.heterograph(edge_dict)

**Add random node features**

In [4]:
g, node_features = add_node_features(g)

**Split into train/test graphs**

In [5]:
etype2pred = ('Compound', 'DRUGBANK::treats::Compound:Disease', 'Disease')

In [6]:
g_train, g_test = split_train_test(g, etype2pred)

# 2) Train Graph Neural Networks

**Graph Convolutional Network**

In [7]:
gcn_model = Model(gnn_variant = 'GCN', 
                  etypes = g.etypes, 
                  etype2pred = etype2pred,
                  g_train = g_train, 
                  g_test = g_test, 
                  node_features = node_features)
gcn_model._train()
gcn_model._eval()
torch.save(gcn_model, 'Output/GNNModels/GCN')

Epoch: 0, Loss: 10.2109
Epoch: 50, Loss: 0.2222
Epoch: 100, Loss: 0.1219
Precision: 0.9512
Recall: 0.0392
F1-Score: 0.0754


**GraphSAGE**

In [8]:
graphsage_model = Model(gnn_variant = 'GraphSAGE', 
                  etypes = g.etypes, 
                  etype2pred = etype2pred,
                  g_train = g_train, 
                  g_test = g_test, 
                  node_features = node_features)
graphsage_model._train()
graphsage_model._eval()
torch.save(graphsage_model, 'Output/GNNModels/GraphSAGE')

Epoch: 0, Loss: 9935.2969
Epoch: 50, Loss: 66.1691
Epoch: 100, Loss: 97.5386
Precision: 0.9417
Recall: 0.3089
F1-Score: 0.4652


**Graph Attention Network**

In [9]:
gat_model = Model(gnn_variant = 'GAT', 
                  etypes = g.etypes, 
                  etype2pred = etype2pred,
                  g_train = g_train, 
                  g_test = g_test, 
                  node_features = node_features)
gat_model._train()
gat_model._eval()
torch.save(gat_model, 'Output/GNNModels/GAT')

Epoch: 0, Loss: 538.9611
Epoch: 50, Loss: 3.3144
Epoch: 100, Loss: 2.1257
Precision: 0.1747
Recall: 0.9688
F1-Score: 0.2960
