# Accuracy experiments with Infection Benchmark

In [57]:
# Import the utility functions
from attention_analysis_utils import (
    get_attention_raw_dict,
    process_attention_dict,
    get_computation_graph,
    get_nodes_per_level_from_comp_graph_full,
    get_attention_raw_dict_multihead,
    reindex_nodes_per_level,
    translate_comp_graph,
    get_att_dict_per_layer,
    return_edges_in_k_hop,
    get_ATTATTTRIBUTE_edge,
    get_AVGATT_edge,
    average_attention_heads,
)
from torch_geometric.utils import get_num_hops
import torch
from typing import Tuple

def get_edge_scores(
    target_edge: Tuple, comp_graph, comp_graph_new, layer_att_dict, att
):
    assert type(target_edge) == tuple, "target_edge must be a tuple"
    # Get ATTATTRIBUTE & ATTATTRIBUTE_sim scores
    attattribute, attattribute_sim = get_ATTATTTRIBUTE_edge(
        comp_graph=comp_graph,
        comp_graph_new=comp_graph_new,
        layer_att_dict=layer_att_dict,
        target_edge=target_edge,
        verbose=False,
    )
    # Get AVGATT scores
    avgatt = get_AVGATT_edge(att=att, edge=target_edge)

    return attattribute, attattribute_sim, avgatt

def return_is_edge_list_Infection(edge_list, path_expl):
    # Assuming path_expl is something like:
    # [1215, 1024, 606, 10]. We need all edges in edge_list
    # to be checked for the presence of this path.
    expl_edge_set = {(path_expl[i], path_expl[i+1]) for i in range(len(path_expl) - 1)}

    ground_truth_edge_list = []
    for edge in edge_list:
        if tuple(edge) in expl_edge_set:
            ground_truth_edge_list.append(1)
        else:
            ground_truth_edge_list.append(0)
    return ground_truth_edge_list

def experiment_on_target_node(
    target_idx: int, data, model, path_expl, self_loops=True, multiheads=False,
):
    num_hops = get_num_hops(model)
    num_layers = num_hops

    edge_lists = return_edges_in_k_hop(
        data=data, target_idx=target_idx, hop=2, self_loops=self_loops
    )
    # 3. For all edges in the k-hop neighborhood, we get the attribution scores
    # according to ATTATTRIBUTE, ATTATTRIBUTE_sim, and AVGATT.
    # First, prepare ingredients for analysis

    num_layers = get_num_hops(model)
    if multiheads:
        att_dict_raw = get_attention_raw_dict_multihead(model, data)
    else:
        att_dict_raw = get_attention_raw_dict(model, data)
    att_dict = process_attention_dict(att_dict_raw)
    comp_graph = get_computation_graph(
        edge_index=data.edge_index, k=num_layers, target_idx=target_idx
    )
    (
        nodes_per_level_original,
        num_nodes_per_level,
        true_node_label,
    ) = get_nodes_per_level_from_comp_graph_full(comp_graph=comp_graph)
    nodes_per_level_new = reindex_nodes_per_level(
        nodes_per_level_original, num_nodes_per_level
    )
    comp_graph_new = translate_comp_graph(
        comp_graph=comp_graph,
        nodes_per_level_new=nodes_per_level_new,
        nodes_per_level_original=nodes_per_level_original,
    )
    layer_att_dict = get_att_dict_per_layer(
        comp_graph=comp_graph, comp_graph_new=comp_graph_new, att_dict=att_dict
    )

    # Get results for all edges in the k-hop neighborhood
    attattribute_list, attattribute_sim_list, avgatt_list = [], [], []
    # Get the attention weights again
    with torch.no_grad():
        model(data.x, data.edge_index, return_att=True)
        att = model.att 
        att = average_attention_heads(att)
        model.att = att

    for current_edge in edge_lists:
        attattribute, attattribute_sim, avgatt = get_edge_scores(
            target_edge=tuple(current_edge),
            comp_graph=comp_graph,
            comp_graph_new=comp_graph_new,
            layer_att_dict=layer_att_dict,
            att=att,
        )
        attattribute_list.append(attattribute)
        attattribute_sim_list.append(attattribute_sim)
        avgatt_list.append(avgatt)

    ground_truth_edge_list = return_is_edge_list_Infection(edge_lists, path_expl)

    return (
        attattribute_list,
        attattribute_sim_list,
        avgatt_list,
        ground_truth_edge_list,
    )

GAT 3 layer 1 head

In [None]:
import torch
import networkx as nx
from torch_geometric.utils import to_networkx, k_hop_subgraph, remove_self_loops, add_self_loops
from torch_geometric.data import Data
import os

curr_work_dir = os.getcwd()

dataset_name = 'Infection_50003d_sp'
model_name = f'GAT_infection_3L1H_sp'

# Load the data
data = torch.load(f'/{curr_work_dir}/{dataset_name}.pt',map_location ='cpu')
# Load the model
model = torch.load(f'/{curr_work_dir}/{model_name}.pt',map_location ='cpu')
model.eval()
# Get the attention weights
with torch.no_grad():
    out = model(data.x, data.edge_index, return_att=True)
    att = model.att 

In [59]:
attattribute_list, attattribute_sim_list, avgatt_list = [], [], []
ground_truth_edge_list = []

for idx, target_node in enumerate(data.unique_solution_nodes):
    path_expl = data.unique_solution_explanations[idx]
    target_node_results = experiment_on_target_node(
        target_idx=target_node,
        data=data,
        model=model,
        path_expl=path_expl,
        self_loops=True,
    )
    attattribute_list_curr = target_node_results[0]
    attattribute_sim_list_curr = target_node_results[1]
    avgatt_list_curr = target_node_results[2]
    ground_truth_edge_list_curr = target_node_results[3]

    attattribute_list.extend(attattribute_list_curr)
    attattribute_sim_list.extend(attattribute_sim_list_curr)
    avgatt_list.extend(avgatt_list_curr)
    ground_truth_edge_list.extend(ground_truth_edge_list_curr)

In [61]:
from sklearn.metrics import roc_auc_score

ground_truth_edge_list = torch.Tensor(ground_truth_edge_list)
# Also include a random baseline
random_attr = torch.rand(ground_truth_edge_list.shape)
random_attr_roc_auc = roc_auc_score(ground_truth_edge_list, random_attr)
attattribute_roc_auc = roc_auc_score(ground_truth_edge_list, attattribute_list)
attattribute_sim_roc_auc = roc_auc_score(ground_truth_edge_list, attattribute_sim_list)
avgatt_roc_auc = roc_auc_score(ground_truth_edge_list, avgatt_list)

# Print results
print("ROC AUC for ATTATTRIBUTE / ATTATTRIBUTE_SIM / AVGATT / RANDOM")
print(f"{attattribute_roc_auc:.4f}, {attattribute_sim_roc_auc:.4f}, {avgatt_roc_auc:.4f}, {random_attr_roc_auc:.4f}")

ROC AUC for ATTATTRIBUTE / ATTATTRIBUTE_SIM / AVGATT / RANDOM
0.9359, 0.9405, 0.8852, 0.5073
