In [1]:
import torch
import pandas as pd

from torch_geometric.data import DataLoader, Data
from torch_geometric.datasets import PPI
from torch_geometric.utils import remove_isolated_nodes

from dig.sslgraph.utils import Encoder
from dig.sslgraph.dataset import get_node_dataset

from downstream import MLP, EndtoEnd, train_MLP
from dig.xgraph.evaluation import XCollector

device = torch.device("cuda:1" if torch.cuda.is_available() else torch.device("cpu"))

In [2]:
def get_task(idx):
    def transform(data):
        return Data(x=data.x, edge_index=data.edge_index, y=data.y[:, idx])
    return transform

def get_task_rm_iso(idx):
    def transform(data):
        edge_index, _, mask = remove_isolated_nodes(data.edge_index, num_nodes=data.x.shape[0])
        return Data(x=data.x[mask], edge_index=edge_index, y=data.y[mask, idx])
    return transform
    
ppi = PPI('node_dataset/ppi/', transform=get_task_rm_iso(0))
loader = DataLoader(ppi, 1)

In [3]:
encoder = Encoder(feat_dim=ppi[0].x.shape[1], hidden_dim=600, 
                  n_layers=2, gnn='gcn', node_level=True, graph_level=False)
encoder.load_state_dict(torch.load('ckpts_model/ppi_pretrain_grace600_h2.pth', map_location='cpu'))

<All keys matched successfully>

In [4]:
from tagexplainer import TAGExplainer, MLPExplainer
enc_explainer = TAGExplainer(encoder, embed_dim=600, device=device, explain_graph=False, 
                              grad_scale=0.1, coff_size=0.05, coff_ent=0.002, loss_type='JSE')

#### To train the explainer, uncomment the following cell.

In [5]:
# enc_explainer.train_explainer_node(loader, batch_size=4, lr=5e-6, epochs=1)
# torch.save(enc_explainer.state_dict(), 'ckpts_explainer/explain_ppi_grace.pt')

100%|██████████| 390/390 [00:31<00:00, 12.25it/s, loss=-.24, log=0.4901, 0.9028, 0.1317, 0.1679]   
100%|██████████| 345/345 [00:27<00:00, 12.40it/s, loss=-.262, log=0.4368, 0.8754, -1.0137, 0.1058]  
100%|██████████| 566/566 [01:02<00:00,  9.04it/s, loss=-.359, log=2.0513, 0.9053, -1.1522, 0.1367]  
100%|██████████| 585/585 [01:10<00:00,  8.27it/s, loss=-.465, log=5.0124, 0.9107, -1.9695, 0.1184]  
100%|██████████| 395/395 [00:33<00:00, 11.84it/s, loss=0.133, log=3.8440, 0.9185, -2.5551, 0.1190] 
100%|██████████| 256/256 [00:17<00:00, 14.82it/s, loss=-.32, log=3.0553, 0.8555, -1.9011, 0.1751]  
100%|██████████| 456/456 [00:41<00:00, 10.87it/s, loss=-.127, log=3.7548, 0.8240, -2.5014, 0.0741] 
100%|██████████| 622/622 [01:17<00:00,  8.05it/s, loss=-.397, log=10.1655, 0.9261, -4.8972, 0.0603]  
100%|██████████| 148/148 [00:08<00:00, 18.12it/s, loss=-.394, log=10.5567, 0.9312, -1.6987, 0.2227]
100%|██████████| 828/828 [02:19<00:00,  5.94it/s, loss=-.27, log=13.6232, 0.9427, -9.1332, 0.03

In [6]:
state_dict = torch.load('ckpts_explainer/explain_ppi_grace.pt')
enc_explainer.load_state_dict(state_dict)

<All keys matched successfully>

In [7]:
def get_results(task_id, top_k):
    ppi = PPI('node_dataset/ppi/', transform=get_task_rm_iso(task_id))
    loader = DataLoader(ppi, 1)

    mlp_model = MLP(num_layer = 2, emb_dim = 600, hidden_dim = 600, out_dim = 2)
    mlp_model.load_state_dict(torch.load('ckpts_model/downstream_ppi%d_grace600.pth'%task_id, map_location='cpu'))
    mlp_explainer = MLPExplainer(mlp_model, device)

    x_collector = XCollector()
    for i, data in enumerate(loader):
        for j, node_idx in enumerate(torch.where(data.y)[0]):
            data.to(device)
            walks, masks, related_preds = \
                enc_explainer(data, mlp_explainer, node_idx=node_idx, top_k=top_k)
            fidelity = related_preds[0]['origin'] - related_preds[0]['maskout']

            print(f'explain graph {i} node {node_idx}'+' fidelity %.4f'%fidelity, end='\r')
            x_collector.collect_data(masks, related_preds)

    fid, fid_std = x_collector.fidelity
    spa, spa_std = x_collector.sparsity

    print()
    print(f'Fidelity: {fid:.4f} ±{fid_std:.4f}\n'
          f'Sparsity: {spa:.4f} ±{spa_std:.4f}')

In [8]:
get_results(task_id=0, top_k=100)

explain graph 19 node 3020 fidelity 0.05188
Fidelity: 0.2694 ±0.3878
Sparsity: 0.8545 ±0.1814


In [9]:
get_results(task_id=1, top_k=120)

explain graph 19 node 3020 fidelity -0.0007
Fidelity: 0.3038 ±0.4385
Sparsity: 0.8671 ±0.1770


In [10]:
get_results(task_id=2, top_k=100)

explain graph 19 node 3015 fidelity 0.00002
Fidelity: 0.5042 ±0.4782
Sparsity: 0.8444 ±0.2278


In [11]:
get_results(task_id=3, top_k=400)

explain graph 19 node 3015 fidelity 0.30530
Fidelity: 0.2763 ±0.4332
Sparsity: 0.8541 ±0.2171


In [12]:
get_results(task_id=4, top_k=400)

explain graph 19 node 3015 fidelity 0.01282
Fidelity: 0.3234 ±0.4460
Sparsity: 0.8547 ±0.2490
