# Gemma Probe Evaluation


**W&B Runs:**
- Gemma 27B: https://wandb.ai/seperability/outlines_probes/runs/9zr7uulg
- Gemma 12B: https://wandb.ai/seperability/outlines_probes/runs/940bo0gy
- Gemma 4B: https://wandb.ai/seperability/outlines_probes/runs/xn3z3h4f

**Evaluation**: Each model compared against its own regeneration (original outline vs probe-decoded outline)


In [1]:
import sys
from pathlib import Path
import os

# Set working directory to the outlines folder where all scripts are
OUTLINES_DIR = Path("/workspace/ALGOVERSE/yas/yulia/parascopes/src/yulia/outlines")
os.chdir(OUTLINES_DIR)
print(f"Working directory: {os.getcwd()}")

# IMPORTANT: Add outlines dir FIRST so it takes precedence over src/
# (there are duplicate utils_load_data.py files (not duplicates but Nicky's files)) 
SRC_DIR = OUTLINES_DIR.parent.parent  

# Remove src/ from path if present (it might have the old utils_load_data.py)
sys.path = [p for p in sys.path if str(SRC_DIR) not in p or "outlines" in p]

# Add outlines dir at the very front
if str(OUTLINES_DIR) in sys.path:
    sys.path.remove(str(OUTLINES_DIR))
sys.path.insert(0, str(OUTLINES_DIR))

# Add src/ AFTER outlines for yulia.outlines.* imports
if str(SRC_DIR) not in sys.path:
    sys.path.append(str(SRC_DIR))  # append, not insert!

print(f"Path order (first 3):")
for p in sys.path[:3]:
    print(f"  {p}")

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import einops
import wandb
import os
import re
import shutil
import json
from typing import List, Dict
from openai import OpenAI

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32
print("Using device:", device)


Working directory: /workspace/ALGOVERSE/yas/yulia/parascopes/src/yulia/outlines
Path order (first 3):
  /workspace/ALGOVERSE/yas/yulia/parascopes/src/yulia/outlines
  /usr/lib/python310.zip
  /usr/lib/python3.10
Using device: cuda


In [2]:

import os
DEEPINFRA_KEY = ""  


In [3]:
# ==== GEMMA PROBE CONFIGURATIONS ====
WANDB_ENTITY = "seperability"
WANDB_PROJECT = "outlines_probes"

# NOTE: n_layers and d_model are INFERRED from the checkpoint
GEMMA_CONFIGS = {
    "gemma27b": {
        "run_id": "9zr7uulg",
        "local_residuals_dir": "/mnt/hdd-8tb/hdd_cache/tensors/gemma-27b",
        "local_embeds_dir": "/workspace/ALGOVERSE/yas/yulia/parascopes/src/yulia/outlines/results/gemma27b-outlines-embeddings",
        "hf_texts_dataset": "annnettte/fineweb-gemma27b-texts",  # Original texts for regen
        "hf_outlines_repo": "yulia-volkova/parascopes-outlines-gemma27b",  # Original outlines
        # n_layers and d_model will be inferred from checkpoint
    },
    "gemma12b": {
        "run_id": "940bo0gy",
        "local_residuals_dir": "/mnt/hdd-8tb/hdd_cache/tensors/gemma-12b",
        "local_embeds_dir": "/workspace/ALGOVERSE/yas/yulia/parascopes/src/yulia/outlines/results/gemma12b-outlines-embeddings",
        "hf_texts_dataset": "annnettte/fineweb-gemma12b-texts",  # Original texts for regen
        "hf_outlines_repo": "yulia-volkova/parascopes-outlines-gemma12b",  # Original outlines
        # n_layers and d_model will be inferred from checkpoint
    },
    "gemma4b": {
        "run_id": "xn3z3h4f",
        "local_residuals_dir": "/mnt/hdd-8tb/hdd_cache/tensors/gemma-4b",
        "local_embeds_dir": "/workspace/ALGOVERSE/yas/yulia/parascopes/src/yulia/outlines/results/gemma4b-outlines-embeddings",
        "hf_texts_dataset": "annnettte/fineweb-gemma4b-texts",  # Original texts for regen (UPDATE IF DIFFERENT)
        "hf_outlines_repo": "yulia-volkova/parascopes-outlines-gemma4b",  # Original outlines (UPDATE IF DIFFERENT)
        # n_layers and d_model will be inferred from checkpoint
    },
}

# Output directory for artifacts and results
OUT_DIR = Path("./eval_gemma_probes")
OUT_DIR.mkdir(exist_ok=True, parents=True)

# Chunk 99 = validation chunk
EVAL_CHUNKS = [99] 
print(f"Will evaluate chunk: {EVAL_CHUNKS}")


Will evaluate chunk: [99]


In [21]:
from utils_normalizers import Normalizer

class LinearProbe(torch.nn.Module):
    """Simple linear probe: flattens [batch, n_layers, d_model] and maps to d_sonar."""
    def __init__(self, n_layers: int, d_model: int, d_sonar: int = 1024):
        super().__init__()
        self.n_layers = n_layers
        self.d_model = d_model
        self.d_sonar = d_sonar
        self.linear = torch.nn.Linear(n_layers * d_model, d_sonar)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.reshape(x.shape[0], -1)
        return self.linear(x)

D_SONAR = 1024


In [22]:

def download_probe_artifacts(model_name: str, config: dict) -> dict:
    run_id = config["run_id"]
    run_path = f"{WANDB_ENTITY}/{WANDB_PROJECT}/{run_id}"
    local_dir = OUT_DIR / model_name / run_id
    local_dir.mkdir(parents=True, exist_ok=True)
    
    api = wandb.Api()
    run = api.run(run_path)
    print(f"\n{'='*60}")
    print(f"Model: {model_name}")
    print(f"Run: {run.name} | ID: {run.id}")
    
    # List all artifacts for debugging
    all_artifacts = list(run.logged_artifacts())
    print(f"Available artifacts ({len(all_artifacts)}):")
    for art in all_artifacts:
        print(f"  - {art.name} (type={art.type})")
    
    # Find best epoch from run history
    hist = run.history(pandas=True)
    val_col_candidates = ["epoch/val_loss", "val/loss", "val_loss", "val/mse", "epoch/val_mse"]
    val_col = next((c for c in val_col_candidates if c in hist.columns and hist[c].notna().any()), None)
    
    if val_col:
        h2 = hist[["epoch", val_col]].dropna()
        best_epoch = int(h2.loc[h2[val_col].idxmin(), "epoch"])
        best_val = float(h2.loc[h2[val_col].idxmin(), val_col])
        print(f"Best epoch: {best_epoch} ({val_col}={best_val:.6f})")
    else:
        best_epoch = int(hist["epoch"].max()) if "epoch" in hist.columns else 1
        print(f"Using epoch: {best_epoch} (no val loss found)")
    
    # Download checkpoint artifact - try multiple naming patterns
    ckpt_path = None
    checkpoint_patterns = [
        f"checkpoint-epoch-{best_epoch}:",   
        f"checkpoint_epoch_{best_epoch}:",   
        f"checkpoint-epoch-{best_epoch+1}:", 
        f"checkpoint_epoch_{best_epoch+1}:", 
    ]
    
    for art in all_artifacts:
        if art.type == "model":
            # Check if artifact name matches any pattern
            for pattern in checkpoint_patterns:
                if art.name.startswith(pattern.rstrip(":")):
                    base_name = art.name.split(":")[0]
                    local_ckpt = local_dir / "checkpoints" / f"{base_name}.pkl"
                    local_ckpt.parent.mkdir(parents=True, exist_ok=True)
                    
                    if local_ckpt.exists():
                        print(f"Using cached checkpoint: {local_ckpt}")
                        ckpt_path = str(local_ckpt)
                    else:
                        print(f"Downloading checkpoint: {art.name}")
                        ckpt_dir = art.download(root=str(local_ckpt.parent))
                        # Find the actual downloaded file
                        for root, _, files in os.walk(ckpt_dir):
                            for f in files:
                                src = os.path.join(root, f)
                                shutil.copy2(src, local_ckpt)
                                print(f"  Copied to: {local_ckpt}")
                                ckpt_path = str(local_ckpt)
                                break
                            if ckpt_path:
                                break
                    break
            if ckpt_path:
                break
    
    if ckpt_path is None:
        print(f"WARNING: No checkpoint found for epoch {best_epoch}!")
    
    # Download normalizers
    norm_dir = local_dir / "normalizers"
    norm_dir.mkdir(parents=True, exist_ok=True)
    
    res_norm_path = None
    emb_norm_path = None
    
    for art in all_artifacts:
        if "res_normalizer" in art.name:
            local_path = norm_dir / "res_normalizer.pt"
            if not local_path.exists():
                print(f"Downloading: {art.name}")
                art.download(root=str(norm_dir))
            res_norm_path = str(local_path)
        elif "embed_normalizer" in art.name:
            local_path = norm_dir / "embed_normalizer.pt"
            if not local_path.exists():
                print(f"Downloading: {art.name}")
                art.download(root=str(norm_dir))
            emb_norm_path = str(local_path)
    
    return {
        "checkpoint_path": ckpt_path,
        "res_normalizer_path": res_norm_path,
        "embed_normalizer_path": emb_norm_path,
        "best_epoch": best_epoch,
    }


In [23]:
# Download artifacts for all models
model_artifacts = {}
for model_name, config in GEMMA_CONFIGS.items():
    try:
        artifacts = download_probe_artifacts(model_name, config)
        model_artifacts[model_name] = artifacts
        print(f"✓ Downloaded artifacts for {model_name}")
    except Exception as e:
        print(f"✗ Failed to download artifacts for {model_name}: {e}")
        model_artifacts[model_name] = None



Model: gemma27b
Run: gemma27b_linear_probe_20251130_151545 | ID: 9zr7uulg
Available artifacts (10):
  - checkpoint_epoch_1:v6 (type=model)
  - checkpoint_epoch_2:v6 (type=model)
  - checkpoint_epoch_3:v6 (type=model)
  - checkpoint_epoch_4:v6 (type=model)
  - checkpoint_epoch_5:v6 (type=model)
  - checkpoint_epoch_6:v6 (type=model)
  - checkpoint_epoch_7:v6 (type=model)
  - checkpoint_epoch_8:v6 (type=model)
  - checkpoint_epoch_9:v6 (type=model)
  - run-9zr7uulg-history:v0 (type=wandb-history)
Best epoch: 9 (epoch/val_loss=0.946803)
Using cached checkpoint: eval_gemma_probes/gemma27b/9zr7uulg/checkpoints/checkpoint_epoch_9.pkl
✓ Downloaded artifacts for gemma27b

Model: gemma12b
Run: gemma12b_linear_probe_20251130_045927 | ID: 940bo0gy
Available artifacts (12):
  - checkpoint_epoch_1:v5 (type=model)
  - checkpoint_epoch_2:v5 (type=model)
  - checkpoint_epoch_3:v5 (type=model)
  - checkpoint_epoch_4:v5 (type=model)
  - checkpoint_epoch_5:v5 (type=model)
  - checkpoint_epoch_6:v5 (type

In [None]:
import utils_sonar as sonar_utils

# Initialize SONAR models for decoding embeddings to text
text2vec, vec2text = sonar_utils.init_sonar()

def decode_embeddings(tensors, target_lang="eng_Latn"):
    if isinstance(tensors, torch.Tensor):
        if tensors.ndim == 1:
            batch = [tensors]
        elif tensors.ndim == 2:
            batch = [tensors[i] for i in range(tensors.size(0))]
        else:
            raise ValueError("Expected 1D or 2D tensor")
    else:
        batch = list(tensors)
    
    # Get decoder device/dtype
    for attr in ["model", "_model", "decoder"]:
        m = getattr(vec2text, attr, None)
        if m is not None:
            try:
                p = next(m.parameters())
                dec_device, dec_dtype = p.device, p.dtype
                break
            except:
                pass
    else:
        dec_device = device
        dec_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
    
    batch = [t.detach().to(device=dec_device, dtype=dec_dtype) for t in batch]
    return vec2text.predict(batch, target_lang=target_lang)


[SONAR] init on device=cuda, dtype=torch.bfloat16


In [None]:
# ==== EVALUATE AND DECODE FUNCTION ====
import pickle


def residual_pre_diffs(res_all):
    """raw residual -> layer diffs (same as utils_train.py)."""
    states = res_all[:, 0, :].to(dtype=torch.float32)
    return states[1:, :] - states[:-1, :]


def get_probe_dims_from_checkpoint(state: dict) -> tuple:
    if "config" not in state:
        raise ValueError("Checkpoint missing 'config' key")
    config = state["config"]
    n_layers = config["n_layers"]
    d_model = config["d_model"]
    print(f"  From config: n_layers={n_layers}, d_model={d_model}")
    return n_layers, d_model


def load_probe_and_normalizers(model_name: str, config: dict, artifacts: dict):
    ckpt_path = artifacts["checkpoint_path"]
    
    try:
        with open(ckpt_path, "rb") as f:
            state = pickle.load(f)
        print(f"  Loaded checkpoint as pickle")
    except:
        state = torch.load(ckpt_path, map_location="cpu")
        print(f"  Loaded checkpoint as torch")
    
    n_layers, d_model = get_probe_dims_from_checkpoint(state)
    model_state = state.get("model", state.get("model_state_dict", {}))
    
    probe = LinearProbe(n_layers=n_layers, d_model=d_model, d_sonar=D_SONAR)
    probe.load_state_dict(model_state, strict=True)
    probe.eval()
    
    # Try GPU first, fall back to CPU if OOM
    probe_device = device
    try:
        probe = probe.to(device)
        print(f"  Probe on GPU")
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print(f"  GPU OOM, using CPU")
            torch.cuda.empty_cache()
            probe = probe.cpu()
            probe_device = torch.device("cpu")
        else:
            raise
    
    res_norm = Normalizer(mean=state["res_norm_mean"], std=state["res_norm_std"])
    emb_norm = Normalizer(mean=state["emb_norm_mean"], std=state["emb_norm_std"])
    
    print(f"Loaded probe for {model_name}: {n_layers}L, d={d_model}, dev={probe_device}")
    return probe, res_norm, emb_norm, n_layers, d_model, probe_device


def eval_and_decode_model(
    model_name: str,
    config: dict,
    probe,
    res_norm,
    emb_norm,
    probe_device,
    chunk_ids: list,
    per_chunk: int = 1000,
    batch_size: int = 32,
) -> pd.DataFrame:
    residuals_path = Path(config["local_residuals_dir"])
    embeds_path = Path(config["local_embeds_dir"])
    
    rows = []
    all_mse, all_cos = [], []
    
    for chunk_id in chunk_ids:
        res_file = residuals_path / f"res_data_{chunk_id:03d}.pt"
        emb_file = embeds_path / f"outlines_{chunk_id:03d}.pt"
        
        if not res_file.exists() or not emb_file.exists():
            print(f"Skipping chunk {chunk_id}: files not found")
            continue
        
        res_list = torch.load(res_file, map_location="cpu")
        embeds = torch.load(emb_file, map_location="cpu").to(dtype=DTYPE)
        
        n_samples = min(len(res_list), len(embeds))
        print(f"\n=== {model_name} Chunk {chunk_id}: {n_samples} samples ===")
        
        # Evaluate metrics on all samples
        for i in range(0, n_samples, batch_size):
            res_batch, emb_batch = [], []
            end = min(i + batch_size, n_samples)
            
            for j in range(i, end):
                r = res_list[j]["res"].to(dtype=DTYPE)
                # Apply layer diffs (same as training)
                x = residual_pre_diffs(r).unsqueeze(0)
                x = res_norm.normalize(x)
                y = emb_norm.normalize(embeds[j]).unsqueeze(0)
                res_batch.append(x)
                emb_batch.append(y)
            
            X = torch.cat(res_batch, 0).to(probe_device)
            Y = torch.cat(emb_batch, 0).to(probe_device)
            
            with torch.no_grad():
                P = probe(X)
                all_mse.extend(((P - Y) ** 2).mean(dim=1).cpu().tolist())
                all_cos.extend(F.cosine_similarity(P, Y, dim=1).cpu().tolist())
        
        if per_chunk <= 0:
            continue
        
        step = max(n_samples // per_chunk, 1)
        decode_idxs = list(range(0, n_samples, step))[:per_chunk]
        
        # Print indices being evaluated
        print(f"  Evaluating {len(decode_idxs)} samples")
        print(f"    Local indices: {decode_idxs[:10]}{'...' if len(decode_idxs) > 10 else ''}")
        
        for idx in decode_idxs:
            r = res_list[idx]["res"].to(dtype=DTYPE)
            x = residual_pre_diffs(r).unsqueeze(0)
            x_n = res_norm.normalize(x)
            y_n = emb_norm.normalize(embeds[idx]).unsqueeze(0)
            
            with torch.no_grad():
                y_hat_n = probe(x_n.to(probe_device)).squeeze(0)
            
            mse = F.mse_loss(y_hat_n, y_n.squeeze(0).to(probe_device)).item()
            cos = F.cosine_similarity(y_hat_n.unsqueeze(0), y_n.to(probe_device), dim=1).item()
            
            # Restore from normalized
            # Training uses normalized embeddings: 
            # embeddings are normalized (mean=0, std=1) 
            # for training stability. The probe outputs embeddings 
            # in this normalized space.
            # The SONAR decoder expects original-scale embeddings: 
            # vec2text.predict() (used by decode_embeddings) 
            # was trained on embeddings in their original distribution, 
            # not normalized.
            y_hat_rest = emb_norm.restore(y_hat_n.cpu())
            y_gt_rest = emb_norm.restore(y_n.squeeze(0))
            
            decoded = decode_embeddings([y_gt_rest, y_hat_rest])
            
            rows.append({
                "model": model_name,
                "chunk": chunk_id,
                "index": idx,
                "mse": mse,
                "cosine": cos,
                "outline_generated": decoded[0],
                "decoded_predicted": decoded[1],
            })
    
    # Print summary metrics
    if all_mse:
        mse_t = torch.tensor(all_mse)
        cos_t = torch.tensor(all_cos)
        print(f"\n{model_name} Overall Metrics:")
        print(f"  MSE: mean={mse_t.mean():.4f}, median={mse_t.median():.4f}")
        print(f"  COS: mean={cos_t.mean():.4f}, median={cos_t.median():.4f}")
    
    return pd.DataFrame(rows)


In [None]:
# ==== GENERATE REGENERATION BASELINE ====
# Regenerate outlines from original FineWeb texts (all gemma models, Anna's data) using Llama 70B
# This represents the true "ceiling" - same model regenerating outline from same text

from datasets import load_dataset
from io_utils import extract_outline_for_model
from utils_parallel import process_in_parallel

REGEN_MODEL = "meta-llama/Llama-3.3-70B-Instruct"  # Same model used for original outlines

def regenerate_outline(item):
    idx, completion_text = item
    outline, _ = extract_outline_for_model(REGEN_MODEL, completion_text)
    return {"index": idx, "regenerated_outline": outline}


def generate_regen_baseline(
    model_name: str,
    config: dict,
    eval_indices: list,  # List of (chunk_id, local_idx, global_idx) tuples
    max_workers: int = 20,
) -> pd.DataFrame:
   
    hf_texts_dataset = config.get("hf_texts_dataset")
    if not hf_texts_dataset:
        raise ValueError(f"No hf_texts_dataset configured for {model_name}")
    
    print(f"\nLoading texts from {hf_texts_dataset}...")
    
    # Load the texts dataset (streaming to avoid loading everything)
    texts_ds = load_dataset(hf_texts_dataset, split="train", streaming=True)
    
    # Global index = chunk_id * 1000 + local_idx (assuming 1000 samples per chunk)
    global_indices = set(idx for _, _, idx in eval_indices)
    
    print(f"Fetching {len(global_indices)} texts from dataset...")
    
    # Collect texts for the indices we need
    texts_by_idx = {}
    for i, sample in enumerate(texts_ds):
        if i in global_indices:
            texts_by_idx[i] = sample["completion"]
        if len(texts_by_idx) >= len(global_indices):
            break
        if i > max(global_indices) + 100:  # Safety stop
            break
    
    print(f"Loaded {len(texts_by_idx)} texts")
    
    # Prepare items for parallel regeneration
    items = [(idx, texts_by_idx[idx]) for _, _, idx in eval_indices if idx in texts_by_idx]
    
    print(f"Regenerating {len(items)} outlines using {REGEN_MODEL}...")
    
    # Regenerate outlines in parallel
    results = process_in_parallel(items, regenerate_outline, max_workers=max_workers)
    
    rows = []
    for (chunk_id, local_idx, global_idx), result in zip(eval_indices, results):
        if result and global_idx in texts_by_idx:
            rows.append({
                "model": model_name,
                "chunk": chunk_id,
                "index": local_idx,
                "global_index": global_idx,
                "type": "regen_baseline",
                "completion_text": texts_by_idx[global_idx][:500] + "...",  # Truncated for storage, this is the original full text document from FineWeb
                "regenerated_outline": result["regenerated_outline"],
            })
    
    return pd.DataFrame(rows)


# First, collect all evaluation indices from the probe results
# We need (chunk_id, local_idx, global_idx) for each sample
def collect_eval_indices(all_results: dict, chunk_size: int = 1000) -> dict:
    """Collect evaluation indices from probe results for each model"""
    model_indices = {}
    for model_name, df in all_results.items():
        indices = []
        for _, row in df.iterrows():
            chunk_id = row["chunk"]
            local_idx = row["index"]
            global_idx = chunk_id * chunk_size + local_idx
            indices.append((chunk_id, local_idx, global_idx))
        model_indices[model_name] = indices
    return model_indices



[Outline Generation Config]
  Source Dataset:  annnettte/fineweb-gemma12b-texts
  Output HF Repo:  yulia-volkova/parascopes-outlines-gemma12b
  Local CSV Path:  /workspace/ALGOVERSE/yas/yulia/parascopes/src/yulia/outlines/results/outlines_0.0.csv
  Version:         0.0
  Outline Model:   meta-llama/Llama-3.3-70B-Instruct



In [None]:
# ==== RUN PROBE EVALUATION FOR ALL GEMMA MODELS ====

all_results = {}
model_dims = {}  # Store inferred dimensions for reference

for model_name, config in GEMMA_CONFIGS.items():
    if model_artifacts.get(model_name) is None:
        print(f"\nSkipping {model_name}: no artifacts")
        continue
    
    artifacts = model_artifacts[model_name]
    
    try:
        # Load probe and normalizers (dimensions are inferred from checkpoint)
        probe, res_norm, emb_norm, n_layers, d_model = load_probe_and_normalizers(
            model_name, config, artifacts
        )
        model_dims[model_name] = {"n_layers": n_layers, "d_model": d_model}
        
        # Evaluate and decode
        df = eval_and_decode_model(
            model_name=model_name,
            config=config,
            probe=probe,
            res_norm=res_norm,
            emb_norm=emb_norm,
            chunk_ids=EVAL_CHUNKS,
            per_chunk=25,  # 25 samples per chunk
            batch_size=32,
        )
        
        # Save intermediate results
        csv_path = OUT_DIR / f"decoded_outputs_{model_name}.csv"
        df.to_csv(csv_path, index=False)
        print(f"\nSaved {len(df)} samples to {csv_path}")
        
        all_results[model_name] = df
        
    except Exception as e:
        print(f"\nError evaluating {model_name}: {e}")
        import traceback
        traceback.print_exc()

# Print summary of inferred dimensions
print("\n" + "="*60)
print("INFERRED MODEL DIMENSIONS (from checkpoints)")
print("="*60)
for model_name, dims in model_dims.items():
    print(f"  {model_name}: n_layers={dims['n_layers']}, d_model={dims['d_model']}")


  Loaded checkpoint as pickle
  From config: n_layers=62, d_model=5376
  Probe on GPU
Loaded probe for gemma27b: 62L, d=5376, dev=cuda

Error evaluating gemma27b: too many values to unpack (expected 5)


Traceback (most recent call last):
  File "/tmp/ipykernel_2995670/627019282.py", line 15, in <module>
    probe, res_norm, emb_norm, n_layers, d_model = load_probe_and_normalizers(
ValueError: too many values to unpack (expected 5)


  Loaded checkpoint as pickle
  From config: n_layers=48, d_model=3840
  Probe on GPU
Loaded probe for gemma12b: 48L, d=3840, dev=cuda

Error evaluating gemma12b: too many values to unpack (expected 5)


Traceback (most recent call last):
  File "/tmp/ipykernel_2995670/627019282.py", line 15, in <module>
    probe, res_norm, emb_norm, n_layers, d_model = load_probe_and_normalizers(
ValueError: too many values to unpack (expected 5)


  Loaded checkpoint as pickle
  From config: n_layers=34, d_model=2560
  Probe on GPU
Loaded probe for gemma4b: 34L, d=2560, dev=cuda

Error evaluating gemma4b: too many values to unpack (expected 5)

INFERRED MODEL DIMENSIONS (from checkpoints)


Traceback (most recent call last):
  File "/tmp/ipykernel_2995670/627019282.py", line 15, in <module>
    probe, res_norm, emb_norm, n_layers, d_model = load_probe_and_normalizers(
ValueError: too many values to unpack (expected 5)


In [None]:
# ==== RUN REGEN BASELINE GENERATION ====
# This regenerates outlines from original texts for the same indices we evaluated

# Collect indices from probe results
model_eval_indices = collect_eval_indices(all_results)

regen_baselines = {}

for model_name, config in GEMMA_CONFIGS.items():
    if model_name not in model_eval_indices:
        print(f"\nSkipping {model_name}: no probe results")
        continue
    
    eval_indices = model_eval_indices[model_name]
    print(f"\n{'='*60}")
    print(f"Generating regen baseline for {model_name}")
    print(f"Indices to regenerate: {len(eval_indices)}")
    print(f"{'='*60}")
    
    try:
        df_regen = generate_regen_baseline(
            model_name=model_name,
            config=config,
            eval_indices=eval_indices,
            max_workers=20,
        )
        
        
        csv_path = OUT_DIR / f"regen_baseline_{model_name}.csv"
        df_regen.to_csv(csv_path, index=False)
        print(f"\nSaved {len(df_regen)} regen baseline samples to {csv_path}")
        
        regen_baselines[model_name] = df_regen
        
    except Exception as e:
        print(f"\nError generating regen baseline for {model_name}: {e}")
        import traceback
        traceback.print_exc()



Skipping gemma27b: no probe results

Skipping gemma12b: no probe results

Skipping gemma4b: no probe results


In [None]:
# ==== RUBRIC FOR EVALUATION ====

rubric = """
#### 0. Complexity
How complex is the outline text?
0: Trivial (e.g: just says "** Section **")
1. Simple (e.g: "** Section 1: Green Tea **")
2. Some detail (e.g: a short undetailed sentence or two about something)
3. Many details (e.g: a detailed paragraph with specific information)

#### 1. Coherence (Outline-Level)
Does Outline 2 make sense as an outline compared to Outline 1?
0: Completely incoherent (e.g., excessive repetition, nonsensical phrases, strange symbols).
1: Partially coherent, but repetitive or has formatting issues.
2: Mostly coherent with minor grouping/order issues.
3: Clear, logical, coherent outline structure.

#### 2. Hierarchy / Structure
How well does Outline 2 preserve the hierarchical levels (headings vs sub-bullets)?
0: No recognizable hierarchy; flat or malformed.
1: Basic levels exist but often mis-nested.
2: Mostly correct hierarchy with some mismatches.
3: Hierarchy closely matches with minimal deviations.

#### 3. Coverage of Key Sections
Do the major sections in Outline 1 appear in Outline 2?
0: Most key sections missing or unrelated.
1: About half of major sections appear.
2: Most sections present; minor omissions.
3: All major sections present (allow synonyms/regrouping).

#### 4. Ordering / Flow
Does the order of major sections and sub-sections follow Outline 1?
0: Largely shuffled or illogical.
1: Partial overlap but frequent swaps.
2: Mostly consistent with minor swaps.
3: Order closely matches.

#### 5. Subject Match
How similar is the subject of Outline 2 to Outline 1?
-1: No subjects to compare.
0: Completely unrelated subjects
1: Vaguely similar field
2: Related general domain or adjacent fields
3: Same subject
4: Identical focus

#### 6. Entities / Key Concepts
How well does Outline 2 preserve entities or technical terms from Outline 1?
-1: No entities to compare.
0: Unrelated entities.
1: Same category but little overlap.
2: Some overlap or synonyms.
3: Most entities/terms preserved.
4: Nearly all preserved.

#### 7. Details
How similar are the details in Outline 2 to Outline 1?
-1: Neither outline has details to compare.
0: Details differ completely.
1: Minimal depth.
2: Moderate depth.
3: Highly specific details.

#### 8. Conciseness of Headings
Are headings concise and outline-appropriate?
0: Often verbose, unclear, or sentence-like.
1: Mixed clarity.
2: Mostly concise, descriptive headings.

#### 9. Identical
Is Outline 2 essentially identical to Outline 1?
0: Not identical.
1: Identical.
---

JSON output: {
    "reasoning": {complexity, coherence, hierarchy, coverage, ordering, subject, entities, details, conciseness, identical}
    "scoring":  {Same keys as above} - each with number score
}
"""


In [None]:
# ==== RUBRIC EVALUATION FUNCTIONS ====
from utils_parallel import exponential_backoff, process_in_parallel

@exponential_backoff
def rubric_compare(ref_text: str, comp_text: str):
    """Call LLM to compare two outlines using the rubric"""
    prompt = (
        f"Using the following rubric, compare the two outlines:\n\n"
        f"Rubric: {rubric}\n\n"
        f"Outline 1 (reference): {ref_text}\n\n"
        f"Outline 2 (candidate): {comp_text}\n\n"
        "The output must be a valid JSON object and nothing else."
    )

    client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", ""))

    response = client.chat.completions.create(
        model="gpt-4o-mini-2024-07-18",
        messages=[{
            "role": "user",
            "content": "You are an expert evaluator.\n\n" + prompt
        }],
        temperature=0.3,
        max_tokens=4000,
        response_format={"type": "json_object"}
    )

    return response.choices[0].message.content


def get_rubric_parallel(ref_texts: List[str], comp_texts: List[str], 
                        indices: List[int], label=None, max_workers: int = 20):
    """Run rubric comparison in parallel"""
    items = list(zip(indices, ref_texts, comp_texts))
    print(f"Processing {len(items)} comparisons for {label}")

    def get_rubric(item):
        index, ref_text, comp_text = item
        result = rubric_compare(ref_text, comp_text)
        print(f"  [{index}] Done")
        return result

    results = process_in_parallel(items, get_rubric, max_workers=max_workers)
    return results


def evaluate_with_rubric(
    df: pd.DataFrame,
    ref_col: str = "outline_generated",
    cand_col: str = "decoded_predicted",
    model_name: str = "model",
    parallel_workers: int = 20,
) -> pd.DataFrame:
    """Run rubric evaluation on a dataframe"""
    ref_texts = df[ref_col].astype(str).tolist()
    comp_texts = df[cand_col].astype(str).tolist()
    indices = df["index"].tolist() if "index" in df.columns else list(range(len(df)))

    raw_results = get_rubric_parallel(
        ref_texts, comp_texts, indices, label=model_name, max_workers=parallel_workers
    )

    df_out = df.copy()
    df_out["rubric_json"] = raw_results

    # Parse and expand scores
    def _safe_load(js):
        try:
            return json.loads(js)
        except:
            return {}

    parsed = [_safe_load(s) for s in raw_results]
    score_keys = ["complexity", "coherence", "hierarchy", "coverage", "ordering",
                  "subject", "entities", "details", "conciseness", "identical"]
    
    for k in score_keys:
        df_out[f"score_{k}"] = [p.get("scoring", {}).get(k, None) for p in parsed]

    # Summary scores
    score_cols = [f"score_{k}" for k in score_keys]
    df_out["score_sum"] = df_out[score_cols].apply(
        lambda r: np.nansum([float(x) if x is not None else np.nan for x in r.values]), axis=1
    )
    df_out["score_mean"] = df_out[score_cols].apply(
        lambda r: np.nanmean([float(x) if x is not None else np.nan for x in r.values]), axis=1
    )

    return df_out


In [None]:
# ==== RUN RUBRIC EVALUATION FOR PROBE OUTPUTS ====

scored_results = {}

for model_name, df in all_results.items():
    print(f"\n{'='*60}")
    print(f"Running rubric evaluation for {model_name} (PROBE)")
    print(f"{'='*60}")
    
    df_scored = evaluate_with_rubric(
        df,
        ref_col="outline_generated",
        cand_col="decoded_predicted",
        model_name=f"{model_name}_probe",
        parallel_workers=20,
    )
    
    # Save scored results
    csv_path = OUT_DIR / f"rubric_scores_{model_name}_probe.csv"
    df_scored.to_csv(csv_path, index=False)
    print(f"\nSaved rubric scores to {csv_path}")
    
    scored_results[model_name] = df_scored

# ==== RUN RUBRIC EVALUATION FOR REGEN BASELINE ====
# Compare original outline vs regenerated outline (from same text using Llama 70B)
# This represents the "ceiling" - same model regenerating from same text

scored_regen = {}

for model_name, df_regen in regen_baselines.items():
    print(f"\n{'='*60}")
    print(f"Running rubric evaluation for {model_name} (REGEN BASELINE)")
    print(f"{'='*60}")
    
    # Merge with probe results to get the original outline
    if model_name in all_results:
        df_probe = all_results[model_name]
        
        # Merge on chunk and index to get original outline
        df_merged = df_regen.merge(
            df_probe[["chunk", "index", "outline_generated"]].rename(
                columns={"outline_generated": "original_outline"}
            ),
            on=["chunk", "index"],
            how="left"
        )
        
        if df_merged["original_outline"].notna().sum() > 0 and "regenerated_outline" in df_merged.columns:
            df_scored_regen = evaluate_with_rubric(
                df_merged,
                ref_col="original_outline",        # Original outline from probe/embedding
                cand_col="regenerated_outline",    # Newly regenerated outline from Llama 70B
                model_name=f"{model_name}_regen",
                parallel_workers=20,
            )
            
            csv_path = OUT_DIR / f"rubric_scores_{model_name}_regen.csv"
            df_scored_regen.to_csv(csv_path, index=False)
            print(f"\nSaved regen baseline rubric scores to {csv_path}")
            
            scored_regen[model_name] = df_scored_regen
        else:
            print(f"Warning: Missing data for {model_name}")
            print(f"  - original_outline notna: {df_merged['original_outline'].notna().sum()}")
            print(f"  - regenerated_outline present: {'regenerated_outline' in df_merged.columns}")
    else:
        print(f"Warning: No probe results for {model_name}, skipping regen baseline rubric")


The below cells are not checked, check if you want to rerun!!

In [None]:
# ==== PRINT STATISTICS ====

def print_rubric_statistics(df, model_name):
    print("=" * 80)
    print(f"RUBRIC STATISTICS - {model_name} (self-comparison: original vs decoded)")
    print("=" * 80)
    
    print(f"\nSample Size: {len(df)} outline pairs")
    
    # Total score statistics
    print("\nTotal Score Statistics:")
    print(f"  Average: {df['score_sum'].mean():.2f}")
    print(f"  Max: {df['score_sum'].max():.2f}")
    print(f"  Min: {df['score_sum'].min():.2f}")
    print(f"  Std: {df['score_sum'].std():.2f}")
    
    max_scores = {
        'score_complexity': 3, 'score_coherence': 3, 'score_hierarchy': 3,
        'score_coverage': 3, 'score_ordering': 3, 'score_subject': 4,
        'score_entities': 4, 'score_details': 3, 'score_conciseness': 2,
        'score_identical': 1
    }
    
    score_cols = [col for col in df.columns if col.startswith('score_') 
                  and col not in ['score_sum', 'score_mean']]
    
    print("\nCategory Statistics:")
    print("-" * 60)
    print(f"{'Category':<20} {'Average':<10} {'Max':<10} {'% of Max':<10}")
    print("-" * 60)
    
    for col in score_cols:
        avg = df[col].mean()
        max_val = df[col].max()
        max_possible = max_scores.get(col, 3)
        pct = (avg / max_possible) * 100 if max_possible > 0 else 0
        print(f"{col.replace('score_', ''):<20} {avg:>8.2f}  {max_val:>8.2f}  {pct:>8.1f}%")


for model_name, df in scored_results.items():
    print_rubric_statistics(df, model_name)
    print()


In [None]:
# ==== VISUALIZATION: PROBE vs REGEN BASELINE ====
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.patches import Patch
from matplotlib import colors as mcolors

sns.set_theme()
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 11

score_cols = ['score_coverage', 'score_ordering', 'score_subject', 
              'score_entities', 'score_details']

column_name_map = {
    "score_coverage": "Coverage",
    "score_ordering": "Ordering",
    "score_subject": "Subject",
    "score_entities": "Entities",
    "score_details": "Details",
}

# Create comparison plots for each model: Probe vs Regen Baseline
for model_name in scored_results.keys():
    if model_name not in scored_regen:
        print(f"Skipping {model_name}: no regen baseline")
        continue
    
    df_probe = scored_results[model_name]
    df_regen = scored_regen[model_name]
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    x = np.arange(len(score_cols))
    width = 0.35
    
    # Get means for both
    probe_means = [df_probe[col].mean() for col in score_cols]
    regen_means = [df_regen[col].mean() for col in score_cols]
    
    # Create bars
    bars1 = ax.bar(x - width/2, regen_means, width, label='Regen Baseline (ceiling)', 
                   color='#5cb85c', alpha=0.8)
    bars2 = ax.bar(x + width/2, probe_means, width, label='Probe Decoded', 
                   color='#0275d8', alpha=0.8)
    
    # Add value labels
    for bar, val in zip(bars1, regen_means):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05,
                f'{val:.2f}', ha='center', va='bottom', fontsize=10)
    for bar, val in zip(bars2, probe_means):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05,
                f'{val:.2f}', ha='center', va='bottom', fontsize=10)
    
    ax.set_ylabel('Average Score')
    ax.set_title(f'{model_name}: Probe vs Regen Baseline\n(n={len(df_probe)} samples)')
    ax.set_xticks(x)
    ax.set_xticklabels([column_name_map.get(col, col) for col in score_cols])
    ax.legend()
    ax.set_ylim(0, 4.5)
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(OUT_DIR / f"{model_name}_probe_vs_regen.png", dpi=300, bbox_inches='tight')
    plt.show()

# Combined comparison across all models
if scored_results and scored_regen:
    n_models = len(scored_results)
    fig, axes = plt.subplots(1, n_models, figsize=(5*n_models, 6), sharey=True)
    if n_models == 1:
        axes = [axes]
    
    for ax, model_name in zip(axes, scored_results.keys()):
        if model_name not in scored_regen:
            continue
            
        df_probe = scored_results[model_name]
        df_regen = scored_regen[model_name]
        
        x = np.arange(len(score_cols))
        width = 0.35
        
        probe_means = [df_probe[col].mean() for col in score_cols]
        regen_means = [df_regen[col].mean() for col in score_cols]
        
        ax.bar(x - width/2, regen_means, width, label='Regen', color='#5cb85c', alpha=0.8)
        ax.bar(x + width/2, probe_means, width, label='Probe', color='#0275d8', alpha=0.8)
        
        ax.set_title(f'{model_name}')
        ax.set_xticks(x)
        ax.set_xticklabels([c.replace('score_', '')[:4] for c in score_cols], rotation=45)
        if ax == axes[0]:
            ax.set_ylabel('Average Score')
            ax.legend()
        ax.set_ylim(0, 4.5)
        ax.grid(True, alpha=0.3, axis='y')
    
    plt.suptitle('Gemma Probes: Probe Decoded vs Regen Baseline', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(OUT_DIR / "all_models_probe_vs_regen.png", dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
# ==== COMBINED COMPARISON TABLE ====

score_cols_summary = ['score_coverage', 'score_ordering', 'score_subject', 
                      'score_entities', 'score_details', 'score_mean']

# Create summary table comparing probe vs regen baseline for all models
summary_data = []

for model_name in scored_results.keys():
    # Probe scores
    df_probe = scored_results[model_name]
    row_probe = {'model': model_name, 'type': 'probe'}
    for col in score_cols_summary:
        if col in df_probe.columns:
            row_probe[col] = df_probe[col].mean()
    summary_data.append(row_probe)
    
    # Regen baseline scores (if available)
    if model_name in scored_regen:
        df_regen = scored_regen[model_name]
        row_regen = {'model': model_name, 'type': 'regen_baseline'}
        for col in score_cols_summary:
            if col in df_regen.columns:
                row_regen[col] = df_regen[col].mean()
        summary_data.append(row_regen)

if summary_data:
    df_summary = pd.DataFrame(summary_data)
    
    print("\n" + "="*100)
    print("SUMMARY: Probe vs Regen Baseline Scores Across All Gemma Models")
    print("="*100)
    print(df_summary.round(3).to_string(index=False))
    
    # Calculate gap between probe and regen baseline
    print("\n" + "-"*100)
    print("GAP ANALYSIS (Regen Baseline - Probe):")
    print("-"*100)
    for model_name in scored_results.keys():
        if model_name in scored_regen:
            df_probe = scored_results[model_name]
            df_regen = scored_regen[model_name]
            print(f"\n{model_name}:")
            for col in ['score_coverage', 'score_subject', 'score_entities', 'score_details']:
                if col in df_probe.columns and col in df_regen.columns:
                    gap = df_regen[col].mean() - df_probe[col].mean()
                    probe_val = df_probe[col].mean()
                    regen_val = df_regen[col].mean()
                    pct = (probe_val / regen_val * 100) if regen_val > 0 else 0
                    print(f"  {col.replace('score_', ''):12}: Regen={regen_val:.2f}, Probe={probe_val:.2f}, Gap={gap:+.2f} ({pct:.1f}% of ceiling)")

    # Save summary
    df_summary.to_csv(OUT_DIR / "gemma_probes_summary.csv", index=False)
    print(f"\n\nSaved summary to {OUT_DIR / 'gemma_probes_summary.csv'}")
