## Process Reward Model (MVP)

Train a lightweight process reward model by fine-tuning `Qwen3-0.6B-Base` (LoRA + 4-bit) on PRM800K-style chain-of-thought traces. Each reasoning step has a {-1, 0, 1} label; we tokenize step segments, mask everything except the step terminator, and apply per-step cross-entropy loss. At the end we score an unseen PRM800K test trace to show how the model rates each intermediate step it never saw during training.


In [1]:
%pip install -q transformers datasets accelerate bitsandbytes peft


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from huggingface_hub import notebook_login
notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
import itertools
import random
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model



### Configure model and LoRA


In [4]:
MODEL_ID = "Qwen/Qwen3-0.6B-Base"  # lighter backbone to avoid OOM
PRM_DATASET = "tasksource/PRM800K"  # community mirror of OpenAI PRM800K
SAMPLES = 2000
BATCH_SIZE = 1  # PRM traces are long; keep per-batch memory low
GRAD_ACCUM_STEPS = 2
MAX_STEPS_PER_SAMPLE = 20  # chunk very long traces to avoid OOM
MAX_TOKENS_PER_SAMPLE = 5500  # cap on total tokens per record
EPOCHS = 1
LR = 3e-5
SEED = 13
STEP_SEPARATOR = "\n<step>\n"
PRM_CLASS_VALUES = [-1, 0, 1]
PRM_CLASS_TO_IDX = {value: idx for idx, value in enumerate(PRM_CLASS_VALUES)}

random.seed(SEED)
torch.manual_seed(SEED)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

lora = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

In [5]:
def to_plain_text(value):
    if isinstance(value, str):
        return value
    if isinstance(value, dict):
        for key in ("text", "value", "content"):
            if key in value and isinstance(value[key], str):
                return value[key]
        return " ".join(str(v) for v in value.values())
    if isinstance(value, list):
        return " ".join(str(v) for v in value)
    return str(value)



### Build PRM800K-derived process dataset


In [6]:
def get_problem_text(example: Dict) -> str:
    question_block = example.get("question") or {}
    raw = (
        question_block.get("problem")
        or question_block.get("question")
        or question_block.get("prompt")
        or question_block.get("problem_statement")
        or question_block.get("content")
        or example.get("problem")
        or example.get("prompt")
        or ""
    )
    return to_plain_text(raw)


def get_steps_and_labels(example: Dict):
    label_block = example.get("label") or {}
    steps_struct = label_block.get("steps") or []

    steps: List[str] = []
    parsed_labels: List[int] = []

    for step in steps_struct:
        completions = step.get("completions") or []
        found = False
        for comp in completions:
            text = comp.get("text")
            rating = comp.get("rating")
            if text is None or rating is None:
                continue
            text = to_plain_text(text).strip()
            if not text:
                continue
            try:
                rating_int = int(rating)
            except (TypeError, ValueError):
                continue
            steps.append(text)
            parsed_labels.append(rating_int)
            found = True
        if found:
            continue

        # Fallback to other text/rating fields in the step
        text = step.get("human_completion") or step.get("text") or step.get("completion")
        rating = step.get("rating") or step.get("score")
        if text and rating is not None:
            text = to_plain_text(text).strip()
            if text:
                try:
                    rating_int = int(rating)
                except (TypeError, ValueError):
                    continue
                steps.append(text)
                parsed_labels.append(rating_int)

    return steps, parsed_labels


def build_prm_dataset(limit: int) -> Dataset:
    stream = load_dataset(PRM_DATASET, split="train", streaming=True)
    records = []
    for example in stream:
        if len(records) >= limit:
            break
        problem = get_problem_text(example).strip()
        steps, labels = get_steps_and_labels(example)
        if not problem or not steps or not labels:
            continue
        prompt = f"Problem: {problem}\nReasoning trace:\n"
        prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]

        # Chunk very long traces so each record stays GPU-friendly
        for start in range(0, len(steps), MAX_STEPS_PER_SAMPLE):
            if len(records) >= limit:
                break
            chunk_steps = steps[start : start + MAX_STEPS_PER_SAMPLE]
            chunk_labels = labels[start : start + MAX_STEPS_PER_SAMPLE]
            if not chunk_steps or not chunk_labels:
                continue

            input_ids = list(prompt_ids)
            attention_mask = [1] * len(input_ids)
            label_ids = [-100] * len(input_ids)

            for step_text, lbl in zip(chunk_steps, chunk_labels):
                step_payload = step_text.strip() + STEP_SEPARATOR
                encoded = tokenizer(step_payload, add_special_tokens=False)["input_ids"]
                input_ids.extend(encoded)
                attention_mask.extend([1] * len(encoded))
                step_labels = [-100] * len(encoded)
                cls_id = PRM_CLASS_TO_IDX.get(int(lbl), PRM_CLASS_TO_IDX[0]) if isinstance(lbl, int) else PRM_CLASS_TO_IDX[0]
                step_labels[-1] = cls_id
                label_ids.extend(step_labels)

            if len(input_ids) > MAX_TOKENS_PER_SAMPLE:
                continue  # skip pathological traces that would still OOM

            records.append({
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": label_ids,
            })
    if not records:
        raise ValueError("No PRM examples were loaded; check dataset path/permissions.")
    return Dataset.from_list(records[:limit])


prm_data = build_prm_dataset(SAMPLES)
print("Total PRM examples:", len(prm_data))



README.md:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

Total PRM examples: 2000


In [7]:
def show_prm_example(dataset: Dataset, idx: int):
    row = dataset[idx]
    text = tokenizer.decode(row["input_ids"], skip_special_tokens=True)

    # split off the prompt header
    if "\nReasoning trace:\n" in text:
        header, body = text.split("\nReasoning trace:\n", 1)
    else:
        header, body = text, ""

    # extract problem text
    if header.startswith("Problem:"):
        problem = header[len("Problem:"):].strip()
    else:
        problem = header.strip()

    # split the reasoning body into steps
    segments = [seg.strip() for seg in body.split(STEP_SEPARATOR) if seg.strip()]

    print(f"Example {idx}")
    print("Problem:", problem)

    step_labels = [label for label in row["labels"] if label != -100]

    for i, (segment, lbl) in enumerate(zip(segments, step_labels)):
        label_name = PRM_CLASS_VALUES[lbl] if lbl < len(PRM_CLASS_VALUES) else lbl
        print(f"\nStep {i} (label {label_name}):\n{segment}")

    print("\nTotal steps:", len(step_labels))


# show_prm_example(prm_data, 0)
show_prm_example(prm_data, 1)



Example 1
Problem: How many positive two-digit integers leave a remainder of 2 when divided by 8?

Step 0 (label 0):
Let's call our two-digit integers x.

Step 1 (label 0):
Let's first think about the remainders when we divide by 8.

Step 2 (label 1):
So we need to find the number of positive two-digit integers that are 2 more than a multiple of 8.

Step 3 (label 0):
So we're looking for numbers that are two more than a multiple of 8.

Step 4 (label 0):
So we have to find the number of integers that are two more than a multiple of 8.

Step 5 (label 0):
Let's write out the first few multiples of 8.

Step 6 (label 1):
So if a number leaves a remainder of 2 when divided by 8, it's of the form 8n+2.

Step 7 (label 1):
So we want to know the number of positive two-digit integers of the form 8n+2.

Step 8 (label 1):
I think we should just plug in numbers and see what happens.

Step 9 (label 1):
Ok let's start with n=1.

Step 10 (label 1):
8*1+2=10 which is a two-digit integer.

Step 11 (labe

### Batch collation


In [8]:
def collate(batch: List[Dict[str, List[int]]]):
    max_len = max(len(item["input_ids"]) for item in batch)
    inputs = torch.full((len(batch), max_len), tokenizer.pad_token_id, dtype=torch.long)
    attn = torch.zeros_like(inputs)
    labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
    for idx, item in enumerate(batch):
        length = len(item["input_ids"])
        inputs[idx, :length] = torch.tensor(item["input_ids"], dtype=torch.long)
        attn[idx, :length] = torch.tensor(item["attention_mask"], dtype=torch.long)
        labels[idx, :length] = torch.tensor(item["labels"], dtype=torch.long)
    return {"input_ids": inputs, "attention_mask": attn, "labels": labels}

loader = DataLoader(prm_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)



### Define LoRA process reward model


In [9]:
class ProcessRewardModel(nn.Module):
    def __init__(self):
        super().__init__()
        base = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            quantization_config=bnb,
            device_map={"": 0} if torch.cuda.is_available() else None,
            trust_remote_code=True,
        )
        base = prepare_model_for_kbit_training(base)
        base.config.use_cache = False
        self.model = get_peft_model(base, lora)
        self.head = nn.Linear(self.model.config.hidden_size, len(PRM_CLASS_VALUES))

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            use_cache=False,
            return_dict=True,
        )
        hidden = outputs.hidden_states[-1]
        logits = self.head(hidden)
        loss = None
        if labels is not None:
            mask = labels != -100
            if mask.any():
                loss = F.cross_entropy(logits[mask], labels[mask])
            else:
                loss = logits.sum() * 0
        return loss, logits

prm_model = ProcessRewardModel().to(device)
trainable_params = sum(p.numel() for p in prm_model.parameters() if p.requires_grad)
print(f"{trainable_params/1e6:.2f}M trainable params")



config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.19G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

10.10M trainable params


### Train for one epoch


In [10]:
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, prm_model.parameters()), lr=LR)

autocast_enabled = torch.cuda.is_available()

for epoch in range(EPOCHS):
    prm_model.train()
    total_loss = 0.0
    total_correct = 0
    total_steps = 0
    optimizer.zero_grad()
    for step_idx, batch in enumerate(loader):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=autocast_enabled):
            loss, logits = prm_model(**batch)
        (loss / GRAD_ACCUM_STEPS).backward()

        if (step_idx + 1) % GRAD_ACCUM_STEPS == 0 or (step_idx + 1) == len(loader):
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item()
        mask = batch["labels"] != -100
        preds = logits[mask].argmax(dim=-1)
        total_correct += (preds == batch["labels"][mask]).sum().item()
        total_steps += mask.sum().item()
        if step_idx % 100 == 0:
            print(f"epoch {epoch} step {step_idx} loss {loss.item():.4f}")
    accuracy = total_correct / max(1, total_steps)
    print(f"epoch {epoch} loss {total_loss/len(loader):.4f} step-acc {accuracy:.3f}")



  with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=autocast_enabled):
  return fn(*args, **kwargs)


epoch 0 step 0 loss 2.2344
epoch 0 step 100 loss 1.1609
epoch 0 step 200 loss 0.6680
epoch 0 step 300 loss 0.6998
epoch 0 step 400 loss 1.2358
epoch 0 step 500 loss 0.6484
epoch 0 step 600 loss 0.9303
epoch 0 step 700 loss 1.2421
epoch 0 step 800 loss 1.4570
epoch 0 step 900 loss 1.5809
epoch 0 step 1000 loss 2.1276
epoch 0 step 1100 loss 0.3245
epoch 0 step 1200 loss 1.0799
epoch 0 step 1300 loss 0.8381
epoch 0 step 1400 loss 1.0316
epoch 0 step 1500 loss 0.9979
epoch 0 step 1600 loss 1.4293
epoch 0 step 1700 loss 0.2236
epoch 0 step 1800 loss 0.7233
epoch 0 step 1900 loss 1.2313
epoch 0 loss 0.9746 step-acc 0.503


### Score an unseen PRM800K test trace


In [12]:
def encode_trace(problem: str, steps: List[str], labels: List[int]):
    prompt = f"Problem: {problem}\nReasoning trace:\n"
    prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
    input_ids = list(prompt_ids)
    attention_mask = [1] * len(input_ids)
    label_ids = [-100] * len(input_ids)
    step_boundaries = []
    for step_text, lbl in zip(steps, labels):
        step_payload = step_text.strip() + STEP_SEPARATOR
        encoded = tokenizer(step_payload, add_special_tokens=False)["input_ids"]
        input_ids.extend(encoded)
        attention_mask.extend([1] * len(encoded))
        step_labels = [-100] * len(encoded)
        cls_id = PRM_CLASS_TO_IDX.get(int(lbl), PRM_CLASS_TO_IDX[0])
        step_labels[-1] = cls_id
        step_boundaries.append(len(input_ids) - 1)
        label_ids.extend(step_labels)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": label_ids,
        "boundaries": step_boundaries,
    }


def score_unseen_trace(model, max_search: int = 1000):
    test_stream = load_dataset(PRM_DATASET, split="test", streaming=True)
    sample = None
    target_idx = random.randint(0, max_search)
    for idx, item in enumerate(test_stream):
        if idx == target_idx:
            sample = item
            break
    if sample is None:
        raise ValueError("Could not fetch a test example; try increasing max_search.")

    problem = get_problem_text(sample).strip()
    steps, labels = get_steps_and_labels(sample)
    example = encode_trace(problem, steps, labels)

    batch = {
        "input_ids": torch.tensor(example["input_ids"], dtype=torch.long).unsqueeze(0).to(device),
        "attention_mask": torch.tensor(example["attention_mask"], dtype=torch.long).unsqueeze(0).to(device),
        "labels": torch.tensor(example["labels"], dtype=torch.long).unsqueeze(0).to(device),
    }

    model.eval()
    with torch.no_grad():
        _, logits = model(**batch)
        probs = torch.softmax(logits[0], dim=-1)

    print("Problem:", problem[:300], "...")
    for idx, (step_text, lbl) in enumerate(zip(steps, labels)):
        token_idx = example["boundaries"][idx]
        label_name = PRM_CLASS_VALUES[int(lbl)] if int(lbl) in PRM_CLASS_VALUES else lbl
        prob_vec = probs[token_idx].cpu().tolist()
        print(f"\nStep {idx} (label {label_name}):")
        print(step_text.strip())
        print("Predicted probs:", {cls: round(prob_vec[PRM_CLASS_TO_IDX[cls]], 3) for cls in PRM_CLASS_VALUES})

score_unseen_trace(prm_model)



Problem: Let $f(x)=\left\lfloor\left(-\frac58\right)^x\right\rfloor$ be a function that is defined for all values of $x$ in $[0,\infty)$ such that $f(x)$ is a real number. How many distinct values exist in the range of $f(x)$? ...

Step 0 (label -1):
This is a challenging problem that involves exponents and floor functions.
Predicted probs: {-1: 0.163, 0: 0.404, 1: 0.433}

Step 1 (label 0):
I will start by trying to understand the behavior of the function $f(x)$ for different values of $x$.
Predicted probs: {-1: 0.093, 0: 0.554, 1: 0.353}

Step 2 (label 0):
First, I notice that the base of the exponent is negative, so $f(x)$ will alternate between positive and negative values depending on the parity of $x$.
Predicted probs: {-1: 0.218, 0: 0.453, 1: 0.329}

Step 3 (label 0):
For example, if $x$ is even, then $f(x)$ will be positive, and if $x$ is odd, then $f(x)$ will be negative.
Predicted probs: {-1: 0.137, 0: 0.424, 1: 0.438}

Step 4 (label 0):
Next, I notice that the base of the exp