Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion src/strands/experimental/bidirectional_streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,37 @@
"""Bidirectional streaming package for real-time audio/text conversations."""
"""Bidirectional streaming package."""

# Main components - Primary user interface
from .agent.agent import BidirectionalAgent

# Advanced interfaces (for custom implementations)
from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession

# Model providers - What users need to create models
from .models.novasonic import NovaSonicBidirectionalModel

# Event types - For type hints and event handling
from .types.bidirectional_streaming import (
AudioInputEvent,
AudioOutputEvent,
BidirectionalStreamEvent,
InterruptionDetectedEvent,
TextOutputEvent,
UsageMetricsEvent,
)

__all__ = [
# Main interface
"BidirectionalAgent",
# Model providers
"NovaSonicBidirectionalModel",
# Event types
"AudioInputEvent",
"AudioOutputEvent",
"TextOutputEvent",
"InterruptionDetectedEvent",
"BidirectionalStreamEvent",
"UsageMetricsEvent",
# Model interface
"BidirectionalModel",
"BidirectionalModelSession",
]
297 changes: 288 additions & 9 deletions src/strands/experimental/bidirectional_streaming/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,32 @@
"""

import asyncio
import json
import logging
from typing import AsyncIterable
import random
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterable, Callable, Mapping, Optional

from .... import _identifier
from ....hooks import HookProvider, HookRegistry
from ....telemetry.metrics import EventLoopMetrics
from ....tools.executors import ConcurrentToolExecutor
from ....tools.executors._executor import ToolExecutor
from ....tools.registry import ToolRegistry
from ....types.content import Messages
from ....tools.watcher import ToolWatcher
from ....types.content import Message, Messages
from ....types.tools import ToolResult, ToolUse
from ....types.traces import AttributeValue
from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection
from ..models.bidirectional_model import BidirectionalModel
from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent


logger = logging.getLogger(__name__)

_DEFAULT_AGENT_NAME = "Strands Agents"
_DEFAULT_AGENT_ID = "default"


class BidirectionalAgent:
"""Agent for bidirectional streaming conversations.
Expand All @@ -34,12 +47,125 @@ class BidirectionalAgent:
sessions. Supports concurrent tool execution and interruption handling.
"""

class ToolCaller:
"""Call tool as a function for bidirectional agent."""

def __init__(self, agent: "BidirectionalAgent") -> None:
"""Initialize tool caller with agent reference."""
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
# agent tools and thus break their execution.
self._agent = agent

def __getattr__(self, name: str) -> Callable[..., Any]:
"""Call tool as a function.

This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').

Args:
name: The name of the attribute (tool) being accessed.

Returns:
A function that when called will execute the named tool.

Raises:
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
"""

def caller(
user_message_override: Optional[str] = None,
record_direct_tool_call: Optional[bool] = None,
**kwargs: Any,
) -> Any:
"""Call a tool directly by name.

Args:
user_message_override: Optional custom message to record instead of default
record_direct_tool_call: Whether to record direct tool calls in message history.
For bidirectional agents, this is always True to maintain conversation history.
**kwargs: Keyword arguments to pass to the tool.

Returns:
The result returned by the tool.

Raises:
AttributeError: If the tool doesn't exist.
"""
normalized_name = self._find_normalized_tool_name(name)

# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}
tool_results: list[ToolResult] = []
invocation_state = kwargs

async def acall() -> ToolResult:
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
_ = event

return tool_results[0]

def tcall() -> ToolResult:
return asyncio.run(acall())

with ThreadPoolExecutor() as executor:
future = executor.submit(tcall)
tool_result = future.result()

# Always record direct tool calls for bidirectional agents to maintain conversation history
# Use agent's record_direct_tool_call setting if not overridden
if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
self._agent._record_tool_execution(tool_use, tool_result, user_message_override)

return tool_result

return caller

def _find_normalized_tool_name(self, name: str) -> str:
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
tool_registry = self._agent.tool_registry.registry

if tool_registry.get(name, None):
return name

# If the desired name contains underscores, it might be a placeholder for characters that can't be
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
# all tools that can be represented with the normalized name
if "_" in name:
filtered_tools = [
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
]

# The registry itself defends against similar names, so we can just take the first match
if filtered_tools:
return filtered_tools[0]

raise AttributeError(f"Tool '{name}' not found")

def __init__(
self,
model: BidirectionalModel,
tools: list | None = None,
system_prompt: str | None = None,
messages: Messages | None = None,
record_direct_tool_call: bool = True,
load_tools_from_directory: bool = False,
agent_id: Optional[str] = None,
name: Optional[str] = None,
tool_executor: Optional[ToolExecutor] = None,
hooks: Optional[list[HookProvider]] = None,
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
description: Optional[str] = None,
):
"""Initialize bidirectional agent with required model and optional configuration.

Expand All @@ -48,24 +174,177 @@ def __init__(
tools: Optional list of tools available to the model.
system_prompt: Optional system prompt for conversations.
messages: Optional conversation history to initialize with.
record_direct_tool_call: Whether to record direct tool calls in message history.
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios.
name: Name of the Agent.
tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.).
hooks: Hooks to be added to the agent hook registry.
trace_attributes: Custom trace attributes to apply to the agent's trace span.
description: Description of what the Agent does.
"""
self.model = model
self.system_prompt = system_prompt
self.messages = messages or []

# Initialize tool registry using existing Strands infrastructure

# Agent identification
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
self.name = name or _DEFAULT_AGENT_NAME
self.description = description

# Tool execution configuration
self.record_direct_tool_call = record_direct_tool_call
self.load_tools_from_directory = load_tools_from_directory

# Process trace attributes to ensure they're of compatible types
self.trace_attributes: dict[str, AttributeValue] = {}
if trace_attributes:
for k, v in trace_attributes.items():
if isinstance(v, (str, int, float, bool)) or (
isinstance(v, list) and all(isinstance(x, (str, int, float, bool)) for x in v)
):
self.trace_attributes[k] = v

# Initialize tool registry
self.tool_registry = ToolRegistry()
if tools:

if tools is not None:
self.tool_registry.process_tools(tools)
self.tool_registry.initialize_tools()

# Initialize tool executor for concurrent execution
self.tool_executor = ConcurrentToolExecutor()

self.tool_registry.initialize_tools(self.load_tools_from_directory)

# Initialize tool watcher if directory loading is enabled
if self.load_tools_from_directory:
self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry)

# Initialize tool executor
self.tool_executor = tool_executor or ConcurrentToolExecutor()

# Initialize hooks system
self.hooks = HookRegistry()
if hooks:
for hook in hooks:
self.hooks.add_hook(hook)

# Initialize other components
self.event_loop_metrics = EventLoopMetrics()
self.tool_caller = BidirectionalAgent.ToolCaller(self)

# Session management
self._session = None
self._output_queue = asyncio.Queue()

@property
def tool(self) -> ToolCaller:
"""Call tool as a function.

Returns:
Tool caller through which user can invoke tool as a function.

Example:
```
agent = BidirectionalAgent(model=model, tools=[calculator])
agent.tool.calculator(expression="2+2")
```
"""
return self.tool_caller

@property
def tool_names(self) -> list[str]:
"""Get a list of all registered tool names.

Returns:
Names of all tools available to this agent.
"""
all_tools = self.tool_registry.get_all_tools_config()
return list(all_tools.keys())

def _record_tool_execution(
self,
tool: ToolUse,
tool_result: ToolResult,
user_message_override: Optional[str],
) -> None:
"""Record a tool execution in the message history.

Creates a sequence of messages that represent the tool execution:

1. A user message describing the tool call
2. An assistant message with the tool use
3. A user message with the tool result
4. An assistant message acknowledging the tool call

Args:
tool: The tool call information.
tool_result: The result returned by the tool.
user_message_override: Optional custom message to include.
"""
# Filter tool input parameters to only include those defined in tool spec
filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"])

# Create user message describing the tool call
input_parameters = json.dumps(filtered_input, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")

user_msg_content = [
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")}
]

# Add override message if provided
if user_message_override:
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})

# Create filtered tool use for message history
filtered_tool: ToolUse = {
"toolUseId": tool["toolUseId"],
"name": tool["name"],
"input": filtered_input,
}

# Create the message sequence
user_msg: Message = {
"role": "user",
"content": user_msg_content,
}
tool_use_msg: Message = {
"role": "assistant",
"content": [{"toolUse": filtered_tool}],
}
tool_result_msg: Message = {
"role": "user",
"content": [{"toolResult": tool_result}],
}
assistant_msg: Message = {
"role": "assistant",
"content": [{"text": f"agent.tool.{tool['name']} was called."}],
}

# Add to message history
self.messages.append(user_msg)
self.messages.append(tool_use_msg)
self.messages.append(tool_result_msg)
self.messages.append(assistant_msg)

logger.debug("Direct tool call recorded in message history: %s", tool["name"])

def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]:
"""Filter input parameters to only include those defined in the tool specification.

Args:
tool_name: Name of the tool to get specification for
input_params: Original input parameters

Returns:
Filtered parameters containing only those defined in tool spec
"""
all_tools_config = self.tool_registry.get_all_tools_config()
tool_spec = all_tools_config.get(tool_name)

if not tool_spec or "inputSchema" not in tool_spec:
return input_params.copy()

properties = tool_spec["inputSchema"]["json"]["properties"]
return {k: v for k, v in input_params.items() if k in properties}

async def start(self) -> None:
"""Start a persistent bidirectional conversation session.

Expand Down
Loading
Loading