# Set up Ollama for local inference


In [1]:
! pip install ollama


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import ollama
from ollama import ChatResponse

# MODEL = "deepseek-r1:1.5b"
MODEL = "llama3.2:3b"

def get_llm_response(prompt: str) -> str:
    response: ChatResponse = ollama.chat(
        model=MODEL,
        messages=[
            {
                'role': 'user',
                'content': prompt
            }
        ]
    )
    return response.message.content

prompt = 'Hello! write me a 2 liner poem about a dog?'
print(get_llm_response(prompt))


Here is a 2-line poem about a dog:

With wagging tail and loving eyes,
Faithful companion, always by our side.


# COT w/ Reflection

In [3]:
from typing import Tuple

def get_cot_response(question: str) -> Tuple[str, str]:
    prompt = f"""
question: `{question}`
-------------------
Answer the given question, be critical in your thinking process.
Provide a detailed reasoning process, and a verification process.
Respond in the following format:

Reasoning:
detailed reasonging process

Verification:
Cross checking and reflection process

Answer:
"""
    response = get_llm_response(prompt)
    final_answer = response.split("Answer:")[1].strip()
    return final_answer, response


In [5]:
question = "A farmer has 17 chickens and 12 cows. If each chicken lays 2 eggs per day and each cow produces 10 liters of milk per day, how many eggs and liters of milk will the farmer collect in total in one day?"
answer, llm_resp = get_cot_response(question)

print(answer)

The farmer will collect a total of 154 (34 eggs + 120 liters) eggs and liters of milk in one day.


# Monte Carlo Tree Search Implementation

### Helper Functions
1. Critique
2. Improve
3. Rate / Scoring

In [6]:
def get_critique(question: str, answer: str) -> str:
    prompt = f"""
question: `{question}`
answer: `{answer}`
-------------------
Please critique the answer carefully.
Assess for correctness, and point out any errors.
Be verbose and go step by step.
Do not give any answer, just suggest a few obvious improvements.
Give bullets of Assessment and suggested improvements, followed by the conclusion.
"""
    return get_llm_response(prompt)


# print(critique("What is the capital of India?", "Bangalore is the Capital of India."))

In [7]:
def improve_answer(question: str, answer: str, critique: str) -> str:
    prompt = f"""
question: `{question}`
answer: `{answer}`
critique: `{critique}`
-------------------

Improve the answer based on the critique. Respond in the following format:

Reasoning:
detailed reasonging process

Verification:
Cross checking and reflection process

Corrected Answer:
Improved answer
"""
    return get_llm_response(prompt)

# print(improve_answer("What is the capital of India?", "India's capital is Bangalore.", "The answer is incorrect because it is not the capital of India."))

In [8]:
import re
def rate_answer(question: str, answer: str) -> str:
    prompt = f"""
question: `{question}`
answer: `{answer}`
-------------------

You are an expert in evaluating student answers.
Go throught the answer carefully and give a score between 0 and 100.
Take into account the explanation steps as well for evaluation, and give partial marks for correct steps.
Respond in the following format:

Thought: <verification steps and critique>

Score: <Score between 0 and 100>
"""
    response = get_llm_response(prompt)
    try:
        match = re.search(r"Score:\s*(\d+)", response)
        if match:
            score = int(match.group(1))
            if score < 0 or score > 100:
                raise ValueError("Score out of range")
            return float(score)/100
        else:
            raise ValueError("Score not found")
    except Exception as e:
        print(f"Error: {e}")
        return 0.0

# print(rate_answer("What is the capital of India?", "India's capital is Delhi."))

## MCTS code

In [9]:
MAX_CHILDREN = 3
MAX_ITERATIONS = 2
EXPLORATION_CONSTANT = 1.414

SEED_ANSWERS = [
    "I do not know",
    "I can not understand",
    "I am quite confused",
]


In [64]:
import random
import math

class TreeNode:
    def __init__(self, question: str, answer: str, id: str, parent: "TreeNode" = None):
        self.id = id
        self.question = question
        self.answer = answer
        self.parent = parent
        self.children: list["TreeNode"] = []
        self.visits = 0
        self.value = 0.0
        print(f">> INIT node {self.id}")

    def is_leaf(self):
        return len(self.children) == 0

    def is_root(self):
        return self.parent is None

    def is_full(self):
        return len(self.children) >= MAX_CHILDREN

    def add_child(self, child_node: "TreeNode"):
        self.children.append(child_node)

    def most_visited_child(self):
        return max(self.children, key=lambda x: x.visits)

    def best_child(self):
        weights = []
        for c in self.children:
            if c.visits == 0:
                weight = float('inf')
            else:
                exploitation = c.value / c.visits
                exploration = math.sqrt(math.log(c.visits+1) / self.visits)
                weight = exploitation + (EXPLORATION_CONSTANT * exploration)
            weights.append(weight)
        max_weight = max(weights)
        best_child_index = weights.index(max_weight)
        best_child = self.children[best_child_index]
        print(">>>> Best child of", self.id, "is", best_child.id, "with weight", max(weights))
        return best_child

    def expand(self):
        print(f">> Expanding node {self.id}")
        for i in range(MAX_CHILDREN):
            child_node = TreeNode(self.question, self.answer, self.id+str(i), self)
            print(f">>> Child {child_node.id}")
            critique = get_critique(self.question, child_node.answer)
            # print(f">>> Critique: {critique}")
            improved_answer = improve_answer(self.question, child_node.answer, critique)
            child_node.answer = improved_answer
            print(f">>> Improved answer: {improved_answer}")
            self.add_child(child_node)

    def print_node(self):
        print(f">> Node {self.id}")
        print(f">> Answer: {self.answer}")
        print(f">> Visits: {self.visits}")
        print(f">> Value: {self.value}")
        print(f">> Children: {len(self.children)}")
        print("-"*100)

class TreeSearch:
    def __init__(self, question: str):
        self.question = question
        self.root = TreeNode(question, random.choice(SEED_ANSWERS), "0")

    def print_tree(self, node: TreeNode):
        """
        Print the tree in a readable format
        """
        node.print_node()
        for child in node.children:
            self.print_tree(child)

    def select_candidate(self):
        node = self.root
        while not node.is_leaf():
            node = node.best_child()
        return node

    def backprop(self, node: TreeNode, reward: float):
        """
        Backpropagate the reward to the parent nodes -> update the visits and value
        """
        while node is not None:
            node.visits += 1
            node.value += reward
            node = node.parent

    def search(self):
        for i in range(MAX_ITERATIONS):
            print(f"> Iteration {i+1} of {MAX_ITERATIONS}")
            candidate = self.select_candidate()
            if candidate.is_leaf():
                candidate.expand()
                # on expansion, we added all new children to the leaf node
                candidate = random.choice(candidate.children)

            print(f"> Selected candidate: {candidate.id}")
            print(f"> Candidate answer: {candidate.answer}")
            rating = rate_answer(candidate.question, candidate.answer)
            print(f"> Rating: {rating}")
            self.backprop(candidate, rating)
        print("-------DONE-------")

        return self.root.most_visited_child()


In [None]:
result = TreeSearch("How many R's are there in the word 'strawberry'?").search()

print(result.answer)