# Accuracy experiments with BA-Shapes

## Load and visualize the data

In [53]:
# Import the utility functions
from attention_analysis_utils import (
    get_attention_raw_dict,
    process_attention_dict,
    get_computation_graph,
    get_nodes_per_level_from_comp_graph_full,
    get_attention_raw_dict_multihead,
    reindex_nodes_per_level,
    translate_comp_graph,
    get_att_dict_per_layer,
    return_edges_in_k_hop,
    get_ATTATTTRIBUTE_edge,
    get_AVGATT_edge,
    average_attention_heads,
)

from torch_geometric.utils import get_num_hops
from visualization_utils import (
    visualize_computation_graph,
)
import torch
from typing import Tuple

def get_edge_scores(
    target_edge: Tuple, comp_graph, comp_graph_new, layer_att_dict, att
):
    assert type(target_edge) == tuple, "target_edge must be a tuple"
    # Get ATTATTRIBUTE & ATTATTRIBUTE_sim scores
    attattribute, attattribute_sim = get_ATTATTTRIBUTE_edge(
        comp_graph=comp_graph,
        comp_graph_new=comp_graph_new,
        layer_att_dict=layer_att_dict,
        target_edge=target_edge,
        verbose=False,
    )
    # Get AVGATT scores
    avgatt = get_AVGATT_edge(att=att, edge=target_edge)

    return attattribute, attattribute_sim, avgatt

def return_is_edge_list_BA_Shapes(edge_list):
    ground_truth_edge_list = []
    for edge in edge_list:
        if edge[0] < 300 or edge[1] < 300:
            ground_truth_edge_list.append(0)
        elif edge[0] >= 300 and edge[1] >= 300:
            ground_truth_edge_list.append(1)
        else:
            raise ValueError("Something wrong with the edge list")
    return ground_truth_edge_list


def experiment_on_target_node(
    target_idx: int, data, model, self_loops=True, multiheads=False,
):
    num_hops = get_num_hops(model)
    num_layers = num_hops

    edge_lists = return_edges_in_k_hop(
        data=data, target_idx=target_idx, hop=2, self_loops=self_loops
    )
    # 3. For all edges in the k-hop neighborhood, we get the attribution scores
    # according to ATTATTRIBUTE, ATTATTRIBUTE_sim, and AVGATT.
    # First, prepare ingredients for analysis

    num_layers = get_num_hops(model)
    if multiheads:
        att_dict_raw = get_attention_raw_dict_multihead(model, data)
    else:
        att_dict_raw = get_attention_raw_dict(model, data)
    att_dict = process_attention_dict(att_dict_raw)
    comp_graph = get_computation_graph(
        edge_index=data.edge_index, k=num_layers, target_idx=target_idx
    )
    (
        nodes_per_level_original,
        num_nodes_per_level,
        true_node_label,
    ) = get_nodes_per_level_from_comp_graph_full(comp_graph=comp_graph)
    nodes_per_level_new = reindex_nodes_per_level(
        nodes_per_level_original, num_nodes_per_level
    )
    comp_graph_new = translate_comp_graph(
        comp_graph=comp_graph,
        nodes_per_level_new=nodes_per_level_new,
        nodes_per_level_original=nodes_per_level_original,
    )
    layer_att_dict = get_att_dict_per_layer(
        comp_graph=comp_graph, comp_graph_new=comp_graph_new, att_dict=att_dict
    )

    # Get results for all edges in the k-hop neighborhood
    attattribute_list, attattribute_sim_list, avgatt_list = [], [], []
    # Get the attention weights again
    with torch.no_grad():
        model(data.x, data.edge_index, return_att=True)
        att = model.att 
        att = average_attention_heads(att)
        model.att = att

    for current_edge in edge_lists:
        attattribute, attattribute_sim, avgatt = get_edge_scores(
            target_edge=tuple(current_edge),
            comp_graph=comp_graph,
            comp_graph_new=comp_graph_new,
            layer_att_dict=layer_att_dict,
            att=att,
        )
        attattribute_list.append(attattribute)
        attattribute_sim_list.append(attattribute_sim)
        avgatt_list.append(avgatt)

    ground_truth_edge_list = return_is_edge_list_BA_Shapes(edge_lists)

    return (
        attattribute_list,
        attattribute_sim_list,
        avgatt_list,
        ground_truth_edge_list,
    )

GAT 3 layer 1 head

In [54]:
import torch
import networkx as nx
from torch_geometric.utils import to_networkx, k_hop_subgraph, remove_self_loops, add_self_loops
from torch_geometric.data import Data

dataset_name = 'BA-Shapes'
model_name = 'GAT_BAShapes_3L1H'

# Load the data
data = torch.load(f'/workspace/{dataset_name}.pt',map_location ='cpu')
# Load the model
model = torch.load(f'/workspace/{model_name}.pt',map_location ='cpu')
model.eval()
# Get the attention weights
with torch.no_grad():
    out = model(data.x, data.edge_index, return_att=True)
    att = model.att 
    # att = average_attention_heads(att)
    # model.att = att

# data.edge_index = add_self_loops(remove_self_loops(data.edge_index)[0])[0]
# G = to_networkx(data, to_undirected=True)   

In [55]:
attattribute_list, attattribute_sim_list, avgatt_list = [], [], []
ground_truth_edge_list = []

for target_node in range(300, 700):
    target_node_results = experiment_on_target_node(
        target_idx=target_node,
        data=data,
        model=model,
        self_loops=True,
    )
    attattribute_list_curr = target_node_results[0]
    attattribute_sim_list_curr = target_node_results[1]
    avgatt_list_curr = target_node_results[2]
    ground_truth_edge_list_curr = target_node_results[3]

    attattribute_list.extend(attattribute_list_curr)
    attattribute_sim_list.extend(attattribute_sim_list_curr)
    avgatt_list.extend(avgatt_list_curr)
    ground_truth_edge_list.extend(ground_truth_edge_list_curr)

torch.save(
    torch.Tensor(attattribute_list),
    f"/workspace/{dataset_name}_{model_name}_attattribute_list_Accuracy_test.pt",
)
torch.save(
    torch.Tensor(attattribute_sim_list),
    f"/workspace/{dataset_name}_{model_name}_attattribute_sim_list_Accuracy_test.pt",
)
torch.save(
    torch.Tensor(avgatt_list), f"/workspace/{dataset_name}_{model_name}_avgatt_list_Accuracy_test.pt"
)
torch.save(
    torch.Tensor(ground_truth_edge_list),
    f"/workspace/{dataset_name}_{model_name}_ground_truth_edge_list_Accuracy_test.pt",
)

In [56]:
# Load experiments from local

attattribute_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_attattribute_list_Accuracy_test.pt"
)
attattribute_sim_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_attattribute_sim_list_Accuracy_test.pt"
)
avgatt_list = torch.load(f"/workspace/{dataset_name}_{model_name}_avgatt_list_Accuracy_test.pt")
ground_truth_edge_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_ground_truth_edge_list_Accuracy_test.pt"
)

In [57]:
from scipy.stats import kendalltau, spearmanr, pearsonr
from sklearn.metrics import roc_auc_score

ground_truth_edge_list = torch.Tensor(ground_truth_edge_list)
# Also include a random baseline
random_attr = torch.rand(ground_truth_edge_list.shape)
random_attr_roc_auc = roc_auc_score(ground_truth_edge_list, random_attr)
attattribute_roc_auc = roc_auc_score(ground_truth_edge_list, attattribute_list)
attattribute_sim_roc_auc = roc_auc_score(ground_truth_edge_list, attattribute_sim_list)
avgatt_roc_auc = roc_auc_score(ground_truth_edge_list, avgatt_list)

# Print results
print("ROC AUC for ATTATTRIBUTE / ATTATTRIBUTE_SIM / AVGATT / RANDOM")
print(f"{attattribute_roc_auc:.4f}, {attattribute_sim_roc_auc:.4f}, {avgatt_roc_auc:.4f}, {random_attr_roc_auc:.4f}")

ROC AUC for ATTATTRIBUTE / ATTATTRIBUTE_SIM / AVGATT / RANDOM
0.6897, 0.8416, 0.7559, 0.5085


GAT 3 layer 2 head

In [58]:
import torch
import networkx as nx
from attention_analysis_utils import average_attention_heads
from torch_geometric.utils import to_networkx, k_hop_subgraph, remove_self_loops, add_self_loops
from torch_geometric.data import Data

dataset_name = 'BA-Shapes'
model_name = 'GAT_BAShapes_3L2H'

# Load the data
data = torch.load(f'/workspace/{dataset_name}.pt',map_location ='cpu')
# Load the model
model = torch.load(f'/workspace/{model_name}.pt',map_location ='cpu')
model.eval()
# Get the attention weights
with torch.no_grad():
    out = model(data.x, data.edge_index, return_att=True)
    att = model.att 
    att = average_attention_heads(att)
    model.att = att

# data.edge_index = add_self_loops(remove_self_loops(data.edge_index)[0])[0]
# G = to_networkx(data, to_undirected=True)   

In [59]:
attattribute_list, attattribute_sim_list, avgatt_list = [], [], []
ground_truth_edge_list = []

for target_node in range(300, 700):
    target_node_results = experiment_on_target_node(
        target_idx=target_node,
        data=data,
        model=model,
        self_loops=True,
        multiheads=True,
    )
    attattribute_list_curr = target_node_results[0]
    attattribute_sim_list_curr = target_node_results[1]
    avgatt_list_curr = target_node_results[2]
    ground_truth_edge_list_curr = target_node_results[3]

    attattribute_list.extend(attattribute_list_curr)
    attattribute_sim_list.extend(attattribute_sim_list_curr)
    avgatt_list.extend(avgatt_list_curr)
    ground_truth_edge_list.extend(ground_truth_edge_list_curr)

torch.save(
    torch.Tensor(attattribute_list),
    f"/workspace/{dataset_name}_{model_name}_attattribute_list_Accuracy_test.pt",
)
torch.save(
    torch.Tensor(attattribute_sim_list),
    f"/workspace/{dataset_name}_{model_name}_attattribute_sim_list_Accuracy_test.pt",
)
torch.save(
    torch.Tensor(avgatt_list), f"/workspace/{dataset_name}_{model_name}_avgatt_list_Accuracy_test.pt"
)
torch.save(
    torch.Tensor(ground_truth_edge_list),
    f"/workspace/{dataset_name}_{model_name}_ground_truth_edge_list_Accuracy_test.pt",
)

In [60]:
# Load experiments from local

attattribute_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_attattribute_list_Accuracy_test.pt"
)
attattribute_sim_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_attattribute_sim_list_Accuracy_test.pt"
)
avgatt_list = torch.load(f"/workspace/{dataset_name}_{model_name}_avgatt_list_Accuracy_test.pt")
ground_truth_edge_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_ground_truth_edge_list_Accuracy_test.pt"
)

In [61]:
from scipy.stats import kendalltau, spearmanr, pearsonr
from sklearn.metrics import roc_auc_score

ground_truth_edge_list = torch.Tensor(ground_truth_edge_list)
# Also include a random baseline
random_attr = torch.rand(ground_truth_edge_list.shape)
random_attr_roc_auc = roc_auc_score(ground_truth_edge_list, random_attr)
attattribute_roc_auc = roc_auc_score(ground_truth_edge_list, attattribute_list)
attattribute_sim_roc_auc = roc_auc_score(ground_truth_edge_list, attattribute_sim_list)
avgatt_roc_auc = roc_auc_score(ground_truth_edge_list, avgatt_list)

# Print results
print("ROC AUC for ATTATTRIBUTE / ATTATTRIBUTE_SIM / AVGATT / RANDOM")
print(f"{attattribute_roc_auc:.4f}, {attattribute_sim_roc_auc:.4f}, {avgatt_roc_auc:.4f}, {random_attr_roc_auc:.4f}")

ROC AUC for ATTATTRIBUTE / ATTATTRIBUTE_SIM / AVGATT / RANDOM
0.6707, 0.6806, 0.7284, 0.5038


GAT 3 layer 4 head

In [62]:
import torch
import networkx as nx
from attention_analysis_utils import average_attention_heads
from torch_geometric.utils import to_networkx, k_hop_subgraph, remove_self_loops, add_self_loops
from torch_geometric.data import Data

dataset_name = 'BA-Shapes'
model_name = 'GAT_BAShapes_3L4H'

# Load the data
data = torch.load(f'/workspace/{dataset_name}.pt',map_location ='cpu')
# Load the model
model = torch.load(f'/workspace/{model_name}.pt',map_location ='cpu')
model.eval()
# Get the attention weights
with torch.no_grad():
    out = model(data.x, data.edge_index, return_att=True)
    att = model.att 
    att = average_attention_heads(att)
    model.att = att

# data.edge_index = add_self_loops(remove_self_loops(data.edge_index)[0])[0]
# G = to_networkx(data, to_undirected=True)   

In [63]:
attattribute_list, attattribute_sim_list, avgatt_list = [], [], []
ground_truth_edge_list = []

for target_node in range(300, 700):
    target_node_results = experiment_on_target_node(
        target_idx=target_node,
        data=data,
        model=model,
        self_loops=True,
        multiheads=True,
    )
    attattribute_list_curr = target_node_results[0]
    attattribute_sim_list_curr = target_node_results[1]
    avgatt_list_curr = target_node_results[2]
    ground_truth_edge_list_curr = target_node_results[3]

    attattribute_list.extend(attattribute_list_curr)
    attattribute_sim_list.extend(attattribute_sim_list_curr)
    avgatt_list.extend(avgatt_list_curr)
    ground_truth_edge_list.extend(ground_truth_edge_list_curr)

torch.save(
    torch.Tensor(attattribute_list),
    f"/workspace/{dataset_name}_{model_name}_attattribute_list_Accuracy_test.pt",
)
torch.save(
    torch.Tensor(attattribute_sim_list),
    f"/workspace/{dataset_name}_{model_name}_attattribute_sim_list_Accuracy_test.pt",
)
torch.save(
    torch.Tensor(avgatt_list), f"/workspace/{dataset_name}_{model_name}_avgatt_list_Accuracy_test.pt"
)
torch.save(
    torch.Tensor(ground_truth_edge_list),
    f"/workspace/{dataset_name}_{model_name}_ground_truth_edge_list_Accuracy_test.pt",
)

In [64]:
# Load experiments from local

attattribute_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_attattribute_list_Accuracy_test.pt"
)
attattribute_sim_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_attattribute_sim_list_Accuracy_test.pt"
)
avgatt_list = torch.load(f"/workspace/{dataset_name}_{model_name}_avgatt_list_Accuracy_test.pt")
ground_truth_edge_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_ground_truth_edge_list_Accuracy_test.pt"
)

In [65]:
from scipy.stats import kendalltau, spearmanr, pearsonr
from sklearn.metrics import roc_auc_score

ground_truth_edge_list = torch.Tensor(ground_truth_edge_list)
# Also include a random baseline
random_attr = torch.rand(ground_truth_edge_list.shape)
random_attr_roc_auc = roc_auc_score(ground_truth_edge_list, random_attr)
attattribute_roc_auc = roc_auc_score(ground_truth_edge_list, attattribute_list)
attattribute_sim_roc_auc = roc_auc_score(ground_truth_edge_list, attattribute_sim_list)
avgatt_roc_auc = roc_auc_score(ground_truth_edge_list, avgatt_list)

# Print results
print("ROC AUC for ATTATTRIBUTE / ATTATTRIBUTE_SIM / AVGATT / RANDOM")
print(f"{attattribute_roc_auc:.4f}, {attattribute_sim_roc_auc:.4f}, {avgatt_roc_auc:.4f}, {random_attr_roc_auc:.4f}")

ROC AUC for ATTATTRIBUTE / ATTATTRIBUTE_SIM / AVGATT / RANDOM
0.6852, 0.6734, 0.7453, 0.5087


GAT 3 layer 8 head

In [66]:
import torch
import networkx as nx
from attention_analysis_utils import average_attention_heads
from torch_geometric.utils import to_networkx, k_hop_subgraph, remove_self_loops, add_self_loops
from torch_geometric.data import Data

dataset_name = 'BA-Shapes'
model_name = 'GAT_BAShapes_3L8H'

# Load the data
data = torch.load(f'/workspace/{dataset_name}.pt',map_location ='cpu')
# Load the model
model = torch.load(f'/workspace/{model_name}.pt',map_location ='cpu')
model.eval()
# Get the attention weights
with torch.no_grad():
    out = model(data.x, data.edge_index, return_att=True)
    att = model.att 
    att = average_attention_heads(att)
    model.att = att

data.edge_index = add_self_loops(remove_self_loops(data.edge_index)[0])[0]
G = to_networkx(data, to_undirected=True)   

In [67]:
attattribute_list, attattribute_sim_list, avgatt_list = [], [], []
ground_truth_edge_list = []

for target_node in range(300, 700):
    target_node_results = experiment_on_target_node(
        target_idx=target_node,
        data=data,
        model=model,
        self_loops=True,
        multiheads=True,
    )
    attattribute_list_curr = target_node_results[0]
    attattribute_sim_list_curr = target_node_results[1]
    avgatt_list_curr = target_node_results[2]
    ground_truth_edge_list_curr = target_node_results[3]

    attattribute_list.extend(attattribute_list_curr)
    attattribute_sim_list.extend(attattribute_sim_list_curr)
    avgatt_list.extend(avgatt_list_curr)
    ground_truth_edge_list.extend(ground_truth_edge_list_curr)

torch.save(
    torch.Tensor(attattribute_list),
    f"/workspace/{dataset_name}_{model_name}_attattribute_list_Accuracy_test.pt",
)
torch.save(
    torch.Tensor(attattribute_sim_list),
    f"/workspace/{dataset_name}_{model_name}_attattribute_sim_list_Accuracy_test.pt",
)
torch.save(
    torch.Tensor(avgatt_list), f"/workspace/{dataset_name}_{model_name}_avgatt_list_Accuracy_test.pt"
)
torch.save(
    torch.Tensor(ground_truth_edge_list),
    f"/workspace/{dataset_name}_{model_name}_ground_truth_edge_list_Accuracy_test.pt",
)

In [68]:
# Load experiments from local

attattribute_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_attattribute_list_Accuracy_test.pt"
)
attattribute_sim_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_attattribute_sim_list_Accuracy_test.pt"
)
avgatt_list = torch.load(f"/workspace/{dataset_name}_{model_name}_avgatt_list_Accuracy_test.pt")
ground_truth_edge_list = torch.load(
    f"/workspace/{dataset_name}_{model_name}_ground_truth_edge_list_Accuracy_test.pt"
)

In [69]:
from scipy.stats import kendalltau, spearmanr, pearsonr
from sklearn.metrics import roc_auc_score

ground_truth_edge_list = torch.Tensor(ground_truth_edge_list)
# Also include a random baseline
random_attr = torch.rand(ground_truth_edge_list.shape)
random_attr_roc_auc = roc_auc_score(ground_truth_edge_list, random_attr)
attattribute_roc_auc = roc_auc_score(ground_truth_edge_list, attattribute_list)
attattribute_sim_roc_auc = roc_auc_score(ground_truth_edge_list, attattribute_sim_list)
avgatt_roc_auc = roc_auc_score(ground_truth_edge_list, avgatt_list)

# Print results
# print("ROC AUC for ATTATTRIBUTE / ATTATTRIBUTE_SIM / AVGATT / RANDOM")
# print(f"{attattribute_roc_auc:.4f}, {attattribute_sim_roc_auc:.4f}, {avgatt_roc_auc:.4f}, {random_attr_roc_auc:.4f}")

GNNExplainer (3L1H)

In [70]:
# import torch
# import networkx as nx
# from attention_analysis_utils import average_attention_heads
# from torch_geometric.utils import to_networkx, k_hop_subgraph, remove_self_loops, add_self_loops
# from torch_geometric.data import Data

# dataset_name = 'BA-Shapes'
# model_name = 'GAT_BAShapes_3L1H'

# # Load the data
# data = torch.load(f'/workspace/{dataset_name}.pt',map_location ='cpu')
# # Load the model
# model = torch.load(f'/workspace/{model_name}.pt',map_location ='cpu')
# model.eval()
# # Get the attention weights
# with torch.no_grad():
#     out = model(data.x, data.edge_index, return_att=True)
#     att = model.att 
#     att = average_attention_heads(att)
#     model.att = att

In [71]:
# from tqdm import tqdm
# from torch_geometric.explain import Explainer, GNNExplainer

# explainer = Explainer(
#     model=model,
#     algorithm=GNNExplainer(epochs=300),
#     explanation_type='phenomenon',
#     node_mask_type='attributes',
#     edge_mask_type='object',
#     model_config=dict(
#         mode='multiclass_classification',
#         task_level='node',
#         return_type='raw',
#     ),
# )



In [72]:
# # Explanation ROC AUC over all test nodes:
# targets, preds = [], []
# node_indices = range(300, 700)
# for node_index in tqdm(node_indices, leave=False, desc='Train Explainer'):
#     target = data.y
#     explanation = explainer(data.x, data.edge_index, index=node_index,
#                             target=target)

#     _, _, _, hard_edge_mask = k_hop_subgraph(node_index, num_hops=3,
#                                             edge_index=data.edge_index)

#     targets.append(data.edge_mask[hard_edge_mask].cpu())
#     preds.append(explanation.edge_mask[hard_edge_mask].cpu())

# auc = roc_auc_score(torch.cat(targets), torch.cat(preds))
# print(f'Mean ROC AUC (explanation type phenomenon): {auc:.4f}')

In [73]:
# auc

In [74]:
# from tqdm import tqdm
# from torch_geometric.explain import Explainer, PGExplainer

# explainer = Explainer(
#     model=model,
#     algorithm=PGExplainer(epochs=30, lr=0.003),
#     explanation_type='phenomenon',
#     # node_mask_type='attributes',
#     edge_mask_type='object',
#     model_config=dict(
#         mode='multiclass_classification',
#         task_level='node',
#         return_type='raw',
#     ),
# )

# for epoch in range(30):
#     for index in range(300, 700):  # Indices to train against.
#         loss = explainer.algorithm.train(epoch, model, data.x, data.edge_index,
#                                         target=target, index=index)
        
# # Explanation ROC AUC over all test nodes:
# targets, preds = [], []
# node_indices = range(300, 700)
# for node_index in tqdm(node_indices, leave=False, desc='Train Explainer'):
#     target = data.y
#     explanation = explainer(data.x, data.edge_index, index=node_index,
#                             target=target)

#     _, _, _, hard_edge_mask = k_hop_subgraph(node_index, num_hops=3,
#                                             edge_index=data.edge_index)

#     targets.append(data.edge_mask[hard_edge_mask].cpu())
#     preds.append(explanation.edge_mask[hard_edge_mask].cpu())

# auc_pgexpl = roc_auc_score(torch.cat(targets), torch.cat(preds))
# print(f'Mean ROC AUC (explanation type phenomenon): {auc_pgexpl:.4f}')

In [75]:
# auc_pgexpl