In [None]:
import pickle as pkl
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.stats import hmean
import json
import torch
import pandas as pd

import seaborn as sns

plt.style.use('seaborn-v0_8')
pal = plt.rcParams['axes.prop_cycle'].by_key()['color']
import matplotlib as mpl

In [None]:
models = {
    # "naive_model": "Base",
    "llama3_ft": "ICL-FT",
    "rag_trained": "RAG-FT",
    "memoryllm": "MemLLM",
    "gnm": "GNM",
}

model_keys = list(models.keys())
model_names = list(models.values())

In [None]:
model_keys, model_names

In [None]:
data = dict()

for model_key, model_name in models.items():

    data_root = f"../saved_evals/mixed_documents_seq_10/{model_key}/summary.json"

    # Open the file and load the data
    with open(data_root, 'r') as json_file:
        data_dict = json.load(json_file)

    data[model_key] = data_dict
    data[model_key]["model_name"] = model_name

In [None]:
def upper_triangle(mat, k=0):
    """
    Keep diagonal+above (k=0). If you want strictly above diagonal, use k=1.
    """
    mat = np.array(mat, dtype=float, copy=True)
    mat[np.tril_indices_from(mat, k=k-1)] = np.nan  # below-diagonal -> NaN
    return mat

In [None]:
def annotate_heatmap(ax, mat, fmt="{:.2f}", fontsize=7, threshold=0.5):
    """
    Writes values into each non-NaN cell of `mat`.
    `threshold` controls when to switch text color for readability.
    """
    for (i, j), v in np.ndenumerate(mat):
        if np.isnan(v):
            continue
        color = "white" if v < threshold else "black"
        ax.text(j, i, fmt.format(v), ha="center", va="center",
                fontsize=fontsize, color=color, fontfamily="monospace")


In [None]:

colors = [ '#C44E52',  '#CCB974', '#55A868']


cmap = mpl.colors.LinearSegmentedColormap.from_list("custom6", colors, N=256).copy()
cmap.set_bad(alpha=0.0)  # keep your NaN-masked triangle transparent

In [None]:
# n_steps = 10


for model_key, model_name in models.items():

    print(model_key, model_name)

    model_data = data[model_key]
    view_df = pd.DataFrame(model_data)
    display(view_df)

    print("-"*20)

    item = view_df.iloc[0]

    # f_acc_mat = np.array(item["fact_accuracy_matrix"])[:n_steps, :n_steps]

    # f_sel_mat = np.array(item["fact_selectivity_matrix"])[:n_steps, :n_steps]

    # f_spec_mat = np.array(item["fact_specificity_matrix"])[:n_steps, :n_steps]

    f_acc_mat  = np.array(item["fact_accuracy_matrix"]) * 100
    f_sel_mat  = np.array(item["fact_selectivity_matrix"]) * 100
    f_spec_mat = np.array(item["fact_specificity_matrix"]) * 100
    r_prec_mat  = np.array(item["refusal_precision_matrix"]) * 100
    r_rec_mat = np.array(item["refusal_recall_matrix"]) * 100


    mats = {
        "Fact Accuracy":     upper_triangle(f_acc_mat, k=0),
        "Fact Selectivity":  upper_triangle(f_sel_mat, k=0),
        "Fact Specificity":  upper_triangle(f_spec_mat, k=0),
        "Refusal Precision":  upper_triangle(r_prec_mat, k=0),
        "Refusal Recall":  upper_triangle(r_rec_mat, k=0),
    }

    # vmin, vmax = 0.0, 1.0
    vmin, vmax = 0, 100


    fig, axes = plt.subplots(1, 5, 
                            figsize=(13, 3.5), 
                            # sharey=True, 
                            constrained_layout=True
                        )

    ims = []
    for ax, (title, mat) in zip(axes, mats.items()):
        mat = np.asarray(mat)


        im = ax.imshow(mat, vmin=vmin, vmax=vmax, cmap=cmap, interpolation="nearest")

        # im = ax.imshow(mat, cmap="viridis", vmin=vmin, vmax=vmax, interpolation="nearest")
        annotate_heatmap(ax, mat, fmt="{:.0f}", fontsize=7, threshold=(vmin+vmax)/2)
        ims.append(im)
        
        m, n = mat.shape

        # major ticks (labels) at centers
        ax.set_xticks(np.arange(n))
        ax.set_yticks(np.arange(m))

        ax.set_xticklabels(np.arange(n)+1)
        ax.set_yticklabels(np.arange(m)+1)

        # minor ticks (grid) at boundaries between cells
        ax.set_xticks(np.arange(0.5, n, 1), minor=True)
        ax.set_yticks(np.arange(0.5, m, 1), minor=True)

        ax.grid(False)
        ax.grid(which="minor", linestyle="-")
        ax.tick_params(which="minor", bottom=False, left=False)  # hide minor tick marks
        
        ax.set_title(title)
        ax.set_xlabel("Documents Seen")
        ax.set_ylabel("Query Timestep")

    # One shared colorbar for all three
    # fig.suptitle(model_name)
    fig.colorbar(ims[0], ax=axes, shrink=0.9, pad=0.02)
    # fig.tight_layout()

    plt.savefig(f"../plots/warmup_{model_key}_heatmaps.png", dpi=600)
    plt.show()

    # break

In [None]:
# n_steps = 10

fig, axes = plt.subplots(len(models), 5, 
                        figsize=(20, 5*len(models)), 
                        # sharey=True, 
                        constrained_layout=True
                    )

row_idx = 0
for model_key, model_name in models.items():

    print(model_key, model_name)

    model_data = data[model_key]
    view_df = pd.DataFrame(model_data)
    display(view_df)

    print("-"*20)

    item = view_df.iloc[0]

    # f_acc_mat = np.array(item["fact_accuracy_matrix"]) * 100

    # f_sel_mat = np.array(item["fact_selectivity_matrix"]) * 100

    # f_spec_mat = np.array(item["fact_specificity_matrix"]) * 100

    # mats = {
    #     "Accuracy":     upper_triangle(f_acc_mat, k=0),
    #     "Selectivity":  upper_triangle(f_sel_mat, k=0),
    #     "Specificity":  upper_triangle(f_spec_mat, k=0),
    # }

    f_acc_mat  = np.array(item["fact_accuracy_matrix"]) * 100
    f_sel_mat  = np.array(item["fact_selectivity_matrix"]) * 100
    f_spec_mat = np.array(item["fact_specificity_matrix"]) * 100
    r_prec_mat  = np.array(item["refusal_precision_matrix"]) * 100
    r_rec_mat = np.array(item["refusal_recall_matrix"]) * 100


    mats = {
        "Fact Accuracy":     upper_triangle(f_acc_mat, k=0),
        "Fact Selectivity":  upper_triangle(f_sel_mat, k=0),
        "Fact Specificity":  upper_triangle(f_spec_mat, k=0),
        "Refusal Precision":  upper_triangle(r_prec_mat, k=0),
        "Refusal Recall":  upper_triangle(r_rec_mat, k=0),
    }

    vmin, vmax = 0, 100

    col_idx = 0
    ims = []
    for ax, (title, mat) in zip(axes[row_idx, :], mats.items()):
        mat = np.asarray(mat)

        im = ax.imshow(mat, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
        annotate_heatmap(ax, mat, fmt="{:.0f}", fontsize=7, threshold=(vmin+vmax)/2)
        ims.append(im)
        
        m, n = mat.shape

        # major ticks (labels) at centers
        ax.set_xticks(np.arange(n))
        ax.set_yticks(np.arange(m))

        ax.set_xticklabels(np.arange(n)+1)
        ax.set_yticklabels(np.arange(m)+1)

        # minor ticks (grid) at boundaries between cells
        ax.set_xticks(np.arange(0.5, n, 1), minor=True)
        ax.set_yticks(np.arange(0.5, m, 1), minor=True)

        ax.grid(False)
        ax.grid(which="minor", linestyle="-", linewidth=1)
        ax.tick_params(which="minor", bottom=False, left=False)  # hide minor tick marks
        
        ax.set_title(title, fontsize=16)
        ax.set_xlabel("Documents Seen", fontsize=16)
        if col_idx == 0:
            ax.set_ylabel(f"{model_name}\nQuery Timestep", fontsize=16)
            # def escape_mathtext(s: str) -> str:
            #     return s.replace("\\", r"\\").replace("_", r"\_")

            # mn = escape_mathtext(model_name)

            # ax.set_ylabel(rf"$\bf{{\Large {mn}}}$" + "\nQuery Timestep", fontsize=16)
            # replace ax.set_ylabel(...) with:
            ax.set_ylabel("")  # clear default ylabel

            # big + bold model name
            ax.text(
                -0.18, 0.5, model_name,
                transform=ax.transAxes,
                rotation=90,
                va="center", ha="center",
                fontsize=20, fontweight="bold",
                clip_on=False,
            )

            # normal subtitle
            ax.text(
                -0.10, 0.5, "Query Timestep",
                transform=ax.transAxes,
                rotation=90,
                va="center", ha="center",
                fontsize=16, fontweight="normal",
                clip_on=False,
            )

        else:
            ax.set_ylabel(f"Query Timestep", fontsize=16)

        col_idx += 1

    row_idx += 1

fig.colorbar(ims[0], ax=axes, shrink=0.4, pad=0.02, orientation="horizontal")
plt.savefig("../plots/mixed_documents_seq_10_all_heatmaps.png", dpi=600)
plt.show()

    # break

In [None]:
# Load data from mixed_documents (20 elements)
data_20 = dict()

for model_key, model_name in models.items():
    data_root = f"../saved_evals/mixed_documents/{model_key}/summary.json"
    with open(data_root, 'r') as json_file:
        data_dict = json.load(json_file)
    data_20[model_key] = data_dict
    data_20[model_key]["model_name"] = model_name

fig, axes = plt.subplots(len(models), 5, 
                        figsize=(20, 5*len(models)), 
                        constrained_layout=True
                    )

row_idx = 0
for model_key, model_name in models.items():

    print(model_key, model_name)

    model_data = data_20[model_key]
    view_df = pd.DataFrame(model_data)
    display(view_df)

    print("-"*20)

    item = view_df.iloc[0]

    f_acc_mat  = np.array(item["fact_accuracy_matrix"]) * 100
    f_sel_mat  = np.array(item["fact_selectivity_matrix"]) * 100
    f_spec_mat = np.array(item["fact_specificity_matrix"]) * 100
    r_prec_mat  = np.array(item["refusal_precision_matrix"]) * 100
    r_rec_mat = np.array(item["refusal_recall_matrix"]) * 100

    mats = {
        "Fact Accuracy":     upper_triangle(f_acc_mat, k=0),
        "Fact Selectivity":  upper_triangle(f_sel_mat, k=0),
        "Fact Specificity":  upper_triangle(f_spec_mat, k=0),
        "Refusal Precision":  upper_triangle(r_prec_mat, k=0),
        "Refusal Recall":  upper_triangle(r_rec_mat, k=0),
    }

    vmin, vmax = 0, 100

    col_idx = 0
    ims = []
    for ax, (title, mat) in zip(axes[row_idx, :], mats.items()):
        mat = np.asarray(mat)

        im = ax.imshow(mat, cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
        annotate_heatmap(ax, mat, fmt="{:.0f}", fontsize=7, threshold=(vmin+vmax)/2)
        ims.append(im)
        
        m, n = mat.shape

        ax.set_xticks(np.arange(n))
        ax.set_yticks(np.arange(m))

        ax.set_xticklabels(np.arange(n)+1)
        ax.set_yticklabels(np.arange(m)+1)

        ax.set_xticks(np.arange(0.5, n, 1), minor=True)
        ax.set_yticks(np.arange(0.5, m, 1), minor=True)

        ax.grid(False)
        ax.grid(which="minor", linestyle="-", linewidth=1)
        ax.tick_params(which="minor", bottom=False, left=False)
        
        ax.set_title(title, fontsize=16)
        ax.set_xlabel("Documents Seen", fontsize=16)
        if col_idx == 0:
            ax.set_ylabel("")

            ax.text(
                -0.18, 0.5, model_name,
                transform=ax.transAxes,
                rotation=90,
                va="center", ha="center",
                fontsize=20, fontweight="bold",
                clip_on=False,
            )

            ax.text(
                -0.10, 0.5, "Query Timestep",
                transform=ax.transAxes,
                rotation=90,
                va="center", ha="center",
                fontsize=16, fontweight="normal",
                clip_on=False,
            )

        else:
            ax.set_ylabel(f"Query Timestep", fontsize=16)

        col_idx += 1

    row_idx += 1

fig.colorbar(ims[0], ax=axes, shrink=0.4, pad=0.02, orientation="horizontal")
plt.savefig("../plots/mixed_documents_seq_20_all_heatmaps.png", dpi=600)
plt.show()