In [1]:
import os
from typing import Annotated, Sequence, TypedDict, Literal
import operator
import functools
from datetime import datetime

from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
    ToolMessage,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import tool
from langchain_experimental.utilities import PythonREPL
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage

from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode

# Define Tools

In [2]:
os.environ["TAVILY_API_KEY"] = "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"
tavily_tool = TavilySearchResults(max_results=5)

repl = PythonREPL()

@tool
def python_repl(
    code: Annotated[str, "The python code to execute to generate your chart."],
):
    """Use this to execute python code. If you want to see the output of a value,
    you should print it out with `print(...)`. This is visible to the user."""
    try:
        result = repl.run(code)
    except BaseException as e:
        return f"Failed to execute. Error: {repr(e)}"
    result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
    return (
        result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
    )

@tool
def emb_finder(
    message: Annotated[str, "The python code to execute to generate your chart."],
):
    """Use this to execute python code. If you want to see the output of a value,
    you should print it out with `print(...)`. This is visible to the user."""
    
    try:
        print("!!")
        print(message)
        date_str, question = message.split(":::")

        # Step 2: Remove the brackets around the dates
        date_str = date_str.strip("[]")
        # Step 3: Convert the string of dates to a list
        date_list = date_str.split(", ")
        print("@@")
        print(date_list)
        
    except BaseException as e:
        return f"Failed to execute. Error: {repr(e)}"
    return (
        message
    )

# State of Graph
A list of messages, along with a key to track the most recent sender

In [3]:
# This defines the object that is passed between each node
# in the graph. We will create different nodes for each agent and tool
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    sender: str

# Define Nodes
Agent Nodes, Tool Nodes

In [4]:
"""
Data Nodes
"""
def agent_node(state, agent, name):
    """Helper function to create a node for a given agent"""
    result = agent.invoke(state)
    # We convert the agent output into a format that is suitable to append to the global state
    if isinstance(result, ToolMessage):
        pass
    else:
        result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)    
    return {
        "messages": [result],
        # Since we have a strict workflow, we can
        # track the sender so we know who to pass to next.
        "sender": name,
    }


llm = ChatOpenAI(model="gpt-4o", api_key="sk-proj-DFEqmV2bESTGXITqzVrHT3BlbkFJ3ndYJrjURSkNmALp5kqS")

# Date Agent
current_date = datetime.now().strftime("%Y-%m-%d") 
file_list = ["2023-08-23.txt", "2024-01-10.txt", "2024-05-30.txt", "2024-06-19.txt", "2024-06-20.txt"]
prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You do not need to answer question itself."
                "You have two goals, one is to find date information from question and other is to pass the question for the next agent to answer."
                "Make '[date1, date2, ...]:::(question)' format to answer, when (question) is the original question from the user."
                "\n{system_message}"
                ,
            ),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )
prompt = prompt.partial(system_message=f"The current date is {current_date}. If the question does not have any clue about date, use the current date. If the question has a clue about date, find all the related dates from the list {file_list}. For example, If current date is '2023-08-23' and question indicating this year, you have to answer with all the date list with 2023, like '2023-01-15, 2023-02-11, 2023-05-19...'")
data_agent =  prompt | llm

date_node = functools.partial(agent_node, agent=data_agent, name="date_finder")

In [5]:
"""
Tool Nodes
"""
tools = [emb_finder]
tool_node = ToolNode(tools)

In [6]:
"""
Embedding Node
"""
prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You have to use tools to find the embedding with the proper date,"
                "Remember, you must use tool with the answer received from previous agent without any editing or deleting with the format of '[date1, date2, ...]:::(question)'."
                "Remember, don't miss or ignore single date element from received answer when you pass the state to tools."
                "You have access to the following tools: {tool_names}.",
            ),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
chart_agent =  prompt | llm.bind_tools(tools) 

chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_generator")

# Define Edge Logic
edge logic that is needed to decide what to do based on results of the agents

In [7]:
def router(state) -> Literal["call_tool", "__end__", "continue"]:
    """
        This is the router
        Either agent can decide to end
    """
    messages = state["messages"]
    last_message = messages[-1]
    
    if last_message.tool_calls:
        # The previous agent is invoking a tool
        return "call_tool"
    return "continue"

def router2(state) -> Literal["call_tool", "__end__", "continue"]:
    return "continue"

# Define Graph

In [8]:
workflow = StateGraph(AgentState)

workflow.add_node("date_finder", date_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_node("call_tool", tool_node)

workflow.add_conditional_edges(
    "date_finder",
    router,
    {"continue": "chart_generator", "call_tool": "call_tool", "__end__": END},
)
workflow.add_conditional_edges(
    "chart_generator",
    router,
    {"continue": END, "call_tool": "call_tool", "__end__": END},
)

workflow.add_conditional_edges(
    "call_tool",
    router2, # date_finder or chart_generator
    {
        "continue": END,
    },
)

workflow.set_entry_point("date_finder")
graph = workflow.compile()

In [9]:
import json
from typing import List, Dict, Any, Union

def extract_content_and_urls(value: Dict[str, Any]) -> List[Dict[str, Union[str, Dict[str, str]]]]:
    result = []
    possible_keys = ['call_tool', 'date_finder', 'chart_generator']

    for key in possible_keys:
        if key in value:
            data = value[key]
            if 'messages' in data:
                messages = data['messages']
                if isinstance(messages, list) and len(messages) > 0:
                    print(key + "/")
                    message = messages[0]
                    content = message.content
                    # Check if the content is a JSON string
                    try:
                        json_content = json.loads(content)
                        # Handle case where content is a JSON string
                        for item in json_content:
                            url = item.get('url')
                            content = item.get('content')
                            result.append({'url': url, 'content': content})
                    except json.JSONDecodeError:
                        # Handle case where content is a regular string
                        result.append({'content': content})
            break  # Stop after finding the first valid key
    return result

In [10]:
while True:
    user_input = input("User: ") # what's the chinese zodiac of last year?
    if user_input.lower() in ["quit", "exit", "q"]:
        print("Goodbye!")
        break
    events = graph.stream(
    {
        "messages": [
            HumanMessage(
                content=user_input
            )
        ],
    },
    # Maximum number of steps to take in the graph
    {"recursion_limit": 5},
    )
    for event in events:
        print(extract_content_and_urls(event))
        print("----")

date_finder/
[{'content': '[2024-01-10, 2024-05-30, 2024-06-19, 2024-06-20]:::(what happend this year?)'}]
----
chart_generator/
[{'content': ''}]
----
!!
[2024-01-10, 2024-05-30, 2024-06-19, 2024-06-20]:::(what happend this year?)
@@
['2024-01-10', '2024-05-30', '2024-06-19', '2024-06-20']
call_tool/
[{'content': '[2024-01-10, 2024-05-30, 2024-06-19, 2024-06-20]:::(what happend this year?)'}]
----
Goodbye!
