# DeepLIFT on Tox21 dataset for GCN

In [1]:
from dig.xgraph.dataset import MoleculeDataset
from dig.xgraph.models import GCN_3l
import torch
from torch.utils.data import random_split
from torch_geometric.data import DataLoader
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
from torch_geometric.data.dataset import files_exist
import os.path as osp
import os

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Load dataset

In [2]:
def split_dataset(dataset, dataset_split=[0.8, 0.1, 0.1]):
    dataset_len = len(dataset)
    dataset_split = [int(dataset_len * dataset_split[0]),
                     int(dataset_len * dataset_split[1]),
                     0]
    dataset_split[2] = dataset_len - dataset_split[0] - dataset_split[1]
    train_set, val_set, test_set = \
        random_split(dataset, dataset_split)

    return {'train': train_set, 'val': val_set, 'test': test_set}

dataset = MoleculeDataset('datasets', 'Tox21')
dataset.data.x = dataset.data.x.to(torch.float32)
dataset.data.y = dataset.data.y[:, 2] # the target 2 task.
dim_node = dataset.num_node_features
dim_edge = dataset.num_edge_features
num_targets = dataset.num_classes
num_classes = 2

splitted_dataset = split_dataset(dataset)
dataloader = DataLoader(splitted_dataset['test'], batch_size=1, shuffle=False)



### Load model and checkpoints

In [9]:
def check_checkpoints(root='./'):
    if osp.exists(osp.join(root, 'checkpoints')):
        return
    url = ('https://github.com/divelab/DIG_storage/raw/main/xgraph/checkpoints.zip')
    path = download_url(url, root)
    extract_zip(path, root)
    os.unlink(path)

model = GCN_3l(model_level='graph', dim_node=dim_node, dim_hidden=300, num_classes=num_classes)
model.to(device)
check_checkpoints()
ckpt_path = osp.join('checkpoints', 'tox21', 'GCN_3l', '2', 'GCN_3l_best.ckpt')
ckpt_dict = torch.load(ckpt_path)['state_dict']
ckpt_dict["conv1.lin.weight"]=ckpt_dict["conv1.weight"].T
ckpt_dict["convs.0.lin.weight"]=ckpt_dict["convs.0.weight"]
ckpt_dict["convs.1.lin.weight"]=ckpt_dict["convs.1.weight"]
ckpt_dict.pop("conv1.weight")
ckpt_dict.pop("convs.0.weight")
ckpt_dict.pop("convs.1.weight")
model.load_state_dict(ckpt_dict)


<All keys matched successfully>

### Display example output

In [4]:
data = list(dataloader)[0].to(device)
out = model(data.x, data.edge_index)
print(out)

tensor([[ 1.0779, -1.0313]], device='cuda:0', grad_fn=<AddmmBackward0>)


### Load the explainer

In [5]:
from dig.xgraph.method import DeepLIFT
explainer = DeepLIFT(model, explain_graph=True)

### Setup for evaluation

In [6]:
# --- Set the Sparsity to 0.5 ---
sparsity = 0.5

# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector, ExplanationProcessor
x_collector = XCollector(sparsity)
# x_processor = ExplanationProcessor(model=model, device=device)

### Run explainer on the given model and dataset

In [None]:
for index, data in enumerate(dataloader):
    print(f'explain graph line {dataloader.dataset.indices[index] + 2}')
    data.to(device)

    if torch.isnan(data.y[0].squeeze()):
        continue

    walks, masks, related_preds = \
        explainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes)

    x_collector.collect_data(masks, related_preds, data.y[0].squeeze().long().item())

    # if you only have the edge masks without related_pred, please feed sparsity controlled mask to
    # obtain the result: x_processor(data, masks, x_collector)

    # if index >= 99:
    #     break

In [11]:
x_collector

<dig.xgraph.evaluation.metrics.XCollector at 0x7f84d1e40250>

### Output metrics evaluation results

In [8]:
print(f'Fidelity: {x_collector.fidelity:.4f}\n'
      f'Fidelity_inv: {x_collector.fidelity_inv:.4f}\n'
      f'Sparsity: {x_collector.sparsity:.4f}')

Fidelity: 0.0007
Fidelity_inv: 0.0009
Sparsity: 0.5000
