**------------------------------------------------------------------------------------------------------------------------------------------------------**

**Input: Prediction(s) obtained from a trained GNN - GraphSAGE, GCN and GAT**

**This notebook returns a dictionary containing the most important nodes and edges that were most influential for the prediction(s) using Integrated Gradients and Saliency Maps**

**------------------------------------------------------------------------------------------------------------------------------------------------------**

# Librairies

In [1]:
import pandas as pd
import numpy as np
import operator
import matplotlib.pyplot as plt
import networkx as nx
import json

import os
import re
import random
import itertools

import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F

import dgl
import dgl.nn as dglnn
import dgl.function as fn
from dgl import AddReverse
from dgl.data.utils import save_graphs

from captum.attr import Saliency, IntegratedGradients
from torch_geometric.explain import characterization_score

from functools import partial

import warnings
warnings.simplefilter("ignore")

from src.utils import *
from src.gnn import *
from src.explain import *

In [2]:
# etype we want to predict
etype = ('Compound', 'DRUGBANK::treats::Compound:Disease', 'Disease')

In [3]:
keys = ['Gene', 'Compound', 'Disease', 'Biological Process', 'Molecular Function', 'Pathway']

# 1) Get Subgraph

**Get DRKG**

In [4]:
df = pd.read_csv('Input/DRKG/drkg.tsv', sep='\t', header=None)
df = df.dropna()

**Define subgraph for Alzheimer, we drop all the 'compound treats disease (not Alzheimer)' edges**

In [5]:
labels = df[(df[1] == 'DRUGBANK::treats::Compound:Disease') & (df[2] != 'Disease::MESH:D000544')].index
df = df.drop(labels=labels)

**Create HeteroGraph**

In [6]:
node_dict = get_node_dict(df)
node_dict = {k:v for k,v in node_dict.items() if k in keys}
edge_dict = get_edge_dict(df, node_dict)
g = dgl.heterograph(edge_dict)

In [7]:
with open('Output/Explainability/Alzheimer/node_mapping.json', 'w') as file:
    json.dump(node_dict, file)

**Add reverse edges**

In [8]:
transform = AddReverse()
g = transform(g)

**Add node features**

In [9]:
g, node_features = add_node_features(g)

**Construct negative graph**

In [10]:
g_neg = construct_negative_graph(g, etype)

**Save subgraph**

In [11]:
save_graphs('Output/Explainability/Alzheimer/AlzheimerGraph', g_list=[g, g_neg])

# 2) Get Explainability

In [12]:
get_imp_node_dicts(g, etype, 'sal', keys, 500)

sal: 0
sal: 1
sal: 2
sal: 3
sal: 4
sal: 5
sal: 6
sal: 7
