# GNN-LRP on BA-LRP dataset for GCN

In [1]:
import os
import os.path as osp

import torch
from torch.utils.data import random_split
from torch_geometric.data import download_url, extract_zip
from torch_geometric.loader import DataLoader

from dig.xgraph.dataset import BA_LRP
from dig.xgraph.models import GCN_3l
from dig.xgraph.utils.compatibility import compatible_state_dict
from dig.xgraph.utils.init import fix_random_seed

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


### Load dataset

In [2]:
def split_dataset(dataset, dataset_split=(0.8, 0.1, 0.1)):
    dataset_len = len(dataset)
    dataset_split = [int(dataset_len * dataset_split[0]),
                     int(dataset_len * dataset_split[1]),
                     0]
    dataset_split[2] = dataset_len - dataset_split[0] - dataset_split[1]
    train_set, val_set, test_set = random_split(dataset, dataset_split)

    return {'train': train_set, 'val': val_set, 'test': test_set}


fix_random_seed(123)
dataset = BA_LRP('datasets')

dataset.data.x = dataset.data.x.to(torch.float32)
dataset.data.y = dataset.data.y[:, 0]
dim_node = dataset.num_node_features
dim_edge = dataset.num_edge_features
num_targets = dataset.num_classes
num_classes = 2

splitted_dataset = split_dataset(dataset)
dataloader = DataLoader(splitted_dataset['test'], batch_size=1, shuffle=False)

### Load model and checkpoints

In [3]:
def check_checkpoints(root='./'):
    if osp.exists(osp.join(root, 'checkpoints')):
        return
    url = ('https://github.com/divelab/DIG_storage/raw/main/xgraph/checkpoints.zip')
    path = download_url(url, root)
    extract_zip(path, root)
    os.unlink(path)


model = GCN_3l(model_level='graph', dim_node=dim_node, dim_hidden=300, num_classes=num_classes)
model.to(device)
check_checkpoints()
ckpt_path = osp.join('checkpoints', 'ba_lrp', 'GCN_3l', '0', 'GCN_3l_best.ckpt')
state_dict = compatible_state_dict(torch.load(ckpt_path, map_location='cpu')['state_dict'])
model.load_state_dict(state_dict)

<All keys matched successfully>

### Display example output

In [4]:
data = list(dataloader)[0].to(device)
out = model(data.x, data.edge_index)
print(out)

tensor([[-2.5759,  2.6295]], device='cuda:0', grad_fn=<AddmmBackward0>)


### Load the explainer

In [5]:
from dig.xgraph.method import GNN_LRP

explainer = GNN_LRP(model, explain_graph=True)

### Setup for evaluation

In [6]:
# --- Set the Sparsity to 0. ---
sparsity = 0.5

# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector

x_collector = XCollector(sparsity)
# x_processor = ExplanationProcessor(model=model, device=device)

### Run explainer on the given model and dataset

In [7]:
for index, data in enumerate(dataloader):
    print(f'explain graph line {dataloader.dataset.indices[index] + 2}')
    data.to(device)

    if torch.isnan(data.y[0].squeeze()):
        continue

    walks, masks, related_preds = explainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes)

    x_collector.collect_data(masks, related_preds, data.y[0].squeeze().long().item())

    # if you only have the edge masks without related_pred, please feed sparsity controlled mask to
    # obtain the result: x_processor(data, masks, x_collector)

    if index >= 99:
        break

explain graph line 4835
explain graph line 3443
explain graph line 14435
explain graph line 6575
explain graph line 1521
explain graph line 2454
explain graph line 2551
explain graph line 15026
explain graph line 14862
explain graph line 18955
explain graph line 16517
explain graph line 4866
explain graph line 17951
explain graph line 7567
explain graph line 6920
explain graph line 14567
explain graph line 2847
explain graph line 11472
explain graph line 13305
explain graph line 6961
explain graph line 18725
explain graph line 13531
explain graph line 12320
explain graph line 2744
explain graph line 19356
explain graph line 14045
explain graph line 17569
explain graph line 5872
explain graph line 11646
explain graph line 4492
explain graph line 17075
explain graph line 9038
explain graph line 17290
explain graph line 64
explain graph line 974
explain graph line 17241
explain graph line 693
explain graph line 779
explain graph line 7396
explain graph line 18414
explain graph line 3176
e

### Output metrics evaluation results

In [8]:
print(f'Fidelity: {x_collector.fidelity:.4f}\n'
      f'Fidelity_inv: {x_collector.fidelity_inv:.4f}\n'
      f'Sparsity: {x_collector.sparsity:.4f}')

Fidelity: 0.5191
Fidelity_inv: 0.1258
Sparsity: 0.5000
