# GNNExplainer on BA-Shapes dataset for 2-layer GIN

In [1]:
from dig.xgraph.dataset import SynGraphDataset
from dig.xgraph.models import *
import torch
from torch_geometric.data import DataLoader
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
import os.path as osp
import os

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Load dataset

In [2]:
def index_to_mask(index, size):
    mask = torch.zeros(size, dtype=torch.bool, device=index.device)
    mask[index] = 1
    return mask

def split_dataset(dataset):
    indices = []
    num_classes = 4
    train_percent = 0.7
    for i in range(num_classes):
        index = (dataset.data.y == i).nonzero().view(-1)
        index = index[torch.randperm(index.size(0))]
        indices.append(index)

    train_index = torch.cat([i[:int(len(i) * train_percent)] for i in indices], dim=0)

    rest_index = torch.cat([i[int(len(i) * train_percent):] for i in indices], dim=0)
    rest_index = rest_index[torch.randperm(rest_index.size(0))]

    dataset.data.train_mask = index_to_mask(train_index, size=dataset.data.num_nodes)
    dataset.data.val_mask = index_to_mask(rest_index[:len(rest_index) // 2], size=dataset.data.num_nodes)
    dataset.data.test_mask = index_to_mask(rest_index[len(rest_index) // 2:], size=dataset.data.num_nodes)

    dataset.data, dataset.slices = dataset.collate([dataset.data])

    return dataset

dataset = SynGraphDataset('./datasets', 'BA_shapes')
dataset.data.x = dataset.data.x.to(torch.float32)
dataset.data.x = dataset.data.x[:, :1]
# dataset.data.y = dataset.data.y[:, 2]
dim_node = dataset.num_node_features
dim_edge = dataset.num_edge_features
# num_targets = dataset.num_classes
num_classes = dataset.num_classes

splitted_dataset = split_dataset(dataset)
splitted_dataset.data.mask = splitted_dataset.data.test_mask
splitted_dataset.slices['mask'] = splitted_dataset.slices['test_mask']
dataloader = DataLoader(splitted_dataset, batch_size=1, shuffle=False)

### Load model and checkpoints

In [3]:
def check_checkpoints(root='./'):
    if osp.exists(osp.join(root, 'checkpoints')):
        return
    url = ('https://github.com/divelab/DIG_storage/raw/main/xgraph/checkpoints.zip')
    path = download_url(url, root)
    extract_zip(path, root)
    os.unlink(path)

model = GIN_2l(model_level='node', dim_node=dim_node, dim_hidden=300, num_classes=num_classes)
model.to(device)
check_checkpoints()
ckpt_path = osp.join('checkpoints', 'ba_shapes', 'GIN_2l', '0', 'GIN_2l_best.ckpt')
model.load_state_dict(torch.load(ckpt_path)['state_dict'])

<All keys matched successfully>

### Display example output

In [4]:
data = list(dataloader)[0].to(device)
out = model(data.x, data.edge_index)
print(out)

tensor([[  59.2552,    3.2258,   18.8490, -145.8462],
        [ 108.6987,  -13.9900,   67.3546, -267.5370],
        [  37.3232,   -1.5697,   13.8692,  -96.5077],
        ...,
        [  -6.7269,   -1.9655,   -4.3532,   10.0127],
        [  -5.7191,   -0.6677,   -5.7571,    8.5893],
        [  -5.4793,    3.3847,   -1.2214,    3.0312]], device='cuda:0',
       grad_fn=<AddmmBackward>)


### Load the explainer

In [5]:
from dig.xgraph.method import GNNExplainer
explainer = GNNExplainer(model, epochs=100, lr=0.01, explain_graph=False)

### Setup for evaluation

In [6]:
# --- Set the Sparsity to 0.5 ---
sparsity = 0.5

# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector, ExplanationProcessor
x_collector = XCollector(sparsity)
# x_processor = ExplanationProcessor(model=model, device=device)

### Run explainer on the given model and dataset

In [7]:
index = -1
for i, data in enumerate(dataloader):
    for j, node_idx in enumerate(torch.where(data.mask == True)[0].tolist()):
        index += 1
        print(f'explain graph {i} node {node_idx}')
        data.to(device)

        if torch.isnan(data.y[0].squeeze()):
            continue
    
        walks, masks, related_preds = \
            explainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes, node_idx=node_idx)
    
        x_collector.collect_data(masks, related_preds, data.y[0].squeeze().long().item())
        
        # if you only have the edge masks without related_pred, please feed sparsity controlled mask to
        # obtain the result: x_processor(data, masks, x_collector)
        if index >= 99:
            break

    if index >= 99:
        break

explain graph 0 node 17
explain graph 0 node 20
explain graph 0 node 30
explain graph 0 node 31
explain graph 0 node 48
explain graph 0 node 57
explain graph 0 node 71
explain graph 0 node 74
explain graph 0 node 75
explain graph 0 node 95
explain graph 0 node 96
explain graph 0 node 99
explain graph 0 node 105
explain graph 0 node 109
explain graph 0 node 115
explain graph 0 node 120
explain graph 0 node 126
explain graph 0 node 136
explain graph 0 node 137
explain graph 0 node 138
explain graph 0 node 148
explain graph 0 node 159
explain graph 0 node 171
explain graph 0 node 172
explain graph 0 node 180
explain graph 0 node 187
explain graph 0 node 189
explain graph 0 node 191
explain graph 0 node 192
explain graph 0 node 195
explain graph 0 node 203
explain graph 0 node 209
explain graph 0 node 216
explain graph 0 node 219
explain graph 0 node 220
explain graph 0 node 222
explain graph 0 node 224
explain graph 0 node 230
explain graph 0 node 249
explain graph 0 node 256
explain grap

### Output metrics evaluation results

In [8]:
print(f'Fidelity: {x_collector.fidelity:.4f}\n'
      f'Fidelity_inv: {x_collector.fidelity_inv:.4f}\n'
      f'Sparsity: {x_collector.sparsity:.4f}')

Fidelity: 0.4989
Fidelity_inv: -0.0150
Sparsity: 0.5000


For more details, please refer to [https://github.com/divelab/DIG/tree/dig/benchmarks/xgraph/supp/](https://github.com/divelab/DIG/tree/dig/benchmarks/xgraph/supp/)