# Crosscoder Learning - Sparsity & CE Diff Results

In [1]:
import os
import json
import pandas as pd

from vis import aggregate_eval_res

  from .autonotebook import tqdm as notebook_tqdm


## L1 Only

In [2]:
# common settings
DATA_SPLIT = "val"
METRICS_SPARSITY = ["l0_loss-mean", "dead_count"]
METRICS_CE       = ["ce_diff_A-mean", "ce_diff_B-mean", "ce_diff_C-mean"]

# list out each model’s parameters and its token‐order list
MODEL_CONFIGS = [
    {
        "model_str": "pythia1b",
        "model_name": "pythia",
        "replace_string": "EleutherAI/pythia-1b/",
        "assert_num": 24,
        "tok_list": ["128M vs. 1B", "1B vs. 4B", "4B vs. 286B", "1B vs. 4B vs. 286B"],
    },
    {
        "model_str": "olmo1b",
        "model_name": "olmo",
        "replace_string": "allenai/OLMo-1B-0724-hf/",
        "assert_num": 24,
        "tok_list": ["2B vs. 4B", "4B vs. 33B", "33B vs. 287B", "33B vs. 3048B",
                     "287B vs. 3048B", "4B vs. 33B vs. 3048B"],
    },
    {
        "model_str": "bloom1b",
        "model_name": "bloom",
        "replace_string": "bigscience/bloom-1b1-intermediate/",
        "assert_num": 24,
        "tok_list": ["550M vs. 6B", "6B vs. 55B", "55B vs. 341B", "6B vs. 55B vs. 341B"],
    },
]

def prepare_df(df: pd.DataFrame, tok_list: list[str]) -> pd.DataFrame:
    # map token_name → order index
    tok_order = {tok: i for i, tok in enumerate(tok_list)}
    df = df.copy()
    df["order"] = df["token_name"].map(tok_order)

    # round floats and cast ints
    if "l0_loss-mean" in df:
        df["l0_loss-mean"] = df["l0_loss-mean"].round(2)
    if "dead_count" in df:
        df["dead_count"] = df["dead_count"].astype(int)
    for col in df.columns:
        if col.startswith("ce_diff_") and col.endswith("-mean"):
            df[col] = df[col].round(2)

    return df

def print_grouped(df: pd.DataFrame, metrics: list[str]) -> None:
    for m in metrics:
        print("_" * 30)
        print(m)
        grouped = (
            df[df["data_split"] == DATA_SPLIT]
              .groupby(["categories", "token_name", "data_split", "order"])[m]
              .agg(["mean", "std"])
              .reset_index()
              .sort_values("order")
        )
        print(grouped)

# main loop
for cfg in MODEL_CONFIGS:
    sparsity_df, ce_df = aggregate_eval_res(
        model_str=cfg["model_str"],
        model_name=cfg["model_name"],
        replace_string=cfg["replace_string"],
        assert_num=cfg["assert_num"],
        joint_eval_sparsity_path="workspace/results/sparsity_ce/joint_eval_sparsity_l1",
        joint_eval_ce_path="workspace/results/sparsity_ce/joint_eval_ce_l1",
    )

    sparsity_df = prepare_df(sparsity_df, cfg["tok_list"])
    ce_df       = prepare_df(ce_df,       cfg["tok_list"])

    print(f"\n=== Results for model: {cfg['model_name']} ===")
    print_grouped(sparsity_df, METRICS_SPARSITY)
    print_grouped(ce_df,       METRICS_CE)


=== Results for model: pythia ===
______________________________
l0_loss-mean
    categories          token_name data_split  order        mean       std
0  consecutive         128M vs. 1B        val      0   88.300000  0.260000
1  consecutive           1B vs. 4B        val      1  214.176667  0.089629
3  consecutive         4B vs. 286B        val      2  190.433333  0.366924
2  consecutive  1B vs. 4B vs. 286B        val      3  214.753333  0.572917
______________________________
dead_count
    categories          token_name data_split  order       mean       std
0  consecutive         128M vs. 1B        val      0   9.666667  1.154701
1  consecutive           1B vs. 4B        val      1   1.666667  1.154701
3  consecutive         4B vs. 286B        val      2   9.000000  1.000000
2  consecutive  1B vs. 4B vs. 286B        val      3  19.333333  3.785939
______________________________
ce_diff_A-mean
    categories          token_name data_split  order      mean       std
0  consecutive 

## L1 + BatchTopK

In [3]:
def prepare_df(df: pd.DataFrame, tok_list: list[str]) -> pd.DataFrame:
    tok_order = {tok: i for i, tok in enumerate(tok_list)}
    df = df.copy()
    df["order"] = df["token_name"].map(tok_order)
    if "l0_loss-mean" in df: df["l0_loss-mean"] = df["l0_loss-mean"].round(2)
    if "dead_count"   in df: df["dead_count"]    = df["dead_count"].astype(int)
    for col in df.columns:
        if col.startswith("ce_diff_") and col.endswith("-mean"):
            df[col] = df[col].round(2)
    return df

MODEL_CONFIGS = [
    {
        "model_str":      "pythia1b",
        "model_name":     "pythia",
        "latex_name":     "Pythia-1B L8",
        "replace_string": "EleutherAI/pythia-1b/",
        "assert_num":     24,
        "tok_list":       ["128M vs. 1B", "1B vs. 4B", "4B vs. 286B", "1B vs. 4B vs. 286B"],
    },
    {
        "model_str": "olmo1b",
        "model_name": "olmo",
        "latex_name":     "OLMo-1B L8",
        "replace_string": "allenai/OLMo-1B-0724-hf/",
        "assert_num": 24,
        "tok_list": ["2B vs. 4B", "4B vs. 33B", "33B vs. 3048B", "4B vs. 33B vs. 3048B"],
    },
    {
        "model_str":      "bloom1b",
        "model_name":     "bloom",
        "latex_name":     "BLOOM-1B L12",
        "replace_string": "bigscience/bloom-1b1-intermediate/",
        "assert_num":     24,
        "tok_list":       ["550M vs. 6B", "6B vs. 55B", "55B vs. 341B", "6B vs. 55B vs. 341B"],
    },
]

header = r"""\begin{table*}[t]
  \centering
  \resizebox{1.0\linewidth}{!}{
    \begin{tabular}{ 
      cc 
      @{\hspace{2em}}
      *{5}{c} 
      @{\hspace{2em}}
      *{5}{c} 
    }
      \toprule
      \multicolumn{2}{c}{} 
        & \multicolumn{5}{c}{\bfseries $\bm{\ell_1}$ Sparsity Crosscoder} 
        & \multicolumn{5}{c}{\bfseries BatchTopK Crosscoder} \\
      \cmidrule(r){3-7} \cmidrule(r){8-12}
      \textbf{Model} & \textbf{Comparison}
        & \textbf{$\bm{\ell_0}$} & \textbf{Dead Feats} 
        & \textbf{$\Delta$CE A} & \textbf{$\Delta$CE B}  & \textbf{$\Delta$CE C}
        & \textbf{$\bm{\ell_0}$} & \textbf{Dead Feats} 
        & \textbf{$\Delta$CE A} & \textbf{$\Delta$CE B}  & \textbf{$\Delta$CE C}\\
      \midrule
"""
footer = r"""    \bottomrule
    \end{tabular}
  }
  \caption{\textbf{Crosscoder statistics.} Results averaged over three seeds on validation set. $\Delta$CE is the change in cross-entropy loss when doing a forward pass using the original output versus the crosscoder reconstruction. A, B, C refer to the 1st, 2nd and 3rd checkpoints used for loss computation. $\ell_0$ and dead feature averages are rounded to integers.}
  \label{tab:crosscoder-stats}
\end{table*}
"""

body = []
for cfg in MODEL_CONFIGS:
    # L1 run
    l1_spars, l1_ce = aggregate_eval_res(
        model_str=cfg["model_str"],
        model_name=cfg["model_name"],
        replace_string=cfg["replace_string"],
        assert_num=cfg["assert_num"],
        joint_eval_sparsity_path="workspace/results/sparsity_ce/joint_eval_sparsity_l1",
        joint_eval_ce_path="workspace/results/sparsity_ce/joint_eval_ce_l1",
    )
    # TopK run
    tk_spars, tk_ce = aggregate_eval_res(
        model_str=cfg["model_str"],
        model_name=cfg["model_name"],
        replace_string=cfg["replace_string"],
        assert_num=cfg["assert_num"],
        joint_eval_sparsity_path="workspace/results/sparsity_ce/joint_eval_sparsity_topk_norm1",
        joint_eval_ce_path="workspace/results/sparsity_ce/joint_eval_ce_topk_norm1",
    )

    l1_spars = prepare_df(l1_spars, cfg["tok_list"])
    l1_ce    = prepare_df(l1_ce,    cfg["tok_list"])
    tk_spars = prepare_df(tk_spars, cfg["tok_list"])
    tk_ce    = prepare_df(tk_ce,    cfg["tok_list"])

    n = len(cfg["tok_list"])
    for i, tok in enumerate(cfg["tok_list"]):
        # select val‐split row
        sel_l1_s = l1_spars[(l1_spars["data_split"]=="val") & (l1_spars["token_name"]==tok)]
        sel_l1_c = l1_ce   [(l1_ce   ["data_split"]=="val") & (l1_ce   ["token_name"]==tok)]
        sel_tk_s = tk_spars[(tk_spars["data_split"]=="val") & (tk_spars["token_name"]==tok)]
        sel_tk_c = tk_ce   [(tk_ce   ["data_split"]=="val") & (tk_ce   ["token_name"]==tok)]

        # compute means
        def m(df, col): return float(df[col].mean()) if not df.empty else float("nan")
        l0_l1   = m(sel_l1_s, "l0_loss-mean")
        dead_l1 = m(sel_l1_s, "dead_count")
        ceA_l1  = m(sel_l1_c, "ce_diff_A-mean")
        ceB_l1  = m(sel_l1_c, "ce_diff_B-mean")
        ceC_l1  = m(sel_l1_c, "ce_diff_C-mean")

        l0_tk   = m(sel_tk_s, "l0_loss-mean")
        dead_tk = m(sel_tk_s, "dead_count")
        ceA_tk  = m(sel_tk_c, "ce_diff_A-mean")
        ceB_tk  = m(sel_tk_c, "ce_diff_B-mean")
        ceC_tk  = m(sel_tk_c, "ce_diff_C-mean")

        # format numeric & “-” for NaNs
        def fmt(v, intify=False):
            if pd.isna(v): return "-"
            return f"{int(v)}" if intify else f"{v:.2f}"
        left  = f"{fmt(l0_l1)} & {fmt(dead_l1, True)} & {fmt(ceA_l1)} & {fmt(ceB_l1)} & {fmt(ceC_l1)}"
        right = f"{fmt(l0_tk)} & {fmt(dead_tk, True)} & {fmt(ceA_tk)} & {fmt(ceB_tk)} & {fmt(ceC_tk)}"

        comp = tok.replace(" vs. ", r" \compar{} ")
        if i == 0:
            row = rf"\multirow{{{n}}}{{*}}{{{cfg['latex_name']}}} & {comp} & {left} & {right} \\"
        else:
            row = rf"                    & {comp} & {left} & {right} \\"
        body.append(row)
    body.append(r"      \midrule")

print(header + "\n".join(body) + "\n" + footer)


\begin{table*}[t]
  \centering
  \resizebox{1.0\linewidth}{!}{
    \begin{tabular}{ 
      cc 
      @{\hspace{2em}}
      *{5}{c} 
      @{\hspace{2em}}
      *{5}{c} 
    }
      \toprule
      \multicolumn{2}{c}{} 
        & \multicolumn{5}{c}{\bfseries $\bm{\ell_1}$ Sparsity Crosscoder} 
        & \multicolumn{5}{c}{\bfseries BatchTopK Crosscoder} \\
      \cmidrule(r){3-7} \cmidrule(r){8-12}
      \textbf{Model} & \textbf{Comparison}
        & \textbf{$\bm{\ell_0}$} & \textbf{Dead Feats} 
        & \textbf{$\Delta$CE A} & \textbf{$\Delta$CE B}  & \textbf{$\Delta$CE C}
        & \textbf{$\bm{\ell_0}$} & \textbf{Dead Feats} 
        & \textbf{$\Delta$CE A} & \textbf{$\Delta$CE B}  & \textbf{$\Delta$CE C}\\
      \midrule
\multirow{4}{*}{Pythia-1B L8} & 128M \compar{} 1B & 88.30 & 9 & 0.00 & 0.07 & - & 200.00 & 4 & -0.01 & 0.04 & - \\
                    & 1B \compar{} 4B & 214.18 & 1 & 0.05 & 0.18 & - & 200.00 & 0 & 0.01 & 0.11 & - \\
                    & 4B \compar{} 286B & 190.43