In [None]:
import os
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from datasets import load_dataset


In [None]:
CHECKPOINT_FOLDER = "/home/morg/students/gottesman3/knowledge-analysis-suite/performance_by_step/"
OUTPUT_FOLDER = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/performance_by_step_plots/subject_chunks"

In [None]:
SHARED_CHECKPOINT_FOLDER = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/outputs"
SHARED_OUTPUT_FOLDER = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/performance_by_step_plots/shared_chunks"


In [None]:
def load_checkpoints(folder):
    checkpoints = {}
    for filename in os.listdir(folder):
        if filename.endswith(".json"):
            with open(os.path.join(folder, filename), "r") as f:
                data = json.load(f)
                checkpoint_name = os.path.splitext(filename)[0]
                checkpoints[checkpoint_name] = data
    return checkpoints

In [191]:
def extract_entity_data(checkpoint_data):
    records = []
    for entity_id, stats in checkpoint_data.items():
        questions = stats.get("questions", 0)
        correct = stats.get("correct", 0)
        occurences = stats.get("occurences", None)
        last_occurence = stats.get("last_occurence", None)

        if questions > 0:
            accuracy = correct / questions
            records.append({
                "entity_id": entity_id,
                "accuracy": accuracy,
                "questions": questions,      # ← include
                "correct": correct,          # ← include
                "occurences": occurences,
                "last_occurence": last_occurence
            })
    return pd.DataFrame(records)

In [200]:
def plot_accuracy_by_field(df, field, checkpoint_name, output_folder):
    filtered_df = df[df[field].notnull()].copy()

    try:
        max_val = filtered_df[field].max()
        bin_edges = np.linspace(0, max_val, num=11)  # 10 bins from 0 to max
        filtered_df["bin"] = pd.cut(
            filtered_df[field], bins=bin_edges, include_lowest=True
        )
    except ValueError:
        print(f"Could not bin '{field}' in {checkpoint_name}. Skipping.")
        return

    # Group by bin - keep empty bins
    grouped = (
        filtered_df
        .groupby("bin", observed=False)  # Keep empty categories
        .agg(
            total_questions=("questions", "sum"),
            total_correct=("correct", "sum"),
            entity_count=("entity_id", "count")
        )
        .reset_index()
    )
    grouped["accuracy"] = grouped.apply(
        lambda row: row["total_correct"] / row["total_questions"] if row["total_questions"] > 0 else 0, 
        axis=1
    )    
    grouped["bin_label"] = grouped["bin"].astype(str)
    grouped["checkpoint"] = checkpoint_name

    # Plot
    plt.figure(figsize=(10, 6))
    ax = sns.barplot(x="bin_label", y="accuracy", data=grouped)
    plt.xticks(rotation=45, ha='right')
    plt.title(f"{checkpoint_name}: Accuracy by {field} (n={len(filtered_df)})")
    plt.ylabel("Accuracy (Total Correct / Total Questions)")
    plt.xlabel(f"{field.replace('_', ' ').title()}")
    if field == "last_occurence":
        plt.ylim(0, 0.3)
    else:
        plt.ylim(0, 1.11)

    plt.grid(True, linestyle="--", alpha=0.5)

    # Annotate entity counts
    for i, row in grouped.iterrows():
        ax.text(
            i,
            row["accuracy"] + 0.03,
            f'n={row["entity_count"]}',
            ha='center',
            va='bottom',
            fontsize=9
        )

    plt.tight_layout()
    plot_path = os.path.join(output_folder, f"{checkpoint_name}_binned_accuracy_by_{field}.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Saved plot: {plot_path}")
    return grouped[["checkpoint", "bin_label", "accuracy", "entity_count"]]


In [199]:
def plot_occurences_of_last_seen_entities(df, checkpoint_name, output_folder):
    # Filter entities with valid last_occurence and occurences
    filtered_df = df[df["last_occurence"].notnull() & df["occurences"].notnull()].copy()
    try:
        max_val = filtered_df["last_occurence"].max()
        bin_edges = np.linspace(0, max_val, num=11)  # 10 bins from 0 to max
        filtered_df["last_bin"] = pd.cut(
            filtered_df["last_occurence"], bins=bin_edges, include_lowest=True
        )
    except ValueError:
        print(f"Could not bin last_occurence in {checkpoint_name}, skipping.")
        return

    # Identify the last bin
    last_bin = filtered_df["last_bin"].cat.categories[-1]
    last_seen_df = filtered_df[filtered_df["last_bin"] == last_bin].copy()


    # Define fixed bin edges
    bin_edges = [0, 10, 100, 1000, float('inf')]
    bin_labels = ['0-10', '10-100', '100-1000', '1000+']
    
    last_seen_df["occur_bin"] = pd.cut(
        last_seen_df["occurences"], 
        bins=bin_edges, 
        labels=bin_labels,
        include_lowest=True,
        right=False 
    )

    # Count how many entities fall into each occurence bin
    grouped = (
        last_seen_df
        .groupby("occur_bin", observed=True)
        .agg(entity_count=("entity_id", "count"))
        .reset_index()
    )
    grouped["bin_label"] = grouped["occur_bin"].astype(str)
    grouped["checkpoint"] = checkpoint_name

    # Plot
    plt.figure(figsize=(10, 6))
    ax = sns.barplot(x="bin_label", y="entity_count", data=grouped)
    plt.xticks(rotation=45, ha='right')
    plt.title(f"{checkpoint_name}: Occurence Bins of Last-Seen Entities")
    plt.ylabel("Entity Count")
    plt.xlabel("Occurences Bin")
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.tight_layout()

    plot_path = os.path.join(output_folder, f"{checkpoint_name}_last_seen_entities_by_occurences.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Saved last-seen occurence bin plot: {plot_path}")
    return grouped[["checkpoint", "bin_label", "entity_count"]]

In [204]:
checkpoints = load_checkpoints(SHARED_CHECKPOINT_FOLDER)
for checkpoint_name, checkpoint_data in checkpoints.items():
    print(f"Processing checkpoint: {checkpoint_name}")        
    df = extract_entity_data(checkpoint_data)
    if checkpoint_name == "checkpoint_final":
        final_df = df.copy()
    plot_accuracy_by_field(df, "occurences", checkpoint_name, SHARED_OUTPUT_FOLDER)
    plot_accuracy_by_field(df, "last_occurence", checkpoint_name, SHARED_OUTPUT_FOLDER)
    plot_occurences_of_last_seen_entities(df, checkpoint_name, SHARED_OUTPUT_FOLDER)

Processing checkpoint: checkpoint_1
Saved plot: /home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/performance_by_step_plots/shared_chunks/checkpoint_1_binned_accuracy_by_occurences.png
Saved plot: /home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/performance_by_step_plots/shared_chunks/checkpoint_1_binned_accuracy_by_last_occurence.png
Saved last-seen occurence bin plot: /home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/performance_by_step_plots/shared_chunks/checkpoint_1_last_seen_entities_by_occurences.png
Processing checkpoint: checkpoint_2
Saved plot: /home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/performance_by_step_plots/shared_chunks/checkpoint_2_binned_accuracy_by_occurences.png
Saved plot: /home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/performance_by_step_plots/shared_chunks/checkpoint_2_binned_accuracy_by_last_occurence.png
Saved last-seen occurence bin plot: /home/joberant/NLP_242

In [None]:
ds = load_dataset("dhgottesman/popqa-kas", split="train")
popqa_questions_df = ds.to_pandas()

# subject_chunks --> shared_chunks
popqa_questions = popqa_questions_df[["s_uri", "subject_num_chunks"]].drop_duplicates("s_uri")

from collections import defaultdict

uris_to_num_chunks = defaultdict(list)
for _, row in popqa_questions.iterrows():
    uris_to_num_chunks[row['s_uri']] = row["subject_num_chunks"]

In [None]:
unequal_df = final_df[final_df["num_chunks"] != final_df["occurences"]]
filtered_df = popqa_questions_df[popqa_questions_df["s_uri"].isin(unequal_df["entity_id"])]


In [None]:
def has_unique_chunks(chunk_list):
    return len(chunk_list) == len(set(chunk_list))