# LATS for Ollama

Here is my primitive implementation of LATS using basic Langchain functionalities alongside with a simpler structure. I built this LATS structure to be as modifiable as possible, with different prompt schemas and value functions available.

In [55]:
import getpass
import os


def _set_if_undefined(var: str) -> None:
    if os.environ.get(var):
        return
    os.environ[var] = getpass.getpass(var)
#Configure Tavily API key here
_set_if_undefined("TAVILY_API_KEY")
#You won't need this if you won't use OpenAI LLMs
_set_if_undefined("OPENAI_API_KEY")

## Definitions
Here is the tree itself. Much of the code was taken / inspired from https://github.com/langchain-ai/langgraph/blob/main/examples/lats/lats.ipynb

Also check out this article: https://medium.com/pythoneers/power-up-ollama-chatbots-with-tools-113ed8229a7a

In [97]:
from __future__ import annotations

import math
from typing import Optional
import time

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from collections import deque
from langchain_core.pydantic_v1 import BaseModel, Field


class Node:
    def __init__(self, 
                 messages: list[BaseMessage],
                 reflection,
                 parent: Optional[Node] = None,
                 exploration_weight: Optional[float] = 1.0,
                 ):
        self.messages = messages
        self.reflection = reflection
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0
        self.exploration_weight = exploration_weight
        self.depth = parent.depth + 1 if parent is not None else 1
        self._is_solved = reflection.found_solution if reflection else False
        self.backpropagate(reflection.normalized_score)
    
    def __repr__(self):
        return (
            f"<Node average value={self.avg_value}, visits={self.visits},"
            f" solution={self.messages} reflection={self.reflection}/>"
        )

    @property
    def avg_value(self):
        """Average value. Stands for win/total visits ratio in a normal MCTS"""
        return self.value if self.visits == 0 else self.visits/self.value

    @property
    def upper_confidence_bound(self):
        """The most integral part of MCTS. Balances exploration by choosing nodes with high value and low exploration depending on the exploration_weight"""
        if self.parent is None:
            #raise ValueError("Cannot obtain UCT from root node")
            return 0
        if self.visits == 0:
            return self.value
        exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
        return self.avg_value + self.exploration_weight * exploration_term
    
    @property
    def is_tree_solved(self):
        """Determines if the tree is solved."""
        return any(node._is_solved for node in self.all_children)
    
    @property
    def solved_node(self):
        """Returns the node who solved the prompt."""
        return next((node for node in self.all_children if node._is_solved), None)
    
    def backpropagate(self, value: float):
        node = self
        while node:
            node.visits += 1
            node.value += value
            node = node.parent

    @property
    def all_children(self):
        #taken directly
        all_nodes = []
        nodes = deque()
        nodes.append(self)
        all_nodes.append(self)
        while nodes:
            node = nodes.popleft()
            all_nodes.extend(node.children)
            for n in node.children:
                nodes.append(n)
        return all_nodes

    @property
    def all_messages(self):
        node = self
        message_list = []
        while node.depth > 1:            
            message_list.extend([node.reflection.as_message(), node.messages])
            node = node.parent
        return message_list[::-1]
    
    @property 
    def best_child(self):
        """Gets the best children for the next step of MCTS iteration"""
        #A heap could be used instead of searching the whole structure all the time
        return max(self.all_children, key=lambda child: child.upper_confidence_bound)
    

# LLM & Tools
We define the LLM model and the tools in use (Tavily for this example) here.

In [102]:
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langchain_community.llms import Ollama
from langchain_openai import ChatOpenAI

llm = Ollama(model="llama3", format="json") #ChatOpenAI(model="gpt-4o") 
tool_llm = llm
reasoning_llm = llm
reflection_llm = llm
search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)
tools = [tavily_tool]

# Prompt Templates & Prompt Models
To make LATS as accessible as possible, we avoid using API specific methods. This means we have to implement pydantic parsing, tool calling ourselves. 
We will have 3 chains to simulate LATS alongside with tool calls.
- Tool Chain: Here we will define our tavily tool (or more!) and get the query.
- Reasoning Chain: We feed the tool output into the LLM to answer the prompt itself.
- Reflection Chain: The LLM rates the prompt quality in the end and the node gets created.

In [101]:
from langchain.tools.render import render_text_description 
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.output_parsers import PydanticOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from operator import itemgetter
import json

class ToolCall(BaseModel):
    tool_name: str = Field(
        description="The name of the tool being called"
    )
    args: dict[str,str] = Field(
        description="Arguments for the tool being called in the form of a dictionary"
    )

#taken directly
class Reflection(BaseModel):
    reflections: str = Field(
        description="The critique and reflections on the sufficiency, superfluency,"
        " and general quality of the response"
    )
    score: int = Field(
        description="Score from 0-10 on the quality of the candidate response.",
        gte=0,
        lte=10,
    )
    found_solution: bool = Field(
        description="Whether the response has fully solved the question or task."
    )

    def as_message(self):
        return HumanMessage(
            content=f"Reasoning: {self.reflections}\nScore: {self.score}"
        )

    @property
    def normalized_score(self) -> float:
        return self.score / 10.0

tool_parser = PydanticOutputParser(pydantic_object=ToolCall)
tool_format_instructions = tool_parser.get_format_instructions()
tool_descriptions = render_text_description(tools)
tool_system_prompt = """You are an assistant that has access to the following set of tools.
Here are the names and descriptions for each tool:
{tool_descriptions}
You must specify the tool name and give the method input variables into args as follows:
{tool_format_instructions}
"""

tool_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            tool_system_prompt,
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="messages", optional=True),
    ]
)
def tool_runner(model_output : ToolCall):
    tool_map = {tool.name: tool for tool in tools}
    chosen_tool = tool_map[model_output.tool_name]
    return chosen_tool(model_output.args)

tool_chain = tool_prompt | tool_llm | tool_parser | tool_runner
#tool_chain.invoke({input : "search me about lithium"})
#temp = tool_chain.invoke({
#    "input": "search me about lithium", 
#    "tool_format_instructions": tool_format_instructions, 
#    "tool_descriptions": tool_descriptions
#})

reasoning_system_prompt = """
Think step by step. You are an AI assistant that will answer the given question as a detailed paragraph on the given tool info, your previous thoughts and reflections.
"""
reasoning_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            reasoning_system_prompt,
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="messages", optional=True),
    ]
)

reasoning_chain = reasoning_prompt | reasoning_llm
reflection_parser = PydanticOutputParser(pydantic_object=Reflection)
reflection_format_instructions = reflection_parser.get_format_instructions()
reflection_system_prompt = """
Criticize the given line of thoughts and score it.
You must answer in this format as a json:
{reflection_format_instructions}
"""

reflection_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            reflection_system_prompt,
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="messages", optional=True),
    ]
)
reflection_chain = reflection_prompt | reflection_llm | reflection_parser
#usage: reflection_chain.invoke({"input": "user input", "reflection_format_instructions": "put the pydantic format here"})

# Tree States
We will handle the tree states here.

In [103]:
class LATS:
    def __init__(self,
                 user_prompt: str,
                 max_depth: Optional[int] = 5,
                 leaf_number: Optional[int] = 5,
                 exploration_weight: Optional[float] = 1.0,
                 max_tries: Optional[int] = 3
                 ):
            self.user_prompt = user_prompt
            self.max_depth = max_depth
            self.leaf_number = leaf_number
            self.max_tries = max_tries
            self.exploration_weight = exploration_weight
            init_messages = [HumanMessage(content = user_prompt),]
            init_reflection = Reflection(reflections = "", score = 0, found_solution = False) #Dummy values for root node
            #In theory, these dummy values shouldn't have any effect even though we start with 1 visit in root.
            self.root_node = Node(init_messages, init_reflection, exploration_weight = self.exploration_weight)

    def invoke(self):
        """Runs the tree."""
        ended = False
        iteration = 1
        while not ended:
          print("Iteration ", iteration)
          status = self.expand()
          iteration += 1
          if(status != 'CONTINUE'):
            ended = True
        self.print_solution()

    def expand(self):
          """Expands the best child."""
          node = self.root_node.best_child
          if(node.depth >= self.max_depth):
                return "END"
          #add parallelization here
          for i in range(self.leaf_number):
            print("Node iteration:", i + 1)
            messages = node.all_messages
            tries = 0
            new_node_messages = []  # messages for the node we will initialize

            # Attempt to invoke tool_chain
            while tries < self.max_tries:
                try:
                    tool_message = tool_chain.invoke({
                        "input": self.user_prompt, 
                        "tool_format_instructions": tool_format_instructions, 
                        "tool_descriptions": tool_descriptions,
                        "messages": messages,
                    })
                    break
                except Exception as e:
                    print(e)
                    tries += 1  # Increment tries
            else:
                print("Max tries exceeded for tool_chain")
                return "ENDED"

            tool_message = ToolMessage(content=tool_message, tool_call_id=0)  # call id not important for us
            messages.append(tool_message)

            # Reset tries for reasoning_chain
            tries = 0

            # Attempt to invoke reasoning_chain
            while tries < self.max_tries:
                try:
                    reasoning_message = reasoning_chain.invoke({
                        "input": self.user_prompt,
                        "messages": messages
                    })
                    break
                except Exception as e:
                    print(e)
                    tries += 1  # Increment tries
            else:
                print("Max tries exceeded for reasoning_chain")
                return "ENDED"

            reasoning_message = AIMessage(content=reasoning_message)
            messages.append(reasoning_message)

            # Reset tries for reflection_chain
            tries = 0

            # Attempt to invoke reflection_chain
            while tries < self.max_tries:
                try:
                    reflection_message = reflection_chain.invoke({
                        "input": self.user_prompt,
                        "reflection_format_instructions": reflection_format_instructions,
                        "messages": messages    
                    })
                    break
                except Exception as e:
                    print(e)
                    tries += 1  # Increment tries
            else:
                print("Max tries exceeded for reflection_chain")
                return "ENDED"

            new_node_messages = [tool_message, reasoning_message]
            new_node = Node(new_node_messages, reflection_message, self.root_node, exploration_weight=self.exploration_weight)
            node.children.append(new_node)

            if reflection_message.found_solution:
                return "SOLVED"
          return "CONTINUE"


    def print_solution(self):
      solution_node = self.root_node.solved_node
      if solution_node is None:
        print("No solution has been found in this tree.")
      else:
        print(solution_node.all_messages)
      



# Using The Tree
Congratulations! You have implemented the tree! Using it is relatively simple, like this:

In [104]:
question = "Give me a summary of what OllamaFunctions are in Langchain and how can I use them?"
tree = LATS(user_prompt = question)
tree.invoke()

Iteration  1
Node iteration: 1
[[ToolMessage(content=[{'url': 'https://medium.com/@mauryaanoop3/unleashing-structured-responses-functional-calling-with-langchain-ollama-and-phi-3-part-3-720b34203778', 'content': "In the previous articles, we explored functional calling with LangChain, Ollama, and Microsoft's Phi-3 model. We focused on functional calling, demonstrating how to interact with the LLM and ..."}, {'url': 'https://blog.langchain.dev/json-based-agents-with-ollama-and-langchain/', 'content': "When the LLM needs to call a function, it should use the following JSON structure:\nThat’s why it is called a JSON-based agent: we instruct the LLM to produce a JSON when it wants to use any available tools. Since one of the available tools of the agent is a recommender tool, it decided to utilize the recommender tool by providing the JSON syntax to define its input. Since the tool provided all the required information, the LLM decided that it had enough information to construct a final an