## Outcome Reward Model (MVP)

Train a minimal outcome reward model by fine-tuning `Qwen3-1.7B-Base` with LoRA on GSM8K-derived correct/incorrect math answers. For each question we parse the gold numeric answer, add a small random offset to synthesize the wrong completion, then apply per-token BCE loss on completion tokens. At the end we score a fresh GSM8K test question to show how the trained model ranks a correct vs. corrupted completion it never saw during training.


In [None]:
%pip install -q transformers datasets accelerate bitsandbytes peft


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h

In [None]:
# connect to huggingface
from huggingface_hub import notebook_login
notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import random
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model



### Configure model and LoRA


In [None]:
MODEL_ID = "Qwen/Qwen3-1.7B-Base"
DATASET = "gsm8k"
SAMPLES = 200
BATCH_SIZE = 4
EPOCHS = 1
LR = 5e-5
SEED = 7

random.seed(SEED)
torch.manual_seed(SEED)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

lora = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])



### Build GSM8K-derived outcome dataset


In [None]:
def parse_answer(text: str):
    if "####" in text:
        tail = text.split("####")[-1]
    else:
        sentences = [seg.strip() for seg in text.strip().split("\n") if seg.strip()]
        tail = sentences[-1] if sentences else text
    tokens = tail.replace(",", "").split()
    for token in reversed(tokens):
        digits = "".join(ch for ch in token if ch.isdigit() or ch == "-")
        if digits:
            try:
                return int(digits)
            except ValueError:
                continue
    return None


def pack(prompt: str, completion: str, label: int):
    prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
    completion_ids = tokenizer(completion + tokenizer.eos_token, add_special_tokens=False)["input_ids"]
    input_ids = prompt_ids + completion_ids
    attention = [1] * len(input_ids)
    labels = [-100] * len(prompt_ids) + [label] * len(completion_ids)
    return {"input_ids": input_ids, "attention_mask": attention, "labels": labels}


def build_dataset(limit: int) -> Dataset:
    raw = load_dataset(DATASET, "main", split=f"train[:{limit}]")
    rows = []
    for ex in raw:
        question = ex["question"].strip()
        prompt = f"Question: {question}\nAnswer:"
        answer = ex["answer"].strip()
        value = parse_answer(answer)
        if value is None:
            continue
        rows.append(pack(prompt, answer, 1))
        wrong = value + random.randint(1, 9)
        wrong_solution = answer + f"\nTherefore, the answer is {wrong}."
        rows.append(pack(prompt, wrong_solution, 0))
    return Dataset.from_list(rows)


data = build_dataset(SAMPLES)
print(len(data))
print(tokenizer.decode(data[0]["input_ids"][:80]))


400
Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Answer:Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<4


In [None]:
# show examples
def show_dataset_example(dataset, idx):
    row = dataset[idx]
    text = tokenizer.decode(row["input_ids"], skip_special_tokens=True)
    if "Answer:" in text:
        prompt_part, completion_part = text.split("Answer:", 1)
    else:
        prompt_part, completion_part = text, ""
    label = int(row["labels"][-1])
    question = prompt_part.replace("Question:", "").strip()
    completion = completion_part.strip()
    final_value = parse_answer(completion)
    print(f"Example {idx} | label {label} ({'correct' if label == 1 else 'incorrect'})")
    print("Question:", question)
    print("Completion:", completion)
    print("Extracted answer:", final_value)
    print("-" * 60)

show_dataset_example(data, 2)
show_dataset_example(data, 3)



Example 2 | label 1 (correct)
Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
Completion: Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
#### 10
Extracted answer: 10
------------------------------------------------------------
Example 3 | label 0 (incorrect)
Question: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
Completion: Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
#### 10
Therefore, the answer is 13.
Extracted answer: 13
------------------------------------------------------------


### Batch collation


In [None]:
def collate(batch):
    max_len = max(len(x["input_ids"]) for x in batch)
    inputs = torch.full((len(batch), max_len), tokenizer.pad_token_id, dtype=torch.long)
    attn = torch.zeros_like(inputs)
    labels = torch.full((len(batch), max_len), -100, dtype=torch.long)
    for i, item in enumerate(batch):
        length = len(item["input_ids"])
        inputs[i, :length] = torch.tensor(item["input_ids"], dtype=torch.long)
        attn[i, :length] = torch.tensor(item["attention_mask"], dtype=torch.long)
        labels[i, :length] = torch.tensor(item["labels"], dtype=torch.long)
    return {"input_ids": inputs, "attention_mask": attn, "labels": labels}

loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)



### Define LoRA outcome reward model


In [None]:
class OutcomeRewardModel(nn.Module):
    def __init__(self):
        super().__init__()
        base = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            quantization_config=bnb,
            device_map={"": 0} if torch.cuda.is_available() else None,
            trust_remote_code=True,
        )
        base = prepare_model_for_kbit_training(base)
        base.config.use_cache = False
        self.model = get_peft_model(base, lora)
        self.head = nn.Linear(self.model.config.hidden_size, 1, bias=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            use_cache=False,
            return_dict=True,
        )
        hidden = outputs.hidden_states[-1]
        logits = self.head(hidden).squeeze(-1)
        loss = None
        if labels is not None:
            mask = labels != -100
            if mask.any():
                loss = F.binary_cross_entropy_with_logits(logits[mask], labels[mask].float())
            else:
                loss = logits.sum() * 0
        return loss, logits

orm_model = OutcomeRewardModel().to(device)
print(sum(p.numel() for p in orm_model.parameters() if p.requires_grad)/1e6, "M trainable params")


17.434625 M trainable params


### Train for one epoch


In [None]:
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, orm_model.parameters()), lr=LR)

for epoch in range(EPOCHS):
    orm_model.train()
    total_loss = 0.0
    total_correct = 0
    total_tokens = 0
    for step, batch in enumerate(loader):
        batch = {k: v.to(device) for k, v in batch.items()}
        loss, logits = orm_model(**batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
        mask = batch["labels"] != -100
        preds = (torch.sigmoid(logits[mask]) > 0.5).long()
        total_correct += (preds == batch["labels"][mask]).sum().item()
        total_tokens += mask.sum().item()
        if step % 10 == 0:
            print(f"epoch {epoch} step {step} loss {loss.item():.4f}")
    print(f"epoch {epoch} loss {total_loss/len(loader):.4f} acc {total_correct/total_tokens:.3f}")


epoch 0 step 0 loss 0.9715
epoch 0 step 10 loss 1.0870
epoch 0 step 20 loss 0.6629
epoch 0 step 30 loss 0.8053
epoch 0 step 40 loss 0.7621
epoch 0 step 50 loss 0.6903
epoch 0 step 60 loss 0.8525
epoch 0 step 70 loss 0.6804
epoch 0 step 80 loss 0.5791
epoch 0 step 90 loss 0.7675
epoch 0 loss 0.7564 acc 0.486


### Score an unseen GSM8K test question


In [None]:
def score_unseen_pair(model, seed: int = SEED):
    test_index = random.randint(0, 1000)
    sample = load_dataset(DATASET, "main", split=f"test[{test_index}:{test_index + 1}]")[0]
    question = sample["question"].strip()
    prompt = f"Question: {question}\nAnswer:"
    answer = sample["answer"].strip()
    value = parse_answer(answer)
    if value is None:
        raise ValueError("Unable to parse numeric answer from GSM8K sample")
    wrong_value = value + random.randint(1, 9)
    wrong_answer = answer + f"\nTherefore, the answer is {wrong_value}."

    rows = [pack(prompt, answer, 1), pack(prompt, wrong_answer, 0)]
    batch = collate(rows)
    batch = {k: v.to(device) for k, v in batch.items()}

    model.eval()
    with torch.no_grad():
        _, logits = model(**batch)
        probs = torch.sigmoid(logits)

    print("Question:", question)
    for idx, row in enumerate(rows):
        text = tokenizer.decode(row["input_ids"], skip_special_tokens=True)
        completion = text.split("Answer:", 1)[-1].strip()
        mask = batch["labels"][idx] != -100
        completion_prob = probs[idx][mask].mean().item()
        label = "correct" if idx == 0 else "incorrect"
        print(f"\nCompletion ({label}):\n{completion}")
        print(f"Average token prob: {completion_prob:.3f}")

score_unseen_pair(orm_model)


Question: An electronics seller bought 5 phones for $700 each and gives the seller $4000 in dollar bills. How much will the seller give back in change?

Completion (correct):
The seller bought the 5 phones for $700 * 5 = $<<700*5=3500>>3500.
So the seller gives back $4000-$3500 = $<<4000-3500=500>>500.
#### 500
Average token prob: 0.435

Completion (incorrect):
The seller bought the 5 phones for $700 * 5 = $<<700*5=3500>>3500.
So the seller gives back $4000-$3500 = $<<4000-3500=500>>500.
#### 500
Therefore, the answer is 505.
Average token prob: 0.379
