In [None]:
import pickle as pkl
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.stats import hmean
import json
import torch
import pandas as pd
import seaborn as sns
plt.style.use('seaborn-v0_8')
pal = sns.color_palette("deep", 10)
model_colors = {
    "Ablation": pal[7],
    "ICL-FT": pal[8],
    "RAG-FT": pal[6],
    "MemLLM": pal[9],
    "GNM": pal[0],
}

# Wilson score interval CI function - returns symmetric error for visual consistency
def wilson_ci(p, n, z=1.96):
    """Returns 95% CI half-width using Wilson score interval (p in 0-100 scale).
    Makes error bars visually symmetric by using max of lower/upper error.
    """
    p_prop = p / 100  # Convert percentage to proportion
    
    denom = 1 + z**2 / n
    center = (p_prop + z**2 / (2*n)) / denom
    margin = z * np.sqrt(p_prop * (1 - p_prop) / n + z**2 / (4 * n**2)) / denom
    
    lower = center - margin
    upper = center + margin
    
    # Compute asymmetric errors
    lower_err = max(0, (p_prop - lower) * 100)
    upper_err = max(0, (upper - p_prop) * 100)
    
    # Use the max of the two for symmetric display
    symmetric_err = max(lower_err, upper_err)
    
    return symmetric_err

# Sample counts for each metric
n_samples = {
    'Fact Score': 5940,   # 11,880 * 0.5
    'Fact Sel.': 5940,    # 11,880 * 0.5
    'Format Acc.': 216,
    'Refusal F1': 5940    # same as facts
}

# Data from the table
ablation_data = {
    'Model': ['MemLLM', 'GNM (Ablation)', 'GNM'],
    'Fact Score': [38.7, 44.5, 84.1],
    'Fact Sel.': [19.2, 24.8, 77.4],
    'Format Acc.': [0.0, 0.0, 67.4],
    'Refusal F1': [0.1, 51.0, 93.2]
}

ablation_df = pd.DataFrame(ablation_data)

# Metrics to plot (excluding Acc. since it's 0 for two models)
metrics = ['Fact Score', 'Fact Sel.', "Format Acc.", 'Refusal F1']
metric_labels = ['Fact Score', 'Fact Sel.', "Format Acc.", 'Refusal F1']

fig, ax = plt.subplots(figsize=(6.5, 2.5))

x = np.arange(len(metrics))
bar_width = 0.35

# Plot bars for GNM and GNM (Ablation)
gnm_values = [ablation_df.loc[2, m] for m in metrics]
gnm_ablation_values = [ablation_df.loc[1, m] for m in metrics]

# Compute CIs for each metric
gnm_cis = [wilson_ci(v, n_samples[m]) for v, m in zip(gnm_values, metrics)]
ablation_cis = [wilson_ci(v, n_samples[m]) for v, m in zip(gnm_ablation_values, metrics)]

bars_gnm = ax.bar(x - bar_width/2, gnm_values, bar_width, label='GNM', color=model_colors["GNM"], alpha=1.0,
                  yerr=gnm_cis, capsize=0, error_kw={'linewidth': 1.5})
bars_ablation = ax.bar(x + bar_width/2, gnm_ablation_values, bar_width, label='GNM (Ablation)', color=model_colors["Ablation"], alpha=1.0,
                       yerr=ablation_cis, capsize=0, error_kw={'linewidth': 1.5})

# Add data labels above the error bars
for bar, value, ci in zip(bars_gnm, gnm_values, gnm_cis):
    ax.text(bar.get_x() + bar.get_width()/2, value + ci + 2, f'{value:.0f}', 
            ha='center', va='bottom', fontsize=10, fontweight='normal', color='black')

for bar, value, ci in zip(bars_ablation, gnm_ablation_values, ablation_cis):
    ax.text(bar.get_x() + bar.get_width()/2, value + ci + 2, f'{value:.0f}', 
            ha='center', va='bottom', fontsize=10, fontweight='normal', color='black')

# Plot horizontal dotted lines for MemLLM
memllm_values = [ablation_df.loc[0, m] for m in metrics]
for i, val in enumerate(memllm_values):
    ax.plot([i - bar_width, i + bar_width], [val, val], 
            color='black', linestyle='--', linewidth=2, dashes=(5, 2))

# Add MemLLM to legend
ax.plot([], [], color='black', linestyle='--', linewidth=2, dashes=(5, 2), label='MemLLM')

ax.set_ylabel('Percent (%)', fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels(metric_labels, fontsize=12)
ax.set_ylim(0, 110)
ax.legend(fontsize=12, loc='upper center', bbox_to_anchor=(0.5, 1.22), frameon=False, ncol=3)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig("../plots/ablation_comparison.png", dpi=600, bbox_inches="tight")
plt.show()

In [None]:
# okay now make the above figure but have it be a two subplot figure, where the left one is exactly same as the figure above
# but has a title that says "Memory Ablation", and then the right one is a bar chart that shows
#  Target-Distractor Orthogonality, with y-axis being "mean anticorrelation"
# and ablation has value of 1.1602431757193781% and GNM has value of 4.5258498757097206%
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6.5, 3.5), gridspec_kw={'width_ratios': [4, 1]})
# Left subplot: Memory Ablation (same as above)
x = np.arange(len(metrics))
bar_width = 0.35
# Plot bars for GNM and GNM (Ablation)
bars_gnm = ax1.bar(x - bar_width/2, gnm_values, bar_width, label='GNM', color=model_colors['GNM'], alpha=1.0)
bars_ablation = ax1.bar(x + bar_width/2, gnm_ablation_values, bar_width, label='GNM (Ablation)', color=model_colors['GNM'], alpha=0.5)
# Plot horizontal dotted lines for MemLLM
for i, val in enumerate(memllm_values):
    ax1.plot([i - bar_width, i + bar_width], [val, val], 
            color='black', linestyle='--', linewidth=2, dashes=(5, 2))
# Add downward arrows on GNM (Ablation) bars
for i, (bar, gnm_val, ablation_val) in enumerate(zip(bars_ablation, gnm_values, gnm_ablation_values)):
    # Arrow from top of GNM bar pointing down to ablation bar
    arrow_x = bar.get_x() + bar.get_width() / 2
    arrow_y_start = gnm_val - 2
    arrow_y_end = ablation_val + 2
    
    ax1.annotate('', xy=(arrow_x, arrow_y_end), xytext=(arrow_x, arrow_y_start),
                arrowprops=dict(arrowstyle='->', color='red', lw=2.5))
# Add MemLLM to legend
ax1.plot([], [], color='black', linestyle='--', linewidth=2, dashes=(5, 2), label='MemLLM')
ax1.set_ylabel('Percent (%)', fontsize=12)
ax1.set_title('Memory Ablation', fontsize=14)
ax1.set_xticks(x)
ax1.set_xticklabels(metric_labels, fontsize=12)
ax1.set_ylim(0, 100)
ax1.legend(fontsize=12, loc='upper center', bbox_to_anchor=(0.4, 1.03), frameon=False, ncol=1)
ax1.grid(axis='y', alpha=0.3)

# Layers 8-12 only:
# GNM: -0.1290 ± 0.0425 (95% CI)
# Ablation: -0.0798 ± 0.0176 (95% CI)
# Right subplot: Target-Distractor Orthogonality
# Plot bars separately to use different alpha values
# ax2.bar(['Ablation'], [1.1602431757193781], color=model_colors['GNM'], alpha=0.5)
# ax2.bar(['GNM'], [4.5258498757097206], color=model_colors['GNM'], alpha=1.0)
ax2.bar(['Ablation'], [7.98], color=model_colors['GNM'], alpha=0.5)
ax2.bar(['GNM'], [12.90], color=model_colors['GNM'], alpha=1.0)
ax2.set_ylabel('Mean Anticorrelation (%)', fontsize=12)
ax2.set_title('Target-Distractor \n Anticorrelation', fontsize=14)
ax2.grid(axis='y', alpha=0.3)
ax2.tick_params(axis='both', labelsize=12)  # Add this line
plt.tight_layout()
plt.savefig("../plots/ablation_comparison_with_orthogonality.png", dpi=600, bbox_inches="tight")
plt.show()

In [None]:
# Memory Selectivity Index Chart (Line + Bar)
import json
import numpy as np
import matplotlib.pyplot as plt

# plt.style.use('seaborn-v0_8')
# pal = plt.rcParams['axes.prop_cycle'].by_key()['color']

plt.style.use('seaborn-v0_8')
pal = sns.color_palette("deep", 10)
model_colors = {
    "Ablation": pal[7],
    "ICL-FT": pal[8],
    "llama3_ft": pal[8],
    "RAG-FT": pal[6],
    "rag_trained": pal[6],
    "MemLLM": pal[9],
    "GNM": pal[0],
    "gnm": pal[0]
}

# Load data
with open("../saved_evals/analysis_final/memory_selectivity_index.json", "r") as f:
    selectivity_data = json.load(f)

gnm_mean = np.array(selectivity_data["gnm_mean_selectivity_per_layer"])
abl_mean = np.array(selectivity_data["abl_mean_selectivity_per_layer"])
# NOTE: These are already SEM (std / sqrt(N)), NOT raw std
gnm_std = np.array(selectivity_data["gnm_std_selectivity_per_layer"])
abl_std = np.array(selectivity_data["abl_std_selectivity_per_layer"])

gnm_mid_layer_mean = selectivity_data["gnm_mid_layer_mean"]
abl_mid_layer_mean = selectivity_data["abl_mid_layer_mean"]

# 95% CI for line chart - stored values are already SEM, just multiply by 1.96
gnm_ci = 1.96 * gnm_std
abl_ci = 1.96 * abl_std
# gnm_ci = 2.576 * gnm_std
# abl_ci = 2.576 * abl_std

# For bar chart: layers 15-29 (15 layers total), matching the original code
# Original: mid_layers = slice(15, -1)  # layers 15 to second-last layer
mid_layers = slice(15, 30)  # layers 15-29 (15 layers)
# Approximate the SE by averaging the per-layer SEMs in this range
gnm_mid_se = gnm_std[mid_layers].mean()
abl_mid_se = abl_std[mid_layers].mean()
gnm_mid_ci = 1.96 * gnm_mid_se
abl_mid_ci = 1.96 * abl_mid_se
# gnm_mid_ci = 2.576 * gnm_mid_se
# abl_mid_ci = 2.576 * abl_mid_se

layers = np.arange(len(gnm_mean))

# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6.5, 3.0), gridspec_kw={'width_ratios': [3.5, 1]})

# Left subplot: Line chart
ax1.plot(layers, gnm_mean, 'o-', color=model_colors["GNM"], label='GNM', markersize=4, linewidth=1.5, alpha=1.0)
ax1.fill_between(layers, gnm_mean - gnm_ci, gnm_mean + gnm_ci, color=model_colors["GNM"], alpha=0.2)

ax1.plot(layers, abl_mean, 's-', color=model_colors["Ablation"], label='GNM (Ablation)', markersize=4, linewidth=1.5, alpha=1.0)
ax1.fill_between(layers, abl_mean - abl_ci, abl_mean + abl_ci, color=model_colors["Ablation"], alpha=0.2)
ax1.axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.7)
ax1.set_xlabel('Layer', fontsize=12)
# ax1.set_ylabel(r'$M \cdot \frac{h_{\mathrm{tgt}} - h_{\mathrm{dist}}}{\|h_{\mathrm{tgt}} - h_{\mathrm{dist}}\|}$', fontsize=10)
# ax1.set_ylabel(r'$M \cdot \hat{d}_{\mathrm{tgt-dist}}$', fontsize=11)
ax1.set_ylabel('Target Alignment', fontsize=12)
# ax1.set_title('Memory Update Direction By Layer \n(Positive = target aligned, Negative = distractor aligned)', fontsize=12)
ax1.set_title('Memory Update Direction By Layer \n(more positive = more target aligned)', fontsize=12)
# ax1.set_title('Memory Update Selectivity\n(Positive = encodes target, Negative = encodes distractor)', fontsize=10)
ax1.legend(fontsize=11, loc='upper left', frameon=True)
ax1.set_xlim(0, 30)
ax1.tick_params(axis='both', labelsize=12)

# Right subplot: Bar chart
# bar_colors = ['red', model_colors['GNM']]
# bar_values = [abl_mid_layer_mean, gnm_mid_layer_mean]
# bar_errors = [abl_mid_ci, gnm_mid_ci]
# bar_labels = ['Ablation', 'GNM']
# # choose different alphas for each bar
# bar_alphas = [1.0, 1.0]

# bars = ax2.bar(bar_labels, bar_values, color=bar_colors, alpha=bar_alphas)
ax2.bar(['Ablation'], [abl_mid_layer_mean], color=model_colors["Ablation"], alpha=1.0)
ax2.bar(['GNM'], [gnm_mid_layer_mean], color=model_colors["GNM"], alpha=1.0)
# Add error bars separately for better visibility
ax2.errorbar(['Ablation', 'GNM'], [abl_mid_layer_mean, gnm_mid_layer_mean], 
             yerr=[abl_mid_ci, gnm_mid_ci], fmt='none', 
             capsize=5, capthick=0, elinewidth=1.5, color='black')
# ax2.errorbar(bar_labels, bar_values, yerr=bar_errors, fmt='none', 
#              capsize=5, capthick=1, elinewidth=1, color='black')

ax2.axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.7)
# ax2.set_ylabel('Mean Selectivity Index\n(Layers 15-29)', fontsize=10)
ax2.set_ylabel('Target Alignment', fontsize=12)
ax2.set_title('Memory Update Direction\n(Layers 15-30)', fontsize=12)
ax2.tick_params(axis='both', labelsize=12)

plt.tight_layout()
plt.savefig("../plots/memory_selectivity_index.png", dpi=600, bbox_inches="tight")
plt.show()

In [None]:
import pickle as pkl
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.stats import hmean
import json
import pandas as pd

plt.style.use('seaborn-v0_8')
pal = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
def diag_means(A):
    A = np.asarray(A)
    n, m = A.shape
    assert n == m, "expect square matrix"
    return np.array([A.diagonal(offset=k).mean() for k in range(-(n-1), n)])


In [None]:
models = {
    "memoryllm": "MemLLM",
    "gnm_ablation": "GNM (Ablation)",
    "gnm": "GNM",
    
}

model_keys = list(models.keys())
model_names = list(models.values())

In [None]:
model_keys, model_names

In [None]:
model_colors = {
    "llama3_ft": pal[1],
    "rag_trained": pal[2],
    "memoryllm": pal[4],
    "gnm": pal[0],
}

In [None]:
data = dict()

for model_key, model_name in models.items():

    data_root = f"../data/gnm_experiments/mixed_documents/{model_key}/summary.json"

    # Open the file and load the data
    with open(data_root, 'r') as json_file:
        data_dict = json.load(json_file)

    data[model_key] = data_dict
    data[model_key]["model_name"] = model_name

In [None]:
len_rollout = 10
step_0 = len_rollout-1

rows = []

for k, v in data.items():

    total_refusal_score = hmean(
        np.array([
            # v['refusal_accuracy'][0],
            v['refusal_precision'][0],
            v['refusal_recall'][0],
            # v['refusal_specificity'][0],
        ])
    )

    #######################
    ### All Upper Triangle
    #######################

    print(k)

    fa_mat = np.array(v["fact_accuracy_matrix"])[0][:len_rollout, :len_rollout]
    all_acc = np.mean(fa_mat[np.triu_indices_from(fa_mat, k=0)])
    print("all_acc", all_acc)

    f_acc_diags = diag_means(fa_mat)[step_0:]

    fs_mat = np.array(v["fact_specificity_matrix"])[0][:len_rollout, :len_rollout]
    all_spec = np.mean(fs_mat[np.triu_indices_from(fs_mat, k=0)])
    print("all_spec", all_spec)

    f_spec_diags = diag_means(fs_mat)[step_0:]

    fs_mat = np.array(v["fact_selectivity_matrix"])[0][:len_rollout, :len_rollout]
    all_sel = np.mean(fs_mat[np.triu_indices_from(fs_mat, k=0)])
    print("all_sel", all_sel)

    f_sel_diags = diag_means(fs_mat)[step_0:]

    total_all_score = harmonic_mean = hmean(
        np.array(
            [
                all_acc,
                all_spec,
                all_sel
            ]
        )
    )
    print("total all score", total_all_score)
    print("---")

    f_score_diags = np.mean(np.vstack([f_acc_diags, f_spec_diags, f_sel_diags]), 0)

    data[k]["fact_score_retention"] = f_score_diags
    
    refusal_f1_over_time = np.mean(np.vstack([v["refusal_precision_over_time"], v["refusal_recall_over_time"]]), 0)[:len_rollout]
    print(refusal_f1_over_time)
    
    scale = 100

    row = [
        # k,
        v["model_name"],
        total_all_score*scale,
        all_acc*scale,
        all_spec*scale,
        all_sel*scale,
        v['format_accuracy'][0]*scale,
        v['format_selectivity'][0]*scale,
        total_refusal_score*scale,
        v['refusal_precision'][0]*scale,
        v['refusal_recall'][0]*scale,
    ]

    rows.append(row)

columns = [
    # "model_key",
    "model_name",
    "total_all_score",
    "all_fact_accuracy",
    "all_fact_specificity",
    "all_fact_selectivity",    
    'format_accuracy',
    'format_selectivity',
    'total_refusal_score',
    'refusal_precision',
    'refusal_recall',

]

results_df = pd.DataFrame(rows, columns=columns)
results_df.round(1)

In [None]:
results_df.columns

In [None]:
ablation_cols = [
    'model_name', 
    'total_all_score', 
    'all_fact_selectivity', 
    'format_accuracy',  
    'total_refusal_score',
]

ablation_comp_models = ["MemLLM", "GNM", "GNM (Ablation)"]

ablation_df = results_df[ablation_cols]
ablation_df = ablation_df[ablation_df["model_name"].isin(ablation_comp_models)]
ablation_df.columns = ["model_name", "Facts (Overall)", "Facts (Sel.)", "Format (Acc.)", "Refusal (F1)"]
ablation_df

In [None]:
print(ablation_df.to_latex(index=False, float_format="%.1f"))