In [None]:
import sys
sys.path.append('..')

In [None]:
import copy
import warnings
import scipy.stats
from tqdm import tqdm
import os.path as osp
from torch_geometric.data import DataLoader, Dataset
from datasets.graphss2_dataset import get_dataset, get_dataloader  

from utils import *
from utils.parser import *
from utils.get_subgraph import *

from gnn import *
from datasets import *
from explainers import *

warnings.filterwarnings("ignore")
np.set_printoptions(precision=3, suppress=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Explainers = {
    0: RandomCaster,
    1: SAExplainer,
    2: IGExplainer,
    3: DeepLIFTExplainer,
    4: GradCam,
    5: GNNExplainer,
    6: CXplainer,
    # 7: PGExplainer,
    8: PGMExplainer,
    9: Screener,
}

### Fetch Testing Dataset

In [None]:
n_loop=3
edge_ratio=0.2
MAX_DIAM=100
set_seed(2021)

dataset_name = 'graphsst2'
path = '../param/gnns/%s_net.pt' % dataset_name

train_dataset = get_dataset(dataset_dir='../data/', dataset_name='Graph_SST2', task=None)
val_dataset = None; test_dataset = None
dataloader = get_dataloader(train_dataset,  
                            batch_size=1, 
                            random_split_flag=True,     
                            data_split_ratio=[0.8, 0.1, 0.1], 
                            seed=2)    
dataset = dataloader['test'].dataset
traindataset = dataloader['train'].dataset
trainloader = DataLoader(traindataset,
                        batch_size=1,
                        shuffle=False
                        )
gnn = torch.load(path).to(device)
gmm_path = '../param/gmm/%s.pt' % dataset_name
gmm = torch.load(gmm_path).to(device)

f='GraphSST2_ablation_pen'#'SST2_FINAL_param17'
generative_model_path = '../param/cg/%s/generator/best.pkl' % f
G = torch.load(generative_model_path, map_location=device).cpu().to(device)

dataset_mask = []
folder = '../data/Graph_SST2'
flitered_path = folder + "/filtered_idx_test.pt"
if osp.exists(flitered_path):
    graph_mask = torch.load(flitered_path)
else:
    loader = DataLoader(dataset,
                        batch_size=1,
                        shuffle=False
                        )
    # filter graphs with right prediction
    model = torch.load(path).to(device)
    graph_mask = torch.zeros(len(loader.dataset), dtype=torch.bool)
    idx = 0
    for g in tqdm(iter(loader), total=len(loader)):
        g.to(device)
        model(g.x, g.edge_index, g.edge_attr, g.batch)
        if g.y == model.readout.argmax(dim=1):
            graph_mask[idx] = True
        idx += 1
    torch.save(graph_mask, flitered_path)
    dataset_mask.append(graph_mask)
    
new = []
for i in range(graph_mask.size(0)):
    if graph_mask[i]:
        new.append(dataset[i])
loader = DataLoader(new, batch_size=32, shuffle=False, drop_last=False)

print("number of graphs(Testing): %4d" % len(loader.dataset))

In [None]:
print(loader.dataset[0])

### Stat

In [None]:
num_nodes = 0.
num_edges = 0.
for g in tqdm(iter(loader), total=len(loader)):
    num_nodes += g.num_nodes
    num_edges += g.num_edges
print("Average #Nodes {:.4f}".format(num_nodes/len(loader.dataset)))
print("Average #Edges {:.4f}".format(num_edges/len(loader.dataset)))

###  Validity

### CG

In [None]:
val_filled_likelihood = []
ID_perf = torch.tensor([]).to(device)
OOD_perf = torch.tensor([]).to(device)
for g in tqdm(iter(loader), total=len(loader)):
    g.to(device)
    pos = g.pos
    broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, edge_ratio, connectivity=False)
    relabel_x, relabel_edge_index, relabel_batch, relabel_pos = relabel(g.x, broken_edge_index, g.batch, pos)
    broken_emb = gnn.get_graph_rep(relabel_x, relabel_edge_index, broken_edge_attr, relabel_batch)
    broken_log_likelihood = gmm.score_samples(broken_emb)
    OOD_perf = torch.cat([OOD_perf, broken_log_likelihood])
        
    mu, log_var, z = G.encode(
        x=g.x, in_edge_index=broken_edge_index, 
        in_edge_attr=broken_edge_attr
        )
    _, _, cond_z = G.encode(
        x=g.x, in_edge_index=g.edge_index, 
        in_edge_attr=g.edge_attr
        )
    z = torch.cat([z, cond_z], dim=1)
    for inner_loop in range(n_loop):

        fake_edge_index, fake_edge_prob, fake_edge_attr, _ = \
        G.fill(
                z=z, preserved_edge_index=broken_edge_index, 
                preserved_edge_ratio=out_edge_ratio,
                batch=g.batch, pos=pos, neg_edge_index=None, threshold=False
                )
        relabel_x, relabel_edge_index, relabel_batch, relabel_pos = relabel(g.x, fake_edge_index, g.batch, pos)
        full_emb = gnn.get_graph_rep(relabel_x, relabel_edge_index, fake_edge_attr, relabel_batch)
        full_log_likelihood = gmm.score_samples(full_emb)
        ID_perf = torch.cat([ID_perf, full_log_likelihood])
ID_perf = torch.tensor(ID_perf).mean().item()
OOD_perf = torch.tensor(OOD_perf).mean().item()
print('OOD Mean LL {:.4f}'.format(OOD_perf))
print('ID  Mean LL {:.4f}'.format(ID_perf))
print('VAL value: {:.4f}'.format(ID_perf - OOD_perf))

### Random generator

In [None]:
from module.random_gen import RandomGenerator
random = RandomGenerator()
ID_perf = torch.tensor([]).to(device)
for i in range(n_loop):
    for g in loader:
        g.to(device)
        full_emb = gnn.get_graph_rep(g.x,g.edge_index, g.edge_attr, g.batch)
        full_log_likelihood = gmm.score_samples(full_emb)
        ID_perf = torch.cat([ID_perf, full_log_likelihood])
#         g.to(device)
#         broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, edge_ratio, connectivity=False)
#         tmp_g = g.clone()
#         tmp_g.edge_index = broken_edge_index
#         tmp_g.edge_attr = broken_edge_attr
#         filled_g = random.fill(tmp_g, out_edge_ratio)
#         relabel_x, relabel_edge_index, relabel_batch, _ = relabel(filled_g.x, filled_g.edge_index, filled_g.batch, None)
#         full_emb = gnn.get_graph_rep(relabel_x, relabel_edge_index, filled_g.edge_attr, relabel_batch)
#         full_log_likelihood = gmm.score_samples(full_emb)
#         ID_perf = torch.cat([ID_perf, full_log_likelihood])
        
ID_perf = torch.tensor(ID_perf).mean().item()
print('ID  Mean LL {:.4f}'.format(ID_perf))
print('VAL value: {:.4f}'.format(ID_perf - OOD_perf))

### Regular VGAE

In [None]:
vgae = torch.load("../param/vgae/%s.pt" % dataset_name).to(device)
loader = DataLoader(new, batch_size=1, shuffle=False, drop_last=False)
ID_perf = torch.tensor([]).to(device)
for _ in range(n_loop):
    for g in tqdm(iter(loader), total=len(loader)):
        g.to(device)
        if g.num_edges==0:
            continue
        broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, edge_ratio, connectivity=False)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, broken_edge_index, g.batch, None)
        z = vgae.encode(g.x, relabel_edge_index, broken_edge_attr)
        num_neg_samples = int((1-out_edge_ratio) * g.num_edges)
        neg_candidates = negative_sampling(broken_edge_index, g.num_nodes, 3 * num_neg_samples)
        prob = vgae.decode(z, neg_candidates)
        
        neg_idx = torch.argsort(-prob)[:num_neg_samples]
        fake_edge_index = torch.cat([broken_edge_index, neg_candidates[:, neg_idx]], dim=1)
        fake_edge_attr = torch.ones((fake_edge_index.size(1), 1)).to(device)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, fake_edge_index, g.batch, None)
        
        full_emb = gnn.get_graph_rep(relabel_x, relabel_edge_index, fake_edge_attr, relabel_batch)
        full_log_likelihood = gmm.score_samples(full_emb).unsqueeze(0)
        ID_perf = torch.cat([ID_perf, full_log_likelihood])
        

ID_perf = torch.tensor(ID_perf).mean().item()
print('ID  Mean LL {:.4f}'.format(ID_perf))
print('VAL value: {:.4f}'.format(ID_perf - OOD_perf))

In [None]:
print('VAL value: {:.4f}'.format(ID_perf - 24.0371))

### Fidelity

### CG

In [None]:
fid_id = []
fid_ood = []
for g in loader:
    g.to(device)
    gnn(g.x,
        g.edge_index,
        g.edge_attr,
        g.batch
        )
    Y_ori = gnn.readout.detach().cpu().numpy()[0]
    broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, edge_ratio, connectivity=False)
    fake_edge_index = broken_edge_index
    fake_edge_attr = broken_edge_attr
    relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, fake_edge_index, g.batch, None)

    gnn(relabel_x,
        relabel_edge_index,
        fake_edge_attr,
        relabel_batch
        )
    Y_gs = gnn.readout.detach().cpu().numpy()[0]
    Y_g_hat_mt = []
    for _ in range(n_loop):
        mu, log_var, z = G.encode(
                x=g.x, in_edge_index=broken_edge_index, 
                in_edge_attr=broken_edge_attr, reparameterize=True
        )
        _, _, cond_z = G.encode(
                x=g.x, in_edge_index=g.edge_index, 
                in_edge_attr=g.edge_attr, reparameterize=True
                )
        z = torch.cat([z, cond_z], dim=1)
        fake_edge_index, fake_edge_prob, fake_edge_attr, _ = \
            G.fill(z=z, preserved_edge_index=broken_edge_index, preserved_edge_ratio=out_edge_ratio,
                        batch=g.batch, neg_edge_index=None, threshold=False)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, fake_edge_index, g.batch, None)

        gnn(relabel_x,
            relabel_edge_index,
            fake_edge_attr,
            relabel_batch
            )
        Y_g_hat = gnn.readout.detach().cpu().numpy()[0]
        Y_g_hat_mt.append(Y_g_hat)
        
    Y_g_hat_mt = np.array(Y_g_hat_mt).mean(axis=0)
    fid_id.append(pow(Y_g_hat_mt-Y_ori, 2).sum())
    fid_ood.append(pow(Y_gs-Y_ori, 2).sum())
    
fid_id = torch.tensor(fid_id)
fid_ood = torch.tensor(fid_ood)
print('FID ID value: {:.3f}'.format(fid_id.mean()))
print('FID OOD value: {:.3f}'.format(fid_ood.mean()))

### Random generator

In [None]:
from module.random_gen import RandomGenerator
random = RandomGenerator()
fid_id = []
for g in loader:
    g.to(device)
    gnn(
        g.x,
        g.edge_index,
        g.edge_attr,
        g.batch
        )
    Y_ori = gnn.readout.detach().cpu().numpy()[0]
    Y_g_hat_mt = []
    for i in range(n_loop):
        broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, edge_ratio, connectivity=False)
        tmp_g = g.clone()
        tmp_g.edge_index = broken_edge_index
        tmp_g.edge_attr = broken_edge_attr
        filled_g = random.fill(tmp_g, out_edge_ratio)
        relabel_x, relabel_edge_index, relabel_batch, _ = relabel(filled_g.x, filled_g.edge_index, filled_g.batch, None)
        gnn(
            relabel_x,
            relabel_edge_index,
            filled_g.edge_attr,
            relabel_batch
          )
        Y_g_hat = gnn.readout.detach().cpu().numpy()[0]
        Y_g_hat_mt.append(Y_g_hat)
    Y_g_hat_mt = np.array(Y_g_hat_mt).mean(axis=0)
    fid_id.append(pow(Y_g_hat_mt-Y_ori, 2).sum())
    
fid_id = torch.tensor(fid_id)
print('FID ID value: {:.3f}'.format(fid_id.mean()))

### VGAE

In [None]:
fid_id = []
vgae = torch.load("../param/vgae/_%s.pt" % dataset_name).to(device)
loader = DataLoader(new, batch_size=32, shuffle=False, drop_last=False)
loader = DataLoader(new, batch_size=1, shuffle=False, drop_last=False)
for g in loader:
    g.to(device)
    gnn(
        g.x,
        g.edge_index,
        g.edge_attr,
        g.batch
        )
    Y_ori = gnn.readout.detach().cpu().numpy()[0]
    Y_g_hat_mt = []
    
    if g.num_edges==0:
        continue
    broken_edge_index, broken_edge_attr, out_edge_ratio = get_broken_graph(g, edge_ratio, connectivity=False)
    relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, broken_edge_index, g.batch, None)
    z = vgae.encode(g.x, relabel_edge_index, broken_edge_attr)
    num_neg_samples = int((1-out_edge_ratio) * g.num_edges)
    neg_candidates = negative_sampling(broken_edge_index, g.num_nodes, 2 * num_neg_samples)
    prob = vgae.decode(z, neg_candidates)
    neg_idx = torch.argsort(-prob)[:num_neg_samples]
    fake_edge_index = torch.cat([broken_edge_index, neg_candidates[:, neg_idx]], dim=1)
    fake_edge_attr = torch.ones((fake_edge_index.size(1), 1)).to(device)
    relabel_x, relabel_edge_index, relabel_batch, _ = relabel(g.x, fake_edge_index, g.batch, None)
    gnn(
        relabel_x,
        relabel_edge_index,
        fake_edge_attr,
        relabel_batch
        )
    Y_g_hat = gnn.readout.detach().cpu().numpy()[0]
    Y_g_hat_mt.append(Y_g_hat)
    Y_g_hat_mt = np.array(Y_g_hat_mt).mean(axis=0)
    fid_id.append(pow(Y_g_hat_mt-Y_ori, 2).sum())
    
fid_id = torch.tensor(fid_id)
print('FID ID value: {:.3f}'.format(fid_id.mean()))

### For explainers

In [None]:
explainers_id = [4] # 1,4,5,6,,9
top_ratio_list = [edge_ratio]
loader = DataLoader(new, batch_size=1, shuffle=False, drop_last=False)
explainers = [Explainers[i](gnn_model_path=path, gen_model_path=generative_model_path) for i in explainers_id]

seq_ood_acc = []
seq_id_acc = []
seq_ood_sp = []
seq_id_sp = []
for e in explainers:
    print(e.name)
    id_acc_logger = []
    ood_acc_logger = []
    precision5_logger = []
    for g in tqdm(iter(loader), total=len(loader)):
        if g.num_edges == 0:
            continue
        if not ' '.join(g.sentence_tokens[0]) ==  "a little weak -- and it is n't that funny .":
            continue

        
        g.to(device)
        e.explain_graph(g)
        ood_acc, ood_prob = e.evaluate_acc(top_ratio_list)
        id_acc, id_prob, fake_edge_index = e.evaluate_CounterSup_acc(top_ratio_list, return_fake_ratio=0.4)
        
         # For Visualization
        counter_edge_index = set([(int(fake_edge_index[0][i]), int(fake_edge_index[1][i])) for i in range(fake_edge_index.size(1))])-set([(int(g.edge_index[0][i]), int(g.edge_index[1][i])) for i in range(g.edge_index.size(1))])
        counter_edge_index = torch.tensor(list(counter_edge_index)).T      
        g.name = "a"
        e.visualize(save=True, vis_ratio=0.2)
        _g, imp = e.last_result
        _g.name="d"
        e.last_result = (_g, imp)
        e.visualize(save=True, counter_edge_index=counter_edge_index)

        ood_acc_logger.append(float(ood_prob[0][0]))
        id_acc_logger.append(float(id_prob[0][0]))
        
        print(' '.join(g.sentence_tokens[0]))
        print(g.y , ood_prob, id_prob)
    seq_ood_acc.append(np.array(ood_acc_logger).mean())
    seq_id_acc.append(np.array(id_acc_logger).mean())
    
    print("w.r.t OOD ACC")
    print("%.4f" % seq_ood_acc[-1])
    
    print("w.r.t ID ACC")
    print("%.4f" % seq_id_acc[-1])
    