In [None]:
import pickle as pkl
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.stats import hmean
import json
import torch
import pandas as pd
import seaborn as sns

plt.style.use('seaborn-v0_8')
pal = sns.color_palette("deep", 10)
model_colors = {
    "Ablation": pal[7],
    "ICL-FT": pal[8],
    "RAG-FT": pal[6],
    "MemLLM": pal[9],
    "GNM": pal[0],
}

# Data from the table
# data = {
#     'Dataset': ['train', 'test-id', 'test-ood'],
#     'RAG-FT': [100, 49.4, 0.0],
#     'ICL-FT': [100, 63.3, 4.4],
#     'GNM': [100, 100, 67.4]
# }

data = {
    'Dataset': ['ICL-FT', 'RAG-FT','GNM'],
    'train': [100, 100, 100],
    'test-id': [63, 49, 100],
    'test-ood': [4.4, 0.0, 67]
}

data_df = pd.DataFrame(data)

# Metrics to plot (excluding Acc. since it's 0 for two models)
metrics = ['ICL-FT', 'RAG-FT', "GNM"]
metric_labels = ['ICL-FT', 'RAG-FT', "GNM"]

fig, ax = plt.subplots(figsize=(7.5, 3))

x = np.arange(len(metrics))
bar_width = 0.25  # Reduced width to prevent overlap

# Colors for arrows and bars
id_color = "#CC8426"  # Blue for in-distribution
ood_color = "#7213AD"  # Orange for out-of-distribution

# Plot bars for GNM, RAG-FT, ICL-FT, grouped by the model
bars1 = ax.bar(x - bar_width, data_df['train'], width=bar_width, label='train formats', color='#6B7280', alpha=0.9)
bars2 = ax.bar(x, data_df['test-id'], width=bar_width, label='test-id formats', color=id_color, alpha=0.9)
bars3 = ax.bar(x + bar_width, data_df['test-ood'], width=bar_width, label='test-ood formats', color=ood_color, alpha=0.9)

# Add data labels on each bar (white text inside for tall bars, black text on top for short bars)
threshold = 15  # Values below this go on top in black

for i, (bar, value) in enumerate(zip(bars1, data_df['train'])):
    if value >= threshold:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() - 5, f'{value:.0f}', 
                ha='center', va='top', fontsize=10, fontweight='normal', color='white')
    elif value >= 0:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1.5, f'{value:.0f}', 
                ha='center', va='bottom', fontsize=10, fontweight='normal', color='black')

for i, (bar, value) in enumerate(zip(bars2, data_df['test-id'])):
    if value >= threshold:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() - 5, f'{value:.0f}', 
                ha='center', va='top', fontsize=10, fontweight='normal', color='white')
    elif value >= 0:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1.5, f'{value:.0f}', 
                ha='center', va='bottom', fontsize=10, fontweight='normal', color='black')

for i, (bar, value) in enumerate(zip(bars3, data_df['test-ood'])):
    if value >= threshold:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() - 5, f'{value:.0f}', 
                ha='center', va='top', fontsize=10, fontweight='normal', color='white')
    elif value >= 0:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1.5, f'{value:.0f}', 
                ha='center', va='bottom', fontsize=10, fontweight='normal', color='black')

ax.set_ylabel('Format Accuracy (%)', fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels(metric_labels, fontsize=12)
ax.set_ylim(0, 105) 
ax.legend(fontsize=11, loc='upper left', bbox_to_anchor=(0.12, 0.98), frameon=False, ncol=1)
ax.grid(axis='y', alpha=0.3, zorder=0)

plt.tight_layout()
plt.savefig("../plots/format_generalization.png", dpi=600, bbox_inches="tight")
plt.show()


In [None]:
import seaborn as sns

plt.style.use('seaborn-v0_8')
pal = sns.color_palette("deep", 10)
model_colors = {
    "Ablation": pal[7],
    "ICL-FT": pal[8],
    "llama3_ft": pal[8],
    "RAG-FT": pal[5],
    "rag_trained": pal[6],
    "MemLLM": pal[9],
    "GNM": pal[0],
    "gnm": pal[0]
}

# Wilson score interval CI function - returns symmetric error for visual consistency
def wilson_ci(p, n, z=1.96):
    """Returns 95% CI half-width using Wilson score interval (p in 0-100 scale).
    Makes error bars visually symmetric by using max of lower/upper error.
    """
    p_prop = p / 100  # Convert percentage to proportion
    
    denom = 1 + z**2 / n
    center = (p_prop + z**2 / (2*n)) / denom
    margin = z * np.sqrt(p_prop * (1 - p_prop) / n + z**2 / (4 * n**2)) / denom
    
    lower = center - margin
    upper = center + margin
    
    # Compute asymmetric errors
    lower_err = max(0, (p_prop - lower) * 100)
    upper_err = max(0, (upper - p_prop) * 100)
    
    # Use the max of the two for symmetric display
    symmetric_err = max(lower_err, upper_err)
    
    return symmetric_err

# Sample counts for each category
n_samples = {
    'train': 200,
    'test-id': 35,
    'test-ood': 216
}

## now make a version of this graph that swaps it so that the legend is RAG-FT, ICL-FT, GNM, and the grouped bars are train-formats, test-id formats, test-ood formats
fig, ax = plt.subplots(figsize=(6.5, 2.1))
categories = ['train', 'test-id', 'test-ood']
x = np.arange(len(categories))
bar_width = 0.25  # Reduced width to prevent overlap

# Plot bars for each model and store them for labeling
all_bars = []
all_cis = []
for i, model_name in enumerate(data_df['Dataset']):
    model_data = data_df[data_df['Dataset'] == model_name]
    offset = (i - 1) * bar_width  # center the bars
    values = model_data[categories].values.flatten()
    
    # Compute symmetric Wilson CIs for each category
    cis = [wilson_ci(values[j], n_samples[cat]) for j, cat in enumerate(categories)]
    
    bars = ax.bar(x + offset, values, 
           width=bar_width, label=model_name, color=model_colors[model_name], alpha=0.9,
           yerr=cis, capsize=0, error_kw={'linewidth': 1.5})
    all_bars.append((bars, values))
    all_cis.append(cis)

# Add data labels - at bottom of bars, except for short bars which go above error bar
for (bars, values), cis in zip(all_bars, all_cis):
    for bar, value, ci in zip(bars, values, cis):
        if value >= 25:  # Tall enough bars: label at bottom inside
            ax.text(bar.get_x() + bar.get_width()/2, 3, f'{value:.0f}', 
                    ha='center', va='bottom', fontsize=10, fontweight='normal', color='black')
        else:  # Short bars: label above the error bar
            ax.text(bar.get_x() + bar.get_width()/2, value + ci + 2, f'{value:.0f}', 
                    ha='center', va='bottom', fontsize=10, fontweight='normal', color='black')
        
ax.set_xticks(x)
# set tick labels to pretty names for categories
ax.set_xticklabels(['Train Formats', 'Val-ID Formats', 'Test-OOD Formats'], fontsize=14)
ax.legend(fontsize=12, loc='upper left', bbox_to_anchor=(0.63, 1.05), frameon=False, ncol=1)
ax.set_ylabel('Format Accuracy (%)', fontsize=14)
ax.set_ylim(0, 105)
ax.grid(axis='y', alpha=0.3, zorder=0)
plt.savefig("../plots/format_generalization_swapped.png", dpi=600, bbox_inches="tight")
plt.show()

In [None]:
import pickle as pkl
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.stats import hmean
import json
import pandas as pd

plt.style.use('seaborn-v0_8')
pal = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
def diag_means(A):
    A = np.asarray(A)
    n, m = A.shape
    assert n == m, "expect square matrix"
    return np.array([A.diagonal(offset=k).mean() for k in range(-(n-1), n)])


In [None]:
models = {
    "memoryllm": "MemLLM",
    "gnm_ablation": "GNM (Ablation)",
    "gnm": "GNM",
    
}

model_keys = list(models.keys())
model_names = list(models.values())

In [None]:
model_keys, model_names

In [None]:
model_colors = {
    "llama3_ft": pal[1],
    "rag_trained": pal[2],
    "memoryllm": pal[4],
    "gnm": pal[0],
}

In [None]:
data = dict()

for model_key, model_name in models.items():

    data_root = f"../data/gnm_experiments/mixed_documents/{model_key}/summary.json"

    # Open the file and load the data
    with open(data_root, 'r') as json_file:
        data_dict = json.load(json_file)

    data[model_key] = data_dict
    data[model_key]["model_name"] = model_name

In [None]:
len_rollout = 10
step_0 = len_rollout-1

rows = []

for k, v in data.items():

    total_refusal_score = hmean(
        np.array([
            # v['refusal_accuracy'][0],
            v['refusal_precision'][0],
            v['refusal_recall'][0],
            # v['refusal_specificity'][0],
        ])
    )

    #######################
    ### All Upper Triangle
    #######################

    print(k)

    fa_mat = np.array(v["fact_accuracy_matrix"])[0][:len_rollout, :len_rollout]
    all_acc = np.mean(fa_mat[np.triu_indices_from(fa_mat, k=0)])
    print("all_acc", all_acc)

    f_acc_diags = diag_means(fa_mat)[step_0:]

    fs_mat = np.array(v["fact_specificity_matrix"])[0][:len_rollout, :len_rollout]
    all_spec = np.mean(fs_mat[np.triu_indices_from(fs_mat, k=0)])
    print("all_spec", all_spec)

    f_spec_diags = diag_means(fs_mat)[step_0:]

    fs_mat = np.array(v["fact_selectivity_matrix"])[0][:len_rollout, :len_rollout]
    all_sel = np.mean(fs_mat[np.triu_indices_from(fs_mat, k=0)])
    print("all_sel", all_sel)

    f_sel_diags = diag_means(fs_mat)[step_0:]

    total_all_score = harmonic_mean = hmean(
        np.array(
            [
                all_acc,
                all_spec,
                all_sel
            ]
        )
    )
    print("total all score", total_all_score)
    print("---")

    f_score_diags = np.mean(np.vstack([f_acc_diags, f_spec_diags, f_sel_diags]), 0)

    data[k]["fact_score_retention"] = f_score_diags
    
    refusal_f1_over_time = np.mean(np.vstack([v["refusal_precision_over_time"], v["refusal_recall_over_time"]]), 0)[:len_rollout]
    print(refusal_f1_over_time)
    
    scale = 100

    row = [
        # k,
        v["model_name"],
        total_all_score*scale,
        all_acc*scale,
        all_spec*scale,
        all_sel*scale,
        v['format_accuracy'][0]*scale,
        v['format_selectivity'][0]*scale,
        total_refusal_score*scale,
        v['refusal_precision'][0]*scale,
        v['refusal_recall'][0]*scale,
    ]

    rows.append(row)

columns = [
    # "model_key",
    "model_name",
    "total_all_score",
    "all_fact_accuracy",
    "all_fact_specificity",
    "all_fact_selectivity",    
    'format_accuracy',
    'format_selectivity',
    'total_refusal_score',
    'refusal_precision',
    'refusal_recall',

]

results_df = pd.DataFrame(rows, columns=columns)
results_df.round(1)

In [None]:
results_df.columns

In [None]:
ablation_cols = [
    'model_name', 
    'total_all_score', 
    'all_fact_selectivity', 
    'format_accuracy',  
    'total_refusal_score',
]

ablation_comp_models = ["MemLLM", "GNM", "GNM (Ablation)"]

ablation_df = results_df[ablation_cols]
ablation_df = ablation_df[ablation_df["model_name"].isin(ablation_comp_models)]
ablation_df.columns = ["model_name", "Facts (Overall)", "Facts (Sel.)", "Format (Acc.)", "Refusal (F1)"]
ablation_df

In [None]:
print(ablation_df.to_latex(index=False, float_format="%.1f"))