In [70]:
import torch
from torch import Tensor

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

In [71]:
SYSTEM_MESSAGE = """
You are a helpful AI Assistant who solves problems. Solve each problem step by step, ensuring that:
1. Each reasoning step is separated by a blank line for clarity.
2. The final answer is formatted as: \\boxed{<answer>}.
Always follow this format when solving problems.
"""

problem = "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$"
answer = "\\left( 3, \\frac{\\pi}{2} \\right)"

In [None]:
class BeamSearch:
    def __init__(
        self,
        width=5,
        new_samples_per_beam=5,
        max_new_tokens=2500,
        max_generation_rounds=10,
        device=None,
    ):
        self.width = width
        self.new_samples_per_beam = new_samples_per_beam
        self.max_new_tokens = max_new_tokens
        self.max_generation_rounds = max_generation_rounds

        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device

        quant_config = BitsAndBytesConfig(load_in_8bit=True)

        self.generator = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-3.2-3B-Instruct",
            quantization_config=quant_config,
            token="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",
        )

        self.prm = AutoModelForCausalLM.from_pretrained(
            "mtzig/prm800k_llama_debug_full",
            token="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",
            quantization_config=quant_config,
        )

        self.tokenizer = AutoTokenizer.from_pretrained(
            "meta-llama/Llama-3.2-3B-Instruct",
            token="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",
        )

        self.generator = self_generator
        self.prm = self_prm
        self.tokenizer = self_tokenizer

        self.step_strs = []
        self.step_ids = []

        for t, i in self.tokenizer.vocab.items():
            if "\n\n" in t or "ĊĊ" in t:
                self.step_strs.append(t)
                self.step_ids.append(i)

        self.prm_step_id = self.tokenizer.encode(" ки", add_special_tokens=False)[0]
        self.eos_id = self.tokenizer.encode(
            "<|end_of_text|>", add_special_tokens=False
        )[0]

        self.candidate_ids = [648, 387]
        self.pad_id = self.tokenizer.eos_token_id
        self.generator.generation_config.pad_token_id = self.pad_id
        self.generator.generation_config.eos_token_id = self.step_ids + [
            self.eos_id,
            self.generator.generation_config.eos_token_id,
        ]

    def encode(self, question: str) -> Tensor:
        messages = [
            {"role": "system", "content": SYSTEM_MESSAGE},
            {"role": "user", "content": problem},
        ]

        input_ids = self.tokenizer.apply_chat_template(
            messages,
            return_tensors="pt",
            add_generation_prompt=True,
        )

        return input_ids.repeat(self.width, 1).to(
            self.device
        )  # start with `self.width` beams initialized to the prompt

    def convert_to_left_padding(self, x: Tensor) -> Tensor:
        converted_rows = []
        for row in x:
            pad_idxs = row == self.pad_id
            converted_rows.append(torch.cat((row[pad_idxs], row[~pad_idxs])))

        return torch.stack(converted_rows)

    def prepare_prm_inputs(self, output_ids: Tensor, initial_prompt_len: int) -> Tensor:
        input_ids = output_ids.clone()

        mask = torch.zeros(input_ids.shape, dtype=torch.bool, device=input_ids.device)
        for i in self.step_ids:
            mask |= input_ids == i

        mask |= input_ids == self.eos_id

        ignore_prefix_lens = (
            torch.sum((input_ids == self.pad_id).long(), dim=1) + initial_prompt_len
        )
        idxs = torch.arange(input_ids.shape[1], device=input_ids.device)
        idxs = idxs.repeat(len(mask)).view(len(mask), -1)
        response_mask = (idxs.T > ignore_prefix_lens).T

        mask &= response_mask
        input_ids[mask] = self.prm_step_id
        return input_ids

    def compute_scores(self, logits: Tensor) -> list[float]:
        return logits.softmax(dim=-1)[:, -1, 0].cpu().tolist()

    def __call__(self, question: str) -> str:
        gen_input_ids = self.encode(question)
        initial_prompt_len = gen_input_ids.shape[1]

        width = self.width
        completed_beams: list[tuple[Tensor, float]] = []

        for _ in range(self.max_generation_rounds):
            attention_mask = (gen_input_ids != self.pad_id).long()

            gen_output_ids = self.generator.generate(
                input_ids=gen_input_ids,
                attention_mask=attention_mask,
                do_sample=True,
                max_new_tokens=self.max_new_tokens,
                max_length=None,
                num_beams=width,
                num_return_sequences=width,
                eos_token_id=self.step_ids + [self.eos_id],
                pad_token_id=self.pad_id,
                use_cache=True,
            )

            gen_output_ids = self.convert_to_left_padding(gen_output_ids)

            prm_input_ids = self.prepare_prm_inputs(gen_output_ids, initial_prompt_len)

            with torch.no_grad():
                logits = self.prm(prm_input_ids).logits[:, :, self.candidate_ids]
                scores = self.compute_scores(logits)

            sorted_scored_idxs = sorted(
                enumerate(scores), key=lambda t: t[1], reverse=True
            )
            best_idxs = [i for i, _ in sorted_scored_idxs]

            gen_input_ids = gen_output_ids[best_idxs[:width]]

        return [self.tokenizer.decode(x) for x in gen_input_ids]

In [146]:
outputs = BeamSearch(width=2, new_samples_per_beam=2, max_generation_rounds=7)(problem)

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [147]:
for output in outputs:
    print(output)

<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 11 Dec 2024

You are a helpful AI Assistant who solves problems. Solve each problem step by step, ensuring that:
1. Each reasoning step is separated by a blank line for clarity.
2. The final answer is formatted as: \boxed{<answer>}.
Always follow this format when solving problems.<|start_header_id|>user<|end_header_id|>

Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$<|start_header_id|>assistant<|end_header_id|>

Step 1: To convert the point $(0,3)$ from rectangular coordinates to polar coordinates, we need to 