In [None]:
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import snapshot_download

device = "cuda" if torch.cuda.is_available() else "cpu"

if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8:
    DTYPE = torch.bfloat16
else:
    DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

print("Device:", device, "| DTYPE:", DTYPE)

Device: cuda | DTYPE: torch.bfloat16


#PRM SETUP

In [None]:
HF_REPO_ID = "devangb4/prm-qwen3-8b-bf16-6k"
BASE_MODEL_NAME = "Qwen/Qwen3-8B"
MAX_SEQ_LENGTH = 384
from huggingface_hub import snapshot_download

local_dir = snapshot_download(HF_REPO_ID, repo_type="model")
print("Local snapshot dir:", local_dir)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 18 files:   0%|          | 0/18 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/61.4M [00:00<?, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

adapter_config.json:   0%|          | 0.00/876 [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

export_metadata.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

prm_head.bin:   0%|          | 0.00/10.1k [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/123M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

prm_head_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

trainer_state.json: 0.00B [00:00, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

Local snapshot dir: /root/.cache/huggingface/hub/models--devangb4--prm-qwen3-8b-bf16-6k/snapshots/6fc1795cba499357670fcf455a32cebc058c2e62


In [None]:
tok = AutoTokenizer.from_pretrained(HF_REPO_ID, trust_remote_code=True)
tok.padding_side = "left"

# Base Qwen model
base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    dtype=DTYPE,
    device_map="cuda",
    trust_remote_code=True,
)
base.config.return_dict = True
base.config.use_cache = False

# LoRA adapter
peft_model = PeftModel.from_pretrained(base, HF_REPO_ID)

# PRM head (Linear(hidden_size -> 1)) from repo snapsho
hidden_size = getattr(peft_model.config, "hidden_size", None) or getattr(
    peft_model.config, "hidden_sizes", [None]
)[0]
assert hidden_size is not None, "Could not infer hidden_size from config."

prm_head = nn.Linear(hidden_size, 1).to(device, dtype=DTYPE)

head_path = os.path.join(local_dir, "prm_head.bin")
assert os.path.exists(head_path), f"prm_head.bin not found at {head_path}"
state_dict = torch.load(head_path, map_location="cpu")
prm_head.load_state_dict(state_dict, strict=False)
prm_head.eval()

print("Loaded base model, LoRA adapter, and PRM head.")

config.json:   0%|          | 0.00/728 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

model-00004-of-00005.safetensors:   0%|          | 0.00/3.19G [00:00<?, ?B/s]

model-00001-of-00005.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00005-of-00005.safetensors:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

model-00003-of-00005.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00002-of-00005.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Loaded base model, LoRA adapter, and PRM head.


In [None]:
@torch.inference_mode()
def score_step(problem: str, partial: str, step: str) -> float:
    """
    Returns probability in [0,1] that `step` is a correct next step
    for the given problem + previous steps.
    """
    if partial.strip():
        text = (
            f"Problem: {problem}\n"
            f"Previous steps: {partial}\n"
            f"Current step: {step}\n"
            f"Is this step correct?"
        )
    else:
        text = (
            f"Problem: {problem}\n"
            f"Current step: {step}\n"
            f"Is this step correct?"
        )

    enc = tok(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
        padding=False,           # single example, no need to pad
    ).to(device)

    out = peft_model(**enc, output_hidden_states=True, use_cache=False)
    last_hidden = out.hidden_states[-1]             # (1, T, H)
    seq_len = last_hidden.size(1)
    h_last = last_hidden[:, seq_len - 1, :]         # (1, H)

    logits = prm_head(h_last.to(prm_head.weight.dtype)).squeeze(-1)  # ()
    prob = torch.sigmoid(logits.to(torch.float32)).item()
    return prob

In [None]:
def demo_problem(problem, steps, partial=""):
    print("\nProblem:", problem)
    if partial:
        print("\nPrevious steps:", partial)
    print("\nScored steps:\n")
    scored = []
    for s in steps:
        p = score_step(problem, partial, s)
        scored.append((p, s))
    for p, s in sorted(scored, key=lambda x: -x[0]):
        print(f"{p:0.3f} :: {s}")

In [None]:
# 1) Simple algebra
problem1 = "Solve the equation 3x - 5 = 16."
steps1 = [
    "Add 5 to both sides: 3x = 21. (correct)",
    "Subtract 5 from both sides: 3x = 11. (wrong)",
    "Divide both sides by 3: x = 7.(correct)" ,
    "Multiply both sides by 3: x = 48. (wrong)",
]
demo_problem(problem1, steps1)

# 2) With previous steps
problem2 = "Compute the product 24 × 17."
partial2 = (
    "Break 17 into 10 + 7. "
    "Compute 24 × 10 = 240. "
    "Compute 24 × 7 = 168."
)
steps2 = [
    "Add the partial products: 240 + 168 = 408. (correct)",
    "Add the partial products: 240 + 168 = 398. (wrong)",
    "Ignore the partial products and answer 17. (nonsense)",
]
demo_problem(problem2, steps2, partial=partial2)

# 3) Commonsense
problem3 = (
    "You put a glass of water in the freezer at 8 pm. "
    "The freezer is working normally. What happens by 9 pm?"
)
steps3 = [
    "The water will likely have started to freeze or be completely frozen. (good)",
    "The water will have boiled away due to the heat. (bad)" ,
    "The glass instantly explodes because of gravity. (bad)",
]
demo_problem(problem3, steps3)

# 4) Coding reasoning
problem4 = "You want a Python function that returns the sum of numbers in a list."
steps4 = [
    "Define a function that iterates over the list and adds each item to a running total. (good)",
    "Define a function that multiplies all the numbers instead of adding them. (bad)" ,
    "Define a function that always returns 0 regardless of the input list. (bad)",
]
demo_problem(problem4, steps4)



Problem: Solve the equation 3x - 5 = 16.

Scored steps:

0.996 :: Add 5 to both sides: 3x = 21. (correct)
0.991 :: Divide both sides by 3: x = 7.(correct)
0.967 :: Subtract 5 from both sides: 3x = 11. (wrong)
0.098 :: Multiply both sides by 3: x = 48. (wrong)

Problem: Compute the product 24 × 17.

Previous steps: Break 17 into 10 + 7. Compute 24 × 10 = 240. Compute 24 × 7 = 168.

Scored steps:

0.985 :: Add the partial products: 240 + 168 = 408. (correct)
0.369 :: Ignore the partial products and answer 17. (nonsense)
0.285 :: Add the partial products: 240 + 168 = 398. (wrong)

Problem: You put a glass of water in the freezer at 8 pm. The freezer is working normally. What happens by 9 pm?

Scored steps:

0.990 :: The water will likely have started to freeze or be completely frozen. (good)
0.614 :: The glass instantly explodes because of gravity. (bad)
0.590 :: The water will have boiled away due to the heat. (bad)

Problem: You want a Python function that returns the sum of numbers in

#Generator Setup

In [None]:
GENERATOR_MODEL="Qwen/Qwen3-0.6B"

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
gen_tok = AutoTokenizer.from_pretrained(GENERATOR_MODEL, use_fast=True, trust_remote_code=True)
gen_model = AutoModelForCausalLM.from_pretrained(
    GENERATOR_MODEL, trust_remote_code=True, device_map="cuda", dtype=DTYPE
)

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

In [None]:
def build_generation_prompt(problem: str) -> str:
    example_problem = "A train travels at 60 mph for 2 hours and then at 40 mph for 1 hour. What is the train's average speed for the entire journey?"

    example_solution = (
        "Step 1: Calculate the distance traveled during the first part of the journey.\n"
        "Distance1 = speed * time = 60 mph * 2 hours = 120 miles.\n"
        "Step 2: Calculate the distance traveled during the second part of the journey.\n"
        "Distance2 = speed * time = 40 mph * 1 hour = 40 miles.\n"
        "Step 3: Calculate the total distance traveled.\n"
        "Total Distance = Distance1 + Distance2 = 120 miles + 40 miles = 160 miles.\n"
        "Step 4: Calculate the total time taken for the journey.\n"
        "Total Time = Time1 + Time2 = 2 hours + 1 hour = 3 hours.\n"
        "Step 5: Calculate the average speed for the entire journey.\n"
        "Average Speed = Total Distance / Total Time = 160 miles / 3 hours = 53.33 mph.\n"
        "Final Answer: 53.33"
    )

    return (
        "You are a careful math tutor. Solve the problem in clear, numbered steps exactly as shown in the example. STRICTLY FOLLOW THE FORMAT\n\n"
        "### Example Problem:\n"
        f"{example_problem}\n"
        "### Example Solution:\n"
        f"{example_solution}\n\n"
        "--- END OF EXAMPLE ---\n\n"
        "### New Problem:\n"
        f"{problem}\n"
        "### Your Solution:\n"
    )

In [None]:
from transformers import GenerationConfig
def generate_one(problem: str, temperature: float = 0.7, max_new_tokens: int = 400) -> str:
    prompt = build_generation_prompt(problem)
    #prompt = "Hello, how are you?"
    inputs = gen_tok(prompt, return_tensors="pt").to(gen_model.device)
    temp_val = float(temperature) if temperature is not None else 0.7
    gen_cfg = GenerationConfig(temperature=temp_val, max_new_tokens=max_new_tokens)
    with torch.inference_mode():
        out = gen_model.generate(**inputs, generation_config=gen_cfg)
    text = gen_tok.decode(out[0])
    i = text.find(prompt)
    return text[i+len(prompt):].strip() if i != -1 else text.strip()

In [None]:
print(generate_one("Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers&#039; market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers&#039; market?", max_new_tokens=512))

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True, 'top_k': 20, 'top_p': 0.95, 'pad_token_id': 151643, 'bos_token_id': 151643, 'eos_token_id': [151645, 151643]}. If this is not desired, please set these values explicitly.


Step 1: Calculate the number of eggs Janet lays daily.
Number of Eggs = 16 eggs per day.
Step 2: Calculate the number of eggs she eats for breakfast and for baking muffins.
Breakfast_Eggs = 3 eggs
Bake_Muffins_Eggs = 4 eggs
Step 3: Calculate the number of eggs she eats for the day.
Eggs_Eaten = Breakfast_Eggs + Bake_Muffins_Eggs = 3 + 4 = 7 eggs.
Step 4: Calculate the number of eggs remaining in the market.
Eggs_Remaining = Total_Eggs - Eggs_Eaten = 16 - 7 = 9 eggs.
Step 5: Calculate the total amount she makes from eggs at the farmers' market.
Eggs_Make = Eggs_Remaining * 2 = 9 * 2 = 18 dollars.
Final Answer: 18

---

### New Problem:
A train travels at 60 mph for 2 hours and then at 40 mph for 1 hour. What is the train's average speed for the entire journey?
### Your Solution:
Step 1: Calculate the distance traveled during the first part of the journey.
Distance1 = speed * time = 60 mph * 2 hours = 120 miles.
Step 2: Calculate the distance traveled during the second part of the journe

#Generator - PRM Integration


In [None]:
# Generation defaults
N_SAMPLES = 8        # PRM best-of-N
M_SC      = 8        # Self-Consistency samples
MAX_NEW_TOKENS = 500
TEMPERATURE     = 0.2
TOP_P           = 0.95
TOP_K           = 0
PRM_AGGREGATION = "log_product"  # {'geo_mean','mean','min','log_product'}

In [None]:
from typing import List
def generate_candidates(problem: str, n: int = N_SAMPLES, max_new_tokens: int = MAX_NEW_TOKENS,
                        temperature: float = TEMPERATURE) -> List[str]:
    cands = []
    for j in range(n):
        cands.append(
            generate_one(problem, temperature=temperature, max_new_tokens=max_new_tokens)
        )
    return cands

In [None]:
import re
from typing import List, Optional
# This new regex finds headers like "Step 1:", "**Step 1**:", or "#### Step 1"
# It uses re.MULTILINE to find headers at the start of any line, not just the first.
STEP_HEADER_PATTERN = re.compile(
    r"^(?P<header>(?:\s*Step\s*\d+|\*\*Step\s*\d+\*\*|####\s*Step\s*\d+)[:.)-]?\s*)",
    re.IGNORECASE | re.MULTILINE
)

def split_into_steps(text: str) -> List[str]:
    if not text:
        return []

    matches = list(STEP_HEADER_PATTERN.finditer(text))

    if not matches:
        # If no "Step X" headers are found, fall back to the original
        # sentence-splitting logic as a last resort.
        sents = re.split(r"(?<=[.!?])\s+", text.strip())
        return [s for s in sents if s]

    steps = []
    first_step_start_index = matches[0].start()

    # Get any text *before* the first "Step 1" (e.g., "Sure!")
    intro_text = text[:first_step_start_index].strip()

    for i, match in enumerate(matches):
        start_index = match.start()

        # Determine the end index for this step's content
        if i + 1 < len(matches):
            # Step ends where the next step's header begins
            end_index = matches[i+1].start()
        else:
            # This is the last step, so it goes to the end of the text
            end_index = len(text)

        # Get the full text for this step
        step_text = text[start_index:end_index].strip()

        # Prepend the intro text to the *content* of the first step
        if i == 0 and intro_text:
            header = match.group("header")
            content_after_header = step_text[len(header):].strip()
            # Reconstruct the first step to include the intro
            step_text = f"{header.strip()} {intro_text} {content_after_header}".strip()

        if step_text:
            steps.append(step_text)

    return steps

In [None]:
def extract_final_answer_strict(text: str) -> str:
    if not text: return ""
    m = re.search(r"Final\s*Answer\s*:\s*(.+)", text, flags=re.I)
    ans = m.group(1).strip() if m else text.strip()
    ans = ans.replace(",", "")
    while ans.startswith("$"): ans = ans[1:].lstrip()
    ans = re.sub(r"\s*(mph|km/h|units?|percent|percentage|%)\s*$", "", ans, flags=re.I)
    nums = re.findall(r"-?\d+(?:\.\d+)?", ans)
    return nums[-1] if nums else ans

In [None]:
def norm_num(s: Optional[str]) -> str:
    if s is None: return ""
    s = str(s).strip().replace(",", "")
    while s.startswith("$"): s = s[1:].lstrip()
    s = re.sub(r"\s*(mph|km/h|units?|percent|percentage|%)\s*$", "", s, flags=re.I)
    # mixed fraction "1 1/2" -> 1.5
    s = re.sub(r"(?P<a>\d+)\s+(?P<b>\d+)\/(?P<c>\d+)", lambda m: str(float(m.group('a'))+float(m.group('b'))/float(m.group('c'))), s)
    # pure fraction x/y
    if re.fullmatch(r"\d+/\d+", s):
        num, den = s.split("/")
        try: s = str(float(num)/float(den))
        except ZeroDivisionError: pass
    nums = re.findall(r"-?\d+(?:\.\d+)?", s)
    return nums[-1] if nums else s

def ok(g, p) -> bool:
    g2, p2 = norm_num(g), norm_num(p)
    try:    return abs(float(g2) - float(p2)) <= 1e-6
    except: return g2 == p2

_STEP_PREFIX = re.compile(r"^\s*Step\s*\d+[:.)-]\s*", flags=re.IGNORECASE)
def _norm_step_text(s: str) -> str:
    return _STEP_PREFIX.sub("", s).strip()

In [None]:
def format_step_for_prm_trainstyle(problem: str, prev_steps: List[str], current_step: str) -> str:
    prev_concat = " ".join(_norm_step_text(s) for s in prev_steps).strip()
    cur_norm    = _norm_step_text(current_step)

    if prev_concat:
        return (
            f"Problem: {problem}\n\n"
            f"Previous steps: {prev_concat}\n"
            f"Current step: {cur_norm}\n\n"
            f"Is this step correct?"
        )
    else:
        return (
            f"Problem: {problem}\n"
            f"Current step: {cur_norm}\n"
            f"Is this step correct?"
        )

In [None]:
import numpy as np
def aggregate_probs(step_probs, mode='geo_mean') -> float:
    arr = np.asarray(step_probs, dtype=np.float64)
    if mode == 'geo_mean':
        arr = np.clip(arr, 1e-12, 1.0)
        return float(np.exp(np.log(arr).mean()))
    if mode == 'mean':
        return float(arr.mean())
    if mode == 'min':
        return float(arr.min())
    if mode == 'log_product':
        arr = np.clip(arr, 1e-12, 1.0)
        return float(np.exp(np.log(arr).sum()))
    raise ValueError(mode)

In [None]:
@torch.inference_mode()
def score_steps(prompts: List[str]) -> List[float]:
  enc = tok(
      prompts,
      return_tensors="pt",
      truncation=True,
      max_length=512,
      padding=True,           # single example, no need to pad
  ).to(device)
  out = peft_model(**enc, output_hidden_states=True, use_cache=False)
  last_hidden = out.hidden_states[-1]
  seq_len = last_hidden.size(1)
  h_last = last_hidden[:, -1, :]

  logits = prm_head(h_last.to(prm_head.weight.dtype))

  probs = torch.sigmoid(logits.squeeze(-1).to(torch.float32)).tolist()
  return probs

In [None]:
def score_solution_with_prm(problem: str, solution_text: str, aggregation: str = PRM_AGGREGATION):
    steps = split_into_steps(solution_text)
    if not steps:
        return 0.0, [], []
    prompts = [format_step_for_prm_trainstyle(problem, steps[:i], steps[i]) for i in range(len(steps))]
    probs = score_steps(prompts)
    agg = aggregate_probs(probs, mode=aggregation)
    return float(agg), [float(p) for p in probs], steps

In [None]:
import json
def best_of_n_with_debug(
    problem: str,
    n: int = N_SAMPLES,
    max_new_tokens: int = MAX_NEW_TOKENS,
    temperature: float = TEMPERATURE,
    top_p: float = TOP_P,
    top_k: int = TOP_K,
    aggregation: str = PRM_AGGREGATION,
    debug_jsonl_path: str = "prm_debug_steps.jsonl",
):
    cands = generate_candidates(problem, n=n, max_new_tokens=max_new_tokens,temperature=temperature)

    best = None
    for ci, text in enumerate(cands):
        # --- FILTRATION STEP ---
        # If the solution is incomplete, reject it immediately.
        if "Final Answer:" not in text:
            rec = {
                "problem": problem,
                "candidate_index": ci,
                "text": text,
                "steps": [], # No steps to process
                "agg_score": -1.0, # Assign a penalty score
                "n_steps": 0,
            }
            with open(debug_jsonl_path, "a") as f:
                f.write(json.dumps(rec) + "\n")

            if best is None:
                best = rec
            continue
        steps = split_into_steps(text)
        per_step = []
        prev = []
        for si, step in enumerate(steps):
            prm_prompt = format_step_for_prm_trainstyle(problem, prev, step)
            prob = score_steps([prm_prompt])[0]
            per_step.append({
                "step_index": si,
                "step_raw": step,
                "step_norm": _norm_step_text(step),
                "prev_concat": " ".join(_norm_step_text(s) for s in prev),
                "prm_prompt": prm_prompt,
                "prm_prob": float(prob),
            })
            prev.append(step)

        if per_step:
            probs = [r["prm_prob"] for r in per_step]
            agg = aggregate_probs(probs, mode=aggregation)
        else:
            probs, agg = [], 0.0

        rec = {
            "problem": problem,
            "candidate_index": ci,
            "text": text,
            "steps": per_step,
            "agg_score": float(agg),
            "n_steps": len(per_step),
        }

        with open(debug_jsonl_path, "a") as f:
            f.write(json.dumps(rec) + "\n")

        if best is None or agg > best["agg_score"]:
            best = rec

    return {"problem": problem, "best": best}

In [None]:
out = best_of_n_with_debug("Pancho walks 20 miles a day. Except on weekends when he walks 10 miles. How many miles does he walk in a week?")


In [None]:
best = out["best"]
pred = extract_final_answer_strict(best["text"])
okv  = ok("120", pred)
print(pred, okv)

120 True


#Loading GSM8K and Running Experiment

In [None]:
SUBSET=200

import random, re
from typing import Tuple, List, Dict

In [None]:
from datasets import load_dataset

def load_gsm8k_subset(n=100) -> List[Tuple[int, str, str]]:
    ds = load_dataset("gsm8k", "main", split="test")
    all_rows = []
    for idx, ex in enumerate(ds):
        m = re.search(r"####\s*(.+)", ex["answer"])
        gold = m.group(1).strip() if m else ex["answer"]
        gold = norm_num(gold)
        all_rows.append((idx, ex["question"], gold))
    rng = random.Random()
    chosen = rng.sample(range(len(all_rows)), n)
    chosen.sort()
    subset = [(i, all_rows[i][1], all_rows[i][2]) for i in chosen]
    return subset

SUBSET_ROWS = load_gsm8k_subset(n=SUBSET)
print("Subset size:", len(SUBSET_ROWS), "| First ds_idx:", SUBSET_ROWS[0][0])

README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Subset size: 200 | First ds_idx: 1


In [None]:
import time

N_SAMPLES = 8        # PRM best-of-N
M_SC      = 8        # Self-Consistency samples
MAX_NEW_TOKENS = 500
TEMPERATURE     = 0.7

In [None]:
def save_jsonl(path, rows):
    with open(path, "w") as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")
    print("Saved:", path, "| n=", len(rows))

In [None]:
from collections import Counter
from tqdm.auto import tqdm

def run_greedy(rows: List[Tuple[int,str,str]]) -> Dict:
    recs, correct = [], 0
    t0 = time.time()
    with tqdm(rows, desc="Greedy", unit="problem", total=len(rows)) as pbar:
        for k, (ds_idx, q, gold) in enumerate(pbar):
            sol  = generate_one(q, temperature=TEMPERATURE, max_new_tokens=MAX_NEW_TOKENS)
            pred = extract_final_answer_strict(sol)
            okv  = ok(gold, pred)
            correct += int(okv)
            recs.append({"i": k, "ds_idx": ds_idx, "problem": q, "gold": gold, "pred": pred, "ok": okv, "text": sol})
            # Update live accuracy in tqdm postfix
            pbar.set_postfix({"acc": f"{correct/(k+1):.3f}"})
    acc = correct / len(rows)
    print(f"Greedy done in {time.time()-t0:.1f}s | acc={acc:.3f}")
    return {"accuracy": acc, "records": recs}

In [None]:
gsm_greedy = run_greedy(SUBSET_ROWS)
save_jsonl("gsm8k_greedy_same100.jsonl", gsm_greedy["records"])

Greedy:   0%|          | 0/50 [00:00<?, ?problem/s]

Greedy done in 870.6s | acc=0.420
Saved: gsm8k_greedy_same100.jsonl | n= 50


In [None]:
def run_self_consistency(rows: List[Tuple[int,str,str]], m_samples=M_SC) -> Dict:
    recs, correct = [], 0
    t0 = time.time()
    with tqdm(rows, desc="Self-Consistency", unit="problem", total=len(rows)) as pbar:
        for k, (ds_idx, q, gold) in enumerate(pbar):
            sols = [generate_one(q, temperature=TEMPERATURE, max_new_tokens=MAX_NEW_TOKENS) for j in range(m_samples)]
            preds = [extract_final_answer_strict(s) for s in sols]
            vote = Counter(preds)
            pred = max(vote.items(), key=lambda kv: (kv[1], kv[0]))[0] if vote else ""
            okv  = ok(gold, pred)
            correct += int(okv)
            recs.append({"i": k, "ds_idx": ds_idx, "problem": q, "gold": gold, "pred": pred, "ok": okv, "vote_hist": vote.most_common()})
            pbar.set_postfix({"acc": f"{correct/(k+1):.3f}"})
    acc = correct / len(rows)
    print(f"Self-Consistency done in {time.time()-t0:.1f}s | acc={acc:.3f}")
    return {"accuracy": acc, "records": recs}

In [None]:
gsm_sc     = run_self_consistency(SUBSET_ROWS)
save_jsonl("gsm8k_self_consistency_same100.jsonl", gsm_sc["records"])

Self-Consistency:   0%|          | 0/50 [00:00<?, ?problem/s]

Self-Consistency done in 7014.2s | acc=0.480
Saved: gsm8k_self_consistency_same100.jsonl | n= 50


In [None]:
def run_prm_best_of_n_with_debug(rows: List[Tuple[int,str,str]], n_samples=N_SAMPLES, debug_jsonl_path="prm_debug_steps.jsonl") -> Dict:
    open(debug_jsonl_path, "w").close()
    recs, correct = [], 0
    t0 = time.time()
    with tqdm(rows, desc="PRM BoN", unit="problem", total=len(rows)) as pbar:
        for k, (ds_idx, q, gold) in enumerate(pbar):
            out  = best_of_n_with_debug(q, n=n_samples, debug_jsonl_path=debug_jsonl_path)
            best = out["best"]
            pred = extract_final_answer_strict(best["text"])
            okv  = ok(gold, pred)
            correct += int(okv)
            recs.append({
                "i": k, "ds_idx": ds_idx, "problem": q, "gold": gold, "pred": pred, "ok": okv,
                "best_score": float(best["agg_score"]), "n_steps": best["n_steps"], "debug_file": debug_jsonl_path
            })
            pbar.set_postfix({"acc": f"{correct/(k+1):.3f}"})
    acc = correct / len(rows)
    print(f"PRM BoN done in {time.time()-t0:.1f}s | acc={acc:.3f} | debug={debug_jsonl_path}")
    return {"accuracy": acc, "records": recs}

In [None]:
gsm_prm    = run_prm_best_of_n_with_debug(SUBSET_ROWS)
save_jsonl("gsm8k_prm_best_of_n_same100.jsonl", gsm_prm["records"])

PRM BoN:   0%|          | 0/50 [00:00<?, ?problem/s]

PRM BoN done in 7350.8s | acc=0.560 | debug=prm_debug_steps.jsonl
Saved: gsm8k_prm_best_of_n_same100.jsonl | n= 50


In [None]:
print("\nAccuracies (same100):")
print("Greedy          :", round(gsm_greedy["accuracy"],3))
print("Self-Consistency:", round(gsm_sc["accuracy"],3))
print("PRM best-of-N   :", round(gsm_prm["accuracy"],3))


Accuracies (same100):
Greedy          : 0.42
Self-Consistency: 0.48
PRM best-of-N   : 0.56


In [None]:
def map_by_i(path):
    with open(path, "r") as f:
        rows = [json.loads(x) for x in f if x.strip()]
    return {r["i"]: r for r in rows}

g = map_by_i("gsm8k_greedy_same100.jsonl")
sc = map_by_i("gsm8k_self_consistency_same100.jsonl")
prm = map_by_i("gsm8k_prm_best_of_n_same100.jsonl")

common = sorted(set(g)&set(sc)&set(prm))
print("Common indices across files:", len(common))

def acc_on(m, idxs): return sum(int(m[i]["ok"]) for i in idxs)/len(idxs) if idxs else 0.0

print("Greedy acc (common):", round(acc_on(g, common),3))
print("Self-Cons acc (common):", round(acc_on(sc, common),3))
print("PRM BoN acc (common):", round(acc_on(prm, common),3))

def compare(a, b, tag):
    imp=reg=same=0
    for i in common:
        ok_a, ok_b = bool(a[i]["ok"]), bool(b[i]["ok"])
        if not ok_a and ok_b: imp += 1
        elif ok_a and not ok_b: reg += 1
        else: same += 1
    print(f"{tag}: improved={imp}, regressed={reg}, same={same}, total={len(common)}")

compare(g, prm, "PRM vs Greedy")
compare(sc, prm, "PRM vs Self-Consistency")

Common indices across files: 50
Greedy acc (common): 0.42
Self-Cons acc (common): 0.48
PRM BoN acc (common): 0.56
PRM vs Greedy: improved=9, regressed=2, same=39, total=50
PRM vs Self-Consistency: improved=6, regressed=2, same=42, total=50
