In [1]:
import os
import json
import torch
import torch.nn as nn
import re
import random
import numpy as np
from collections import Counter
from tqdm.auto import tqdm
from typing import List, Dict, Tuple, Optional
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16

In [3]:
K_PATHS = 16
MAX_NEW_TOKENS = 512
TEMPERATURE = 0.6

In [4]:
POOL_FILE = "fixed_pool_v2_fewshot_full_test.jsonl"
RESULTS_FILE = "final_scientific_results_v2_full_test.json"

In [5]:
def extract_final_answer_strict(text: str) -> str:
    if not text: return ""
    # Look for explicit header
    m = re.search(r"Final\s*Answer\s*:\s*(.+)", text, flags=re.I)
    ans = m.group(1).strip() if m else text.strip()

    # Cleanup
    ans = ans.replace(",", "")
    while ans.startswith("$"): ans = ans[1:].lstrip()
    ans = re.sub(r"\s*(mph|km/h|units?|percent|percentage|%)\s*$", "", ans, flags=re.I)

    # Extract last number
    nums = re.findall(r"-?\d+(?:\.\d+)?", ans)
    return nums[-1] if nums else ans

def normalize_answer(s: Optional[str]) -> str:
    if s is None: return ""
    s = str(s).strip().replace(",", "")
    while s.startswith("$"): s = s[1:].lstrip()

    # Handle fractions
    if re.fullmatch(r"\d+/\d+", s):
        try:
            n, d = s.split("/")
            return str(float(n)/float(d))
        except: pass

    nums = re.findall(r"-?\d+(?:\.\d+)?", s)
    return nums[-1] if nums else s

def check_correctness(gold: str, pred: str) -> bool:
    g = normalize_answer(gold)
    p = normalize_answer(pred)
    try:
        return abs(float(g) - float(p)) <= 1e-6
    except:
        return g == p

In [6]:
STEP_HEADER_PATTERN = re.compile(
    r"^(?P<header>(?:\s*Step\s*\d+|\*\*Step\s*\d+\*\*|####\s*Step\s*\d+)[:.)-]?\s*)",
    re.IGNORECASE | re.MULTILINE
)

def split_into_steps(text: str) -> List[str]:
    if not text: return []
    matches = list(STEP_HEADER_PATTERN.finditer(text))

    if not matches:
        # Fallback to sentence splitting if no explicit steps
        return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]

    steps = []
    first_start = matches[0].start()
    intro = text[:first_start].strip()

    for i, match in enumerate(matches):
        start = match.start()
        end = matches[i+1].start() if i + 1 < len(matches) else len(text)
        content = text[start:end].strip()

        if i == 0 and intro:
            header = match.group("header")
            body = content[len(header):]
            content = f"{header} {intro} {body}"

        steps.append(content)
    return steps


In [7]:
def format_step_for_prm(problem: str, prev_steps: List[str], current_step: str) -> str:
    _clean = lambda s: re.sub(r"^\s*Step\s*\d+[:.)-]\s*", "", s, flags=re.I).strip()

    prev_str = " ".join(_clean(s) for s in prev_steps).strip()
    curr_str = _clean(current_step)

    if prev_str:
        return f"Problem: {problem}\n\nPrevious steps: {prev_str}\nCurrent step: {curr_str}\n\nIs this step correct?"
    else:
        return f"Problem: {problem}\nCurrent step: {curr_str}\nIs this step correct?"

In [8]:
GENERATOR_MODEL_ID = "Qwen/Qwen3-0.6B"
class Generator:
    def __init__(self):
        print(f"Loading Generator: {GENERATOR_MODEL_ID}...")
        self.tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_ID, use_fast=True, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            GENERATOR_MODEL_ID,
            trust_remote_code=True,
            device_map=DEVICE,
            dtype=DTYPE
        )
        self.model.eval()

    def build_few_shot_prompt(self, problem: str) -> str:
        demo_q = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
        demo_a = (
            "Step 1: Natalia sold 48 clips in April.\n"
            "Step 2: In May, she sold half as many as in April, so she sold 48 / 2 = 24 clips.\n"
            "Step 3: To find the total, we add the clips sold in April and May: 48 + 24 = 72.\n"
            "Final Answer: 72"
        )

        return (
            "You are a precise math solver. Solve the problem using clear steps starting with 'Step 1:', 'Step 2:', etc. End with 'Final Answer:'.\n\n"
            f"Example Problem: {demo_q}\n"
            "Example Solution:\n"
            f"{demo_a}\n\n"
            f"Problem: {problem}\n"
            "Solution:\n"
        )

    def generate_batch(self, problem: str, n: int) -> List[str]:
        # UPGRADED: Use Few-Shot Prompt
        prompt = self.build_few_shot_prompt(problem)

        inputs = self.tokenizer([prompt] * n, return_tensors="pt").to(DEVICE)

        with torch.inference_mode():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                do_sample=True,
                top_p=0.95,
                pad_token_id=self.tokenizer.eos_token_id
            )

        decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        results = []
        for text in decoded:
            if prompt in text:
                 # Best way: strict slicing
                 results.append(text[len(prompt):].strip())
            else:
                 # Fallback: Split by last "Solution:\n"
                 parts = text.rsplit("Solution:\n", 1)
                 results.append(parts[-1].strip() if len(parts) > 1 else text)
        return results

In [9]:
PRM_ADAPTER_ID = "devangb4/prm-qwen3-8b-bf16-6k"
BASE_PRM_MODEL = "Qwen/Qwen3-8B"
class PRMScanner:
    def __init__(self):
        print(f"Loading PRM Base: {BASE_PRM_MODEL}...")
        self.tokenizer = AutoTokenizer.from_pretrained(PRM_ADAPTER_ID, trust_remote_code=True)
        self.tokenizer.padding_side = "left"
        base = AutoModelForCausalLM.from_pretrained(
            BASE_PRM_MODEL,
            dtype=DTYPE,
            device_map=DEVICE,
            trust_remote_code=True
        )
        base.config.return_dict = True
        base.config.use_cache = False

        print(f"Loading LoRA Adapter: {PRM_ADAPTER_ID}...")
        self.model = PeftModel.from_pretrained(base, PRM_ADAPTER_ID)

        try:
            from huggingface_hub import snapshot_download
            local_dir = snapshot_download(PRM_ADAPTER_ID)
            head_path = os.path.join(local_dir, "prm_head.bin")

            hidden_size = getattr(self.model.config, "hidden_size", None) or getattr(self.model.config, "hidden_sizes", [None])[0]
            assert hidden_size is not None, "Could not infer hidden_size from config."

            self.head = nn.Linear(hidden_size, 1).to(DEVICE, dtype=DTYPE)
            self.head.load_state_dict(torch.load(head_path, map_location="cpu"))
            self.head.eval()
            print("PRM Head loaded successfully.")
        except Exception as e:
            print(f"CRITICAL WARNING: Could not load PRM head. {e}")
            raise e

    def _truncate_at_final_answer(self, text: str) -> str:
        patterns = [
            r"(Final\s*Answer\s*:.+?)(\n\n|\Z)",
            r"(####\s*.+?)(\n\n|\Z)",
            r"(The final answer is.+?)(\n\n|\Z)"
        ]
        for pat in patterns:
            match = re.search(pat, text, flags=re.IGNORECASE | re.DOTALL)
            if match:
                return text[:match.end()].strip()
        return text

    @torch.inference_mode()
    def score_path(self, problem: str, solution_text: str) -> float:
        # 1. Truncate garbage
        clean_text = self._truncate_at_final_answer(solution_text)

        # 2. Split steps
        steps = split_into_steps(clean_text)
        if not steps: return -100.0

        prompts = []
        prev_steps = []
        for step in steps:
            if not step.strip(): continue
            prompts.append(format_step_for_prm(problem, prev_steps, step))
            prev_steps.append(step)

        if not prompts: return -100.0

        # Batch score
        inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
        outputs = self.model(**inputs, output_hidden_states=True, use_cache=False)

        last_hidden = outputs.hidden_states[-1]
        h_last = last_hidden[:, -1, :]

        logits = self.head(h_last.to(self.head.weight.dtype))
        probs = torch.sigmoid(logits.squeeze(-1).to(torch.float32)).tolist()

        # 3. Aggregation (Product of probabilities)
        clean_probs = [max(p, 1e-6) for p in probs]
        score = np.exp(np.sum(np.log(clean_probs)))

        return float(score)

In [10]:
ds = load_dataset("gsm8k", "main", split="test")

problems = []
for i in range(len(ds)):
    row = ds[i]
    m = re.search(r"####\s*(.+)", row['answer'])
    gold = m.group(1).strip() if m else ""
    problems.append({
        "id": i,
        "question": row['question'],
        "gold_answer": gold
    })

N_PROBLEMS = len(problems)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [12]:
POOL_FILE_DRIVE = "/content/drive/MyDrive/prm_project/Fixed_Pool/fixed_pool_v2_fewshot_full_test.jsonl"
if not os.path.exists(POOL_FILE_DRIVE):
    print(f"Generating Fixed Pool ({N_PROBLEMS} probs * {K_PATHS} paths)...")
    generator = Generator()

    pool_data = []
    for p_data in tqdm(problems, desc="Generating"):
        candidates = generator.generate_batch(p_data['question'], K_PATHS)
        entry = {
            "id": p_data['id'],
            "question": p_data['question'],
            "gold_answer": p_data['gold_answer'],
            "candidates": candidates
        }
        pool_data.append(entry)

    del generator
    torch.cuda.empty_cache()

    with open(POOL_FILE, 'w') as f:
        for entry in pool_data:
            f.write(json.dumps(entry) + "\n")
    print(f"Pool saved to {POOL_FILE}")
else:
    print(f"Loading existing pool from {POOL_FILE_DRIVE}")
    with open(POOL_FILE_DRIVE, 'r') as f:
        pool_data = [json.loads(line) for line in f]

Loading existing pool from /content/drive/MyDrive/prm_project/Fixed_Pool/fixed_pool_v2_fewshot_full_test.jsonl


In [13]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [14]:
print(f"Phase 2: Running Self-Consistency Baseline on Full Test Set (N={len(pool_data)})")
sc_correct = 0

pbar = tqdm(pool_data, desc="SC Eval Full")
for i, entry in enumerate(pbar):
    answers = [extract_final_answer_strict(c) for c in entry['candidates']]
    valid_answers = [a for a in answers if a]

    if not valid_answers:
        prediction = ""
    else:
        counts = Counter(valid_answers)
        prediction = counts.most_common(1)[0][0]

    if check_correctness(entry['gold_answer'], prediction):
        sc_correct += 1

    pbar.set_postfix(acc=f"{sc_correct / (i + 1):.2%}")

sc_acc = sc_correct / len(pool_data)
print(f"\n>>> SELF-CONSISTENCY ACCURACY: {sc_acc:.2%}")

Phase 2: Running Self-Consistency Baseline on Full Test Set (N=1319)


SC Eval Full:   0%|          | 0/1319 [00:00<?, ?it/s]


>>> SELF-CONSISTENCY ACCURACY: 33.28%


In [15]:
prm = PRMScanner()

Loading PRM Base: Qwen/Qwen3-8B...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/728 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

model-00005-of-00005.safetensors:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

model-00003-of-00005.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00004-of-00005.safetensors:   0%|          | 0.00/3.19G [00:00<?, ?B/s]

model-00001-of-00005.safetensors:   0%|          | 0.00/4.00G [00:00<?, ?B/s]

model-00002-of-00005.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Loading LoRA Adapter: devangb4/prm-qwen3-8b-bf16-6k...


adapter_config.json:   0%|          | 0.00/876 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/61.4M [00:00<?, ?B/s]

Fetching 18 files:   0%|          | 0/18 [00:00<?, ?it/s]

.gitattributes: 0.00B [00:00, ?B/s]

export_metadata.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

prm_head_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

scheduler.pt:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/123M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

prm_head.bin:   0%|          | 0.00/10.1k [00:00<?, ?B/s]

trainer_state.json: 0.00B [00:00, ?B/s]

PRM Head loaded successfully.


In [16]:
prm_correct = 0
pbar = tqdm(pool_data, desc="PRM Eval")
for i, entry in enumerate(pbar):
    candidates = entry['candidates']
    best_score = -1.0
    best_candidate = candidates[0]

    for cand in candidates:
        try:
            score = prm.score_path(entry['question'], cand)
        except Exception as e:
            print(f"ERROR: {e}")
            score = 0.0

        if score > best_score:
            best_score = score
            best_candidate = cand

    prediction = extract_final_answer_strict(best_candidate)

    if check_correctness(entry['gold_answer'], prediction):
        prm_correct += 1

    pbar.set_postfix(acc=f"{prm_correct / (i + 1):.2%}")

prm_acc = prm_correct / len(pool_data)
print(f"PRM Accuracy: {prm_acc}")

PRM Eval:   0%|          | 0/1319 [00:00<?, ?it/s]

PRM Accuracy: 0.4040940106141016


In [17]:

print(f"N_PROBLEMS: {len(pool_data)}")
print(f"K_PATHS:    {K_PATHS}")
print("-" * 20)
print(f"Self-Consistency: {sc_acc:.4f} ({sc_correct}/{len(pool_data)})")
print(f"PRM Best-of-N:    {prm_acc:.4f} ({prm_correct}/{len(pool_data)})")
print("-" * 20)

delta = prm_acc - sc_acc
if delta > 0.02:
    print(f"RESULT: +{delta:.2%} improvement.")
elif delta < -0.02:
    print(f"RESULT: {delta:.2%} regression.")
else:
    print(f"RESULT: INCONCLUSIVE (Noise). Delta: {delta:.2%}")

N_PROBLEMS: 1319
K_PATHS:    16
--------------------
Self-Consistency: 0.3328 (439/1319)
PRM Best-of-N:    0.4041 (533/1319)
--------------------
RESULT: +7.13% improvement.
