# Agents Series - Episode 3
_Reflection agents_

# Installs

In [0]:
%pip install langgraph langchain-openai==0.3.0

In [0]:
dbutils.library.restartPython()

# Setup

In [0]:
import os

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import AzureChatOpenAI

In [0]:
os.environ["AZURE_OPENAI_ENDPOINT"] = dbutils.secrets.get(<<SCOPE>>, <<KEY>>)
os.environ["AZURE_OPENAI_API_KEY"] = dbutils.secrets.get(<<SCOPE>>, <<KEY>>)

## LLM

In [0]:
clever_llm = AzureChatOpenAI(
    model_name="gpt4o",
    openai_api_version="2024-08-01-preview",
    temperature= 0.9
)

In [0]:
dumb_llm = AzureChatOpenAI(
    model_name="gpt-35-turbo-16k",
    openai_api_version="2024-08-01-preview",
    temperature= 0.9
)

# Maker

In [0]:
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an essay assistant tasked with writing haikus."
            "Generate the best haiku possible for the user's requested topic."
            "If the user provides critique, respond with a revised version of your previous attempts.",
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)

generate = prompt | clever_llm

In [0]:
essay = ""
request = HumanMessage(
    content="Write a Haiku about turtles."
)

In [0]:
essay = generate.invoke({"messages": [request]})

# Checker

In [0]:
reflection_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a really mean poetry teacher marking a Haiku submission. Generate critique and recommendations for the user's submission."
            "Provide detailed recommendations, including requests for themes, wording, and tone."
            "If you cannot find anything wrong with the Haiku, make something up. Nothing you recieve is ever good enough."
            "Do not give away examples which the user can copy.",
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)
reflect = reflection_prompt | dumb_llm

In [0]:
reflection = reflect.invoke({"messages": [request, HumanMessage(content=essay.content)]})

# Repeat

In [0]:
generate.invoke(
    {"messages": [request, AIMessage(content=essay.content), HumanMessage(content=reflection.content)]}
)

# Graph

In [0]:
from typing import Annotated, List, Sequence
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
from typing_extensions import TypedDict

In [0]:
class State(TypedDict):
    messages: Annotated[list, add_messages]

In [0]:
async def generation_node(state: State) -> State:
    return {"messages": [await generate.ainvoke(state["messages"])]}

In [0]:
async def reflection_node(state: State) -> State:
    # Other messages we need to adjust
    cls_map = {"ai": HumanMessage, "human": AIMessage}
    # First message is the original user request. We hold it the same for all nodes
    translated = [state["messages"][0]] + [
        cls_map[msg.type](content=msg.content) for msg in state["messages"][1:]
    ]
    res = await reflect.ainvoke(translated)
    # We treat the output of this as human feedback for the generator
    return {"messages": [HumanMessage(content=res.content)]}

In [0]:
workflow = StateGraph(State)
workflow.add_node("generate", generation_node)
workflow.add_node("reflect", reflection_node)
workflow.add_edge(START, "generate")

In [0]:
def should_continue(state: State):
    if len(state["messages"]) > 6:
        # End after 3 iterations
        return END
    return "reflect"

In [0]:
workflow.add_conditional_edges("generate", should_continue)
workflow.add_edge("reflect", "generate")
memory = MemorySaver()
graph = workflow.compile(checkpointer=memory)

In [0]:
config = {"configurable": {"thread_id": "5"}}

In [0]:
# Run the graph
async for event in graph.astream(
    {
        "messages": [
            HumanMessage(
                content="Generate a Haiku on the world's tallest tree."
            )
        ],
    },
    config,
):
    print(event)
    print("---")

In [0]:
state = graph.get_state(config)

In [0]:
ChatPromptTemplate.from_messages(state.values["messages"]).pretty_print()

In [0]:
app = workflow.compile()

In [0]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        graph.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)