In [None]:
import pickle as pkl
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from scipy.stats import hmean
import json
import torch
import pandas as pd
import seaborn as sns

plt.style.use('seaborn-v0_8')
pal = sns.color_palette("deep", 10)
model_colors = {
    "Ablation": pal[7],
    "ICL-FT": pal[8],
    "llama3_ft": pal[8],
    "RAG-FT": pal[5],
    "rag_trained": pal[6],
    "MemLLM": pal[9],
    "GNM": pal[0],
    "gnm": pal[0]
}

In [None]:
def diag_means(A):
    A = np.asarray(A)
    n, m = A.shape
    assert n == m, "expect square matrix"
    return np.array([A.diagonal(offset=k).mean() for k in range(-(n-1), n)])


In [None]:
models = {
    "naive_model": "Base",
    # "llama3": "ICL",
    "llama3_ft": "ICL-FT",
    # "rag_untrained": "RAG",
    "rag_trained": "RAG-FT",
    "memoryllm": "MemLLM",
    "gnm": "GNM",
}

model_keys = list(models.keys())
model_names = list(models.values())

In [None]:
# model_colors = {
#     "Base": pal[3],
#     "ICL-FT": pal[1],
#     "RAG-FT": pal[2],
#     "MemLLM": pal[4],
#     "GNM": pal[0],
# }

In [None]:
model_keys, model_names

In [None]:
data = dict()

for model_key, model_name in models.items():

    data_root = f"../saved_evals/warmup_seq_10/{model_key}/summary.json"

    # Open the file and load the data
    with open(data_root, 'r') as json_file:
        data_dict = json.load(json_file)

    data[model_key] = data_dict
    data[model_key]["model_name"] = model_name

In [None]:
# len_rollout = 10
# step_0 = len_rollout-1

# old_measure_step = 5

rows = []

for k, v in data.items():


    fa_mat = np.array(v["fact_accuracy_matrix"])[0]
    all_acc = np.mean(fa_mat[np.triu_indices_from(fa_mat, k=0)])
    

    # f_acc_diags = diag_means(fa_mat)[step_0:]

    fs_mat = np.array(v["fact_specificity_matrix"])[0]
    all_spec = np.mean(fs_mat[np.triu_indices_from(fs_mat, k=0)])
    

    # f_spec_diags = diag_means(fs_mat)[step_0:]

    fs_mat = np.array(v["fact_selectivity_matrix"])[0]
    all_sel = np.mean(fs_mat[np.triu_indices_from(fs_mat, k=0)])
    

    # f_sel_diags = diag_means(fs_mat)[step_0:]
    
    total_all_score = harmonic_mean = hmean(
        np.array(
            [
                all_acc,
                all_spec,
                all_sel
            ]
        )
    )

    # f_score_diags = np.mean(np.vstack([f_acc_diags, f_spec_diags, f_sel_diags]), 0)

    # Extract fact_retention and fact_accuracy_over_time arrays
    fact_retention = v.get("fact_retention", [0] * 10)  # Default to zeros if not found
    fact_accuracy_over_time = v.get("fact_accuracy_over_time", [0] * 10)  # Default to zeros if not found
    
    scale = 100

    row = [
        # k,
        v["model_name"],
        total_all_score*scale,
        all_acc*scale,
        all_spec*scale,
        all_sel*scale,
        # Add fact_retention values (assuming they're percentages or 0-1 values)
        np.mean(fact_retention)*scale if len(fact_retention) > 0 else 0,
        # Add fact_accuracy_over_time values
        np.mean(fact_accuracy_over_time)*scale if len(fact_accuracy_over_time) > 0 else 0,
        # f_score_diags[0]*scale,
        # f_score_diags[old_measure_step]*scale
    ]

    rows.append(row)


columns = [
    "model_name",
    "total_all_score",
    "all_fact_accuracy",
    "all_fact_specificity",
    "all_fact_selectivity",
    "fact_retention",
    "fact_accuracy_over_time",
]

results_df = pd.DataFrame(rows, columns=columns)
results_df.round(1)


In [None]:
results_df.columns

In [None]:
plot_models = {
    "naive_model": "Base",
    "memoryllm": "MemLLM",
    "llama3_ft": "ICL-FT",
    "rag_trained": "RAG-FT",
    "gnm": "GNM",
}

plot_model_names = list(plot_models.keys())

plot_cols_1 = [
    'total_all_score', 
    'all_fact_accuracy',
    'all_fact_specificity', 
    'all_fact_selectivity', 
]
pc1_names = [
    "Score", "Accuracy", "Specificity", "Selectivity"
]

plot_cols_2 = [
    'fact_retention',
]
pc2_names = [
    "Retention"
]

plot_cols_3 = [
    'fact_accuracy_over_time',
]
pc3_names = [
    "Accuracy"
]


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --- choose the model display order / labels ---
model_order = list(plot_models.values())
model_order_bars = [m for m in model_order if m != "Base"]


def grouped_bar_panel(ax, df, metrics, model_order, model_order_bars, tick_names, title=None, ylim=None, show_base_lines=True):
    """
    One axis: x = metrics, within each metric: bars for each model.
    Base model shown as horizontal dotted lines.
    """
    # keep only the models you want, and order them
    sub = df[df["model_name"].isin(model_order)].copy()
    sub = sub.set_index("model_name").reindex(model_order)

    n_models = len(model_order_bars)
    n_metrics = len(metrics)

    x = np.arange(n_metrics)                       # metric group centers
    group_width = 0.8                              # total width reserved per metric group
    bar_w = group_width / n_models

    # Draw bars for non-Base models with opacity for all except GNM
    for i, m in enumerate(model_order_bars):
        offsets = x - group_width/2 + (i + 0.5)*bar_w
        # Use full opacity for GNM, reduced for others
        alpha = 1.0 if m == "GNM" else 1.0
        ax.bar(offsets, sub.loc[m, metrics].values, width=bar_w, label=m, 
               color=model_colors[m], alpha=alpha)

    # Draw horizontal dashed lines for Base model
    if show_base_lines:
        base_values = sub.loc["Base", metrics].values
        for i, val in enumerate(base_values):
            # Calculate span from left edge of first bar to right edge of last bar in data coordinates
            x_left = i - group_width/2
            x_right = i + group_width/2
            ax.plot([x_left, x_right], [val, val], 
                    color='black', linestyle='--', linewidth=2, dashes=(5, 2))
        
        # Add Base to legend (create a dummy line)
        ax.plot([], [], color='black', linestyle='--', linewidth=2, dashes=(5, 2), label="Base")

    ax.set_xticks(x)
    ax.set_xticklabels(tick_names, size=16)
    
    if title:
        ax.set_title(title, fontsize=16)
    if ylim is not None:
        ax.set_ylim(*ylim)


def line_chart_panel(ax, data_dict, metric_name, model_order, title=None, ylim=None, xlabel="Steps Since Seen"):
    """
    Line chart showing metric values over time steps for each model.
    """
    # Plot non-Base models
    for model_key, model_vals in data_dict.items():
        model_name = model_vals["model_name"]
        if model_name not in model_order:
            continue
        
        # Get the metric array (nested in a list)
        metric_values = np.array(model_vals.get(metric_name, [[]])[0]) * 100
        
        if len(metric_values) == 0:
            continue
            
        x = np.arange(len(metric_values))
        
        if model_name == "Base":
            # Dashed line for Base
            ax.plot(x, metric_values, "--o", label=model_name, 
                   color='black', linewidth=2, markersize=6, dashes=(5, 2))
        else:
            # Solid lines for other models with opacity
            alpha = 1.0 if model_name == "GNM" else 1.0
            ax.plot(x, metric_values, "-o", label=model_name, 
                   color=model_colors[model_name], linewidth=2, markersize=6, alpha=alpha)
    
    if title:
        ax.set_title(title, fontsize=16)
    if ylim is not None:
        ax.set_ylim(*ylim)
    
    ax.set_xlabel(xlabel, fontsize=14)
    ax.set_xticks(x)
    ax.grid(axis='y', alpha=0.3)


# --- build the figure ---
fig, (ax1, ax2, ax3) = plt.subplots(
    1, 3, 
    width_ratios=(0.5, 0.25, 0.25),
    figsize=(13, 3), sharey=False, constrained_layout=True
)

grouped_bar_panel(
    ax1,
    results_df,
    plot_cols_1,
    model_order,
    model_order_bars,
    pc1_names,
    title="Performance Over All Time Steps",
    ylim=(0, 100)
)

line_chart_panel(
    ax2,
    data,
    "fact_retention",
    model_order,
    title="Retention",
    ylim=(15, 100),
    xlabel="Steps Since Seen"
)

line_chart_panel(
    ax3,
    data,
    "fact_accuracy_over_time",
    model_order,
    title="Accuracy Stability",
    ylim=(70, 100),
    xlabel="Total Documents Seen"
)

ax1.set_ylabel("Percent (%)", fontsize=14)
ax2.set_ylabel("Percent (%)", fontsize=14)

# legend once (outside) - Base will now be included with dotted line
handles, labels = ax1.get_legend_handles_labels()
fig.legend(handles, labels, loc="center left", bbox_to_anchor=(0.16, -0.04), ncol=5, fontsize=16)
fig.tight_layout()

plt.savefig("../plots/warmup_bar.png", dpi=600, bbox_inches="tight")
plt.show()


In [None]:
# Alternative version using harmonic mean of all three metrics (accuracy, specificity, selectivity)

# First, compute the over_time and retention metrics for all three metrics
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import hmean

# Extract metrics for each model
retention_data = {}
over_time_data = {}

for model_key, model_vals in data.items():
    model_name = model_vals["model_name"]
    
    # Extract fact_accuracy_over_time (already exists)
    accuracy_over_time = np.array(model_vals.get("fact_accuracy_over_time", [[]])[0])
    
    # Extract specificity_over_time from diagonal of fact_specificity_matrix
    specificity_matrix = np.array(model_vals["fact_specificity_matrix"])[0]
    specificity_over_time = np.diag(specificity_matrix)
    
    # Extract selectivity_over_time from diagonal of fact_selectivity_matrix
    selectivity_matrix = np.array(model_vals["fact_selectivity_matrix"])[0]
    selectivity_over_time = np.diag(selectivity_matrix)
    
    # Extract fact_retention (already exists)
    accuracy_retention = np.array(model_vals.get("fact_retention", [[]])[0])
    
    # Compute specificity_retention (average along diagonal offsets)
    specificity_retention = diag_means(specificity_matrix)[len(specificity_matrix)-1:]
    
    # Compute selectivity_retention (average along diagonal offsets)
    selectivity_retention = diag_means(selectivity_matrix)[len(selectivity_matrix)-1:]
    
    # Compute harmonic means
    # For over_time metrics
    over_time_hmean = []
    for i in range(len(accuracy_over_time)):
        over_time_hmean.append(hmean([
            accuracy_over_time[i],
            specificity_over_time[i],
            selectivity_over_time[i]
        ]))
    
    # For retention metrics
    retention_hmean = []
    for i in range(len(accuracy_retention)):
        retention_hmean.append(hmean([
            accuracy_retention[i],
            specificity_retention[i],
            selectivity_retention[i]
        ]))
    
    retention_data[model_key] = {
        "model_name": model_name,
        "values": np.array(retention_hmean)
    }
    
    over_time_data[model_key] = {
        "model_name": model_name,
        "values": np.array(over_time_hmean)
    }

# --- build the figure ---
fig, (ax1, ax2, ax3) = plt.subplots(
    1, 3, 
    width_ratios=(0.5, 0.25, 0.25),
    figsize=(13, 3), sharey=False, constrained_layout=True
)

# First panel: Same bar chart as before
grouped_bar_panel(
    ax1,
    results_df,
    plot_cols_1,
    model_order,
    model_order_bars,
    pc1_names,
    title="Performance Over All Time Steps",
    ylim=(0, 100)
)

# Second panel: Line chart of retention (harmonic mean)
for model_key, metric_data in retention_data.items():
    model_name = metric_data["model_name"]
    if model_name not in model_order:
        continue
    
    metric_values = metric_data["values"] * 100
    x = np.arange(len(metric_values))
    
    if model_name == "Base":
        ax2.plot(x, metric_values, "--o", label=model_name, 
                color='black', linewidth=2, markersize=6, dashes=(5, 2))
    else:
        alpha = 1.0 if model_name == "GNM" else 0.7
        ax2.plot(x, metric_values, "-o", label=model_name, 
                color=model_colors[model_name], linewidth=2, markersize=6, alpha=alpha)

ax2.set_title("Score Retention", fontsize=16)
ax2.set_xlabel("Steps Since Seen", fontsize=14)
ax2.set_ylabel("Percent (%)", fontsize=14)
ax2.set_xticks(x)
ax2.grid(axis='y', alpha=0.3)
ax2.set_ylim(30, 100)

# Third panel: Line chart of over_time (harmonic mean)
for model_key, metric_data in over_time_data.items():
    model_name = metric_data["model_name"]
    if model_name not in model_order:
        continue
    
    metric_values = metric_data["values"] * 100
    x = np.arange(len(metric_values))
    
    if model_name == "Base":
        ax3.plot(x, metric_values, "--o", label=model_name, 
                color='black', linewidth=2, markersize=6, dashes=(5, 2))
    else:
        alpha = 1.0 if model_name == "GNM" else 0.7
        ax3.plot(x, metric_values, "-o", label=model_name, 
                color=model_colors[model_name], linewidth=2, markersize=6, alpha=alpha)

ax3.set_title("Score Stability", fontsize=16)
ax3.set_xlabel("Documents Seen", fontsize=14)
ax3.set_xticks(x)
ax3.grid(axis='y', alpha=0.3)
ax3.set_ylim(30, 100)

ax1.set_ylabel("Percent (%)", fontsize=14)

# legend once (outside)
handles, labels = ax1.get_legend_handles_labels()
fig.legend(handles, labels, loc="center left", bbox_to_anchor=(0.16, -0.04), ncol=5, fontsize=16)
fig.tight_layout()

plt.savefig("../plots/warmup_bar_hmean.png", dpi=600, bbox_inches="tight")
plt.show()


In [None]:
# Third version: Bar chart showing specific time steps from retention data

import numpy as np
import matplotlib.pyplot as plt

# Extract specific time steps from retention_data (indices 0, 3, 7)
retention_snapshots = {}
for model_key, metric_data in retention_data.items():
    model_name = metric_data["model_name"]
    values = metric_data["values"] * 100
    
    # Extract values at indices 0, 3, and 7
    retention_snapshots[model_name] = {
        "most_recent": values[0] if len(values) > 0 else 0,
        "4_steps_ago": values[4] if len(values) > 4 else 0,
        "8_steps_ago": values[8] if len(values) > 8 else 0,
    }

# Create DataFrame for the snapshot data
snapshot_rows = []
for model_name in model_order:
    if model_name in retention_snapshots:
        snapshot_rows.append([
            model_name,
            retention_snapshots[model_name]["most_recent"],
            retention_snapshots[model_name]["4_steps_ago"],
            retention_snapshots[model_name]["8_steps_ago"],
        ])

snapshot_df = pd.DataFrame(snapshot_rows, columns=["model_name", "most_recent", "4_steps_ago", "8_steps_ago"])
# --- build the figure ---
fig, (ax1, ax2) = plt.subplots(
    1, 2, 
    width_ratios=(0.6, 0.4),
    figsize=(13, 3), sharey=False, constrained_layout=True
)

# First panel: Same bar chart as before
grouped_bar_panel(
    ax1,
    results_df,
    plot_cols_1,
    model_order,
    model_order_bars,
    pc1_names,
    title="Total Performance",
    ylim=(0, 100)
)

# Second panel: Bar chart with retention at specific time steps
grouped_bar_panel(
    ax2,
    snapshot_df,
    ["most_recent", "4_steps_ago", "8_steps_ago"],
    model_order,
    model_order_bars,
    ["Most Recent", "4 Steps Ago", "8 Steps Ago"],
    title="Performance By Recency",
    ylim=(0, 100),
    show_base_lines=False
)

ax1.set_ylabel("Percent (%)", fontsize=14)
# ax2.set_ylabel("Percent (%)", fontsize=14)

# legend once (outside)
handles, labels = ax1.get_legend_handles_labels()
fig.legend(handles, labels, loc="center left", bbox_to_anchor=(0.24, -0.04), ncol=5, fontsize=16)
fig.tight_layout()

plt.savefig("../plots/warmup_bar_snapshots.png", dpi=600, bbox_inches="tight")
plt.show()


In [None]:
# Third version: Bar chart showing specific time steps from retention data

import numpy as np
import matplotlib.pyplot as plt

# Binomial CI function
def binomial_ci(p, n, z=1.96):
    """Returns 95% CI half-width for a proportion."""
    se = np.sqrt(p * (1 - p) / n)
    return z * se

# Sample counts
n_episodes = 216
n_upper_triangle = 55  # 10+9+8+...+1 for 10x10 matrix
n_total_performance = n_episodes * n_upper_triangle  # 11880

# Sample counts for retention snapshots (index 0, 4, 9)
n_samples_retention = {
    "most_recent": 10 * n_episodes,   # 2160
    "5_steps_ago": 6 * n_episodes,    # 1296
    "10_steps_ago": 1 * n_episodes,   # 216
}

# Extract specific time steps from retention_data (indices 0, 4, 9)
retention_snapshots = {}
for model_key, metric_data in retention_data.items():
    model_name = metric_data["model_name"]
    values = metric_data["values"] * 100
    
    # Extract values at indices 0, 4, and 9
    retention_snapshots[model_name] = {
        "most_recent": values[0] if len(values) > 0 else 0,
        "5_steps_ago": values[4] if len(values) > 4 else 0,
        "10_steps_ago": values[9] if len(values) > 9 else 0,
    }

# Create DataFrame for the snapshot data
snapshot_rows = []
for model_name in model_order:
    if model_name in retention_snapshots:
        snapshot_rows.append([
            model_name,
            retention_snapshots[model_name]["most_recent"],
            retention_snapshots[model_name]["5_steps_ago"],
            retention_snapshots[model_name]["10_steps_ago"],
        ])

snapshot_df = pd.DataFrame(snapshot_rows, columns=["model_name", "most_recent", "5_steps_ago", "10_steps_ago"])

# Modified grouped_bar_panel with error bars
def grouped_bar_panel_with_ci(ax, df, metrics, model_order, model_order_bars, tick_names, 
                               n_samples_dict, title=None, ylim=None, show_base_lines=True):
    """
    One axis: x = metrics, within each metric: bars for each model.
    Base model shown as horizontal dotted lines. Includes 95% CI error bars.
    n_samples_dict: dict mapping metric name to sample count, or single int for all metrics
    """
    # keep only the models you want, and order them
    sub = df[df["model_name"].isin(model_order)].copy()
    sub = sub.set_index("model_name").reindex(model_order)

    n_models = len(model_order_bars)
    n_metrics = len(metrics)

    x = np.arange(n_metrics)
    group_width = 0.8
    bar_w = group_width / n_models

    # Draw bars for non-Base models with error bars
    for i, m in enumerate(model_order_bars):
        offsets = x - group_width/2 + (i + 0.5)*bar_w
        values = sub.loc[m, metrics].values
        
        # Compute CIs for each metric
        cis = []
        for j, metric in enumerate(metrics):
            if isinstance(n_samples_dict, dict):
                n = n_samples_dict[metric]
            else:
                n = n_samples_dict
            # Convert percentage back to proportion for CI calculation
            p = values[j] / 100
            ci = binomial_ci(p, n) * 100  # Convert CI back to percentage
            cis.append(ci)
        
        ax.bar(offsets, values, width=bar_w, label=m, 
               color=model_colors[m], alpha=1.0, yerr=cis, capsize=0, 
               error_kw={'linewidth': 1.5})

    # Draw horizontal dashed lines for Base model
    if show_base_lines:
        base_values = sub.loc["Base", metrics].values
        for i, val in enumerate(base_values):
            x_left = i - group_width/2
            x_right = i + group_width/2
            ax.plot([x_left, x_right], [val, val], 
                    color='black', linestyle='--', linewidth=2, dashes=(5, 2))
        
        ax.plot([], [], color='black', linestyle='--', linewidth=2, dashes=(5, 2), label="Base")

    ax.set_xticks(x)
    ax.set_xticklabels(tick_names, size=16)
    
    if title:
        ax.set_title(title, fontsize=16)
    if ylim is not None:
        ax.set_ylim(*ylim)

# --- build the figure ---
fig, (ax1, ax2) = plt.subplots(
    1, 2, 
    width_ratios=(0.6, 0.4),
    figsize=(13, 2.5), sharey=False, constrained_layout=True
)

# First panel: Bar chart with CI (all metrics use n_total_performance)
grouped_bar_panel_with_ci(
    ax1,
    results_df,
    plot_cols_1,
    model_order,
    model_order_bars,
    pc1_names,
    n_samples_dict=n_total_performance,  # 11880 samples
    title="(C) Total Performance",
    ylim=(0, 100)
)

# Second panel: Bar chart with retention at specific time steps (different n per metric)
grouped_bar_panel_with_ci(
    ax2,
    snapshot_df,
    ["most_recent", "5_steps_ago", "10_steps_ago"],
    model_order,
    model_order_bars,
    ["Most Recent", "4 Steps Ago", "9 Steps Ago"],
    n_samples_dict=n_samples_retention,  # {2160, 1296, 216}
    title="(D) Performance By Recency",
    ylim=(0, 100),
    show_base_lines=False
)

ax1.set_ylabel("Percent (%)", fontsize=14)

# legend on the right side, horizontally stacked
handles, labels = ax1.get_legend_handles_labels()
fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1.0, 0.5), ncol=1, fontsize=14)
fig.tight_layout()

plt.savefig("../plots/warmup_bar_snapshots_10.png", dpi=600, bbox_inches="tight")
plt.show()

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# # --- choose the model display order / labels ---
# # If results_df["model_name"] already contains display names like "Base", "MetaICL", etc:
# model_order = list(plot_models.values())  # ["Base","MetaICL","RAG-FT","MemLLM","Ours"] in your chosen order
# # If instead results_df uses internal keys like "naive_model", then swap to:
# # model_order = list(plot_models.keys())
# model_order_bars = [m for m in model_order if m != "Base"]

# fig, (ax_left, ax_right) = plt.subplots(
#     1, 2, 
#     width_ratios=(0.7, 0.3),
#     figsize=(13, 3), 
#     sharey=False, 
#     constrained_layout=True
# )

# # Left panel: Bar chart for metrics averaged across time steps 0-7
# # This should match averaging the line values on the right panel
# rows_8steps = []
# curve_len_8 = 10

# for k, v in data.items():
#     # Use full matrix and compute step_0_offset dynamically
#     fa_mat = np.array(v["fact_accuracy_matrix"])[0]
#     mat_size = fa_mat.shape[0]
#     step_0_offset = mat_size - 1
#     # Get diagonal means for each time step (0-7 steps ago)
#     f_acc_diags = diag_means(fa_mat)[step_0_offset:step_0_offset+curve_len_8]
#     all_acc_8 = np.mean(f_acc_diags)  # Average across the 8 time steps
    
#     fs_mat = np.array(v["fact_specificity_matrix"])[0]
#     f_spec_diags = diag_means(fs_mat)[step_0_offset:step_0_offset+curve_len_8]
#     all_spec_8 = np.mean(f_spec_diags)
    
#     fsel_mat = np.array(v["fact_selectivity_matrix"])[0]
#     f_sel_diags = diag_means(fsel_mat)[step_0_offset:step_0_offset+curve_len_8]
#     all_sel_8 = np.mean(f_sel_diags)
    
#     # Score is the average of the three metrics (same as on right panel)
#     metric_values = np.mean(np.vstack([f_acc_diags, f_spec_diags, f_sel_diags]), 0)
#     total_score_8 = np.mean(metric_values)  # Average across the 8 time steps
    
#     row_8 = [
#         v["model_name"],
#         total_score_8 * 100,
#         all_acc_8 * 100,
#         all_spec_8 * 100,
#         all_sel_8 * 100,
#     ]
#     rows_8steps.append(row_8)

# results_df_8 = pd.DataFrame(rows_8steps, columns=['model_name', 'Score', 'Accuracy', 'Specificity', 'Selectivity'])

# # Plot bars on left panel
# metrics_8 = ['Score', 'Accuracy', 'Specificity', 'Selectivity']
# sub = results_df_8[results_df_8["model_name"].isin(model_order)].copy()
# sub = sub.set_index("model_name").reindex(model_order)

# n_models = len(model_order_bars)
# n_metrics = len(metrics_8)

# x_bar = np.arange(n_metrics)
# group_width = 0.8
# bar_w = group_width / n_models

# for i, m in enumerate(model_order_bars):
#     offsets = x_bar - group_width/2 + (i + 0.5)*bar_w
#     alpha = 1.0 if m == "GNM" else 0.7
#     ax_left.bar(offsets, sub.loc[m, metrics_8].values, width=bar_w, label=m, 
#                 color=model_colors[m], alpha=alpha)

# # Draw horizontal dashed lines for Base model
# base_values_8 = sub.loc["Base", metrics_8].values
# for i, val in enumerate(base_values_8):
#     x_left = i - group_width/2
#     x_right = i + group_width/2
#     ax_left.plot([x_left, x_right], [val, val], 
#                  color='black', linestyle='--', linewidth=2, dashes=(5, 2))

# ax_left.plot([], [], color='black', linestyle='--', linewidth=2, dashes=(5, 2), label="Base")

# ax_left.set_xticks(x_bar)
# ax_left.set_xticklabels(metrics_8, size=16)
# ax_left.set_title("Performance Across 8 Time Steps", fontsize=16)
# ax_left.set_ylim(0, 100)
# ax_left.set_ylabel("Percent (%)", fontsize=14)

# # Right panel: Line chart showing score over time
# x_line = np.arange(curve_len_8)

# # Plot non-Base models first (to match legend order from bars)
# for model_name in model_order_bars:
#     # Find the corresponding key in data
#     k = [key for key, val in data.items() if val["model_name"] == model_name][0]
#     v = data[k]
    
#     fa_mat = np.array(v["fact_accuracy_matrix"])[0]
#     mat_size = fa_mat.shape[0]
#     step_0_offset = mat_size - 1
#     f_acc_diags = diag_means(fa_mat)[step_0_offset:step_0_offset+curve_len_8]
    
#     fs_mat = np.array(v["fact_specificity_matrix"])[0]
#     f_spec_diags = diag_means(fs_mat)[step_0_offset:step_0_offset+curve_len_8]
    
#     fsel_mat = np.array(v["fact_selectivity_matrix"])[0]
#     f_sel_diags = diag_means(fsel_mat)[step_0_offset:step_0_offset+curve_len_8]
    
#     metric_values = np.mean(np.vstack([f_acc_diags, f_spec_diags, f_sel_diags]), 0)
    
#     ax_right.plot(x_line, metric_values * 100, "-o", label=model_name, 
#                   color=model_colors[model_name], linewidth=2, markersize=6)

# # Plot Base model last with dashed line
# k_base = [key for key, val in data.items() if val["model_name"] == "Base"][0]
# v_base = data[k_base]

# fa_mat = np.array(v_base["fact_accuracy_matrix"])[0]
# mat_size = fa_mat.shape[0]
# step_0_offset = mat_size - 1
# f_acc_diags = diag_means(fa_mat)[step_0_offset:step_0_offset+curve_len_8]

# fs_mat = np.array(v_base["fact_specificity_matrix"])[0]
# f_spec_diags = diag_means(fs_mat)[step_0_offset:step_0_offset+curve_len_8]

# fsel_mat = np.array(v_base["fact_selectivity_matrix"])[0]
# f_sel_diags = diag_means(fsel_mat)[step_0_offset:step_0_offset+curve_len_8]
# metric_values_base = np.mean(np.vstack([f_acc_diags, f_spec_diags, f_sel_diags]), 0)

# ax_right.plot(x_line, metric_values_base * 100, "--o", label="Base", 
#               color='black', linewidth=2, markersize=6, dashes=(5, 2))

# ax_right.set_title("Score Over Time", fontsize=16)
# ax_right.set_xticks(list(x_line))
# ax_right.set_xlabel("Steps Since Seen", fontsize=14)
# ax_right.set_ylim(50, 100)
# ax_right.set_ylabel("Score (%)", fontsize=14)
# ax_right.grid(axis='y', alpha=0.3)

# # Single legend for the entire figure
# handles, labels = ax_left.get_legend_handles_labels()
# fig.legend(handles, labels, loc="center left", bbox_to_anchor=(0.16, -0.04), ncol=5, fontsize=16)

# plt.savefig("../plots/warmup_combined.png", dpi=600, bbox_inches="tight")
# plt.show()

In [None]:
# # Create a figure with 4 line charts showing performance over time
# curve_len = 8
# metrics = ['score', 'accuracy', 'specificity', 'selectivity']
# metric_titles = ['Score', 'Accuracy', 'Specificity', 'Selectivity']

# fig, axs = plt.subplots(1, len(metrics), figsize=(13, 3.5))

# x = np.arange(curve_len)

# # Calculate metrics for all time steps
# for metric_idx, metric_name in enumerate(metrics):
#     for k, v in data.items():
#         # Calculate diagonal means for each metric
#         if metric_name == 'score':
#             # Score is average of the three metrics
#             fa_mat = np.array(v["fact_accuracy_matrix"])[0][:len_rollout, :len_rollout]
#             f_acc_diags = diag_means(fa_mat)[step_0:][:curve_len]
            
#             fs_mat = np.array(v["fact_specificity_matrix"])[0][:len_rollout, :len_rollout]
#             f_spec_diags = diag_means(fs_mat)[step_0:][:curve_len]
            
#             fs_mat = np.array(v["fact_selectivity_matrix"])[0][:len_rollout, :len_rollout]
#             f_sel_diags = diag_means(fs_mat)[step_0:][:curve_len]
            
#             metric_values = np.mean(np.vstack([f_acc_diags, f_spec_diags, f_sel_diags]), 0)
#         elif metric_name == 'accuracy':
#             fa_mat = np.array(v["fact_accuracy_matrix"])[0][:len_rollout, :len_rollout]
#             metric_values = diag_means(fa_mat)[step_0:][:curve_len]
#         elif metric_name == 'specificity':
#             fs_mat = np.array(v["fact_specificity_matrix"])[0][:len_rollout, :len_rollout]
#             metric_values = diag_means(fs_mat)[step_0:][:curve_len]
#         elif metric_name == 'selectivity':
#             fs_mat = np.array(v["fact_selectivity_matrix"])[0][:len_rollout, :len_rollout]
#             metric_values = diag_means(fs_mat)[step_0:][:curve_len]
        
#         # Plot the line
#         axs[metric_idx].plot(x, metric_values * 100, "--o", label=v["model_name"], color=model_colors[v["model_name"]])
    
#     axs[metric_idx].set_title(metric_titles[metric_idx], fontsize=14)
#     axs[metric_idx].set_xticks(list(x))
#     axs[metric_idx].set_xlabel("Steps Since Seen", fontsize=14)
#     axs[metric_idx].set_ylim(0, 100)

# # Add legend to last subplot
# axs[-1].legend(fontsize=12, loc="lower left")

# # Set y-label on first subplot
# axs[0].set_ylabel("% Performance", fontsize=14)
# fig.suptitle("Memory Retention Over Time", y=0.94, fontsize=16)

# fig.tight_layout()

# plt.savefig("../plots/warmup_bar_by_time.png", dpi=600, bbox_inches="tight")
# plt.show()
