# Comparison of Explainers : GNNExplainer, EdgeSHAPer, regSHAP

In [None]:
# Imports
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, '../../EdgeSHAPer/src/')
sys.path.insert(0,'../../models/pcqm4m-v2_ogb/')
sys.path.insert(0,'../../utils/')
sys.path.insert(0,'../../../AMLD-2021-Graphs/src/')

import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from tkinter import *
matplotlib.use('TkAgg')

# PyTorch related
import torch
import torch.nn.functional as F
from torch_geometric.data.data import Data
from torch_geometric.utils import to_undirected
from torch_geometric.explain import Explainer,GNNExplainer,PGExplainer
from torch_geometric.explain.metric import fidelity
from gnn import GNN
from visualization import * # AMDL-2021-Graphs
from torch_geometric.utils import to_networkx
import networkx as nx

from utils import *

from regSHAPer import regSHAP
from edgeshaper import edgeshaper
from scipy.special import binom
from itertools import combinations

# Dataset-related
import ogb
from ogb.lsc import PCQM4Mv2Dataset, PygPCQM4Mv2Dataset
from ogb.utils import smiles2graph

# Chemistry related
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem

In [None]:
# ================== #
# Importing datasets #
# ================== #
path_data = "../../data/"
dataset_smiles = PCQM4Mv2Dataset(root=path_data,only_smiles = True)
dataset_PyG = PygPCQM4Mv2Dataset(root=path_data)
suppl = Chem.SDMolSupplier(path_data+'pcqm4m-v2-train.sdf')
n_train = 3378606 # Number of molecules in the train subset of pcqm4m-v2 dataset

In [None]:
# =================== #
# Importing the model #
# =================== #
path_model = "../../models/pcqm4m-v2_ogb/"
model_pt = torch.load(path_model+"model_trained.pt",map_location=torch.device('cpu'))

shared_params = {
        'num_layers': 5,
        'emb_dim': 600,
        'drop_ratio': 0,
        'graph_pooling': 'sum'
    }
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
print(device)
model = GNN(gnn_type = 'gcn', virtual_node = False, **shared_params).to('cpu')
model.load_state_dict(model_pt['model_state_dict'])
epoch = model_pt['epoch']
print(epoch)

In [None]:
# ====================== #
# Importing explanations #
# ====================== #
path_script = "scripts/"
algo_files = {"GNNExplainer":"gnnexplainer/model_5gr_gnnexpl_dict.pt",
            "EdgeSHAPer":"EdgeSHAPer/5gr_edgeSHAPexpl_dict.pt",
            "regSHAPer":"regSHAPer/diff_5gr_regSHAP_dict.pt"}
algorithms = list(algo_files.keys())

algo_dicts = {}
algo_mol_indices = {}
for alg in algorithms:
    algo_dicts[alg] = torch.load(path_script+algo_files[alg])
    print(len(algo_dicts[alg].keys()), "molecules for "+alg)
    algo_mol_indices[alg] = list(algo_dicts[alg].keys())

In [None]:
# ===================== #
# Creating GNNExplainer #
# ===================== #
gnn_explainer = Explainer(
        model=model,
        algorithm=GNNExplainer(epochs=500),
        explanation_type="model",
        node_mask_type=None,
        edge_mask_type="object",
        model_config=dict(
            mode='regression',
            task_level='graph',
            return_type='raw',
        )
    )

## Metrics and visualization

### Metrics

#### FID +/-

In [None]:
idx_expl = np.arange(0,5000,1)
print("number of mols :",len(idx_expl))
idx_mols = np.array([algo_mol_indices["EdgeSHAPer"][i] for i in idx_expl])
list_graphs = dataset_PyG[idx_mols]
positive = False
for alg in algorithms:
    explanations = [torch.abs(torch.tensor(algo_dicts[alg][i])) for i in idx_mols]
    print(explanations[0])
    print(compute_fid(model,list_graphs,positive,explanations,5,None))

#### GEF

In [None]:
# Creating list of graphs and explanations
idx_mols = algo_mol_indices[alg][:5000]
list_graphs = dataset_PyG[idx_mols]

expl = [torch.tensor(algo_dicts[alg][i]) for i in idx_mols]
k = 5
dict_gef = compute_gef(idx_mols,list_graphs,algo_dicts,model,k,algorithms)

In [None]:
print("GNNExplainer:",dict_gef["GNNExplainer"][1])
print("EdgeSHAPer:",dict_gef["EdgeSHAPer"][1])
print("regSHAPer:",dict_gef["regSHAPer"][1])
# regSHAP is better for this metric. (the lower the better)

In [None]:
# verif mean
y = torch.tensor([data.y for data in list_graphs])
print(torch.mean(y))
print(np.mean(dict_gef["GNNExplainer"][0][:,1]))

### Plots for the report

#### Graph vs Subgraph predictions

In [None]:
# Showing graph vs subgraph
col_dict = {"GNNExplainer":"red",
            "EdgeSHAPer":"mediumblue",
            "regSHAPer":"orange"}
# Creating list of graphs and explanations
idx_mols = algo_mol_indices[alg][:5000]
list_graphs = dataset_PyG[idx_mols]

In [None]:
for k in [3,5,7,9]:
    dict_gef = compute_gef(idx_mols,list_graphs,algo_dicts,model,k,algorithms)
    graph_subgraph(dict_gef,k,col_dict,(2,14),(2,14),"plots/pred_graph_subgraph/graph_subgraph_k_"+str(k)+".png")

In [None]:
# MAE for last k (verif)
for i,alg in enumerate(dict_gef.keys()):
    print(alg)
    print(np.mean((dict_gef[alg][0][:,0]-dict_gef[alg][0][:,1])**2))

#### Threshold graph

In [None]:
thresh = np.arange(0,1.01,0.01)
compute_tresh(thresh,5000,algo_dicts,col_dict,algorithms,"plots/threshold/threshold_plot.png")

#### Molecule visualization

In [None]:
idx_mols = [629596,634384]

mol_viz(algorithms,idx_mols,(10,10),dataset_PyG,algo_dicts,"./plots/expl_viz/expl_mol"+str(idx_mols))

In [None]:
idx_mols = [algo_mol_indices[alg][idx] for idx in np.random.randint(5000,size=4)]
idx_mols = [[idx_mols[0],idx_mols[1]],[idx_mols[2],idx_mols[3]]]
for i in idx_mols:
    mol_viz(algorithms,i,(10,10),dataset_PyG,algo_dicts,"./plots/expl_viz/expl_mol"+str(i))

# Other snippets of code and/or old code

In [None]:

idx_random = np.random.randint(5)
idx_random

In [None]:
# Show explanation -> broken due to indices missmatch
idx = algo_mol_indices["regSHAP"][40]
print(idx)
show_explanation(algorithms,dataset_PyG,dataset_smiles,idx,algo_dicts,"jet")

In [None]:
# typical barplot instead
idx = algo_mol_indices["regSHAP"][40]
show_barplot_explanation(algorithms,dataset_PyG,idx,algo_dicts,"out.png")

In [None]:
# Verifying the efficiency Shapley property
data = dataset_PyG[0]
x = data.x
y = data.y
e1 = data.edge_index[:,4:16]
e2 = data.edge_index[:,-2:]

edge_index = torch.cat((e1,e2),axis=1)
new_edge_attr = torch.cat((data.edge_attr[4:16,:],data.edge_attr[-2:,:]),axis=0)
new_x_idx = np.unique(edge_index[0]) # selecting the nodes which degree is > 0
new_x = data.x[new_x_idx,:]
new_edge_index = torch.vstack((rank_arr(edge_index[0],False),rank_arr(edge_index[1],False)))
new_graph = Data(x=new_x,edge_index=new_edge_index,edge_attr=new_edge_attr,y=y)
test_regSHAP = regSHAP(new_graph,64,model,42,False,False)
batch = torch.zeros(new_graph.x.shape[0], dtype=int, device=new_graph.x.device)
print(abs(sum(test_regSHAP)))
print(model(new_graph.x,new_graph.edge_index,new_graph.edge_attr,batch))

In [None]:
# Computing the mean, min and max number of edges in the 5000 molecules
mol_idx = algo_mol_indices[alg][:5000]
c = 0
c_min = 50
c_max = 0
for mol in mol_idx:
    explanation = algo_dicts["regSHAP"][mol]
    c_min = min(c_min,len(explanation))
    c_max = max(c_max,len(explanation))
    c += len(explanation)
print(c/5000)
print(c_min)
print(c_max)