# Scaling Test-Time Compute for Longer Thinking in LLMs

In this example, we will extend the inference time for an **Instruct LLM system** using **test-time compute** to solve more challenging problems, such as **complex math problems**. This approach, inspired by OpenAI reasoning models, demonstrates that **longer reasoning time** during inference can enhance model peformance.

The [blog post](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute) shows that smaller models, like the 1B and 3B Llama Instruct models, can outperform much larger ones on the MATH-500 benchmark when given enough "time to think".

Research from Deepmind suggests that **test-time compte** can be scaled optimally through strategies like iterative self-refinement or using a reward model.

The blog also introduces the [`search-and-learn`](https://github.com/huggingface/search-and-learn) repository. In this example, we will build a small chatbot that engages in longer reasoning to tackle harder problems using small open models.

## Setups

We need to install the `search-and-learn` repository.

In [None]:
!git clone https://github.com/huggingface/search-and-learn

In [None]:
%cd search-and-learn
!pip install -e '.[dev]'

## Set up the LLM and the Process Reward Model (PRM)

The entire system consists of
- an LLM that generates intermediate answers based on user input,
- a PRM model, based on the paper [*Solving math word problems with process- and outcome-based feedback*](https://huggingface.co/papers/2211.14275), that evaluates and scores these answers,
- a search strategy that uses the PRM feedback to guide the subsequent steps in the search process until reaching the final answer.

![system_LLM_PRM](https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/system.png)


In this example, we will use the [`meta-llama/Llama-3.2-1B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) model as the LLM. Next, we will use the [`RLHFlow/Llama3.1-8B-PRM-Deepseek-Data`](https://huggingface.co/RLHFlow/Llama3.1-8B-PRM-Deepseek-Data) model as our PRM.

In [None]:
import torch
from vllm import LLM
from sal.models.reward_models import RLHFFlow # sal --> search-and-learn

llm_path = 'meta-llama/Llama-3.2-1B-Instruct'
prm_path = 'RLHFlow/Llama3.1-8B-PRM-Deepseek-Data'

llm = LLM(
    model=llm_path,
    gpu_memory_utilization=0.5, # utilize 50% of GPU memory
    enable_prefix_caching=True, # optimize repeated prefix computations
    seed=111
)

prm = RLHFFlow(prm_path)

### Instantiate the question, search strategy, and call the pipeline

1. **Instantiate the Question**: We will define the input question that the system will answer, considering the given context.
2. **Search Strategy**: The system supports: `best_of_n`, `beam_search`, and `dvts` (diverse verifier tree search). We will use `best_of_n` in this example.
3. **Call the pipeline**: With the question and search strategy in place, we will call the inference pipeline, processing the inputs through both the LLM and PRM to generate the final answer.

The first step is to clearly define the question that the system will answer.

In [None]:
question_text = "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$"
input_batch = {'problem': [question_text]}

Next, we define the configuration, including parameters like the number of candidate answers, and choose the search strategy that will be used. With the question and configuration in place, we use the selected search strategy to generate multiple candidate answers. These candidates are evalauted based on their relevance and quality and the final answer is returned.

In [None]:
from sal.config import Config
from sal.search import beam_search, best_of_n, dvts

config = Config()
config.n = 32 # number of answers to generate during the search

search_result = best_of_n(
    x=input_batch,
    config=config,
    llm=llm,
    prm=prm
)

### Display the final result

Once the pipeline has processed the question through the LLM and PRM, we can display the final result. This result will be the model's output after considering the intermediate answers and scoring them using the PRM.

In [None]:
search_reuslt['pred'][0]

We may need to remove speical tokens from specific models.

In [None]:
formatted_output = search_result['pred'][0].replace(
    "<|start_header_id|>assistant<|end_header_id|>\n\n",
    ""
).strip()
formatted_output

After removing any speical tokens, we can display the final answer to the user.

In [None]:
from IPython.display import display, Markdown
display(Markdown(formatted_output))

## Assemble it all

Now we will create a function that encapsulates the entire pipeline so that we can easily reuse the process in the future.

In [None]:
import time

def generate_with_search_and_learn(question, config, llm, prm, method='best_of_n'):
    """Generate an answer for a given question using the search-and-learn pipeline.

    Parameters
    ----------
    question: str
        The input question to generate an answer for
    config: Config
        Configuration object containing parameters for search strategy
    llm: LLM
        Pretrained large language model used for generating answers
    prm: RLHFFlow
        Process reward model used for evaluating answers
    method: str
        Search strategy to use for generating answers

    Returns
    -------
    str
        The formatted output after processing the question
    """
    batch = {'problem': [question]}

    start_time = time.time()
    if method == 'best_of_n':
        result = best_of_n(x=batch, config=config, llm=llm, prm=prm)
    elif method == 'beam_search':
        result = beam_search(examples=batch, config=config, llm=llm, prm=prm)
    elif method == 'dvts':
        result = dvts(examples=batch, config=config, llm=llm, prm=prm)

    elapsed_time = time.time() - start_time
    print(f"\nFinished in {elapsed_time:.2f} seconds\n")

    tokenizer = llm.get_tokenizer()
    total_tokens = 0
    for completion in reuslt['completions']:
        for comp in completion:
            output_tokens = tokenizer.encode(comp)
            total_tokens += len(output_tokens)

    print(f"Total tokens in all completions: {total_tokens}")

    formatted_output = result["pred"][0].replace(
        "<|start_header_id|>assistant<|end_header_id|>\n\n",
        ""
    ).strip()
    return formatted_output

### `best_of_n`

In [None]:
question = "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$"

config.n = 8

formatted_output = generate_with_search_and_learn(
    question=question,
    config=config,
    llm=llm,
    prm=prm,
    method="best_of_n"
)

In [None]:
display(Markdown(formatted_output))

### `beam_search`

In [None]:
question = "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$"

config.n = 8
# beam search specific
config.sort_completed = True
config.filter_duplicates = True

formatted_output = generate_with_search_and_learn(
    question=question,
    config=config,
    llm=llm,
    prm=prm,
    method="beam_search"
)

In [None]:
display(Markdown(formatted_output))

### `dvts` (Diverse Verifier Tree Search)

In [None]:
question = "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$"

config.n = 8
# dvts specific
config.n_beams = config.n // config.beam_width

formatted_output = generate_with_search_and_learn(
    question=question,
    config=config,
    llm=llm,
    prm=prm,
    method="dvts"
)

In [None]:
display(Markdown(formatted_output))

### Test the system with a simple question

We also want to test system using a straightforward question to observe how it performs in simpler cases. This allows us to verify that the system works as expected even for basic queries.

In [None]:
question = "What's the capital of Spain?"

config.n = 32

formatted_output = generate_with_search_and_learn(
    question=question,
    config=config,
    llm=llm,
    prm=prm,
    method="best_of_n"
)

In [None]:
display(Markdown(formatted_output))