In [None]:
import sys
sys.path.append('..')

In [None]:
import copy
import warnings
import scipy.stats
from tqdm import tqdm
import os.path as osp
from torch_geometric.data import DataLoader

from utils import *
from utils.parser import *
from utils.get_subgraph import *

from gnn import *
from datasets import *
from explainers import *


MAX_DIAM=100
warnings.filterwarnings("ignore")
np.set_printoptions(precision=3, suppress=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
datasets = ["mutag", "reddit5k", "vg", "ba3"]
Explainers = {
    0: RandomCaster,
    1: SAExplainer,
    2: IGExplainer,
    3: DeepLIFTExplainer,
    4: GradCam,
    5: GNNExplainer,
    6: CXplainer,
    # 7: PGExplainer,
    8: PGMExplainer,
    9: Screener,
}


### Fetch Testing Dataset

In [None]:
set_seed(2021)
dataset_name = 'tr3'
path = '../param/gnns/%s_net.pt' % dataset_name
random_model = torch.load(path).to(device)
random_model.reset_parameters()
folder = '../data/TR3'
dataset = TR3Motif(folder, mode='testing')
dataset_mask = []

flitered_path = folder + "/filtered_idx_test.pt"
if osp.exists(flitered_path):
    graph_mask = torch.load(flitered_path)
else:
    loader = DataLoader(dataset,
                        batch_size=1,
                        shuffle=False
                        )
    # filter graphs with right prediction
    model = torch.load(path).to(device)
    graph_mask = torch.zeros(len(loader.dataset), dtype=torch.bool)
    idx = 0
    for g in tqdm(iter(loader), total=len(loader)):

        g.to(device)
        model(g.x, g.edge_index, g.edge_attr, g.batch)
        if g.y == model.readout.argmax(dim=1):
            graph_mask[idx] = True
        idx += 1
    
    torch.save(graph_mask, flitered_path)
    dataset_mask.append(graph_mask)
gnn = torch.load(path).to(device)
n_filtered = graph_mask.nonzero().size(0)
loader = DataLoader(dataset[graph_mask], batch_size=32, shuffle=False, drop_last=False)
print("number of graphs(Testing): %4d" % len(dataset[graph_mask]))

## Stat

In [None]:
num_nodes = 0.
num_edges = 0.
for g in tqdm(iter(loader), total=len(loader)):
    num_nodes += g.num_nodes
    num_edges += g.num_edges
print("Average #Nodes {:.4f}".format(num_nodes/len(loader.dataset)))
print("Average #Edges {:.4f}".format(num_edges/len(loader.dataset)))

### Experienment 1

In [None]:
def get_single_ground_truth_graph(g):

    _, _, _, num_edges, cum_edges = split_batch(g)
    nodel_label = np.concatenate(g.z, axis=0)
    row, col = g.edge_index.detach().cpu().numpy()
    broken_mask = torch.tensor(nodel_label[row] * nodel_label[col] > 0, dtype=torch.bool)
    
    broken_edge_indices = torch.LongTensor([[],[]]).to(device)
    broken_edge_attrs = torch.LongTensor([]).to(device)
    out_edge_ratio = []
    for E, C in zip(num_edges.tolist(), cum_edges.tolist()):
        edge_idx = torch.nonzero(broken_mask[C: C + E]).view(-1) + C
        edge_index = g.edge_index[:, edge_idx]
        node_idx = np.random.choice(np.unique(edge_index.detach().cpu().numpy()))
        node_idx = torch.tensor([node_idx]).to(device)
        _, broken_edge_index, _, edge_mask = bid_k_hop_subgraph(node_idx, num_hops=5, edge_index=edge_index)
        broken_edge_attr = g.edge_attr[C: C + E][edge_idx - C][edge_mask]
        broken_edge_indices = torch.cat([broken_edge_indices, broken_edge_index], dim=1)
        broken_edge_attrs = torch.cat([broken_edge_attrs, broken_edge_attr], dim=0)
        out_edge_ratio.append(float(broken_edge_index.size(1)) / E)
        
    out_edge_ratio = torch.tensor(out_edge_ratio).to(g.x.device)
    return broken_edge_indices, broken_edge_attrs, out_edge_ratio

### CG

In [None]:
cal_ground_truth = 1
draw_graph = 0
n_loop = 5
num_test=1000
for f in ['TR3_FINAL_param4']:
    mean = []
    generative_model_path = '../param/cg/%s/generator/best.pkl' % f
    G = torch.load(generative_model_path, map_location=device).cpu().to(device)
    val_filled_likelihood = []
    All_ID_ACC = []
    All_OOD_ACC = []
    exp = Explainers[1](gnn_model_path=path, gen_model_path=generative_model_path)
    
    for i in range(n_loop):
        cnt = 0
        ID_ACC = torch.tensor([]).to(device)
        OOD_ACC = torch.tensor([]).to(device)
        for g in loader:
            cnt += 1
            if cnt > num_test:
                break
            g.to(device)
            broken_edge_index, broken_edge_attr, out_edge_ratio = get_single_ground_truth_graph(g)
            # ground truth
            if cal_ground_truth:
                fake_edge_index = broken_edge_index
                fake_edge_attr = broken_edge_attr
                relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, fake_edge_index, g.batch, None)

                readout = gnn(relabel_x,
                              relabel_edge_index,
                              fake_edge_attr,
                              relabel_batch
                              )
                ood_acc = (g.y == readout.argmax(dim=1)).view(-1).float()
                fake_edge_index = set([(int(fake_edge_index[0][i]), int(fake_edge_index[1][i])) for i in range(fake_edge_index.size(1))])-set([(int(g.edge_index[0][i]), int(g.edge_index[1][i])) for i in range(g.edge_index.size(1))])
                fake_edge_index = torch.tensor(list(fake_edge_index)).T
                if draw_graph:
                    print("OOD ACC: %.4f" % gnn.readout[0, g.y])
                    exp.visualize(graph=g, edge_imp=g.ground_truth_mask[0], vis_ratio=out_edge_ratio, counter_edge_index=fake_edge_index, layout=False)
                OOD_ACC = torch.cat([OOD_ACC, ood_acc])
                
            mu, log_var, z = G.encode(
                x=g.x, in_edge_index=broken_edge_index, 
                in_edge_attr=broken_edge_attr, reparameterize=False
            )
            _, _, cond_z = G.encode(
                    x=g.x, in_edge_index=g.edge_index, 
                    in_edge_attr=g.edge_attr, reparameterize=False
                    )
            z = torch.cat([z, cond_z], dim=1)
            fake_edge_index, fake_edge_prob, fake_edge_attr, _ = \
                G.fill(z=z, preserved_edge_index=broken_edge_index, preserved_edge_ratio=out_edge_ratio,
                            batch=g.batch, neg_edge_index=None, threshold=False)
            relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, fake_edge_index, g.batch, None)
            
            readout = gnn(relabel_x,
                          relabel_edge_index,
                          fake_edge_attr,
                          relabel_batch
                          )
            id_acc = (g.y == readout.argmax(dim=1)).view(-1).float()
            fake_edge_index = set([(int(fake_edge_index[0][i]), int(fake_edge_index[1][i])) for i in range(fake_edge_index.size(1))])-set([(int(g.edge_index[0][i]), int(g.edge_index[1][i])) for i in range(g.edge_index.size(1))])
            fake_edge_index = torch.tensor(list(fake_edge_index)).T
            if draw_graph:
                print("ID ACC: %.4f" % gnn.readout[0, g.y])
                exp.visualize(graph=g, edge_imp=g.ground_truth_mask[0], vis_ratio=out_edge_ratio, counter_edge_index=fake_edge_index, layout=False)
            ID_ACC = torch.cat([ID_ACC, id_acc])
            
            if draw_graph and cal_ground_truth:
                if ood_acc < id_acc:
                    print("Counterfactual Incrementation!")
                else:
                    print("Counterfactual Decrease!")

        ID_ACC = ID_ACC.mean().item() * 100
        All_ID_ACC.append(ID_ACC)
        OOD_ACC = OOD_ACC.mean().item() * 100
        All_OOD_ACC.append(OOD_ACC)
    
    All_ID_ACC = torch.tensor(All_ID_ACC)
    All_OOD_ACC = torch.tensor(All_OOD_ACC)
    print('ID  Max ACC {:.2f}  Mean ACC {:.2f}'.format(max(All_ID_ACC), All_ID_ACC.mean()))
    print('OOD Max ACC {:.2f}  Mean ACC {:.2f}'.format(max(All_OOD_ACC), All_OOD_ACC.mean()))
    print('VAL value: {:.3f}'.format(All_ID_ACC.mean() - All_OOD_ACC.mean()))

### Random generator

In [None]:
from module.random_gen import RandomGenerator
random = RandomGenerator()
for i in range(n_loop):
    All_Random_ACC = []
    Random_ACC = torch.tensor([]).to(device)
    for g in loader:
        g.to(device)
        broken_edge_index, broken_edge_attr, out_edge_ratio = get_single_ground_truth_graph(g)
        tmp_g = g.clone()
        tmp_g.edge_index = broken_edge_index
        tmp_g.edge_attr = broken_edge_attr
        filled_g = random.fill(tmp_g, out_edge_ratio)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(filled_g.x, filled_g.edge_index, filled_g.batch, None)
        readout = gnn(relabel_x,
                      relabel_edge_index,
                      filled_g.edge_attr,
                      relabel_batch
                      )
        random_acc = (g.y == readout.argmax(dim=1)).view(-1).float()
        Random_ACC = torch.cat([Random_ACC, random_acc])
    Random_ACC = Random_ACC.mean().item() * 100
    All_Random_ACC.append(Random_ACC)
All_Random_ACC = torch.tensor(All_Random_ACC)
print('ID  Max ACC {:.2f}  Mean ACC {:.2f}'.format(All_Random_ACC.max(), All_Random_ACC.mean()))
print('VAL value: {:.3f}'.format(All_Random_ACC.mean() - All_OOD_ACC.mean()))

### Regular VGAE

In [None]:
vgae = torch.load("../param/vgae/%s.pt" % dataset_name).to(device)
loader = DataLoader(dataset[graph_mask], batch_size=1, shuffle=False, drop_last=False)
ID_ACC = torch.tensor([]).to(device)
for _ in range(n_loop):
    for g in loader:
        g.to(device)
        broken_edge_index, broken_edge_attr, out_edge_ratio = get_single_ground_truth_graph(g)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, broken_edge_index, g.batch, None)
        z = vgae.encode(g.x, relabel_edge_index, broken_edge_attr)
        num_neg_samples = int((1-out_edge_ratio) * g.num_edges)
        neg_candidates = negative_sampling(broken_edge_index, g.num_nodes, 2 * num_neg_samples)
        prob = vgae.decode(z, neg_candidates)
        neg_idx = torch.argsort(-prob)[:num_neg_samples]
        fake_edge_index = torch.cat([broken_edge_index, neg_candidates[:, neg_idx]], dim=1)
        fake_edge_attr = torch.ones((fake_edge_index.size(1), 1)).to(device)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, fake_edge_index, g.batch, None)
        readout = gnn(
            relabel_x,
            relabel_edge_index,
            fake_edge_attr,
            relabel_batch
            )
        id_acc = (g.y == readout.argmax(dim=1)).view(-1).float() * 100
        ID_ACC = torch.cat([ID_ACC, id_acc])
        
print('ID   Mean ACC {:.2f}'.format(ID_ACC.mean()))
print('VAL value: {:.3f}'.format(ID_ACC.mean() - All_OOD_ACC.mean()))

### Experienment 2: Fidelity

### CG

In [None]:
fid_id = []
fid_ood = []
for g in loader:
    g.to(device)
    gnn(g.x,
        g.edge_index,
        g.edge_attr,
        g.batch
        )
    Y_ori = gnn.readout.detach().cpu().numpy()[0]
    broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, 0.5, connectivity=False)
    fake_edge_index = broken_edge_index
    fake_edge_attr = broken_edge_attr
    relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, fake_edge_index, g.batch, None)

    gnn(relabel_x,
        relabel_edge_index,
        fake_edge_attr,
        relabel_batch
        )
    Y_gs = gnn.readout.detach().cpu().numpy()[0]
    Y_g_hat_mt = []
    for _ in range(n_loop):
        mu, log_var, z = G.encode(
                x=g.x, in_edge_index=broken_edge_index, 
                in_edge_attr=broken_edge_attr, reparameterize=True
        )
        _, _, cond_z = G.encode(
                x=g.x, in_edge_index=g.edge_index, 
                in_edge_attr=g.edge_attr, reparameterize=True
                )
        z = torch.cat([z, cond_z], dim=1)
        fake_edge_index, fake_edge_prob, fake_edge_attr, _ = \
            G.fill(z=z, preserved_edge_index=broken_edge_index, preserved_edge_ratio=out_edge_ratio,
                        batch=g.batch, neg_edge_index=None, threshold=False)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, fake_edge_index, g.batch, None)

        gnn(relabel_x,
            relabel_edge_index,
            fake_edge_attr,
            relabel_batch
            )
        Y_g_hat = gnn.readout.detach().cpu().numpy()[0]
        Y_g_hat_mt.append(Y_g_hat)
        
    Y_g_hat_mt = np.array(Y_g_hat_mt).mean(axis=0)
    fid_id.append(pow(Y_g_hat_mt-Y_ori, 2).sum())
    fid_ood.append(pow(Y_gs-Y_ori, 2).sum())
    
fid_id = torch.tensor(fid_id)
fid_ood = torch.tensor(fid_ood)
print('FID ID value: {:.3f}'.format(fid_id.mean()))
print('FID OOD value: {:.3f}'.format(fid_ood.mean()))

### Random generator

In [None]:
fid_id = []
for g in loader:
    g.to(device)
    gnn(
        g.x,
        g.edge_index,
        g.edge_attr,
        g.batch
        )
    Y_ori = gnn.readout.detach().cpu().numpy()[0]
    Y_g_hat_mt = []
    for i in range(n_loop):
        broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, 0.5, connectivity=False)
        tmp_g = g.clone()
        tmp_g.edge_index = broken_edge_index
        tmp_g.edge_attr = broken_edge_attr
        filled_g = random.fill(tmp_g, out_edge_ratio)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(filled_g.x, filled_g.edge_index, filled_g.batch, None)
        gnn(
            relabel_x,
            relabel_edge_index,
            filled_g.edge_attr,
            relabel_batch
          )
        Y_g_hat = gnn.readout.detach().cpu().numpy()[0]
        Y_g_hat_mt.append(Y_g_hat)
    Y_g_hat_mt = np.array(Y_g_hat_mt).mean(axis=0)
    fid_id.append(pow(Y_g_hat_mt-Y_ori, 2).sum())
    
fid_id = torch.tensor(fid_id)
print('FID ID value: {:.3f}'.format(fid_id.mean()))

### VGAE

In [None]:
fid_id = []
for g in loader:
    g.to(device)
    gnn(
        g.x,
        g.edge_index,
        g.edge_attr,
        g.batch
        )
    Y_ori = gnn.readout.detach().cpu().numpy()[0]
    Y_g_hat_mt = []
    for _ in range(n_loop):
        broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, 0.5, connectivity=False)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, broken_edge_index, g.batch, None)
        z = vgae.encode(g.x, relabel_edge_index, broken_edge_attr)
        num_neg_samples = int((1-out_edge_ratio) * g.num_edges)
        neg_candidates = negative_sampling(broken_edge_index, g.num_nodes, 2 * num_neg_samples)
        prob = vgae.decode(z, neg_candidates)
        neg_idx = torch.argsort(-prob)[:num_neg_samples]
        fake_edge_index = torch.cat([broken_edge_index, neg_candidates[:, neg_idx]], dim=1)
        fake_edge_attr = torch.ones((fake_edge_index.size(1), 1)).to(device)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, fake_edge_index, g.batch, None)
        gnn(
            relabel_x,
            relabel_edge_index,
            fake_edge_attr,
            relabel_batch
            )
        Y_g_hat = gnn.readout.detach().cpu().numpy()[0]
        Y_g_hat_mt.append(Y_g_hat)
    Y_g_hat_mt = np.array(Y_g_hat_mt).mean(axis=0)
    fid_id.append(pow(Y_g_hat_mt-Y_ori, 2).sum())
    
fid_id = torch.tensor(fid_id)
print('FID ID value: {:.3f}'.format(fid_id.mean()))

In [None]:
n_loop = 1
top_ratio_list = [0.15]
explainers_id = [1, 4, 5, 6, 8, 9]
for f in ['TR3_FINAL_param4']:
    generative_model_path = '../param/cg/%s/generator/best.pkl' % f
test_loader = DataLoader(dataset[graph_mask], batch_size=1, shuffle=False, drop_last=False)
explainers = [Explainers[i](gnn_model_path=path, gen_model_path=generative_model_path) for i in explainers_id]

seq_precision = []
seq_ood_acc = []
seq_id_acc = []
seq_ood_sp = []
seq_id_sp = []
for e in explainers:
    print(e.name)
    id_acc_logger = []
    ood_acc_logger = []
    precision3_logger = []
    cnt = 0
    for _ in range(n_loop):
        for g in test_loader:
            cnt += 1
            g.to(device)
            e.explain_graph(g)
            ood_acc, ood_prob = e.evaluate_acc(top_ratio_list)
            id_acc, id_prob, _ = e.evaluate_CounterSup_acc(top_ratio_list)
            precision = e.evaluate_precision(topk=3)
            precision3_logger.append(precision)

            ood_acc_logger.append(float(ood_acc[0][0]))
            id_acc_logger.append(float(id_acc[0][0]))
    ood_acc_spearson = np.corrcoef(precision3_logger, ood_acc_logger)[0, -1] # scipy.stats.spearmanr(precision3_logger, ood_acc_logger)[0]
    id_acc_spearson = np.corrcoef(precision3_logger, id_acc_logger)[0, -1] # scipy.stats.spearmanr(precision3_logger, id_acc_logger)[0]
    
    seq_precision.append(np.array(precision3_logger).mean())
    seq_ood_acc.append(np.array(ood_acc_logger).mean())
    seq_id_acc.append(np.array(id_acc_logger).mean())
    seq_ood_sp.append(ood_acc_spearson)
    seq_id_sp.append(id_acc_spearson)
    
    print("w.r.t Precision")
    print("%.4f" % seq_precision[-1])
    
    print("w.r.t OOD ACC")
    print("%.4f" % seq_ood_acc[-1])
    
    print("w.r.t ID ACC")
    print("%.4f" % seq_id_acc[-1])
    
    print("w.r.t PearsonC")
    print('Before Counterfactual Infilling, PearsonC: %.4f' % seq_ood_sp[-1])
    print('After Counterfactual Infilling, PearsonC: %.4f' % seq_id_sp[-1])
    