In [1]:
import os
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

In [6]:
CHECKPOINT_FOLDER = "/home/morg/students/gottesman3/knowledge-analysis-suite/performance_by_step/"
OUTPUT_FOLDER = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/performance_by_step_plots"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

In [3]:
def load_checkpoints(folder):
    checkpoints = {}
    for filename in os.listdir(folder):
        if filename.endswith(".json"):
            with open(os.path.join(folder, filename), "r") as f:
                data = json.load(f)
                checkpoint_name = os.path.splitext(filename)[0]
                checkpoints[checkpoint_name] = data
    return checkpoints

In [43]:
def extract_entity_data(checkpoint_data):
    records = []
    for entity_id, stats in checkpoint_data.items():
        questions = stats.get("questions", 0)
        correct = stats.get("correct", 0)
        occurences = stats.get("occurences", None)
        last_occurence = stats.get("last_occurence", None)

        if questions > 0:
            accuracy = correct / questions
            records.append({
                "entity_id": entity_id,
                "accuracy": accuracy,
                "questions": questions,      # ← include
                "correct": correct,          # ← include
                "occurences": occurences,
                "last_occurence": last_occurence
            })
    return pd.DataFrame(records)

In [None]:
def plot_accuracy_by_field(df, field, checkpoint_name):
    filtered_df = df[df[field].notnull()].copy()

    # Ensure accuracy is valid
    filtered_df = filtered_df[
        (filtered_df["accuracy"] >= 0) & (filtered_df["accuracy"] <= 1)
    ]

    if filtered_df.empty:
        print(f"⚠️ No valid data for '{field}' in {checkpoint_name}, skipping plot.")
        return

    # Choose binning strategy
    try:
        if field == "occurences":
            filtered_df["bin"] = pd.cut(filtered_df[field], bins=10)
        else:
            filtered_df["bin"] = pd.qcut(filtered_df[field], q=10, duplicates="drop")
    except ValueError:
        print(f"⚠️ Could not bin '{field}' in {checkpoint_name}. Skipping.")
        return

    # Group by bin and calculate total accuracy and entity count
    grouped = (
        filtered_df
        .groupby("bin", observed=True)
        .agg(
            total_questions=("questions", "sum"),
            total_correct=("correct", "sum"),
            entity_count=("entity_id", "count")
        )
        .reset_index()
    )
    grouped["accuracy"] = grouped["total_correct"] / grouped["total_questions"]
    grouped["bin_label"] = grouped["bin"].astype(str)  # Convert Interval to str for plotting

    # Plot
    plt.figure(figsize=(10, 6))
    ax = sns.barplot(x="bin_label", y="accuracy", data=grouped)
    plt.xticks(rotation=45, ha='right')
    plt.title(f"{checkpoint_name}: Total Accuracy by Binned {field} (n={len(filtered_df)})")
    plt.ylabel("Bin Accuracy (Total Correct / Total Questions)")
    plt.xlabel(f"{field.replace('_', ' ').title()} Bin Ranges")
    plt.ylim(0, 1)
    plt.grid(True, linestyle="--", alpha=0.5)

    # Annotate entity count on each bar
    for i, row in grouped.iterrows():
        ax.text(
            i,
            row["accuracy"] + 0.03,
            f'{row["entity_count"]} ents',
            ha='center',
            va='bottom',
            fontsize=9
        )

    plt.tight_layout()
    plot_path = os.path.join(OUTPUT_FOLDER, f"{checkpoint_name}_binned_accuracy_by_{field}.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"✅ Saved plot: {plot_path}")


In [47]:
checkpoints = load_checkpoints(CHECKPOINT_FOLDER)
for checkpoint_name, checkpoint_data in checkpoints.items():
    print(f"📊 Processing checkpoint: {checkpoint_name}")
    df = extract_entity_data(checkpoint_data)
    plot_accuracy_by_field(df, "occurences", checkpoint_name)
    plot_accuracy_by_field(df, "last_occurence", checkpoint_name)

📊 Processing checkpoint: checkpoint_2


KeyError: 'bin'

<Figure size 1000x600 with 0 Axes>