In [7]:
import sys

sys.path.append('..')


import os

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from src.mcts import MCTS

## Load model and tokenizer

In [8]:
path = 'mkurman/llama-3.2-MEDIT-3B-o1'
hf_cache = os.environ.get('HF_HOME', os.path.expanduser('~/.cache/huggingface'))

model = AutoModelForCausalLM.from_pretrained(
    path,
    cache_dir=hf_cache,
    torch_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(
    path,
    cache_dir=hf_cache,
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 10.83it/s]


In [9]:
if isinstance(model.config.eos_token_id, int):
    model.config.eos_token_id = [model.config.eos_token_id]

def has_eos(output, eos_token: list[int]):
    return any(x for x in eos_token if x in output)    

## Create input prompt

In [10]:
chat_like_texts = [
    [
        {
            "role": "user",
            "content": f"Question: A 48-year-old man comes to the clinic because of a 10-year history of recurrent, intrusive thoughts that his house will bebroken into and damaged by criminals or accidentally destroyed by a fire when he is not home. These thoughts haveworsened during the past 2 months. He reports now spending 4 hours daily checking that the doors and windows are closedand locked and that the stove and oven are turned off; he previously spent 2 hours daily doing these tasks. He says he cannotkeep a job or leave the house very much because of the amount of time he spends checking these things. He has no otherhistory of serious illness and takes no medications. Physical examination shows no abnormalities. On mental statusexamination, he has an anxious mood and a sad affect. He is fully oriented. He is not having hallucinations or delusions. Themost effective pharmacotherapy for this patient is an agent that targets which of the following neurotransmitters?\nA. γ-Aminobutyric acid\nB. Dopamine\nC. Glutamate\nD. Norepinephrine\nE. Serotonin",
        },
    ]
]

input_ids = tokenizer(tokenizer.apply_chat_template(chat_like_texts[0], tokenize=False, add_generation_prompt=True), padding=True, truncation=True, return_tensors="pt")

## Configure MCTS

In [11]:
MAX_DEPTH = 6 # Maximum depth of the tree
MAX_SIMULATIONS = 4 # Maximum number of simulations per node
MAX_NEW_TOKENS = 32 # Maximum number of tokens to generate per simulation
TEMPERATURE = 0.5 # Temperature for the generation

MAX_TOTAL_TOKENS = 8192 # Maximum number of tokens to generate in total

## Run the MCTS

In [12]:
model.to("cuda")
input_tokens = input_ids['input_ids'].to('cuda') 

model.eval()

print(tokenizer.decode(input_tokens[0]))

while len(input_tokens[0]) < MAX_TOTAL_TOKENS:
    mcts = MCTS(model, tokenizer, max_depth=MAX_DEPTH, num_simulations=MAX_SIMULATIONS, temperature=TEMPERATURE, max_new_tokens=MAX_NEW_TOKENS, stop_tokens=model.config.eos_token_id)
    
    new_tokens = mcts.search(input_tokens)

    new_tokens = new_tokens[..., input_tokens.shape[-1]:]

    input_tokens = torch.cat([input_tokens, new_tokens], dim=-1)

    print(tokenizer.decode(new_tokens[0]), end='')

    if has_eos(new_tokens, eos_token=model.config.eos_token_id):
        break

<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 14 Jan 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

Question: A 48-year-old man comes to the clinic because of a 10-year history of recurrent, intrusive thoughts that his house will bebroken into and damaged by criminals or accidentally destroyed by a fire when he is not home. These thoughts haveworsened during the past 2 months. He reports now spending 4 hours daily checking that the doors and windows are closedand locked and that the stove and oven are turned off; he previously spent 2 hours daily doing these tasks. He says he cannotkeep a job or leave the house very much because of the amount of time he spends checking these things. He has no otherhistory of serious illness and takes no medications. Physical examination shows no abnormalities. On mental statusexamination, he has an anxious mood and a sad affect. He is fully oriented. He is not 