In [2]:
import os
from dotenv import load_dotenv

os.environ['GROQ_API_KEY'] = os.getenv('GROQ_API_KEY')

In [3]:
def add(a: int, b: int) -> int:
    """Add two integers."""
    return a + b

def subtract(a: int, b: int) -> int:
    """Subtract second integer from first."""
    return a - b


def multiply(a: int, b: int) -> int:
    """Multiply two integers."""
    return a * b


def divide(a: int, b: int) -> float:
    """Divide first integer by second."""
    return a / b


In [4]:
tools = [add, subtract, multiply, divide]

In [7]:
from langchain_groq import ChatGroq
llm = ChatGroq(
  model = 'llama-3.1-8b-instant',
  )

llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False)

In [8]:
from typing_extensions import TypedDict
from langchain_core.messages import AnyMessage
from typing import Annotated
from langgraph.graph.message import add_messages

class MessageState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    

In [9]:
from langchain_core.messages import SystemMessage, HumanMessage

system_message = SystemMessage(
    content="You are a helpful assistant that can perform basic arithmetic operations."
)
def assistant(state:MessageState):
    return {"messages":[llm_with_tools.invoke([system_message] + state["messages"])]}

In [14]:
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition, ToolNode
from IPython.display import display, Image

builder = StateGraph(MessageState)

builder.add_node('assistant', assistant)
builder.add_node('tools', ToolNode(tools))

builder.add_edge(START, 'assistant')
builder.add_conditional_edges(
    'assistant', 
    #if the latest msg from the assistant is a tool call --> tool condition routes to the tools node and if the latest msg is not a tool call tools_condition routes to END
    tools_condition
    )

builder.add_edge('tools', 'assistant')

react_graph = builder.compile()
# display(Image(react_graph.get_graph().draw_mermaid_png()))


# mermaid_code = react_graph.get_graph().draw_mermaid()

# # Print the Mermaid code
# print(mermaid_code)

In [15]:
messages = [HumanMessage(content="What is 2 + 23 - 2 multiply by 13 and divide by 3?")]
messages = react_graph.invoke({'messages': messages})

In [16]:
for m in messages['messages']:
    m.pretty_print()


What is 2 + 23 - 2 multiply by 13 and divide by 3?
Tool Calls:
  add (174m9efhq)
 Call ID: 174m9efhq
  Args:
    a: 2
    b: 23
Name: add

25
Tool Calls:
  subtract (97ps2039e)
 Call ID: 97ps2039e
  Args:
    a: 25
    b: 2
Name: subtract

23
Tool Calls:
  multiply (ghch0hg0r)
 Call ID: ghch0hg0r
  Args:
    a: 23
    b: 13
Name: multiply

299
Tool Calls:
  divide (2mewj5bme)
 Call ID: 2mewj5bme
  Args:
    a: 299
    b: 3
Name: divide

99.66666666666667

The result is 99.67.


In [17]:
question = [HumanMessage(content='What is 12+13?')]
response = react_graph.invoke({'messages': question})
for m in response['messages']:
    m.pretty_print()


What is 12+13?
Tool Calls:
  add (acwe3eqwd)
 Call ID: acwe3eqwd
  Args:
    a: 12
    b: 13
Name: add

25

That's the result of the operation.


In [19]:
question = [HumanMessage(content='Multiply that by 2')]
response = react_graph.invoke({'messages': question})
for m in response['messages']:
    m.pretty_print()  

GraphRecursionError: Recursion limit of 25 reached without hitting a stop condition. You can increase the limit by setting the `recursion_limit` config key.
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/GRAPH_RECURSION_LIMIT

#### Here if we see the llm was not able to respond or make a tool call because it does not have a memory of previous messages

## Memory in Agents using MemorySaver

In [13]:
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()

react_graph = builder.compile(checkpointer=memory)

In [14]:
config = {'configurable': {'thread_id': '1'}}

In [15]:
question = [HumanMessage(content='What is 12+13?')]
response = react_graph.invoke({'messages': question}, config)

In [16]:
for each_interaction in response['messages']:
    each_interaction.pretty_print()


What is 12+13?
Tool Calls:
  add (call_rASGkCmyNVjcQcxBXb4u95Lx)
 Call ID: call_rASGkCmyNVjcQcxBXb4u95Lx
  Args:
    a: 12
    b: 13
Name: add

25

12 + 13 equals 25.


In [17]:
question = [HumanMessage(content='Now multiply that by 2')]
response = react_graph.invoke({'messages':question}, config)

In [18]:
for each_interaction in response['messages']:
    each_interaction.pretty_print()


What is 12+13?
Tool Calls:
  add (call_rASGkCmyNVjcQcxBXb4u95Lx)
 Call ID: call_rASGkCmyNVjcQcxBXb4u95Lx
  Args:
    a: 12
    b: 13
Name: add

25

12 + 13 equals 25.

Now multiply that by 2
Tool Calls:
  multiply (call_kj1rKiqDYfrngDl9lyULU1uf)
 Call ID: call_kj1rKiqDYfrngDl9lyULU1uf
  Args:
    a: 25
    b: 2
Name: multiply

50

25 multiplied by 2 equals 50.
