# GNN-LRP on BA-LRP dataset

In [1]:
from dig.xgraph.dataset import BA_LRP
from dig.xgraph.models import GCN_3l
import torch
from torch.utils.data import random_split
from torch_geometric.data import DataLoader
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
from torch_geometric.data.dataset import files_exist
import os.path as osp
import os

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Load dataset

In [2]:
def split_dataset(dataset, dataset_split=[0.8, 0.1, 0.1]):
    dataset_len = len(dataset)
    dataset_split = [int(dataset_len * dataset_split[0]),
                     int(dataset_len * dataset_split[1]),
                     0]
    dataset_split[2] = dataset_len - dataset_split[0] - dataset_split[1]
    train_set, val_set, test_set = \
        random_split(dataset, dataset_split)

    return {'train': train_set, 'val': val_set, 'test': test_set}

dataset = BA_LRP('datasets')
dataset.data.x = dataset.data.x.to(torch.float32)
dataset.data.y = dataset.data.y[:, 0]
dim_node = dataset.num_node_features
dim_edge = dataset.num_edge_features
num_targets = dataset.num_classes
num_classes = 2

splitted_dataset = split_dataset(dataset)
dataloader = DataLoader(splitted_dataset['test'], batch_size=1, shuffle=False)

### Load model and checkpoints

In [3]:
def check_checkpoints(root='./'):
    if osp.exists(osp.join(root, 'checkpoints')):
        return
    url = ('https://github.com/divelab/DIG_storage/raw/main/xgraph/checkpoints.zip')
    path = download_url(url, root)
    extract_zip(path, root)
    os.unlink(path)

model = GCN_3l(model_level='graph', dim_node=dim_node, dim_hidden=300, num_classes=num_classes)
model.to(device)
check_checkpoints()
ckpt_path = osp.join('checkpoints', 'ba_lrp', 'GCN_3l', '0', 'GCN_3l_best.ckpt')
model.load_state_dict(torch.load(ckpt_path)['state_dict'])

<All keys matched successfully>

### Display example output

In [4]:
data = list(dataloader)[0].to(device)
out = model(data.x, data.edge_index)
print(out)

tensor([[ 3.4417, -3.4062]], device='cuda:0', grad_fn=<AddmmBackward>)


### Load the explainer

In [5]:
from dig.xgraph.method import GNN_LRP
explainer = GNN_LRP(model, explain_graph=True)

### Setup for evaluation

In [6]:
# --- Set the Sparsity to 0.5 ---
sparsity = 0.5

# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector, ExplanationProcessor
x_collector = XCollector(sparsity)
# x_processor = ExplanationProcessor(model=model, device=device)

### Run explainer on the given model and dataset

In [7]:
index = -1
for index, data in enumerate(dataloader):
    print(f'explain graph line {dataloader.dataset.indices[index] + 2}')
    data.to(device)

    if torch.isnan(data.y[0].squeeze()):
        continue

    walks, masks, related_preds = \
        explainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes)

    x_collector.collect_data(masks, related_preds, data.y[0].squeeze().long().item())

    # if you only have the edge masks without related_pred, please feed sparsity controlled mask to
    # obtain the result: x_processor(data, masks, x_collector)

    if index >= 99:
        break

explain graph line 10982
explain graph line 10544
explain graph line 8199
explain graph line 5612
explain graph line 16665
explain graph line 2517
explain graph line 18422
explain graph line 8885
explain graph line 16072
explain graph line 9213
explain graph line 12727
explain graph line 13930
explain graph line 981
explain graph line 15235
explain graph line 9196
explain graph line 11909
explain graph line 4167
explain graph line 11524
explain graph line 3282
explain graph line 4993
explain graph line 9617
explain graph line 7064
explain graph line 10585
explain graph line 18522
explain graph line 7461
explain graph line 14086
explain graph line 3027
explain graph line 16680
explain graph line 14267
explain graph line 4300
explain graph line 372
explain graph line 9251
explain graph line 8630
explain graph line 5466
explain graph line 6546
explain graph line 1637
explain graph line 9295
explain graph line 10958
explain graph line 3233
explain graph line 10715
explain graph line 1884
e

### Output metrics evaluation results

In [8]:
print(f'Fidelity: {x_collector.fidelity:.4f}\n'
      f'Fidelity_inv: {x_collector.fidelity_inv:.4f}\n'
      f'Sparsity: {x_collector.sparsity:.4f}')

Fidelity: 0.4786
Fidelity_inv: 0.1102
Sparsity: 0.5000


For more details, please refer to [https://github.com/divelab/DIG/tree/dig/benchmarks/xgraph/supp/](https://github.com/divelab/DIG/tree/dig/benchmarks/xgraph/supp/)