In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import nest_asyncio

nest_asyncio.apply()

In [None]:
import operator
import os
from typing import Annotated

from dotenv import load_dotenv
from IPython.display import Image, display
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import WikipediaLoader
from langchain_core.messages import (
    AIMessage,
    HumanMessage,
    SystemMessage,
    get_buffer_string,
)
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.types import Send
from pydantic import BaseModel, Field
from serpapi import GoogleSearch
from typing_extensions import TypedDict

load_dotenv()

llm = ChatOpenAI(model="gpt-4.1-mini", temperature=0)


def load_prompt(prompt_filename, partial_variables=None):
    partial_variables = partial_variables or {}
    with open(prompt_filename, "r") as f:
        file_content = f.read()
        return PromptTemplate.from_template(
            file_content, template_format="jinja2", partial_variables=partial_variables
        ).format()

## Interview Agent

In [None]:
class InterviewState(MessagesState):
    max_num_turns: int
    context: Annotated[list, operator.add]
    interview: str
    sections: list


class SearchQuery(BaseModel):
    search_query: str = Field(None, description="Search query for retrieval.")


def generate_question(state: InterviewState):
    messages = state["messages"]
    analyst_system_prompt = load_prompt(
        "./assets/15_complex_agent/analyst_system_prompt.jinja2"
    )
    question = llm.invoke([SystemMessage(content=analyst_system_prompt)] + messages)
    return {"messages": [question]}


def search_web(state: InterviewState):
    structured_llm = llm.with_structured_output(SearchQuery)
    search_system_prompt = load_prompt(
        "assets/15_complex_agent/search_system_prompt.jinja2"
    )
    search_query = structured_llm.invoke([search_system_prompt] + state["messages"])

    params = {
        "q": search_query.search_query,
        "hl": "en",
        "google_domain": "google.com",
        "api_key": os.getenv("SERPAPI_API_KEY"),
    }

    search = GoogleSearch(params)
    results = search.get_dict()
    organic_results = results["organic_results"] if "organic_results" in results else []

    formatted_results = "\n\n---\n\n".join(
        [
            f'<Document source="{doc["link"]}" page="{doc["title"]}"/>\n{doc["snippet"]}\n</Document>'
            for doc in organic_results
        ]
    )
    return {"context": [formatted_results]}


def search_wikipedia(state: InterviewState):
    structured_llm = llm.with_structured_output(SearchQuery)
    search_system_prompt = load_prompt(
        "assets/15_complex_agent/search_system_prompt.jinja2"
    )
    search_query = structured_llm.invoke([search_system_prompt] + state["messages"])

    search_docs = WikipediaLoader(
        query=search_query.search_query, load_max_docs=2
    ).load()

    formatted_search_docs = "\n\n---\n\n".join(
        [
            f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
            for doc in search_docs
        ]
    )

    return {"context": [formatted_search_docs]}


def generate_answer(state: InterviewState):
    expert_system_prompt = load_prompt(
        "assets/15_complex_agent/expert_system_prompt.jinja2",
        {"context": state["context"]},
    )
    answer = llm.invoke(
        [SystemMessage(content=expert_system_prompt)] + state["messages"]
    )
    answer.name = "expert"

    return {"messages": [answer]}


def save_interview(state: InterviewState):
    messages = state["messages"]
    interview = get_buffer_string(messages)
    return {"interview": interview}


def route_messages(state: InterviewState, name: str = "expert"):
    messages = state["messages"]
    max_num_turns = state.get("max_num_turns", 2)
    num_responses = len(
        [m for m in messages if isinstance(m, AIMessage) and m.name == name]
    )

    if num_responses >= max_num_turns:
        return "save_interview"

    last_question = messages[-2]

    if "Thank you. That's all." in last_question.content:
        return "save_interview"
    return "ask_question"


def write_section(state: InterviewState):
    context = state["context"]
    writer_system_prompt = load_prompt(
        "assets/15_complex_agent/section_writer_system_prompt.jinja2"
    )
    section = llm.invoke(
        [SystemMessage(content=writer_system_prompt)]
        + [HumanMessage(content=f"Use this source to write your section: {context}")]
    )
    return {"sections": [section.content]}


interview_builder = StateGraph(InterviewState)
interview_builder.add_node("ask_question", generate_question)
interview_builder.add_node("search_web", search_web)
interview_builder.add_node("search_wikipedia", search_wikipedia)
interview_builder.add_node("answer_question", generate_answer)
interview_builder.add_node("save_interview", save_interview)
interview_builder.add_node("write_section", write_section)

interview_builder.add_edge(START, "ask_question")
interview_builder.add_edge("ask_question", "search_web")
interview_builder.add_edge("ask_question", "search_wikipedia")
interview_builder.add_edge("search_web", "answer_question")
interview_builder.add_edge("search_wikipedia", "answer_question")
interview_builder.add_conditional_edges(
    "answer_question", route_messages, ["ask_question", "save_interview"]
)
interview_builder.add_edge("save_interview", "write_section")
interview_builder.add_edge("write_section", END)

memory = MemorySaver()
interview_agent = interview_builder.compile(checkpointer=memory)

# View
display(Image(interview_agent.get_graph().draw_mermaid_png()))

In [None]:
config = {"configurable": {"thread_id": "1"}}
message = HumanMessage(content="So you said you were writing an article on LangGraph?")
interview_agent.invoke({"messages": [message]}, config=config)

## Research Agent

In [None]:
class ResearchGraphState(TypedDict):
    topic: str
    analysts: int
    sections: Annotated[list, operator.add]
    introduction: str
    content: str
    conclusion: str
    final_report: str


def initiate_all_interviews(state: ResearchGraphState):
    topic = state["topic"]
    return [
        Send(
            "conduct_interview",
            {
                "messages": [
                    HumanMessage(
                        content=f"So you said you were writing an article on {topic}?"
                    )
                ]
            },
        )
        for _ in range(state["analysts"])
    ]


def write_report(state: ResearchGraphState):
    formatted_str_sections = "\n\n".join(
        [f"{section}" for section in state["sections"]]
    )
    system_message = load_prompt(
        "assets/15_complex_agent/report_writer_system_prompt.jinja2",
        {"topic": state["topic"], "context": formatted_str_sections},
    )
    report = llm.invoke(
        [SystemMessage(content=system_message)]
        + [HumanMessage(content="Write a report based upon these memos.")]
    )
    return {"content": report.content}


def write_introduction(state: ResearchGraphState):
    formatted_str_sections = "\n\n".join(
        [f"{section}" for section in state["sections"]]
    )
    system_message = load_prompt(
        "assets/15_complex_agent/intro_conclusion_system_prompt.jinja2",
        {"topic": state["topic"], "formatted_str_sections": formatted_str_sections},
    )
    intro = llm.invoke(
        [SystemMessage(content=system_message)]
        + [HumanMessage(content="Write the report introduction")]
    )
    return {"introduction": intro.content}


def write_conclusion(state: ResearchGraphState):
    formatted_str_sections = "\n\n".join(
        [f"{section}" for section in state["sections"]]
    )
    system_message = load_prompt(
        "assets/15_complex_agent/intro_conclusion_system_prompt.jinja2",
        {"topic": state["topic"], "formatted_str_sections": formatted_str_sections},
    )
    conclusion = llm.invoke(
        [SystemMessage(content=system_message)]
        + [HumanMessage(content="Write the report conclusion")]
    )
    return {"conclusion": conclusion.content}


def finalize_report(state: ResearchGraphState):
    content = state["content"]
    if content.startswith("## Insights"):
        content = content.strip("## Insights")
    if "## Sources" in content:
        try:
            content, sources = content.split("\n## Sources\n")
        except Exception:
            sources = None
    else:
        sources = None

    final_report = (
        state["introduction"]
        + "\n\n---\n\n"
        + content
        + "\n\n---\n\n"
        + state["conclusion"]
    )
    if sources is not None:
        final_report += "\n\n## Sources\n" + sources
    return {"final_report": final_report}


builder = StateGraph(ResearchGraphState)
builder.add_node("conduct_interview", interview_builder.compile())
builder.add_node("write_report", write_report)
builder.add_node("write_introduction", write_introduction)
builder.add_node("write_conclusion", write_conclusion)
builder.add_node("finalize_report", finalize_report)

builder.add_conditional_edges(
    source=START, path=initiate_all_interviews, path_map=["conduct_interview"]
)
builder.add_edge("conduct_interview", "write_report")
builder.add_edge("conduct_interview", "write_introduction")
builder.add_edge("conduct_interview", "write_conclusion")
builder.add_edge(
    ["write_conclusion", "write_report", "write_introduction"], "finalize_report"
)
builder.add_edge("finalize_report", END)

memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
display(Image(graph.get_graph(xray=1).draw_mermaid_png()))

In [None]:
state = graph.invoke(
    {"topic": "LangGraph", "analysts": 4},
    config={"configurable": {"thread_id": "1"}, "max_concurrency": 15},
)