In [1]:
example = {  
    "problem": "Find the number of positive integers n ≤ 100 such that n² + 1 is divisible by 3.",  
    "answer": "33"  
}  

In [2]:
query = example['problem']
ground_truth = example['answer']

In [3]:
import re

pattern = re.compile(r'\-?\d+\.\d+|\-?\d+')

def extract_label(text: str) -> str:

    if '\n####' in text:
        text = text.split('\n####')[-1].replace(',','')
    elif 'The answer is' in text:
        text = text.split('The answer is')[-1].replace(',','')
    numbers = pattern.findall(text)
    if not numbers:
        return None
    return numbers[0]

In [4]:
extract_label(ground_truth)

'33'

In [5]:
if extract_label(ground_truth).isdigit():
    ans_format = r'"[Final Answer] The answer is [number] \n#### [number]"'

In [6]:
ans_format

'"[Final Answer] The answer is [number] \\n#### [number]"'

In [7]:
hints_prompt = f'Question: {query}\nCould you provide me with the thought process to solve this problem, but please don’t give me the answer or calculation, just the thought process?'
print(hints_prompt)

Question: Find the number of positive integers n ≤ 100 such that n² + 1 is divisible by 3.
Could you provide me with the thought process to solve this problem, but please don’t give me the answer or calculation, just the thought process?


In [8]:
max_iter = 16

### Baseado nas informações acima, vamos ter as variaveis abaixo:

In [9]:
print(f"query: {query}")
print(f"ground_truth: {ground_truth}")
print(f"max_iter: {max_iter}")
print(f"ans_format: {ans_format}")

query: Find the number of positive integers n ≤ 100 such that n² + 1 is divisible by 3.
ground_truth: 33
max_iter: 16
ans_format: "[Final Answer] The answer is [number] \n#### [number]"


## Inicio da Main

In [46]:
None == None

True

In [10]:
# Lista de respostas para explorar
to_explore = []

# Dicionário com recompensas para cada resposta
to_explore_reward = {}

# Histórico de conversação para cada resposta
history_bank = {}

# Banco de dicas geradas
hints_bank = {}

# Valores UCB (Upper Confidence Bound) para seleção
ucb_bank = {}

# Estrutura de árvore para o MCTS
fathers = {}
childs = {}

In [11]:
import aisuite as ai

# Inicializa o cliente da AISuite
client = ai.Client()
client.configure({
    "ollama": {
        "timeout": 600
    }
})

def generate(prompt, history=None, timeout=150, truncate=True, model_name="ollama:qwen3:14b"):
    """
    Gera uma resposta usando o modelo especificado via AISuite e Ollama.

    Parâmetros:
    - prompt (str): A entrada do usuário.
    - history (list): Histórico de mensagens anteriores.
    - timeout (int): Tempo máximo de espera pela resposta.
    - truncate (bool): Se True, mantém apenas as duas últimas interações no histórico.
    - model_name (str): Nome do modelo a ser utilizado.

    Retorna:
    - response_content (str): Resposta gerada pelo modelo.
    - updated_history (list): Histórico atualizado com a nova interação.
    """
    if history is None:
        history = []

    # Formata o histórico de mensagens
    messages = [{"role": "user" if i % 2 == 0 else "assistant", "content": h} for i, h in enumerate(history)]

    # Trunca o histórico se necessário
    if truncate:
        messages = messages[-2:]

    # Adiciona o novo prompt
    messages.append({"role": "user", "content": prompt})

    # Chama o modelo através da AISuite
    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        temperature=0.95,
        timeout=timeout
    )

    # Extrai o conteúdo da resposta
    response_content = response.choices[0].message.content

    # Atualiza o histórico com a nova interação
    updated_history = history + [prompt, response_content]

    return response_content, updated_history

In [12]:
def get_weak_answer(question,new_len=0,ans_format=''):
    query = f'Question: {question}\nThe response should begin with [reasoning process]...[Verification]... and end with {ans_format}\nLet\'s think step by step.'
    return generate(query,timeout=90)

In [13]:
weak_answer,history = get_weak_answer(query,ans_format=ans_format)

In [14]:
history

['Question: Find the number of positive integers n ≤ 100 such that n² + 1 is divisible by 3.\nThe response should begin with [reasoning process]...[Verification]... and end with "[Final Answer] The answer is [number] \\n#### [number]"\nLet\'s think step by step.',
 '[reasoning process] To solve the problem of finding how many positive integers $ n \\leq 100 $ make $ n^2 + 1 $ divisible by 3, we analyze the behavior of $ n^2 + 1 $ modulo 3.\n\nAny integer $ n $ modulo 3 can be 0, 1, or 2. We evaluate $ n^2 + 1 \\mod 3 $ for each case:\n\n- If $ n \\equiv 0 \\mod 3 $, then $ n^2 \\equiv 0 \\mod 3 $, and $ n^2 + 1 \\equiv 1 \\mod 3 $.\n- If $ n \\equiv 1 \\mod 3 $, then $ n^2 \\equiv 1 \\mod 3 $, and $ n^2 + 1 \\equiv 2 \\mod 3 $.\n- If $ n \\equiv 2 \\mod 3 $, then $ n^2 \\equiv 4 \\equiv 1 \\mod 3 $, and $ n^2 + 1 \\equiv 2 \\mod 3 $.\n\nIn all cases, $ n^2 + 1 $ is never congruent to 0 modulo 3. That is, $ n^2 + 1 $ is never divisible by 3.\n\nTo confirm this, we can test small values 

In [15]:
history_bank[weak_answer] = tuple(history)

In [16]:
answers_list = [weak_answer,]

In [17]:
to_explore = [weak_answer,]
to_explore

['[reasoning process] To solve the problem of finding how many positive integers $ n \\leq 100 $ make $ n^2 + 1 $ divisible by 3, we analyze the behavior of $ n^2 + 1 $ modulo 3.\n\nAny integer $ n $ modulo 3 can be 0, 1, or 2. We evaluate $ n^2 + 1 \\mod 3 $ for each case:\n\n- If $ n \\equiv 0 \\mod 3 $, then $ n^2 \\equiv 0 \\mod 3 $, and $ n^2 + 1 \\equiv 1 \\mod 3 $.\n- If $ n \\equiv 1 \\mod 3 $, then $ n^2 \\equiv 1 \\mod 3 $, and $ n^2 + 1 \\equiv 2 \\mod 3 $.\n- If $ n \\equiv 2 \\mod 3 $, then $ n^2 \\equiv 4 \\equiv 1 \\mod 3 $, and $ n^2 + 1 \\equiv 2 \\mod 3 $.\n\nIn all cases, $ n^2 + 1 $ is never congruent to 0 modulo 3. That is, $ n^2 + 1 $ is never divisible by 3.\n\nTo confirm this, we can test small values of $ n $ (e.g., $ n = 1 $ to $ n = 10 $) and observe that $ n^2 + 1 $ is not divisible by 3 for any of them. Moreover, a deeper number theory argument shows that the quadratic residues modulo 3 are only 0 and 1, and $ -1 \\mod 3 $ (i.e., 2) is not a quadratic resid

In [18]:
childs[weak_answer] = []

In [19]:
fathers[weak_answer] = None

In [20]:
def cal_reward(question,ans):
    query = f'Question: {question}\nAnswer:{ans}\nAnalyze this Answer Strictly and Critic, point out every flaw for ervery possible imperfect to minus every possible score! You need to be very harsh and mean in calculating grades, and never give full marks to ensure that the marks are authoritative. \nOutput a score between [-100,+100], ig. from -100 to +100. \nResponse format:\n[Analyst]...[Score]...'
    ret = generate(query)
    score = ret[0].split('Score')[-1]
    scores = pattern.findall(score)
    if not scores:
        raise Exception('no')
    else:
        ret = float(scores[-1])
        # if abs(ret - 100.0) < 1e-5:
        #     ret = 50.0
        if ret >= 95:
            ret = 50
        # elif ret <= -100:
        #     ret = -50
        return ret 

In [21]:
cal_reward(query,weak_answer)

50

In [22]:
def sampling_reward(answer):
    if answer not in to_explore_reward:
        to_explore_reward[answer] = []
    reward = cal_reward(query,answer)
    # if check(ground_truth,answer):
    #     reward += 100
    to_explore_reward[answer].append(reward)

In [23]:
sampling_reward(weak_answer)

In [24]:
len(to_explore_reward.keys())

1

In [25]:
to_explore_reward

{'[reasoning process] To solve the problem of finding how many positive integers $ n \\leq 100 $ make $ n^2 + 1 $ divisible by 3, we analyze the behavior of $ n^2 + 1 $ modulo 3.\n\nAny integer $ n $ modulo 3 can be 0, 1, or 2. We evaluate $ n^2 + 1 \\mod 3 $ for each case:\n\n- If $ n \\equiv 0 \\mod 3 $, then $ n^2 \\equiv 0 \\mod 3 $, and $ n^2 + 1 \\equiv 1 \\mod 3 $.\n- If $ n \\equiv 1 \\mod 3 $, then $ n^2 \\equiv 1 \\mod 3 $, and $ n^2 + 1 \\equiv 2 \\mod 3 $.\n- If $ n \\equiv 2 \\mod 3 $, then $ n^2 \\equiv 4 \\equiv 1 \\mod 3 $, and $ n^2 + 1 \\equiv 2 \\mod 3 $.\n\nIn all cases, $ n^2 + 1 $ is never congruent to 0 modulo 3. That is, $ n^2 + 1 $ is never divisible by 3.\n\nTo confirm this, we can test small values of $ n $ (e.g., $ n = 1 $ to $ n = 10 $) and observe that $ n^2 + 1 $ is not divisible by 3 for any of them. Moreover, a deeper number theory argument shows that the quadratic residues modulo 3 are only 0 and 1, and $ -1 \\mod 3 $ (i.e., 2) is not a quadratic resid

In [26]:
type(history), len(history)

(list, 2)

In [27]:
history[0]

'Question: Find the number of positive integers n ≤ 100 such that n² + 1 is divisible by 3.\nThe response should begin with [reasoning process]...[Verification]... and end with "[Final Answer] The answer is [number] \\n#### [number]"\nLet\'s think step by step.'

In [28]:
weak_answer

'[reasoning process] To solve the problem of finding how many positive integers $ n \\leq 100 $ make $ n^2 + 1 $ divisible by 3, we analyze the behavior of $ n^2 + 1 $ modulo 3.\n\nAny integer $ n $ modulo 3 can be 0, 1, or 2. We evaluate $ n^2 + 1 \\mod 3 $ for each case:\n\n- If $ n \\equiv 0 \\mod 3 $, then $ n^2 \\equiv 0 \\mod 3 $, and $ n^2 + 1 \\equiv 1 \\mod 3 $.\n- If $ n \\equiv 1 \\mod 3 $, then $ n^2 \\equiv 1 \\mod 3 $, and $ n^2 + 1 \\equiv 2 \\mod 3 $.\n- If $ n \\equiv 2 \\mod 3 $, then $ n^2 \\equiv 4 \\equiv 1 \\mod 3 $, and $ n^2 + 1 \\equiv 2 \\mod 3 $.\n\nIn all cases, $ n^2 + 1 $ is never congruent to 0 modulo 3. That is, $ n^2 + 1 $ is never divisible by 3.\n\nTo confirm this, we can test small values of $ n $ (e.g., $ n = 1 $ to $ n = 10 $) and observe that $ n^2 + 1 $ is not divisible by 3 for any of them. Moreover, a deeper number theory argument shows that the quadratic residues modulo 3 are only 0 and 1, and $ -1 \\mod 3 $ (i.e., 2) is not a quadratic residu

## Add uma resposta ruim


Nesse momento adicionamos uma resposta ruim no lugar da resposta fraca.


In [29]:
cal_reward(query,'I don\'t know how to solve this question.')

-100.0

In [30]:
sampling_reward(weak_answer)

In [31]:
def check(gt,ans):

    gt_label = extract_label(gt)
    ans_label = extract_label(ans)
    # print(gt_label,ans_label)
    if gt_label is None or ans_label is None:
        return False
    if ans_label == gt_label or abs(float(ans_label) - float(gt_label)) < 1e-5:
        return True
    else:
        return False


In [32]:
alpha = 0.45

In [33]:
ucb_bank

{}

In [None]:
def compute_ucb(r_c, N_n, N_c, C):
    return r_c + C * math.sqrt(math.log(N_n + 1) / (N_c + 1e-5))

In [None]:
def update_ucb(fathers, childs, to_explore, to_explore_reward, ucb_bank, C=1.4,gamma=0.85):
    # 计算所有节点的访问次数
    visit_count = {node: len(to_explore_reward[node]) for node in to_explore}

    # 计算所有节点的平均奖励
    # avg_reward = {node: sum(to_explore_reward[node]) / len(to_explore_reward[node]) for node in to_explore}
    avg_reward = {node: (min(to_explore_reward[node]) + np.mean(to_explore_reward[node])) / 2 for node in to_explore}

    # 获取所有叶子节点
    leaves = set(to_explore) - set(fathers.values())
    
    # 更新所有叶子节点的UCB值
    for leaf in leaves:
        # ucb_bank[leaf] = avg_reward[leaf]
        ucb_bank[leaf] = compute_ucb(avg_reward[leaf],len(to_explore_reward.get(fathers.get(leaf,None),[])),len(to_explore_reward.get(leaf,[])),C)
    
    # 从叶子节点向上更新父节点的UCB值
    nodes_to_update = list(leaves)
    while nodes_to_update:
        new_nodes_to_update = set()
        for node in nodes_to_update:
            father = fathers.get(node)
            if father is not None:
                if father not in ucb_bank:
                    new_nodes_to_update.add(father)
                if father in ucb_bank:
                    # 计算父节点的UCB值
                    ucb_values = []
                    child_reward = []
                    for child in childs[father]:
                        ucb_values.append(ucb_bank[child])
                        child_reward.append(avg_reward[child])
                    father_reward = (avg_reward[father] + max(child_reward))/2
                    ucb_bank[father] = compute_ucb(father_reward,len(to_explore_reward.get(fathers.get(father,None),[])),len(to_explore_reward.get(father,[])),C)
        nodes_to_update = list(new_nodes_to_update)


In [None]:
def compute_ucb(r_c, N_n, N_c, C):
    return r_c + C * math.sqrt(math.log(N_n + 1) / (N_c + 1e-5))

In [None]:
def filter_mature_node(childs, to_explore, to_explore_reward,max_expand=3):
    filterd_to_explore = []
    avg_reward = {node: (min(to_explore_reward[node]) + np.mean(to_explore_reward[node])) / 2 for node in to_explore}

    for node in to_explore:
        if len(childs.get(node,[])) < max_expand or max([avg_reward.get(child,-999) for child in childs.get(node,[])]) < avg_reward.get(node,-999):
            filterd_to_explore.append(node)
    
    return filterd_to_explore

In [35]:
import math

def ucb(mean, C, N, n):
    return mean + C*math.sqrt(math.log(N)/n)

mean = 0.5
for n in [1,5,10,50,100]:
    print('N=100, C=1.4, n=', n, '->', ucb(mean,1.4,100,n))

for N in [10,100,1000,10000]:
    print('n=10, C=1.4, N=', N, '->', ucb(mean,1.4,N,10))

for C in [0.5,1.0,1.4,2.0,3.0]:
    print('N=1000, n=10, C=', C, '->', ucb(mean,C,1000,10))

N=100, C=1.4, n= 1 -> 3.504352436805086
N=100, C=1.4, n= 5 -> 1.8435872554126627
N=100, C=1.4, n= 10 -> 1.4500596594181157
N=100, C=1.4, n= 50 -> 0.9248795962278409
N=100, C=1.4, n= 100 -> 0.8004352436805087
n=10, C=1.4, N= 10 -> 1.1717936277063314
n=10, C=1.4, N= 100 -> 1.4500596594181157
n=10, C=1.4, N= 1000 -> 1.663580695388377
n=10, C=1.4, N= 10000 -> 1.8435872554126627
N=1000, n=10, C= 0.5 -> 0.9155645340672776
N=1000, n=10, C= 1.0 -> 1.3311290681345551
N=1000, n=10, C= 1.4 -> 1.663580695388377
N=1000, n=10, C= 2.0 -> 2.1622581362691102
N=1000, n=10, C= 3.0 -> 2.993387204403665


In [36]:
import math

In [45]:
import pandas as pd, math

def compute_ucb(mean: float, C: float, N: int, n_i: int) -> float:
    """
    Calcula o valor de Upper Confidence Bound (UCB) para um braço/alternativa.

    Parameters
    ----------
    mean : float
        Média empírica das recompensas desse braço.
    C : float
        Constante de exploração.
    N : int
        Número total de interações (visitas) já realizadas.
    n_i : int
        Número de vezes que esse braço foi escolhido.

    Returns
    -------
    float
        Valor UCB.
    """
    return mean + C * math.sqrt(math.log(N) / n_i)

# ----- Exemplo prático: restaurantes X e Y -----
C = 1
avg_X = 4.2   # média observada do restaurante X
avg_Y = 3.8   # média observada do restaurante Y
n_Y  = 2      # Y foi visitado 2 vezes

rows = []
for n_X in range(2, 21):          # considerar de 2 a 20 visitas ao X
    N = n_X + n_Y                 # total de visitas
    ucb_X = compute_ucb(avg_X, C, N, n_X)
    ucb_Y = compute_ucb(avg_Y, C, N, n_Y)
    chosen = "X" if ucb_X >= ucb_Y else "Y"
    rows.append(
        {"n_X": n_X, "n_Y": n_Y, "N": N,
         "UCB_X": round(ucb_X, 2), "UCB_Y": round(ucb_Y, 2),
         "Escolhido": chosen}
    )

df = pd.DataFrame(rows)
df

Unnamed: 0,n_X,n_Y,N,UCB_X,UCB_Y,Escolhido
0,2,2,4,5.03,4.63,X
1,3,2,5,4.93,4.7,X
2,4,2,6,4.87,4.75,X
3,5,2,7,4.82,4.79,X
4,6,2,8,4.79,4.82,Y
5,7,2,9,4.76,4.85,Y
6,8,2,10,4.74,4.87,Y
7,9,2,11,4.72,4.89,Y
8,10,2,12,4.7,4.91,Y
9,11,2,13,4.68,4.93,Y
