In [264]:
# %pip install langchain langgraph langchain-community langchain-ollama
# %pip install -q -U google-generativeai

In [265]:
# Export environment
!conda env export > environment.yml

# Create environment
!conda env create -f environment.yml


CondaValueError: prefix already exists: C:\Users\kelly\anaconda3



In [266]:
import math
from collections import deque
from typing import Optional
import typing
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field

In [267]:
class Reflection(BaseModel):
    reflections: str = Field(
        description="The critique and reflections on the sufficiency, superfluency,"
        " and general quality of the response"
    )
    score: int = Field(
        description="Score from 0-10 on the quality of the candidate response.",
        gte=0,
        lte=10,
    )
    found_solution: bool = Field(
        description="Whether the response has fully solved the question or task."
    )

    def as_message(self):
        return HumanMessage(
            content=f"Reasoning: {self.reflections}\nScore: {self.score}"
        )

    @property
    def normalized_score(self) -> float:
        return self.score / 10.0

class Node:
    def __init__(
        self,
        messages: list[BaseMessage],
        reflection: Reflection,
        parent: Optional["Node"] = None,
    ):
        self.messages = messages
        self.parent = parent
        self.children = []
        self.value = 0
        self.visits = 0
        self.reflection = reflection
        self.depth = parent.depth + 1 if parent is not None else 1
        self._is_solved = reflection.found_solution if reflection else False
        if self._is_solved:
            self._mark_tree_as_solved()
        self.backpropagate(reflection.normalized_score)

    def __repr__(self) -> str:
        return (
            f"<Node value={self.value}, visits={self.visits},"
            f" solution={self.messages} reflection={self.reflection}/>"
        )

    @property
    def is_solved(self):
        """If any solutions exist, we can end the search."""
        return self._is_solved

    @property
    def is_terminal(self):
        return not self.children

    @property
    def best_child_score(self):
        """Return the child with the highest value."""
        if not self.children:
            return None
        return max(self.children, key=lambda child: int(child.is_solved) * child.value)

    @property
    def height(self) -> int:
        """Check for how far we've rolled out the tree."""
        if self.children:
            return 1 + max([child.height for child in self.children])
        return 1

    def upper_confidence_bound(self, exploration_weight=1.0):
        """Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
        if self.parent is None:
            raise ValueError("Cannot obtain UCT from root node")
        if self.visits == 0:
            return self.value
        # Encourages exploitation of high-value trajectories
        average_reward = self.value / self.visits
        # Encourages exploration of less-visited trajectories
        exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
        return average_reward + exploration_weight * exploration_term

    def backpropagate(self, reward: float):
        """Update the score of this node and its parents."""
        node = self
        while node:
            node.visits += 1
            node.value = (node.value * (node.visits - 1) + reward) / node.visits
            node = node.parent

    def get_messages(self, include_reflections: bool = True):
        if include_reflections:
            return self.messages + [self.reflection.as_message()]
        return self.messages

    def get_trajectory(self, include_reflections: bool = True) -> list[BaseMessage]:
        """Get messages representing this search branch."""
        messages = []
        node = self
        while node:
            messages.extend(
                node.get_messages(include_reflections=include_reflections)[::-1]
            )
            node = node.parent
        # Reverse the final back-tracked trajectory to return in the correct order
        return messages[::-1]  # root solution, reflection, child 1, ...

    def _get_all_children(self):
        all_nodes = []
        nodes = deque()
        nodes.append(self)
        while nodes:
            node = nodes.popleft()
            all_nodes.extend(node.children)
            for n in node.children:
                nodes.append(n)
        return all_nodes

    def get_best_solution(self):
        """Return the best solution from within the current sub-tree."""
        all_nodes = [self] + self._get_all_children()
        best_node = max(
            all_nodes,
            # We filter out all non-terminal, non-solution trajectories
            key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
        )
        return best_node

    def _mark_tree_as_solved(self):
        parent = self.parent
        while parent:
            parent._is_solved = True
            parent = parent.parent

In [268]:
from typing_extensions import TypedDict

class TreeState(TypedDict):
    # The full tree
    root: Node
    # The original input
    input: str

In [269]:
import google.generativeai as genai
genai.configure(api_key="AIzaSyDRxJHyXoL5SFL3wSmTo0pZuYBP2WOjcxc")

class Reflection_Schema(TypedDict):
    reflections: str
    score: typing.Annotated[int, 0, 10]
    found_solution: bool

def get_gemini_response(prompt):
    try:
        model = genai.GenerativeModel("gemini-1.5-flash")
        response = model.generate_content(prompt)
        return response
    except Exception as e:
        print(f"Error querying Gemini model: {e}")
        return None
    
def get_gemini_reflection(system, prompt):
    try:
        model = genai.GenerativeModel(model_name="gemini-1.5-flash", system_instruction=system)
        response = model.generate_content(prompt, 
                        generation_config=genai.GenerationConfig(
                            response_mime_type="application/json",
                            response_schema=Reflection_Schema,
                            ),
                    )
        return response
    except Exception as e:
        print(f"Error querying Gemini model: {e}")
        return None


In [270]:
from langchain_core.runnables import chain as as_runnable

@as_runnable
def reflection_chain(inputs) -> Reflection:
    prompt = "Query: "+inputs["input"]+"\nResponse:"
    for candidate in inputs["candidate"]:
        prompt+=candidate.content
        prompt+="\n"
    print(prompt)

    reflection_text = get_gemini_reflection("You are an AI assistant tasked with reflecting on responses.", prompt)
    print(reflection_text.text)
    # Parse the reflection text into a Reflection object
    import json
    try:
        reflection_data = json.loads(reflection_text.text)
        reflection = Reflection(**reflection_data)
    except Exception as e:
        # Handle parsing error
        reflection = Reflection(
            reflections="Parsing error in reflection.",
            score=0,
            found_solution=False
        )
    if not isinstance(inputs["candidate"][-1], AIMessage):
        reflection.found_solution = False
    return reflection

In [271]:
def generate_initial_response(state: TreeState) -> dict:
    """Generate the initial candidate response."""
    print("Generating initial response...")
    res = get_gemini_response(state["input"])
    output_msg = [AIMessage(res.text)]
    print("Initial response generated.")

    reflection = reflection_chain.invoke(
        {"input": state["input"], "candidate": output_msg}
    )
    print("Reflection on initial response completed.")
    
    root = Node(output_msg, reflection=reflection)
    state["root"] = root
    return state

In [272]:
def select(root: Node) -> Node:
    """Select the best child node to expand."""
    node = root
    while node.children:
        node = max(node.children, key=lambda child: child.upper_confidence_bound())
    return node

In [273]:
def generate_candidates(state: TreeState, config: RunnableConfig):
    """Generate the next candidate response."""
    n = config["configurable"].get("N", 3)
    best_candidate = select(state["root"])
    messages = best_candidate.get_trajectory()
    candidates = []
    prompt = []
    prompt.append(messages)
    prompt.append(state["input"])
    
    for _ in range(n):
        res = get_gemini_response(prompt)
        candidates.append(AIMessage(content=res.text))
    return candidates

In [274]:
def expand(state: TreeState, config: RunnableConfig) -> dict:
    n = config["configurable"].get("N", 3)
    print(f"Expanding node at depth {state['root'].height}...")
    
    candidates = generate_candidates(state, config)
    print(f"{len(candidates)} candidates generated.")
    
    # Reflect on each candidate
    reflections = []
    for idx, candidate in enumerate(candidates):
        print(f"Reflecting on candidate {idx+1}/{len(candidates)}...")
        reflection = reflection_chain.invoke(
            {"input": state["input"], "candidate": [candidate]}
        )
        reflections.append(reflection)
    print("All reflections completed.")
    
    # Grow tree
    best_candidate = select(state["root"])
    
    child_nodes = [
        Node([candidate], parent=best_candidate, reflection=reflection)
        for candidate, reflection in zip(candidates, reflections)
    ]
    
    best_candidate.children.extend(child_nodes)
    return state

In [275]:
# Loop control function
def should_loop(state: TreeState, N: int = 5):
    """Determine whether to continue the tree search."""
    root = state["root"]
    if root.is_solved:
        print("Solution found.")
        return False
    if root.height >= N:
        print("Maximum depth reached.")
        return False
    return True

In [276]:
def lats(state: TreeState, N=3):
    print("Starting LATS...")
    state = generate_initial_response(state)
    config = RunnableConfig(configurable={"N": N})
    
    iteration = 0
    while should_loop(state, N):
        iteration += 1
        print(f"Iteration: {iteration}")
        state = expand(state, config)
        
    # After search, get best solution
    print("Search complete.")
    solution_node = state["root"].get_best_solution()
    best_trajectory = solution_node.get_trajectory(include_reflections=False)
    return best_trajectory[-1]


In [278]:
import json
with open("dataset/task_1_plan_generation.json", 'r') as file:
    data = json.load(file)
queries = [instance['query'] for instance in data['instances']]

answers = []
responses = []
cutoff = 0 # To avoid the "429 Resource has been exhausted (e.g. check quota)" error
for query in queries:
    if cutoff == 5:
        break
    initial = query.split("[PLAN]")
    state = {"input": initial[0]}
    
    system = initial[1].split("[STATEMENT]")
    answers.append(system[0])

    final_response = lats(state, N=3)
    responses.append(final_response.content)
    cutoff +=1

Starting LATS...
Generating initial response...
Initial response generated.
Query: I am playing with a set of blocks where I need to arrange the blocks into stacks. Here are the actions I can do

Pick up a block
Unstack a block from on top of another block
Put down a block
Stack a block on top of another block

I have the following restrictions on my actions:
I can only pick up or unstack one block at a time.
I can only pick up or unstack a block if my hand is empty.
I can only pick up a block if the block is on the table and the block is clear. A block is clear if the block has no other blocks on top of it and if the block is not picked up.
I can only unstack a block from on top of another block if the block I am unstacking was really on top of the other block.
I can only unstack a block from on top of another block if the block I am unstacking is clear.
Once I pick up or unstack a block, I am holding the block.
I can only put down a block that I am holding.
I can only stack a block o

In [None]:
class Answer(TypedDict):
    correct: bool
score = 0
for answer, response in zip(answers, responses):
    system_prompt="You are a teacher that must grade student responses. Given their response and the correct answer, mark it as either correct (True) or incorrect (False)."
    try:
        model = genai.GenerativeModel(model_name="gemini-1.5-flash", system_instruction=system_prompt)
        prompt = "Response: " + response + "\nCorrect Answer: " + answer
        response_result = model.generate_content(prompt, 
                        generation_config=genai.GenerationConfig(
                            response_mime_type="application/json",
                            response_schema=Answer,
                            ),
                    )
        correctness = json.loads(response_result.text)
        if correctness["correct"] == True:
            score += 1
    except Exception as e:
        print(f"Error querying Gemini model: {e}")

print(score)

True
Expected:
unstack the blue block from on top of the orange block
put down the blue block
pick up the orange block
stack the orange block on top of the blue block
[PLAN END]


Actual: Here's a plan to achieve the goal, given the initial conditions and actions allowed:

**Plan:**

1. **Pick up the blue block:** The blue block is clear and on the table (although technically on top of the orange block, the rules allow unstacking it before picking up).  The hand is empty, so this is allowed.

2. **Put down the blue block:** The hand is not empty, and the blue block is held. We put it down on the table.

3. **Pick up the orange block:** The orange block is now clear and on the table. The hand is empty.

4. **Stack the orange block on top of the blue block:** The hand is not empty (holding the orange block), and the blue block is clear.

**Result:** The orange block is now on top of the blue block, achieving the goal.


**Why the initial statement of picking up the blue block as the firs