In [1]:
import os
import sys

# Get the current working directory and add the parent directory to the Python path
current_working_directory = os.getcwd()
sys.path.append(os.path.join(current_working_directory, "../.."))

In [2]:
import pprint

### Defining the Graph state

In [3]:
from typing import TypedDict, Annotated, List, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
import operator
from IPython.display import Image, display

In [4]:
class AgentState(TypedDict):
    input: str
    agent_outcome: Union[AgentAction, AgentFinish, None]
    intermediate_step: Annotated[list, operator.add]
    doc_schema: List[BaseMessage]
    revision_number: int = -1
    max_revisions: int


### Defining Tools

In [5]:
tools = []

#### Search Tools

In [6]:
from langchain_community.tools.arxiv.tool import ArxivQueryRun
from langchain_community.tools.tavily_search import TavilySearchResults

arxiv_search = ArxivQueryRun()
tavily_tool = TavilySearchResults(max_results=5)

search_tools = [arxiv_search, tavily_tool]
search_tools = [arxiv_search]

In [7]:
tools.extend(search_tools)

#### File Tools

In [8]:
from tempfile import TemporaryDirectory

from langchain_community.agent_toolkits import FileManagementToolkit
working_directory = TemporaryDirectory()

file_tools = FileManagementToolkit(
    root_dir=str(working_directory.name),
    selected_tools=["read_file", "write_file", "list_directory"],
).get_tools()
read_tool, write_tool, list_tool = file_tools

In [9]:
tools.extend(file_tools)

In [10]:
from models.llm import LLM

model = LLM('gpt-4o')
llm = model.load_model()

In [11]:
from langgraph.graph import END, StateGraph
workflow = StateGraph(AgentState)

In [12]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

def research_agent(data):
    print("----research node----")
    # print("\n", data["agent_outcome"], "\n")
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a researcher charged with providing information that can be used when writing the following literature review."
                " Given a query, generate a few relevant search querries."
                " Use the appropriate search tools and chat history to progress towards finding the relevant results."
                " Once you have the relevant search results, create a directory retrieved_results and dump the each result (title, authors, summary or context) separately in a file using Title as the file name."
                "\nYou have access to the following search tools: {tool_names}."
            ),
            (
                "human",
                "\nUser Query: {input}"
            ),
            
            MessagesPlaceholder(variable_name="intermediate_step"),
        ]
    )
    
    prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
    agent = prompt | llm.bind_tools(tools)
    result = agent.invoke(data)
    return {'agent_outcome': [result]}

In [None]:
workflow.add_node("research", research_agent)
workflow.set_entry_point("research")

In [14]:
import json
from langchain_core.messages import ToolMessage

class BasicToolNode:
    def __init__(self, tools: list) -> None:
        self.tools_by_name = {tool.name: tool for tool in tools}

    def __call__(self, inputs: dict):
        print("----tool calling----")
        message = inputs["agent_outcome"][-1]

        outputs = []
        for tool_call in message.tool_calls:
            print(f"---- Calling {tool_call['name']} with args: {tool_call['args']} ----")
            tool_result = self.tools_by_name[tool_call["name"]].invoke(
                tool_call["args"]
            )
            outputs.append(
                ToolMessage(
                    content=json.dumps(tool_result),
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )

        return {
                "agent_outcome": outputs,
                "intermediate_step": [str(outputs)]
            }

In [None]:
research_tool_node = BasicToolNode(tools=tools)
workflow.add_node("research_tools", research_tool_node)

In [16]:
def plan_agent(data):
    print("----plan node----")

    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are an expert planner tasked with writing a high level outline of a literature review. \
                    Write an outline for the user provided topic based on the retrieved documents. \
                    Give an outline of the literature review along with any relevant notes or instructions for each of the sections."""
            ),
            (
                "human",
                "\nUser Query: {input}"
            ),
            
            MessagesPlaceholder(variable_name="agent_outcome"),
        ]
    )
    agent = prompt | llm
    result = agent.invoke(data)
    return {'agent_outcome': [result],
            'doc_schema': [result],
            'intermediate_step': [result.content]}

In [None]:
workflow.add_node("plan", plan_agent)

In [18]:
def route_researcher(
    state: AgentState,
):
    """
    Use in the conditional_edge to route to the ToolNode if the last message
    has tool calls. Otherwise, route to the end.
    """
    print("----router----")
    if isinstance(state, list):
        ai_message = state[-1]
    elif agent_outcome := state.get("agent_outcome", []):
        ai_message = agent_outcome[-1]
    else:
        raise ValueError(f"No messages found in input state to tool_edge: {state}")

    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return "END"

In [None]:
workflow.add_conditional_edges(
    "research",
    route_researcher,
    {"tools": "research_tools", "END": "plan"}
)

In [None]:
workflow.add_edge("research_tools", "research")

In [21]:
def write_agent(data):
    print("----write node----")
    
    # print("\n", data["agent_outcome"], "\n")
    
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are an essay assistant tasked with writing excellent literature review.\
                    Given the document outline and retrieved documents generate the best literature review possible. \
                    First write all the sections other than Introduction and Conclusion.
                    Add necessary references with proper citations.
                    If the reviewer provides critique, respond with a revised version of your previous attempts. \
                    
                    Utilize all the documents from the directory retrieved_results. Do not add any other extra information on your own. """
            ),
            
            MessagesPlaceholder(variable_name="intermediate_step"),
        ]
    )

    agent = prompt | llm.bind_tools(file_tools)
    # agent = prompt | llm
    result = agent.invoke(data)
    if hasattr(result, "tool_calls") and len(result.tool_calls) > 0:
        return {'agent_outcome': [result],}
    
    return {'agent_outcome': [result],
            'revision_number': data.get("revision_number", -1) + 1}

In [None]:
workflow.add_node("write", write_agent)

In [None]:
workflow.add_edge("plan", "write")

In [None]:
write_tool_node = BasicToolNode(tools=file_tools)
workflow.add_node("write_tools", write_tool_node)

In [25]:
def review_agent(data):
    print("----review node----")
    print(f"---- Revision Count {data['revision_number']+1} ----")
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are an expert grading a literature review submission. \
                    Generate critique and recommendations for the user's submission. \
                    Provide detailed recommendations, including requests for length, depth, style, etc. \
                    If the generated draft looks perfect, reply appropriately."""
            ),
            
            MessagesPlaceholder(variable_name="agent_outcome"),
        ]
    )
    agent = prompt | llm
    result = agent.invoke(data)
    return {'agent_outcome': [result],
            }

In [None]:
workflow.add_node("review", review_agent)

In [27]:
def route_writer(
    state: AgentState,
):
    """
    Use in the conditional_edge to route to the ToolNode if the last message
    has tool calls. Otherwise, route to the end.
    """
    print("----router----")

    if isinstance(state, list):
        ai_message = state[-1]
    elif agent_outcome := state.get("agent_outcome", []):
        ai_message = agent_outcome[-1]
    else:
        raise ValueError(f"No messages found in input state to tool_edge: {state}")

    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"

    if state["revision_number"] < state["max_revisions"]:
        print(f"---- Revision Count {state['revision_number']+1} ----")
        return "review"
    return "END"

In [None]:
workflow.add_conditional_edges(
    "write",
    route_writer,
    {"tools":"write_tools", "review": "review", "END": "edit"}
)

In [None]:
workflow.add_edge("write_tools", "write")

In [None]:
workflow.add_edge("review", "write")

In [31]:
def edit_agent(data):
    print("---- edit node ----")
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are an experienced editor, expert at literature review submission. \
                    Combine drafts of multiple sections into a single, coherent final document, ensuring a consistent flow, tone, and structure throughout.
                    Follow the the doc_schema generated by the plan node to generate the final draft.
                    """
            ),
            
            MessagesPlaceholder(variable_name="agent_outcome"),
            MessagesPlaceholder(variable_name="doc_schema"),
        ]
    )
    agent = prompt | llm
    result = agent.invoke(data)
    return {'agent_outcome': [result],
            }

In [None]:
workflow.add_node("edit", edit_agent)

In [None]:
workflow.add_edge("edit", END)

In [None]:
app = workflow.compile()
try:
    display(Image(app.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [None]:
inputs = {
    "input": "What are the recent research works on LLM Agents for report or article generation?",
    "max_revisions": 2
}

state = AgentState(**inputs)
for s in app.stream(input=state, config={"recursion_limit": 50}):
    print(list(s.values())[0]['agent_outcome'][0].content)
    print("-----"*20)