In [None]:
import sys
sys.path.append('..')

In [None]:
import time
import random
import argparse
import os.path as osp
import torch
from tqdm import tqdm
from utils.get_subgraph import *
import torch.nn.functional as F
import networkx as nx

import matplotlib.pyplot as plt
from torch_geometric.data import Batch
from gnn import MNISTNet
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.datasets import MNISTSuperpixels
from utils import *
from explainers import *
from utils.dataset import MNISTTransform
_vis_dict_ = {
    'MutagNet': {'node_size': 400, 'linewidths': 1, 'font_size': 10, 'width': 3},
    'Tox21Net': {'node_size': 400, 'linewidths': 1, 'font_size': 10, 'width': 3},
    'BA3MotifNet': {'node_size': 300, 'linewidths': 1, 'font_size': 10, 'width': 3},
    'TR3MotifNet': {'node_size': 300, 'linewidths': 1, 'font_size': 10, 'width': 5},
    'defult': {'node_size': 150, 'linewidths': 1, 'font_size': 10, 'width': 0.5}
}
vis_dict = _vis_dict_['defult']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size =  32
data_path = 'data/MNIST'
transform = MNISTTransform(cat=False, max_value=9)
# transform = T.Cartesian(cat=False, max_value=9)
test_dataset = MNISTSuperpixels(data_path, False, transform=transform)
test_loader = DataLoader(test_dataset[:1000], batch_size=batch_size, shuffle=False)

n_loop = 1
edge_ratio=0.2
f = 'NEW_MNIST_FINAL_NoOverload'#'NEW_MNIST_ablation_cts'
path = "param/gnns/mnist_net.pt"
generative_model_path = 'param/cg/%s/generator/best.pkl' % f
G = torch.load(generative_model_path, map_location=device).cpu().to(device)
gnn = torch.load(path).to(device)

def gen_mnist_attr(edge_index, pos):
    assert pos is not None
    max_value = 9

    (row, col) = edge_index

    cart = pos[col] - pos[row]
    cart = cart.view(-1, 1) if cart.dim() == 1 else cart
    cart = cart / (2 * max_value) + 0.5

    return cart

Explainers = {
    0: RandomCaster,
    1: SAExplainer,
    2: IGExplainer,
    3: DeepLIFTExplainer,
    4: GradCam,
    5: GNNExplainer,
    6: CXplainer,
    # 7: PGExplainer,
    8: PGMExplainer,
    9: Screener,
}

## Stat

In [None]:
num_nodes = 0.
num_edges = 0.
for g in tqdm(iter(test_loader), total=len(test_loader)):
    num_nodes += g.num_nodes
    num_edges += g.num_edges
print("Average #Nodes {:.4f}".format(num_nodes/len(test_loader.dataset)))
print("Average #Edges {:.4f}".format(num_edges/len(test_loader.dataset)))

### CG

In [None]:
set_seed(1)
ID_ACC = torch.tensor([]).to(device)
OOD_ACC = torch.tensor([]).to(device)
with torch.no_grad():
    G.eval()
    ID_perf = torch.tensor([]).to(device)
    for _g in tqdm(iter(test_loader), total=len(test_loader)):
        g = _g.clone()
        g.to(device)
        pos = g.pos
        broken_edge_index, broken_edge_attr, out_edge_ratio = get_mnist_ground_truth_graph(g)
        mu, log_var, z = G.encode(
            x=g.x, in_edge_index=broken_edge_index, 
            in_edge_attr=broken_edge_attr
            )
        _, _, cond_z = G.encode(
            x=g.x, in_edge_index=g.edge_index, 
            in_edge_attr=g.edge_attr
            )
        z = torch.cat([z, cond_z], dim=1)
        for _ in range(3):
            fake_edge_index, fake_edge_prob, fake_edge_attr, _ = \
                G.fill(
                    z=z, preserved_edge_index=broken_edge_index, 
                    preserved_edge_ratio=out_edge_ratio,
                    batch=g.batch, pos=pos, neg_edge_index=None, threshold=False
                    )
            # print(fake_edge_prob.min(), fake_edge_prob.max())
            relabel_x, relabel_edge_index, relabel_batch, relabel_pos = relabel(g.x, fake_edge_index, g.batch, pos)
            new_g = Batch(batch=relabel_batch, x=relabel_x, edge_index=relabel_edge_index, edge_attr=fake_edge_attr, pos=relabel_pos)
            readout = gnn(data=new_g)
            id_acc = (g.y == readout.argmax(dim=1)).view(-1).float()
            ID_perf = torch.cat([ID_perf, id_acc])

    ID_perf = ID_perf.mean().item() * 100
ID_perf

## CG 



In [None]:
for _g in tqdm(iter(test_loader), total=len(test_loader)):
    g = _g.clone()
    g.to(device)
    pos = g.pos.clone()
    broken_edge_index, broken_edge_attr, out_edge_ratio = get_mnist_ground_truth_graph(g)
    
    # computer the acc of ground truth graph
    relabel_x_gd, relabel_edge_index_gd, relabel_batch_gd, relabel_pos_gd = relabel(g.x, broken_edge_index, g.batch, pos=pos)
    _g_gd = Batch(
        batch=relabel_batch_gd, x=relabel_x_gd, 
        edge_index=relabel_edge_index_gd, 
        edge_attr=broken_edge_attr, pos=relabel_pos_gd, y=g.y)
    readout = gnn(_g_gd)
    ood_acc = (_g_gd.y == readout.argmax(dim=1)).view(-1).float()
    OOD_ACC = torch.cat([OOD_ACC, ood_acc])
        
OOD_ACC = OOD_ACC.mean().item() * 100
print("OOD_ACC:", OOD_ACC)
print('VAL value: {:.3f}'.format(ID_perf - OOD_ACC))

### Fidelity

Broken graph ratio: 0.2



In [None]:
set_seed(0)
generative_model_path = 'param/cg/%s/generator/best.pkl' % f
test_loader = DataLoader(test_dataset[:1000], batch_size=1, shuffle=False)
G = torch.load(generative_model_path, map_location=device).cpu().to(device)
gnn = torch.load("param/gnns/mnist_net.pt").to(device)
ID_ACC = torch.tensor([]).to(device)
OOD_ACC = torch.tensor([]).to(device)
fid_id = []
fid_ood = []
for _g in tqdm(iter(test_loader), total=len(test_loader)):
    g = _g.clone()
    g.to(device)
    pos = g.pos
    Y_ori = gnn(g.clone()).softmax(dim=1).detach().cpu().numpy()[0]
    broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, edge_ratio, connectivity=False)
    
    # for broken graph
    relabel_x_gd, relabel_edge_index_gd, relabel_batch_gd, relabel_pos_gd = relabel(g.x, broken_edge_index, g.batch, pos=pos)
    _g_gd = Batch(batch=relabel_batch_gd, 
                  x=relabel_x_gd, edge_index=relabel_edge_index_gd, 
                  edge_attr=broken_edge_attr, pos=relabel_pos_gd, y=g.y)
    readout = gnn(_g_gd)
    Y_gs = readout.softmax(dim=1).detach().cpu().numpy()[0]
    # for counterfactual generated graph
    mu, log_var, z = G.encode(x=g.x, in_edge_index=broken_edge_index, in_edge_attr=broken_edge_attr)
    _, _, cond_z = G.encode(
                    x=g.x, in_edge_index=g.edge_index, 
                    in_edge_attr=g.edge_attr
                    )
    z = torch.cat([z, cond_z], dim=1)
    Y_g_hat_mt = []
    for _ in range(20):
        fake_edge_index, fake_edge_prob, fake_edge_attr, _ = \
            G.fill(z=z, preserved_edge_index=broken_edge_index, preserved_edge_ratio=out_edge_ratio,
                        batch=g.batch, pos=pos,neg_edge_index=None, threshold=False)

        relabel_x, relabel_edge_index, relabel_batch, relabel_pos = relabel(g.x, fake_edge_index, g.batch, pos=pos)
        _g = Batch(batch=relabel_batch, x=relabel_x, edge_index=relabel_edge_index, edge_attr=fake_edge_attr, pos=relabel_pos, y=g.y)
        readout = gnn(_g)
        Y_g_hat_mt.append(readout.softmax(dim=1).detach().cpu().numpy()[0])
    
    Y_g_hat_mt = np.array(Y_g_hat_mt).mean(axis=0)
    fid_id.append(pow(Y_g_hat_mt-Y_ori, 2).sum())
    fid_ood.append(pow(Y_gs-Y_ori, 2).sum())
    
fid_id = torch.tensor(fid_id)
fid_ood = torch.tensor(fid_ood)
print('FID ID value: {:.3f}'.format(fid_id.mean()))
print('FID OOD value: {:.3f}'.format(fid_ood.mean()))

## Random 

### Validity

In [None]:
from module.random_gen import RandomGenerator
random = RandomGenerator()

for i in range(3):
    All_Random_ACC = []
    Random_ACC = torch.tensor([]).to(device)
    for ori_g in tqdm(iter(test_loader), total=len(test_loader)):
        g = ori_g.clone()
        g.to(device)
        broken_edge_index, broken_edge_attr, out_edge_ratio = get_mnist_ground_truth_graph(g)
        tmp_g = g.clone()
        tmp_g.edge_index = broken_edge_index
        tmp_g.edge_attr = broken_edge_attr
        filled_g = random.fill(tmp_g, out_edge_ratio)
        filled_x, filled_edge_index, filled_batch, filled_pos = relabel(filled_g.x, filled_g.edge_index, filled_g.batch, pos=tmp_g.pos)
        filled_edge_attr = gen_mnist_attr(filled_g.edge_index, g.pos)
        relabel_filled_g = Batch(batch=filled_batch, 
                                  x=filled_x, edge_index=filled_edge_index, 
                                  edge_attr=filled_edge_attr, pos=filled_pos, y=g.y)
        readout = gnn(data=relabel_filled_g)
        random_acc = (g.y == readout.argmax(dim=1)).view(-1).float()
        Random_ACC = torch.cat([Random_ACC, random_acc])
    Random_ACC = Random_ACC.mean().item() * 100
    All_Random_ACC.append(Random_ACC)
All_Random_ACC = torch.tensor(All_Random_ACC)
print('ID  Max ACC {:.2f}  Mean ACC {:.2f}'.format(All_Random_ACC.max(), All_Random_ACC.mean()))

In [None]:
print('VAL value: {:.3f}'.format(All_Random_ACC.mean() - 40.8))

### Fidelity

In [None]:
from module.random_gen import RandomGenerator
random = RandomGenerator()
fid_id = []
for ori_g in tqdm(iter(test_loader), total=len(test_loader)):
    g = ori_g.clone()
    g.to(device)
    Y_ori = gnn(data=g).softmax(dim=1).detach().cpu().numpy()[0]
    
    g = ori_g.clone()
    Y_g_hat_mt = []
    
    broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, 0.2, connectivity=False)
    tmp_g = g.clone()
    tmp_g.edge_index = broken_edge_index
    tmp_g.edge_attr = broken_edge_attr
    filled_g = random.fill(tmp_g, out_edge_ratio)
    filled_x, filled_edge_index, filled_batch, filled_pos = relabel(filled_g.x, filled_g.edge_index, filled_g.batch, pos=tmp_g.pos)
    filled_edge_attr = gen_mnist_attr(filled_g.edge_index, g.pos)
    relabel_filled_g = Batch(batch=filled_batch, 
                              x=filled_x, edge_index=filled_edge_index, 
                              edge_attr=filled_edge_attr, pos=filled_pos, y=g.y)
    relabel_filled_g.to(device)
    readout = gnn(data=relabel_filled_g)

    Y_g_hat = readout.softmax(dim=1).detach().cpu().numpy()[0]
    Y_g_hat_mt.append(Y_g_hat)
    
    Y_g_hat_mt = np.array(Y_g_hat_mt).mean(axis=0)
    fid_id.append(pow(Y_g_hat_mt-Y_ori, 2).sum())
    
fid_id = torch.tensor(fid_id)
print('FID ID value: {:.3f}'.format(fid_id.mean()))

## VGAE 

### Validity

In [None]:
vgae = torch.load("param/vgae/mnist.pt").to(device)
ID_ACC = torch.tensor([]).to(device)
for _ in range(1):
    for ori_g in tqdm(iter(test_loader), total=len(test_loader)):
        g = ori_g.clone()
        g.to(device)
        broken_edge_index, broken_edge_attr, out_edge_ratio = get_mnist_ground_truth_graph(g)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, broken_edge_index, g.batch, g.pos)
        z = vgae.encode(g.x, relabel_edge_index, broken_edge_attr)
        num_neg_samples = int((1-out_edge_ratio) * g.num_edges)
        neg_candidates = negative_sampling(broken_edge_index, g.num_nodes, 3 * num_neg_samples)
        prob = vgae.decode(z, neg_candidates)
        neg_idx = torch.argsort(-prob)[:num_neg_samples]
        fake_edge_index = torch.cat([broken_edge_index, neg_candidates[:, neg_idx]], dim=1)
        fake_edge_attr = torch.ones((fake_edge_index.size(1), 1)).to(device)
        
        filled_x, filled_edge_index, filled_batch, filled_pos = relabel(g.x, fake_edge_index, g.batch, pos=g.pos)
        filled_edge_attr = gen_mnist_attr(fake_edge_index, g.pos)
        relabel_filled_g = Batch(batch=filled_batch, 
                                  x=filled_x, edge_index=filled_edge_index, 
                                  edge_attr=filled_edge_attr, pos=filled_pos, y=g.y).to(device)
        readout = gnn(data=relabel_filled_g)
        id_acc = (g.y == readout.argmax(dim=1)).view(-1).float() * 100
        ID_ACC = torch.cat([ID_ACC, id_acc])
        
print('ID   Mean ACC {:.2f}'.format(ID_ACC.mean()))
print('VAL value: {:.3f}'.format(ID_ACC.mean()- 40.8))

### Fidelity

In [None]:
fid_id = []
vgae = torch.load("param/vgae/mnist.pt").to(device)
for ori_g in tqdm(iter(test_loader), total=len(test_loader)):
    g = ori_g.clone()
    g.to(device)
    readout = gnn(data=g)
    g = ori_g.clone().to(device)
    Y_ori = readout.softmax(dim=1).detach().cpu().numpy()[0]
    Y_g_hat_mt = []
    
    broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, 0.2, connectivity=False)
    relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, broken_edge_index, g.batch, None)
    z = vgae.encode(g.x, relabel_edge_index, broken_edge_attr)
    num_neg_samples = int((1-out_edge_ratio) * g.num_edges)
    neg_candidates = negative_sampling(broken_edge_index, g.num_nodes, 2 * num_neg_samples)
    prob = vgae.decode(z, neg_candidates)
    neg_idx = torch.argsort(-prob)[:num_neg_samples]
    fake_edge_index = torch.cat([broken_edge_index, neg_candidates[:, neg_idx]], dim=1)
    fake_edge_attr = torch.ones((fake_edge_index.size(1), 1)).to(device)
    filled_x, filled_edge_index, filled_batch, filled_pos = relabel(g.x, fake_edge_index, g.batch, pos=g.pos)
    filled_edge_attr = gen_mnist_attr(fake_edge_index, g.pos)
    relabel_filled_g = Batch(batch=filled_batch, 
                              x=filled_x, edge_index=filled_edge_index, 
                              edge_attr=filled_edge_attr, pos=filled_pos, y=g.y).to(device)
    readout = gnn(data=relabel_filled_g)
    Y_g_hat = readout.softmax(dim=1).detach().cpu().numpy()[0]
    Y_g_hat_mt.append(Y_g_hat)
    
    Y_g_hat_mt = np.array(Y_g_hat_mt).mean(axis=0)
    if np.isnan(pow(Y_g_hat_mt-Y_ori, 2).sum()):
        print(Y_g_hat_mt, Y_ori)
    else:
        fid_id.append(pow(Y_g_hat_mt-Y_ori, 2).sum())
    
fid_id = torch.tensor(fid_id)
print('FID ID value: {:.3f}'.format(fid_id.mean()))

### For Explainers

In [None]:
top_ratio_list=[0.4]
for ori_g in tqdm(iter(test_loader), total=len(test_loader)):
    g = ori_g.clone()
    g.to(device)
    print(g)
    print(gnn.get_graph_rep(g).size())
    SA = Explainers[1](gnn_model_path=path, gen_model_path=generative_model_path)
    
    g = ori_g.clone()
    g.to(device)
    SA.explain_graph(g)
    
    ood_acc, ood_prob = SA.evaluate_acc(top_ratio_list)
    id_acc, id_prob, _ = SA.evaluate_CounterSup_acc(top_ratio_list)
    precision = SA.evaluate_precision(topk=20)
    break

In [None]:
explainers_id = [1]
top_ratio_list = [edge_ratio]
explainers = [Explainers[i](gnn_model_path=path, gen_model_path=generative_model_path) for i in explainers_id]

seq_precision = []
seq_ood_acc = []
seq_id_acc = []
seq_ood_sp = []
seq_id_sp = []

for e in explainers:
    print(e.name)
    id_acc_logger = []
    ood_acc_logger = []
    precision5_logger = []
    cnt = 0
    for g in tqdm(iter(test_loader), total=len(test_loader)):
        g.to(device)
        e.explain_graph(g)#, large_scale=True, C=5)
        ood_acc, ood_prob = e.evaluate_acc(top_ratio_list)
        id_acc, id_prob, _ = e.evaluate_CounterSup_acc(top_ratio_list)
        precision = e.evaluate_precision(topk=50)
        precision5_logger.append(precision)
        
        ood_acc_logger.append(float(ood_acc[0][0]))
        id_acc_logger.append(float(id_acc[0][0]))
    ood_acc_spearson = np.corrcoef(precision5_logger, ood_acc_logger)[0, -1] # scipy.stats.spearmanr(precision5_logger, ood_acc_logger)[0]
    id_acc_spearson = np.corrcoef(precision5_logger, id_acc_logger)[0, -1] # scipy.stats.spearmanr(precision5_logger, id_acc_logger)[0]
    
    seq_precision.append(np.array(precision5_logger).mean())
    seq_ood_acc.append(np.array(ood_acc_logger).mean())
    seq_id_acc.append(np.array(id_acc_logger).mean())
    seq_ood_sp.append(ood_acc_spearson)
    seq_id_sp.append(id_acc_spearson)
    
    print("w.r.t Precision")
    print("%.4f" % seq_precision[-1])
    
    print("w.r.t OOD ACC")
    print("%.4f" % seq_ood_acc[-1])
    
    print("w.r.t ID ACC")
    print("%.4f" % seq_id_acc[-1])
    
    print("w.r.t PearsonC")
    print('Before Counterfactual Infilling, PearsonC: %.4f' % seq_ood_sp[-1])
    print('After Counterfactual Infilling, PearsonC: %.4f' % seq_id_sp[-1])