In [None]:
import json
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import PowerNorm


In [None]:
def plot_attention_heatmap_mpa(
    pro_scores_file,
    anti_scores_file,
    pro_indices_file,
    anti_indices_file,
    correct_key="correct_indices",   #key for correctness filtering
    gender_key=None,                 #optional: "female_indices", "male_indices", or None
    mode="encoder",                  #"encoder" or "cross"
    title="",
    save_path=None
):
    """
    Plot attention heatmap for Pro-S and Anti-S JSON files based on common sentence filtering
    with both correctness and optional gender keys.

    Args:
        pro_scores_file (str): Path to JSON with Pro-S attention scores.
        anti_scores_file (str): Path to JSON with Anti-S attention scores.
        pro_indices_file (str): Path to Pro-S index JSON.
        anti_indices_file (str): Path to Anti-S index JSON.
        correct_key (str): Key for correctness filtering (e.g., "correct_indices", "incorrect_indices").
        gender_key (str or None): Optional key for gender filtering (e.g., "female_indices", "male_indices").
        mode (str): "encoder" or "cross".
        title (str): Title for the plot.
        save_path (str or None): If provided, saves the plot.
    """

    with open(pro_scores_file, "r") as f:
        pro_scores = {entry["sentence_num"]: entry for entry in json.load(f)}
    with open(anti_scores_file, "r") as f:
        anti_scores = {entry["sentence_num"]: entry for entry in json.load(f)}
    with open(pro_indices_file, "r") as f:
        pro_idx = json.load(f)
    with open(anti_indices_file, "r") as f:
        anti_idx = json.load(f)

    # correctness key
    common_correct = sorted(set(pro_idx[correct_key]) & set(anti_idx[correct_key]))

    # gender filtering (optional)
    if gender_key:
        pro_gender = set(pro_idx[gender_key])
        anti_gender = set(anti_idx[gender_key])
        pro_final = [i for i in common_correct if i in pro_gender]
        anti_final = [i for i in common_correct if i in anti_gender]        

    else:
        pro_final = anti_final = common_correct    

    # extract scores and average for each set
    pro_scores_list = [
        np.array(pro_scores[idx]["attributions"][f"{mode}_scores"])
        for idx in pro_final
    ]
    anti_scores_list = [
        np.array(anti_scores[idx]["attributions"][f"{mode}_scores"])
        for idx in anti_final
    ]

    # combine and compute average
    all_scores = pro_scores_list + anti_scores_list
    all_scores[all_scores == 0] = np.nan # we set unavailabe scores (skipped sentences) as equal to 1, now convert to NaN
    avg_scores = np.nanmean(np.stack(all_scores), axis=0)

    # plot heatmaps
    plt.figure(figsize=(10, 8))
    sns.heatmap(avg_scores, annot=True, fmt=".2f", cmap="viridis", cbar=True, norm=PowerNorm(gamma=0.5, vmin=0, vmax=0.4)) #standardized color scale for comparative purposes, can be adjusted
    plt.title(title)
    plt.xlabel("Heads")
    plt.ylabel("Layers")
    plt.xticks(np.arange(avg_scores.shape[1]) + 0.5, labels=np.arange(1, avg_scores.shape[1] + 1))
    plt.yticks(np.arange(avg_scores.shape[0]) + 0.5, labels=np.arange(1, avg_scores.shape[0] + 1))

    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
        print(f"Heatmap saved to {save_path}")
    else:
        plt.show()


In [None]:
plot_attention_heatmap_mpa(
    pro_scores_file="data/attribution_scores/attention/opus/attribution_scores_pro.json",
    anti_scores_file="data/attribution_scores/attention/opus/attribution_scores_anti.json",
    pro_indices_file="indices/opus/opus_pro.json",
    anti_indices_file="indices/opus/opus_anti.json",
    correct_key="correct_indices",   # correctly disambiguated
    gender_key=None,                 # both male and female
    mode="encoder",                  # or "cross"
    title="OPUS-MT - Accurate Minimal Pairs (Attention)",
    save_path="figures/opus_encoder_all_correct.png"
)

