diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 52822711a..0f842ee9f 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,2 +1,37 @@ -"""Bidirectional streaming package for real-time audio/text conversations.""" +"""Bidirectional streaming package.""" +# Main components - Primary user interface +from .agent.agent import BidirectionalAgent + +# Advanced interfaces (for custom implementations) +from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession + +# Model providers - What users need to create models +from .models.novasonic import NovaSonicBidirectionalModel + +# Event types - For type hints and event handling +from .types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, + UsageMetricsEvent, +) + +__all__ = [ + # Main interface + "BidirectionalAgent", + # Model providers + "NovaSonicBidirectionalModel", + # Event types + "AudioInputEvent", + "AudioOutputEvent", + "TextOutputEvent", + "InterruptionDetectedEvent", + "BidirectionalStreamEvent", + "UsageMetricsEvent", + # Model interface + "BidirectionalModel", + "BidirectionalModelSession", +] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 68d371a51..26b964c53 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -13,12 +13,22 @@ """ import asyncio +import json import logging -from typing import AsyncIterable +import random +from concurrent.futures import ThreadPoolExecutor +from typing import Any, AsyncIterable, Callable, Mapping, Optional +from .... import _identifier +from ....hooks import HookProvider, HookRegistry +from ....telemetry.metrics import EventLoopMetrics from ....tools.executors import ConcurrentToolExecutor +from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry -from ....types.content import Messages +from ....tools.watcher import ToolWatcher +from ....types.content import Message, Messages +from ....types.tools import ToolResult, ToolUse +from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent @@ -26,6 +36,9 @@ logger = logging.getLogger(__name__) +_DEFAULT_AGENT_NAME = "Strands Agents" +_DEFAULT_AGENT_ID = "default" + class BidirectionalAgent: """Agent for bidirectional streaming conversations. @@ -34,12 +47,125 @@ class BidirectionalAgent: sessions. Supports concurrent tool execution and interruption handling. """ + class ToolCaller: + """Call tool as a function for bidirectional agent.""" + + def __init__(self, agent: "BidirectionalAgent") -> None: + """Initialize tool caller with agent reference.""" + # WARNING: Do not add any other member variables or methods as this could result in a name conflict with + # agent tools and thus break their execution. + self._agent = agent + + def __getattr__(self, name: str) -> Callable[..., Any]: + """Call tool as a function. + + This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). + + Args: + name: The name of the attribute (tool) being accessed. + + Returns: + A function that when called will execute the named tool. + + Raises: + AttributeError: If no tool with the given name exists or if multiple tools match the given name. + """ + + def caller( + user_message_override: Optional[str] = None, + record_direct_tool_call: Optional[bool] = None, + **kwargs: Any, + ) -> Any: + """Call a tool directly by name. + + Args: + user_message_override: Optional custom message to record instead of default + record_direct_tool_call: Whether to record direct tool calls in message history. + For bidirectional agents, this is always True to maintain conversation history. + **kwargs: Keyword arguments to pass to the tool. + + Returns: + The result returned by the tool. + + Raises: + AttributeError: If the tool doesn't exist. + """ + normalized_name = self._find_normalized_tool_name(name) + + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs + + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + _ = event + + return tool_results[0] + + def tcall() -> ToolResult: + return asyncio.run(acall()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(tcall) + tool_result = future.result() + + # Always record direct tool calls for bidirectional agents to maintain conversation history + # Use agent's record_direct_tool_call setting if not overridden + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call + + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + self._agent._record_tool_execution(tool_use, tool_result, user_message_override) + + return tool_result + + return caller + + def _find_normalized_tool_name(self, name: str) -> str: + """Lookup the tool represented by name, replacing characters with underscores as necessary.""" + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # If the desired name contains underscores, it might be a placeholder for characters that can't be + # represented as python identifiers but are valid as tool names, such as dashes. In that case, find + # all tools that can be represented with the normalized name + if "_" in name: + filtered_tools = [ + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name + ] + + # The registry itself defends against similar names, so we can just take the first match + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + def __init__( self, model: BidirectionalModel, tools: list | None = None, system_prompt: str | None = None, messages: Messages | None = None, + record_direct_tool_call: bool = True, + load_tools_from_directory: bool = False, + agent_id: Optional[str] = None, + name: Optional[str] = None, + tool_executor: Optional[ToolExecutor] = None, + hooks: Optional[list[HookProvider]] = None, + trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + description: Optional[str] = None, ): """Initialize bidirectional agent with required model and optional configuration. @@ -48,24 +174,177 @@ def __init__( tools: Optional list of tools available to the model. system_prompt: Optional system prompt for conversations. messages: Optional conversation history to initialize with. + record_direct_tool_call: Whether to record direct tool calls in message history. + load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. + agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios. + name: Name of the Agent. + tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + hooks: Hooks to be added to the agent hook registry. + trace_attributes: Custom trace attributes to apply to the agent's trace span. + description: Description of what the Agent does. """ self.model = model self.system_prompt = system_prompt self.messages = messages or [] - - # Initialize tool registry using existing Strands infrastructure + + # Agent identification + self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) + self.name = name or _DEFAULT_AGENT_NAME + self.description = description + + # Tool execution configuration + self.record_direct_tool_call = record_direct_tool_call + self.load_tools_from_directory = load_tools_from_directory + + # Process trace attributes to ensure they're of compatible types + self.trace_attributes: dict[str, AttributeValue] = {} + if trace_attributes: + for k, v in trace_attributes.items(): + if isinstance(v, (str, int, float, bool)) or ( + isinstance(v, list) and all(isinstance(x, (str, int, float, bool)) for x in v) + ): + self.trace_attributes[k] = v + + # Initialize tool registry self.tool_registry = ToolRegistry() - if tools: + + if tools is not None: self.tool_registry.process_tools(tools) - self.tool_registry.initialize_tools() - - # Initialize tool executor for concurrent execution - self.tool_executor = ConcurrentToolExecutor() + + self.tool_registry.initialize_tools(self.load_tools_from_directory) + + # Initialize tool watcher if directory loading is enabled + if self.load_tools_from_directory: + self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) + + # Initialize tool executor + self.tool_executor = tool_executor or ConcurrentToolExecutor() + + # Initialize hooks system + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + + # Initialize other components + self.event_loop_metrics = EventLoopMetrics() + self.tool_caller = BidirectionalAgent.ToolCaller(self) # Session management self._session = None self._output_queue = asyncio.Queue() + @property + def tool(self) -> ToolCaller: + """Call tool as a function. + + Returns: + Tool caller through which user can invoke tool as a function. + + Example: + ``` + agent = BidirectionalAgent(model=model, tools=[calculator]) + agent.tool.calculator(expression="2+2") + ``` + """ + return self.tool_caller + + @property + def tool_names(self) -> list[str]: + """Get a list of all registered tool names. + + Returns: + Names of all tools available to this agent. + """ + all_tools = self.tool_registry.get_all_tools_config() + return list(all_tools.keys()) + + def _record_tool_execution( + self, + tool: ToolUse, + tool_result: ToolResult, + user_message_override: Optional[str], + ) -> None: + """Record a tool execution in the message history. + + Creates a sequence of messages that represent the tool execution: + + 1. A user message describing the tool call + 2. An assistant message with the tool use + 3. A user message with the tool result + 4. An assistant message acknowledging the tool call + + Args: + tool: The tool call information. + tool_result: The result returned by the tool. + user_message_override: Optional custom message to include. + """ + # Filter tool input parameters to only include those defined in tool spec + filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) + + # Create user message describing the tool call + input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") + + user_msg_content = [ + {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} + ] + + # Add override message if provided + if user_message_override: + user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) + + # Create filtered tool use for message history + filtered_tool: ToolUse = { + "toolUseId": tool["toolUseId"], + "name": tool["name"], + "input": filtered_input, + } + + # Create the message sequence + user_msg: Message = { + "role": "user", + "content": user_msg_content, + } + tool_use_msg: Message = { + "role": "assistant", + "content": [{"toolUse": filtered_tool}], + } + tool_result_msg: Message = { + "role": "user", + "content": [{"toolResult": tool_result}], + } + assistant_msg: Message = { + "role": "assistant", + "content": [{"text": f"agent.tool.{tool['name']} was called."}], + } + + # Add to message history + self.messages.append(user_msg) + self.messages.append(tool_use_msg) + self.messages.append(tool_result_msg) + self.messages.append(assistant_msg) + + logger.debug("Direct tool call recorded in message history: %s", tool["name"]) + + def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: + """Filter input parameters to only include those defined in the tool specification. + + Args: + tool_name: Name of the tool to get specification for + input_params: Original input parameters + + Returns: + Filtered parameters containing only those defined in tool spec + """ + all_tools_config = self.tool_registry.get_all_tools_config() + tool_spec = all_tools_config.get(tool_name) + + if not tool_spec or "inputSchema" not in tool_spec: + return input_params.copy() + + properties = tool_spec["inputSchema"]["json"]["properties"] + return {k: v for k, v in input_params.items() if k in properties} + async def start(self) -> None: """Start a persistent bidirectional conversation session. diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 16be08aaf..69f5d759d 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -12,12 +12,13 @@ """ import asyncio -import json import logging import traceback import uuid from ....tools._validator import validate_and_prepare_tools +from ....telemetry.metrics import Trace +from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession @@ -59,6 +60,9 @@ def __init__(self, model_session: BidirectionalModelSession, agent: "Bidirection # Interruption handling (model-agnostic) self.interrupted = False self.interruption_lock = asyncio.Lock() + + # Tool execution tracking + self.tool_count = 0 async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: @@ -195,11 +199,11 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: # Cancel all pending tool execution tasks cancelled_tools = 0 - for task_id, task in list(session.pending_tool_tasks.items()): + for _task_id, task in list(session.pending_tool_tasks.items()): if not task.done(): task.cancel() cancelled_tools += 1 - logger.debug("Tool task cancelled: %s", task_id) + logger.debug("Tool task cancelled: %s", _task_id) if cancelled_tools > 0: logger.debug("Tool tasks cancelled: %d", cancelled_tools) @@ -274,7 +278,8 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Queue tool requests for concurrent execution if strands_event.get("toolUse"): - logger.debug("Tool queued: %s", strands_event["toolUse"].get("name")) + tool_name = strands_event["toolUse"].get("name") + logger.debug("Tool usage detected: %s", tool_name) await session.tool_queue.put(strands_event["toolUse"]) continue @@ -316,7 +321,13 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: while session.active: try: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) - logger.debug("Tool execution started: %s (id: %s)", tool_use.get("name"), tool_use.get("toolUseId")) + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + + session.tool_count += 1 + print(f"\nTool #{session.tool_count}: {tool_name}") + + logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id) task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) @@ -330,11 +341,11 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: # Log completion status if completed_task.cancelled(): - logger.debug("Tool task cleanup cancelled: %s", task_id) + logger.debug("Tool task cancelled: %s", task_id) elif completed_task.exception(): - logger.error("Tool task cleanup error: %s - %s", task_id, str(completed_task.exception())) + logger.error("Tool task error: %s - %s", task_id, str(completed_task.exception())) else: - logger.debug("Tool task cleanup success: %s", task_id) + logger.debug("Tool task completed: %s", task_id) except Exception as e: logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) @@ -365,94 +376,106 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: - """Execute tool using Strands infrastructure with interruption support. - - Executes tools using the existing Strands tool system with proper asyncio - cancellation handling. Tool execution is stopped via task cancellation, - not manual state checks. - + """Execute tool using the complete Strands tool execution system. + + Uses proper Strands ToolExecutor system with validation, error handling, + and event streaming. + Args: session: BidirectionalConnection for context. tool_use: Tool use event to execute. """ tool_name = tool_use.get("name") tool_id = tool_use.get("toolUseId") - + + logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) + try: - # Create message structure for existing tool system + # Create message structure for validation tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - + + # Use Strands validation system tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] - - # Validate using existing Strands validation + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - - # Filter valid tool uses + + # Filter valid tools valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - + if not valid_tool_uses: - logger.warning("Tool validation failed: %s (id: %s)", tool_name, tool_id) + logger.warning("No valid tools after validation: %s", tool_name) return - - # Execute tools directly (simpler approach for bidirectional) - for tool_use in valid_tool_uses: - tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) - - if tool_func: - try: - actual_func = _extract_callable_function(tool_func) - - # Execute tool function with provided input - result = actual_func(**tool_use.get("input", {})) - - tool_result = _create_success_result(tool_use["toolUseId"], result) - tool_results.append(tool_result) - - except Exception as e: - logger.error("Tool execution failed: %s - %s", tool_name, str(e)) - tool_result = _create_error_result(tool_use["toolUseId"], str(e)) - tool_results.append(tool_result) - else: - logger.warning("Tool not found: %s", tool_name) - - # Send results through provider-specific session - for result in tool_results: - await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) - - logger.debug("Tool execution completed: %s (%d results)", tool_name, len(tool_results)) - + + # Create invocation state for tool execution + invocation_state = { + "agent": session.agent, + "model": session.agent.model, + "messages": session.agent.messages, + "system_prompt": session.agent.system_prompt, + } + + # Create cycle trace and span + cycle_trace = Trace("Bidirectional Tool Execution") + cycle_span = None + + tool_events = session.agent.tool_executor._execute( + session.agent, + valid_tool_uses, + tool_results, + cycle_trace, + cycle_span, + invocation_state + ) + + # Process tool events and send results to provider + async for tool_event in tool_events: + if isinstance(tool_event, ToolResultEvent): + tool_result = tool_event.tool_result + tool_use_id = tool_result.get("toolUseId") + + # Send result through provider-specific session + await session.model_session.send_tool_result(tool_use_id, tool_result) + logger.debug("Tool result sent: %s", tool_use_id) + + # Handle streaming events if needed later + elif isinstance(tool_event, ToolStreamEvent): + logger.debug("Tool stream event: %s", tool_event) + pass + + # Add tool result message to conversation history + if tool_results: + from ....hooks import MessageAddedEvent + + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + session.agent.messages.append(tool_result_message) + session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message)) + logger.debug("Tool result message added to history: %s", tool_name) + + logger.debug("Tool execution completed: %s", tool_name) + except asyncio.CancelledError: - # Task was cancelled due to interruption - this is expected behavior - logger.debug("Tool task cancelled gracefully: %s (id: %s)", tool_name, tool_id) - raise # Re-raise to properly handle cancellation + logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) + raise except Exception as e: - logger.error("Tool execution error: %s - %s", tool_use.get("name"), str(e)) + logger.error("Tool execution error: %s - %s", tool_name, str(e)) + # Send error result + error_result: ToolResult = { + "toolUseId": tool_id, + "status": "error", + "content": [{"text": f"Error: {str(e)}"}] + } try: - await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) - except Exception as send_error: - logger.error("Tool error send failed: %s", str(send_error)) - - -def _extract_callable_function(tool_func: any) -> any: - """Extract the callable function from different tool object types.""" - if hasattr(tool_func, "_tool_func"): - return tool_func._tool_func - elif hasattr(tool_func, "func"): - return tool_func.func - elif callable(tool_func): - return tool_func - else: - raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") - - -def _create_success_result(tool_use_id: str, result: any) -> dict[str, any]: - """Create a successful tool result.""" - return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} + await session.model_session.send_tool_result(tool_id, error_result) + logger.debug("Error result sent: %s", tool_id) + except Exception: + logger.error("Failed to send error result: %s", tool_id) + pass # Session might be closed -def _create_error_result(tool_use_id: str, error: str) -> dict[str, any]: - """Create an error tool result.""" - return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 6cba974e0..3a785e98a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -3,4 +3,9 @@ from .bidirectional_model import BidirectionalModel, BidirectionalModelSession from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession -__all__ = ["BidirectionalModel", "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession"] +__all__ = [ + "BidirectionalModel", + "BidirectionalModelSession", + "NovaSonicBidirectionalModel", + "NovaSonicSession", +] diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 7f7937ef1..7f35a3c1c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -35,6 +35,7 @@ BidirectionalConnectionStartEvent, InterruptionDetectedEvent, TextOutputEvent, + UsageMetricsEvent ) from .bidirectional_model import BidirectionalModel, BidirectionalModelSession @@ -121,7 +122,7 @@ async def initialize( init_events = self._build_initialization_events(system_prompt, tools or [], messages) - logger.debug(f"Nova Sonic initialization - sending {len(init_events)} events") + logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) await self._send_initialization_events(init_events) logger.info("Nova Sonic connection initialized successfully") @@ -146,7 +147,7 @@ def _build_initialization_events( async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events with required delays.""" - for i, event in enumerate(events): + for _i, event in enumerate(events): await self._send_nova_event(event) await asyncio.sleep(EVENT_DELAY) @@ -167,12 +168,12 @@ async def _process_responses(self) -> None: await asyncio.sleep(0.1) continue except Exception as e: - logger.warning(f"Nova Sonic response error: {e}") + logger.warning("Nova Sonic response error: %s", e) await asyncio.sleep(0.1) continue except Exception as e: - logger.error(f"Nova Sonic fatal error: {e}") + logger.error("Nova Sonic fatal error: %s", e) finally: logger.debug("Nova Sonic response processor stopped") @@ -190,7 +191,7 @@ async def _handle_response_data(self, response_data: str) -> None: await self._event_queue.put(nova_event) except json.JSONDecodeError as e: - logger.warning(f"Nova Sonic JSON decode error: {e}") + logger.warning("Nova Sonic JSON decode error: %s", e) def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" @@ -383,11 +384,9 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No self._get_content_end_event(content_name), ] - for i, event in enumerate(events): + for _i, event in enumerate(events): await self._send_nova_event(event) - - async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: @@ -490,7 +489,14 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No # Handle usage events (ignore) elif "usageEvent" in nova_event: - return None + usage_data = nova_event["usageEvent"] + usage_metrics: UsageMetricsEvent = { + "totalTokens": usage_data.get("totalTokens"), + "inputTokens": usage_data.get("totalInputTokens"), + "outputTokens": usage_data.get("totalOutputTokens"), + "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens"), + } + return {"usageMetrics": usage_metrics} # Handle content start events (track role) elif "contentStart" in nova_event: diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py similarity index 89% rename from src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py rename to src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py index b31607966..8c3ae3b4c 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py @@ -10,6 +10,7 @@ # Add the src directory to Python path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import os import time import pyaudio @@ -19,6 +20,29 @@ from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +def test_direct_tools(): + """Test direct tool calling.""" + print("Testing direct tool calling...") + + # Check AWS credentials + if not all([os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY")]): + print("AWS credentials not set - skipping test") + return + + try: + model = NovaSonicBidirectionalModel() + agent = BidirectionalAgent(model=model, tools=[calculator]) + + # Test calculator + result = agent.tool.calculator(expression="2 * 3") + content = result.get("content", [{}])[0].get("text", "") + print(f"Result: {content}") + print("Test completed") + + except Exception as e: + print(f"Test failed: {e}") + + async def play(context): """Play audio output with responsive interruption support.""" audio = pyaudio.PyAudio() @@ -195,4 +219,7 @@ async def main(duration=180): if __name__ == "__main__": + # Test direct tool calling first + test_direct_tools() + asyncio.run(main()) diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 01d72356a..c0f6eb209 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -116,10 +116,28 @@ class BidirectionalConnectionEndEvent(TypedDict): metadata: Provider-specific connection metadata. """ - reason: Literal["user_request", "timeout", "error"] + reason: Literal["user_request", "timeout", "error", "connection_complete"] connectionId: Optional[str] metadata: Optional[Dict[str, Any]] +class UsageMetricsEvent(TypedDict): + """Token usage and performance tracking. + + Provides standardized usage metrics across providers for cost monitoring + and performance optimization. + + Attributes: + totalTokens: Total tokens used in the interaction. + inputTokens: Tokens used for input processing. + outputTokens: Tokens used for output generation. + audioTokens: Tokens used specifically for audio processing. + """ + + totalTokens: Optional[int] + inputTokens: Optional[int] + outputTokens: Optional[int] + audioTokens: Optional[int] + class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. @@ -134,11 +152,14 @@ class BidirectionalStreamEvent(StreamEvent, total=False): interruptionDetected: User interruption detection. BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. + usageMetrics: Token usage and performance metrics. """ - audioOutput: AudioOutputEvent - audioInput: AudioInputEvent - textOutput: TextOutputEvent - interruptionDetected: InterruptionDetectedEvent - BidirectionalConnectionStart: BidirectionalConnectionStartEvent - BidirectionalConnectionEnd: BidirectionalConnectionEndEvent + audioOutput: Optional[AudioOutputEvent] + audioInput: Optional[AudioInputEvent] + textOutput: Optional[TextOutputEvent] + interruptionDetected: Optional[InterruptionDetectedEvent] + BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] + BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] + usageMetrics: Optional[UsageMetricsEvent] +