Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions livekit-agents/livekit/agents/llm/_provider_format/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,12 @@ def to_fnc_ctx(tool_ctx: llm.ToolContext, *, strict: bool = True) -> list[dict[s
return schemas


def to_responses_fnc_ctx(tool_ctx: llm.ToolContext, *, strict: bool = True) -> list[dict[str, Any]]:
from livekit.plugins import openai

def to_responses_fnc_ctx(
tool_ctx: llm.ToolContext,
*,
strict: bool = True,
provider_tool_type: type[llm.ProviderTool],
) -> list[dict[str, Any]]:
schemas: list[dict[str, Any]] = []
for tool in tool_ctx.flatten():
if isinstance(tool, llm.RawFunctionTool):
Expand All @@ -230,7 +233,7 @@ def to_responses_fnc_ctx(tool_ctx: llm.ToolContext, *, strict: bool = True) -> l
elif isinstance(tool, llm.FunctionTool):
schema = llm.utils.build_legacy_openai_schema(tool, internally_tagged=True)
schemas.append(schema)
elif isinstance(tool, openai.tools.OpenAITool):
elif isinstance(tool, provider_tool_type) and hasattr(tool, "to_dict"):
schemas.append(tool.to_dict())

return schemas
11 changes: 10 additions & 1 deletion livekit-agents/livekit/agents/llm/tool_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,16 @@ def copy(self) -> ToolContext:

@overload
def parse_function_tools(
self, format: Literal["openai", "openai.responses"], *, strict: bool = True
self, format: Literal["openai"], *, strict: bool = True
) -> list[dict[str, Any]]: ...

@overload
def parse_function_tools(
self,
format: Literal["openai.responses"],
*,
strict: bool = True,
provider_tool_type: type[ProviderTool],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unused?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this adds the provider_tool_type type argument so we can check is_instance against either OpenAIProviderTool or XAIProviderTool so we can avoid adding provider tools from other providers, before the is_instance hardcoded the OpenAIProviderTool type which meant XAIProviderTools were never sent.

) -> list[dict[str, Any]]: ...

@overload
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any

from livekit.agents import ProviderTool


class AnthropicTool(ProviderTool, ABC):
class AnthropicTool(ProviderTool):
"""Base class for Anthropic server-side provider tools."""

@abstractmethod
def to_dict(self) -> dict[str, Any]: ...

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any

from livekit.agents import ProviderTool


class MistralTool(ProviderTool, ABC):
class MistralTool(ProviderTool):
"""Base class for Mistral server-side provider tools."""

@abstractmethod
def to_dict(self) -> dict[str, Any]: ...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

from ..log import logger
from ..models import _supports_reasoning_effort
from ..tools import OpenAITool

ServiceTier = Literal["auto", "default", "flex", "scale", "priority"]
Verbosity = Literal["low", "medium", "high"]
Expand Down Expand Up @@ -151,6 +152,10 @@ class _LLMOptions:


class LLM(llm.LLM):
# the plugin's ProviderTool subclass; subclasses (e.g. xAI) override this so server-side
# provider tools are recognized when serializing the request. See to_responses_fnc_ctx.
_provider_tool_type: type[llm.ProviderTool] = OpenAITool

def __init__(
self,
*,
Expand Down Expand Up @@ -406,7 +411,9 @@ async def _run_impl(self) -> None:
tool_schemas = cast(
list[ToolParam],
self._tool_ctx.parse_function_tools(
"openai.responses", strict=self._strict_tool_schema
"openai.responses",
strict=self._strict_tool_schema,
provider_tool_type=self._llm._provider_tool_type,
),
)

Expand Down Expand Up @@ -540,6 +547,21 @@ def _handle_response_created(self, event: ResponseCreatedEvent) -> None:
self._response_id = event.response.id

def _handle_response_completed(self, event: ResponseCompletedEvent) -> llm.ChatChunk | None:
for item in event.response.output:
# Every item.type is a discriminator of openai's ResponseOutputItem union.
# Of those, only these are produced/consumed by the agent itself; all other
# members of the union are tools the Responses API runs server-side (e.g.
# openai web_search, xAI web_search and x_search's custom_tool_call subcalls),
# so anything not in this set is a provider-executed tool.
if item.type not in ("message", "reasoning", "function_call", "function_call_output"):
logger.info(
"provider tool executed",
extra={
"tool_type": item.type,
"result": item.model_dump(exclude_none=True),
},
)

self._response_completed = True
self._llm._prev_chat_ctx = self._full_chat_ctx
self._llm._prev_resp_id = self._response_id
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Any, Literal

from livekit.agents import ProviderTool
from openai.types import responses


class OpenAITool(ProviderTool, ABC):
class OpenAITool(ProviderTool):
"""Base class for OpenAI server-side provider tools."""

@abstractmethod
def to_dict(self) -> dict[str, Any]: ...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
from livekit.agents.utils import is_given
from livekit.plugins import openai

from ..tools import XAITool

XAI_BASE_URL = "https://api.x.ai/v1"


class LLM(openai.responses.LLM):
# xAI's server-side tools (web_search, x_search, file_search) subclass XAITool.
_provider_tool_type = XAITool

def __init__(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Any

from livekit.agents import ProviderTool


class XAITool(ProviderTool, ABC):
class XAITool(ProviderTool):
"""Base class for xAI server-side provider tools."""

@abstractmethod
def to_dict(self) -> dict[str, Any]: ...

Expand Down
Loading