In [1]:
import sys

sys.path.append('..')


import os

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from reasonflow import ReasonFlowConfig, ReasonFlow

## Load model and tokenizer

In [2]:
path = 'mkurman/phi4-MedIT-10B-o1'
hf_cache = os.environ.get('HF_HOME', os.path.expanduser('~/hf_cache'))

model = AutoModelForCausalLM.from_pretrained(
    path,
    cache_dir=hf_cache,
    torch_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(
    path,
    cache_dir=hf_cache,
)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [3]:
if isinstance(model.config.eos_token_id, int):
    model.config.eos_token_id = [model.config.eos_token_id]

def has_eos(output, eos_token: list[int]):
    return any(x for x in eos_token if x in output)    

## Create input prompt

In [4]:
chat_like_texts = [
    [
        {
            "role": "user",
            "content": f"What is the strength (B1) of a 13C 90° pulse of duration 1 μs?",
        },
    ],
    [
        {
            "role": "user",
            "content": 'What is the maximum value of the z-component of the spin angular momentum of 43Ca? [I = 7⁄2] [ "10.492 x 10^-34 J s", "2.549 x 10^-34 J s", "5.493 x 10^-34 J s", "4.185 x 10^-34 J s", "1.238 x 10^-34 J s", "8.370 x 10^-34 J s", "6.234 x 10^-34 J s", "9.854 x 10^-34 J s", "3.691 x 10^-34 J s", "7.123 x 10^-34 J s" ]',
        },
    ],
]

input_ids = tokenizer(tokenizer.apply_chat_template(chat_like_texts[1], tokenize=False, add_generation_prompt=True), padding=True, truncation=True, return_tensors="pt")

## Configure MCTS

In [5]:
NUM_OF_THOUGHTS = 4 # Maximum number of thoughts to validate
NUM_OF_THINKERS = 6 # Maximum number of thinkers to reason with
TOPK_THINKERS = 2 # Maximum number of thinkers to consider
ACCEPTANCE_THRESHOLD = 0.5 # Decision threshold for accepting a token with lower probability than the most probable token
TEMPERATURES = [0.7, 1.3] # Temperature values for sampling less probable tokens

MAX_TOTAL_TOKENS = 8192 # Maximum number of tokens to generate in total

## Run the ReasonFlow algorithm

In [8]:
model.to("cuda")
input_tokens = input_ids['input_ids'].to('cuda') 

model.eval()

config = ReasonFlowConfig(
    num_of_thinkers=NUM_OF_THINKERS,
    num_of_thoughts=NUM_OF_THOUGHTS,
    topk_thinkers=TOPK_THINKERS,
    acceptance_threshold=ACCEPTANCE_THRESHOLD,
    temperatures=TEMPERATURES,
)

index = 1

reason_flow = ReasonFlow(config, model, tokenizer)
output = reason_flow.generate_thoughts(chat_like_texts[index], 
                                       max_new_tokens=4096, 
                                       stream=True,
                                       device='cuda', 
                                       torch_dtype='bfloat16',
                                       do_sample=False, 
                                       temperature=None, 
                                       top_p=None, 
                                       top_k=None, 
                                       repetition_penalty=None)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


<|im_start|>user<|im_sep|>What is the maximum value of the z-component of the spin angular momentum of 43Ca? [I = 7⁄2] [ "10.492 x 10^-34 J s", "2.549 x 10^-34 J s", "5.493 x 10^-34 J s", "4.185 x 10^-34 J s", "1.238 x 10^-34 J s", "8.370 x 10^-34 J s", "6.234 x 10^-34 J s", "9.854 x 10^-34 J s", "3.691 x 10^-34 J s", "7.123 x 10^-34 J s" ]<|im_end|><|im_start|>assistant<|im_sep|><Thought>
Alright, I need to determine the maximum value of the z-component of the spin angular momentum for the isotope 43Ca, which has a spin quantum number \( I = \frac{7}{2} \). The formula for the maximum value of the z-component of the spin angular momentum is given by:

\[ m_z = m \hbar \]

where:
- \( m \) is the quantum number,
- \( \hbar \) is the reduced Planck's constant, approximately \( 1.0545718 \times 10^{-34} \, \text{J} \cdot \text{s} \).

Given \( m = \frac{7}{2} \), I can calculate:

\[ m_z = \frac{7}{2} \times 1.0545718 \times 10^{-34} \, \text{J} \cdot \text{s} \]

Let's perform the multi