In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [12]:
class BestOfNReasoner:
    def __init__(self, model_id: str):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map=self.device
        )
        # Ensure pad token is set for batch generation
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def generate_candidates(self, prompt: str, n: int = 3):
        """
        The Generator: Creates N distinct chains of thought using high temperature.
        """
        print(f"--- Generating {n} candidates ---")

        system_prompt = (
            "You are a deep reasoning AI. You must NOT answer immediately. "
            "First, use a <think></think> block to break down the problem, check for edge cases, "
            "and verify your logic step-by-step. "
            "Only after thinking should you provide the final answer (after the ending </think> tag)."
        )

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ]

        formatted_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        input_ids = (
            self
            .tokenizer(formatted_prompt, return_tensors="pt")
            .input_ids
            .to(self.device)
        )

        forced_prefix = "<think>\n"
        prefix_ids = (
            self
            .tokenizer(forced_prefix, add_special_tokens=False, return_tensors="pt")
            .input_ids
            .to(self.device)
        )

        input_ids = torch.cat([input_ids, prefix_ids], dim=1)

        # num_return_sequences=n generates N variations in parallel
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=512,
            temperature=0.9,      # High temp to ensure diversity in reasoning
            do_sample=True,
            num_return_sequences=n,
            pad_token_id=self.tokenizer.pad_token_id
        )

        candidates = []
        for out in outputs:
            # Decode and reconstruct the full text including the forced <think>
            text = self.tokenizer.decode(out, skip_special_tokens=True)
            candidates.append(text)

        return candidates

    def score_candidate(self, problem, candidate):
        judge_prompt = (
            f"Review the following problem and proposed solution.\n"
            f"Problem: {problem}\n"
            f"Proposed Solution:\n{candidate}\n\n"
            f"Critique the logic. Look for fallacies. Then, on the very last line, "
            f"output a score from 0 to 10. Format: 'Score: X'\n"
        )

        messages = [{"role": "user", "content": judge_prompt}]
        inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.device)

        # Low temperature for the judge (we want consistent, critical evaluation)
        outputs = self.model.generate(
            inputs,
            max_new_tokens=200,
            temperature=0.1,
            do_sample=False
        )

        critique = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)

        # Simple parsing to find "Score: X"
        try:
            # Look for the last occurrence of "Score:"
            score_text = critique.split("Score:")[-1].strip()
            # Take the first number found
            score = int(''.join(filter(str.isdigit, score_text.split()[0])))
        except:
            score = 0 # Fallback if parsing fails

        return score, critique

    def solve(self, problem, n=3):
        # 1. Generate N solutions
        candidates = self.generate_candidates(problem, n)

        scored_results = []

        # 2. Critique each solution
        for i, cand in enumerate(candidates):
            score, critique = self.score_candidate(problem, cand)
            print(f"\n[Candidate {i+1}] Score: {score}/10")
            print(f"Critique snippet: {critique[:100]}...") # Print start of critique
            scored_results.append((score, cand, critique))

        # 3. Select Best
        # Sort by score descending
        scored_results.sort(key=lambda x: x[0], reverse=True)
        best_result = scored_results[0]

        return best_result

In [10]:
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
reasoner = BestOfNReasoner(model_id)

problem = (
    "I have 3 apples today. Yesterday I ate one apple. "
    "How many apples do I have now?"
)

best_score, best_text, best_critique = reasoner.solve(problem, n=3)

print(f"\n{'='*20} WINNING SOLUTION (Score {best_score}) {'='*20}")
# Extract just the answer part for cleanliness
if "</think>" in best_text:
    print(best_text.split("</think>")[-1].strip())
else:
    print(best_text)

--- Generating 3 candidates ---

[Candidate 1] Score: 8/10
Critique snippet: Score: 8

The logic provided is correct but could be more explicit about the process. The breakdown ...

[Candidate 2] Score: 8/10
Critique snippet: Score: 8

Explanation:
- The logic provided correctly accounts for the fact that eating an apple doe...

[Candidate 3] Score: 8/10
Critique snippet: Score: 8

The logic provided correctly accounts for the fact that you did not eat an apple yesterday...

Therefore, the answer is 2 apples.


In [11]:
print(best_text)

system
You are a deep reasoning AI. You must NOT answer immediately. First, use a <think></think> block to break down the problem, check for edge cases, and verify your logic step-by-step. Only after thinking should you provide the final answer (after the ending </think> tag).
user
I have 3 apples today. Yesterday I ate one apple. How many apples do I have now?
assistant
<think>
To solve this, let's break it down:
1. Today: We start with 3 apples.
2. Yesterday: We consumed 1 apple.
3. Now we need to find out how many apples we have left.

Calculation:
- Start with 3 apples.
- Subtract the 1 apple that was eaten yesterday.
So, 3 - 1 = 2

We don't eat or lose any apples in between, so the number of apples remains unchanged from today.
</think>

Therefore, the answer is 2 apples.
