In [None]:
import torch

### Setup in google colab

Uncomment the code in the following cells to use this notebook in google colab

In [None]:
# def format_pytorch_version(version):
#   return version.split('+')[0]
#
# TORCH_version = torch.__version__
# TORCH = format_pytorch_version(TORCH_version)
#
# def format_cuda_version(version):
#   return 'cu' + version.replace('.', '')
#
# CUDA_version = torch.version.cuda
# CUDA = "cpu"
#
# !pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
# !pip install torch-geometric
#

In [None]:
# !git clone https://github.com/funket/zorro.git
#

In [None]:
# !pwd
#

In [None]:
# %cd zorro/
#

In [None]:
# !pwd

In [None]:
from explainer import *
from models import *
import torch
import matplotlib.pylab as plt

# Data loading and GNN training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset, data, results_path = load_dataset("Cora")
model = GCNNet(dataset)
model.to(device)
data = data.to(device)

In [None]:
train_model(model, data)

# Gradient based explanation

In [None]:
from gnn_explainer import GNNExplainer

# GNNExplainer class needed for retrieval of computational graph
gnn_explainer = GNNExplainer(model, log=False)

explain_node = 0

In [None]:
def execute_model_with_gradient(model, node, x, edge_index):
    """Helper function, which mainly does a forward pass of the GNN"""
    ypred = model(x, edge_index)

    predicted_labels = ypred.argmax(dim=-1)
    predicted_label = predicted_labels[node]
    logit = torch.nn.functional.softmax((ypred[node, :]).squeeze(), dim=0)

    logit = logit[predicted_label]
    loss = -torch.log(logit)
    loss.backward()

In [None]:
def get_grad_node_explanation(model, node, data):
    """Calculates the gradient feature and node explanation"""

    # retrieve computational graph
    computation_graph_feature_matrix, computation_graph_edge_index, mapping, hard_edge_mask, kwargs = \
                                    gnn_explainer.__subgraph__(node, data.x, data.edge_index)
    # from now only work on the computational graph
    x = computation_graph_feature_matrix
    edge_index = computation_graph_edge_index

    # create a mask of ones which will be differentiated
    num_nodes, num_features = x.size()
    node_grad = torch.nn.Parameter(torch.ones(num_nodes, device=x.device))
    feature_grad = torch.nn.Parameter(torch.ones(num_features, device=x.device))
    node_grad.requires_grad = True
    feature_grad.requires_grad = True
    mask = node_grad.unsqueeze(0).T.matmul(feature_grad.unsqueeze(0))

    model.zero_grad()
    execute_model_with_gradient(model, mapping, mask*x, edge_index)

    node_mask = torch.abs(node_grad.grad).cpu().detach().numpy()
    feature_mask = torch.abs(feature_grad.grad).cpu().detach().numpy()

    return feature_mask, node_mask

In [None]:
grad_explanation = get_grad_node_explanation(model, explain_node, data)

In [None]:
plt.title("Distribution of Feature mask")
plt.hist(grad_explanation[0])
plt.yscale("log")

##### Possible task: implementation of GradInput

# GNNExplainer

In [None]:
def get_gnn_explainer(node, data):
    feature_mask, edge_mask = gnn_explainer.explain_node(node, data.x, data.edge_index)
    return feature_mask, edge_mask

In [None]:
gnn_explanation = get_gnn_explainer(explain_node, data)

In [None]:
plt.title("Distribution of Feature mask")
plt.hist(gnn_explanation[0])
plt.yscale("log")

# Zorro

In [None]:
from explainer import Zorro

zorro = Zorro(model, device)
def get_zorro(node):
    # Same as the 0.98 in the paper
    tau = .03
    # only retrieve 1 explanation
    recursion_depth = 1

    explanation = zorro.explain_node(node, data.x, data.edge_index, tau=tau, recursion_depth=recursion_depth,)

    selected_nodes, selected_features, executed_selections = explanation[0]

    return selected_features[0], selected_nodes[0]

In [None]:
zorro_explanation = get_zorro(explain_node)

In [None]:
plt.title("Distribution of Feature mask")
plt.hist(zorro_explanation[0])
plt.yscale("log")

# SoftZorro

In [None]:
from explainer import SoftZorro

soft_zorro = SoftZorro(model, device)

def get_soft_zorro(node):
    node_mask, feature_mask = soft_zorro.explain_node(node, data.x, data.edge_index)
    return feature_mask[0], node_mask[0]

In [None]:
soft_zorro_explanation = get_soft_zorro(explain_node)

In [None]:
plt.title("Distribution of Feature mask")
plt.hist(soft_zorro_explanation[0])
plt.yscale("log")