# GRubrics — Análisis de Rúbricas Generadas

Notebook para inspeccionar y comparar rúbricas generadas por distintos modelos/checkpoints.

**Cómo usar:**
1. Correr las celdas de Setup (§1-§4) una sola vez por sesión.
2. En §5, cargar los checkpoints que querés comparar (tarda 1-2 min cada uno).
3. Navegar preguntas del holdout en §6 y generar rúbricas en §7.
4. Comparar múltiples checkpoints para la misma pregunta en §8.
5. Evaluación con Judge (opcional, ~$0.009/pregunta) en §9.
6. Mini-eval sobre N preguntas aleatorias en §10.
7. Comparar métricas entre checkpoints (CSVs de §10) en §11.
8. **Inspeccionar rúbricas guardadas durante el training en §12** (sin recargar modelos).

**Rúbricas durante training:** La reward function guarda automáticamente el texto de cada
rúbrica generada en `data/results/rubrics/step_XXXX.jsonl`. §12 las carga y grafica cómo
evolucionan alignment, reward y largo a lo largo del training.

---

## §1 — Imports

In [None]:
import sys
import json
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from IPython.display import display, HTML

# Project root (notebook está en notebooks/, proyecto en ..)
ROOT = Path(".").resolve().parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from grubrics_science.data.base import DatasetAdapter
from grubrics_science.evaluation.holdout import load_healthbench_with_cache, split_holdout
from grubrics_science.evaluation.metrics import (
    alignment_score, discrimination_score, format_validity,
    points_sum, info_value, rubric_length,
)

print(f"ROOT: {ROOT}")
print(f"CUDA: {torch.cuda.is_available()} | device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

## §2 — Configuración de paths y checkpoints

In [None]:
# ─── Paths ────────────────────────────────────────────────────────────────────
CHECKPOINT_DIR  = ROOT / "checkpoints" / "grubrics-transfer" / "healthbench-grpo"
SFT_DIR         = ROOT / "checkpoints" / "grubrics-transfer" / "sft-healthbench" / "final"
HB_EVAL_PATH    = ROOT / "data" / "healthbench" / "oss_eval.jsonl"
HB_CACHE_PATH   = ROOT / "data" / "cache" / "healthbench_precompute.jsonl"
BASE_MODEL_ID   = "Qwen/Qwen3-8B"
HOLDOUT_SIZE    = 500
HOLDOUT_SEED    = 42

# ─── Detectar checkpoints disponibles automáticamente ─────────────────────────
checkpoints = {}

# Base model (sin fine-tuning)
checkpoints["base_zeroshot"] = BASE_MODEL_ID

# SFT
if SFT_DIR.exists():
    checkpoints["sft"] = str(SFT_DIR)

# GRPO steps (step_200/actor, step_400/actor, ...)
if CHECKPOINT_DIR.exists():
    for actor_dir in sorted(CHECKPOINT_DIR.glob("step_*/actor"),
                            key=lambda p: int(p.parent.name.split("_")[1])):
        step = int(actor_dir.parent.name.split("_")[1])
        checkpoints[f"grpo_step{step}"] = str(actor_dir)

print("Checkpoints detectados:")
for name, path in checkpoints.items():
    exists = "✓" if (path == BASE_MODEL_ID or Path(path).exists()) else "✗ (no encontrado)"
    print(f"  [{exists}] {name:25s} → {path}")

## §3 — Cargar holdout HealthBench

In [None]:
hb_data = load_healthbench_with_cache(
    eval_path=str(HB_EVAL_PATH),
    cache_path=str(HB_CACHE_PATH),
)
_, holdout = split_holdout(hb_data, holdout_size=HOLDOUT_SIZE, seed=HOLDOUT_SEED)

# Solo quedarse con los que tienen gold_scores (precompute completo)
holdout_with_scores = [e for e in holdout if e.get("gold_scores")]

print(f"Holdout total         : {len(holdout)} preguntas")
print(f"Con gold_scores       : {len(holdout_with_scores)} preguntas")
print(f"Sin gold_scores (skip): {len(holdout) - len(holdout_with_scores)} preguntas")

# Vista rápida de la primera
q = holdout_with_scores[0]
print(f"\nEjemplo [0]:")
print(f"  ID       : {q.get('question_id', q.get('prompt_id', '?'))}")
print(f"  Pregunta : {q['question'][:150]}...")
print(f"  Answers  : {len(q['answers'])}")
print(f"  Gold     : {[round(s, 2) for s in q['gold_scores']]}")
print(f"  Golden rubric (primeros 200 chars): {q['golden_rubric'][:200]}...")

## §4 — Helpers: carga de modelos y generación de rúbricas

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

_model_cache: dict = {}  # nombre → (model, tokenizer)


def load_checkpoint(name: str):
    """Carga base Qwen3-8B + LoRA adapter (si aplica). Usa caché para no recargar.
    
    veRL guarda el actor en HF format. Intentamos cargar como LoRA adapter (peft).
    Si falla (ej. pesos mergeados), cargamos directo como HF model.
    """
    if name in _model_cache:
        print(f"✓ {name!r} ya está en caché")
        return _model_cache[name]

    path = checkpoints.get(name)
    if path is None:
        raise ValueError(f"Checkpoint {name!r} no encontrado. Disponibles: {list(checkpoints)}")

    is_base = (path == BASE_MODEL_ID)
    print(f"Cargando {name!r}...")
    print(f"  Base model : {BASE_MODEL_ID}")
    if not is_base:
        print(f"  Adapter    : {path}")

    # Tokenizer: intentar desde el checkpoint, fallback al base model
    tok_path = BASE_MODEL_ID if is_base else path
    try:
        tokenizer = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True)
    except Exception:
        print(f"  ⚠ Tokenizer no en {path}, usando base model")
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)

    # Base model
    base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )

    if is_base:
        model = base
    else:
        # Intentar cargar como LoRA adapter (peft format — lo que guarda veRL y TRL)
        adapter_config = Path(path) / "adapter_config.json"
        if adapter_config.exists():
            from peft import PeftModel
            print("  → Cargando como peft LoRA adapter")
            model = PeftModel.from_pretrained(base, path)
            model = model.merge_and_unload()  # merge para inferencia más rápida
        else:
            # Checkpoint con pesos mergeados (HF completo)
            print("  → No hay adapter_config.json, cargando como HF model completo")
            del base
            torch.cuda.empty_cache()
            model = AutoModelForCausalLM.from_pretrained(
                path,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True,
            )

    model.eval()
    _model_cache[name] = (model, tokenizer)
    print(f"  ✓ Listo — {sum(p.numel() for p in model.parameters())/1e9:.1f}B params en {next(model.parameters()).device}")
    return model, tokenizer


def unload_checkpoint(name: str):
    """Libera VRAM de un checkpoint."""
    if name in _model_cache:
        del _model_cache[name]
        torch.cuda.empty_cache()
        print(f"✓ {name!r} descargado de VRAM")


def unload_all():
    """Libera todos los modelos cargados."""
    for name in list(_model_cache.keys()):
        unload_checkpoint(name)

In [None]:
HB_CONTEXT = (
    "This is a medical conversation between a patient and an AI assistant. "
    "The rubric should evaluate medical accuracy, completeness, safety, "
    "communication quality, and instruction following."
)


def _build_messages(entry: dict, use_contrastive: bool = True) -> list:
    """Construye los mensajes de chat exactamente igual que el HealthBench adapter."""
    best_excerpt = worst_excerpt = None
    if use_contrastive and entry.get("answers") and entry.get("gold_scores"):
        answers = entry["answers"]
        scores  = entry["gold_scores"]
        if len(answers) > 1:
            best_idx  = int(np.argmax(scores))
            worst_idx = int(np.argmin(scores))
            if best_idx != worst_idx:
                best_excerpt  = answers[best_idx][:400]
                worst_excerpt = answers[worst_idx][:400]

    return DatasetAdapter.build_rubric_generation_prompt(
        question=entry["question"],
        context=HB_CONTEXT,
        best_answer_excerpt=best_excerpt,
        worst_answer_excerpt=worst_excerpt,
    )


@torch.no_grad()
def generate_rubric(
    checkpoint_name: str,
    entry: dict,
    max_new_tokens: int = 512,
    temperature: float = 0.7,
    use_contrastive: bool = True,
) -> str:
    """Genera una rúbrica para una pregunta del holdout."""
    model, tokenizer = load_checkpoint(checkpoint_name)
    messages = _build_messages(entry, use_contrastive=use_contrastive)

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    inputs = tokenizer(text, return_tensors="pt").to(next(model.parameters()).device)

    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature if temperature > 0 else None,
        do_sample=temperature > 0,
        pad_token_id=tokenizer.eos_token_id,
    )
    generated_ids = out[0][inputs["input_ids"].shape[1]:]
    return tokenizer.decode(generated_ids, skip_special_tokens=True).strip()


def quick_metrics(rubric: str) -> dict:
    """Métricas que no requieren Judge API (gratis, instantáneo)."""
    return {
        "format_validity" : round(format_validity(rubric), 3),
        "points_sum"      : round(points_sum(rubric), 2),
        "n_items"         : sum(1 for l in rubric.splitlines() if l.strip().startswith("Points:")),
        "chars"           : rubric_length(rubric),
    }


def show_rubric_comparison(rubrics: dict):
    """Muestra rúbricas side-by-side con métricas rápidas."""
    rows = []
    for name, rubric in rubrics.items():
        m = quick_metrics(rubric)
        rows.append({"model": name, **m})
        print(f"\n{'═' * 65}")
        print(f"  {name.upper()}")
        fmt = m['format_validity']
        pts = m['points_sum']
        print(f"  format={fmt} | points_sum={pts} | items={m['n_items']} | chars={m['chars']}")
        print(f"{'─' * 65}")
        print(rubric)

    print(f"\n{'═' * 65}")
    print("RESUMEN MÉTRICAS RÁPIDAS:")
    display(pd.DataFrame(rows).set_index("model"))


print("Helpers cargados ✓")

## §5 — Cargar checkpoints

Cargar uno a la vez. El H100 tiene 94GB VRAM — Qwen3-8B ocupa ~16GB en bfloat16,
así que podés tener 2-3 modelos cargados simultáneamente si querés comparar rápido.
Usá `unload_checkpoint(name)` para liberar VRAM.

In [None]:
# ─── Cargar checkpoints que querés comparar ───────────────────────────────────
# Descomentá los que necesitás. El primero que cargues tarda más (descarga base model).

load_checkpoint("base_zeroshot")
# load_checkpoint("sft")
# load_checkpoint("grpo_step200")
# load_checkpoint("grpo_step400")

print(f"\nModelos en caché: {list(_model_cache.keys())}")
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated() / 1e9
    reserved  = torch.cuda.memory_reserved() / 1e9
    total     = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"VRAM: {allocated:.1f}GB allocada | {reserved:.1f}GB reservada | {total:.1f}GB total")

## §6 — Browser del holdout

Explorar preguntas. Cambiá `IDX` para navegar. Las que no tienen `gold_scores` fueron excluidas.

In [None]:
IDX = 0  # ← cambiar para explorar (0 a len(holdout_with_scores)-1)

entry = holdout_with_scores[IDX]

print(f"[{IDX}/{len(holdout_with_scores)-1}] ID: {entry.get('question_id', entry.get('prompt_id', '?'))}")
print(f"{'─'*70}")
print("PREGUNTA:")
print(entry["question"])
print(f"\n{'─'*70}")
print("RUBRICA GOLDEN (referencia humana):")
print(entry["golden_rubric"])
print(f"\n{'─'*70}")
print(f"Answers disponibles: {len(entry['answers'])}")
print(f"Gold scores        : {[round(s,3) for s in entry['gold_scores']]}")
m = quick_metrics(entry["golden_rubric"])
print(f"Golden — format={m['format_validity']} | points_sum={m['points_sum']} | items={m['n_items']}")

## §7 — Generar rúbrica con un checkpoint

Genera la rúbrica para la pregunta seleccionada arriba.

In [None]:
CHECKPOINT = "base_zeroshot"   # ← cambiar al que querés usar
TEMPERATURE = 0.7               # 0.0 = greedy, 0.7 = sampling (como en training)
USE_CONTRASTIVE = True          # Incluir excerpts de mejor/peor respuesta (igual que training)

print(f"Generando rúbrica con {CHECKPOINT!r} (temp={TEMPERATURE})...")
generated_rubric = generate_rubric(
    CHECKPOINT, entry,
    temperature=TEMPERATURE,
    use_contrastive=USE_CONTRASTIVE,
)

print(f"\n{'═'*65}")
print(f"  GENERADA ({CHECKPOINT})")
print(f"{'─'*65}")
print(generated_rubric)

m = quick_metrics(generated_rubric)
print(f"\nMétricas: format={m['format_validity']} | points_sum={m['points_sum']} | items={m['n_items']} | chars={m['chars']}")

## §8 — Comparar múltiples checkpoints

Genera la misma pregunta con todos los modelos cargados y compara side-by-side.

In [None]:
# Modelos a comparar (solo los que están en caché)
MODELS_TO_COMPARE = list(_model_cache.keys())  # todos los cargados
# O especificar manualmente:
# MODELS_TO_COMPARE = ["base_zeroshot", "sft", "grpo_step200"]

QUESTION_IDX = IDX   # usar la pregunta del §6 (o cambiar)
entry_cmp = holdout_with_scores[QUESTION_IDX]

rubrics = {"golden": entry_cmp["golden_rubric"]}
for name in MODELS_TO_COMPARE:
    print(f"Generando con {name}...")
    rubrics[name] = generate_rubric(name, entry_cmp, temperature=TEMPERATURE)

print(f"\nPREGUNTA: {entry_cmp['question'][:150]}...\n")
show_rubric_comparison(rubrics)

## §9 — Evaluación con Judge (opcional)

Llama al Judge API para evaluar las rúbricas contra las respuestas precomputadas.
Computa alignment (Spearman) vs gold_scores.

**Costo**: ~$0.009 por pregunta × número de modelos.
Con 4 modelos y 1 pregunta: ~$0.036.

In [None]:
from grubrics_science.judge.judge import Judge

JUDGE_MODEL = os.environ.get("JUDGE_MODEL", "gpt-4.1")
print(f"Judge model: {JUDGE_MODEL}")
print(f"Asegurate de tener AZURE_API_KEY y AZURE_API_BASE en el entorno (o en .env)")


async def eval_rubrics_with_judge(entry: dict, rubrics: dict) -> pd.DataFrame:
    """Evalúa un dict {nombre: rubrica} sobre una pregunta con el Judge.
    Devuelve DataFrame con alignment, discrimination, info_value y scores raw.
    """
    judge = Judge(model=JUDGE_MODEL, max_concurrent=10, max_cache_size=0)
    rows = []
    for name, rubric in rubrics.items():
        print(f"  Evaluando {name}...")
        judge_scores = await judge.evaluate_answers_batched(
            question=entry["question"],
            answers=entry["answers"],
            rubric=rubric,
        )
        gold = entry["gold_scores"]
        rows.append({
            "model"         : name,
            "alignment"     : round(alignment_score(judge_scores, gold), 3),
            "discrimination": round(discrimination_score(judge_scores), 3),
            "info_value"    : round(info_value(judge_scores), 3),
            "format"        : round(format_validity(rubric), 3),
            "points_sum"    : round(points_sum(rubric), 2),
            "judge_scores"  : [round(s, 3) for s in judge_scores],
            "gold_scores"   : [round(s, 3) for s in gold],
        })
    return pd.DataFrame(rows).set_index("model")


# Jupyter soporta await directo en celdas (IPython 7+)
print("\nEvaluando con Judge...")
df_judge = await eval_rubrics_with_judge(entry_cmp, rubrics)

print("\nResultados:")
display(df_judge)

In [None]:
# ─── Visualizar scores raw vs gold scores ─────────────────────────────────────
fig, axes = plt.subplots(1, len(rubrics), figsize=(4 * len(rubrics), 4), sharey=True)
if len(rubrics) == 1:
    axes = [axes]

gold = entry_cmp["gold_scores"]
x = range(len(gold))

for ax, (name, _) in zip(axes, rubrics.items()):
    if name in df_judge.index:
        js = df_judge.loc[name, "judge_scores"]
        ax.bar(x, gold, alpha=0.4, label="gold", color="green")
        ax.bar(x, js,   alpha=0.6, label="judge", color="blue")
        corr = df_judge.loc[name, "alignment"]
        ax.set_title(f"{name}\nalignment={corr}")
        ax.set_xlabel("Answer")
        ax.legend(fontsize=8)
    else:
        ax.set_title(f"{name}\n(sin Judge eval)")

axes[0].set_ylabel("Score")
plt.tight_layout()
plt.show()

## §10 — Mini-eval: N preguntas aleatorias

Genera + evalúa con Judge una muestra aleatoria del holdout para **un checkpoint**.
Útil para comparar métricas agregadas entre checkpoints sin correr las 500 preguntas completas.

**Costo estimado**: `N × $0.009` por checkpoint.

In [None]:
EVAL_CHECKPOINT = "base_zeroshot"   # ← checkpoint a evaluar
N_QUESTIONS     = 20                # ← cantidad de preguntas (ajustar según presupuesto)
EVAL_SEED       = 42

rng     = random.Random(EVAL_SEED)
sample  = rng.sample(holdout_with_scores, min(N_QUESTIONS, len(holdout_with_scores)))

print(f"Mini-eval: {len(sample)} preguntas con {EVAL_CHECKPOINT!r}")
print(f"Costo estimado: ~${len(sample) * 0.009:.2f}")


async def run_mini_eval(checkpoint_name: str, entries: list) -> pd.DataFrame:
    judge = Judge(model=JUDGE_MODEL, max_concurrent=20, max_cache_size=0)
    rows  = []
    for i, e in enumerate(entries):
        print(f"  [{i+1:3d}/{len(entries)}] {e.get('question_id', '?')[:40]}", end="\r")
        rubric = generate_rubric(checkpoint_name, e, temperature=TEMPERATURE)
        scores = await judge.evaluate_answers_batched(
            question=e["question"],
            answers=e["answers"],
            rubric=rubric,
        )
        rows.append({
            "question_id"   : e.get("question_id", e.get("prompt_id", i)),
            "alignment"     : alignment_score(scores, e["gold_scores"]),
            "discrimination": discrimination_score(scores),
            "info_value"    : info_value(scores),
            "format"        : format_validity(rubric),
            "points_sum"    : points_sum(rubric),
            "n_items"       : sum(1 for l in rubric.splitlines() if l.strip().startswith("Points:")),
            "rubric_text"   : rubric,   # guardamos el texto para inspección
        })
    print()
    return pd.DataFrame(rows)


df_mini = await run_mini_eval(EVAL_CHECKPOINT, sample)

print(f"\nMini-eval completado — {len(df_mini)} preguntas, checkpoint: {EVAL_CHECKPOINT!r}")
print("\nEstadísticas agregadas:")
display(df_mini[["alignment", "discrimination", "info_value", "format", "points_sum"]].describe().round(3))

In [None]:
# ─── Distribución de alignment scores ─────────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

axes[0].hist(df_mini["alignment"], bins=20, edgecolor="white", color="steelblue")
axes[0].axvline(df_mini["alignment"].mean(), color="red", linestyle="--", label=f"mean={df_mini['alignment'].mean():.3f}")
axes[0].set_title(f"Alignment (Spearman)\n{EVAL_CHECKPOINT}")
axes[0].set_xlabel("Spearman correlation")
axes[0].legend()

axes[1].hist(df_mini["discrimination"], bins=20, edgecolor="white", color="darkorange")
axes[1].axvline(df_mini["discrimination"].mean(), color="red", linestyle="--",
               label=f"mean={df_mini['discrimination'].mean():.3f}")
axes[1].set_title("Discrimination (std de scores)")
axes[1].set_xlabel("Std")
axes[1].legend()

axes[2].hist(df_mini["format"], bins=15, edgecolor="white", color="green")
axes[2].axvline(df_mini["format"].mean(), color="red", linestyle="--",
               label=f"mean={df_mini['format'].mean():.3f}")
axes[2].set_title("Format validity")
axes[2].set_xlabel("Fracción de líneas válidas")
axes[2].legend()

plt.tight_layout()
plt.show()

In [None]:
# ─── Inspeccionar casos extremos ───────────────────────────────────────────────
# Las mejores y peores rúbricas según alignment

print("=== Top 3 mejores (mayor alignment) ===")
for _, row in df_mini.nlargest(3, "alignment").iterrows():
    print(f"\nalignment={row['alignment']:.3f} | id={row['question_id']}")
    print(row["rubric_text"][:400] + "...")

print("\n=== Bottom 3 peores (menor alignment) ===")
for _, row in df_mini.nsmallest(3, "alignment").iterrows():
    print(f"\nalignment={row['alignment']:.3f} | id={row['question_id']}")
    print(row["rubric_text"][:400] + "...")

In [None]:
# ─── Guardar resultados del mini-eval para comparar después ───────────────────
RESULTS_DIR = ROOT / "data" / "results"
RESULTS_DIR.mkdir(exist_ok=True)

out_path = RESULTS_DIR / f"mini_eval_{EVAL_CHECKPOINT}_n{len(df_mini)}.csv"
df_mini.to_csv(out_path, index=False)
print(f"Guardado en: {out_path}")

## §11 — Comparar métricas entre checkpoints

Si corriste mini-eval para varios checkpoints y guardaste los CSVs,
esta celda los carga y los compara en una tabla.

In [None]:
RESULTS_DIR = ROOT / "data" / "results"

# Cargar todos los mini-eval CSVs disponibles
all_evals = {}
for csv_path in sorted(RESULTS_DIR.glob("mini_eval_*.csv")):
    name = csv_path.stem.replace("mini_eval_", "").rsplit("_n", 1)[0]
    df = pd.read_csv(csv_path)
    all_evals[name] = df
    print(f"  {name}: {len(df)} preguntas")

if not all_evals:
    print("No hay CSVs de mini-eval. Corré §10 primero.")
else:
    metrics = ["alignment", "discrimination", "info_value", "format", "points_sum"]
    comparison = pd.DataFrame({
        name: df[metrics].mean()
        for name, df in all_evals.items()
    }).T.round(3)
    print("\nComparación de checkpoints (media sobre preguntas compartidas):")
    display(comparison)

---

## Referencia rápida

| Métrica | Rango | Qué mide |
|---------|-------|----------|
| `alignment` | [-1, 1] | Spearman vs gold_scores. **Principal métrica del paper.** |
| `discrimination` | [0, 1] | Std de los scores del Judge. 0 = rúbrica degenerada. |
| `info_value` | [0, 1] | `4p(1-p)` donde p = frac. que pasa. Máx en p=0.5. |
| `format_validity` | [0, 1] | Fracción de líneas con formato `Points: X, Item: Y`. |
| `points_sum` | 0-∞ | Suma de todos los Points. Target: **10.0**. |

**Comandos útiles:**
```python
unload_checkpoint("base_zeroshot")  # liberar VRAM
unload_all()                         # liberar todo
list(_model_cache.keys())            # modelos en caché
```

## §12 — Rúbricas guardadas durante el training

Durante el GRPO, la reward function guarda automáticamente el texto de cada rúbrica generada
en `data/results/rubrics/step_XXXX.jsonl` (una línea JSON por rollout).

Cada entrada tiene: `step`, `question_id`, `question`, `rubric`, `alignment`, `reward`,
`judge_scores`, `gold_scores`, `n_chars`.

Esto permite ver cómo evolucionan las rúbricas **sin recargar checkpoints** — son las que el
modelo generó durante el training real (no regeneraciones a posteriori).

In [None]:
RUBRICS_DIR = ROOT / "data" / "results" / "rubrics"

# ─── Cargar todos los step_XXXX.jsonl disponibles ─────────────────────────────
step_data = {}  # step_num → list of entries

if not RUBRICS_DIR.exists():
    print(f"No hay rúbricas de training todavía ({RUBRICS_DIR})")
    print("Asegurate de que SAVE_RUBRICS no esté en 0 y de haber hecho al menos 1 step.")
else:
    for jl in sorted(RUBRICS_DIR.glob("step_*.jsonl"),
                     key=lambda p: int(p.stem.split("_")[1])):
        step_num = int(jl.stem.split("_")[1])
        with open(jl, encoding="utf-8") as f:
            entries = [json.loads(line) for line in f if line.strip()]
        step_data[step_num] = entries
        print(f"  step {step_num:4d} — {len(entries):3d} rúbricas | "
              f"alignment mean={np.mean([e['alignment'] for e in entries]):.3f} "
              f"| reward mean={np.mean([e['reward'] for e in entries]):.3f}")

    print(f"\nTotal steps: {len(step_data)} | "
          f"Total rúbricas: {sum(len(v) for v in step_data.values())}")

In [None]:
# ─── Evolución de alignment y reward a lo largo del training ──────────────────
if step_data:
    steps       = sorted(step_data.keys())
    mean_align  = [np.mean([e["alignment"] for e in step_data[s]]) for s in steps]
    mean_reward = [np.mean([e["reward"]    for e in step_data[s]]) for s in steps]
    mean_chars  = [np.mean([e["n_chars"]   for e in step_data[s]]) for s in steps]

    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    axes[0].plot(steps, mean_align,  "o-", color="steelblue")
    axes[0].set_title("Alignment (Spearman) durante training")
    axes[0].set_xlabel("Step")
    axes[0].set_ylabel("Spearman correlation (media del batch)")
    axes[0].grid(alpha=0.3)

    axes[1].plot(steps, mean_reward, "o-", color="darkorange")
    axes[1].set_title("Reward total durante training")
    axes[1].set_xlabel("Step")
    axes[1].set_ylabel("Reward (media del batch)")
    axes[1].grid(alpha=0.3)

    axes[2].plot(steps, mean_chars,  "o-", color="green")
    axes[2].set_title("Largo de rúbricas durante training")
    axes[2].set_xlabel("Step")
    axes[2].set_ylabel("Chars (media del batch)")
    axes[2].grid(alpha=0.3)

    plt.tight_layout()
    plt.show()

In [None]:
# ─── Inspeccionar rúbricas de un step específico ──────────────────────────────
INSPECT_STEP = steps[0] if step_data else None  # ← cambiar al step que querés ver
TOP_N        = 3   # mostrar las N de mayor y menor alignment

if INSPECT_STEP and INSPECT_STEP in step_data:
    entries_step = step_data[INSPECT_STEP]
    print(f"=== Step {INSPECT_STEP} — {len(entries_step)} rúbricas ===")
    print(f"alignment: mean={np.mean([e['alignment'] for e in entries_step]):.3f} "
          f"std={np.std([e['alignment'] for e in entries_step]):.3f}")
    print(f"reward   : mean={np.mean([e['reward'] for e in entries_step]):.3f}\n")

    sorted_entries = sorted(entries_step, key=lambda e: e["alignment"], reverse=True)

    print(f"--- Top {TOP_N} (mayor alignment) ---")
    for e in sorted_entries[:TOP_N]:
        print(f"\nalignment={e['alignment']:.3f} | reward={e['reward']:.3f} | chars={e['n_chars']}")
        print(f"Q: {e['question'][:120]}...")
        print(f"Rubric:\n{e['rubric']}")
        print()

    print(f"\n--- Bottom {TOP_N} (menor alignment) ---")
    for e in sorted_entries[-TOP_N:]:
        print(f"\nalignment={e['alignment']:.3f} | reward={e['reward']:.3f} | chars={e['n_chars']}")
        print(f"Q: {e['question'][:120]}...")
        print(f"Rubric:\n{e['rubric']}")
        print()