In [1]:
from __future__ import annotations

import math
from typing import List, Optional

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from collections import defaultdict, deque


class Node:
    def __init__(
        self,
        messages: List[BaseMessage],
        reflection: Reflection,
        parent: Optional[Node] = None,
    ):
        self.messages = messages
        self.parent = parent
        self.children = []
        self.value = 0
        self.visits = 0
        self.reflection = reflection
        self.depth = parent.depth + 1 if parent is not None else 1
        self._is_solved = reflection.found_solution if reflection else False
        if self._is_solved:
            self._mark_tree_as_solved()
        self.backpropagate(reflection.normalized_score)

    def __repr__(self) -> str:
        return (
            f"<Node value={self.value}, visits={self.visits},"
            f" solution={self.messages} reflection={self.reflection}/>"
        )

    @property
    def is_solved(self):
        """If any solutions exist, we can end the search."""
        return self._is_solved

    @property
    def is_terminal(self):
        return not self.children

    @property
    def best_child(self):
        """Select the child with the highest UCT to search next."""
        if not self.children:
            return None
        return max(self.children, key=lambda child: child.upper_confidence_bound())

    @property
    def best_child_score(self):
        """Return the child with the highest value."""
        if not self.children:
            return None
        return max(self.children, key=lambda child: int(child.is_solved) * child.value)

    @property
    def height(self) -> int:
        """Check for how far we've rolled out the tree."""
        if self.children:
            return 1 + max([child.depth for child in self.children])
        return 1

    def upper_confidence_bound(self, exploration_weight=1.0):
        """Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
        if self.parent is None:
            raise ValueError("Cannot obtain UCT from root node")
        if self.visits == 0:
            return self.value
        # Encourages exploitation of high-value trajectories
        average_reward = self.value / self.visits
        # Encourages exploration of less-visited trajectories
        exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
        return average_reward + exploration_weight * exploration_term

    def backpropagate(self, reward: float):
        """Update the score of this node and its parents."""
        node = self
        while node:
            node.visits += 1
            node.value = (node.value * (node.visits - 1) + reward) / node.visits
            node = node.parent

    def get_messages(self, include_reflections: bool = True):
        if include_reflections:
            return self.messages + [self.reflection.as_message()]
        return self.messages

    def get_trajectory(self, include_reflections: bool = True) -> List[BaseMessage]:
        """Get messages representing this search branch."""
        messages = []
        node = self
        while node:
            messages.extend(
                node.get_messages(include_reflections=include_reflections)[::-1]
            )
            node = node.parent
        # Reverse the final back-tracked trajectory to return in the correct order
        return messages[::-1]  # root solution, reflection, child 1, ...

    def get_best_solution(self):
        """Return the best solution from within the current sub-tree."""
        all_nodes = [self]
        nodes = deque()
        nodes.append(self)
        while nodes:
            node = nodes.popleft()
            all_nodes.extend(node.children)
            for n in node.children:
                nodes.append(n)
        best_node = max(
            all_nodes,
            # We filter out all non-terminal, non-solution trajectories
            key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
        )
        return best_node

    def _mark_tree_as_solved(self):
        parent = self.parent
        while parent:
            parent._is_solved = True
            parent = parent.parent

In [2]:
from typing_extensions import TypedDict


class TreeState(TypedDict):
    # The full tree
    root: Node
    # The original input
    input: str

In [3]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo")

In [4]:
import base64

def encode_image_to_base64(image_path):
    with open(image_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
    return encoded_string

# Example usage
# image_path = "thumbnail.jpg"
# base64_image = encode_image_to_base64(image_path)

In [5]:
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation

@tool
def get_image_description(question:str, image_path:str):
    """Useful for identifying points in an image that could be responsible for its bad or good CTR. It needs a question and the path to the image."""
    llm = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=1028)
    image = encode_image_to_base64(image_path)
    return llm.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": "{question}"},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{image}"},
                    },
                ]
            )
        ]
    ).content

# @tool
# def get_positive_points(image_path):
#     """Useful for identifying positive points in an image that could cause it to receive a good CTR."""
#     llm = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=1028)
#     image = encode_image_to_base64(image_path)
#     return llm.invoke(
#         [
#             HumanMessage(
#                 content=[
#                     {"type": "text", "text": "Identify the positive points in this image that could cause it to receive a good CTR."},
#                     {
#                         "type": "image_url",
#                         "image_url": {"url": f"data:image/png;base64,{image}"},
#                     },
#                 ]
#             )
#         ]
#     ).content
tools = [get_image_description]
tool_executor = ToolExecutor(tools=tools)

In [6]:
from langchain.chains import create_structured_output_runnable
from langchain.output_parsers.openai_tools import (
    JsonOutputToolsParser,
    PydanticToolsParser,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import chain as as_runnable


class Reflection(BaseModel):
    reflections: str = Field(
        description="The critique and reflections on the sufficiency, superfluency,"
        " and general quality of the response"
    )
    score: int = Field(
        description="Score from 0-10 on the quality of the candidate response.",
        gte=0,
        lte=10,
    )
    found_solution: bool = Field(
        description="Whether the response has fully solved the question or task."
    )

    def as_message(self):
        return HumanMessage(
            content=f"Reasoning: {self.reflections}\nScore: {self.score}"
        )

    @property
    def normalized_score(self) -> float:
        return self.score / 10.0


prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Reflect and grade the assistant response to the user question below.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="candidate"),
    ]
)

reflection_llm_chain = (
    prompt
    | llm.bind_tools(tools=[Reflection], tool_choice="Reflection").with_config(
        run_name="Reflection"
    )
    | PydanticToolsParser(tools=[Reflection])
)


@as_runnable
def reflection_chain(inputs) -> Reflection:
    tool_choices = reflection_llm_chain.invoke(inputs)
    reflection = tool_choices[0]
    if not isinstance(inputs["candidate"][-1], AIMessage):
        reflection.found_solution = False
    return reflection

In [7]:
from typing import List

from langchain_core.prompt_values import ChatPromptValue
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_core.runnables import RunnableConfig

prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful AI assistant that answers questions about images.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="messages", optional=True),
    ]
)


initial_answer_chain = prompt_template | llm.bind_tools(tools=tools).with_config(
    run_name="GenerateInitialCandidate"
)


parser = JsonOutputToolsParser(return_id=True)

In [8]:
initial_response = initial_answer_chain.invoke(
    {"input": "Identify the positive and negative points in this image that may be important in predicting thumbnail CTR. Its path is ./thumbnainal.jpg."}
)
initial_response

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_QjzQjRPHBsG1aGr18I9jh4gP', 'function': {'arguments': '{"question":"Identify the positive and negative points in this image that may be important in predicting thumbnail CTR.","image_path":"./thumbnail.jpg"}', 'name': 'get_image_description'}, 'type': 'function'}]})

In [9]:
import json


# Define the node we will add to the graph
def generate_initial_response(state: TreeState) -> dict:
    """Generate the initial candidate response."""
    res = initial_answer_chain.invoke({"input": state["input"]})
    parsed = parser.invoke(res)
    tool_responses = tool_executor.batch(
        [ToolInvocation(tool=r["type"], tool_input=r["args"]) for r in parsed]
    )
    # print(tool_responses)
    output_messages = [res] + [
        ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])
        for resp, tool_call in zip(tool_responses, parsed)
    ]
    reflection = reflection_chain.invoke(
        {"input": state["input"], "candidate": output_messages}
    )
    root = Node(output_messages, reflection=reflection)
    return {
        **state,
        "root": root,
    }

In [15]:
# This generates N candidate values
# for a single input to sample actions from the environment


def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):
    n = config["configurable"].get("N", 5)
    bound_kwargs = llm.bind_tools(tools=tools).kwargs
    chat_result = llm.generate(
        [messages.to_messages()],
        n=n,
        callbacks=config["callbacks"],
        run_name="GenerateCandidates",
        **bound_kwargs
    )
    # print(f"bound_kwargs: {bound_kwargs}")
    return [gen.message for gen in chat_result.generations[0]]


expansion_chain = prompt_template | generate_candidates

In [None]:
res = expansion_chain.invoke({"input": "Identify the positive and negative points in this image that may be important in predicting thumbnail CTR. Its path is ./thumbnainal.jpg."})
res

In [16]:

def expand(state: TreeState, config: RunnableConfig) -> dict:
    """Starting from the "best" node in the tree, generate N candidates for the next step."""
    root = state["root"]
    best_candidate: Node = root.best_child if root.children else root
    # print(f"best_candidate: {best_candidate}")
    messages = best_candidate.get_trajectory()
    # Generate N candidates from the single child candidate
    new_candidates = expansion_chain.invoke(
        {"input": state["input"], "messages": messages}, config
    )
    parsed = parser.batch(new_candidates)
    flattened = [
        (i, tool_call)
        for i, tool_calls in enumerate(parsed)
        for tool_call in tool_calls
    ]
    tool_responses = tool_executor.batch(
        [
            ToolInvocation(tool=tool_call["type"], tool_input=tool_call["args"])
            for _, tool_call in flattened
        ]
    )
    collected_responses = defaultdict(list)
    for (i, tool_call), resp in zip(flattened, tool_responses):
        collected_responses[i].append(
            ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])
        )
    output_messages = []
    for i, candidate in enumerate(new_candidates):
        output_messages.append([candidate] + collected_responses[i])

    # Reflect on each candidate
    # For tasks with external validation, you'd add that here.
    reflections = reflection_chain.batch(
        [{"input": state["input"], "candidate": msges} for msges in output_messages],
        config,
    )
    # Grow tree
    child_nodes = [
        Node(cand, parent=best_candidate, reflection=reflection)
        for cand, reflection in zip(output_messages, reflections)
    ]
    best_candidate.children.extend(child_nodes)
    # We have already extended the tree directly, so we just return the state
    return state

In [17]:
from langgraph.graph import END, StateGraph


def should_loop(state: TreeState):
    """Determine whether to continue the tree search."""
    root = state["root"]
    print(f"state: {state}")
    if root.is_solved:
        return END
    if root.height > 5:
        print("Reached max depth")
        return END
    return "expand"


builder = StateGraph(TreeState)
builder.add_node("start", generate_initial_response)
builder.add_node("expand", expand)
builder.set_entry_point("start")


builder.add_conditional_edges(
    "start",
    # Either expand/rollout or finish
    should_loop,
)
builder.add_conditional_edges(
    "expand",
    # Either continue to rollout or finish
    should_loop,
)

graph = builder.compile()

In [24]:
question = "Make a detailed report identifing the positive and negative points in this thumbnail image that could impact the thumbnail CTR. Its path is ./thumbnainal.jpg."
for step in graph.stream({"input": question}, {"recursion_limit":150},):
    step_name, step_state = next(iter(step.items()))
    print(step_name)
    print("rolled out: ", step_state["root"].height)
    print("---")

start
rolled out:  1
---
state: {'root': <Node value=0.9, visits=1, solution=[AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_jPDx1MvVZlNuOwDUNtjkzlC8', 'function': {'arguments': '{"question":"Identify the positive and negative points in this thumbnail image that could impact the thumbnail CTR.","image_path":"./thumbnail.jpg"}', 'name': 'get_image_description'}, 'type': 'function'}]}), ToolMessage(content='"In the image, there is a person who appears to be shooting a firearm at a can of Coca-Cola, as indicated by the visual effects of a bullet hole and a muzzle flash. The text \\"40M\\" suggests some form of celebration or achievement, possibly related to reaching a milestone of 40 million in some context, which could be subscribers, views, or something else significant to the individual. The person is wearing protective hearing gear typically used during shooting for safety. The image is likely edited to dramatize the action for effect or to illustrate a story or 

In [25]:
solution_node = step["__end__"]["root"].get_best_solution()
best_trajectory = solution_node.get_trajectory(include_reflections=False)
print(best_trajectory[-1].content)

I apologize for the oversight. Let me provide a more detailed analysis of the positive and negative points in the thumbnail image that could impact the thumbnail CTR based on the description provided earlier:

Positive Points:
1. Visual Effects: The image includes visual effects such as a bullet hole and a muzzle flash, which can attract attention and create intrigue for viewers.
2. Milestone Indicator: The text "40M" suggests a milestone achievement, which can generate curiosity and interest among viewers.
3. Safety Gear: The presence of protective hearing gear indicates safety consciousness, which can be perceived positively by the audience.

Negative Points:
1. Dramatized Action: The dramatized action of shooting a firearm at a can of Coca-Cola may convey violence or aggression, potentially leading to a negative perception among some viewers.
2. Context Ambiguity: The context of the image, with elements like the firearm and Coca-Cola can, may be unclear or confusing to some viewers,

In [None]:
best_trajectory