In [2]:
import base64
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage

def encode_image_to_base64(image_path):
    with open(image_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
    return encoded_string

@tool
def get_image_description(question:str, image_path:str):
    """Useful for describing an image by answering questions about it. Need a question and the path to the image."""
    llm = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=1028)
    image = encode_image_to_base64(image_path)
    result = llm.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": "{question}"},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{image}"},
                    },
                ]
            )
        ]
    ).content
    return result

In [3]:
tools = [get_image_description]

In [4]:
from collections import defaultdict
from typing import List
import json

from langchain.output_parsers.openai_tools import (
    JsonOutputToolsParser,
    PydanticToolsParser,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation

tool_executor = ToolExecutor(tools)
parser = JsonOutputToolsParser(return_id=True)

def execute_tools(state: List[BaseMessage]) -> List[BaseMessage]:
    tool_invocation: AIMessage = state[-1]
    parsed_tool_calls = parser.invoke(tool_invocation)
    print(f"parsed_tool_calls: {parsed_tool_calls}")
    ids = []
    outputs = []
    print(f"parsed_tool_calls[0]['id']: {parsed_tool_calls[0]['id']}")
    for i in range(len(parsed_tool_calls[0]['args']['questions'])):
        tool_invocation = ToolInvocation(
                tool="get_image_description",
                tool_input={"question": parsed_tool_calls[0]['args']['questions'][i], "image_path": parsed_tool_calls[0]['args']['image_path'][0]},
            )
        ids.append(parsed_tool_calls[0]["id"])
        outputs.append(tool_executor.invoke(tool_invocation))
    print(f"outputs: {outputs}")
    outputs_map = defaultdict(dict)
    outputs_map[ids[0]][tool_invocation.tool] = outputs
    result = [ToolMessage(content=json.dumps(query_outputs), tool_call_id=id_) for id_, query_outputs in outputs_map.items()]
    return result

In [5]:
import datetime

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_openai import ChatOpenAI
from langsmith import traceable

actor_prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """You are an expert in video thumbnails for social media, provided with a question and an image.
             Current time: {time}
             1. {first_instruction}
             2. Reflect and critique your answer. Be severe to maximize improvement.
             3. Recommend questions to get information and improve your answer.""",
        ),
        MessagesPlaceholder(variable_name="messages"),
        ("system", "Answer the user's question above using the required format."),
    ]
).partial(
    time=lambda: datetime.datetime.now().isoformat(),
)

class Reflection(BaseModel):
    missing: str = Field(description="Critique of what is missing.")
    superfluous: str = Field(description="Critique of what is superfluous.")

class AnswerQuestion(BaseModel):
    """Answer a question."""
    answer: str = Field(description="~100 word detailed answer to the question.")
    reflection: Reflection = Field(description="Your reflection on the initial answer.")
    questions: List[str] = Field(description="1-3 questions for the image describing tool for improvements to address the critique of your current answer.")
    image_path: List[str] = Field(description="Path to the image.")

llm = ChatOpenAI(model="gpt-4-turbo-preview")

initial_answer_chain = actor_prompt_template.partial(
    first_instruction="Provide some questions you need to answer the user's question."
) | llm.bind_tools(tools=[AnswerQuestion], tool_choice="AnswerQuestion")

validator = PydanticToolsParser(tools=[AnswerQuestion])

class ResponderWithRetries:
    def __init__(self, runnable, validator):
        self.runnable = runnable
        self.validator = validator

    @traceable
    def respond(self, state: List[BaseMessage]):
        response = []
        for attempt in range(3):
            try:
                response = self.runnable.invoke({"messages":state})
                self.validator.invoke(response)
                return response
            except ValidationError as e:
                state = state + [HumanMessage(content=repr(e))]
        return response

In [6]:
first_responder = ResponderWithRetries(
    runnable=initial_answer_chain,
    validator=validator,
)

In [None]:
example_question = "What are the points that can be improved in this thumbnail?. Its path is ./thumbnail.jpg."
initial = first_responder.respond([HumanMessage(content=example_question)])

In [None]:
parsed = parser.invoke(initial)
parsed

In [7]:
revise_instructions = """Revise your previous answer using the new information.
    - You should use the previous critique to add important information to your answer.
    - You should use the previous critique to remove superfluous information from your answer and make SURE it is not more than 100 words.
"""

class ReviseAnswer(AnswerQuestion):
    """Revise your original answer to your question."""

    ideas: List[str] = Field(description="Ideas motivating your updated answer.")

revision_chain = actor_prompt_template.partial(
    first_instruction=revise_instructions
) | llm.bind_tools(tools=[ReviseAnswer], tool_choice="ReviseAnswer")

revision_validator = PydanticToolsParser(tools=[ReviseAnswer])

revisor = ResponderWithRetries(
    runnable=revision_chain,
    validator=revision_validator,
)

In [8]:
revised = revisor.respond(
    [
        HumanMessage(content=""),
        initial,
        ToolMessage(
            tool_call_id=initial.additional_kwargs["tool_calls"][0]["id"],
            content=json.dumps(
                get_image_description.batch([({"question": x, "image_path": parsed[0]['args']['image_path'][0]}) for x in parsed[0]['args']['questions']])
            ),
        ),
    ]
)

NameError: name 'initial' is not defined

In [None]:
parsed = parser.invoke(revised)
parsed

In [9]:
from langgraph.graph import END, MessageGraph

MAX_ITERATIONS = 5
builder = MessageGraph()
builder.add_node("draft", first_responder.respond)
builder.add_node("execute_tools", execute_tools)
builder.add_node("revise", revisor.respond)
# draft -> execute_tools
builder.add_edge("draft", "execute_tools")
# execute_tools -> revise
builder.add_edge("execute_tools", "revise")

# Define looping logic:


def _get_num_iterations(state: List[BaseMessage]):
    i = 0
    for m in state[::-1]:
        if not isinstance(m, (ToolMessage, AIMessage)):
            break
        i += 1
    return i


def event_loop(state: List[BaseMessage]) -> str:
    # in our case, we'll just stop after N plans
    num_iterations = _get_num_iterations(state)
    if num_iterations > MAX_ITERATIONS:
        return END
    return "execute_tools"


# revise -> execute_tools OR end
builder.add_conditional_edges("revise", event_loop)
builder.set_entry_point("draft")
graph = builder.compile()

In [10]:
events = graph.stream(
    [HumanMessage(content="What are the points that can be improved in this thumbnail?. It's path is ./thumbnail.jpg.")]
)
for i, step in enumerate(events):
    node, output = next(iter(step.items()))
    print(f"## {i+1}. {node}")
    print(str(output))
    print("---")

## 1. draft
content='' additional_kwargs={'tool_calls': [{'id': 'call_F1rs074EYo0WbcIkl0yiWUNO', 'function': {'arguments': '{"answer":"To provide a detailed critique and suggestions for improvement on the thumbnail, I would need to analyze its visual elements, composition, and text. However, without seeing the thumbnail, I can offer some general points that often need improvement in video thumbnails for social media: 1. Clarity and Quality: Ensure the thumbnail is high-resolution and the main subject is clear and focused. 2. Text Legibility: If text is used, it should be easy to read against the background. Consider font size, color, and contrast. 3. Composition: The thumbnail should have a balanced composition that draws the viewer\'s attention to the main subject or action. 4. Branding: Consistent use of colors, logos, or styles can help viewers instantly recognize your content. 5. Emotional Appeal: Thumbnails that evoke curiosity, excitement, or other emotions can be more effective 

In [14]:
print(parser.invoke(step[END][-1])[0]["args"]["answer"])

