# 2. import

In [None]:
import os
from dotenv import find_dotenv, load_dotenv
_ = load_dotenv(find_dotenv())

from llama_index.core.llms import ChatMessage

from llama_index.core.workflow import Event
from typing import Any, List
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import Memory
from llama_index.core.tools import ToolSelection, ToolOutput

from llama_index.core.tools.types import BaseTool
from llama_index.core.workflow import (
    Context,
    Workflow,
    StartEvent,
    StopEvent,
    step,
)
from llama_index.llms.openai import OpenAI

from llama_index.utils.workflow import draw_all_possible_flows

# 3. event

In [None]:
class InputEvent(Event):
    input: list[ChatMessage]

class ToolCallEvent(Event):
    tool_calls: list[ToolSelection]

class FunctionOutputEvent(Event):
    output: ToolOutput

class StreamEvent(Event):
    msg: ChatMessage

# 4. Workflow

In [None]:
class FunctionCallingAgent(Workflow):
    def __init__(
        self,
        *args: Any,
        llm: FunctionCallingLLM | None = None,
        tools: List[BaseTool] | None = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.tools = tools or []

        self.llm = llm or OpenAI()
        assert self.llm.metadata.is_function_calling_model

    @step
    async def prepare_chat_history(
        self, ctx: Context, ev: StartEvent
    ) -> InputEvent:
        # clear sources
        await ctx.store.set("sources", [])

        # check if memory is setup
        memory = await ctx.store.get("memory", default=None)
        if not memory:
            #memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
            memory = Memory.from_defaults(token_limit=40000)

        # get user input
        user_input = ev.input
        user_msg = ChatMessage(role="user", content=user_input)
        memory.put(user_msg)

        # get chat history
        chat_history = memory.get()

        # update context
        await ctx.store.set("memory", memory)

        return InputEvent(input=chat_history)

    @step
    async def handle_llm_input(
        self, ctx: Context, ev: InputEvent
    ) -> ToolCallEvent | StopEvent:
        chat_history = ev.input

        # stream the response
        #response_stream = await self.llm.astream_chat_with_tools(
        #    self.tools, chat_history=chat_history
        #)
        #async for response in response_stream:
        #    ctx.write_event_to_stream(StreamEvent(delta=response.delta or ""))
        response = await self.llm.achat_with_tools(
            self.tools, chat_history=chat_history
        )
        ctx.write_event_to_stream(StreamEvent(msg=response.message))

        # save the final response, which should have all content
        memory = await ctx.store.get("memory")
        memory.put(response.message)
        await ctx.store.set("memory", memory)

        # get tool calls
        tool_calls = self.llm.get_tool_calls_from_response(
            response, error_on_no_tool_call=False
        )  # 如果這邊沒有 tool calls 回傳就會是 []

        if not tool_calls:
            sources = await ctx.store.get("sources", default=[])
            return StopEvent(
                result={"response": response, "sources": [*sources]}
            )
        else:
            return ToolCallEvent(tool_calls=tool_calls)

    @step
    async def handle_tool_calls(
        self, ctx: Context, ev: ToolCallEvent
    ) -> InputEvent:
        tool_calls = ev.tool_calls  # model 要 call 的 tool
        tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}  # 可以使用的 tool

        tool_msgs = []
        sources = await ctx.store.get("sources", default=[])

        # call tools -- safely!
        for tool_call in tool_calls:
            tool = tools_by_name.get(tool_call.tool_name)
            additional_kwargs = {
                "tool_call_id": tool_call.tool_id,
                "name": tool.metadata.get_name(),
            }
            if not tool:
                tool_msgs.append(
                    ChatMessage(
                        role="tool",
                        content=f"Tool {tool_call.tool_name} does not exist",
                        additional_kwargs=additional_kwargs,
                    )
                )
                continue

            try:
                tool_output = tool(**tool_call.tool_kwargs)
                sources.append(tool_output)
                tool_msgs.append(
                    ChatMessage(
                        role="tool",
                        content=tool_output.content,
                        additional_kwargs=additional_kwargs,
                    )
                )
            except Exception as e:
                tool_msgs.append(
                    ChatMessage(
                        role="tool",
                        content=f"Encountered error in tool call: {e}",
                        additional_kwargs=additional_kwargs,
                    )
                )

        # update memory
        memory = await ctx.store.get("memory")
        for msg in tool_msgs:
            memory.put(msg)

        await ctx.store.set("sources", sources)
        await ctx.store.set("memory", memory)

        chat_history = memory.get()
        return InputEvent(input=chat_history)

# 5. workflow VIZ

In [None]:
draw_all_possible_flows(
    FunctionCallingAgent, filename="day14_FunctionCallingAgent.html"
)

# 6. tools

In [None]:
from llama_index.tools.tavily_research.base import TavilyToolSpec

TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
tavily_tool = TavilyToolSpec(
    api_key=TAVILY_API_KEY,
)
tools = tavily_tool.to_tool_list()
# 把工具轉成文字
tool_descs = []
for tool in tools:
    tool_descs.append(f"{tool.metadata.name}: {tool.metadata.description}")

tools_str = "\n".join(tool_descs)
tools_str

In [None]:
print(tools_str)

# 7. init

In [None]:
from llama_index.llms.openai import OpenAI
llm = OpenAI(model="gpt-5-mini", streamimg=False)

In [None]:
agent = FunctionCallingAgent(
    llm=OpenAI(model="gpt-5-mini", streamimg=False), tools=tools, timeout=120, verbose=True
)

In [None]:
ret = await agent.run(input="Hello!")

In [None]:
ret.keys()

In [None]:
ret = await agent.run(input="幫我查一下現在及時的特斯拉股價(美金)")
ret