# GradCAM on ClinTox dataset for GIN

In [1]:
import os
import os.path as osp

import torch
from torch.utils.data import random_split
from torch_geometric.data import download_url, extract_zip
from torch_geometric.loader import DataLoader

from dig.xgraph.dataset import MoleculeDataset
from dig.xgraph.utils.compatibility import compatible_state_dict
from dig.xgraph.utils.init import fix_random_seed

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


### Load dataset

In [2]:
def split_dataset(dataset, dataset_split=[0.8, 0.1, 0.1]):
    dataset_len = len(dataset)
    dataset_split = [int(dataset_len * dataset_split[0]),
                     int(dataset_len * dataset_split[1]),
                     0]
    dataset_split[2] = dataset_len - dataset_split[0] - dataset_split[1]
    train_set, val_set, test_set = random_split(dataset, dataset_split)

    return {'train': train_set, 'val': val_set, 'test': test_set}


fix_random_seed(123)
dataset = MoleculeDataset('datasets', 'clintox')
dataset.data.x = dataset.data.x.to(torch.float32)
dataset.data.y = dataset.data.y[:, 0]
dim_node = dataset.num_node_features
dim_edge = dataset.num_edge_features
num_targets = dataset.num_classes
num_classes = 2

splitted_dataset = split_dataset(dataset)
dataloader = DataLoader(splitted_dataset['test'], batch_size=1, shuffle=False)

### Load model and checkpoints

In [3]:
from dig.xgraph.models import GIN_3l


def check_checkpoints(root='./'):
    if osp.exists(osp.join(root, 'checkpoints')):
        return
    url = ('https://github.com/divelab/DIG_storage/raw/main/xgraph/checkpoints.zip')
    path = download_url(url, root)
    extract_zip(path, root)
    os.unlink(path)


model = GIN_3l(model_level='graph', dim_node=dim_node, dim_hidden=300, num_classes=num_classes)
model.to(device)
check_checkpoints()
ckpt_path = osp.join('checkpoints', 'clintox', 'GIN_3l', '0', 'GIN_3l_best.ckpt')
state_dict = compatible_state_dict(torch.load(ckpt_path, map_location='cpu')['state_dict'])
model.load_state_dict(state_dict)

<All keys matched successfully>

### Display example output

In [4]:
data = list(dataloader)[0].to(device)
out = model(data.x, data.edge_index)
print(out)

tensor([[-1.0756,  1.0493]], device='cuda:3', grad_fn=<AddmmBackward0>)


### Load the explainer

In [5]:
from dig.xgraph.method import GradCAM

explainer = GradCAM(model, explain_graph=True)

### Setup for evaluation

In [6]:
# --- Set the Sparsity to 0.5 ---
sparsity = 0.5

# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector

x_collector = XCollector(sparsity)
# x_processor = ExplanationProcessor(model=model, device=device)

### Run explainer on the given model and dataset

In [7]:
for index, data in enumerate(dataloader):
    print(f'explain graph line {dataloader.dataset.indices[index] + 2}')
    data.to(device)

    if torch.isnan(data.y[0].squeeze()):
        continue

    walks, masks, related_preds = explainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes)

    x_collector.collect_data(masks, related_preds, data.y[0].squeeze().long().item())

    # if you only have the edge masks without related_pred, please feed sparsity controlled mask to
    # obtain the result: x_processor(data, masks, x_collector)

    if index >= 99:
        break

explain graph line 1278
explain graph line 1153
explain graph line 25
explain graph line 721
explain graph line 292
explain graph line 186
explain graph line 1402
explain graph line 1095




explain graph line 1093
explain graph line 221
explain graph line 404
explain graph line 472
explain graph line 1180
explain graph line 419
explain graph line 1117
explain graph line 467
explain graph line 942
explain graph line 401
explain graph line 760
explain graph line 257
explain graph line 161
explain graph line 656
explain graph line 1350
explain graph line 744
explain graph line 56
explain graph line 835
explain graph line 1383
explain graph line 1063
explain graph line 18
explain graph line 174
explain graph line 1261
explain graph line 1341
explain graph line 973
explain graph line 1203
explain graph line 1280
explain graph line 671
explain graph line 1303
explain graph line 1311
explain graph line 1214
explain graph line 141
explain graph line 952
explain graph line 881
explain graph line 1213
explain graph line 7
explain graph line 130
explain graph line 1451
explain graph line 293
explain graph line 73
explain graph line 677
explain graph line 892
explain graph line 868
e

### Output metrics evaluation results

In [8]:
print(f'Fidelity: {x_collector.fidelity:.4f}\n'
      f'Fidelity_inv: {x_collector.fidelity_inv:.4f}\n'
      f'Sparsity: {x_collector.sparsity:.4f}')

Fidelity: 0.4519
Fidelity_inv: -0.0001
Sparsity: 0.5000
