# Checkpoint Finder for Pythia, OLMo, and BLOOM

This notebook scans the `./workspace/logs/checkpoints` directory to find valid experiment checkpoints for **Pythia**, **OLMo**, and **BLOOM** models. It collects metadata such as:

- **Folder path** containing the checkpoint
- **Seed** used during training
- **Revision list** (e.g., `step512-step2000`)
- **Layer index** being analyzed, write the one you want +1 due to loading calculation
- **Version number** extracted from folder naming

### Workflow Overview
1. **Define revision lists**  sets the order in which checkpoints should be evaluated.
2. **Run `find_folders()`**  scans subdirectories for matching checkpoints.
3. **Sort results**  uses the custom revision order + seed for consistent output.
4. **Display results**  prints:
   - Total number of checkpoints found.
   - Version numbers for easy reference.
   - A full Pandas DataFrame with checkpoint details.

### Supported Models
| Model   | Matching Keyword | Expected Layer | Example Revisions |
|---------|------------------|---------------|-------------------|
| **Pythia** | `"pythia"` | `9` | `step512-step2000` |
| **OLMo**   | `"olmo"`   | `9` | `step2000-tokens4B_step16000-tokens33B` |
| **BLOOM**  | `"bloom"`  | `13` | `global_step10000-global_step100000` |

In [1]:
import os
import json
import pandas as pd
from vis import revision2tokens_pythia, revision2tokens_olmo, revision2tokens_bloom

def find_ckpts(
    root,
    revs=None,
    seeds=None,
    model=None,
    layers=None,
    batch_topk=False
):
    records = []

    for folder, _, files in os.walk(root):
        json_files = [f for f in files if f.endswith(".json")]
        if len(json_files) != 21:
            continue

        for file in json_files:
            path = os.path.join(folder, file)
            try:
                with open(path, "r", encoding="utf-8") as f:
                    data = json.load(f)

                # Parse hook layer index
                hook = data.get("hook_point") or ""
                parts = hook.split(".")
                idx = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else None
                if idx is None or (layers and idx not in layers):
                    continue

                # BatchTopK checkpoints
                if batch_topk:
                    if (
                        data.get("batch_topk_final") is not None
                        and (not seeds or data.get("seed") in seeds)
                        and data.get("dec_init_norm") == 1.0
                        and data.get("n_models") == 3.0
                    ):
                        records.append({
                            "folder": folder,
                            "seed": data.get("seed"),
                            "model": data.get("model_name"),
                            "rev": data.get("revision_list"),
                            "layer": idx - 1,
                            "version": folder.split("_")[-1]
                        })
                        break
                    continue

                # L1 checkpoints
                if (
                    data.get("batch_topk_final") is None
                    and (not revs or data.get("revision_list") in revs)
                    and (not seeds or data.get("seed") in seeds)
                    and (not model or model in (data.get("model_name") or "").lower())
                ):
                    records.append({
                        "folder": folder,
                        "seed": data.get("seed"),
                        "rev": data.get("revision_list"),
                        "layer": idx - 1,
                        "version": folder.split("_")[-1]
                    })
                    break
            except (json.JSONDecodeError, IOError):
                continue

    return pd.DataFrame(records)


  from .autonotebook import tqdm as notebook_tqdm


### Pythia

In [2]:
seeds = [124, 153, 6582]
root = "./workspace/logs/checkpoints"

revs = "step64-step512 step512-step2000 step2000-main".split(" ")
revs += ["step512-step2000-main"]
# print([revision2tokens_pythia(e) for i in revs for e in i.split("-")])

df = find_ckpts(root, revs, seeds, model="pythia", layers=[9])
df["rev"] = pd.Categorical(df["rev"], categories=revs, ordered=True)
df = df.sort_values(by=["rev", "seed"])
print(len(df))
print(" ".join(df["version"].values))
df

12
218 311 443 265 318 444 219 315 445 446 455 456


Unnamed: 0,folder,seed,rev,layer,version
0,./workspace/logs/checkpoints/version_218,124,step64-step512,8,218
3,./workspace/logs/checkpoints/version_311,153,step64-step512,8,311
6,./workspace/logs/checkpoints/version_443,6582,step64-step512,8,443
2,./workspace/logs/checkpoints/version_265,124,step512-step2000,8,265
5,./workspace/logs/checkpoints/version_318,153,step512-step2000,8,318
7,./workspace/logs/checkpoints/version_444,6582,step512-step2000,8,444
1,./workspace/logs/checkpoints/version_219,124,step2000-main,8,219
4,./workspace/logs/checkpoints/version_315,153,step2000-main,8,315
8,./workspace/logs/checkpoints/version_445,6582,step2000-main,8,445
9,./workspace/logs/checkpoints/version_446,124,step512-step2000-main,8,446


### OLMo

In [3]:
seeds = [124, 153, 6582]
root = "./workspace/logs/checkpoints"

revs = "step1000-tokens2B_step2000-tokens4B step2000-tokens4B_step16000-tokens33B step16000-tokens33B_step1454000-tokens3048B".split(" ")
revs += ["step2000-tokens4B_step16000-tokens33B_step1454000-tokens3048B"]

df = find_ckpts(root, revs, seeds, model="olmo", layers=[9])
df["rev"] = pd.Categorical(df["rev"], categories=revs, ordered=True)
df = df.sort_values(by=["rev", "seed"])
print(len(df))
print(" ".join(df["version"].values))
df

12
440 439 429 430 431 442 433 432 437 447 460 459


Unnamed: 0,folder,seed,rev,layer,version
7,./workspace/logs/checkpoints/version_440,124,step1000-tokens2B_step2000-tokens4B,8,440
6,./workspace/logs/checkpoints/version_439,153,step1000-tokens2B_step2000-tokens4B,8,439
0,./workspace/logs/checkpoints/version_429,6582,step1000-tokens2B_step2000-tokens4B,8,429
1,./workspace/logs/checkpoints/version_430,124,step2000-tokens4B_step16000-tokens33B,8,430
2,./workspace/logs/checkpoints/version_431,153,step2000-tokens4B_step16000-tokens33B,8,431
8,./workspace/logs/checkpoints/version_442,6582,step2000-tokens4B_step16000-tokens33B,8,442
4,./workspace/logs/checkpoints/version_433,124,step16000-tokens33B_step1454000-tokens3048B,8,433
3,./workspace/logs/checkpoints/version_432,153,step16000-tokens33B_step1454000-tokens3048B,8,432
5,./workspace/logs/checkpoints/version_437,6582,step16000-tokens33B_step1454000-tokens3048B,8,437
9,./workspace/logs/checkpoints/version_447,124,step2000-tokens4B_step16000-tokens33B_step1454...,8,447


### BLOOM

In [4]:
seeds = [124, 153, 6582]
root = "./workspace/logs/checkpoints"

revs = "global_step1000-global_step10000 global_step10000-global_step100000 global_step100000-main".split(" ")
revs += ["global_step10000-global_step100000-main"]

df = find_ckpts(root, revs, seeds, model="bloom", layers=[13])
df["rev"] = pd.Categorical(df["rev"], categories=revs, ordered=True)
df = df.sort_values(by=["rev", "seed"])
print(len(df))
print(" ".join(df["version"].values))
df

12
387 450 449 400 452 448 409 451 453 454 458 457


Unnamed: 0,folder,seed,rev,layer,version
0,./workspace/logs/checkpoints/version_387,124,global_step1000-global_step10000,12,387
5,./workspace/logs/checkpoints/version_450,153,global_step1000-global_step10000,12,450
4,./workspace/logs/checkpoints/version_449,6582,global_step1000-global_step10000,12,449
1,./workspace/logs/checkpoints/version_400,124,global_step10000-global_step100000,12,400
7,./workspace/logs/checkpoints/version_452,153,global_step10000-global_step100000,12,452
3,./workspace/logs/checkpoints/version_448,6582,global_step10000-global_step100000,12,448
2,./workspace/logs/checkpoints/version_409,124,global_step100000-main,12,409
6,./workspace/logs/checkpoints/version_451,153,global_step100000-main,12,451
8,./workspace/logs/checkpoints/version_453,6582,global_step100000-main,12,453
9,./workspace/logs/checkpoints/version_454,124,global_step10000-global_step100000-main,12,454


### BatchTopK

In [5]:
seeds = [124]
root = "./workspace/logs/checkpoints"

df_batch = find_ckpts(root, seeds=seeds, batch_topk=True)
df_batch = df_batch.sort_values(by=["model", "rev", "seed"])
print(len(df_batch))
print(" ".join(df_batch["version"].values))
df_batch

3
998 1064 1059


Unnamed: 0,folder,seed,model,rev,layer,version
0,./workspace/logs/checkpoints/version_998,124,EleutherAI/pythia-1b,step512-step2000-main,8,998
2,./workspace/logs/checkpoints/version_1064,124,allenai/OLMo-1B-0724-hf,step2000-tokens4B_step16000-tokens33B_step1454...,8,1064
1,./workspace/logs/checkpoints/version_1059,124,bigscience/bloom-1b1-intermediate,global_step10000-global_step100000-main,12,1059
