# Self-Consistency Algorithm Demo
This notebook demonstrates the Self-Consistency algorithm for mathematical reasoning.

In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
from its_hub.utils import SAL_STEP_BY_STEP_SYSTEM_PROMPT
from its_hub.lms import OpenAICompatibleLanguageModel
import os


api_key = os.getenv("OPENAI_API_KEY") 

lm = OpenAICompatibleLanguageModel(
    endpoint="https://api.openai.com/v1", 
    api_key=api_key, 
    model_name="gpt-4o", 
    system_prompt=SAL_STEP_BY_STEP_SYSTEM_PROMPT, 
    is_async=False 
)


In [7]:
# Mathematical problem to solve
prompt = r"Let $a$ be a positive real number such that all the roots of \[x^3 + ax^2 + ax + 1 = 0\]are real. Find the smallest possible value of $a.$"
prompt = "what is 1 + 2?"
messages = [
    {
        "role": "user",
        "content": prompt
    }
]
    

response = lm.generate(messages)

print(response)

{'role': 'assistant', 'content': 'The solution is straightforward:\n\n1 + 2 = 3\n\nTherefore, the final answer is: \\(\\boxed{3}\\). I hope it is correct.', 'refusal': None, 'annotations': []}


In [9]:
def extract_boxed(s: str) -> str:
    import re
    # find all occurrences of \boxed{...}
    boxed_matches = re.findall(r'\\boxed\{([^{}]+(?:\{[^{}]*\}[^{}]*)*)\}', s)
    # return the last match if any were found
    return boxed_matches[-1] if boxed_matches else ""
    
extract_boxed(response['content'])

'3'

## Self-Consistency Algorithm
Now we'll use the Self-Consistency algorithm to improve the answer quality.

In [13]:
from its_hub.algorithms import SelfConsistency

# Set computational budget for scaling
budget = 2

scaling_alg = SelfConsistency(extract_boxed)

scaling_result = scaling_alg.infer(
    lm, prompt, budget, return_response_only=False
)

print(scaling_result.the_one)

{'role': 'assistant', 'content': '1 + 2 equals 3.\n\nTherefore, the final answer is: $\\boxed{3}$. I hope it is correct.', 'refusal': None, 'annotations': []}


In [15]:
print("extracted response counts")
scaling_result.response_counts

extracted response counts


Counter({'3': 2})