In [None]:
!pip -q install -U transformers accelerate sentencepiece

import re
import math
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Dict, Any

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

MODEL_NAME = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

print("Device:", device)
print("Model loaded:", MODEL_NAME)

@dataclass
class Node:
    depth: int
    numbers: List[float]
    exprs: List[str]
    thought: str = ""
    score: float = -1e9
    is_goal: bool = False
    parent: Optional["Node"] = None
    meta: Dict[str, Any] = field(default_factory=dict)

def pretty_state(nums: List[float], exprs: List[str]) -> str:
    pairs = [f"{e}={n:g}" for e, n in zip(exprs, nums)]
    return " | ".join(pairs)

In [None]:
OPS = ["+", "-", "*", "/"]

def safe_apply(a: float, b: float, op: str) -> Optional[float]:
    if op == "+": return a + b
    if op == "-": return a - b
    if op == "*": return a * b
    if op == "/":
        if abs(b) < 1e-12:
            return None
        return a / b
    return None

def combine_expr(ea: str, eb: str, op: str) -> str:
    return f"({ea} {op} {eb})"

def is_24(x: float, tol: float = 1e-6) -> bool:
    return abs(x - 24.0) <= tol

def one_step_closeness(nums: List[float]) -> float:
    if len(nums) == 1:
        return abs(nums[0] - 24.0)
    best = float("inf")
    n = len(nums)
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            a, b = nums[i], nums[j]
            for op in OPS:
                r = safe_apply(a, b, op)
                if r is None:
                    continue
                best = min(best, abs(r - 24.0))
    return best if best != float("inf") else 1e9

def heuristic_score(node: Node) -> float:
    nums = node.numbers
    base = -one_step_closeness(nums)
    depth_penalty = 0.05 * node.depth
    exact_bonus = 2.0 if any(is_24(x) for x in nums) else 0.0
    return base - depth_penalty + exact_bonus

In [None]:
PROPOSER_PROMPT = """You are helping solve the 24 game.
We have current items, each item has an expression and its numeric value.
Pick TWO items and combine them with one operation from + - * / to create a new item.
Return between {k} and {k2} suggestions as lines using EXACT format:

i,j,op

Where i and j are 0-based indices into the list. Use i != j. Prefer moves that help reach 24.

Current items:
{items}
"""

def llm_generate_suggestions(items: str, k_min: int, k_max: int, max_new_tokens: int = 160) -> str:
    prompt = PROPOSER_PROMPT.format(k=k_min, k2=k_max, items=items)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.8,
            top_p=0.92,
            num_return_sequences=1,
        )
    txt = tokenizer.decode(out[0], skip_special_tokens=True)
    return txt.strip()

def parse_moves(text: str, n_items: int) -> List[Tuple[int, int, str]]:
    moves = []
    for line in text.splitlines():
        line = line.strip()
        m = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*,\s*([\+\-\*\/])\s*$", line)
        if not m:
            continue
        i, j, op = int(m.group(1)), int(m.group(2)), m.group(3)
        if 0 <= i < n_items and 0 <= j < n_items and i != j:
            moves.append((i, j, op))
    seen = set()
    uniq = []
    for mv in moves:
        if mv not in seen:
            uniq.append(mv)
            seen.add(mv)
    return uniq

def fallback_moves(nums: List[float], limit: int = 24) -> List[Tuple[int, int, str]]:
    scored = []
    n = len(nums)
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            for op in OPS:
                r = safe_apply(nums[i], nums[j], op)
                if r is None:
                    continue
                scored.append((abs(r - 24.0), i, j, op))
    scored.sort(key=lambda x: x[0])
    out = [(i, j, op) for _, i, j, op in scored[:limit]]
    seen, uniq = set(), []
    for mv in out:
        if mv not in seen:
            uniq.append(mv)
            seen.add(mv)
    return uniq

In [None]:
def apply_move(node: Node, i: int, j: int, op: str) -> Optional[Node]:
    nums = node.numbers[:]
    exprs = node.exprs[:]

    a, b = nums[i], nums[j]
    r = safe_apply(a, b, op)
    if r is None:
        return None

    ea, eb = exprs[i], exprs[j]
    new_expr = combine_expr(ea, eb, op)

    for idx in sorted([i, j], reverse=True):
        nums.pop(idx)
        exprs.pop(idx)

    nums.append(r)
    exprs.append(new_expr)

    child = Node(
        depth=node.depth + 1,
        numbers=nums,
        exprs=exprs,
        parent=node,
        thought=f"Combine item {i} and {j} with '{op}' -> {new_expr} = {r:g}",
    )
    child.is_goal = (len(nums) == 1 and is_24(nums[0]))
    child.score = heuristic_score(child)
    return child

def expand(node: Node, branch_factor: int, proposer_kmin: int = 8, proposer_kmax: int = 14) -> List[Node]:
    items_str = "\n".join([f"{idx}: {node.exprs[idx]} = {node.numbers[idx]:g}" for idx in range(len(node.numbers))])

    raw = llm_generate_suggestions(items_str, proposer_kmin, proposer_kmax)
    moves = parse_moves(raw, len(node.numbers))

    if not moves:
        moves = fallback_moves(node.numbers, limit=30)

    moves = moves[: max(branch_factor * 2, branch_factor)]

    children = []
    for (i, j, op) in moves:
        ch = apply_move(node, i, j, op)
        if ch is not None:
            children.append(ch)

    children.sort(key=lambda x: x.score, reverse=True)
    return children[:branch_factor]

In [2]:
def reconstruct_solution(goal: Node) -> List[str]:
    path = []
    cur = goal
    while cur is not None:
        if cur.thought:
            path.append(cur.thought)
        cur = cur.parent
    return list(reversed(path))

def tot_solve_24(
    start_nums: List[int],
    beam_width: int = 10,
    branch_factor: int = 8,
    max_depth: int = 3,
    prune_threshold: float = -10.0,
    verbose: bool = True
) -> Dict[str, Any]:
    root = Node(
        depth=0,
        numbers=[float(x) for x in start_nums],
        exprs=[str(x) for x in start_nums],
    )
    root.score = heuristic_score(root)

    beam = [root]
    best_seen = root

    if verbose:
        print("\n=== ToT Search Start ===")
        print("Start:", pretty_state(root.numbers, root.exprs))
        print("Root score:", root.score)

    for d in range(max_depth):
        candidates: List[Node] = []

        if verbose:
            print(f"\n--- Depth {d} -> {d+1} expansion ---")
            print("Beam states:")
            for bidx, b in enumerate(beam[: min(len(beam), 6)]):
                print(f"  [{bidx}] score={b.score:.3f} | {pretty_state(b.numbers, b.exprs)}")

        for b in beam:
            kids = expand(b, branch_factor=branch_factor)
            candidates.extend(kids)

        if not candidates:
            break

        candidates = [c for c in candidates if c.score >= prune_threshold]

        goals = [c for c in candidates if c.is_goal]
        if goals:
            goals.sort(key=lambda x: x.score, reverse=True)
            sol = goals[0]
            steps = reconstruct_solution(sol)
            return {
                "solved": True,
                "start": start_nums,
                "expression": sol.exprs[0],
                "value": sol.numbers[0],
                "steps": steps,
                "final_score": sol.score
            }

        candidates.sort(key=lambda x: x.score, reverse=True)
        beam = candidates[:beam_width]

        if beam and beam[0].score > best_seen.score:
            best_seen = beam[0]

        if verbose:
            print("Top candidates after pruning/beam:")
            for cidx, c in enumerate(beam[: min(len(beam), 6)]):
                print(f"  [{cidx}] score={c.score:.3f} | {pretty_state(c.numbers, c.exprs)}")

    best_expr = best_seen.exprs[0] if len(best_seen.exprs) == 1 else " ; ".join(best_seen.exprs)
    best_val = best_seen.numbers[0] if len(best_seen.numbers) == 1 else None
    return {
        "solved": False,
        "start": start_nums,
        "best_state": pretty_state(best_seen.numbers, best_seen.exprs),
        "best_expression": best_expr,
        "best_value": best_val,
        "final_score": best_seen.score,
        "note": "Not solved within depth/beam limits; increase beam_width/branch_factor or adjust pruning."
    }

tests = [
    [4, 1, 8, 7],
    [3, 3, 8, 8],
    [6, 6, 6, 6],
    [9, 9, 4, 4],
]

for nums in tests:
    result = tot_solve_24(
        nums,
        beam_width=12,
        branch_factor=10,
        max_depth=3,
        prune_threshold=-12.0,
        verbose=True
    )
    print("\n=== RESULT ===")
    for k, v in result.items():
        if k == "steps":
            print("steps:")
            for s in v:
                print("  -", s)
        else:
            print(f"{k}: {v}")
    print("\n" + "="*80 + "\n")

print("""
To adapt this ToT agent beyond the 24 game:
1) Define a STATE representation (like numbers/exprs here).
2) Define a PROPOSER that generates candidate next steps (LLM tool or rule-based).
3) Define a HEURISTIC / SCORER:
   - for checkable tasks, use objective scoring
   - for open-ended tasks, use an LLM-critic scoring rubric
4) Run the same ToT loop:
   expand -> score -> prune -> keep top beam -> repeat until goal or depth limit.
""")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.3/10.3 MB[0m [31m52.8 MB/s[0m eta [36m0:00:00[0m
[?25h

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/282 [00:00<?, ?it/s]



generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Device: cpu
Model loaded: google/flan-t5-base

=== ToT Search Start ===
Start: 4=4 | 1=1 | 8=8 | 7=7
Root score: -4.0

--- Depth 0 -> 1 expansion ---
Beam states:
  [0] score=-4.000 | 4=4 | 1=1 | 8=8 | 7=7
Top candidates after pruning/beam:
  [0] score=-1.050 | 1=1 | 7=7 | (4 * 8)=32
  [1] score=-1.050 | 1=1 | 7=7 | (8 * 4)=32
  [2] score=-3.050 | 1=1 | 8=8 | (4 * 7)=28
  [3] score=-3.050 | 1=1 | 8=8 | (7 * 4)=28
  [4] score=-4.050 | 4=4 | 7=7 | (1 + 8)=9
  [5] score=-4.050 | 4=4 | 7=7 | (8 + 1)=9

--- Depth 1 -> 2 expansion ---
Beam states:
  [0] score=-1.050 | 1=1 | 7=7 | (4 * 8)=32
  [1] score=-1.050 | 1=1 | 7=7 | (8 * 4)=32
  [2] score=-3.050 | 1=1 | 8=8 | (4 * 7)=28
  [3] score=-3.050 | 1=1 | 8=8 | (7 * 4)=28
  [4] score=-4.050 | 4=4 | 7=7 | (1 + 8)=9
  [5] score=-4.050 | 4=4 | 7=7 | (8 + 1)=9
Top candidates after pruning/beam:
  [0] score=-0.100 | 1=1 | ((4 * 8) - 7)=25
  [1] score=-0.100 | 7=7 | ((4 * 8) - 1)=31
  [2] score=-0.100 | (4 * 8)=32 | (1 + 7)=8
  [3] score=-0.100 | (4