# Scaling Test-Time Compute for Longer Thinking in LLMs

## Simple Benchmark

How to run:
- Set Hugging Face Token
- Use T4
- Run all

Will encounter errors at two points:
1. After `pip install`
2. On the first attempt to `import sal`

Just “Restart session and run all” after each of them.

❗️Huge amount of time is needed for `pip install` and downloading models in every new RUNTIME. If any error occurs, restart the SESSION.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
methods = [
    'best_of_n',
    'beam_search',
    'dvts',
    'dynamic_beam', # 3
    'beam_search_ev', # 4
    'greedy_backtrack' # 5
]

# Set method to be tested
test_method = methods[3]

## 1. Install Dependencies

Since Colab comes with many pre-installed packages, leading to difficult-to-resolve version conflicts, we install dependencies in a local virtual environment and freeze them here.

In [None]:
%%bash
echo "
accelerate==1.5.2
aiohappyeyeballs==2.6.1
aiohttp==3.11.14
aiosignal==1.3.2
annotated-types==0.7.0
antlr4-python3-runtime==4.7.2
anyio==4.9.0
attrs==25.3.0
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
cloudpickle==3.1.1
datasets==3.5.0
dill==0.3.8
diskcache==5.6.3
distro==1.9.0
einops==0.8.1
fastapi==0.115.12
filelock==3.18.0
frozenlist==1.5.0
fsspec==2024.12.0
gguf==0.10.0
h11==0.14.0
hf_transfer==0.1.9
httpcore==1.0.7
httptools==0.6.4
httpx==0.28.1
huggingface-hub==0.29.3
idna==3.10
importlib_metadata==8.6.1
iniconfig==2.1.0
interegular==0.3.3
isort==6.0.1
Jinja2==3.1.6
jiter==0.9.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
lark==1.2.2
latex2sympy2==1.9.1
llvmlite==0.44.0
lm-format-enforcer==0.10.6
MarkupSafe==3.0.2
mistral_common==1.5.4
mpmath==1.3.0
msgpack==1.1.0
msgspec==0.19.0
multidict==6.2.0
multiprocess==0.70.16
nest-asyncio==1.6.0
networkx==3.4.2
numba==0.61.0
numpy==1.26.4
nvidia-ml-py==12.570.86
openai==1.69.0
opencv-python-headless==4.11.0.86
outlines==0.0.46
packaging==24.2
pandas==2.2.3
partial-json-parser==0.2.1.1.post5
Pebble==5.1.1
pillow==11.1.0
pluggy==1.5.0
prometheus-fastapi-instrumentator==7.1.0
prometheus_client==0.21.1
propcache==0.3.1
protobuf==6.30.2
psutil==7.0.0
py-cpuinfo==9.0.0
pyairports==2.1.1
pyarrow==19.0.1
pycountry==24.6.1
pydantic==2.11.1
pydantic_core==2.33.0
pytest==8.3.5
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
pytz==2025.2
PyYAML==6.0.2
pyzmq==26.3.0
ray==2.44.1
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
rpds-py==0.24.0
ruff==0.11.2
safetensors==0.5.3
sentencepiece==0.2.0
six==1.17.0
sniffio==1.3.1
starlette==0.46.1
sympy==1.13.3
tiktoken==0.9.0
tokenizers==0.21.1
torch==2.4.0
torchvision==0.19.0
tqdm==4.67.1
transformers==4.50.3
typing-inspection==0.4.0
typing_extensions==4.13.0
tzdata==2025.2
urllib3==2.3.0
uvicorn==0.34.0
uvloop==0.21.0
vllm==0.6.3
watchfiles==1.0.4
websockets==15.0.1
word2number==1.1
xxhash==3.5.0
yarl==1.18.3
zipp==3.21.0
" > requirements.txt

❗️This ends with multiple errors; just ignore them, as we are not using those packages.

In [None]:
!pip install -r requirements.txt

In [None]:
!git clone https://github.com/choyerhuang/CSCI544-Project

❗️If `ImportError: No module named sal`, restart session and start again from here.

In [None]:
%cd /content/CSCI544-Project
!pip install -e '.[dev]'

Log in to Hugging Face to access [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct), as it is a gated model! 🗝️  
If you haven't previously requested access, you'll need to submit a request before proceeding.

⚠️ Use your USC email to register an account. When requesting access, enter "University of Southern California" as your affiliation and select "Research Graduate"; otherwise, your request will be rejected.

In [None]:
from huggingface_hub import login

login(token = "")

## 2. Setup the Large Language Model (LLM) and the Process Reward Model (PRM)


⬇️ Start again from here after **Restart session**.

In [None]:
import torch
from vllm import LLM
from sal.models.reward_models import RLHFFlow

model_path="meta-llama/Llama-3.2-1B-Instruct"
prm_path="RLHFlow/Llama3.1-8B-PRM-Deepseek-Data"

llm = LLM(
    model=model_path,
    gpu_memory_utilization=0.5,  # Utilize 50% of GPU memory
    enable_prefix_caching=True,  # Optimize repeated prefix computations
    seed=42,                     # Set seed for reproducibility
    dtype='half',
    max_model_len=8192,
)

prm = RLHFFlow(prm_path)

## 3. Setup Searching Methods


In [None]:
from sal.config import Config
import os
os.chdir('/content/CSCI544-Project')
from sal.search import beam_search, best_of_n, dvts, run_dynamic_beam_search, beam_search_ev, greedy_backtrack_search

config = Config()

config.n=4
config.prm_batch_size=1
config.search_batch_size=1

if test_method == 'beam_search':
  config.sort_completed=True
  config.filter_duplicates=True
elif test_method == 'dvts':
  config.sort_completed=True
  config.filter_duplicates=True
  config.n_beams = config.n // config.beam_width
elif test_method == 'dynamic_beam':
  config.approach = "dynamic_beam"
  config.sort_completed = True
  config.filter_duplicates = True
  config.num_iterations = 7
  config.dynamic_beam_delta = 0.3   # Beam score margin
  config.min_beams = 2
  config.max_beams = 4
elif test_method == 'beam_search_ev':
  config.approach = 'beam_search_ev'
  config.sort_completed=True
  config.filter_duplicates=True
elif test_method == 'greedy_backtrack':
  config.approach = "greedy_backtrack"
  config.sort_completed = True
  config.filter_duplicates = True
  config.num_iterations = 10
  config.max_backtrack_depth = 3
  config.early_stop_when_x_finished = 1

In [None]:
import time

def generate_with_search_and_learn(question, config, llm, prm, method='best_of_n'):
    """
    Generate an answer for a given question using the search-and-learn pipeline.

    Args:
    - question (str): The input question to generate an answer for.
    - config (Config): Configuration object containing parameters for search strategy.
    - llm (LLM): Pretrained large language model used for generating answers.
    - prm (RLHFFlow): Process reward model used for evaluating answers.
    - method (str): Search strategy to use. Options are 'best_of_n', 'beam_search', 'dvts'. Default is 'best_of_n'.

    Returns:
    - str: The formatted output after processing the question.
    """
    batch = {"problem": [question]}

    start_time = time.time()
    if method == 'best_of_n':
      result = best_of_n(x=batch, config=config, llm=llm, prm=prm)
    elif method == 'beam_search':
      result = beam_search(examples=batch, config=config, llm=llm, prm=prm)
    elif method == 'dvts':
      result = dvts(examples=batch, config=config, llm=llm, prm=prm)
    elif method == 'dynamic_beam':
      result = run_dynamic_beam_search(example_batch=batch, config=config, llm=llm, prm=prm)
    elif method == 'beam_search_ev':
      result = beam_search_ev(examples=batch, config=config, llm=llm, prm=prm)
    elif method == 'greedy_backtrack':
      result = greedy_backtrack_search(examples=batch, config=config, llm=llm, prm=prm)
      print("Result keys:", result.keys())

    elapsed_time = time.time() - start_time
    print(f"\nFinished in {elapsed_time:.2f} seconds\n")

    # tokenizer = llm.get_tokenizer()
    # total_tokens = 0
    # for completion in result['completions']:
    #     for comp in  completion:
    #         output_tokens = tokenizer.encode(comp)
    #         total_tokens += len(output_tokens)

    # print(f"Total tokens in all completions: {total_tokens}")

    formatted_output = result['pred'][0].replace("<|start_header_id|>assistant<|end_header_id|>\n\n", "").strip()
    return formatted_output

## 4. Load & Run


In [None]:
import json

math_500_path = "/content/CSCI544-Project/math_500.json"
sample_100_path = "/content/CSCI544-Project/sample_100.json"

with open(math_500_path, "r") as f:
    math_500_data = json.load(f)

with open(sample_100_path, "r") as f:
    sample_100_id = json.load(f)

In [None]:
output_file_path = f"/content/gdrive/MyDrive/csci544_output_{test_method}.jsonl"

try:
    # Read the last written idx from the file, if it exists
    with open(output_file_path, "r") as f:
        lines = f.readlines()
        last_idx = int(json.loads(lines[-1])["sample_idx"]) if lines else -1
except FileNotFoundError:
    # If the file doesn't exist, start from 0
    last_idx = -1

In [None]:
print("start time: " + time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(time.mktime(time.gmtime()) + -7 * 3600)) + " PDT")

with open(output_file_path, "a") as output_file:
    for idx, id in enumerate(sample_100_id):
        if idx <= last_idx:
            continue

        for row in math_500_data["rows"]:
            if row["row"]["unique_id"] == id:
                curr_row = row["row"]
                print(f"sample_idx: {idx} - unique_id: {id}")
                break
        else:
            print(f"sample_idx: {idx} - unique_id: {id} not found in math_500.json")
            continue

        formatted_output = generate_with_search_and_learn(
            question=curr_row["problem"],
            config=config,
            llm=llm,
            prm=prm,
            method=test_method
        )

        output_file.write(json.dumps({
            "sample_idx": idx,
            "level": curr_row["level"],
            "unique_id": curr_row["unique_id"],
            "predict": formatted_output,
            "answer": curr_row["answer"],
            "correct": 0
        }) + "\n")
        output_file.flush()

print("end time: " + time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(time.mktime(time.gmtime()) + -7 * 3600)) + " PDT")

In [None]:
# from IPython.display import display, Markdown

# display(Markdown(formatted_output))