In [2]:
pip install langgraph langchain langchain-community langchainhub langchain-core yfinance httpx

Note: you may need to restart the kernel to use updated packages.


In [3]:
import yfinance as yf
from typing import Annotated, TypedDict, Union
import operator
import math

from langchain.agents import create_react_agent
from langchain_community.chat_models import ChatOllama
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.tools import tool
from langchain_core.prompts import PromptTemplate
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation

class AgentState(TypedDict):
    input: str
    chat_history: list[BaseMessage]
    agent_outcome: Union[AgentAction, AgentFinish, None]
    intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add]


ollama_url = "http://127.0.0.1:11434"
model = ChatOllama(base_url=ollama_url, model="mixtral:8x7b")
# Credit: https://smith.langchain.com/hub/hwchase17/react
prompt = PromptTemplate.from_template("""Answer the following questions as best you can. You have access to the following tools:

{tools}

Use the following format:

New input: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Note when passing multiple inputs to tools, you can use the following syntax:
Action Input: pram=value, pram2=value2
                                      
Begin!

Previous conversation history:
{chat_history}
                                      
New input: {input}
Thought:{agent_scratchpad}""",
    agent_scratchpad="",
)

In [4]:
@tool
def get_stock_price(ticker: str):
    """
    Get the latest stock price
    """
    stock = yf.Ticker(ticker)
    return float(stock.history(period="1d")["Close"].values[0])

@tool
def calculate_share_price(number_of_shares: int, price: float):
    """
    Calculate the cost of a given number of shares
    """
    return math.floor((number_of_shares * price) * 100) / 100

@tool
def calculate_share_price_with_a_budget(budget: float, price: float):
    """
    Calculate the max number of shares one can purchase on a given budget
    """
    return math.floor(budget / price)

tools = [get_stock_price, calculate_share_price_with_a_budget, calculate_share_price]

tool_executor = ToolExecutor(tools)

agent_runnable = create_react_agent(model, tools, prompt)

In [5]:
def execute_tools(state):
    messages = [state["agent_outcome"]]
    last_message = messages[-1]

    tool_name = last_message.tool
    # can be a str value or a comma separated string 5, 213.38
    tool_input = last_message.tool_input.split(", ")
    if len(tool_input) == 1:
        tool_input = tool_input[0]
    else:
        tool_input = dict([x.split("=") for x in tool_input])

    print(f"Calling tool: {tool_name}", tool_input)

    action = ToolInvocation(
        tool=tool_name,
        tool_input=tool_input,
    )
    response = tool_executor.invoke(action)
    print("Called `execute_tools`", response)
    return {"intermediate_steps": [(state["agent_outcome"], response)]}


def run_agent(state):
    agent_outcome = agent_runnable.invoke(state)
    return {"agent_outcome": agent_outcome}


def should_continue(state):
    messages = [state["agent_outcome"]]
    last_message = messages[-1]
    if "Action" not in last_message.log:
        return "end"
    else:
        return "continue"

In [6]:
workflow = StateGraph(AgentState)

workflow.add_node("agent", run_agent)
workflow.add_node("action", execute_tools)


workflow.set_entry_point("agent")

workflow.add_conditional_edges(
    "agent", should_continue, {"continue": "action", "end": END}
)


workflow.add_edge("action", "agent")
app = workflow.compile()

In [11]:
inputs = {
    "input": "Whats the cost of 5 MSFT shares?", 
    "chat_history": []
}
results = []
for s in app.stream(inputs):
    result = list(s.values())[0]
    results.append(result)
    if 'agent_outcome' in result and isinstance(result['agent_outcome'], AgentFinish):
        print(result['agent_outcome'].return_values['output'])

Calling tool: get_stock_price MSFT
Called `execute_tools` 441.5799865722656
Calling tool: calculate_share_price {'number_of_shares': '5', 'price': '441.5799865722656'}
Called `execute_tools` 2207.89
The cost of 5 MSFT shares is approximately 2207.90 dollars.
