In [11]:
import sys
import os

# Navigate up from notebooks/ to main/, then add explainers/ to the path
current_dir = os.getcwd()  # Returns main/notebooks/
parent_dir = os.path.dirname(current_dir)  # Returns main/
sys.path.append(parent_dir)  # Now Python can see main/

In [13]:
# Activate Conda Env "Ultra" As the Kernel

import torch
import pandas as pd
from explainers.data_util import batched_k_hop_subgraph
import networkx as nx
from pyvis.network import Network

In [15]:
dataset = torch.load('/storage/ryoji/Graph-Transformer/NBFNet-PyG/misc/wn18rr_dataset.pt')
id2entity, id2relation = torch.load('/storage/ryoji/Graph-Transformer/NBFNet-PyG/misc/wn18rr_id2name.pt')

In [16]:
split = 'valid' # the split
ratio = 0.001 # the top r ratio of edges

if split == 'valid':
    data_index = 1
if split == 'test':
    data_index = 2

# output_file = f'/storage/ryoji/Graph-Transformer/NBFNet-PyG/explanation/NBFNet/WN18RR/explanation_output/{split}_output_hard_edge_mask_top_ratio_{ratio}.pt'
output_file = '/storage/ryoji/Graph-Transformer/NBFNet-PyG/explanation/RAWExplainer/NBFNet/WN18RR/2025-02-07-11-40-10/test_output_factual_eval_hard_edge_mask_top_k_100.pt'
# explanation_file = f'/storage/ryoji/Graph-Transformer/NBFNet-PyG/explanation/NBFNet/WN18RR/explanation_output/{split}_explanations_hard_edge_mask_top_ratio_{ratio}.pt'
explanation_file = '/storage/ryoji/Graph-Transformer/NBFNet-PyG/explanation/RAWExplainer/NBFNet/WN18RR/2025-02-07-11-40-10/test_explanations_factual_eval_hard_edge_mask_top_k_100.pt'

explanations = torch.load(explanation_file)
outputs = torch.load(output_file)
output_df = pd.DataFrame(outputs)

data = dataset[data_index]

In [21]:
output_df.iloc[1]

Ranking      19838
Heads           97
Tails        20114
Rel              2
Mode             1
Num_Edges       96
Num_Nodes       28
Inclusion        0
Name: 1, dtype: int64

In [22]:
explanations[1].sum()

tensor(96)

In [23]:
def vizualize_explanation(index):
    row = output_df.iloc[index]
    if row['Mode'] == 1:
        rel = id2relation[row['Rel']]
    else:
        rel = id2relation[row['Rel']]+'_inv'

    head_id = row['Heads']
    tail_id = row['Tails']
    head = id2entity[head_id]
    tail = id2entity[tail_id]

    print(f"*** Query: {head, rel}. Answer: {tail}, Rank given the explanation: {row['Ranking']} ***")
    nodes, edges = batched_k_hop_subgraph(torch.tensor([head_id]).unsqueeze(0), 6, data.edge_index, data.num_nodes)

    print(f'Is the tail included in the 6-hop neighbor of head in the original graph? {nodes[0, tail_id].item()}')
    full_edge_index = data.edge_index[:, edges.squeeze()]

    expl = explanations[index]
    # Check that the explanation are restricted only to within the ego_network (not necessary)
    # assert torch.all(torch.logical_and(expl, ~edges.squeeze()) == False)
    expl_edge_index = data.edge_index[:, expl.to(torch.bool)]

    print(f'Is the head included in the Explanation? {torch.any((expl_edge_index == head_id)).item()}')
    print(f'Is the tail included in the Explanation? {torch.any((expl_edge_index == tail_id)).item()}')

    # Create a directed graph
    G = nx.DiGraph()

    # Add edges from expl_edge_index
    for i in range(expl_edge_index.shape[1]):
        source = expl_edge_index[0, i].item()
        target = expl_edge_index[1, i].item()
        # G.add_edge(source, target,arrows='to')
        G.add_edge(source, target)

    attrs = {head_id: "#FF0000", tail_id: "#00FF00"}
    nx.set_node_attributes(G, attrs, name='color')

    try:
        paths = list(nx.all_simple_paths(G, source = head_id, target=tail_id, cutoff = 6))
    
        attrs = {}
        if len(paths) > 0:
            connected = True
        else:
            connected = False
    
        print(f'Is the tail connected to the head? {connected}')
        for path in paths:
            src = head_id
            for node in path[1:]:
                edge = (src, node)
                attrs[edge] = {'color':"#b27ebd"}
                src = node

        nx.set_edge_attributes(G, attrs)
    except:
        print(f'Is the tail connected to the head? {False}')
            
    
    return G
    

In [24]:
display(output_df)

Unnamed: 0,Ranking,Heads,Tails,Rel,Mode,Num_Edges,Num_Nodes,Inclusion
0,1,108,8037,1,1,92,26,1
1,19838,97,20114,2,1,96,28,0
2,4,261,28044,3,1,100,34,1
3,31259,297,13112,0,1,100,24,0
4,15603,314,16265,2,1,94,29,0
5,1,82,5590,1,1,100,31,1
6,11,296,22129,1,1,98,24,1
7,4,30,24549,4,1,100,39,1
8,1,8037,108,1,0,90,27,1
9,1,20114,97,2,0,98,24,1


In [25]:
# Control the index (row number of output_df of the instance you want to inspect here)
index = 3

In [26]:
G = vizualize_explanation(index)

*** Query: ('travel.v.01', '_also_see'). Answer: advance.v.01, Rank given the explanation: 31259 ***
Is the tail included in the 6-hop neighbor of head in the original graph? True
Is the head included in the Explanation? True
Is the tail included in the Explanation? False
Is the tail connected to the head? False


In [27]:
# *** Explanation Graph Visualization ***
# RED NODE: The Query Head (if inside the explanation)
# GREEN NODE: The Answer Tail (if inside the explanation)
# Purple Edges: Edges that are in any path that connected Red to Green
net = Network(notebook = True, cdn_resources = "remote",
                    bgcolor = "#222222",
                    font_color = "white",
                    height = "750px",
                    width = "100%",
                    select_menu = True,
                    filter_menu = True,
                    directed=True
    )
net.from_nx(G)
net.inherit_edge_colors(False)
net.set_edge_smooth('dynamic')
net.show('graph.html')

graph.html
