# Inspect GRPO rollout traces

Rollouts are saved as **tokenized** JSONL: one JSON object per line with `prompt_tokens`, `response_tokens` (token IDs), and optionally `logprobs` (vLLM logprobs per response token). Files can be very large (multi-GB). This notebook streams them line-by-line to avoid loading everything into memory.

**RolloutRecord fields:** `step`, `sample_idx`, `prompt_idx`, `prompt_tokens`, `response_tokens`, `reward`, `advantage`, `finish_reason`, `dataset`, `ground_truth`, `request_info`, `logprobs`

**Note:** The saved `logprobs` are from **vLLM** at rollout time. To compare with local policy logprobs (e.g. for debugging `vllm_vs_local_logprob_diff_mean`), you would need to run a separate forward pass on the same tokens with the learner model; this notebook only inspects the stored vLLM logprobs and metrics.

## 1. Configuration

In [None]:
import json
import os
from collections import defaultdict
from pathlib import Path

import numpy as np
import pandas as pd

# Path to the rollouts directory (contains *_metadata.jsonl and *_rollouts_*.jsonl)
ROLLOUTS_DIR = "/mnt/vast/home/lh22zyta/shortcut-RL/open-instruct/output/RLVR-soofi-Olmo-IsomorphicRL/rollouts"

# Optional: restrict to a specific run (prefix of rollout filenames). If None, use first run found.
RUN_NAME = "RLVR-soofi-Olmo-IsomorphicRL__1__1772176743"  # or None for auto

# Limit number of lines to read per file (None = no limit). Use for quick checks on huge files.
MAX_LINES_PER_FILE = None  # e.g. 50_000

# Step range to load (None = all steps). Reduces memory when inspecting a window.
STEP_MIN = None  # e.g. 400
STEP_MAX = None  # e.g. 600

# How many full records to keep in memory for detailed inspection (rest are aggregated as stats only)
NUM_SAMPLE_RECORDS = 5

## 2. List rollout files and load metadata

In [None]:
def get_rollout_files(rollouts_dir: str, run_name: str | None = None) -> tuple[list[str], str | None, dict | None]:
    """List rollout JSONL paths and return (paths, run_name, metadata_dict)."""
    base = Path(rollouts_dir)
    if not base.is_dir():
        raise FileNotFoundError(f"Not a directory: {rollouts_dir}")

    if run_name is None:
        # Infer run_name from first metadata file
        metas = sorted(base.glob("*_metadata.jsonl"))
        if not metas:
            raise FileNotFoundError(f"No *_metadata.jsonl in {rollouts_dir}")
        run_name = metas[0].stem.replace("_metadata", "")

    meta_path = base / f"{run_name}_metadata.jsonl"
    metadata = None
    if meta_path.exists():
        with open(meta_path) as f:
            metadata = json.loads(f.readline())

    paths = sorted(base.glob(f"{run_name}_rollouts_*.jsonl"))
    return [str(p) for p in paths], run_name, metadata


paths, run_name, metadata = get_rollout_files(ROLLOUTS_DIR, RUN_NAME)
print(f"Run: {run_name}")
print(f"Rollout files: {len(paths)}")
for p in paths:
    size_mb = os.path.getsize(p) / (1024 * 1024)
    print(f"  {os.path.basename(p)}  ({size_mb:.1f} MB)")
if metadata:
    print(f"Metadata: {metadata}")

## 3. Stream rollouts and aggregate stats (memory-efficient)

In [None]:
def stream_rollouts(
    paths: list[str],
    step_min: int | None = None,
    step_max: int | None = None,
    max_lines_per_file: int | None = None,
    num_sample_records: int = 5,
):
    """
    Iterate over JSONL files line-by-line. For each record:
    - Aggregate per-step stats (reward, advantage, response length, logprobs).
    - Keep up to num_sample_records full records (spread across steps) for inspection.
    """
    step_records = defaultdict(list)  # step -> list of lightweight dicts
    sample_records = []  # full records for display
    steps_seen = set()
    total_lines = 0

    for filepath in paths:
        lines_read = 0
        with open(filepath) as f:
            for line in f:
                if max_lines_per_file is not None and lines_read >= max_lines_per_file:
                    break
                line = line.strip()
                if not line:
                    continue
                try:
                    rec = json.loads(line)
                except json.JSONDecodeError as e:
                    print(f"Skip bad line in {filepath} line {lines_read + 1}: {e}")
                    continue

                step = rec.get("step", -1)
                if step_min is not None and step < step_min:
                    lines_read += 1
                    total_lines += 1
                    continue
                if step_max is not None and step > step_max:
                    lines_read += 1
                    total_lines += 1
                    continue

                reward = rec.get("reward", 0.0)
                advantage = rec.get("advantage", 0.0)
                resp_tokens = rec.get("response_tokens", [])
                logprobs = rec.get("logprobs")

                stat = {
                    "reward": reward,
                    "advantage": advantage,
                    "response_len": len(resp_tokens),
                    "prompt_len": len(rec.get("prompt_tokens", [])),
                    "finish_reason": rec.get("finish_reason", ""),
                }
                if logprobs is not None:
                    valid = [x for x in logprobs if isinstance(x, (int, float)) and not (isinstance(x, float) and np.isnan(x))]
                    if valid:
                        stat["logprob_mean"] = float(np.mean(valid))
                        stat["logprob_std"] = float(np.std(valid))
                        stat["logprob_min"] = float(np.min(valid))
                        stat["logprob_max"] = float(np.max(valid))
                    else:
                        stat["logprob_mean"] = stat["logprob_std"] = stat["logprob_min"] = stat["logprob_max"] = None
                else:
                    stat["logprob_mean"] = stat["logprob_std"] = stat["logprob_min"] = stat["logprob_max"] = None

                step_records[step].append(stat)

                # Keep a few full records for inspection (one per step span)
                if len(sample_records) < num_sample_records and step not in steps_seen:
                    steps_seen.add(step)
                    sample_records.append(rec)
                elif len(sample_records) < num_sample_records and step in steps_seen:
                    # Replace one of the samples with a later step to spread steps
                    pass  # keep first occurrence per step

                lines_read += 1
                total_lines += 1

    return dict(step_records), sample_records, total_lines


step_records, sample_records, total_lines = stream_rollouts(
    paths,
    step_min=STEP_MIN,
    step_max=STEP_MAX,
    max_lines_per_file=MAX_LINES_PER_FILE,
    num_sample_records=NUM_SAMPLE_RECORDS,
)
print(f"Total records read: {total_lines}")
print(f"Steps with data: {len(step_records)}")
if step_records:
    steps = sorted(step_records.keys())
    print(f"Step range: {steps[0]} .. {steps[-1]}")

In [None]:
p = paths[0]
file = open(p)
j = json.loads(file.readline())


In [None]:
print(j['logprobs'])

## 4. Per-step summary (rewards, lengths, logprobs)

In [None]:
def build_step_summary(step_records: dict) -> pd.DataFrame:
    rows = []
    for step in sorted(step_records.keys()):
        stats = step_records[step]
        n = len(stats)
        rewards = [s["reward"] for s in stats]
        advantages = [s["advantage"] for s in stats]
        resp_lens = [s["response_len"] for s in stats]
        prompt_lens = [s["prompt_len"] for s in stats]
        finish_stop = sum(1 for s in stats if s["finish_reason"] == "stop")

        row = {
            "step": step,
            "n": n,
            "reward_mean": np.mean(rewards),
            "reward_std": np.std(rewards),
            "reward_min": np.min(rewards),
            "reward_max": np.max(rewards),
            "advantage_mean": np.mean(advantages),
            "resp_len_mean": np.mean(resp_lens),
            "resp_len_max": np.max(resp_lens),
            "prompt_len_mean": np.mean(prompt_lens),
            "stop_rate": finish_stop / n if n else 0,
        }
        logprob_means = [s["logprob_mean"] for s in stats if s.get("logprob_mean") is not None]
        if logprob_means:
            row["logprob_mean_mean"] = np.mean(logprob_means)
            row["logprob_mean_std"] = np.std(logprob_means)
        else:
            row["logprob_mean_mean"] = np.nan
            row["logprob_mean_std"] = np.nan
        rows.append(row)
    return pd.DataFrame(rows)


df = build_step_summary(step_records)
display(df)

In [None]:
# Plot reward and response length vs step (if matplotlib available)
try:
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    ax = axes[0, 0]
    ax.plot(df["step"], df["reward_mean"], label="reward mean")
    ax.fill_between(df["step"], df["reward_min"], df["reward_max"], alpha=0.2)
    ax.set_xlabel("step")
    ax.set_ylabel("reward")
    ax.legend()
    ax.grid(True, alpha=0.3)

    ax = axes[0, 1]
    ax.plot(df["step"], df["resp_len_mean"], label="response length mean")
    ax.plot(df["step"], df["resp_len_max"], alpha=0.7, label="response length max")
    ax.set_xlabel("step")
    ax.set_ylabel("tokens")
    ax.legend()
    ax.grid(True, alpha=0.3)

    ax = axes[1, 0]
    ax.plot(df["step"], df["stop_rate"], color="green")
    ax.set_xlabel("step")
    ax.set_ylabel("stop_rate")
    ax.grid(True, alpha=0.3)

    ax = axes[1, 1]
    if df["logprob_mean_mean"].notna().any():
        ax.plot(df["step"], df["logprob_mean_mean"], label="mean logprob (vLLM)")
    ax.set_xlabel("step")
    ax.set_ylabel("logprob")
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()
except ImportError:
    print("matplotlib not available; skip plots.")

## 5. Sample full records (token IDs only)

In [None]:
for i, rec in enumerate(sample_records):
    print(f"--- Sample {i + 1} (step={rec.get('step')}, sample_idx={rec.get('sample_idx')}) ---")
    print(f"  reward={rec.get('reward')}, advantage={rec.get('advantage')}, finish_reason={rec.get('finish_reason')}")
    print(f"  prompt_tokens: len={len(rec.get('prompt_tokens', []))}")
    print(f"  response_tokens: len={len(rec.get('response_tokens', []))}")
    if rec.get("logprobs"):
        lp = rec["logprobs"]
        valid = [x for x in lp if isinstance(x, (int, float)) and not (isinstance(x, float) and np.isnan(x))]
        if valid:
            print(f"  logprobs: len={len(lp)}, mean={np.mean(valid):.4f}, min={np.min(valid):.4f}, max={np.max(valid):.4f}")
    print()

## 6. Optional: decode tokens to text (requires tokenizer)

In [None]:
# Load tokenizer from metadata model_name (or set MODEL_NAME explicitly)
MODEL_NAME = (metadata or {}).get("model_name", "allenai/OLMo-1B-7B")  # fallback
USE_TOKENIZER = True  # set False to skip tokenizer load and decode

if USE_TOKENIZER:
    try:
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        print(f"Loaded tokenizer: {MODEL_NAME}")
    except Exception as e:
        print(f"Could not load tokenizer: {e}")
        tokenizer = None
else:
    tokenizer = None
    print("Skipping tokenizer (USE_TOKENIZER=False).")

In [None]:
len(sample_records)

In [None]:
if tokenizer is not None and sample_records:
    for i, rec in enumerate(sample_records[:3]):  # decode first 3 only
        print(f"=== Sample {i + 1} (step={rec.get('step')}) ===")
        prompt_ids = rec.get("prompt_tokens", [])
        response_ids = rec.get("response_tokens", [])
        prompt_text = tokenizer.decode(prompt_ids, skip_special_tokens=False)
        response_text = tokenizer.decode(response_ids, skip_special_tokens=False)
        print("[Prompt] (first 500 chars)")
        print(prompt_text[:500])
        print("[Response] (first 800 chars)")
        print(response_text[:800])
        print()
else:
    print("No tokenizer or no sample records; skip decode.")

## 7. Inspect a single step in detail (re-run with STEP_MIN/STEP_MAX)

In [None]:
# Optional: stream again with a narrow step range and larger sample to get many records from one step
def stream_one_step(paths: list[str], target_step: int, max_records: int = 64):
    records = []
    for filepath in paths:
        with open(filepath) as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                rec = json.loads(line)
                if rec.get("step") == target_step:
                    records.append(rec)
                    if len(records) >= max_records:
                        return records
    return records


TARGET_STEP = 30  # change to step of interest (e.g. 400, 500)
one_step_records = stream_one_step(paths, TARGET_STEP)
print(f"Records at step {TARGET_STEP}: {len(one_step_records)}")
if one_step_records:
    r0 = one_step_records[0]
    print(f"  reward: {r0.get('reward')}, advantage: {r0.get('advantage')}, finish_reason: {r0.get('finish_reason')}")
    print(f"  response_tokens length: {len(r0.get('response_tokens', []))}")
    if r0.get("logprobs"):
        lp = [x for x in r0["logprobs"] if isinstance(x, (int, float)) and not (isinstance(x, float) and np.isnan(x))]
        if lp:
            print(f"  logprobs: mean={np.mean(lp):.4f}, std={np.std(lp):.4f}")

In [None]:
# Compare first and last rollouts (2 each)

def _tail_lines(path: str, n: int):
    # read last n non-empty lines (binary-safe)
    with open(path, 'rb') as f:
        f.seek(0, 2)
        size = f.tell()
        block = bytearray()
        lines = []
        pos = size
        while pos > 0 and len(lines) <= n:
            toread = min(4096, pos)
            pos -= toread
            f.seek(pos)
            data = f.read(toread)
            block = data + block
            lines = block.splitlines()
        return [l.decode('utf-8', 'ignore') for l in lines[-n:]]


def get_first_last_rollouts(paths: list[str], n_each: int = 2):
    first = []
    last = []
    # collect first records from files in order
    for p in paths:
        with open(p) as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                first.append(json.loads(line))
                if len(first) >= n_each:
                    break
        if len(first) >= n_each:
            break
    # collect last records from files in reverse order
    for p in reversed(paths):
        try:
            tail = _tail_lines(p, n_each * 2)  # read a few to skip blanks
        except Exception:
            tail = []
        for line in reversed(tail):
            if not line.strip():
                continue
            try:
                last.append(json.loads(line))
            except Exception:
                continue
            if len(last) >= n_each:
                break
        if len(last) >= n_each:
            break
    last = list(reversed(last))  # keep chronological order for the late set
    return first, last


def summarize_and_decode(records: list[dict], label: str, tokenizer=None, max_chars=800):
    print(f'=== {label} ({len(records)} records) ===')
    for i, r in enumerate(records):
        print(f'-- {label} #{i+1}: step={r.get("step")}, sample_idx={r.get("sample_idx")}')
        print(f'   reward={r.get("reward")}, advantage={r.get("advantage")}, finish_reason={r.get("finish_reason")}')
        prompt_ids = r.get('prompt_tokens', [])
        resp_ids = r.get('response_tokens', [])
        print(f'   prompt_len={len(prompt_ids)}, response_len={len(resp_ids)}')
        lp = r.get('logprobs')
        if lp:
            valid = [x for x in lp if isinstance(x, (int, float)) and not (isinstance(x, float) and np.isnan(x))]
            if valid:
                print(f'   logprobs: n={len(lp)}, mean={np.mean(valid):.4f}, std={np.std(valid):.4f}')
        if tokenizer is not None:
            try:
                ptext = tokenizer.decode(prompt_ids, skip_special_tokens=False)
                rtext = tokenizer.decode(resp_ids, skip_special_tokens=False)
                print('   [Prompt] (first 300 chars)')
                print(ptext[:].replace('\n', ' '))
                print('   [Response] (first 600 chars)')
                print(rtext[:].replace('\n', ' '))
            except Exception as e:
                print(f'   decode failed: {e}')
        print()


# Run comparison (2 earliest, 2 latest)
early, late = get_first_last_rollouts(paths, n_each=2)
# load tokenizer if available (reuse earlier `tokenizer` variable if set)
tk = globals().get('tokenizer', None)
summarize_and_decode(early, 'EARLY', tokenizer=tk)
summarize_and_decode(late, 'LATE', tokenizer=tk)


In [None]:
# Print one (or more) task-response examples per step (memory-efficient)

def load_examples_by_step(paths: list[str], steps: list[int] | None = None, max_examples_per_step: int = 1):
    """Load up to `max_examples_per_step` examples per step from the rollouts."""
    examples = defaultdict(list)  # step -> list[rec]

    for filepath in paths:
        try:
            with open(filepath) as f:
                for raw in f:
                    line = raw.strip()
                    if not line:
                        continue
                    try:
                        rec = json.loads(line)
                    except Exception:
                        continue
                    step = rec.get('step')
                    if step is None:
                        continue
                    if steps and step not in steps:
                        continue
                    if len(examples[step]) < max_examples_per_step:
                        examples[step].append(rec)
        except Exception:
            continue

    return dict(examples)


paths = sorted(paths)[-1:]

examples = load_examples_by_step(paths, steps=None, max_examples_per_step=1000)


In [None]:
def print_examples(examples: dict, tokenizer):
    for step in sorted(examples.keys()):
        # if step % 10 == 0:
        for i, r in enumerate(examples[step]):
            pids = r.get('prompt_tokens', [])
            rids = r.get('response_tokens', [])
            print(f'=== STEP {step}  ({i+1}/{len(examples[step])}) - dataset={r.get("dataset")} ===')
            print(f"  reward={round(r.get('reward', 0), 2)}, advantage={round(r.get('advantage', 0), 2)}, finish_reason={r.get('finish_reason')},  prompt_len={len(pids)}, response_len={len(rids)}")
            try:
                ptext = tokenizer.decode(pids, skip_special_tokens=False)
                rtext = tokenizer.decode(rids, skip_special_tokens=False)
                print('  [Prompt]')
                print('   ', ptext[:].replace('\n', ' '))
                print('  [Response]')
                print('   ', rtext[:].replace('\n', ' '))
            except Exception as e:
                print('  decode failed:', e)
steps = sorted(examples.keys())
recent_steps = steps[:]
# Run the per-step printer with default options (prints 1 example per step)
print_examples(tokenizer=globals().get('tokenizer', None), examples={step: examples[step] for step in recent_steps})


In [None]:
examples[328][0].keys()
for k in ['step', 'sample_idx', 'prompt_idx', 'reward', 'advantage', 'finish_reason', 'dataset', 'ground_truth', 'request_info']:
    print(f"{k}: {examples[328][0].get(k)}")

In [None]:
examples.keys()

## 8. Diagnose vLLM vs Local Logprob Divergence

**Goal:** Check whether the logprob divergence is caused by prompt-response misalignment (wrong responses mapped to wrong prompts) or by genuine model drift.

Checks performed:
1. **Prompt consistency within groups:** All `num_samples_per_prompt` samples sharing the same `prompt_idx` at a given step must have identical `prompt_tokens`.
2. **Logprob length matches response length:** `len(logprobs) == len(response_tokens)` for every record.
3. **Logprob NaN/invalid fraction:** Fraction of NaN or missing logprobs per step.
4. **Local forward pass vs vLLM logprobs:** Load model, run a forward pass on saved token IDs, and compare.

In [None]:
# --- Check 1 & 2: Prompt-response alignment + logprob length consistency ---
# Streams ALL rollout files for the current run and reports:
# - Whether prompt_tokens are identical for records sharing (step, prompt_idx)
# - Whether len(logprobs) == len(response_tokens) for every record
# - NaN fraction in logprobs per step

from collections import defaultdict

alignment_errors = []
logprob_len_mismatches = []
step_nan_fractions = defaultdict(list)  # step -> list of nan_fraction per sample
step_logprob_means = defaultdict(list)
total_records = 0
prompt_groups = defaultdict(dict)  # (step, prompt_idx) -> {sample_idx: prompt_tokens_hash}

for filepath in paths:
    with open(filepath) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                rec = json.loads(line)
            except json.JSONDecodeError:
                continue
            total_records += 1
            step = rec['step']
            pidx = rec['prompt_idx']
            sidx = rec['sample_idx']
            pt = rec['prompt_tokens']
            rt = rec['response_tokens']
            lp = rec.get('logprobs', [])

            # Check 1: prompt consistency within groups
            group_key = (step, pidx)
            pt_hash = hash(tuple(pt))
            if group_key not in prompt_groups:
                prompt_groups[group_key] = {'hash': pt_hash, 'first_sidx': sidx, 'first_pt_head': pt[:5]}
            elif prompt_groups[group_key]['hash'] != pt_hash:
                alignment_errors.append({
                    'step': step, 'prompt_idx': pidx, 'sample_idx': sidx,
                    'expected_head': prompt_groups[group_key]['first_pt_head'],
                    'got_head': pt[:5],
                    'first_sidx': prompt_groups[group_key]['first_sidx'],
                })

            # Check 2: logprob length matches response length
            if lp is not None and len(lp) != len(rt):
                logprob_len_mismatches.append({
                    'step': step, 'sample_idx': sidx,
                    'resp_len': len(rt), 'lp_len': len(lp),
                })

            # Check 3: NaN fraction
            if lp:
                nan_count = sum(1 for x in lp if x is None or (isinstance(x, float) and np.isnan(x)))
                step_nan_fractions[step].append(nan_count / len(lp) if len(lp) > 0 else 0)
                valid = [x for x in lp if x is not None and not (isinstance(x, float) and np.isnan(x))]
                if valid:
                    step_logprob_means[step].append(float(np.mean(valid)))

print(f"Total records scanned: {total_records}")
print(f"Unique (step, prompt_idx) groups: {len(prompt_groups)}")
print()

# Report Check 1
if alignment_errors:
    print(f"❌ PROMPT ALIGNMENT ERRORS: {len(alignment_errors)}")
    for err in alignment_errors[:5]:
        print(f"   step={err['step']}, prompt_idx={err['prompt_idx']}, "
              f"sample_idx={err['sample_idx']} vs first_sidx={err['first_sidx']}")
        print(f"   expected prompt[:5]={err['expected_head']}, got={err['got_head']}")
else:
    print("✅ Prompt alignment: All samples within each (step, prompt_idx) group have identical prompt_tokens.")

# Report Check 2
if logprob_len_mismatches:
    print(f"\n❌ LOGPROB LENGTH MISMATCHES: {len(logprob_len_mismatches)}")
    for mm in logprob_len_mismatches[:5]:
        print(f"   step={mm['step']}, sample_idx={mm['sample_idx']}: resp_len={mm['resp_len']}, lp_len={mm['lp_len']}")
else:
    print("✅ Logprob lengths: len(logprobs) == len(response_tokens) for all records.")

# Report Check 3
print(f"\n--- NaN fraction in vLLM logprobs per step (sample of every 50 steps) ---")
for step in sorted(step_nan_fractions.keys()):
    if step % 50 == 0:
        fracs = step_nan_fractions[step]
        print(f"  step={step:5d}: mean_nan_frac={np.mean(fracs):.4f}, max_nan_frac={np.max(fracs):.4f}, n={len(fracs)}")

In [None]:
# --- Plot vLLM logprob mean over training steps ---
# This shows if vLLM logprobs change as training progresses (they should become
# slightly different from local logprobs as the learner model drifts between weight syncs).

steps_sorted = sorted(step_logprob_means.keys())
lp_mean_per_step = [np.mean(step_logprob_means[s]) for s in steps_sorted]
lp_std_per_step = [np.std(step_logprob_means[s]) for s in steps_sorted]

try:
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(1, 1, figsize=(12, 4))
    ax.plot(steps_sorted, lp_mean_per_step, label='vLLM logprob mean (per step)')
    ax.fill_between(steps_sorted,
                     [m - s for m, s in zip(lp_mean_per_step, lp_std_per_step)],
                     [m + s for m, s in zip(lp_mean_per_step, lp_std_per_step)],
                     alpha=0.2)
    ax.set_xlabel('step')
    ax.set_ylabel('mean logprob')
    ax.set_title('vLLM logprob mean across training steps')
    ax.grid(True, alpha=0.3)
    ax.legend()
    plt.tight_layout()
    plt.show()
except ImportError:
    print("matplotlib not available; printing table instead.")
    for s, m, sd in zip(steps_sorted, lp_mean_per_step, lp_std_per_step):
        if s % 50 == 0:
            print(f"  step={s:5d}: lp_mean={m:.4f}, lp_std={sd:.4f}")

### 8.1 Local forward pass vs vLLM logprobs

Load the **base model** (before any training), run a forward pass on the saved `prompt_tokens + response_tokens`, and compare the local logprobs with the stored vLLM logprobs.

If at step 0 the logprobs match closely (< 0.01 mean abs diff), the prompt-response mapping is correct. If they diverge even at step 0, something is wrong with how tokens are saved or how logprobs are extracted.

**Note:** This requires a GPU. The model loaded here is the *initial* checkpoint, so it should match vLLM logprobs at early steps. At later steps, the learner model will have drifted.

In [None]:
# --- Local forward pass comparison with vLLM logprobs ---
# Pick a few samples from step 0 (where model = initial weights = vLLM weights) and
# a few from a later step to compare.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = (metadata or {}).get("model_name", "allenai/Olmo-3-7B-Think-DPO")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_COMPARE = 5  # number of samples to compare
TEMPERATURE = 1.0  # must match the generation temperature used during rollout

print(f"Loading model: {MODEL_NAME} on {DEVICE}")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True
).to(DEVICE).eval()
tok = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
print(f"Model loaded. Vocab size: {model.config.vocab_size}")

In [None]:
# --- Compute local logprobs and compare with vLLM logprobs ---

def compute_local_logprobs(model, input_ids: list[int], device: str, temperature: float = 1.0) -> np.ndarray:
    """Run a forward pass and extract per-token logprobs for the response portion."""
    with torch.no_grad():
        ids = torch.tensor([input_ids], dtype=torch.long, device=device)
        outputs = model(ids)
        logits = outputs.logits[0]  # (seq_len, vocab_size)
        if temperature != 1.0:
            logits = logits / temperature
        log_probs = torch.log_softmax(logits, dim=-1)
        # logprob of each token given the previous context: log_probs[t-1, token_t]
        # So for position t (0-indexed), the logprob is log_probs[t-1, input_ids[t]]
        token_logprobs = []
        for t in range(1, len(input_ids)):
            token_logprobs.append(log_probs[t - 1, input_ids[t]].item())
    return np.array(token_logprobs)


def compare_logprobs_for_records(records: list[dict], model, device: str, temperature: float = 1.0, label: str = ""):
    """Compare local vs vLLM logprobs for a list of rollout records."""
    print(f"\n{'='*60}")
    print(f"Comparing local vs vLLM logprobs: {label} ({len(records)} samples)")
    print(f"{'='*60}")

    all_diffs = []
    for i, rec in enumerate(records):
        prompt_ids = rec['prompt_tokens']
        resp_ids = rec['response_tokens']
        vllm_lp = rec.get('logprobs', [])

        if not vllm_lp or not resp_ids:
            print(f"  Sample {i}: skipped (no logprobs or empty response)")
            continue

        full_ids = prompt_ids + resp_ids
        local_lp_full = compute_local_logprobs(model, full_ids, device, temperature)

        # local_lp_full has logprobs for positions 1..len(full_ids)-1
        # vLLM logprobs correspond to response tokens only
        prompt_len = len(prompt_ids)
        # local logprobs for response tokens: positions prompt_len..end
        # In local_lp_full, index (prompt_len - 1) corresponds to logprob of token at position prompt_len
        local_lp_response = local_lp_full[prompt_len - 1:]  # logprobs for resp tokens

        vllm_lp_arr = np.array(vllm_lp, dtype=np.float64)
        # Filter out NaN vllm logprobs
        valid_mask = ~np.isnan(vllm_lp_arr)

        if len(local_lp_response) != len(vllm_lp_arr):
            print(f"  Sample {i}: LENGTH MISMATCH local={len(local_lp_response)} vs vllm={len(vllm_lp_arr)}")
            continue

        diffs = np.abs(local_lp_response[valid_mask] - vllm_lp_arr[valid_mask])
        all_diffs.extend(diffs.tolist())

        mean_diff = diffs.mean()
        max_diff = diffs.max()
        median_diff = np.median(diffs)

        # Check first few token logprobs in detail
        print(f"  Sample {i} (step={rec.get('step')}, sidx={rec.get('sample_idx')}): "
              f"mean_abs_diff={mean_diff:.6f}, max_abs_diff={max_diff:.6f}, median={median_diff:.6f}, "
              f"resp_len={len(resp_ids)}")

        # Show first 5 token comparisons
        for t in range(min(5, len(resp_ids))):
            if valid_mask[t]:
                token_text = tok.decode([resp_ids[t]])
                print(f"    token[{t}]={resp_ids[t]:6d} ({token_text!r:>12s}): "
                      f"local={local_lp_response[t]:.6f}, vllm={vllm_lp_arr[t]:.6f}, "
                      f"diff={abs(local_lp_response[t] - vllm_lp_arr[t]):.6f}")

    if all_diffs:
        all_diffs = np.array(all_diffs)
        print(f"\n  OVERALL: mean_abs_diff={all_diffs.mean():.6f}, max={all_diffs.max():.6f}, "
              f"median={np.median(all_diffs):.6f}, p95={np.percentile(all_diffs, 95):.6f}, "
              f"p99={np.percentile(all_diffs, 99):.6f}")
        if all_diffs.mean() < 0.05:
            print("  ✅ Logprobs match closely — prompt-response alignment looks correct.")
        elif all_diffs.mean() < 0.5:
            print("  ⚠️  Small divergence — likely due to bf16/fp16 precision differences or temperature.")
        else:
            print("  ❌ Large divergence — possible prompt-response misalignment or model mismatch!")


# Collect samples from step 0 (earliest) and a late step
step0_records = []
late_step_records = []
latest_step = max(step_logprob_means.keys()) if step_logprob_means else 0

for filepath in paths:
    with open(filepath) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rec = json.loads(line)
            if rec['step'] == 0 and len(step0_records) < NUM_COMPARE:
                step0_records.append(rec)
            if rec['step'] == latest_step and len(late_step_records) < NUM_COMPARE:
                late_step_records.append(rec)
            if len(step0_records) >= NUM_COMPARE and len(late_step_records) >= NUM_COMPARE:
                break
    if len(step0_records) >= NUM_COMPARE and len(late_step_records) >= NUM_COMPARE:
        break

print(f"Collected {len(step0_records)} step-0 records, {len(late_step_records)} step-{latest_step} records")

# Compare step 0 (should match closely since model = initial weights = vLLM weights)
compare_logprobs_for_records(step0_records, model, DEVICE, TEMPERATURE, label=f"Step 0 (initial model)")

# Compare latest step (expected to diverge more since learner has trained)
compare_logprobs_for_records(late_step_records, model, DEVICE, TEMPERATURE, label=f"Step {latest_step} (latest)")

### 8.2 Response-prompt coherence check

An additional check: decode a few prompt-response pairs and verify that the response is actually a plausible continuation of the prompt (not a response to a different question). This catches subtle misalignment that token-level checks might miss.

In [None]:
# --- Response-prompt coherence: eyeball check with decoded text ---
# Also checks: within a (step, prompt_idx) group, are all responses different?
# (If responses are identical across samples, something might be wrong with sampling.)

def check_response_diversity_and_coherence(
    paths: list[str], steps_to_check: list[int], n_groups: int = 3, tokenizer=None
):
    """For each step, pick a few prompt groups and:
    1. Show decoded prompt + response for visual inspection.
    2. Check if responses within a group are diverse (not duplicated).
    """
    for target_step in steps_to_check:
        print(f"\n{'='*60}")
        print(f"Step {target_step}: Coherence & Diversity Check")
        print(f"{'='*60}")
        groups = defaultdict(list)  # prompt_idx -> [records]
        for filepath in paths:
            with open(filepath) as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    rec = json.loads(line)
                    if rec['step'] == target_step:
                        groups[rec['prompt_idx']].append(rec)
            if len(groups) >= n_groups + 2:
                break

        if not groups:
            print(f"  No records found for step {target_step}")
            continue

        for pidx in sorted(groups.keys())[:n_groups]:
            recs = groups[pidx]
            print(f"\n  --- prompt_idx={pidx}, n_samples={len(recs)} ---")

            # Decode prompt (same for all in group)
            if tokenizer:
                prompt_text = tokenizer.decode(recs[0]['prompt_tokens'], skip_special_tokens=False)
                print(f"  [Prompt] (first 400 chars)")
                print(f"  {prompt_text[:400]}")
            print(f"  [Ground truth]: {str(recs[0].get('ground_truth', ''))[:200]}")

            # Check response diversity
            resp_hashes = set()
            for i, r in enumerate(recs[:6]):
                rt_hash = hash(tuple(r['response_tokens']))
                resp_hashes.add(rt_hash)
                if tokenizer:
                    resp_text = tokenizer.decode(r['response_tokens'], skip_special_tokens=False)
                    print(f"  [Response {i}] reward={r['reward']:.2f}, len={len(r['response_tokens'])}, "
                          f"finish={r['finish_reason']}")
                    print(f"    {resp_text[:300]}")

            if len(resp_hashes) < len(recs[:6]):
                print(f"  ⚠️  Only {len(resp_hashes)} unique responses out of {len(recs[:6])} samples!")
            else:
                print(f"  ✅ All {len(recs[:6])} sampled responses are unique.")


# Check early, middle, and late steps
all_steps = sorted(step_logprob_means.keys())
mid_step = all_steps[len(all_steps) // 2] if all_steps else 0
check_steps = [0, mid_step, all_steps[-1]] if all_steps else [0]
check_steps = sorted(set(check_steps))

tk = globals().get('tokenizer', tok if 'tok' in dir() else None)
check_response_diversity_and_coherence(paths, check_steps, n_groups=2, tokenizer=tk)

### 8.3 Off-by-one check: first response token logprob

vLLM can sometimes return N-1 logprobs for N response tokens (missing the first token's logprob). Check if the first logprob in each record is NaN (indicating it was synthesized) vs a real value. Also check if shifting the vLLM logprobs by 1 gives a better match with local logprobs (would indicate an off-by-one in the alignment).

In [None]:
# --- Off-by-one check ---
# If vLLM logprobs are shifted by 1 relative to local logprobs, shifting them would
# reduce the disagreement. We test:
#   1) Normal alignment: local_lp[prompt_len-1:] vs vllm_lp
#   2) Shifted by +1:    local_lp[prompt_len:] vs vllm_lp[:-1]  (vLLM is ahead by 1)
#   3) Shifted by -1:    local_lp[prompt_len-2:] vs vllm_lp[1:]  (vLLM is behind by 1)

def check_offbyone(records: list[dict], model, device: str, temperature: float = 1.0, label: str = ""):
    print(f"\n--- Off-by-one check: {label} ---")
    results = {'normal': [], 'shift+1': [], 'shift-1': []}

    for rec in records:
        prompt_ids = rec['prompt_tokens']
        resp_ids = rec['response_tokens']
        vllm_lp = rec.get('logprobs', [])
        if not vllm_lp or len(resp_ids) < 5:
            continue

        full_ids = prompt_ids + resp_ids
        local_lp_full = compute_local_logprobs(model, full_ids, device, temperature)
        prompt_len = len(prompt_ids)
        vllm_arr = np.array(vllm_lp, dtype=np.float64)
        valid = ~np.isnan(vllm_arr)

        # Normal alignment
        local_resp = local_lp_full[prompt_len - 1:]
        if len(local_resp) == len(vllm_arr):
            diffs = np.abs(local_resp[valid] - vllm_arr[valid])
            results['normal'].append(diffs.mean())

        # Shift +1: compare local[prompt_len:] with vllm[:-1]
        local_shifted_plus = local_lp_full[prompt_len:]
        vllm_shifted_minus = vllm_arr[:-1]
        valid_s = ~np.isnan(vllm_shifted_minus)
        min_len = min(len(local_shifted_plus), len(vllm_shifted_minus))
        if min_len > 0:
            diffs_p = np.abs(local_shifted_plus[:min_len][valid_s[:min_len]] - vllm_shifted_minus[:min_len][valid_s[:min_len]])
            if len(diffs_p) > 0:
                results['shift+1'].append(diffs_p.mean())

        # Shift -1: compare local[prompt_len-2:] with vllm[1:]
        if prompt_len >= 2:
            local_shifted_minus = local_lp_full[prompt_len - 2:]
            vllm_shifted_plus = vllm_arr[1:]
            valid_m = ~np.isnan(vllm_shifted_plus)
            min_len = min(len(local_shifted_minus), len(vllm_shifted_plus))
            if min_len > 0:
                diffs_m = np.abs(local_shifted_minus[:min_len][valid_m[:min_len]] - vllm_shifted_plus[:min_len][valid_m[:min_len]])
                if len(diffs_m) > 0:
                    results['shift-1'].append(diffs_m.mean())

    for key in ['normal', 'shift+1', 'shift-1']:
        if results[key]:
            print(f"  {key:>10s}: mean_abs_diff={np.mean(results[key]):.6f} (over {len(results[key])} samples)")
        else:
            print(f"  {key:>10s}: no data")

    # Determine best alignment
    means = {k: np.mean(v) if v else float('inf') for k, v in results.items()}
    best = min(means, key=means.get)
    if best == 'normal':
        print("  ✅ Normal alignment is best — no off-by-one issue detected.")
    else:
        print(f"  ⚠️  '{best}' alignment gives lower error — possible off-by-one in logprob indexing!")

    # Also check: is the first logprob typically NaN?
    first_lp_nan = sum(1 for rec in records if rec.get('logprobs') and np.isnan(rec['logprobs'][0]))
    first_lp_valid = sum(1 for rec in records if rec.get('logprobs') and not np.isnan(rec['logprobs'][0]))
    print(f"  First response token logprob: NaN in {first_lp_nan}/{first_lp_nan + first_lp_valid} records")


check_offbyone(step0_records, model, DEVICE, TEMPERATURE, label="Step 0")
if late_step_records:
    check_offbyone(late_step_records, model, DEVICE, TEMPERATURE, label=f"Step {latest_step}")

### 8.4 Per-token logprob diff distribution (heatmap for one sample)

Visualize the per-token absolute difference between local and vLLM logprobs across the full response. Helps spot if the divergence is concentrated at certain positions (e.g., beginning, end, or at tool call boundaries).

In [None]:
# --- Per-token logprob diff plot for step 0 samples ---

try:
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(min(len(step0_records), 3), 1, figsize=(14, 3 * min(len(step0_records), 3)), squeeze=False)
    for idx, rec in enumerate(step0_records[:3]):
        prompt_ids = rec['prompt_tokens']
        resp_ids = rec['response_tokens']
        vllm_lp = np.array(rec.get('logprobs', []), dtype=np.float64)
        if len(vllm_lp) == 0:
            continue

        full_ids = prompt_ids + resp_ids
        local_lp_full = compute_local_logprobs(model, full_ids, DEVICE, TEMPERATURE)
        prompt_len = len(prompt_ids)
        local_lp_resp = local_lp_full[prompt_len - 1:]

        valid = ~np.isnan(vllm_lp)
        positions = np.arange(len(vllm_lp))

        ax = axes[idx, 0]
        # Plot both logprob curves
        ax.plot(positions[valid], local_lp_resp[valid], alpha=0.7, label='local', linewidth=0.5)
        ax.plot(positions[valid], vllm_lp[valid], alpha=0.7, label='vLLM', linewidth=0.5)
        ax.set_xlabel('response token position')
        ax.set_ylabel('logprob')
        ax.set_title(f'Sample {idx} (step={rec["step"]}, sidx={rec["sample_idx"]}): '
                      f'mean_abs_diff={np.abs(local_lp_resp[valid] - vllm_lp[valid]).mean():.4f}')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Also plot the absolute difference
    fig2, axes2 = plt.subplots(min(len(step0_records), 3), 1, figsize=(14, 2 * min(len(step0_records), 3)), squeeze=False)
    for idx, rec in enumerate(step0_records[:3]):
        prompt_ids = rec['prompt_tokens']
        resp_ids = rec['response_tokens']
        vllm_lp = np.array(rec.get('logprobs', []), dtype=np.float64)
        if len(vllm_lp) == 0:
            continue

        full_ids = prompt_ids + resp_ids
        local_lp_full = compute_local_logprobs(model, full_ids, DEVICE, TEMPERATURE)
        prompt_len = len(prompt_ids)
        local_lp_resp = local_lp_full[prompt_len - 1:]

        valid = ~np.isnan(vllm_lp)
        positions = np.arange(len(vllm_lp))
        abs_diff = np.abs(local_lp_resp - vllm_lp)

        ax = axes2[idx, 0]
        ax.bar(positions[valid], abs_diff[valid], width=1.0, alpha=0.7, color='red')
        ax.set_xlabel('response token position')
        ax.set_ylabel('|local - vLLM|')
        ax.set_title(f'Abs diff sample {idx}')
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

except ImportError:
    print("matplotlib not available; skip plots.")

### 8.5 Parse training log divergence metrics

Extract `vllm_vs_local_logprob_diff_mean` from the training log file and plot it vs training step to see exactly when the divergence begins.

In [None]:
# --- Parse training log to extract divergence metrics ---
import re
import subprocess

LOG_FILE = "/mnt/vast/home/lh22zyta/shortcut-RL/open-instruct/logs/RLVR-soofi-Olmo_IsomorphicRL_86561.out"

def parse_training_log(log_file: str) -> pd.DataFrame:
    """Extract training_step and debug metrics from the log file."""
    steps = []
    diffs = []
    diff_maxes = []
    losses = []
    ws_maxes = []

    step_re = re.compile(r'training_step:\s*(\d+)')
    diff_re = re.compile(r'vllm_vs_local_logprob_diff_mean:\s*([\d.e+-]+)')
    diff_max_re = re.compile(r'vllm_vs_local_logprob_diff_max:\s*([\d.e+-]+)')
    loss_re = re.compile(r'policy_avg:\s*([\d.e+-]+)')
    ws_re = re.compile(r'weight_sync_max:\s*([\d.e+-]+)')

    with open(log_file) as f:
        content = f.read()

    # Split by metrics blocks
    step_matches = step_re.findall(content)
    diff_matches = diff_re.findall(content)
    diff_max_matches = diff_max_re.findall(content)
    loss_matches = loss_re.findall(content)
    ws_matches = ws_re.findall(content)

    n = min(len(step_matches), len(diff_matches))
    rows = []
    for i in range(n):
        row = {
            'training_step': int(step_matches[i]),
            'logprob_diff_mean': float(diff_matches[i]),
        }
        if i < len(diff_max_matches):
            row['logprob_diff_max'] = float(diff_max_matches[i])
        if i < len(loss_matches):
            row['policy_loss'] = float(loss_matches[i])
        if i < len(ws_matches):
            row['weight_sync_max'] = float(ws_matches[i])
        rows.append(row)

    return pd.DataFrame(rows)


if os.path.exists(LOG_FILE):
    log_df = parse_training_log(LOG_FILE)
    print(f"Parsed {len(log_df)} training steps from log")
    display(log_df.describe())

    try:
        import matplotlib.pyplot as plt

        fig, axes = plt.subplots(2, 2, figsize=(14, 8))

        ax = axes[0, 0]
        ax.plot(log_df['training_step'], log_df['logprob_diff_mean'], linewidth=0.8)
        ax.set_xlabel('training step')
        ax.set_ylabel('logprob diff mean')
        ax.set_title('vLLM vs Local Logprob Diff (Mean)')
        ax.grid(True, alpha=0.3)
        # Mark the divergence onset
        threshold = 0.5
        diverge_steps = log_df[log_df['logprob_diff_mean'] > threshold]
        if len(diverge_steps) > 0:
            first_diverge = diverge_steps.iloc[0]['training_step']
            ax.axvline(x=first_diverge, color='red', linestyle='--', alpha=0.7,
                       label=f'First diff > {threshold} at step {int(first_diverge)}')
            ax.legend()

        ax = axes[0, 1]
        if 'logprob_diff_max' in log_df.columns:
            ax.plot(log_df['training_step'], log_df['logprob_diff_max'], linewidth=0.8, color='orange')
            ax.set_xlabel('training step')
            ax.set_ylabel('logprob diff max')
            ax.set_title('vLLM vs Local Logprob Diff (Max)')
            ax.grid(True, alpha=0.3)

        ax = axes[1, 0]
        if 'policy_loss' in log_df.columns:
            ax.plot(log_df['training_step'], log_df['policy_loss'], linewidth=0.8, color='green')
            ax.set_xlabel('training step')
            ax.set_ylabel('policy loss')
            ax.set_title('Policy Loss')
            ax.grid(True, alpha=0.3)

        ax = axes[1, 1]
        if 'weight_sync_max' in log_df.columns:
            ax.plot(log_df['training_step'], log_df['weight_sync_max'], linewidth=0.8, color='purple')
            ax.set_xlabel('training step')
            ax.set_ylabel('weight sync max (s)')
            ax.set_title('Weight Sync Time (Max)')
            ax.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()
    except ImportError:
        print("matplotlib not available; skip plots")
else:
    print(f"Log file not found: {LOG_FILE}")