In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# Load Models and Tokenizer

In [None]:
quant_config = BitsAndBytesConfig(load_in_8bit=True)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct",
    quantization_config=quant_config,
    token="...",
)

prm = AutoModelForCausalLM.from_pretrained(
    "mtzig/prm800k_llama_debug_full",
    token="...",
    quantization_config=quant_config,
)

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct",
    token="...",
)

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

# Setup Prompt

In [4]:
system_message = """
You are a helpful AI Assistant who solves problems. Solve each problem step by step, ensuring that:
1. Each reasoning step is separated by a blank line for clarity.
2. The final answer is formatted as: \\boxed{<answer>}.
Always follow this format when solving problems.
"""

problem = "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$"
answer = "\\left( 3, \\frac{\\pi}{2} \\right)"

messages = [
    {"role": "system", "content": system_message},
    {"role": "user", "content": problem},
]

In [None]:
print(
    tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 07 Dec 2024

You are a helpful AI Assistant who solves problems. Solve each problem step by step, ensuring that:
1. Each reasoning step is separated by a blank line for clarity.
2. The final answer is formatted as: \boxed{<answer>}.
Always follow this format when solving problems.<|eot_id|><|start_header_id|>user<|end_header_id|>

Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$<|eot_id|><|start_header_id|>assistant<|end_header_id|>




# Generate from Model

In [None]:
inputs = tokenizer.apply_chat_template(
    messages, return_tensors="pt", add_generation_prompt=True
)
outputs = model.generate(
    inputs.to("cuda"),
    max_new_tokens=2500,
    max_length=None,
    num_beams=5,
    num_return_sequences=5,
    use_cache=True,
)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [7]:
print(tokenizer.decode(outputs[0]))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 07 Dec 2024

You are a helpful AI Assistant who solves problems. Solve each problem step by step, ensuring that:
1. Each reasoning step is separated by a blank line for clarity.
2. The final answer is formatted as: \boxed{<answer>}.
Always follow this format when solving problems.<|eot_id|><|start_header_id|>user<|end_header_id|>

Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Step 1: To convert the point $(0,3)$ from rectangular coordinates to polar coordinates, we need to find the radius $r$ and the angle $\theta$.


Step 2: The radius $r$ can be calculated using the formula $r = \sqrt{x^2 + y^2}$. In this case, $x = 0$ and $y = 3$, so $r = \sqrt{0^2 + 3^2} = \sqrt{9} = 3$.


Step 3: The angle $\theta$ can b

# Make Outputs Ready for PRM

In [8]:
step_strs = []
step_ids = []

for t, i in tokenizer.vocab.items():
    if "\n\n" in t or "ĊĊ" in t:
        step_strs.append(t)
        step_ids.append(i)

prm_step_id = tokenizer.encode(" ки", add_special_tokens=False)[0]
eot_id = tokenizer.encode("<|end_of_text|>", add_special_tokens=False)[0]

In [None]:
x = outputs.clone()
mask = torch.zeros(x.shape, dtype=torch.bool, device=x.device)
for i in step_ids:
    mask |= x == i

response_mask = torch.arange(x.shape[1], device=x.device) > inputs.shape[1]
mask[
    torch.arange(len(mask), dtype=x.dtype),
    (torch.argmax((x[:, response_mask] == eot_id).byte(), dim=1) + inputs.shape[1]),
] = True
mask &= response_mask
x[mask] = prm_step_id

In [10]:
print(tokenizer.decode(x[0]).replace("ки", "\nки\n"))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 07 Dec 2024

You are a helpful AI Assistant who solves problems. Solve each problem step by step, ensuring that:
1. Each reasoning step is separated by a blank line for clarity.
2. The final answer is formatted as: \boxed{<answer>}.
Always follow this format when solving problems.<|eot_id|><|start_header_id|>user<|end_header_id|>

Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Step 1: To convert the point $(0,3)$ from rectangular coordinates to polar coordinates, we need to find the radius $r$ and the angle $\theta$ 
ки
Step 2: The radius $r$ can be calculated using the formula $r = \sqrt{x^2 + y^2}$. In this case, $x = 0$ and $y = 3$, so $r = \sqrt{0^2 + 3^2} = \sqrt{9} = 3$ 
ки
Step 3: The angle $\theta$ can

# Get Scores From PRM

In [11]:
candidate_ids = [648, 387]

with torch.no_grad():
    logits = prm(x).logits[:, :, candidate_ids]
    scores = logits.softmax(dim=-1)[:, :, 0]

In [12]:
for i in range(len(x)):
    s = scores[i, mask[i]]
    print(f"average score: {torch.mean(s).item():.3f}, last score: {s[-1].item():.3f}")

average score: 0.944, last score: 0.925
average score: 0.944, last score: 0.925
average score: 0.946, last score: 0.938
average score: 0.944, last score: 0.967
average score: 0.946, last score: 0.938
