# GNNExplainer on BA-Shapes dataset for 2-layer GIN

In [1]:
from dig.xgraph.dataset import SynGraphDataset
import torch
from torch.utils.data import random_split
from torch_geometric.data import DataLoader
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
from torch_geometric.data.dataset import files_exist
import os.path as osp
import os

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Load dataset

In [2]:
def split_dataset(dataset, dataset_split=[0.8, 0.1, 0.1]):
    dataset_len = len(dataset)
    dataset_split = [int(dataset_len * dataset_split[0]),
                     int(dataset_len * dataset_split[1]),
                     0]
    dataset_split[2] = dataset_len - dataset_split[0] - dataset_split[1]
    train_set, val_set, test_set = \
        random_split(dataset, dataset_split)

    return {'train': train_set, 'val': val_set, 'test': test_set}

dataset = SynGraphDataset('datasets', 'ba_shapes')
dataset.data.x = dataset.data.x.to(torch.float32)
dataset.data.y = dataset.data.y[:, 0] 
dim_node = dataset.num_node_features
dim_edge = dataset.num_edge_features
num_targets = dataset.num_classes
num_classes = 2

splitted_dataset = split_dataset(dataset)
dataloader = DataLoader(splitted_dataset['test'], batch_size=1, shuffle=False)

### Load model and checkpoints

In [3]:
from dig.xgraph.models import GIN_2l

def check_checkpoints(root='./'):
    if osp.exists(osp.join(root, 'checkpoints')):
        return
    url = ('https://github.com/divelab/DIG_storage/raw/main/xgraph/checkpoints.zip')
    path = download_url(url, root)
    extract_zip(path, root)
    os.unlink(path)

model = GIN_2l(model_level='graph', dim_node=dim_node, dim_hidden=300, num_classes=num_classes)
model.to(device)
check_checkpoints()
ckpt_path = osp.join('checkpoints', 'ba_shapes', 'GIN_2l', '0', 'GIN_2l_best.ckpt')
model.load_state_dict(torch.load(ckpt_path)['state_dict'])

<All keys matched successfully>

### Display example output

In [4]:
data = list(dataloader)[0].to(device)
out = model(data.x, data.edge_index)
print(out)

tensor([[-2.3228,  2.3888]], device='cuda:0', grad_fn=<AddmmBackward>)


### Load the explainer

In [5]:
from dig.xgraph.method import GNNExplainer
explainer = GNNExplainer(model, epochs=100, lr=0.01, explain_graph=False)

### Setup for evaluation

In [6]:
# --- Set the Sparsity to 0.5 ---
sparsity = 0.5

# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector, ExplanationProcessor
x_collector = XCollector(sparsity)
# x_processor = ExplanationProcessor(model=model, device=device)

### Run explainer on the given model and dataset

In [7]:
index = -1
for i, data in enumerate(dataloader):
    for j, node_idx in enumerate(torch.where(data.mask == True)[0].tolist()):
        index += 1
        print(f'explain graph {i} node {node_idx}')
        data.to(device)

        if torch.isnan(data.y[0].squeeze()):
            continue
    
        walks, masks, related_preds = \
            explainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes, node_idx=node_idx)
    
        x_collector.collect_data(masks, related_preds, data.y[0].squeeze().long().item())
        
        # if you only have the edge masks without related_pred, please feed sparsity controlled mask to
        # obtain the result: x_processor(data, masks, x_collector)
        if index >= 99:
            break

    if index >= 99:
        break

explain graph line 359
explain graph line 1002
explain graph line 823
explain graph line 1285
explain graph line 315
explain graph line 1280
explain graph line 228
explain graph line 471
explain graph line 1143
explain graph line 845
explain graph line 890
explain graph line 570
explain graph line 37
explain graph line 1090
explain graph line 911
explain graph line 1237
explain graph line 1120
explain graph line 120
explain graph line 1206
explain graph line 371
explain graph line 489
explain graph line 1267
explain graph line 1343
explain graph line 384
explain graph line 1445
explain graph line 416
explain graph line 443
explain graph line 967
explain graph line 1196
explain graph line 403
explain graph line 1103
explain graph line 483
explain graph line 1055
explain graph line 838
explain graph line 914
explain graph line 34
explain graph line 1035
explain graph line 977
explain graph line 217
explain graph line 511
explain graph line 948
explain graph line 1330
explain graph line 2

### Output metrics evaluation results

In [8]:
print(f'Fidelity: {x_collector.fidelity:.4f}\n'
      f'Fidelity_inv: {x_collector.fidelity_inv:.4f}\n'
      f'Sparsity: {x_collector.sparsity:.4f}')

Fidelity: 0.5038
Fidelity_inv: 0.0988
Sparsity: 0.5000


For more details, please refer to [https://github.com/divelab/DIG/tree/dig/benchmarks/xgraph/supp/](https://github.com/divelab/DIG/tree/dig/benchmarks/xgraph/supp/)