In [1]:
import torch
import pandas as pd
from torch_geometric.data import DataLoader, Batch, Data
from embedding import GNN
from downstream import MLP
from tagexplainer import TAGExplainer, MLPExplainer
from loader import MoleculeDataset
from splitters import scaffold_split
from dig.xgraph.evaluation import XCollector

device = torch.device("cuda:2" if torch.cuda.is_available() else torch.device("cpu"))

In [2]:
pre_train_dataset = MoleculeDataset("dataset/zinc_standard_agent", dataset='zinc_standard_agent')
train_loader = DataLoader(pre_train_dataset, 256, shuffle=True)

embed_model = GNN(num_layer = 5, emb_dim = 600, JK = 'last', drop_ratio = 0, gnn_type = 'gin')
embed_model.load_state_dict(torch.load('ckpts_model/chem_pretrained_contextpred.pth', map_location='cpu'))
enc_explainer = TAGExplainer(embed_model, embed_dim=600, device=device, explain_graph=True, 
                              grad_scale=0.2, coff_size=0.05, coff_ent=0.002, loss_type='JSE')

#### To train the explainer, uncomment the following cell.

In [3]:
# enc_explainer.train_explainer_graph(train_loader, lr=0.0001, epochs=1)
# torch.save(enc_explainer.explainer.state_dict(), 'ckpts_explainer/explain_mol_twostage.pt')

100%|██████████| 7813/7813 [23:56<00:00,  5.44it/s, loss=0.0217, log=75.2979, 1.0000, -21.9580, 0.0079]  


In [4]:
state_dict = torch.load('ckpts_explainer/explain_mol_twostage.pt')
enc_explainer.explainer.load_state_dict(state_dict)

<All keys matched successfully>

In [5]:
def get_task(idx):
    def transform(data):
        return Data(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr, y=data.y[idx:idx+1].long())
    return transform

def get_dataset(name, task=0):
    task_transform = get_task(task)
    dataset = MoleculeDataset("dataset/%s"%name, dataset=name, transform=task_transform)
    smiles_list = pd.read_csv('dataset/%s/processed/smiles.csv'%name, header=None)[0].tolist()
    train_dataset, valid_dataset, test_dataset = scaffold_split(
        dataset, smiles_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1)
    return train_dataset

def get_results(task_name, top_k, pos=True):
    train_dataset = get_dataset(task_name)
    mlp_model = MLP(num_layer = 2, emb_dim =600, hidden_dim = 600)
    mlp_model.load_state_dict(torch.load('ckpts_model/downstream_%s_contextpred.pth'%task_name, map_location='cpu'))
    mlp_explainer = MLPExplainer(mlp_model, device)

    x_collector = XCollector()
    dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers = 1)
    for i, data in enumerate(dataloader):
        if pos==(data.y < 0):
            continue
        if data.edge_index.shape[1]<=0:
            continue

        print(f'explain graph {i}...', end='\r')
        walks, masks, related_preds = \
            enc_explainer(data.to(device), mlp_explainer, top_k=top_k, mask_mode='split')

        x_collector.collect_data(masks, related_preds)

    fid, fid_std = x_collector.fidelity
    spa, spa_std = x_collector.sparsity

    print()
    print(f'Fidelity: {fid:.4f} +/- {fid_std:.4f}\n'
          f'Sparsity: {spa:.4f} +/- {spa_std:.4f}')

In [6]:
get_results('bace', top_k=5)

explain graph 965...
Fidelity: 0.3782 +/- 0.2934
Sparsity: 0.9026 +/- 0.0278


In [7]:
get_results('hiv', top_k=5)

explain graph 32795...
Fidelity: 0.5952 +/- 0.3200
Sparsity: 0.8806 +/- 0.0582


In [8]:
get_results('sider', top_k=5)

explain graph 1140...
Fidelity: 0.4067 +/- 0.3226
Sparsity: 0.8545 +/- 0.0751


In [9]:
get_results('bbbp', top_k=10, pos=False)

explain graph 1193...
Fidelity: 0.1878 +/- 0.1543
Sparsity: 0.7205 +/- 0.1162
