# Lesson 6: Essay Writer


In [None]:
from dotenv import load_dotenv
import os
import sys
import json, re
import pprint
import boto3
from botocore.client import Config
import warnings
warnings.filterwarnings("ignore")
import logging

# import local modules
dir_current = os.path.abspath('')
dir_parent = os.path.dirname(dir_current)
if dir_parent not in sys.path:
    sys.path.append(dir_parent)
from utils import utils

# Set basic configs
logger = utils.set_logger()
pp = utils.set_pretty_printer()

# Set main parameters
tavily_api_key_name = "TAVILY_API_KEY"
aws_region = "us-east-1"

# Set bedrock configs
bedrock_config = Config(
    connect_timeout=120, read_timeout=120, retries={"max_attempts": 0}
)

# Create a bedrock runtime client
bedrock_rt = boto3.client(
    "bedrock-runtime",
    region_name=aws_region,
    config=bedrock_config
)

# Retrieve API KEY from env variables or secrets manager
try:
    tavily_ai_api_key = utils.get_from_secretstore_or_env(tavily_api_key_name, aws_region)
    os.environ["TAVILY_API_KEY"] = tavily_ai_api_key
except ValueError as ve:
    logger.error(
        "Could not retrieve the TAVILIY API KEY, neither from the os enviroment variables, nor from AWS Secrets manager!"
    )
    logger.error(ve)

In [None]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, List
import operator
from langgraph.checkpoint.memory import MemorySaver

from langchain_core.messages import (
    AnyMessage,
    SystemMessage,
    HumanMessage,
    AIMessage,
    ChatMessage,
)

memory = MemorySaver()

In [None]:
class AgentState(TypedDict):
    task: str
    plan: str
    draft: str
    critique: str
    content: List[str]
    revision_number: int
    max_revisions: int

In [None]:
from langchain_aws import ChatBedrockConverse

model = ChatBedrockConverse(
    client=bedrock_rt,
    model="anthropic.claude-3-haiku-20240307-v1:0",
    temperature=0,
    max_tokens=None,
)

In [None]:
PLAN_PROMPT = """You are an expert writer tasked with writing a high level outline of an essay. \
Write such an outline for the user provided topic. Give an outline of the essay along with any relevant notes \
or instructions for the sections."""

In [None]:
WRITER_PROMPT = """You are an essay assistant tasked with writing excellent 5-paragraph essays.\
Generate the best essay possible for the user's request and the initial outline. \
If the user provides critique, respond with a revised version of your previous attempts. \
Utilize all the information below as needed: 

------
<content>
{content}
</content>"""

In [None]:
REFLECTION_PROMPT = """You are a teacher grading an essay submission. \
Generate critique and recommendations for the user's submission. \
Provide detailed recommendations, including requests for length, depth, style, etc."""

In [None]:
RESEARCH_PLAN_PROMPT = """You are a researcher charged with providing information that can \
be used when writing the following essay. Generate a list of search queries that will gather \
any relevant information. Only generate 3 queries max."""

In [None]:
RESEARCH_CRITIQUE_PROMPT = """You are a researcher charged with providing information that can \
be used when making any requested revisions (as outlined below). \
Generate a list of search queries that will gather any relevant information. Only generate 3 queries max."""

In [None]:
from langchain_core.pydantic_v1 import BaseModel

class Queries(BaseModel):
    queries: List[str]

In [None]:
from tavily import TavilyClient
import os

tavily = TavilyClient(api_key=tavily_ai_api_key)

In [None]:
def plan_node(state: AgentState):
    messages = [SystemMessage(content=PLAN_PROMPT), HumanMessage(content=state["task"])]
    response = model.invoke(messages)
    return {"plan": response.content}

In [None]:
from typing import List
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
import json


class Queries(BaseModel):
    queries: List[str] = Field(description="List of research queries")


def research_plan_node(state: AgentState):
    # Set up the Pydantic output parser
    parser = PydanticOutputParser(pydantic_object=Queries)

    # Create a prompt template with format instructions
    prompt = PromptTemplate(
        template="Generate research queries based on the given task.\n{format_instructions}\nTask: {task}\n",
        input_variables=["task"],
        partial_variables={"format_instructions": parser.get_format_instructions()},
    )

    # Use the model with the new prompt and parser
    queries_output = model.invoke(prompt.format_prompt(task=state["task"]))

    # Extract the content from the AIMessage
    queries_text = queries_output.content

    # Extract the JSON string from the content
    json_start = queries_text.find("{")
    json_end = queries_text.rfind("}") + 1
    json_str = queries_text[json_start:json_end]

    # Parse the JSON string
    queries_dict = json.loads(json_str)

    # Create a Queries object from the parsed JSON
    parsed_queries = Queries(**queries_dict)

    content = state["content"] or []
    for q in parsed_queries.queries:
        response = tavily.search(query=q, max_results=2)
        for r in response["results"]:
            content.append(r["content"])
    return {"content": content}

In [None]:
def generation_node(state: AgentState):
    content = "\n\n".join(state["content"] or [])
    user_message = HumanMessage(
        content=f"{state['task']}\n\nHere is my plan:\n\n{state['plan']}"
    )
    messages = [
        SystemMessage(content=WRITER_PROMPT.format(content=content)),
        user_message,
    ]
    response = model.invoke(messages)
    return {
        "draft": response.content,
        "revision_number": state.get("revision_number", 1) + 1,
    }

In [None]:
def reflection_node(state: AgentState):
    messages = [
        SystemMessage(content=REFLECTION_PROMPT),
        HumanMessage(content=state["draft"]),
    ]
    response = model.invoke(messages)
    return {"critique": response.content}

In [None]:
def research_critique_node(state: AgentState):
    # Set up the Pydantic output parser
    parser = PydanticOutputParser(pydantic_object=Queries)

    # Create a prompt template with format instructions
    prompt = PromptTemplate(
        template="Generate research queries based on the given critique.\n{format_instructions}\nCritique: {critique}\n",
        input_variables=["critique"],
        partial_variables={"format_instructions": parser.get_format_instructions()},
    )

    # Use the model with the new prompt and parser
    queries_output = model.invoke(prompt.format_prompt(critique=state["critique"]))

    # Extract the content from the AIMessage
    queries_text = queries_output.content

    # Extract the JSON string from the content
    json_start = queries_text.find("{")
    json_end = queries_text.rfind("}") + 1
    json_str = queries_text[json_start:json_end]

    # Parse the JSON string
    queries_dict = json.loads(json_str)

    # Create a Queries object from the parsed JSON
    parsed_queries = Queries(**queries_dict)

    content = state["content"] or []
    for q in parsed_queries.queries:
        response = tavily.search(query=q, max_results=2)
        for r in response["results"]:
            content.append(r["content"])
    return {"content": content}

In [None]:
def should_continue(state):
    if state["revision_number"] > state["max_revisions"]:
        return END
    return "reflect"

In [None]:
builder = StateGraph(AgentState)

In [None]:
builder.add_node("planner", plan_node)
builder.add_node("generate", generation_node)
builder.add_node("reflect", reflection_node)
builder.add_node("research_plan", research_plan_node)
builder.add_node("research_critique", research_critique_node)

In [None]:
builder.set_entry_point("planner")

In [None]:
builder.add_conditional_edges(
    "generate", should_continue, {END: END, "reflect": "reflect"}
)

In [None]:
builder.add_edge("planner", "research_plan")
builder.add_edge("research_plan", "generate")

builder.add_edge("reflect", "research_critique")
builder.add_edge("research_critique", "generate")

In [None]:
graph = builder.compile(checkpointer=memory)

In [None]:
from IPython.display import Image
Image(graph.get_graph().draw_png())

In [None]:
thread = {"configurable": {"thread_id": "1"}}
for s in graph.stream(
    {
        "task": "what is the difference between langchain and langsmith",
        "max_revisions": 2,
        "revision_number": 1,
    },
    thread,
):
    print(s)

## Essay Writer Interface


In [None]:
from helper import ewriter, writer_gui
import gradio as gr

MultiAgent = ewriter()
app = writer_gui(MultiAgent.graph)
app.launch()