In [15]:
import torch
import re
import numpy as np
import tqdm.notebook as tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
class TreeOfThoughtsReasoner:
    def __init__(self, model_id: str):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id, 
            torch_dtype=torch.bfloat16, 
            device_map=self.device
        )
    
    def query_model(
        self, prompt: str, k: int, max_tokens: int = 200
    ) -> list[str]:
        messages = [
            {
                "role": "system",
                "content": "You are a helpful assistant.",
            },
            {
                "role": "user",
                "content": prompt,
            },
        ]
        input_ids = self.tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
        ).to(self.device)

        output = self.model.generate(
            input_ids, 
            max_new_tokens=max_tokens, 
            temperature=0.7,
            do_sample=True,
            num_return_sequences=k,
            pad_token_id=self.tokenizer.eos_token_id
        )
        # to remove the input prompt from the output
        input_len = input_ids.shape[1]
        # return a list of k outputs
        return [
            self.tokenizer.decode(out[input_len:], skip_special_tokens=True).strip()
            for out in output
        ]
    
    def get_next_steps(
        self,
        current_state: str,
        task_description: str,
        current_depth: int,
        max_depth: int,
        k: int = 3,
    ) -> list[str]:
        # given the current partial solution, generate k possible next steps
        prompt = f"""Task: {task_description}
Current partial solution: {current_state}
Based on this, write the single most logical next step and its solution.
There will be {max_depth} steps in total.
Right now, you need to generate the step {current_depth}.
Output ONLY the next logical step with its solution. Keep it brief."""

        possible_steps = self.query_model(prompt, max_tokens=100, k=k)
        return [current_state + "\n" + step for step in possible_steps]

    def evaluate_state(
        self, current_state: str, task_description: str
    ) -> float:
        # ask the model to score the current partial solution (0.1 to 1.0)
        prompt = f"""Task: {task_description}
Proposed partial solution: {current_state}
Evaluate if this solution is on the right track to solving the task.
Consider logical consistency and constraints.
Give a score between 0.1 and 1.0 where 1.0 is perfect.
Output ONLY the number."""
        
        response = self.query_model(prompt, max_tokens=10, k=1)[0]
        try:
            # extract the first floating point number found in the text
            score = float(re.findall(r"0\.\d+|1\.0|0", response)[0])
        except Exception:
            score = 0.5
            
        return score

    def solve(self, task: str, max_depth: int = 3, beam_width: int = 2) -> str:

        # start with 1 empty state
        current_beams = [""] 
        
        for current_depth in range(1, max_depth + 1):            
            all_candidates = []

            for beam in current_beams:
                next_steps = self.get_next_steps(beam, task, current_depth, max_depth, k=3)
                
                for possible_step in next_steps:
                    score = self.evaluate_state(possible_step, task)
                    all_candidates.append({"score": score, "step": possible_step})

            # keep only the top `beam_width` candidates for the next depth level
            all_candidates.sort(key=lambda x: x["score"], reverse=True)
            current_beams = [candidate["step"] for candidate in all_candidates[:beam_width]]

        return current_beams[0] # Return the single best path

In [3]:
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
tot_engine = TreeOfThoughtsReasoner(model_id)

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/660 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

In [19]:
problem = "Calculate (2 + 1) * 3"
solution = tot_engine.solve(problem, max_depth=3, beam_width=2)
print(solution)


Next logical step:
(2 + 1) = ?

Solution:
3
Step 2:
3 * 3 = ?
3 * 3 = 9


In [41]:
def get_problem() -> tuple[str, int]:
    a, b, c = np.random.randint(0, 100, 3)
    return f"Calculate ({a:.0f} + {b:.0f}) * {c:.0f}. Output just the final number", (a + b) * c


def extract_result(response) -> int:
    try:
        return int(re.findall(r"\d+", response)[-1])
    except Exception:
        return -1

In [42]:
np.random.seed(0)
problems, targets = zip(*[get_problem() for _ in range(30)])

answers_simple = [tot_engine.query_model(p, k=1, max_tokens=50)[0] for p in problems]

In [43]:
answers_tot = []
for problem in tqdm.tqdm(problems):
    answers_tot.append(tot_engine.solve(problem, max_depth=3, beam_width=2))

  0%|          | 0/30 [00:00<?, ?it/s]

In [44]:
targets = np.array(targets)
extracted_answers_simple = np.array([extract_result(ans) for ans in answers_simple])
extracted_answers_tot = np.array([extract_result(ans) for ans in answers_tot])

In [47]:
print(
    "Accuracy (base model):", (targets == extracted_answers_simple).mean()
)
print(
    "Accuracy (tree of thoughts):", (targets == extracted_answers_tot).mean()
)

Accuracy (base model): 0.06666666666666667
Accuracy (tree of thoughts): 0.16666666666666666
