From 107f03533c2abde1586109ec4d4709a556801c21 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 25 Sep 2025 10:47:46 -0400 Subject: [PATCH 01/15] feat(bidirectional_streaming): Add experimental bidirectional streaming MVP POC implementation --- pyproject.toml | 9 +- .../bidirectional_streaming/agent/__init__.py | 2 + .../bidirectional_streaming/agent/agent.py | 167 ++++ .../event_loop/__init__.py | 2 + .../event_loop/bidirectional_event_loop.py | 539 ++++++++++++ .../models/__init__.py | 2 + .../models/bidirectional_model.py | 115 +++ .../models/novasonic.py | 777 ++++++++++++++++++ .../tests/test_bidirectional_streaming.py | 203 +++++ .../bidirectional_streaming/types/__init__.py | 3 + .../types/bidirectional_streaming.py | 167 ++++ .../bidirectional_streaming/utils/debug.py | 45 + 12 files changed, 2030 insertions(+), 1 deletion(-) create mode 100644 src/strands/experimental/bidirectional_streaming/agent/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/agent/agent.py create mode 100644 src/strands/experimental/bidirectional_streaming/event_loop/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py create mode 100644 src/strands/experimental/bidirectional_streaming/models/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py create mode 100644 src/strands/experimental/bidirectional_streaming/models/novasonic.py create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py create mode 100644 src/strands/experimental/bidirectional_streaming/types/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py create mode 100644 src/strands/experimental/bidirectional_streaming/utils/debug.py diff --git a/pyproject.toml b/pyproject.toml index 3c2243299..d4f7e6eee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,13 @@ sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] +bidirectional-streaming = [ + "pyaudio>=0.2.13", + "rx>=3.2.0", + "smithy-aws-core>=0.0.1", + "pytz", + "aws_sdk_bedrock_runtime", +] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -68,7 +75,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/experimental/bidirectional_streaming/agent/__init__.py b/src/strands/experimental/bidirectional_streaming/agent/__init__.py new file mode 100644 index 000000000..bbd2c91f3 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/__init__.py @@ -0,0 +1,2 @@ +"""Bidirectional streaming agent package.""" +# Agent package \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py new file mode 100644 index 000000000..cfc005576 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -0,0 +1,167 @@ +"""Bidirectional Agent for real-time streaming conversations. + +AGENT PURPOSE: +------------- +Provides type-safe constructor and session management for real-time audio/text +interaction. Serves as the bidirectional equivalent to invoke_async() → stream_async() +but establishes sessions that continue indefinitely with concurrent task management. + +ARCHITECTURAL APPROACH: +---------------------- +While invoke_async() creates single request-response cycles that terminate after +stop_reason: "end_turn" with sequential tool processing, start_conversation() +establishes persistent sessions with concurrent processing of model events, tool +execution, and user input without session termination. + +DESIGN CHOICE: +------------- +Uses dedicated BidirectionalAgent class (Option 1 from design document) for: +- Type safety with no conditional behavior based on model type +- Separation of concerns - solely focused on bidirectional streaming +- Future proofing - allows changes without implications to existing Agent class +""" + +import asyncio +import logging +from typing import AsyncIterable, List, Optional + +from strands.tools.registry import ToolRegistry +from strands.types.content import Messages + +from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection +from ..models.bidirectional_model import BidirectionalModel +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent +from ..utils.debug import log_event, log_flow + +logger = logging.getLogger(__name__) + + +class BidirectionalAgent: + """Agent for bidirectional streaming conversations. + + Provides type-safe constructor and session management for real-time + audio/text interaction with concurrent processing capabilities. + """ + + def __init__( + self, + model: BidirectionalModel, + tools: Optional[List] = None, + system_prompt: Optional[str] = None, + messages: Optional[Messages] = None + ): + """Initialize bidirectional agent with required model and optional configuration. + + Args: + model: BidirectionalModel instance supporting streaming sessions. + tools: Optional list of tools available to the model. + system_prompt: Optional system prompt for conversations. + messages: Optional conversation history to initialize with. + """ + self.model = model + self.system_prompt = system_prompt + self.messages = messages or [] + + # Initialize tool registry using existing Strands infrastructure + self.tool_registry = ToolRegistry() + if tools: + self.tool_registry.process_tools(tools) + self.tool_registry.initialize_tools() + + # Initialize tool executor for concurrent execution + from strands.tools.executors import ConcurrentToolExecutor + self.tool_executor = ConcurrentToolExecutor() + + # Session management + self._session = None + self._output_queue = asyncio.Queue() + + async def start_conversation(self) -> None: + """Initialize persistent bidirectional session for real-time interaction. + + Creates provider-specific session and starts concurrent background tasks + for model events, tool execution, and session lifecycle management. + + Raises: + ValueError: If conversation already active. + ConnectionError: If session creation fails. + """ + if self._session and self._session.active: + raise ValueError("Conversation already active. Call end_conversation() first.") + + log_flow("conversation_start", "initializing session") + self._session = await start_bidirectional_connection(self) + log_event("conversation_ready") + + async def send_text(self, text: str) -> None: + """Send text input during active session without interrupting model generation. + + Args: + text: Text message to send to the model. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + log_event("text_sent", length=len(text)) + await self._session.model_session.send_text_content(text) + + async def send_audio(self, audio_input: AudioInputEvent) -> None: + """Send audio input during active session for real-time speech interaction. + + Args: + audio_input: AudioInputEvent containing audio data and configuration. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + await self._session.model_session.send_audio_content(audio_input) + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive output events from the model including audio, text. + + Provides access to model output events processed by background tasks. + Events include audio output, text responses, tool calls, and session updates. + + Yields: + BidirectionalStreamEvent: Events from the model session. + """ + while self._session and self._session.active: + try: + event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) + yield event + except asyncio.TimeoutError: + continue + + async def interrupt(self) -> None: + """Interrupt current model generation and switch to listening mode. + + Sends interruption signal to immediately stop generation and clear + pending audio output for responsive conversational experience. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + await self._session.model_session.send_interrupt() + + async def end_conversation(self) -> None: + """End session and cleanup resources including background tasks. + + Performs graceful session termination with proper resource cleanup + including background task cancellation and connection closure. + """ + if self._session: + await stop_bidirectional_connection(self._session) + self._session = None + + def _validate_active_session(self) -> None: + """Validate that an active session exists. + + Raises: + ValueError: If no active session. + """ + if not self._session or not self._session.active: + raise ValueError("No active conversation. Call start_conversation() first.") + diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py new file mode 100644 index 000000000..24080b703 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py @@ -0,0 +1,2 @@ +"""Bidirectional streaming event loop package.""" +# Event Loop package \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py new file mode 100644 index 000000000..2164115d8 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -0,0 +1,539 @@ +"""Bidirectional session management for concurrent streaming conversations. + +SESSION PURPOSE: +--------------- +Session wrapper for bidirectional communication that manages concurrent tasks for +model events, tool execution, and audio processing while providing simple interface +for Agent interaction. + +CONCURRENT ARCHITECTURE: +----------------------- +Unlike existing event_loop_cycle() that processes events sequentially where tool +execution blocks conversation, this module coordinates concurrent tasks through +asyncio queues and background task management. +""" + +import asyncio +import json +import logging +import traceback +import uuid +from typing import Any, Dict + +from strands.tools._validator import validate_and_prepare_tools +from strands.types.content import Message +from strands.types.tools import ToolResult, ToolUse + +from ..models.bidirectional_model import BidirectionalModelSession +from ..utils.debug import log_event, log_flow + +logger = logging.getLogger(__name__) + +# Session constants +TOOL_QUEUE_TIMEOUT = 0.5 +SUPERVISION_INTERVAL = 0.1 + + +class BidirectionalConnection: + """Session wrapper for bidirectional communication. + + Manages concurrent tasks for model events, tool execution, and audio processing + while providing simple interface for Agent interaction. + """ + + def __init__(self, model_session: BidirectionalModelSession, agent): + """Initialize session with model session and agent reference. + + Args: + model_session: Provider-specific bidirectional model session. + agent: BidirectionalAgent instance for tool registry access. + """ + self.model_session = model_session + self.agent = agent + self.active = True + + # Background processing coordination + self.background_tasks = [] + self.tool_queue = asyncio.Queue() + self.audio_output_queue = asyncio.Queue() + + # Task management for cleanup + self.pending_tool_tasks: Dict[str, asyncio.Task] = {} + + # Interruption handling (model-agnostic) + self.interrupted = False + +async def start_bidirectional_connection(agent) -> BidirectionalConnection: + """Initialize bidirectional session with concurrent background tasks. + + Creates provider-specific session and starts concurrent tasks for model events, + tool execution, and session lifecycle management. + + Args: + agent: BidirectionalAgent instance. + + Returns: + BidirectionalConnection: Active session with background tasks running. + """ + log_flow("session_start", "initializing model session") + + # Create provider-specific session + model_session = await agent.model.create_bidirectional_connection( + system_prompt=agent.system_prompt, + tools=agent.tool_registry.get_all_tool_specs(), + messages=agent.messages + ) + + # Create session wrapper for background processing + session = BidirectionalConnection(model_session=model_session, agent=agent) + + # Start concurrent background processors IMMEDIATELY after session creation + # This is critical - Nova Sonic needs response processing during initialization + log_flow("background_tasks", "starting processors") + session.background_tasks = [ + asyncio.create_task(_process_model_events(session)), # Handle model responses + asyncio.create_task(_process_tool_execution(session)) # Execute tools concurrently + ] + + # Start main coordination cycle + session.main_cycle_task = asyncio.create_task( + bidirectional_event_loop_cycle(session) + ) + + # Give background tasks a moment to start + await asyncio.sleep(0.1) + log_event("session_ready", tasks=len(session.background_tasks)) + + return session + + +async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: + """End session and cleanup resources including background tasks. + + Args: + session: BidirectionalConnection to cleanup. + """ + if not session.active: + return + + log_flow("session_cleanup", "starting") + session.active = False + + # Cancel pending tool tasks + for _, task in session.pending_tool_tasks.items(): + if not task.done(): + task.cancel() + + # Cancel background tasks + for task in session.background_tasks: + if not task.done(): + task.cancel() + + # Cancel main cycle task + if hasattr(session, 'main_cycle_task') and not session.main_cycle_task.done(): + session.main_cycle_task.cancel() + + # Wait for tasks to complete + all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) + if hasattr(session, 'main_cycle_task'): + all_tasks.append(session.main_cycle_task) + + if all_tasks: + await asyncio.gather(*all_tasks, return_exceptions=True) + + # Close model session + await session.model_session.close() + log_event("session_closed") + + +async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: + """Main bidirectional event loop coordinator - runs continuously during session. + + Coordinates background tasks and manages session lifecycle. Unlike the + sequential event_loop_cycle() that processes events one by one, this coordinator + manages concurrent tasks and session state. + + Args: + session: BidirectionalConnection to coordinate. + """ + while session.active: + try: + # Check if background processors are still running + if all(task.done() for task in session.background_tasks): + log_event("session_end", reason="all_processors_completed") + session.active = False + break + + # Check for failed background tasks + for i, task in enumerate(session.background_tasks): + if task.done() and not task.cancelled(): + exception = task.exception() + if exception: + log_event("session_error", processor=i, error=str(exception)) + session.active = False + raise exception + + # Brief pause before next supervision check + await asyncio.sleep(SUPERVISION_INTERVAL) + + except asyncio.CancelledError: + break + except Exception as e: + log_event("event_loop_error", error=str(e)) + session.active = False + raise + + +async def _handle_interruption(session: BidirectionalConnection) -> None: + """Handle interruption detection with comprehensive task cancellation. + + Sets interruption flag, cancels pending tool tasks, and aggressively + clears audio output queue following Nova Sonic example patterns. + + Args: + session: BidirectionalConnection to handle interruption for. + """ + log_event("interruption_detected") + session.interrupted = True + + # 🔥 CANCEL ALL PENDING TOOL TASKS (Nova Sonic pattern) + cancelled_tools = 0 + for task_id, task in list(session.pending_tool_tasks.items()): + if not task.done(): + task.cancel() + cancelled_tools += 1 + log_event("tool_task_cancelled", task_id=task_id) + + if cancelled_tools > 0: + log_event("tool_tasks_cancelled", count=cancelled_tools) + + # 🔥 AGGRESSIVELY CLEAR AUDIO OUTPUT QUEUE (Nova Sonic pattern) + cleared_count = 0 + while True: + try: + session.audio_output_queue.get_nowait() + cleared_count += 1 + except asyncio.QueueEmpty: + break + + # Also clear the agent's audio output queue if it exists + if hasattr(session.agent, '_output_queue'): + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + log_event("agent_audio_queue_cleared", count=audio_cleared) + + if cleared_count > 0: + log_event("session_audio_queue_cleared", count=cleared_count) + + # Brief sleep to allow audio system to settle (matches Nova Sonic timing) + await asyncio.sleep(0.05) + + # Reset interruption flag after clearing (automatic recovery) + session.interrupted = False + log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) + + +async def _process_model_events(session: BidirectionalConnection) -> None: + """Process model events using existing Strands event types. + + This background task handles all model responses and converts + them to existing StreamEvent format for integration with Strands. + + Args: + session: BidirectionalConnection containing model session. + """ + log_flow("model_events", "processor started") + try: + async for provider_event in session.model_session.receive_events(): + if not session.active: + break + + # Convert provider events to Strands format + strands_event = _convert_to_strands_event(provider_event) + + # Handle interruption detection (multiple patterns) + if strands_event.get("interruptionDetected"): + log_event("interruption_forwarded") + await _handle_interruption(session) + # Forward interruption event to agent for application-level handling + await session.agent._output_queue.put(strands_event) + continue + + # Check for text-based interruption (Nova Sonic pattern) + if strands_event.get("textOutput"): + text_content = strands_event["textOutput"].get("content", "") + if '{ "interrupted" : true }' in text_content: + log_event("text_interruption_detected") + await _handle_interruption(session) + # Still forward the text event + await session.agent._output_queue.put(strands_event) + continue + + # Queue tool requests for concurrent execution + if strands_event.get("toolUse"): + log_event("tool_queued", name=strands_event["toolUse"].get("name")) + await session.tool_queue.put(strands_event["toolUse"]) + continue + + # Send output events to Agent for receive() method + if strands_event.get("audioOutput") or strands_event.get("textOutput"): + await session.agent._output_queue.put(strands_event) + + # Update Agent conversation history using existing patterns + if strands_event.get("messageStop"): + log_event("message_added_to_history") + session.agent.messages.append(strands_event["messageStop"]["message"]) + + except Exception as e: + log_event("model_events_error", error=str(e)) + traceback.print_exc() + finally: + log_flow("model_events", "processor stopped") + + +async def _process_tool_execution(session: BidirectionalConnection) -> None: + """Execute tools concurrently using existing Strands infrastructure with barge-in support. + + This background task manages tool execution without blocking + model event processing or user interaction. Includes proper + task cleanup and cancellation handling. + + Args: + session: BidirectionalConnection containing tool queue. + """ + log_flow("tool_execution", "processor started") + while session.active: + try: + tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) + log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) + + if not session.active: + break + + task_id = str(uuid.uuid4()) + task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) + session.pending_tool_tasks[task_id] = task + + # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) + def cleanup_task(completed_task): + try: + # Remove from pending tasks + if task_id in session.pending_tool_tasks: + del session.pending_tool_tasks[task_id] + + # Log completion status + if completed_task.cancelled(): + log_event("tool_task_cleanup_cancelled", task_id=task_id) + elif completed_task.exception(): + log_event("tool_task_cleanup_error", task_id=task_id, + error=str(completed_task.exception())) + else: + log_event("tool_task_cleanup_success", task_id=task_id) + except Exception as e: + log_event("tool_task_cleanup_failed", task_id=task_id, error=str(e)) + + task.add_done_callback(cleanup_task) + + except asyncio.TimeoutError: + if not session.active: + break + # 🔥 PERIODIC CLEANUP OF COMPLETED TASKS + completed_tasks = [ + task_id for task_id, task in session.pending_tool_tasks.items() + if task.done() + ] + for task_id in completed_tasks: + if task_id in session.pending_tool_tasks: + del session.pending_tool_tasks[task_id] + + if completed_tasks: + log_event("periodic_task_cleanup", count=len(completed_tasks)) + + continue + except Exception as e: + log_event("tool_execution_error", error=str(e)) + if not session.active: + break + + log_flow("tool_execution", "processor stopped") + + +def _convert_to_strands_event(provider_event: Dict) -> Dict: + """Pass-through for events already normalized by provider sessions. + + Providers convert their raw events to standard format before reaching here. + This just validates and passes through the normalized events. + + Args: + provider_event: Already normalized event from provider session. + + Returns: + Dict: The same event, validated and passed through. + """ + # Basic validation - ensure we have a dict + if not isinstance(provider_event, dict): + return {} + + # Pass through - conversion already done by provider session + return provider_event + + +async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: + """Execute tool using existing Strands infrastructure with barge-in support. + + Model-agnostic tool execution that uses existing Strands tool system, + handles interruption during execution, and delegates result formatting + to provider-specific session. + + Args: + session: BidirectionalConnection for context. + tool_use: Tool use event to execute. + """ + tool_name = tool_use.get('name') + tool_id = tool_use.get('toolUseId') + + try: + # 🔥 CHECK FOR INTERRUPTION BEFORE STARTING (Nova Sonic pattern) + if session.interrupted or not session.active: + log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) + return + + # Create message structure for existing tool system + tool_message: Message = { + "role": "assistant", + "content": [{"toolUse": tool_use}] + } + + tool_uses: list[ToolUse] = [] + tool_results: list[ToolResult] = [] + invalid_tool_use_ids: list[str] = [] + + # Validate using existing Strands validation + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) + + # Filter valid tool uses + valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] + + if not valid_tool_uses: + log_event("tool_validation_failed", name=tool_name, id=tool_id) + return + + # Execute tools directly (simpler approach for bidirectional) + for tool_use in valid_tool_uses: + # 🔥 CHECK FOR INTERRUPTION DURING EXECUTION + if session.interrupted or not session.active: + log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) + return + + tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) + + if tool_func: + try: + actual_func = _extract_callable_function(tool_func) + + # 🔥 WRAP TOOL EXECUTION IN CANCELLATION CHECK + # For async tools, we could wrap with asyncio.wait_for with cancellation + # For sync tools, we execute directly but check interruption after + result = actual_func(**tool_use.get("input", {})) + + # 🔥 CHECK FOR INTERRUPTION AFTER TOOL EXECUTION + if session.interrupted or not session.active: + log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) + return + + tool_result = _create_success_result(tool_use["toolUseId"], result) + tool_results.append(tool_result) + + except asyncio.CancelledError: + # Tool was cancelled due to interruption + log_event("tool_execution_cancelled", name=tool_name, id=tool_id) + return + except Exception as e: + # 🔥 CHECK FOR INTERRUPTION EVEN ON ERROR + if session.interrupted or not session.active: + log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) + return + + log_event("tool_execution_failed", name=tool_name, error=str(e)) + tool_result = _create_error_result(tool_use["toolUseId"], str(e)) + tool_results.append(tool_result) + else: + log_event("tool_not_found", name=tool_name) + + # 🔥 FINAL INTERRUPTION CHECK BEFORE SENDING RESULTS + if session.interrupted or not session.active: + log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) + return + + # Send results through provider-specific session + for result in tool_results: + await session.model_session.send_tool_result( + tool_use.get("toolUseId"), + result + ) + + log_event("tool_execution_completed", name=tool_name, results=len(tool_results)) + + except asyncio.CancelledError: + # Task was cancelled due to interruption - this is expected behavior + log_event("tool_task_cancelled_gracefully", name=tool_name, id=tool_id) + raise # Re-raise to properly handle cancellation + except Exception as e: + log_event("tool_execution_error", name=tool_use.get('name'), error=str(e)) + + # Only send error if not interrupted + if not session.interrupted and session.active: + try: + await session.model_session.send_tool_error( + tool_use.get("toolUseId"), + str(e) + ) + except Exception as send_error: + log_event("tool_error_send_failed", error=str(send_error)) + + +def _extract_callable_function(tool_func): + """Extract the callable function from different tool object types.""" + if hasattr(tool_func, '_tool_func'): + return tool_func._tool_func + elif hasattr(tool_func, 'func'): + return tool_func.func + elif callable(tool_func): + return tool_func + else: + raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") + + +def _create_success_result(tool_use_id: str, result) -> Dict[str, Any]: + """Create a successful tool result.""" + return { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": json.dumps(result)}] + } + + +def _create_error_result(tool_use_id: str, error: str) -> Dict[str, Any]: + """Create an error tool result.""" + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error}"}] + } \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py new file mode 100644 index 000000000..b2b10a5f2 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -0,0 +1,2 @@ +"""Bidirectional streaming models package.""" +# Models package \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py new file mode 100644 index 000000000..32727105d --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -0,0 +1,115 @@ +"""Bidirectional model interface for real-time streaming conversations. + +INTERFACE PURPOSE: +----------------- +Declares bidirectional capabilities separate from existing Model hierarchy to maintain +clean separation of concerns. Models choose to implement this interface explicitly +for bidirectional streaming support. + +PROVIDER ABSTRACTION: +-------------------- +Abstracts incompatible initialization patterns: Nova Sonic's event-driven sequences, +Google's WebSocket setup, OpenAI's dual protocol support. Normalizes different tool +calling approaches and handles provider-specific session management with varying +time limits and connection patterns. + +SESSION-BASED APPROACH: +---------------------- +Unlike existing Model interface's stateless request-response pattern where each +stream() call processes complete messages independently, BidirectionalModel introduces +session-based approach where create_bidirectional_connection() establishes persistent +connections supporting real-time bidirectional communication during active generation. +""" + +import abc +import logging +from typing import Any, AsyncIterable, Dict, List, Optional + +from ....types.content import Messages +from ....types.tools import ToolSpec +from ..types.bidirectional_streaming import AudioInputEvent + +logger = logging.getLogger(__name__) + +class BidirectionalModelSession(abc.ABC): + """Model-specific session interface for bidirectional communication.""" + + @abc.abstractmethod + async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + """Receive events from model in provider-agnostic format. + + Normalizes different provider event formats so the event loop + can process all providers uniformly. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content to model during session. + + Manages complex audio encoding and provider-specific event sequences + while presenting simple AudioInputEvent interface to Agent. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content processed concurrently with ongoing generation. + + Enables natural interruption and follow-up questions without session restart. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_interrupt(self) -> None: + """Send interruption signal to immediately stop generation. + + Critical for responsive conversational experiences where users + can naturally interrupt mid-response. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + """Send tool execution result to model in provider-specific format. + + Each provider handles result formatting according to their protocol: + - Nova Sonic: toolResult events with JSON content + - Google Live API: toolResponse with specific structure + - OpenAI Realtime: function call responses with call_id correlation + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_tool_error(self, tool_use_id: str, error: str) -> None: + """Send tool execution error to model in provider-specific format.""" + raise NotImplementedError + + @abc.abstractmethod + async def close(self) -> None: + """Close session and cleanup resources with graceful termination.""" + raise NotImplementedError + + +class BidirectionalModel(abc.ABC): + """Interface for models that support bidirectional streaming. + + Separate from Model to maintain clean separation of concerns. + Models choose to implement this interface explicitly. + """ + + @abc.abstractmethod + async def create_bidirectional_connection( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None, + **kwargs + ) -> BidirectionalModelSession: + """Create bidirectional session with model-specific implementation. + + Abstracts complex provider-specific initialization while presenting + uniform interface to Agent. + """ + raise NotImplementedError + diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py new file mode 100644 index 000000000..ba71cd4d3 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -0,0 +1,777 @@ +"""Nova Sonic bidirectional model provider for real-time streaming conversations. + +PROVIDER PURPOSE: +---------------- +Implements BidirectionalModel and BidirectionalModelSession interfaces for Nova Sonic, +handling the complex three-tier event management and structured event cleanup sequences +required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. + +NOVA SONIC SPECIFICS: +-------------------- +- Requires hierarchical event sequences: sessionStart → promptStart → content streaming +- Uses hex-encoded base64 audio format that needs conversion to raw bytes +- Implements toolUse/toolResult with content containers and identifier tracking +- Manages 8-minute session limits with proper cleanup sequences +- Handles stopReason: "INTERRUPTED" events for interruption detection + +INTEGRATION APPROACH: +-------------------- +Adapts existing Nova Sonic sample patterns to work with Strands bidirectional +infrastructure while maintaining provider-specific protocol requirements. +""" + +import asyncio +import base64 +import json +import logging +import time +import traceback +import uuid +from typing import Any, AsyncIterable, Dict, List, Optional + +from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput +from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme +from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk +from smithy_aws_core.credentials_resolvers.environment import EnvironmentCredentialsResolver + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + InterruptionDetectedEvent, + TextOutputEvent, +) +from ..utils.debug import log_event, log_flow, time_it_async +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# Nova Sonic configuration constants +NOVA_INFERENCE_CONFIG = { + "maxTokens": 1024, + "topP": 0.9, + "temperature": 0.7 +} + +NOVA_AUDIO_INPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "audioType": "SPEECH", + "encoding": "base64" +} + +NOVA_AUDIO_OUTPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 24000, + "sampleSizeBits": 16, + "channelCount": 1, + "voiceId": "matthew", + "encoding": "base64", + "audioType": "SPEECH" +} + +NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} +NOVA_TOOL_CONFIG = {"mediaType": "application/json"} + +# Timing constants +SILENCE_THRESHOLD = 2.0 +EVENT_DELAY = 0.1 +RESPONSE_TIMEOUT = 1.0 + + +class NovaSonicSession(BidirectionalModelSession): + """Nova Sonic session handling protocol-specific details.""" + + def __init__(self, stream, config: Dict[str, Any]): + """Initialize Nova Sonic session. + + Args: + stream: Nova Sonic bidirectional stream. + config: Model configuration. + """ + self.stream = stream + self.config = config + self.prompt_name = str(uuid.uuid4()) + self._active = True + + # Nova Sonic requires unique content names + self.audio_content_name = str(uuid.uuid4()) + self.text_content_name = str(uuid.uuid4()) + + # Audio session state + self.audio_session_active = False + self.last_audio_time = None + self.silence_threshold = SILENCE_THRESHOLD + self.silence_task = None + + # Validate stream + if not stream: + logger.error("Stream is None") + raise ValueError("Stream cannot be None") + + logger.debug("Nova Sonic session initialized with prompt: %s", self.prompt_name) + + async def initialize( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None + ) -> None: + """Initialize Nova Sonic session with required protocol sequence.""" + try: + system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." + + init_events = self._build_initialization_events(system_prompt, tools or [], messages) + + log_flow("nova_init", f"sending {len(init_events)} events") + await self._send_initialization_events(init_events) + + log_event("nova_session_initialized") + self._response_task = asyncio.create_task(self._process_responses()) + + except Exception as e: + logger.error("Error during Nova Sonic initialization: %s", e) + raise + + def _build_initialization_events(self, system_prompt: str, tools: List[ToolSpec], + messages: Optional[Messages]) -> List[str]: + """Build the sequence of initialization events.""" + events = [ + self._get_session_start_event(), + self._get_prompt_start_event(tools) + ] + + events.extend(self._get_system_prompt_events(system_prompt)) + + # Message history would be processed here if needed in the future + # Currently not implemented as it's not used in the existing test cases + + return events + + async def _send_initialization_events(self, events: List[str]) -> None: + """Send initialization events with required delays.""" + for i, event in enumerate(events): + await time_it_async(f"send_init_event_{i+1}", lambda: self._send_nova_event(event)) + await asyncio.sleep(EVENT_DELAY) + + async def _process_responses(self) -> None: + """Process Nova Sonic responses continuously.""" + log_flow("nova_responses", "processor started") + + try: + while self._active: + try: + output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) + result = await output[1].receive() + + if result.value and result.value.bytes_: + await self._handle_response_data(result.value.bytes_.decode('utf-8')) + + except asyncio.TimeoutError: + await asyncio.sleep(0.1) + continue + except Exception as e: + log_event("nova_response_error", error=str(e)) + await asyncio.sleep(0.1) + continue + + except Exception as e: + log_event("nova_fatal_error", error=str(e)) + finally: + log_flow("nova_responses", "processor stopped") + + async def _handle_response_data(self, response_data: str) -> None: + """Handle decoded response data from Nova Sonic.""" + try: + json_data = json.loads(response_data) + + if 'event' in json_data: + nova_event = json_data['event'] + self._log_event_type(nova_event) + + if not hasattr(self, '_event_queue'): + self._event_queue = asyncio.Queue() + + await self._event_queue.put(nova_event) + except json.JSONDecodeError as e: + log_event("nova_json_error", error=str(e)) + + def _log_event_type(self, nova_event: Dict[str, Any]) -> None: + """Log specific Nova Sonic event types for debugging.""" + if 'usageEvent' in nova_event: + log_event("nova_usage", usage=nova_event['usageEvent']) + elif 'textOutput' in nova_event: + log_event("nova_text_output") + elif 'toolUse' in nova_event: + tool_use = nova_event['toolUse'] + log_event("nova_tool_use", name=tool_use['toolName'], id=tool_use['toolUseId']) + elif 'audioOutput' in nova_event: + audio_content = nova_event['audioOutput']['content'] + audio_bytes = base64.b64decode(audio_content) + log_event("nova_audio_output", bytes=len(audio_bytes)) + + async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + """Receive Nova Sonic events and convert to provider-agnostic format.""" + if not self.stream: + logger.error("Stream is None") + return + + log_flow("nova_events", "starting event stream") + + # Emit session start event to Strands event system + session_start: BidirectionalConnectionStartEvent = { + "sessionId": self.prompt_name, + "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")} + } + yield { + "BidirectionalConnectionStart": session_start + } + + # Initialize event queue if not already done + if not hasattr(self, '_event_queue'): + self._event_queue = asyncio.Queue() + + try: + while self._active: + try: + # Get events from the queue populated by _process_responses + nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + + # Convert to provider-agnostic format + provider_event = self._convert_nova_event(nova_event) + if provider_event: + yield provider_event + + except asyncio.TimeoutError: + # No events in queue - continue waiting + continue + + except Exception as e: + logger.error("Error receiving Nova Sonic event: %s", e) + logger.error(traceback.format_exc()) + finally: + # Emit session end event when exiting + session_end: BidirectionalConnectionEndEvent = { + "sessionId": self.prompt_name, + "reason": "session_complete", + "metadata": {"provider": "nova_sonic"} + } + yield { + "BidirectionalConnectionEnd": session_end + } + + async def start_audio_session(self) -> None: + """Start audio input session (call once before sending audio chunks).""" + if self.audio_session_active: + return + + log_event("nova_audio_session_start") + + audio_content_start = json.dumps({ + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG + } + } + }) + + await self._send_nova_event(audio_content_start) + self.audio_session_active = True + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio using Nova Sonic protocol-specific format.""" + if not self._active: + return + + # Start audio session if not already active + if not self.audio_session_active: + await self.start_audio_session() + + # Update last audio time and cancel any pending silence task + self.last_audio_time = time.time() + if self.silence_task and not self.silence_task.done(): + self.silence_task.cancel() + + # Convert audio to Nova Sonic base64 format + nova_audio_data = base64.b64encode(audio_input["audioData"]).decode('utf-8') + + # Send audio input event + audio_event = json.dumps({ + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "content": nova_audio_data + } + } + }) + + await self._send_nova_event(audio_event) + + # Start silence detection task + self.silence_task = asyncio.create_task(self._check_silence()) + + async def _check_silence(self): + """Check for silence and automatically end audio session.""" + try: + await asyncio.sleep(self.silence_threshold) + if self.audio_session_active and self.last_audio_time: + elapsed = time.time() - self.last_audio_time + if elapsed >= self.silence_threshold: + log_event("nova_silence_detected", elapsed=elapsed) + await self.end_audio_input() + except asyncio.CancelledError: + pass + + async def end_audio_input(self) -> None: + """End current audio input session to trigger Nova Sonic processing.""" + if not self.audio_session_active: + return + + log_event("nova_audio_session_end") + + audio_content_end = json.dumps({ + "event": { + "contentEnd": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name + } + } + }) + + await self._send_nova_event(audio_content_end) + self.audio_session_active = False + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content using Nova Sonic format.""" + if not self._active: + return + + content_name = str(uuid.uuid4()) + events = [ + self._get_text_content_start_event(content_name), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name) + ] + + for event in events: + await self._send_nova_event(event) + + async def send_interrupt(self) -> None: + """Send interruption signal to Nova Sonic.""" + if not self._active: + return + + # Nova Sonic handles interruption through special input events + interrupt_event = { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "stopReason": "INTERRUPTED" + } + } + } + await self._send_nova_event(interrupt_event) + + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + """Send tool result using Nova Sonic toolResult format.""" + if not self._active: + return + + log_event("nova_tool_result_send", id=tool_use_id) + content_name = str(uuid.uuid4()) + events = [ + self._get_tool_content_start_event(content_name, tool_use_id), + self._get_tool_result_event(content_name, result), + self._get_content_end_event(content_name) + ] + + for i, event in enumerate(events): + await time_it_async(f"send_tool_event_{i+1}", lambda: self._send_nova_event(event)) + + async def send_tool_error(self, tool_use_id: str, error: str) -> None: + """Send tool error using Nova Sonic format.""" + log_event("nova_tool_error_send", id=tool_use_id, error=error) + error_result = {"error": error} + await self.send_tool_result(tool_use_id, error_result) + + async def close(self) -> None: + """Close Nova Sonic session with proper cleanup sequence.""" + if not self._active: + return + + log_flow("nova_cleanup", "starting session close") + self._active = False + + # Cancel response processing task if running + if hasattr(self, '_response_task') and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + # End audio session if active + if self.audio_session_active: + await self.end_audio_input() + + # Send cleanup events + cleanup_events = [ + self._get_prompt_end_event(), + self._get_session_end_event() + ] + + for event in cleanup_events: + try: + await self._send_nova_event(event) + except Exception as e: + logger.warning("Error during Nova Sonic cleanup: %s", e) + + # Close stream + try: + await self.stream.input_stream.close() + except Exception as e: + logger.warning("Error closing Nova Sonic stream: %s", e) + + except Exception as e: + log_event("nova_cleanup_error", error=str(e)) + finally: + log_event("nova_session_closed") + + def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Convert Nova Sonic events to provider-agnostic format.""" + # Handle audio output + if "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + + audio_output: AudioOutputEvent = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": "base64" + } + + return { + "audioOutput": audio_output + } + + # Handle text output + elif "textOutput" in nova_event: + text_content = nova_event["textOutput"]["content"] + # Use stored role from contentStart event, fallback to event role + role = getattr(self, '_current_role', nova_event["textOutput"].get("role", "assistant")) + + # Check for Nova Sonic interruption pattern (matches working sample) + if '{ "interrupted" : true }' in text_content: + log_event("nova_interruption_in_text") + interruption: InterruptionDetectedEvent = { + "reason": "user_input" + } + return { + "interruptionDetected": interruption + } + + # Show transcription for user speech - ALWAYS show these regardless of DEBUG flag + if role == "USER": + print(f"User: {text_content}") + elif role == "ASSISTANT": + print(f"Assistant: {text_content}") + + text_output: TextOutputEvent = { + "text": text_content, + "role": role.lower() + } + + return { + "textOutput": text_output + } + + # Handle tool use + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + + tool_use_event: ToolUse = { + "toolUseId": tool_use["toolUseId"], + "name": tool_use["toolName"], + "input": json.loads(tool_use["content"]) + } + + return { + "toolUse": tool_use_event + } + + # Handle interruption + elif nova_event.get("stopReason") == "INTERRUPTED": + log_event("nova_interruption_stop_reason") + + interruption: InterruptionDetectedEvent = { + "reason": "user_input" + } + + return { + "interruptionDetected": interruption + } + + # Handle usage events (ignore) + elif "usageEvent" in nova_event: + return None + + # Handle content start events (track role) + elif "contentStart" in nova_event: + role = nova_event["contentStart"].get("role", "unknown") + # Store role for subsequent text output events + self._current_role = role + return None + + # Handle other events + else: + return None + + # Nova Sonic event template methods + def _get_session_start_event(self) -> str: + """Generate Nova Sonic session start event.""" + return json.dumps({ + "event": { + "sessionStart": { + "inferenceConfiguration": NOVA_INFERENCE_CONFIG + } + } + }) + + def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: + """Generate Nova Sonic prompt start event with tool configuration.""" + prompt_start_event = { + "event": { + "promptStart": { + "promptName": self.prompt_name, + "textOutputConfiguration": NOVA_TEXT_CONFIG, + "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG + } + } + } + + if tools: + tool_config = self._build_tool_configuration(tools) + prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG + prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} + + return json.dumps(prompt_start_event) + + def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: + """Build tool configuration from tool specs.""" + tool_config = [] + for tool in tools: + input_schema = ({"json": json.dumps(tool['inputSchema']['json'])} + if 'json' in tool['inputSchema'] + else {"json": json.dumps(tool['inputSchema'])}) + + tool_config.append({ + "toolSpec": { + "name": tool["name"], + "description": tool["description"], + "inputSchema": input_schema + } + }) + return tool_config + + def _get_system_prompt_events(self, system_prompt: str) -> List[str]: + """Generate system prompt events.""" + content_name = str(uuid.uuid4()) + return [ + self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_input_event(content_name, system_prompt), + self._get_content_end_event(content_name) + ] + + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: + """Generate text content start event.""" + return json.dumps({ + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG + } + } + }) + + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: + """Generate tool content start event.""" + return json.dumps({ + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG + } + } + } + }) + + def _get_text_input_event(self, content_name: str, text: str) -> str: + """Generate text input event.""" + return json.dumps({ + "event": { + "textInput": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": text + } + } + }) + + def _get_tool_result_event(self, content_name: str, result: Dict[str, Any]) -> str: + """Generate tool result event.""" + return json.dumps({ + "event": { + "toolResult": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": json.dumps(result) + } + } + }) + + def _get_content_end_event(self, content_name: str) -> str: + """Generate content end event.""" + return json.dumps({ + "event": { + "contentEnd": { + "promptName": self.prompt_name, + "contentName": content_name + } + } + }) + + def _get_prompt_end_event(self) -> str: + """Generate prompt end event.""" + return json.dumps({ + "event": { + "promptEnd": { + "promptName": self.prompt_name + } + } + }) + + def _get_session_end_event(self) -> str: + """Generate session end event.""" + return json.dumps({ + "event": { + "sessionEnd": {} + } + }) + + async def _send_nova_event(self, event: str) -> None: + """Send event JSON string to Nova Sonic stream.""" + try: + + # Event is already a JSON string + bytes_data = event.encode('utf-8') + chunk = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=bytes_data) + ) + await self.stream.input_stream.send(chunk) + logger.debug("Successfully sent Nova Sonic event") + + except Exception as e: + logger.error("Error sending Nova Sonic event: %s", e) + logger.error("Event was: %s", event) + raise + + +class NovaSonicBidirectionalModel(BidirectionalModel): + """Nova Sonic model implementing bidirectional capabilities.""" + + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): + """Initialize Nova Sonic bidirectional model. + + Args: + model_id: Nova Sonic model identifier. + region: AWS region. + **config: Additional configuration. + """ + self.model_id = model_id + self.region = region + self.config = config + self._client = None + + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) + + async def create_bidirectional_connection( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None, + **kwargs + ) -> BidirectionalModelSession: + """Create Nova Sonic bidirectional session.""" + log_flow("nova_session_create", "starting") + + # Initialize client if needed + if not self._client: + await time_it_async("initialize_client", lambda: self._initialize_client()) + + # Start Nova Sonic bidirectional stream + try: + stream = await time_it_async("invoke_model_with_bidirectional_stream", + lambda: self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + )) + + # Create and initialize session + session = NovaSonicSession(stream, self.config) + await time_it_async("initialize_session", + lambda: session.initialize(system_prompt, tools, messages)) + + log_event("nova_session_created") + return session + except Exception as e: + log_event("nova_session_create_error", error=str(e)) + logger.error("Failed to create Nova Sonic session: %s", e) + raise + + async def _initialize_client(self) -> None: + """Initialize Nova Sonic client.""" + try: + + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + http_auth_scheme_resolver=HTTPAuthSchemeResolver(), + http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()} + ) + + self._client = BedrockRuntimeClient(config=config) + logger.debug("Nova Sonic client initialized") + + except ImportError as e: + logger.error("Nova Sonic dependencies not available: %s", e) + raise + except Exception as e: + logger.error("Error initializing Nova Sonic client: %s", e) + raise + diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py new file mode 100644 index 000000000..f35fd4462 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -0,0 +1,203 @@ +"""Simple bidirectional streaming test with enhanced interruption support.""" + +import asyncio +import time +import pyaudio + +from src.strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from src.strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands_tools import calculator + + +async def play(context): + """Play audio output with responsive interruption support.""" + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=1, + format=pyaudio.paInt16, + output=True, + rate=24000, + frames_per_buffer=1024, + ) + + try: + while context["active"]: + try: + # Check for interruption first + if context.get("interrupted", False): + # Clear entire audio queue immediately + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get next audio data + audio_data = await asyncio.wait_for( + context["audio_out"].get(), + timeout=0.1 + ) + + if audio_data and context["active"]: + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if context.get("interrupted", False) or not context["active"]: + break + + end = min(i + chunk_size, len(audio_data)) + chunk = audio_data[i:end] + speaker.write(chunk) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue # No audio available + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + finally: + speaker.close() + audio.terminate() + + +async def record(context): + """Record audio input from microphone.""" + audio = pyaudio.PyAudio() + microphone = audio.open( + channels=1, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + rate=16000, + ) + + try: + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + context["audio_in"].put_nowait(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + except asyncio.CancelledError: + pass + finally: + microphone.close() + audio.terminate() + + +async def receive(agent, context): + """Receive and process events from agent.""" + try: + async for event in agent.receive(): + # Handle audio output + if "audioOutput" in event: + if not context.get("interrupted", False): + context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) + + # Handle interruption events + elif "interruptionDetected" in event: + context["interrupted"] = True + elif "interrupted" in event: + context["interrupted"] = True + + # Handle text output with interruption detection + elif "textOutput" in event: + text_content = event["textOutput"].get("content", "") + role = event["textOutput"].get("role", "unknown") + + # Check for text-based interruption patterns + if '{ "interrupted" : true }' in text_content: + context["interrupted"] = True + elif "interrupted" in text_content.lower(): + context["interrupted"] = True + + # Log text output + if role.upper() == "USER": + print(f"User: {text_content}") + elif role.upper() == "ASSISTANT": + print(f"Assistant: {text_content}") + + except asyncio.CancelledError: + pass + + +async def send(agent, context): + """Send audio input to agent.""" + try: + while time.time() - context["start_time"] < context["duration"]: + try: + audio_bytes = context["audio_in"].get_nowait() + audio_event = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 16000 + } + await agent.send_audio(audio_event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) # Restored to working timing + except asyncio.CancelledError: + break + + context["active"] = False + except asyncio.CancelledError: + pass + + +async def main(duration=180): + """Main function for bidirectional streaming test.""" + print("Starting bidirectional streaming test...") + print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") + + # Initialize model and agent + model = NovaSonicBidirectionalModel(region="us-east-1") + agent = BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful assistant." + ) + + await agent.start_conversation() + + # Create shared context for all tasks + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "session": agent._session, + "duration": duration, + "start_time": time.time(), + "interrupted": False, + } + + print("Speak into microphone. Press Ctrl+C to exit.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + return_exceptions=True + ) + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + finally: + print("Cleaning up...") + context["active"] = False + await agent.end_conversation() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py new file mode 100644 index 000000000..f6441d2f0 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -0,0 +1,3 @@ +"""Bidirectional streaming types package.""" +# Types package + diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py new file mode 100644 index 000000000..2b1480e62 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -0,0 +1,167 @@ +"""Bidirectional streaming types for real-time audio/text conversations. + +PROBLEM ADDRESSED: +----------------- +Strands currently uses a request-response architecture without bidirectional streaming +support. Users cannot interrupt ongoing responses, provide additional context during +processing, or engage in real-time conversations. Each interaction requires a complete +request-response cycle. + +ARCHITECTURAL TRANSFORMATION: +---------------------------- +Current Limitations: Strands' unidirectional architecture follows sequential +request-response cycles that prevent real-time interaction. This represents a +pull-based architecture where the model receives the request, processes it, and +sends a response back. + +Bidirectional Solution: Uses persistent session-based connections with continuous +input and output flow. This implements a push-based architecture where the model +sends updates to the client as soon as response becomes available, without explicit +client requests. + +KEY CHARACTERISTICS: +------------------- +- Persistent Sessions: Connections remain open for extended periods (Nova Sonic: 8 minutes, + Google Live API: 15 minutes, OpenAI Realtime: 30 minutes) maintaining conversation context +- Bidirectional Communication: Users can send input while models generate responses +- Interruption Handling: Users can interrupt ongoing model responses in real-time without + terminating the session +- Tool Execution: Tools execute concurrently within the conversation flow rather than + requiring requests rebuilding + +PROVIDER NORMALIZATION: +---------------------- +Must normalize incompatible audio formats: Nova Sonic's hex-encoded base64, Google's +LINEAR16 PCM, OpenAI's Base64-encoded PCM16. Requires unified interruption event types +to handle Nova Sonic's stopReason = INTERRUPTED events, Google's VAD cancellation, and +OpenAI's conversation.item.truncate. + +This module extends existing StreamEvent types while maintaining backward compatibility +with existing Strands streaming patterns. +""" + +from typing import Any, Dict, Literal, Optional + +from strands.types.content import Role +from strands.types.streaming import StreamEvent +from typing_extensions import TypedDict + +# Audio format constants +SUPPORTED_AUDIO_FORMATS = ['pcm', 'wav', 'opus', 'mp3'] +SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] +SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo +DEFAULT_SAMPLE_RATE = 16000 +DEFAULT_CHANNELS = 1 + +class AudioOutputEvent(TypedDict): + """Audio output event from the model. + + Standardizes audio output across different providers using raw bytes + instead of provider-specific encodings (base64, hex, etc.). + + Attributes: + audioData: Raw audio bytes (not base64 or hex encoded). + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + encoding: Original provider encoding for debugging purposes. + """ + + audioData: bytes + format: Literal['pcm', 'wav', 'opus', 'mp3'] + sampleRate: Literal[16000, 24000, 48000] + channels: Literal[1, 2] + encoding: Optional[str] + + +class AudioInputEvent(TypedDict): + """Audio input event for sending audio to the model. + + Used when sending audio data through send_audio() method. + + Attributes: + audioData: Raw audio bytes to send to model. + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + """ + + audioData: bytes + format: Literal['pcm', 'wav', 'opus', 'mp3'] + sampleRate: Literal[16000, 24000, 48000] + channels: Literal[1, 2] + + +class TextOutputEvent(TypedDict): + """Text output event from the model during bidirectional streaming. + + Attributes: + text: The text content from the model. + role: The role of the message sender. + """ + + text: str + role: Role + + +class InterruptionDetectedEvent(TypedDict): + """Interruption detection event. + + Signals when user interruption is detected during model generation. + + Attributes: + reason: Interruption reason from predefined set. + """ + + reason: Literal['user_input', 'vad_detected', 'manual'] + + +class BidirectionalConnectionStartEvent(TypedDict, total=False): + """Session start event for bidirectional streaming. + + Attributes: + sessionId: Unique session identifier. + metadata: Provider-specific session metadata. + """ + + sessionId: Optional[str] + metadata: Optional[Dict[str, Any]] + + +class BidirectionalConnectionEndEvent(TypedDict): + """Session end event for bidirectional streaming. + + Attributes: + reason: Reason for session end from predefined set. + sessionId: Unique session identifier. + metadata: Provider-specific session metadata. + """ + + reason: Literal['user_request', 'timeout', 'error'] + sessionId: Optional[str] + metadata: Optional[Dict[str, Any]] + + +class BidirectionalStreamEvent(StreamEvent, total=False): + """Bidirectional stream event extending existing StreamEvent. + + Inherits all existing StreamEvent fields (contentBlockDelta, toolUse, + messageStart, etc.) while adding bidirectional-specific events. + Maintains full backward compatibility with existing Strands streaming. + + Attributes: + audioOutput: Audio output from the model. + audioInput: Audio input sent to the model. + textOutput: Text output from the model. + interruptionDetected: User interruption detection. + BidirectionalConnectionStart: Session start event. + BidirectionalConnectionEnd: Session end event. + """ + + audioOutput: AudioOutputEvent + audioInput: AudioInputEvent + textOutput: TextOutputEvent + interruptionDetected: InterruptionDetectedEvent + BidirectionalConnectionStart: BidirectionalConnectionStartEvent + BidirectionalConnectionEnd: BidirectionalConnectionEndEvent + diff --git a/src/strands/experimental/bidirectional_streaming/utils/debug.py b/src/strands/experimental/bidirectional_streaming/utils/debug.py new file mode 100644 index 000000000..1e88b6ead --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/utils/debug.py @@ -0,0 +1,45 @@ +"""Debug utilities for Strands bidirectional streaming. + +Provides consistent debug logging across all bidirectional streaming components +with configurable output control matching the Nova Sonic tool use example. +""" + +import datetime +import inspect +import time + +# Debug logging system matching successful tool use example +DEBUG = False # Disable debug logging for clean output like tool use example + +def debug_print(message): + """Print debug message with timestamp and function name.""" + if DEBUG: + function_name = inspect.stack()[1].function + if function_name == 'time_it_async': + function_name = inspect.stack()[2].function + timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + print(f"{timestamp} {function_name} {message}") + +def log_event(event_type, **context): + """Log important events with structured context.""" + if DEBUG: + function_name = inspect.stack()[1].function + timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + context_str = " ".join([f"{k}={v}" for k, v in context.items()]) if context else "" + print(f"{timestamp} {function_name} EVENT: {event_type} {context_str}") + +def log_flow(step, details=""): + """Log important flow steps without excessive detail.""" + if DEBUG: + function_name = inspect.stack()[1].function + timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + print(f"{timestamp} {function_name} FLOW: {step} {details}") + +async def time_it_async(label, method_to_run): + """Time asynchronous method execution.""" + start_time = time.perf_counter() + result = await method_to_run() + end_time = time.perf_counter() + debug_print(f"Execution time for {label}: {end_time - start_time:.4f} seconds") + return result + From 9165a2074eaa3a35f1e7df01ddfdd04c7d6e523a Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 30 Sep 2025 10:41:16 -0400 Subject: [PATCH 02/15] Updated doc strings, updated method from send_text() and send_audio() to send(), Updated imports --- pyproject.toml | 2 +- .../bidirectional_streaming/agent/agent.py | 105 +++++++------ .../event_loop/bidirectional_event_loop.py | 62 ++++---- .../models/bidirectional_model.py | 75 +++++----- .../models/novasonic.py | 141 +++++++++--------- .../tests/test_bidirectional_streaming.py | 26 +++- .../types/bidirectional_streaming.py | 86 ++++------- 7 files changed, 234 insertions(+), 263 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d4f7e6eee..dd01ebde3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,docs,gemini,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index cfc005576..023997551 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -1,30 +1,22 @@ """Bidirectional Agent for real-time streaming conversations. -AGENT PURPOSE: -------------- -Provides type-safe constructor and session management for real-time audio/text -interaction. Serves as the bidirectional equivalent to invoke_async() → stream_async() -but establishes sessions that continue indefinitely with concurrent task management. +Provides real-time audio and text interaction through persistent streaming sessions. +Unlike traditional request-response patterns, this agent maintains long-running +conversations where users can interrupt, provide additional input, and receive +continuous responses including audio output. -ARCHITECTURAL APPROACH: ----------------------- -While invoke_async() creates single request-response cycles that terminate after -stop_reason: "end_turn" with sequential tool processing, start_conversation() -establishes persistent sessions with concurrent processing of model events, tool -execution, and user input without session termination. - -DESIGN CHOICE: -------------- -Uses dedicated BidirectionalAgent class (Option 1 from design document) for: -- Type safety with no conditional behavior based on model type -- Separation of concerns - solely focused on bidirectional streaming -- Future proofing - allows changes without implications to existing Agent class +Key capabilities: +- Persistent conversation sessions with concurrent processing +- Real-time audio input/output streaming +- Mid-conversation interruption and tool execution +- Event-driven communication with model providers """ import asyncio import logging -from typing import AsyncIterable, List, Optional +from typing import AsyncIterable, List, Optional, Union +from strands.tools.executors import ConcurrentToolExecutor from strands.tools.registry import ToolRegistry from strands.types.content import Messages @@ -39,8 +31,8 @@ class BidirectionalAgent: """Agent for bidirectional streaming conversations. - Provides type-safe constructor and session management for real-time - audio/text interaction with concurrent processing capabilities. + Enables real-time audio and text interaction with AI models through persistent + sessions. Supports concurrent tool execution and interruption handling. """ def __init__( @@ -69,60 +61,63 @@ def __init__( self.tool_registry.initialize_tools() # Initialize tool executor for concurrent execution - from strands.tools.executors import ConcurrentToolExecutor self.tool_executor = ConcurrentToolExecutor() # Session management self._session = None self._output_queue = asyncio.Queue() - async def start_conversation(self) -> None: - """Initialize persistent bidirectional session for real-time interaction. + async def start(self) -> None: + """Start a persistent bidirectional conversation session. - Creates provider-specific session and starts concurrent background tasks - for model events, tool execution, and session lifecycle management. + Initializes the streaming session and starts background tasks for processing + model events, tool execution, and session management. Raises: ValueError: If conversation already active. ConnectionError: If session creation fails. """ if self._session and self._session.active: - raise ValueError("Conversation already active. Call end_conversation() first.") + raise ValueError("Conversation already active. Call end() first.") log_flow("conversation_start", "initializing session") self._session = await start_bidirectional_connection(self) log_event("conversation_ready") - async def send_text(self, text: str) -> None: - """Send text input during active session without interrupting model generation. + async def send(self, input_data: Union[str, AudioInputEvent]) -> None: + """Send input to the model (text or audio). - Args: - text: Text message to send to the model. - - Raises: - ValueError: If no active session. - """ - self._validate_active_session() - log_event("text_sent", length=len(text)) - await self._session.model_session.send_text_content(text) - - async def send_audio(self, audio_input: AudioInputEvent) -> None: - """Send audio input during active session for real-time speech interaction. + Unified method for sending both text and audio input to the model during + an active conversation session. Args: - audio_input: AudioInputEvent containing audio data and configuration. + input_data: Either a string for text input or AudioInputEvent for audio input. Raises: - ValueError: If no active session. + ValueError: If no active session or invalid input type. """ self._validate_active_session() - await self._session.model_session.send_audio_content(audio_input) + + if isinstance(input_data, str): + # Handle text input + log_event("text_sent", length=len(input_data)) + await self._session.model_session.send_text_content(input_data) + elif isinstance(input_data, dict) and "audioData" in input_data: + # Handle audio input (AudioInputEvent) + await self._session.model_session.send_audio_content(input_data) + else: + raise ValueError( + "Input must be either a string (text) or AudioInputEvent " + "(dict with audioData, format, sampleRate, channels)" + ) + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive output events from the model including audio, text. + """Receive events from the model including audio, text, and tool calls. - Provides access to model output events processed by background tasks. - Events include audio output, text responses, tool calls, and session updates. + Yields model output events processed by background tasks including audio output, + text responses, tool calls, and session updates. Yields: BidirectionalStreamEvent: Events from the model session. @@ -135,10 +130,10 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: continue async def interrupt(self) -> None: - """Interrupt current model generation and switch to listening mode. + """Interrupt the current model generation and clear audio buffers. - Sends interruption signal to immediately stop generation and clear - pending audio output for responsive conversational experience. + Sends interruption signal to stop generation immediately and clears + pending audio output for responsive conversation flow. Raises: ValueError: If no active session. @@ -146,11 +141,11 @@ async def interrupt(self) -> None: self._validate_active_session() await self._session.model_session.send_interrupt() - async def end_conversation(self) -> None: - """End session and cleanup resources including background tasks. + async def end(self) -> None: + """End the conversation session and cleanup all resources. - Performs graceful session termination with proper resource cleanup - including background task cancellation and connection closure. + Terminates the streaming session, cancels background tasks, and + closes the connection to the model provider. """ if self._session: await stop_bidirectional_connection(self._session) @@ -163,5 +158,5 @@ def _validate_active_session(self) -> None: ValueError: If no active session. """ if not self._session or not self._session.active: - raise ValueError("No active conversation. Call start_conversation() first.") + raise ValueError("No active conversation. Call start() first.") diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 2164115d8..3884750d5 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -1,16 +1,14 @@ """Bidirectional session management for concurrent streaming conversations. -SESSION PURPOSE: ---------------- -Session wrapper for bidirectional communication that manages concurrent tasks for -model events, tool execution, and audio processing while providing simple interface -for Agent interaction. +Manages bidirectional communication sessions with concurrent processing of model events, +tool execution, and audio processing. Provides coordination between background tasks +while maintaining a simple interface for agent interaction. -CONCURRENT ARCHITECTURE: ------------------------ -Unlike existing event_loop_cycle() that processes events sequentially where tool -execution blocks conversation, this module coordinates concurrent tasks through -asyncio queues and background task management. +Features: +- Concurrent task management for model events and tool execution +- Interruption handling with audio buffer clearing +- Tool execution with cancellation support +- Session lifecycle management """ import asyncio @@ -35,10 +33,10 @@ class BidirectionalConnection: - """Session wrapper for bidirectional communication. + """Session wrapper for bidirectional communication with concurrent task management. - Manages concurrent tasks for model events, tool execution, and audio processing - while providing simple interface for Agent interaction. + Coordinates background tasks for model event processing, tool execution, and audio + handling while providing a simple interface for agent interactions. """ def __init__(self, model_session: BidirectionalModelSession, agent): @@ -66,8 +64,8 @@ def __init__(self, model_session: BidirectionalModelSession, agent): async def start_bidirectional_connection(agent) -> BidirectionalConnection: """Initialize bidirectional session with concurrent background tasks. - Creates provider-specific session and starts concurrent tasks for model events, - tool execution, and session lifecycle management. + Creates a model-specific session and starts background tasks for processing + model events, executing tools, and managing the session lifecycle. Args: agent: BidirectionalAgent instance. @@ -147,11 +145,10 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: - """Main bidirectional event loop coordinator - runs continuously during session. + """Main event loop coordinator that runs continuously during the session. - Coordinates background tasks and manages session lifecycle. Unlike the - sequential event_loop_cycle() that processes events one by one, this coordinator - manages concurrent tasks and session state. + Monitors background tasks, manages session state, and handles session lifecycle. + Provides supervision for concurrent model event processing and tool execution. Args: session: BidirectionalConnection to coordinate. @@ -185,10 +182,10 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No async def _handle_interruption(session: BidirectionalConnection) -> None: - """Handle interruption detection with comprehensive task cancellation. + """Handle interruption detection with task cancellation and audio buffer clearing. - Sets interruption flag, cancels pending tool tasks, and aggressively - clears audio output queue following Nova Sonic example patterns. + Cancels pending tool tasks and clears audio output queues to ensure responsive + interruption handling during conversations. Args: session: BidirectionalConnection to handle interruption for. @@ -251,10 +248,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: async def _process_model_events(session: BidirectionalConnection) -> None: - """Process model events using existing Strands event types. + """Process model events and convert them to Strands format. - This background task handles all model responses and converts - them to existing StreamEvent format for integration with Strands. + Background task that handles all model responses, converts provider-specific + events to standardized formats, and manages interruption detection. Args: session: BidirectionalConnection containing model session. @@ -309,11 +306,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: async def _process_tool_execution(session: BidirectionalConnection) -> None: - """Execute tools concurrently using existing Strands infrastructure with barge-in support. + """Execute tools concurrently with interruption support. - This background task manages tool execution without blocking - model event processing or user interaction. Includes proper - task cleanup and cancellation handling. + Background task that manages tool execution without blocking model event + processing or user interaction. Includes proper task cleanup and cancellation + handling for interruptions. Args: session: BidirectionalConnection containing tool queue. @@ -396,11 +393,10 @@ def _convert_to_strands_event(provider_event: Dict) -> Dict: async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: - """Execute tool using existing Strands infrastructure with barge-in support. + """Execute tool using Strands infrastructure with interruption support. - Model-agnostic tool execution that uses existing Strands tool system, - handles interruption during execution, and delegates result formatting - to provider-specific session. + Executes tools using the existing Strands tool system, handles interruption + during execution, and sends results back to the model provider. Args: session: BidirectionalConnection for context. diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 32727105d..81e5cd9d6 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -1,24 +1,14 @@ """Bidirectional model interface for real-time streaming conversations. -INTERFACE PURPOSE: ------------------ -Declares bidirectional capabilities separate from existing Model hierarchy to maintain -clean separation of concerns. Models choose to implement this interface explicitly -for bidirectional streaming support. +Defines the interface for models that support bidirectional streaming capabilities. +Provides abstractions for different model providers with connection-based communication +patterns that support real-time audio and text interaction. -PROVIDER ABSTRACTION: --------------------- -Abstracts incompatible initialization patterns: Nova Sonic's event-driven sequences, -Google's WebSocket setup, OpenAI's dual protocol support. Normalizes different tool -calling approaches and handles provider-specific session management with varying -time limits and connection patterns. - -SESSION-BASED APPROACH: ----------------------- -Unlike existing Model interface's stateless request-response pattern where each -stream() call processes complete messages independently, BidirectionalModel introduces -session-based approach where create_bidirectional_connection() establishes persistent -connections supporting real-time bidirectional communication during active generation. +Features: +- connection-based persistent connections +- Real-time bidirectional communication +- Provider-agnostic event normalization +- Tool execution integration """ import abc @@ -32,51 +22,54 @@ logger = logging.getLogger(__name__) class BidirectionalModelSession(abc.ABC): - """Model-specific session interface for bidirectional communication.""" + """Abstract interface for model-specific bidirectional communication connections. + + Defines the contract for managing persistent streaming connections with individual + model providers, handling audio/text input, receiving events, and managing + tool execution results. + """ @abc.abstractmethod async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: - """Receive events from model in provider-agnostic format. + """Receive events from the model in standardized format. - Normalizes different provider event formats so the event loop - can process all providers uniformly. + Converts provider-specific events to a common format that can be + processed uniformly by the event loop. """ raise NotImplementedError @abc.abstractmethod async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to model during session. + """Send audio content to the model during an active connection. - Manages complex audio encoding and provider-specific event sequences - while presenting simple AudioInputEvent interface to Agent. + Handles audio encoding and provider-specific formatting while presenting + a simple AudioInputEvent interface. """ raise NotImplementedError @abc.abstractmethod async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content processed concurrently with ongoing generation. + """Send text content to the model during ongoing generation. - Enables natural interruption and follow-up questions without session restart. + Allows natural interruption and follow-up questions without requiring + connection restart. """ raise NotImplementedError @abc.abstractmethod async def send_interrupt(self) -> None: - """Send interruption signal to immediately stop generation. + """Send interruption signal to stop generation immediately. - Critical for responsive conversational experiences where users - can naturally interrupt mid-response. + Enables responsive conversational experiences where users can + naturally interrupt during model responses. """ raise NotImplementedError @abc.abstractmethod async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: - """Send tool execution result to model in provider-specific format. + """Send tool execution result to the model. - Each provider handles result formatting according to their protocol: - - Nova Sonic: toolResult events with JSON content - - Google Live API: toolResponse with specific structure - - OpenAI Realtime: function call responses with call_id correlation + Formats and sends tool results according to the provider's specific protocol. """ raise NotImplementedError @@ -87,15 +80,15 @@ async def send_tool_error(self, tool_use_id: str, error: str) -> None: @abc.abstractmethod async def close(self) -> None: - """Close session and cleanup resources with graceful termination.""" + """Close the connection and cleanup resources.""" raise NotImplementedError class BidirectionalModel(abc.ABC): """Interface for models that support bidirectional streaming. - Separate from Model to maintain clean separation of concerns. - Models choose to implement this interface explicitly. + Defines the contract for creating persistent streaming connections that support + real-time audio and text communication with AI models. """ @abc.abstractmethod @@ -106,10 +99,10 @@ async def create_bidirectional_connection( messages: Optional[Messages] = None, **kwargs ) -> BidirectionalModelSession: - """Create bidirectional session with model-specific implementation. + """Create a bidirectional connection with the model. - Abstracts complex provider-specific initialization while presenting - uniform interface to Agent. + Establishes a persistent connection for real-time communication while + abstracting provider-specific initialization requirements. """ raise NotImplementedError diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index ba71cd4d3..4332181b5 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,23 +1,15 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. -PROVIDER PURPOSE: ----------------- -Implements BidirectionalModel and BidirectionalModelSession interfaces for Nova Sonic, -handling the complex three-tier event management and structured event cleanup sequences -required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. +Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the +complex event sequencing and audio processing required by Nova Sonic's +InvokeModelWithBidirectionalStream protocol. -NOVA SONIC SPECIFICS: --------------------- -- Requires hierarchical event sequences: sessionStart → promptStart → content streaming -- Uses hex-encoded base64 audio format that needs conversion to raw bytes -- Implements toolUse/toolResult with content containers and identifier tracking -- Manages 8-minute session limits with proper cleanup sequences -- Handles stopReason: "INTERRUPTED" events for interruption detection - -INTEGRATION APPROACH: --------------------- -Adapts existing Nova Sonic sample patterns to work with Strands bidirectional -infrastructure while maintaining provider-specific protocol requirements. +Nova Sonic specifics: +- Hierarchical event sequences: connectionStart → promptStart → content streaming +- Base64-encoded audio format with hex encoding +- Tool execution with content containers and identifier tracking +- 8-minute connection limits with proper cleanup sequences +- Interruption detection through stopReason events """ import asyncio @@ -85,10 +77,15 @@ class NovaSonicSession(BidirectionalModelSession): - """Nova Sonic session handling protocol-specific details.""" + """Nova Sonic connection implementation handling the provider's specific protocol. + + Manages Nova Sonic's complex event sequencing, audio format conversion, and + tool execution patterns while providing the standard BidirectionalModelSession + interface. + """ def __init__(self, stream, config: Dict[str, Any]): - """Initialize Nova Sonic session. + """Initialize Nova Sonic connection. Args: stream: Nova Sonic bidirectional stream. @@ -103,8 +100,8 @@ def __init__(self, stream, config: Dict[str, Any]): self.audio_content_name = str(uuid.uuid4()) self.text_content_name = str(uuid.uuid4()) - # Audio session state - self.audio_session_active = False + # Audio connection state + self.audio_connection_active = False self.last_audio_time = None self.silence_threshold = SILENCE_THRESHOLD self.silence_task = None @@ -114,7 +111,7 @@ def __init__(self, stream, config: Dict[str, Any]): logger.error("Stream is None") raise ValueError("Stream cannot be None") - logger.debug("Nova Sonic session initialized with prompt: %s", self.prompt_name) + logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) async def initialize( self, @@ -122,7 +119,7 @@ async def initialize( tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None ) -> None: - """Initialize Nova Sonic session with required protocol sequence.""" + """Initialize Nova Sonic connection with required protocol sequence.""" try: system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." @@ -131,7 +128,7 @@ async def initialize( log_flow("nova_init", f"sending {len(init_events)} events") await self._send_initialization_events(init_events) - log_event("nova_session_initialized") + log_event("nova_connection_initialized") self._response_task = asyncio.create_task(self._process_responses()) except Exception as e: @@ -142,7 +139,7 @@ def _build_initialization_events(self, system_prompt: str, tools: List[ToolSpec] messages: Optional[Messages]) -> List[str]: """Build the sequence of initialization events.""" events = [ - self._get_session_start_event(), + self._get_connection_start_event(), self._get_prompt_start_event(tools) ] @@ -223,13 +220,13 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: log_flow("nova_events", "starting event stream") - # Emit session start event to Strands event system - session_start: BidirectionalConnectionStartEvent = { - "sessionId": self.prompt_name, + # Emit connection start event to Strands event system + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.prompt_name, "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")} } yield { - "BidirectionalConnectionStart": session_start + "BidirectionalConnectionStart": connection_start } # Initialize event queue if not already done @@ -255,22 +252,22 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: logger.error("Error receiving Nova Sonic event: %s", e) logger.error(traceback.format_exc()) finally: - # Emit session end event when exiting - session_end: BidirectionalConnectionEndEvent = { - "sessionId": self.prompt_name, - "reason": "session_complete", + # Emit connection end event when exiting + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.prompt_name, + "reason": "connection_complete", "metadata": {"provider": "nova_sonic"} } yield { - "BidirectionalConnectionEnd": session_end + "BidirectionalConnectionEnd": connection_end } - async def start_audio_session(self) -> None: - """Start audio input session (call once before sending audio chunks).""" - if self.audio_session_active: + async def start_audio_connection(self) -> None: + """Start audio input connection (call once before sending audio chunks).""" + if self.audio_connection_active: return - log_event("nova_audio_session_start") + log_event("nova_audio_connection_start") audio_content_start = json.dumps({ "event": { @@ -286,16 +283,16 @@ async def start_audio_session(self) -> None: }) await self._send_nova_event(audio_content_start) - self.audio_session_active = True + self.audio_connection_active = True async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """Send audio using Nova Sonic protocol-specific format.""" if not self._active: return - # Start audio session if not already active - if not self.audio_session_active: - await self.start_audio_session() + # Start audio connection if not already active + if not self.audio_connection_active: + await self.start_audio_connection() # Update last audio time and cancel any pending silence task self.last_audio_time = time.time() @@ -322,10 +319,10 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: self.silence_task = asyncio.create_task(self._check_silence()) async def _check_silence(self): - """Check for silence and automatically end audio session.""" + """Check for silence and automatically end audio connection.""" try: await asyncio.sleep(self.silence_threshold) - if self.audio_session_active and self.last_audio_time: + if self.audio_connection_active and self.last_audio_time: elapsed = time.time() - self.last_audio_time if elapsed >= self.silence_threshold: log_event("nova_silence_detected", elapsed=elapsed) @@ -334,11 +331,11 @@ async def _check_silence(self): pass async def end_audio_input(self) -> None: - """End current audio input session to trigger Nova Sonic processing.""" - if not self.audio_session_active: + """End current audio input connection to trigger Nova Sonic processing.""" + if not self.audio_connection_active: return - log_event("nova_audio_session_end") + log_event("nova_audio_connection_end") audio_content_end = json.dumps({ "event": { @@ -350,7 +347,7 @@ async def end_audio_input(self) -> None: }) await self._send_nova_event(audio_content_end) - self.audio_session_active = False + self.audio_connection_active = False async def send_text_content(self, text: str, **kwargs) -> None: """Send text content using Nova Sonic format.""" @@ -407,11 +404,11 @@ async def send_tool_error(self, tool_use_id: str, error: str) -> None: await self.send_tool_result(tool_use_id, error_result) async def close(self) -> None: - """Close Nova Sonic session with proper cleanup sequence.""" + """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: return - log_flow("nova_cleanup", "starting session close") + log_flow("nova_cleanup", "starting connection close") self._active = False # Cancel response processing task if running @@ -423,14 +420,14 @@ async def close(self) -> None: pass try: - # End audio session if active - if self.audio_session_active: + # End audio connection if active + if self.audio_connection_active: await self.end_audio_input() # Send cleanup events cleanup_events = [ self._get_prompt_end_event(), - self._get_session_end_event() + self._get_connection_end_event() ] for event in cleanup_events: @@ -448,7 +445,7 @@ async def close(self) -> None: except Exception as e: log_event("nova_cleanup_error", error=str(e)) finally: - log_event("nova_session_closed") + log_event("nova_connection_closed") def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Convert Nova Sonic events to provider-agnostic format.""" @@ -542,8 +539,8 @@ def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, return None # Nova Sonic event template methods - def _get_session_start_event(self) -> str: - """Generate Nova Sonic session start event.""" + def _get_connection_start_event(self) -> str: + """Generate Nova Sonic connection start event.""" return json.dumps({ "event": { "sessionStart": { @@ -676,11 +673,11 @@ def _get_prompt_end_event(self) -> str: } }) - def _get_session_end_event(self) -> str: - """Generate session end event.""" + def _get_connection_end_event(self) -> str: + """Generate connection end event.""" return json.dumps({ "event": { - "sessionEnd": {} + "connectionEnd": {} } }) @@ -703,7 +700,11 @@ async def _send_nova_event(self, event: str) -> None: class NovaSonicBidirectionalModel(BidirectionalModel): - """Nova Sonic model implementing bidirectional capabilities.""" + """Nova Sonic model implementation for bidirectional streaming. + + Provides access to Amazon's Nova Sonic model through the bidirectional + streaming interface, handling AWS authentication and connection management. + """ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): """Initialize Nova Sonic bidirectional model. @@ -727,8 +728,8 @@ async def create_bidirectional_connection( messages: Optional[Messages] = None, **kwargs ) -> BidirectionalModelSession: - """Create Nova Sonic bidirectional session.""" - log_flow("nova_session_create", "starting") + """Create Nova Sonic bidirectional connection.""" + log_flow("nova_connection_create", "starting") # Initialize client if needed if not self._client: @@ -741,16 +742,16 @@ async def create_bidirectional_connection( InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) )) - # Create and initialize session - session = NovaSonicSession(stream, self.config) - await time_it_async("initialize_session", - lambda: session.initialize(system_prompt, tools, messages)) + # Create and initialize connection + connection = NovaSonicSession(stream, self.config) + await time_it_async("initialize_connection", + lambda: connection.initialize(system_prompt, tools, messages)) - log_event("nova_session_created") - return session + log_event("nova_connection_created") + return connection except Exception as e: - log_event("nova_session_create_error", error=str(e)) - logger.error("Failed to create Nova Sonic session: %s", e) + log_event("nova_connection_create_error", error=str(e)) + logger.error("Failed to create Nova Sonic connection: %s", e) raise async def _initialize_client(self) -> None: diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index f35fd4462..d650aba9b 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -1,11 +1,20 @@ -"""Simple bidirectional streaming test with enhanced interruption support.""" +"""Test suite for bidirectional streaming with real-time audio interaction. + +Tests the complete bidirectional streaming system including audio input/output, +interruption handling, and concurrent tool execution using Nova Sonic. +""" import asyncio +import sys +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) import time import pyaudio -from src.strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from src.strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel from strands_tools import calculator @@ -139,9 +148,10 @@ async def send(agent, context): audio_event = { "audioData": audio_bytes, "format": "pcm", - "sampleRate": 16000 + "sampleRate": 16000, + "channels": 1 } - await agent.send_audio(audio_event) + await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) # Restored to working timing except asyncio.CancelledError: @@ -165,14 +175,14 @@ async def main(duration=180): system_prompt="You are a helpful assistant." ) - await agent.start_conversation() + await agent.start() # Create shared context for all tasks context = { "active": True, "audio_in": asyncio.Queue(), "audio_out": asyncio.Queue(), - "session": agent._session, + "connection": agent._session, "duration": duration, "start_time": time.time(), "interrupted": False, @@ -196,7 +206,7 @@ async def main(duration=180): finally: print("Cleaning up...") context["active"] = False - await agent.end_conversation() + await agent.end() if __name__ == "__main__": diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 2b1480e62..fabe53ac9 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -1,43 +1,20 @@ """Bidirectional streaming types for real-time audio/text conversations. -PROBLEM ADDRESSED: ------------------ -Strands currently uses a request-response architecture without bidirectional streaming -support. Users cannot interrupt ongoing responses, provide additional context during -processing, or engage in real-time conversations. Each interaction requires a complete -request-response cycle. - -ARCHITECTURAL TRANSFORMATION: ----------------------------- -Current Limitations: Strands' unidirectional architecture follows sequential -request-response cycles that prevent real-time interaction. This represents a -pull-based architecture where the model receives the request, processes it, and -sends a response back. - -Bidirectional Solution: Uses persistent session-based connections with continuous -input and output flow. This implements a push-based architecture where the model -sends updates to the client as soon as response becomes available, without explicit -client requests. - -KEY CHARACTERISTICS: -------------------- -- Persistent Sessions: Connections remain open for extended periods (Nova Sonic: 8 minutes, - Google Live API: 15 minutes, OpenAI Realtime: 30 minutes) maintaining conversation context -- Bidirectional Communication: Users can send input while models generate responses -- Interruption Handling: Users can interrupt ongoing model responses in real-time without - terminating the session -- Tool Execution: Tools execute concurrently within the conversation flow rather than - requiring requests rebuilding - -PROVIDER NORMALIZATION: ----------------------- -Must normalize incompatible audio formats: Nova Sonic's hex-encoded base64, Google's -LINEAR16 PCM, OpenAI's Base64-encoded PCM16. Requires unified interruption event types -to handle Nova Sonic's stopReason = INTERRUPTED events, Google's VAD cancellation, and -OpenAI's conversation.item.truncate. - -This module extends existing StreamEvent types while maintaining backward compatibility -with existing Strands streaming patterns. +Type definitions for bidirectional streaming that extends Strands' existing streaming +capabilities with real-time audio and persistent connection support. + +Key features: +- Audio input/output events with standardized formats +- Interruption detection and handling +- connection lifecycle management +- Provider-agnostic event types +- Backwards compatibility with existing StreamEvent types + +Audio format normalization: +- Supports PCM, WAV, Opus, and MP3 formats +- Standardizes sample rates (16kHz, 24kHz, 48kHz) +- Normalizes channel configurations (mono/stereo) +- Abstracts provider-specific encodings """ from typing import Any, Dict, Literal, Optional @@ -56,8 +33,8 @@ class AudioOutputEvent(TypedDict): """Audio output event from the model. - Standardizes audio output across different providers using raw bytes - instead of provider-specific encodings (base64, hex, etc.). + Provides standardized audio output format across different providers using + raw bytes instead of provider-specific encodings. Attributes: audioData: Raw audio bytes (not base64 or hex encoded). @@ -77,7 +54,7 @@ class AudioOutputEvent(TypedDict): class AudioInputEvent(TypedDict): """Audio input event for sending audio to the model. - Used when sending audio data through send_audio() method. + Used for sending audio data through the send() method. Attributes: audioData: Raw audio bytes to send to model. @@ -117,45 +94,44 @@ class InterruptionDetectedEvent(TypedDict): class BidirectionalConnectionStartEvent(TypedDict, total=False): - """Session start event for bidirectional streaming. + """connection start event for bidirectional streaming. Attributes: - sessionId: Unique session identifier. - metadata: Provider-specific session metadata. + connectionId: Unique connection identifier. + metadata: Provider-specific connection metadata. """ - sessionId: Optional[str] + connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalConnectionEndEvent(TypedDict): - """Session end event for bidirectional streaming. + """connection end event for bidirectional streaming. Attributes: - reason: Reason for session end from predefined set. - sessionId: Unique session identifier. - metadata: Provider-specific session metadata. + reason: Reason for connection end from predefined set. + connectionId: Unique connection identifier. + metadata: Provider-specific connection metadata. """ reason: Literal['user_request', 'timeout', 'error'] - sessionId: Optional[str] + connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. - Inherits all existing StreamEvent fields (contentBlockDelta, toolUse, - messageStart, etc.) while adding bidirectional-specific events. - Maintains full backward compatibility with existing Strands streaming. + Extends the existing StreamEvent type with bidirectional-specific events + while maintaining full backward compatibility with existing Strands streaming. Attributes: audioOutput: Audio output from the model. audioInput: Audio input sent to the model. textOutput: Text output from the model. interruptionDetected: User interruption detection. - BidirectionalConnectionStart: Session start event. - BidirectionalConnectionEnd: Session end event. + BidirectionalConnectionStart: connection start event. + BidirectionalConnectionEnd: connection end event. """ audioOutput: AudioOutputEvent From 15df9f9c06748c06376b596c7186e3712192e3cd Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 30 Sep 2025 10:45:29 -0400 Subject: [PATCH 03/15] Updated minimum python runtime dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index dd01ebde3..f45794d12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ bidirectional-streaming = [ "smithy-aws-core>=0.0.1", "pytz", "aws_sdk_bedrock_runtime", + "python>=3.12" ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ From 3a0e7d5c360107ea4a0c890bf1c9f18ee3f1c603 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 1 Oct 2025 23:54:05 -0400 Subject: [PATCH 04/15] fix imports --- .../bidirectional_streaming/__init__.py | 5 + .../bidirectional_streaming/agent/__init__.py | 7 +- .../bidirectional_streaming/agent/agent.py | 70 ++- .../event_loop/__init__.py | 17 +- .../event_loop/bidirectional_event_loop.py | 243 ++++---- .../models/__init__.py | 8 +- .../models/bidirectional_model.py | 38 +- .../models/novasonic.py | 546 ++++++++---------- .../tests/test_bidirectional_streaming.py | 65 +-- .../bidirectional_streaming/types/__init__.py | 32 +- .../types/bidirectional_streaming.py | 53 +- .../bidirectional_streaming/utils/__init__.py | 5 + .../bidirectional_streaming/utils/debug.py | 13 +- 13 files changed, 530 insertions(+), 572 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/utils/__init__.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py new file mode 100644 index 000000000..f6a3b41bf --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -0,0 +1,5 @@ +"""Bidirectional streaming package for real-time audio/text conversations.""" + +from .utils import log_event, log_flow, time_it_async + +__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/__init__.py b/src/strands/experimental/bidirectional_streaming/agent/__init__.py index bbd2c91f3..c490e001d 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/agent/__init__.py @@ -1,2 +1,5 @@ -"""Bidirectional streaming agent package.""" -# Agent package \ No newline at end of file +"""Bidirectional agent for real-time streaming conversations.""" + +from .agent import BidirectionalAgent + +__all__ = ["BidirectionalAgent"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 023997551..d7a5f17a3 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -1,13 +1,13 @@ """Bidirectional Agent for real-time streaming conversations. Provides real-time audio and text interaction through persistent streaming sessions. -Unlike traditional request-response patterns, this agent maintains long-running -conversations where users can interrupt, provide additional input, and receive +Unlike traditional request-response patterns, this agent maintains long-running +conversations where users can interrupt, provide additional input, and receive continuous responses including audio output. Key capabilities: - Persistent conversation sessions with concurrent processing -- Real-time audio input/output streaming +- Real-time audio input/output streaming - Mid-conversation interruption and tool execution - Event-driven communication with model providers """ @@ -16,10 +16,9 @@ import logging from typing import AsyncIterable, List, Optional, Union -from strands.tools.executors import ConcurrentToolExecutor -from strands.tools.registry import ToolRegistry -from strands.types.content import Messages - +from ....tools.executors import ConcurrentToolExecutor +from ....tools.registry import ToolRegistry +from ....types.content import Messages from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent @@ -30,20 +29,20 @@ class BidirectionalAgent: """Agent for bidirectional streaming conversations. - + Enables real-time audio and text interaction with AI models through persistent sessions. Supports concurrent tool execution and interruption handling. """ - + def __init__( self, model: BidirectionalModel, tools: Optional[List] = None, system_prompt: Optional[str] = None, - messages: Optional[Messages] = None + messages: Optional[Messages] = None, ): """Initialize bidirectional agent with required model and optional configuration. - + Args: model: BidirectionalModel instance supporting streaming sessions. tools: Optional list of tools available to the model. @@ -53,51 +52,51 @@ def __init__( self.model = model self.system_prompt = system_prompt self.messages = messages or [] - + # Initialize tool registry using existing Strands infrastructure self.tool_registry = ToolRegistry() if tools: self.tool_registry.process_tools(tools) self.tool_registry.initialize_tools() - + # Initialize tool executor for concurrent execution self.tool_executor = ConcurrentToolExecutor() - + # Session management self._session = None self._output_queue = asyncio.Queue() - + async def start(self) -> None: """Start a persistent bidirectional conversation session. - + Initializes the streaming session and starts background tasks for processing model events, tool execution, and session management. - + Raises: ValueError: If conversation already active. ConnectionError: If session creation fails. """ if self._session and self._session.active: raise ValueError("Conversation already active. Call end() first.") - + log_flow("conversation_start", "initializing session") self._session = await start_bidirectional_connection(self) log_event("conversation_ready") - + async def send(self, input_data: Union[str, AudioInputEvent]) -> None: """Send input to the model (text or audio). - + Unified method for sending both text and audio input to the model during an active conversation session. - + Args: input_data: Either a string for text input or AudioInputEvent for audio input. - + Raises: ValueError: If no active session or invalid input type. """ self._validate_active_session() - + if isinstance(input_data, str): # Handle text input log_event("text_sent", length=len(input_data)) @@ -110,15 +109,13 @@ async def send(self, input_data: Union[str, AudioInputEvent]) -> None: "Input must be either a string (text) or AudioInputEvent " "(dict with audioData, format, sampleRate, channels)" ) - - async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive events from the model including audio, text, and tool calls. - + Yields model output events processed by background tasks including audio output, text responses, tool calls, and session updates. - + Yields: BidirectionalStreamEvent: Events from the model session. """ @@ -128,35 +125,34 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: yield event except asyncio.TimeoutError: continue - + async def interrupt(self) -> None: """Interrupt the current model generation and clear audio buffers. - - Sends interruption signal to stop generation immediately and clears + + Sends interruption signal to stop generation immediately and clears pending audio output for responsive conversation flow. - + Raises: ValueError: If no active session. """ self._validate_active_session() await self._session.model_session.send_interrupt() - + async def end(self) -> None: """End the conversation session and cleanup all resources. - - Terminates the streaming session, cancels background tasks, and + + Terminates the streaming session, cancels background tasks, and closes the connection to the model provider. """ if self._session: await stop_bidirectional_connection(self._session) self._session = None - + def _validate_active_session(self) -> None: """Validate that an active session exists. - + Raises: ValueError: If no active session. """ if not self._session or not self._session.active: raise ValueError("No active conversation. Call start() first.") - diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py index 24080b703..af8c4e1e1 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py @@ -1,2 +1,15 @@ -"""Bidirectional streaming event loop package.""" -# Event Loop package \ No newline at end of file +"""Event loop management for bidirectional streaming.""" + +from .bidirectional_event_loop import ( + BidirectionalConnection, + bidirectional_event_loop_cycle, + start_bidirectional_connection, + stop_bidirectional_connection, +) + +__all__ = [ + "BidirectionalConnection", + "start_bidirectional_connection", + "stop_bidirectional_connection", + "bidirectional_event_loop_cycle", +] diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 3884750d5..c90d118ff 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -18,10 +18,9 @@ import uuid from typing import Any, Dict -from strands.tools._validator import validate_and_prepare_tools -from strands.types.content import Message -from strands.types.tools import ToolResult, ToolUse - +from ....tools._validator import validate_and_prepare_tools +from ....types.content import Message +from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -34,14 +33,14 @@ class BidirectionalConnection: """Session wrapper for bidirectional communication with concurrent task management. - + Coordinates background tasks for model event processing, tool execution, and audio handling while providing a simple interface for agent interactions. """ - + def __init__(self, model_session: BidirectionalModelSession, agent): """Initialize session with model session and agent reference. - + Args: model_session: Provider-specific bidirectional model session. agent: BidirectionalAgent instance for tool registry access. @@ -49,96 +48,93 @@ def __init__(self, model_session: BidirectionalModelSession, agent): self.model_session = model_session self.agent = agent self.active = True - + # Background processing coordination self.background_tasks = [] self.tool_queue = asyncio.Queue() self.audio_output_queue = asyncio.Queue() - + # Task management for cleanup self.pending_tool_tasks: Dict[str, asyncio.Task] = {} - + # Interruption handling (model-agnostic) self.interrupted = False -async def start_bidirectional_connection(agent) -> BidirectionalConnection: + +async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: """Initialize bidirectional session with concurrent background tasks. - + Creates a model-specific session and starts background tasks for processing model events, executing tools, and managing the session lifecycle. - + Args: agent: BidirectionalAgent instance. - + Returns: BidirectionalConnection: Active session with background tasks running. - """ + """ log_flow("session_start", "initializing model session") - + # Create provider-specific session model_session = await agent.model.create_bidirectional_connection( - system_prompt=agent.system_prompt, - tools=agent.tool_registry.get_all_tool_specs(), - messages=agent.messages + system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages ) - + # Create session wrapper for background processing session = BidirectionalConnection(model_session=model_session, agent=agent) - + # Start concurrent background processors IMMEDIATELY after session creation # This is critical - Nova Sonic needs response processing during initialization log_flow("background_tasks", "starting processors") session.background_tasks = [ - asyncio.create_task(_process_model_events(session)), # Handle model responses - asyncio.create_task(_process_tool_execution(session)) # Execute tools concurrently + asyncio.create_task(_process_model_events(session)), # Handle model responses + asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently ] - + # Start main coordination cycle - session.main_cycle_task = asyncio.create_task( - bidirectional_event_loop_cycle(session) - ) - + session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) + # Give background tasks a moment to start await asyncio.sleep(0.1) log_event("session_ready", tasks=len(session.background_tasks)) - + return session async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: """End session and cleanup resources including background tasks. - + Args: session: BidirectionalConnection to cleanup. """ if not session.active: return - + log_flow("session_cleanup", "starting") session.active = False - + # Cancel pending tool tasks for _, task in session.pending_tool_tasks.items(): if not task.done(): task.cancel() - + # Cancel background tasks for task in session.background_tasks: if not task.done(): task.cancel() - + # Cancel main cycle task - if hasattr(session, 'main_cycle_task') and not session.main_cycle_task.done(): + if hasattr(session, "main_cycle_task") and not session.main_cycle_task.done(): session.main_cycle_task.cancel() - + # Wait for tasks to complete all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) - if hasattr(session, 'main_cycle_task'): + if hasattr(session, "main_cycle_task"): all_tasks.append(session.main_cycle_task) - + if all_tasks: await asyncio.gather(*all_tasks, return_exceptions=True) - + # Close model session await session.model_session.close() log_event("session_closed") @@ -146,10 +142,10 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: """Main event loop coordinator that runs continuously during the session. - + Monitors background tasks, manages session state, and handles session lifecycle. Provides supervision for concurrent model event processing and tool execution. - + Args: session: BidirectionalConnection to coordinate. """ @@ -160,7 +156,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No log_event("session_end", reason="all_processors_completed") session.active = False break - + # Check for failed background tasks for i, task in enumerate(session.background_tasks): if task.done() and not task.cancelled(): @@ -169,10 +165,10 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No log_event("session_error", processor=i, error=str(exception)) session.active = False raise exception - + # Brief pause before next supervision check await asyncio.sleep(SUPERVISION_INTERVAL) - + except asyncio.CancelledError: break except Exception as e: @@ -183,16 +179,16 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No async def _handle_interruption(session: BidirectionalConnection) -> None: """Handle interruption detection with task cancellation and audio buffer clearing. - + Cancels pending tool tasks and clears audio output queues to ensure responsive interruption handling during conversations. - + Args: session: BidirectionalConnection to handle interruption for. """ log_event("interruption_detected") session.interrupted = True - + # 🔥 CANCEL ALL PENDING TOOL TASKS (Nova Sonic pattern) cancelled_tools = 0 for task_id, task in list(session.pending_tool_tasks.items()): @@ -200,10 +196,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: task.cancel() cancelled_tools += 1 log_event("tool_task_cancelled", task_id=task_id) - + if cancelled_tools > 0: log_event("tool_tasks_cancelled", count=cancelled_tools) - + # 🔥 AGGRESSIVELY CLEAR AUDIO OUTPUT QUEUE (Nova Sonic pattern) cleared_count = 0 while True: @@ -212,9 +208,9 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: cleared_count += 1 except asyncio.QueueEmpty: break - + # Also clear the agent's audio output queue if it exists - if hasattr(session.agent, '_output_queue'): + if hasattr(session.agent, "_output_queue"): audio_cleared = 0 # Create a temporary list to hold non-audio events temp_events = [] @@ -228,20 +224,20 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: temp_events.append(event) except asyncio.QueueEmpty: pass - + # Put back non-audio events for event in temp_events: session.agent._output_queue.put_nowait(event) - + if audio_cleared > 0: log_event("agent_audio_queue_cleared", count=audio_cleared) - + if cleared_count > 0: log_event("session_audio_queue_cleared", count=cleared_count) - + # Brief sleep to allow audio system to settle (matches Nova Sonic timing) await asyncio.sleep(0.05) - + # Reset interruption flag after clearing (automatic recovery) session.interrupted = False log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) @@ -249,10 +245,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: async def _process_model_events(session: BidirectionalConnection) -> None: """Process model events and convert them to Strands format. - + Background task that handles all model responses, converts provider-specific events to standardized formats, and manages interruption detection. - + Args: session: BidirectionalConnection containing model session. """ @@ -261,10 +257,10 @@ async def _process_model_events(session: BidirectionalConnection) -> None: async for provider_event in session.model_session.receive_events(): if not session.active: break - + # Convert provider events to Strands format strands_event = _convert_to_strands_event(provider_event) - + # Handle interruption detection (multiple patterns) if strands_event.get("interruptionDetected"): log_event("interruption_forwarded") @@ -272,7 +268,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Forward interruption event to agent for application-level handling await session.agent._output_queue.put(strands_event) continue - + # Check for text-based interruption (Nova Sonic pattern) if strands_event.get("textOutput"): text_content = strands_event["textOutput"].get("content", "") @@ -282,22 +278,22 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Still forward the text event await session.agent._output_queue.put(strands_event) continue - + # Queue tool requests for concurrent execution if strands_event.get("toolUse"): log_event("tool_queued", name=strands_event["toolUse"].get("name")) await session.tool_queue.put(strands_event["toolUse"]) continue - + # Send output events to Agent for receive() method if strands_event.get("audioOutput") or strands_event.get("textOutput"): await session.agent._output_queue.put(strands_event) - + # Update Agent conversation history using existing patterns if strands_event.get("messageStop"): log_event("message_added_to_history") session.agent.messages.append(strands_event["messageStop"]["message"]) - + except Exception as e: log_event("model_events_error", error=str(e)) traceback.print_exc() @@ -307,11 +303,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. - + Background task that manages tool execution without blocking model event processing or user interaction. Includes proper task cleanup and cancellation handling for interruptions. - + Args: session: BidirectionalConnection containing tool queue. """ @@ -320,143 +316,136 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: try: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) - + if not session.active: break - + task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) session.pending_tool_tasks[task_id] = task - + # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) def cleanup_task(completed_task): try: # Remove from pending tasks if task_id in session.pending_tool_tasks: del session.pending_tool_tasks[task_id] - + # Log completion status if completed_task.cancelled(): log_event("tool_task_cleanup_cancelled", task_id=task_id) elif completed_task.exception(): - log_event("tool_task_cleanup_error", task_id=task_id, - error=str(completed_task.exception())) + log_event("tool_task_cleanup_error", task_id=task_id, error=str(completed_task.exception())) else: log_event("tool_task_cleanup_success", task_id=task_id) except Exception as e: log_event("tool_task_cleanup_failed", task_id=task_id, error=str(e)) - + task.add_done_callback(cleanup_task) - + except asyncio.TimeoutError: if not session.active: break # 🔥 PERIODIC CLEANUP OF COMPLETED TASKS - completed_tasks = [ - task_id for task_id, task in session.pending_tool_tasks.items() - if task.done() - ] + completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] for task_id in completed_tasks: if task_id in session.pending_tool_tasks: del session.pending_tool_tasks[task_id] - + if completed_tasks: log_event("periodic_task_cleanup", count=len(completed_tasks)) - + continue except Exception as e: log_event("tool_execution_error", error=str(e)) if not session.active: break - + log_flow("tool_execution", "processor stopped") def _convert_to_strands_event(provider_event: Dict) -> Dict: """Pass-through for events already normalized by provider sessions. - + Providers convert their raw events to standard format before reaching here. This just validates and passes through the normalized events. - + Args: provider_event: Already normalized event from provider session. - + Returns: Dict: The same event, validated and passed through. """ # Basic validation - ensure we have a dict if not isinstance(provider_event, dict): return {} - + # Pass through - conversion already done by provider session return provider_event async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: """Execute tool using Strands infrastructure with interruption support. - + Executes tools using the existing Strands tool system, handles interruption during execution, and sends results back to the model provider. - + Args: session: BidirectionalConnection for context. tool_use: Tool use event to execute. """ - tool_name = tool_use.get('name') - tool_id = tool_use.get('toolUseId') - + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + try: # 🔥 CHECK FOR INTERRUPTION BEFORE STARTING (Nova Sonic pattern) if session.interrupted or not session.active: log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) return - + # Create message structure for existing tool system - tool_message: Message = { - "role": "assistant", - "content": [{"toolUse": tool_use}] - } - + tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] - + # Validate using existing Strands validation validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - + # Filter valid tool uses valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - + if not valid_tool_uses: log_event("tool_validation_failed", name=tool_name, id=tool_id) return - + # Execute tools directly (simpler approach for bidirectional) for tool_use in valid_tool_uses: # 🔥 CHECK FOR INTERRUPTION DURING EXECUTION if session.interrupted or not session.active: log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) return - + tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) - + if tool_func: try: actual_func = _extract_callable_function(tool_func) - + # 🔥 WRAP TOOL EXECUTION IN CANCELLATION CHECK # For async tools, we could wrap with asyncio.wait_for with cancellation # For sync tools, we execute directly but check interruption after result = actual_func(**tool_use.get("input", {})) - + # 🔥 CHECK FOR INTERRUPTION AFTER TOOL EXECUTION if session.interrupted or not session.active: log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) return - + tool_result = _create_success_result(tool_use["toolUseId"], result) tool_results.append(tool_result) - + except asyncio.CancelledError: # Tool was cancelled due to interruption log_event("tool_execution_cancelled", name=tool_name, id=tool_id) @@ -466,50 +455,44 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: if session.interrupted or not session.active: log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) return - + log_event("tool_execution_failed", name=tool_name, error=str(e)) tool_result = _create_error_result(tool_use["toolUseId"], str(e)) tool_results.append(tool_result) else: log_event("tool_not_found", name=tool_name) - + # 🔥 FINAL INTERRUPTION CHECK BEFORE SENDING RESULTS if session.interrupted or not session.active: log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) return - + # Send results through provider-specific session for result in tool_results: - await session.model_session.send_tool_result( - tool_use.get("toolUseId"), - result - ) - + await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) + log_event("tool_execution_completed", name=tool_name, results=len(tool_results)) - + except asyncio.CancelledError: # Task was cancelled due to interruption - this is expected behavior log_event("tool_task_cancelled_gracefully", name=tool_name, id=tool_id) raise # Re-raise to properly handle cancellation except Exception as e: - log_event("tool_execution_error", name=tool_use.get('name'), error=str(e)) - + log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) + # Only send error if not interrupted if not session.interrupted and session.active: try: - await session.model_session.send_tool_error( - tool_use.get("toolUseId"), - str(e) - ) + await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) except Exception as send_error: log_event("tool_error_send_failed", error=str(send_error)) def _extract_callable_function(tool_func): """Extract the callable function from different tool object types.""" - if hasattr(tool_func, '_tool_func'): + if hasattr(tool_func, "_tool_func"): return tool_func._tool_func - elif hasattr(tool_func, 'func'): + elif hasattr(tool_func, "func"): return tool_func.func elif callable(tool_func): return tool_func @@ -519,17 +502,9 @@ def _extract_callable_function(tool_func): def _create_success_result(tool_use_id: str, result) -> Dict[str, Any]: """Create a successful tool result.""" - return { - "toolUseId": tool_use_id, - "status": "success", - "content": [{"text": json.dumps(result)}] - } + return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} def _create_error_result(tool_use_id: str, error: str) -> Dict[str, Any]: """Create an error tool result.""" - return { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error}"}] - } \ No newline at end of file + return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index b2b10a5f2..6cba974e0 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,2 +1,6 @@ -"""Bidirectional streaming models package.""" -# Models package \ No newline at end of file +"""Bidirectional model interfaces and implementations.""" + +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession + +__all__ = ["BidirectionalModel", "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession"] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 81e5cd9d6..cc803458b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -7,7 +7,7 @@ Features: - connection-based persistent connections - Real-time bidirectional communication -- Provider-agnostic event normalization +- Provider-agnostic event normalization - Tool execution integration """ @@ -21,63 +21,64 @@ logger = logging.getLogger(__name__) + class BidirectionalModelSession(abc.ABC): """Abstract interface for model-specific bidirectional communication connections. - + Defines the contract for managing persistent streaming connections with individual model providers, handling audio/text input, receiving events, and managing tool execution results. """ - + @abc.abstractmethod async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: """Receive events from the model in standardized format. - + Converts provider-specific events to a common format that can be processed uniformly by the event loop. """ raise NotImplementedError - + @abc.abstractmethod async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """Send audio content to the model during an active connection. - + Handles audio encoding and provider-specific formatting while presenting a simple AudioInputEvent interface. """ raise NotImplementedError - + @abc.abstractmethod async def send_text_content(self, text: str, **kwargs) -> None: """Send text content to the model during ongoing generation. - + Allows natural interruption and follow-up questions without requiring connection restart. """ raise NotImplementedError - + @abc.abstractmethod async def send_interrupt(self) -> None: """Send interruption signal to stop generation immediately. - + Enables responsive conversational experiences where users can naturally interrupt during model responses. """ raise NotImplementedError - + @abc.abstractmethod async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: """Send tool execution result to the model. - + Formats and sends tool results according to the provider's specific protocol. """ raise NotImplementedError - + @abc.abstractmethod async def send_tool_error(self, tool_use_id: str, error: str) -> None: """Send tool execution error to model in provider-specific format.""" raise NotImplementedError - + @abc.abstractmethod async def close(self) -> None: """Close the connection and cleanup resources.""" @@ -86,23 +87,22 @@ async def close(self) -> None: class BidirectionalModel(abc.ABC): """Interface for models that support bidirectional streaming. - + Defines the contract for creating persistent streaming connections that support real-time audio and text communication with AI models. """ - + @abc.abstractmethod async def create_bidirectional_connection( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None, - **kwargs + **kwargs, ) -> BidirectionalModelSession: """Create a bidirectional connection with the model. - + Establishes a persistent connection for real-time communication while abstracting provider-specific initialization requirements. """ raise NotImplementedError - diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 4332181b5..0efd2413c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,7 +1,7 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the -complex event sequencing and audio processing required by Nova Sonic's +complex event sequencing and audio processing required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. Nova Sonic specifics: @@ -42,11 +42,7 @@ logger = logging.getLogger(__name__) # Nova Sonic configuration constants -NOVA_INFERENCE_CONFIG = { - "maxTokens": 1024, - "topP": 0.9, - "temperature": 0.7 -} +NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} NOVA_AUDIO_INPUT_CONFIG = { "mediaType": "audio/lpcm", @@ -54,7 +50,7 @@ "sampleSizeBits": 16, "channelCount": 1, "audioType": "SPEECH", - "encoding": "base64" + "encoding": "base64", } NOVA_AUDIO_OUTPUT_CONFIG = { @@ -64,7 +60,7 @@ "channelCount": 1, "voiceId": "matthew", "encoding": "base64", - "audioType": "SPEECH" + "audioType": "SPEECH", } NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} @@ -78,15 +74,15 @@ class NovaSonicSession(BidirectionalModelSession): """Nova Sonic connection implementation handling the provider's specific protocol. - + Manages Nova Sonic's complex event sequencing, audio format conversion, and tool execution patterns while providing the standard BidirectionalModelSession interface. """ - + def __init__(self, stream, config: Dict[str, Any]): """Initialize Nova Sonic connection. - + Args: stream: Nova Sonic bidirectional stream. config: Model configuration. @@ -95,80 +91,78 @@ def __init__(self, stream, config: Dict[str, Any]): self.config = config self.prompt_name = str(uuid.uuid4()) self._active = True - + # Nova Sonic requires unique content names self.audio_content_name = str(uuid.uuid4()) self.text_content_name = str(uuid.uuid4()) - + # Audio connection state self.audio_connection_active = False self.last_audio_time = None self.silence_threshold = SILENCE_THRESHOLD self.silence_task = None - + # Validate stream if not stream: logger.error("Stream is None") raise ValueError("Stream cannot be None") - + logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) - + async def initialize( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None + messages: Optional[Messages] = None, ) -> None: """Initialize Nova Sonic connection with required protocol sequence.""" try: system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." - + init_events = self._build_initialization_events(system_prompt, tools or [], messages) - + log_flow("nova_init", f"sending {len(init_events)} events") await self._send_initialization_events(init_events) - + log_event("nova_connection_initialized") self._response_task = asyncio.create_task(self._process_responses()) - + except Exception as e: logger.error("Error during Nova Sonic initialization: %s", e) raise - - def _build_initialization_events(self, system_prompt: str, tools: List[ToolSpec], - messages: Optional[Messages]) -> List[str]: + + def _build_initialization_events( + self, system_prompt: str, tools: List[ToolSpec], messages: Optional[Messages] + ) -> List[str]: """Build the sequence of initialization events.""" - events = [ - self._get_connection_start_event(), - self._get_prompt_start_event(tools) - ] - + events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] + events.extend(self._get_system_prompt_events(system_prompt)) - + # Message history would be processed here if needed in the future # Currently not implemented as it's not used in the existing test cases - + return events - + async def _send_initialization_events(self, events: List[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): - await time_it_async(f"send_init_event_{i+1}", lambda: self._send_nova_event(event)) + await time_it_async(f"send_init_event_{i + 1}", lambda: self._send_nova_event(event)) await asyncio.sleep(EVENT_DELAY) - + async def _process_responses(self) -> None: """Process Nova Sonic responses continuously.""" log_flow("nova_responses", "processor started") - + try: while self._active: try: output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) result = await output[1].receive() - + if result.value and result.value.bytes_: - await self._handle_response_data(result.value.bytes_.decode('utf-8')) - + await self._handle_response_data(result.value.bytes_.decode("utf-8")) + except asyncio.TimeoutError: await asyncio.sleep(0.1) continue @@ -176,39 +170,39 @@ async def _process_responses(self) -> None: log_event("nova_response_error", error=str(e)) await asyncio.sleep(0.1) continue - + except Exception as e: log_event("nova_fatal_error", error=str(e)) finally: log_flow("nova_responses", "processor stopped") - + async def _handle_response_data(self, response_data: str) -> None: """Handle decoded response data from Nova Sonic.""" try: json_data = json.loads(response_data) - - if 'event' in json_data: - nova_event = json_data['event'] + + if "event" in json_data: + nova_event = json_data["event"] self._log_event_type(nova_event) - - if not hasattr(self, '_event_queue'): + + if not hasattr(self, "_event_queue"): self._event_queue = asyncio.Queue() - + await self._event_queue.put(nova_event) except json.JSONDecodeError as e: log_event("nova_json_error", error=str(e)) - + def _log_event_type(self, nova_event: Dict[str, Any]) -> None: """Log specific Nova Sonic event types for debugging.""" - if 'usageEvent' in nova_event: - log_event("nova_usage", usage=nova_event['usageEvent']) - elif 'textOutput' in nova_event: + if "usageEvent" in nova_event: + log_event("nova_usage", usage=nova_event["usageEvent"]) + elif "textOutput" in nova_event: log_event("nova_text_output") - elif 'toolUse' in nova_event: - tool_use = nova_event['toolUse'] - log_event("nova_tool_use", name=tool_use['toolName'], id=tool_use['toolUseId']) - elif 'audioOutput' in nova_event: - audio_content = nova_event['audioOutput']['content'] + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + log_event("nova_tool_use", name=tool_use["toolName"], id=tool_use["toolUseId"]) + elif "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) log_event("nova_audio_output", bytes=len(audio_bytes)) @@ -217,37 +211,35 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: if not self.stream: logger.error("Stream is None") return - + log_flow("nova_events", "starting event stream") - + # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.prompt_name, - "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")} - } - yield { - "BidirectionalConnectionStart": connection_start + "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")}, } - + yield {"BidirectionalConnectionStart": connection_start} + # Initialize event queue if not already done - if not hasattr(self, '_event_queue'): + if not hasattr(self, "_event_queue"): self._event_queue = asyncio.Queue() - + try: while self._active: try: # Get events from the queue populated by _process_responses nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) - + # Convert to provider-agnostic format provider_event = self._convert_nova_event(nova_event) if provider_event: yield provider_event - + except asyncio.TimeoutError: # No events in queue - continue waiting continue - + except Exception as e: logger.error("Error receiving Nova Sonic event: %s", e) logger.error(traceback.format_exc()) @@ -256,68 +248,70 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: connection_end: BidirectionalConnectionEndEvent = { "connectionId": self.prompt_name, "reason": "connection_complete", - "metadata": {"provider": "nova_sonic"} + "metadata": {"provider": "nova_sonic"}, } - yield { - "BidirectionalConnectionEnd": connection_end - } - + yield {"BidirectionalConnectionEnd": connection_end} + async def start_audio_connection(self) -> None: """Start audio input connection (call once before sending audio chunks).""" if self.audio_connection_active: return - + log_event("nova_audio_connection_start") - - audio_content_start = json.dumps({ - "event": { - "contentStart": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name, - "type": "AUDIO", - "interactive": True, - "role": "USER", - "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG + + audio_content_start = json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG, + } } } - }) - + ) + await self._send_nova_event(audio_content_start) self.audio_connection_active = True - + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """Send audio using Nova Sonic protocol-specific format.""" if not self._active: return - + # Start audio connection if not already active if not self.audio_connection_active: await self.start_audio_connection() - + # Update last audio time and cancel any pending silence task self.last_audio_time = time.time() if self.silence_task and not self.silence_task.done(): self.silence_task.cancel() - + # Convert audio to Nova Sonic base64 format - nova_audio_data = base64.b64encode(audio_input["audioData"]).decode('utf-8') - + nova_audio_data = base64.b64encode(audio_input["audioData"]).decode("utf-8") + # Send audio input event - audio_event = json.dumps({ - "event": { - "audioInput": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name, - "content": nova_audio_data + audio_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "content": nova_audio_data, + } } } - }) - + ) + await self._send_nova_event(audio_event) - + # Start silence detection task self.silence_task = asyncio.create_task(self._check_silence()) - + async def _check_silence(self): """Check for silence and automatically end audio connection.""" try: @@ -329,226 +323,195 @@ async def _check_silence(self): await self.end_audio_input() except asyncio.CancelledError: pass - + async def end_audio_input(self) -> None: """End current audio input connection to trigger Nova Sonic processing.""" if not self.audio_connection_active: return - + log_event("nova_audio_connection_end") - - audio_content_end = json.dumps({ - "event": { - "contentEnd": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name - } - } - }) - + + audio_content_end = json.dumps( + {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} + ) + await self._send_nova_event(audio_content_end) self.audio_connection_active = False - + async def send_text_content(self, text: str, **kwargs) -> None: """Send text content using Nova Sonic format.""" if not self._active: return - + content_name = str(uuid.uuid4()) events = [ self._get_text_content_start_event(content_name), self._get_text_input_event(content_name, text), - self._get_content_end_event(content_name) + self._get_content_end_event(content_name), ] - + for event in events: await self._send_nova_event(event) - + async def send_interrupt(self) -> None: """Send interruption signal to Nova Sonic.""" if not self._active: return - + # Nova Sonic handles interruption through special input events interrupt_event = { "event": { "audioInput": { "promptName": self.prompt_name, "contentName": self.audio_content_name, - "stopReason": "INTERRUPTED" + "stopReason": "INTERRUPTED", } } } await self._send_nova_event(interrupt_event) - + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: """Send tool result using Nova Sonic toolResult format.""" if not self._active: return - + log_event("nova_tool_result_send", id=tool_use_id) content_name = str(uuid.uuid4()) events = [ self._get_tool_content_start_event(content_name, tool_use_id), self._get_tool_result_event(content_name, result), - self._get_content_end_event(content_name) + self._get_content_end_event(content_name), ] - + for i, event in enumerate(events): - await time_it_async(f"send_tool_event_{i+1}", lambda: self._send_nova_event(event)) - + await time_it_async(f"send_tool_event_{i + 1}", lambda: self._send_nova_event(event)) + async def send_tool_error(self, tool_use_id: str, error: str) -> None: """Send tool error using Nova Sonic format.""" log_event("nova_tool_error_send", id=tool_use_id, error=error) error_result = {"error": error} await self.send_tool_result(tool_use_id, error_result) - + async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: return - + log_flow("nova_cleanup", "starting connection close") self._active = False - + # Cancel response processing task if running - if hasattr(self, '_response_task') and not self._response_task.done(): + if hasattr(self, "_response_task") and not self._response_task.done(): self._response_task.cancel() try: await self._response_task except asyncio.CancelledError: pass - + try: # End audio connection if active if self.audio_connection_active: await self.end_audio_input() - + # Send cleanup events - cleanup_events = [ - self._get_prompt_end_event(), - self._get_connection_end_event() - ] - + cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] + for event in cleanup_events: try: await self._send_nova_event(event) except Exception as e: logger.warning("Error during Nova Sonic cleanup: %s", e) - + # Close stream try: await self.stream.input_stream.close() except Exception as e: logger.warning("Error closing Nova Sonic stream: %s", e) - + except Exception as e: log_event("nova_cleanup_error", error=str(e)) finally: log_event("nova_connection_closed") - + def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Convert Nova Sonic events to provider-agnostic format.""" # Handle audio output if "audioOutput" in nova_event: audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) - + audio_output: AudioOutputEvent = { "audioData": audio_bytes, "format": "pcm", "sampleRate": 24000, "channels": 1, - "encoding": "base64" - } - - return { - "audioOutput": audio_output + "encoding": "base64", } - + + return {"audioOutput": audio_output} + # Handle text output elif "textOutput" in nova_event: text_content = nova_event["textOutput"]["content"] # Use stored role from contentStart event, fallback to event role - role = getattr(self, '_current_role', nova_event["textOutput"].get("role", "assistant")) - + role = getattr(self, "_current_role", nova_event["textOutput"].get("role", "assistant")) + # Check for Nova Sonic interruption pattern (matches working sample) if '{ "interrupted" : true }' in text_content: log_event("nova_interruption_in_text") - interruption: InterruptionDetectedEvent = { - "reason": "user_input" - } - return { - "interruptionDetected": interruption - } - + interruption: InterruptionDetectedEvent = {"reason": "user_input"} + return {"interruptionDetected": interruption} + # Show transcription for user speech - ALWAYS show these regardless of DEBUG flag if role == "USER": print(f"User: {text_content}") elif role == "ASSISTANT": print(f"Assistant: {text_content}") - - text_output: TextOutputEvent = { - "text": text_content, - "role": role.lower() - } - - return { - "textOutput": text_output - } - + + text_output: TextOutputEvent = {"text": text_content, "role": role.lower()} + + return {"textOutput": text_output} + # Handle tool use elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] - + tool_use_event: ToolUse = { "toolUseId": tool_use["toolUseId"], "name": tool_use["toolName"], - "input": json.loads(tool_use["content"]) - } - - return { - "toolUse": tool_use_event + "input": json.loads(tool_use["content"]), } - + + return {"toolUse": tool_use_event} + # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": log_event("nova_interruption_stop_reason") - - interruption: InterruptionDetectedEvent = { - "reason": "user_input" - } - - return { - "interruptionDetected": interruption - } - + + interruption: InterruptionDetectedEvent = {"reason": "user_input"} + + return {"interruptionDetected": interruption} + # Handle usage events (ignore) elif "usageEvent" in nova_event: return None - + # Handle content start events (track role) elif "contentStart" in nova_event: role = nova_event["contentStart"].get("role", "unknown") # Store role for subsequent text output events self._current_role = role return None - + # Handle other events else: return None - + # Nova Sonic event template methods def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" - return json.dumps({ - "event": { - "sessionStart": { - "inferenceConfiguration": NOVA_INFERENCE_CONFIG - } - } - }) - + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) + def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" prompt_start_event = { @@ -556,143 +519,121 @@ def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: "promptStart": { "promptName": self.prompt_name, "textOutputConfiguration": NOVA_TEXT_CONFIG, - "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG + "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, } } } - + if tools: tool_config = self._build_tool_configuration(tools) prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} - + return json.dumps(prompt_start_event) - + def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: """Build tool configuration from tool specs.""" tool_config = [] for tool in tools: - input_schema = ({"json": json.dumps(tool['inputSchema']['json'])} - if 'json' in tool['inputSchema'] - else {"json": json.dumps(tool['inputSchema'])}) - - tool_config.append({ - "toolSpec": { - "name": tool["name"], - "description": tool["description"], - "inputSchema": input_schema - } - }) + input_schema = ( + {"json": json.dumps(tool["inputSchema"]["json"])} + if "json" in tool["inputSchema"] + else {"json": json.dumps(tool["inputSchema"])} + ) + + tool_config.append( + {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} + ) return tool_config - + def _get_system_prompt_events(self, system_prompt: str) -> List[str]: """Generate system prompt events.""" content_name = str(uuid.uuid4()) return [ self._get_text_content_start_event(content_name, "SYSTEM"), self._get_text_input_event(content_name, system_prompt), - self._get_content_end_event(content_name) + self._get_content_end_event(content_name), ] - + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: """Generate text content start event.""" - return json.dumps({ - "event": { - "contentStart": { - "promptName": self.prompt_name, - "contentName": content_name, - "type": "TEXT", - "role": role, - "interactive": True, - "textInputConfiguration": NOVA_TEXT_CONFIG + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG, + } } } - }) - + ) + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: """Generate tool content start event.""" - return json.dumps({ - "event": { - "contentStart": { - "promptName": self.prompt_name, - "contentName": content_name, - "interactive": False, - "type": "TOOL", - "role": "TOOL", - "toolResultInputConfiguration": { - "toolUseId": tool_use_id, - "type": "TEXT", - "textInputConfiguration": NOVA_TEXT_CONFIG + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG, + }, } } } - }) - + ) + def _get_text_input_event(self, content_name: str, text: str) -> str: """Generate text input event.""" - return json.dumps({ - "event": { - "textInput": { - "promptName": self.prompt_name, - "contentName": content_name, - "content": text - } - } - }) - + return json.dumps( + {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} + ) + def _get_tool_result_event(self, content_name: str, result: Dict[str, Any]) -> str: """Generate tool result event.""" - return json.dumps({ - "event": { - "toolResult": { - "promptName": self.prompt_name, - "contentName": content_name, - "content": json.dumps(result) + return json.dumps( + { + "event": { + "toolResult": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": json.dumps(result), + } } } - }) - + ) + def _get_content_end_event(self, content_name: str) -> str: """Generate content end event.""" - return json.dumps({ - "event": { - "contentEnd": { - "promptName": self.prompt_name, - "contentName": content_name - } - } - }) - + return json.dumps({"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": content_name}}}) + def _get_prompt_end_event(self) -> str: """Generate prompt end event.""" - return json.dumps({ - "event": { - "promptEnd": { - "promptName": self.prompt_name - } - } - }) - + return json.dumps({"event": {"promptEnd": {"promptName": self.prompt_name}}}) + def _get_connection_end_event(self) -> str: """Generate connection end event.""" - return json.dumps({ - "event": { - "connectionEnd": {} - } - }) - + return json.dumps({"event": {"connectionEnd": {}}}) + async def _send_nova_event(self, event: str) -> None: """Send event JSON string to Nova Sonic stream.""" try: - # Event is already a JSON string - bytes_data = event.encode('utf-8') - chunk = InvokeModelWithBidirectionalStreamInputChunk( - value=BidirectionalInputPayloadPart(bytes_=bytes_data) - ) + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk(value=BidirectionalInputPayloadPart(bytes_=bytes_data)) await self.stream.input_stream.send(chunk) logger.debug("Successfully sent Nova Sonic event") - + except Exception as e: logger.error("Error sending Nova Sonic event: %s", e) logger.error("Event was: %s", event) @@ -701,14 +642,14 @@ async def _send_nova_event(self, event: str) -> None: class NovaSonicBidirectionalModel(BidirectionalModel): """Nova Sonic model implementation for bidirectional streaming. - + Provides access to Amazon's Nova Sonic model through the bidirectional streaming interface, handling AWS authentication and connection management. """ - + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): """Initialize Nova Sonic bidirectional model. - + Args: model_id: Nova Sonic model identifier. region: AWS region. @@ -718,61 +659,60 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e self.region = region self.config = config self._client = None - + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - + async def create_bidirectional_connection( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None, - **kwargs + **kwargs, ) -> BidirectionalModelSession: """Create Nova Sonic bidirectional connection.""" log_flow("nova_connection_create", "starting") - + # Initialize client if needed if not self._client: await time_it_async("initialize_client", lambda: self._initialize_client()) - + # Start Nova Sonic bidirectional stream try: - stream = await time_it_async("invoke_model_with_bidirectional_stream", + stream = await time_it_async( + "invoke_model_with_bidirectional_stream", lambda: self._client.invoke_model_with_bidirectional_stream( InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - )) - + ), + ) + # Create and initialize connection connection = NovaSonicSession(stream, self.config) - await time_it_async("initialize_connection", - lambda: connection.initialize(system_prompt, tools, messages)) - + await time_it_async("initialize_connection", lambda: connection.initialize(system_prompt, tools, messages)) + log_event("nova_connection_created") return connection except Exception as e: log_event("nova_connection_create_error", error=str(e)) logger.error("Failed to create Nova Sonic connection: %s", e) raise - + async def _initialize_client(self) -> None: """Initialize Nova Sonic client.""" try: - config = Config( endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", region=self.region, aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), http_auth_scheme_resolver=HTTPAuthSchemeResolver(), - http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()} + http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()}, ) - + self._client = BedrockRuntimeClient(config=config) logger.debug("Nova Sonic client initialized") - + except ImportError as e: logger.error("Nova Sonic dependencies not available: %s", e) raise except Exception as e: logger.error("Error initializing Nova Sonic client: %s", e) raise - diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index d650aba9b..6ef96f919 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -11,12 +11,13 @@ # Add the src directory to Python path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) import time -import pyaudio -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +import pyaudio from strands_tools import calculator +from ..agent.agent import BidirectionalAgent +from ..models.novasonic import NovaSonicBidirectionalModel + async def play(context): """Play audio output with responsive interruption support.""" @@ -26,7 +27,7 @@ async def play(context): format=pyaudio.paInt16, output=True, rate=24000, - frames_per_buffer=1024, + frames_per_buffer=1024, ) try: @@ -40,36 +41,33 @@ async def play(context): context["audio_out"].get_nowait() except asyncio.QueueEmpty: break - + context["interrupted"] = False - await asyncio.sleep(0.05) + await asyncio.sleep(0.05) continue - + # Get next audio data - audio_data = await asyncio.wait_for( - context["audio_out"].get(), - timeout=0.1 - ) - + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + if audio_data and context["active"]: - chunk_size = 1024 + chunk_size = 1024 for i in range(0, len(audio_data), chunk_size): # Check for interruption before each chunk if context.get("interrupted", False) or not context["active"]: break - + end = min(i + chunk_size, len(audio_data)) chunk = audio_data[i:end] speaker.write(chunk) await asyncio.sleep(0.001) - + except asyncio.TimeoutError: continue # No audio available except asyncio.QueueEmpty: await asyncio.sleep(0.01) except asyncio.CancelledError: break - + except asyncio.CancelledError: pass finally: @@ -111,30 +109,30 @@ async def receive(agent, context): if "audioOutput" in event: if not context.get("interrupted", False): context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) - + # Handle interruption events elif "interruptionDetected" in event: context["interrupted"] = True elif "interrupted" in event: context["interrupted"] = True - + # Handle text output with interruption detection elif "textOutput" in event: text_content = event["textOutput"].get("content", "") role = event["textOutput"].get("role", "unknown") - + # Check for text-based interruption patterns if '{ "interrupted" : true }' in text_content: context["interrupted"] = True elif "interrupted" in text_content.lower(): context["interrupted"] = True - + # Log text output if role.upper() == "USER": print(f"User: {text_content}") elif role.upper() == "ASSISTANT": print(f"Assistant: {text_content}") - + except asyncio.CancelledError: pass @@ -145,18 +143,13 @@ async def send(agent, context): while time.time() - context["start_time"] < context["duration"]: try: audio_bytes = context["audio_in"].get_nowait() - audio_event = { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - } + audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) # Restored to working timing except asyncio.CancelledError: break - + context["active"] = False except asyncio.CancelledError: pass @@ -166,14 +159,10 @@ async def main(duration=180): """Main function for bidirectional streaming test.""" print("Starting bidirectional streaming test...") print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") - + # Initialize model and agent model = NovaSonicBidirectionalModel(region="us-east-1") - agent = BidirectionalAgent( - model=model, - tools=[calculator], - system_prompt="You are a helpful assistant." - ) + agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") await agent.start() @@ -189,15 +178,11 @@ async def main(duration=180): } print("Speak into microphone. Press Ctrl+C to exit.") - + try: # Run all tasks concurrently await asyncio.gather( - play(context), - record(context), - receive(agent, context), - send(agent, context), - return_exceptions=True + play(context), record(context), receive(agent, context), send(agent, context), return_exceptions=True ) except KeyboardInterrupt: print("\nInterrupted by user") diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index f6441d2f0..510285f06 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -1,3 +1,31 @@ -"""Bidirectional streaming types package.""" -# Types package +"""Type definitions for bidirectional streaming.""" +from .bidirectional_streaming import ( + DEFAULT_CHANNELS, + DEFAULT_SAMPLE_RATE, + SUPPORTED_AUDIO_FORMATS, + SUPPORTED_CHANNELS, + SUPPORTED_SAMPLE_RATES, + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, +) + +__all__ = [ + "AudioInputEvent", + "AudioOutputEvent", + "BidirectionalConnectionEndEvent", + "BidirectionalConnectionStartEvent", + "BidirectionalStreamEvent", + "InterruptionDetectedEvent", + "TextOutputEvent", + "SUPPORTED_AUDIO_FORMATS", + "SUPPORTED_SAMPLE_RATES", + "SUPPORTED_CHANNELS", + "DEFAULT_SAMPLE_RATE", + "DEFAULT_CHANNELS", +] diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index fabe53ac9..01d72356a 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -19,23 +19,25 @@ from typing import Any, Dict, Literal, Optional -from strands.types.content import Role -from strands.types.streaming import StreamEvent from typing_extensions import TypedDict +from ....types.content import Role +from ....types.streaming import StreamEvent + # Audio format constants -SUPPORTED_AUDIO_FORMATS = ['pcm', 'wav', 'opus', 'mp3'] +SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo DEFAULT_SAMPLE_RATE = 16000 DEFAULT_CHANNELS = 1 + class AudioOutputEvent(TypedDict): """Audio output event from the model. - + Provides standardized audio output format across different providers using raw bytes instead of provider-specific encodings. - + Attributes: audioData: Raw audio bytes (not base64 or hex encoded). format: Audio format from SUPPORTED_AUDIO_FORMATS. @@ -43,9 +45,9 @@ class AudioOutputEvent(TypedDict): channels: Channel count from SUPPORTED_CHANNELS. encoding: Original provider encoding for debugging purposes. """ - + audioData: bytes - format: Literal['pcm', 'wav', 'opus', 'mp3'] + format: Literal["pcm", "wav", "opus", "mp3"] sampleRate: Literal[16000, 24000, 48000] channels: Literal[1, 2] encoding: Optional[str] @@ -53,78 +55,78 @@ class AudioOutputEvent(TypedDict): class AudioInputEvent(TypedDict): """Audio input event for sending audio to the model. - + Used for sending audio data through the send() method. - + Attributes: audioData: Raw audio bytes to send to model. format: Audio format from SUPPORTED_AUDIO_FORMATS. sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. channels: Channel count from SUPPORTED_CHANNELS. """ - + audioData: bytes - format: Literal['pcm', 'wav', 'opus', 'mp3'] + format: Literal["pcm", "wav", "opus", "mp3"] sampleRate: Literal[16000, 24000, 48000] channels: Literal[1, 2] class TextOutputEvent(TypedDict): """Text output event from the model during bidirectional streaming. - + Attributes: text: The text content from the model. role: The role of the message sender. """ - + text: str role: Role class InterruptionDetectedEvent(TypedDict): """Interruption detection event. - + Signals when user interruption is detected during model generation. - + Attributes: reason: Interruption reason from predefined set. """ - - reason: Literal['user_input', 'vad_detected', 'manual'] + + reason: Literal["user_input", "vad_detected", "manual"] class BidirectionalConnectionStartEvent(TypedDict, total=False): """connection start event for bidirectional streaming. - + Attributes: connectionId: Unique connection identifier. metadata: Provider-specific connection metadata. """ - + connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalConnectionEndEvent(TypedDict): """connection end event for bidirectional streaming. - + Attributes: reason: Reason for connection end from predefined set. connectionId: Unique connection identifier. metadata: Provider-specific connection metadata. """ - - reason: Literal['user_request', 'timeout', 'error'] + + reason: Literal["user_request", "timeout", "error"] connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. - + Extends the existing StreamEvent type with bidirectional-specific events while maintaining full backward compatibility with existing Strands streaming. - + Attributes: audioOutput: Audio output from the model. audioInput: Audio input sent to the model. @@ -133,11 +135,10 @@ class BidirectionalStreamEvent(StreamEvent, total=False): BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. """ - + audioOutput: AudioOutputEvent audioInput: AudioInputEvent textOutput: TextOutputEvent interruptionDetected: InterruptionDetectedEvent BidirectionalConnectionStart: BidirectionalConnectionStartEvent BidirectionalConnectionEnd: BidirectionalConnectionEndEvent - diff --git a/src/strands/experimental/bidirectional_streaming/utils/__init__.py b/src/strands/experimental/bidirectional_streaming/utils/__init__.py new file mode 100644 index 000000000..579478436 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/utils/__init__.py @@ -0,0 +1,5 @@ +"""Utility functions for bidirectional streaming.""" + +from .debug import log_event, log_flow, time_it_async + +__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/utils/debug.py b/src/strands/experimental/bidirectional_streaming/utils/debug.py index 1e88b6ead..6a7fc3982 100644 --- a/src/strands/experimental/bidirectional_streaming/utils/debug.py +++ b/src/strands/experimental/bidirectional_streaming/utils/debug.py @@ -11,30 +11,34 @@ # Debug logging system matching successful tool use example DEBUG = False # Disable debug logging for clean output like tool use example + def debug_print(message): """Print debug message with timestamp and function name.""" if DEBUG: function_name = inspect.stack()[1].function - if function_name == 'time_it_async': + if function_name == "time_it_async": function_name = inspect.stack()[2].function - timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] print(f"{timestamp} {function_name} {message}") + def log_event(event_type, **context): """Log important events with structured context.""" if DEBUG: function_name = inspect.stack()[1].function - timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] context_str = " ".join([f"{k}={v}" for k, v in context.items()]) if context else "" print(f"{timestamp} {function_name} EVENT: {event_type} {context_str}") + def log_flow(step, details=""): """Log important flow steps without excessive detail.""" if DEBUG: function_name = inspect.stack()[1].function - timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] print(f"{timestamp} {function_name} FLOW: {step} {details}") + async def time_it_async(label, method_to_run): """Time asynchronous method execution.""" start_time = time.perf_counter() @@ -42,4 +46,3 @@ async def time_it_async(label, method_to_run): end_time = time.perf_counter() debug_print(f"Execution time for {label}: {end_time - start_time:.4f} seconds") return result - From f7e67aec65640b9e262e88d4f82d020308143250 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 1 Oct 2025 23:59:44 -0400 Subject: [PATCH 05/15] fix linting issues --- pyproject.toml | 1 - .../event_loop/bidirectional_event_loop.py | 5 +++-- .../experimental/bidirectional_streaming/models/novasonic.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f45794d12..dd01ebde3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ bidirectional-streaming = [ "smithy-aws-core>=0.0.1", "pytz", "aws_sdk_bedrock_runtime", - "python>=3.12" ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index c90d118ff..4fbae3992 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,6 +21,7 @@ from ....tools._validator import validate_and_prepare_tools from ....types.content import Message from ....types.tools import ToolResult, ToolUse +from ..agent.agent import BidirectionalAgent from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -61,7 +62,7 @@ def __init__(self, model_session: BidirectionalModelSession, agent): self.interrupted = False -async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: +async def start_bidirectional_connection(agent: BidirectionalAgent) -> BidirectionalConnection: """Initialize bidirectional session with concurrent background tasks. Creates a model-specific session and starts background tasks for processing @@ -325,7 +326,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: session.pending_tool_tasks[task_id] = task # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) - def cleanup_task(completed_task): + def cleanup_task(completed_task, task_id=task_id): try: # Remove from pending tasks if task_id in session.pending_tool_tasks: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 0efd2413c..22912354d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -147,7 +147,7 @@ def _build_initialization_events( async def _send_initialization_events(self, events: List[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): - await time_it_async(f"send_init_event_{i + 1}", lambda: self._send_nova_event(event)) + await time_it_async(f"send_init_event_{i + 1}", lambda event=event: self._send_nova_event(event)) await asyncio.sleep(EVENT_DELAY) async def _process_responses(self) -> None: @@ -384,7 +384,7 @@ async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> No ] for i, event in enumerate(events): - await time_it_async(f"send_tool_event_{i + 1}", lambda: self._send_nova_event(event)) + await time_it_async(f"send_tool_event_{i + 1}", lambda event=event: self._send_nova_event(event)) async def send_tool_error(self, tool_use_id: str, error: str) -> None: """Send tool error using Nova Sonic format.""" From c654621d9c345316c90e6895a430e2f1918a9b8c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 00:07:34 -0400 Subject: [PATCH 06/15] Remove typing module and rely on python's built-in types --- .../bidirectional_streaming/agent/agent.py | 10 ++--- .../event_loop/bidirectional_event_loop.py | 13 +++---- .../models/bidirectional_model.py | 12 +++--- .../models/novasonic.py | 38 +++++++++---------- 4 files changed, 36 insertions(+), 37 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index d7a5f17a3..997a0d1df 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -14,7 +14,7 @@ import asyncio import logging -from typing import AsyncIterable, List, Optional, Union +from typing import AsyncIterable from ....tools.executors import ConcurrentToolExecutor from ....tools.registry import ToolRegistry @@ -37,9 +37,9 @@ class BidirectionalAgent: def __init__( self, model: BidirectionalModel, - tools: Optional[List] = None, - system_prompt: Optional[str] = None, - messages: Optional[Messages] = None, + tools: list | None = None, + system_prompt: str | None = None, + messages: Messages | None = None, ): """Initialize bidirectional agent with required model and optional configuration. @@ -83,7 +83,7 @@ async def start(self) -> None: self._session = await start_bidirectional_connection(self) log_event("conversation_ready") - async def send(self, input_data: Union[str, AudioInputEvent]) -> None: + async def send(self, input_data: str | AudioInputEvent) -> None: """Send input to the model (text or audio). Unified method for sending both text and audio input to the model during diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 4fbae3992..65ee6b905 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -16,7 +16,6 @@ import logging import traceback import uuid -from typing import Any, Dict from ....tools._validator import validate_and_prepare_tools from ....types.content import Message @@ -56,14 +55,14 @@ def __init__(self, model_session: BidirectionalModelSession, agent): self.audio_output_queue = asyncio.Queue() # Task management for cleanup - self.pending_tool_tasks: Dict[str, asyncio.Task] = {} + self.pending_tool_tasks: dict[str, asyncio.Task] = {} # Interruption handling (model-agnostic) self.interrupted = False async def start_bidirectional_connection(agent: BidirectionalAgent) -> BidirectionalConnection: - """Initialize bidirectional session with concurrent background tasks. + """Initialize bidirectional session with conycurrent background tasks. Creates a model-specific session and starts background tasks for processing model events, executing tools, and managing the session lifecycle. @@ -365,7 +364,7 @@ def cleanup_task(completed_task, task_id=task_id): log_flow("tool_execution", "processor stopped") -def _convert_to_strands_event(provider_event: Dict) -> Dict: +def _convert_to_strands_event(provider_event: dict) -> dict: """Pass-through for events already normalized by provider sessions. Providers convert their raw events to standard format before reaching here. @@ -385,7 +384,7 @@ def _convert_to_strands_event(provider_event: Dict) -> Dict: return provider_event -async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: +async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: """Execute tool using Strands infrastructure with interruption support. Executes tools using the existing Strands tool system, handles interruption @@ -501,11 +500,11 @@ def _extract_callable_function(tool_func): raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") -def _create_success_result(tool_use_id: str, result) -> Dict[str, Any]: +def _create_success_result(tool_use_id: str, result) -> dict[str, any]: """Create a successful tool result.""" return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} -def _create_error_result(tool_use_id: str, error: str) -> Dict[str, Any]: +def _create_error_result(tool_use_id: str, error: str) -> dict[str, any]: """Create an error tool result.""" return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index cc803458b..1432b112a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -13,7 +13,7 @@ import abc import logging -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import AsyncIterable from ....types.content import Messages from ....types.tools import ToolSpec @@ -31,7 +31,7 @@ class BidirectionalModelSession(abc.ABC): """ @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + async def receive_events(self) -> AsyncIterable[dict[str, any]]: """Receive events from the model in standardized format. Converts provider-specific events to a common format that can be @@ -67,7 +67,7 @@ async def send_interrupt(self) -> None: raise NotImplementedError @abc.abstractmethod - async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: """Send tool execution result to the model. Formats and sends tool results according to the provider's specific protocol. @@ -95,9 +95,9 @@ class BidirectionalModel(abc.ABC): @abc.abstractmethod async def create_bidirectional_connection( self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, **kwargs, ) -> BidirectionalModelSession: """Create a bidirectional connection with the model. diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 22912354d..969cac159 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -19,7 +19,7 @@ import time import traceback import uuid -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import AsyncIterable from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme @@ -80,7 +80,7 @@ class NovaSonicSession(BidirectionalModelSession): interface. """ - def __init__(self, stream, config: Dict[str, Any]): + def __init__(self, stream, config: dict[str, any]): """Initialize Nova Sonic connection. Args: @@ -111,9 +111,9 @@ def __init__(self, stream, config: Dict[str, Any]): async def initialize( self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, ) -> None: """Initialize Nova Sonic connection with required protocol sequence.""" try: @@ -132,8 +132,8 @@ async def initialize( raise def _build_initialization_events( - self, system_prompt: str, tools: List[ToolSpec], messages: Optional[Messages] - ) -> List[str]: + self, system_prompt: str, tools: list[ToolSpec], messages: Messages | None + ) -> list[str]: """Build the sequence of initialization events.""" events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] @@ -144,7 +144,7 @@ def _build_initialization_events( return events - async def _send_initialization_events(self, events: List[str]) -> None: + async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): await time_it_async(f"send_init_event_{i + 1}", lambda event=event: self._send_nova_event(event)) @@ -192,7 +192,7 @@ async def _handle_response_data(self, response_data: str) -> None: except json.JSONDecodeError as e: log_event("nova_json_error", error=str(e)) - def _log_event_type(self, nova_event: Dict[str, Any]) -> None: + def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" if "usageEvent" in nova_event: log_event("nova_usage", usage=nova_event["usageEvent"]) @@ -206,7 +206,7 @@ def _log_event_type(self, nova_event: Dict[str, Any]) -> None: audio_bytes = base64.b64decode(audio_content) log_event("nova_audio_output", bytes=len(audio_bytes)) - async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + async def receive_events(self) -> AsyncIterable[dict[str, any]]: """Receive Nova Sonic events and convert to provider-agnostic format.""" if not self.stream: logger.error("Stream is None") @@ -370,7 +370,7 @@ async def send_interrupt(self) -> None: } await self._send_nova_event(interrupt_event) - async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: """Send tool result using Nova Sonic toolResult format.""" if not self._active: return @@ -433,7 +433,7 @@ async def close(self) -> None: finally: log_event("nova_connection_closed") - def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | None: """Convert Nova Sonic events to provider-agnostic format.""" # Handle audio output if "audioOutput" in nova_event: @@ -512,7 +512,7 @@ def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) - def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: + def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" prompt_start_event = { "event": { @@ -531,7 +531,7 @@ def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: return json.dumps(prompt_start_event) - def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: + def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict]: """Build tool configuration from tool specs.""" tool_config = [] for tool in tools: @@ -546,7 +546,7 @@ def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: ) return tool_config - def _get_system_prompt_events(self, system_prompt: str) -> List[str]: + def _get_system_prompt_events(self, system_prompt: str) -> list[str]: """Generate system prompt events.""" content_name = str(uuid.uuid4()) return [ @@ -599,7 +599,7 @@ def _get_text_input_event(self, content_name: str, text: str) -> str: {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} ) - def _get_tool_result_event(self, content_name: str, result: Dict[str, Any]) -> str: + def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: """Generate tool result event.""" return json.dumps( { @@ -664,9 +664,9 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e async def create_bidirectional_connection( self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, **kwargs, ) -> BidirectionalModelSession: """Create Nova Sonic bidirectional connection.""" From 1f1abacd839cd6ed26ebd9a84bfa2e8aeb50be01 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 00:12:15 -0400 Subject: [PATCH 07/15] add typing to methods --- .../event_loop/bidirectional_event_loop.py | 8 ++++---- .../bidirectional_streaming/models/novasonic.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 65ee6b905..ea00468bb 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -38,7 +38,7 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModelSession, agent): + def __init__(self, model_session: BidirectionalModelSession, agent: BidirectionalAgent) -> None: """Initialize session with model session and agent reference. Args: @@ -325,7 +325,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: session.pending_tool_tasks[task_id] = task # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) - def cleanup_task(completed_task, task_id=task_id): + def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: try: # Remove from pending tasks if task_id in session.pending_tool_tasks: @@ -488,7 +488,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: log_event("tool_error_send_failed", error=str(send_error)) -def _extract_callable_function(tool_func): +def _extract_callable_function(tool_func: any) -> any: """Extract the callable function from different tool object types.""" if hasattr(tool_func, "_tool_func"): return tool_func._tool_func @@ -500,7 +500,7 @@ def _extract_callable_function(tool_func): raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") -def _create_success_result(tool_use_id: str, result) -> dict[str, any]: +def _create_success_result(tool_use_id: str, result: any) -> dict[str, any]: """Create a successful tool result.""" return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 969cac159..89472350b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -80,7 +80,7 @@ class NovaSonicSession(BidirectionalModelSession): interface. """ - def __init__(self, stream, config: dict[str, any]): + def __init__(self, stream: any, config: dict[str, any]) -> None: """Initialize Nova Sonic connection. Args: @@ -312,7 +312,7 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: # Start silence detection task self.silence_task = asyncio.create_task(self._check_silence()) - async def _check_silence(self): + async def _check_silence(self) -> None: """Check for silence and automatically end audio connection.""" try: await asyncio.sleep(self.silence_threshold) @@ -647,7 +647,7 @@ class NovaSonicBidirectionalModel(BidirectionalModel): streaming interface, handling AWS authentication and connection management. """ - def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: """Initialize Nova Sonic bidirectional model. Args: From eb543b52434dbe6af1f2f309f77446a97ed08871 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 12:00:04 -0400 Subject: [PATCH 08/15] Improve comments and remove unused method _convert_to_strands_event --- .../bidirectional_streaming/agent/agent.py | 1 - .../event_loop/bidirectional_event_loop.py | 45 +++++++------------ 2 files changed, 15 insertions(+), 31 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 997a0d1df..e27885c7e 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -98,7 +98,6 @@ async def send(self, input_data: str | AudioInputEvent) -> None: self._validate_active_session() if isinstance(input_data, str): - # Handle text input log_event("text_sent", length=len(input_data)) await self._session.model_session.send_text_content(input_data) elif isinstance(input_data, dict) and "audioData" in input_data: diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index ea00468bb..fddd1245a 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -189,7 +189,7 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: log_event("interruption_detected") session.interrupted = True - # 🔥 CANCEL ALL PENDING TOOL TASKS (Nova Sonic pattern) + # Cancel all pending tool execution tasks cancelled_tools = 0 for task_id, task in list(session.pending_tool_tasks.items()): if not task.done(): @@ -200,7 +200,7 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: if cancelled_tools > 0: log_event("tool_tasks_cancelled", count=cancelled_tools) - # 🔥 AGGRESSIVELY CLEAR AUDIO OUTPUT QUEUE (Nova Sonic pattern) + # Clear all queued audio output events cleared_count = 0 while True: try: @@ -258,8 +258,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: if not session.active: break - # Convert provider events to Strands format - strands_event = _convert_to_strands_event(provider_event) + # Basic validation - skip invalid events + if not isinstance(provider_event, dict): + continue + + strands_event = provider_event # Handle interruption detection (multiple patterns) if strands_event.get("interruptionDetected"): @@ -269,7 +272,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: await session.agent._output_queue.put(strands_event) continue - # Check for text-based interruption (Nova Sonic pattern) + # Check for text-based interruption if strands_event.get("textOutput"): text_content = strands_event["textOutput"].get("content", "") if '{ "interrupted" : true }' in text_content: @@ -324,7 +327,6 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) session.pending_tool_tasks[task_id] = task - # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: try: # Remove from pending tasks @@ -346,7 +348,7 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: except asyncio.TimeoutError: if not session.active: break - # 🔥 PERIODIC CLEANUP OF COMPLETED TASKS + # Remove completed tasks from tracking completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] for task_id in completed_tasks: if task_id in session.pending_tool_tasks: @@ -364,24 +366,7 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: log_flow("tool_execution", "processor stopped") -def _convert_to_strands_event(provider_event: dict) -> dict: - """Pass-through for events already normalized by provider sessions. - - Providers convert their raw events to standard format before reaching here. - This just validates and passes through the normalized events. - - Args: - provider_event: Already normalized event from provider session. - - Returns: - Dict: The same event, validated and passed through. - """ - # Basic validation - ensure we have a dict - if not isinstance(provider_event, dict): - return {} - # Pass through - conversion already done by provider session - return provider_event async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: @@ -398,7 +383,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_id = tool_use.get("toolUseId") try: - # 🔥 CHECK FOR INTERRUPTION BEFORE STARTING (Nova Sonic pattern) + # Skip execution if session is interrupted or inactive if session.interrupted or not session.active: log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) return @@ -422,7 +407,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: # Execute tools directly (simpler approach for bidirectional) for tool_use in valid_tool_uses: - # 🔥 CHECK FOR INTERRUPTION DURING EXECUTION + # Return early if session was interrupted during execution if session.interrupted or not session.active: log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) return @@ -433,12 +418,12 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: try: actual_func = _extract_callable_function(tool_func) - # 🔥 WRAP TOOL EXECUTION IN CANCELLATION CHECK + # Execute tool function with provided input # For async tools, we could wrap with asyncio.wait_for with cancellation # For sync tools, we execute directly but check interruption after result = actual_func(**tool_use.get("input", {})) - # 🔥 CHECK FOR INTERRUPTION AFTER TOOL EXECUTION + # Discard result if session was interrupted during execution if session.interrupted or not session.active: log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) return @@ -451,7 +436,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: log_event("tool_execution_cancelled", name=tool_name, id=tool_id) return except Exception as e: - # 🔥 CHECK FOR INTERRUPTION EVEN ON ERROR + # Discard error result if session was interrupted if session.interrupted or not session.active: log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) return @@ -462,7 +447,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: else: log_event("tool_not_found", name=tool_name) - # 🔥 FINAL INTERRUPTION CHECK BEFORE SENDING RESULTS + # Skip sending results if session was interrupted if session.interrupted or not session.active: log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) return From 5921f8bdb24740adb2b6ad2af609218674b4b4b5 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 12:23:21 -0400 Subject: [PATCH 09/15] Updated: fixed module imports baesd on the new smithy python release on 09-29, added a lock for interruption handling --- .../event_loop/bidirectional_event_loop.py | 118 ++++++++++-------- .../models/novasonic.py | 6 +- .../tests/test_bidirectional_streaming.py | 4 +- 3 files changed, 68 insertions(+), 60 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index fddd1245a..358fdcea3 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -20,7 +20,7 @@ from ....tools._validator import validate_and_prepare_tools from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ..agent.agent import BidirectionalAgent + from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -38,7 +38,7 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModelSession, agent: BidirectionalAgent) -> None: + def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent") -> None: """Initialize session with model session and agent reference. Args: @@ -59,9 +59,10 @@ def __init__(self, model_session: BidirectionalModelSession, agent: Bidirectiona # Interruption handling (model-agnostic) self.interrupted = False + self.interruption_lock = asyncio.Lock() -async def start_bidirectional_connection(agent: BidirectionalAgent) -> BidirectionalConnection: +async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: """Initialize bidirectional session with conycurrent background tasks. Creates a model-specific session and starts background tasks for processing @@ -181,66 +182,73 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: """Handle interruption detection with task cancellation and audio buffer clearing. Cancels pending tool tasks and clears audio output queues to ensure responsive - interruption handling during conversations. + interruption handling during conversations. Protected by async lock to prevent + concurrent execution and race conditions. Args: session: BidirectionalConnection to handle interruption for. """ - log_event("interruption_detected") - session.interrupted = True + async with session.interruption_lock: + # If already interrupted, skip duplicate processing + if session.interrupted: + log_event("interruption_already_in_progress") + return - # Cancel all pending tool execution tasks - cancelled_tools = 0 - for task_id, task in list(session.pending_tool_tasks.items()): - if not task.done(): - task.cancel() - cancelled_tools += 1 - log_event("tool_task_cancelled", task_id=task_id) + log_event("interruption_detected") + session.interrupted = True - if cancelled_tools > 0: - log_event("tool_tasks_cancelled", count=cancelled_tools) + # Cancel all pending tool execution tasks + cancelled_tools = 0 + for task_id, task in list(session.pending_tool_tasks.items()): + if not task.done(): + task.cancel() + cancelled_tools += 1 + log_event("tool_task_cancelled", task_id=task_id) - # Clear all queued audio output events - cleared_count = 0 - while True: - try: - session.audio_output_queue.get_nowait() - cleared_count += 1 - except asyncio.QueueEmpty: - break + if cancelled_tools > 0: + log_event("tool_tasks_cancelled", count=cancelled_tools) - # Also clear the agent's audio output queue if it exists - if hasattr(session.agent, "_output_queue"): - audio_cleared = 0 - # Create a temporary list to hold non-audio events - temp_events = [] - try: - while True: - event = session.agent._output_queue.get_nowait() - if event.get("audioOutput"): - audio_cleared += 1 - else: - # Keep non-audio events - temp_events.append(event) - except asyncio.QueueEmpty: - pass - - # Put back non-audio events - for event in temp_events: - session.agent._output_queue.put_nowait(event) - - if audio_cleared > 0: - log_event("agent_audio_queue_cleared", count=audio_cleared) - - if cleared_count > 0: - log_event("session_audio_queue_cleared", count=cleared_count) - - # Brief sleep to allow audio system to settle (matches Nova Sonic timing) - await asyncio.sleep(0.05) - - # Reset interruption flag after clearing (automatic recovery) - session.interrupted = False - log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) + # Clear all queued audio output events + cleared_count = 0 + while True: + try: + session.audio_output_queue.get_nowait() + cleared_count += 1 + except asyncio.QueueEmpty: + break + + # Also clear the agent's audio output queue if it exists + if hasattr(session.agent, "_output_queue"): + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + log_event("agent_audio_queue_cleared", count=audio_cleared) + + if cleared_count > 0: + log_event("session_audio_queue_cleared", count=cleared_count) + + # Brief sleep to allow audio system to settle (matches Nova Sonic timing) + await asyncio.sleep(0.05) + + # Reset interruption flag after clearing (automatic recovery) + session.interrupted = False + log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) async def _process_model_events(session: BidirectionalConnection) -> None: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 89472350b..e79229623 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -24,7 +24,7 @@ from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk -from smithy_aws_core.credentials_resolvers.environment import EnvironmentCredentialsResolver +from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver from ....types.content import Messages from ....types.tools import ToolSpec, ToolUse @@ -703,8 +703,8 @@ async def _initialize_client(self) -> None: endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", region=self.region, aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), - http_auth_scheme_resolver=HTTPAuthSchemeResolver(), - http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()}, + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, ) self._client = BedrockRuntimeClient(config=config) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index 6ef96f919..b31607966 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -15,8 +15,8 @@ import pyaudio from strands_tools import calculator -from ..agent.agent import BidirectionalAgent -from ..models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel async def play(context): From 8cb4d98ba035d021cdff1953cf9705cca114e270 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 12:37:33 -0400 Subject: [PATCH 10/15] Removed unnecessary _output_queue check as the queue will always be initialized, and removed asyncio.sleep() as they were mainly for defensive purposes and following the pattern of nova sonic samples. --- .../event_loop/bidirectional_event_loop.py | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 358fdcea3..b4395f38e 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -20,7 +20,6 @@ from ....tools._validator import validate_and_prepare_tools from ....types.content import Message from ....types.tools import ToolResult, ToolUse - from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -95,10 +94,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec # Start main coordination cycle session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) - # Give background tasks a moment to start - await asyncio.sleep(0.1) log_event("session_ready", tasks=len(session.background_tasks)) - return session @@ -217,35 +213,31 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: except asyncio.QueueEmpty: break - # Also clear the agent's audio output queue if it exists - if hasattr(session.agent, "_output_queue"): - audio_cleared = 0 - # Create a temporary list to hold non-audio events - temp_events = [] - try: - while True: - event = session.agent._output_queue.get_nowait() - if event.get("audioOutput"): - audio_cleared += 1 - else: - # Keep non-audio events - temp_events.append(event) - except asyncio.QueueEmpty: - pass - - # Put back non-audio events - for event in temp_events: - session.agent._output_queue.put_nowait(event) - - if audio_cleared > 0: - log_event("agent_audio_queue_cleared", count=audio_cleared) + # Also clear the agent's audio output queue + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + log_event("agent_audio_queue_cleared", count=audio_cleared) if cleared_count > 0: log_event("session_audio_queue_cleared", count=cleared_count) - # Brief sleep to allow audio system to settle (matches Nova Sonic timing) - await asyncio.sleep(0.05) - # Reset interruption flag after clearing (automatic recovery) session.interrupted = False log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) From 7a6e53efdf669352bd18f19531178d46589c214d Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 13:03:42 -0400 Subject: [PATCH 11/15] Remove redundant interruption checks --- .../event_loop/bidirectional_event_loop.py | 67 +++---------------- 1 file changed, 11 insertions(+), 56 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index b4395f38e..cc4f416b7 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -264,7 +264,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: strands_event = provider_event - # Handle interruption detection (multiple patterns) + # Handle interruption detection (provider converts raw patterns to interruptionDetected) if strands_event.get("interruptionDetected"): log_event("interruption_forwarded") await _handle_interruption(session) @@ -272,16 +272,6 @@ async def _process_model_events(session: BidirectionalConnection) -> None: await session.agent._output_queue.put(strands_event) continue - # Check for text-based interruption - if strands_event.get("textOutput"): - text_content = strands_event["textOutput"].get("content", "") - if '{ "interrupted" : true }' in text_content: - log_event("text_interruption_detected") - await _handle_interruption(session) - # Still forward the text event - await session.agent._output_queue.put(strands_event) - continue - # Queue tool requests for concurrent execution if strands_event.get("toolUse"): log_event("tool_queued", name=strands_event["toolUse"].get("name")) @@ -308,8 +298,8 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. Background task that manages tool execution without blocking model event - processing or user interaction. Includes proper task cleanup and cancellation - handling for interruptions. + processing or user interaction. Uses proper asyncio cancellation for + interruption handling rather than manual state checks. Args: session: BidirectionalConnection containing tool queue. @@ -320,9 +310,6 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) - if not session.active: - break - task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) session.pending_tool_tasks[task_id] = task @@ -372,8 +359,9 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: """Execute tool using Strands infrastructure with interruption support. - Executes tools using the existing Strands tool system, handles interruption - during execution, and sends results back to the model provider. + Executes tools using the existing Strands tool system with proper asyncio + cancellation handling. Tool execution is stopped via task cancellation, + not manual state checks. Args: session: BidirectionalConnection for context. @@ -383,11 +371,6 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_id = tool_use.get("toolUseId") try: - # Skip execution if session is interrupted or inactive - if session.interrupted or not session.active: - log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) - return - # Create message structure for existing tool system tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} @@ -407,11 +390,6 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: # Execute tools directly (simpler approach for bidirectional) for tool_use in valid_tool_uses: - # Return early if session was interrupted during execution - if session.interrupted or not session.active: - log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) - return - tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) if tool_func: @@ -419,39 +397,18 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: actual_func = _extract_callable_function(tool_func) # Execute tool function with provided input - # For async tools, we could wrap with asyncio.wait_for with cancellation - # For sync tools, we execute directly but check interruption after result = actual_func(**tool_use.get("input", {})) - # Discard result if session was interrupted during execution - if session.interrupted or not session.active: - log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) - return - tool_result = _create_success_result(tool_use["toolUseId"], result) tool_results.append(tool_result) - except asyncio.CancelledError: - # Tool was cancelled due to interruption - log_event("tool_execution_cancelled", name=tool_name, id=tool_id) - return except Exception as e: - # Discard error result if session was interrupted - if session.interrupted or not session.active: - log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) - return - log_event("tool_execution_failed", name=tool_name, error=str(e)) tool_result = _create_error_result(tool_use["toolUseId"], str(e)) tool_results.append(tool_result) else: log_event("tool_not_found", name=tool_name) - # Skip sending results if session was interrupted - if session.interrupted or not session.active: - log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) - return - # Send results through provider-specific session for result in tool_results: await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) @@ -464,13 +421,11 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: raise # Re-raise to properly handle cancellation except Exception as e: log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) - - # Only send error if not interrupted - if not session.interrupted and session.active: - try: - await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) - except Exception as send_error: - log_event("tool_error_send_failed", error=str(send_error)) + + try: + await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) + except Exception as send_error: + log_event("tool_error_send_failed", error=str(send_error)) def _extract_callable_function(tool_func: any) -> any: From a58626107b21dad40a52bf27320f35e1af9a5df8 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 13:25:51 -0400 Subject: [PATCH 12/15] Unified tool result and tool error methods, Added implementation to add user messages to the agent messages --- .../bidirectional_streaming/agent/agent.py | 8 ++++++-- .../event_loop/bidirectional_event_loop.py | 19 ++++++++++++------- .../models/bidirectional_model.py | 6 +----- .../models/novasonic.py | 6 ------ 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index e27885c7e..46bc38ef2 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -87,7 +87,8 @@ async def send(self, input_data: str | AudioInputEvent) -> None: """Send input to the model (text or audio). Unified method for sending both text and audio input to the model during - an active conversation session. + an active conversation session. User input is automatically added to + conversation history for complete message tracking. Args: input_data: Either a string for text input or AudioInputEvent for audio input. @@ -98,10 +99,13 @@ async def send(self, input_data: str | AudioInputEvent) -> None: self._validate_active_session() if isinstance(input_data, str): + # Add user text message to history + self.messages.append({"role": "user", "content": input_data}) + log_event("text_sent", length=len(input_data)) await self._session.model_session.send_text_content(input_data) elif isinstance(input_data, dict) and "audioData" in input_data: - # Handle audio input (AudioInputEvent) + # Handle audio input await self._session.model_session.send_audio_content(input_data) else: raise ValueError( diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index cc4f416b7..684c0037e 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -261,7 +261,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Basic validation - skip invalid events if not isinstance(provider_event, dict): continue - + strands_event = provider_event # Handle interruption detection (provider converts raw patterns to interruptionDetected) @@ -287,6 +287,14 @@ async def _process_model_events(session: BidirectionalConnection) -> None: log_event("message_added_to_history") session.agent.messages.append(strands_event["messageStop"]["message"]) + # Handle user audio transcripts - add to message history + if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": + user_transcript = strands_event["textOutput"]["text"] + if user_transcript.strip(): # Only add non-empty transcripts + user_message = {"role": "user", "content": user_transcript} + session.agent.messages.append(user_message) + log_event("user_transcript_added_to_history") + except Exception as e: log_event("model_events_error", error=str(e)) traceback.print_exc() @@ -298,7 +306,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. Background task that manages tool execution without blocking model event - processing or user interaction. Uses proper asyncio cancellation for + processing or user interaction. Uses proper asyncio cancellation for interruption handling rather than manual state checks. Args: @@ -353,9 +361,6 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: log_flow("tool_execution", "processor stopped") - - - async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: """Execute tool using Strands infrastructure with interruption support. @@ -421,9 +426,9 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: raise # Re-raise to properly handle cancellation except Exception as e: log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) - + try: - await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) + await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) except Exception as send_error: log_event("tool_error_send_failed", error=str(send_error)) diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 1432b112a..4cd9cc6b8 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -71,14 +71,10 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No """Send tool execution result to the model. Formats and sends tool results according to the provider's specific protocol. + Handles both successful results and error cases. """ raise NotImplementedError - @abc.abstractmethod - async def send_tool_error(self, tool_use_id: str, error: str) -> None: - """Send tool execution error to model in provider-specific format.""" - raise NotImplementedError - @abc.abstractmethod async def close(self) -> None: """Close the connection and cleanup resources.""" diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index e79229623..dfd911172 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -386,12 +386,6 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No for i, event in enumerate(events): await time_it_async(f"send_tool_event_{i + 1}", lambda event=event: self._send_nova_event(event)) - async def send_tool_error(self, tool_use_id: str, error: str) -> None: - """Send tool error using Nova Sonic format.""" - log_event("nova_tool_error_send", id=tool_use_id, error=error) - error_result = {"error": error} - await self.send_tool_result(tool_use_id, error_result) - async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: From 16d9b461d187b45ee6d3305268ef23293accd3b0 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 14:00:25 -0400 Subject: [PATCH 13/15] Modified logging to use python logger --- .../bidirectional_streaming/agent/agent.py | 8 +- .../event_loop/bidirectional_event_loop.py | 89 ++++++++++--------- .../models/novasonic.py | 67 +++++++------- 3 files changed, 83 insertions(+), 81 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 46bc38ef2..68d371a51 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -22,7 +22,7 @@ from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent -from ..utils.debug import log_event, log_flow + logger = logging.getLogger(__name__) @@ -79,9 +79,9 @@ async def start(self) -> None: if self._session and self._session.active: raise ValueError("Conversation already active. Call end() first.") - log_flow("conversation_start", "initializing session") + logger.debug("Conversation start - initializing session") self._session = await start_bidirectional_connection(self) - log_event("conversation_ready") + logger.debug("Conversation ready") async def send(self, input_data: str | AudioInputEvent) -> None: """Send input to the model (text or audio). @@ -102,7 +102,7 @@ async def send(self, input_data: str | AudioInputEvent) -> None: # Add user text message to history self.messages.append({"role": "user", "content": input_data}) - log_event("text_sent", length=len(input_data)) + logger.debug("Text sent: %d characters", len(input_data)) await self._session.model_session.send_text_content(input_data) elif isinstance(input_data, dict) and "audioData" in input_data: # Handle audio input diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 684c0037e..16be08aaf 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,7 +21,7 @@ from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession -from ..utils.debug import log_event, log_flow + logger = logging.getLogger(__name__) @@ -73,7 +73,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec Returns: BidirectionalConnection: Active session with background tasks running. """ - log_flow("session_start", "initializing model session") + logger.debug("Starting bidirectional session - initializing model session") # Create provider-specific session model_session = await agent.model.create_bidirectional_connection( @@ -85,7 +85,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec # Start concurrent background processors IMMEDIATELY after session creation # This is critical - Nova Sonic needs response processing during initialization - log_flow("background_tasks", "starting processors") + logger.debug("Starting background processors for concurrent processing") session.background_tasks = [ asyncio.create_task(_process_model_events(session)), # Handle model responses asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently @@ -94,7 +94,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec # Start main coordination cycle session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) - log_event("session_ready", tasks=len(session.background_tasks)) + logger.debug("Session ready with %d background tasks", len(session.background_tasks)) return session @@ -107,7 +107,7 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non if not session.active: return - log_flow("session_cleanup", "starting") + logger.debug("Session cleanup starting") session.active = False # Cancel pending tool tasks @@ -134,7 +134,7 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non # Close model session await session.model_session.close() - log_event("session_closed") + logger.debug("Session closed") async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: @@ -150,7 +150,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No try: # Check if background processors are still running if all(task.done() for task in session.background_tasks): - log_event("session_end", reason="all_processors_completed") + logger.debug("Session end - all processors completed") session.active = False break @@ -159,7 +159,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No if task.done() and not task.cancelled(): exception = task.exception() if exception: - log_event("session_error", processor=i, error=str(exception)) + logger.error("Session error in processor %d: %s", i, str(exception)) session.active = False raise exception @@ -169,7 +169,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No except asyncio.CancelledError: break except Exception as e: - log_event("event_loop_error", error=str(e)) + logger.error("Event loop error: %s", str(e)) session.active = False raise @@ -187,10 +187,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: async with session.interruption_lock: # If already interrupted, skip duplicate processing if session.interrupted: - log_event("interruption_already_in_progress") + logger.debug("Interruption already in progress") return - log_event("interruption_detected") + logger.debug("Interruption detected") session.interrupted = True # Cancel all pending tool execution tasks @@ -199,10 +199,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: if not task.done(): task.cancel() cancelled_tools += 1 - log_event("tool_task_cancelled", task_id=task_id) + logger.debug("Tool task cancelled: %s", task_id) if cancelled_tools > 0: - log_event("tool_tasks_cancelled", count=cancelled_tools) + logger.debug("Tool tasks cancelled: %d", cancelled_tools) # Clear all queued audio output events cleared_count = 0 @@ -233,14 +233,14 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: session.agent._output_queue.put_nowait(event) if audio_cleared > 0: - log_event("agent_audio_queue_cleared", count=audio_cleared) + logger.debug("Agent audio queue cleared: %d events", audio_cleared) if cleared_count > 0: - log_event("session_audio_queue_cleared", count=cleared_count) + logger.debug("Session audio queue cleared: %d events", cleared_count) # Reset interruption flag after clearing (automatic recovery) session.interrupted = False - log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) + logger.debug("Interruption handled - tools cancelled: %d, audio cleared: %d", cancelled_tools, cleared_count) async def _process_model_events(session: BidirectionalConnection) -> None: @@ -252,7 +252,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: Args: session: BidirectionalConnection containing model session. """ - log_flow("model_events", "processor started") + logger.debug("Model events processor started") try: async for provider_event in session.model_session.receive_events(): if not session.active: @@ -261,12 +261,12 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Basic validation - skip invalid events if not isinstance(provider_event, dict): continue - + strands_event = provider_event # Handle interruption detection (provider converts raw patterns to interruptionDetected) if strands_event.get("interruptionDetected"): - log_event("interruption_forwarded") + logger.debug("Interruption forwarded") await _handle_interruption(session) # Forward interruption event to agent for application-level handling await session.agent._output_queue.put(strands_event) @@ -274,7 +274,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Queue tool requests for concurrent execution if strands_event.get("toolUse"): - log_event("tool_queued", name=strands_event["toolUse"].get("name")) + logger.debug("Tool queued: %s", strands_event["toolUse"].get("name")) await session.tool_queue.put(strands_event["toolUse"]) continue @@ -284,39 +284,39 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Update Agent conversation history using existing patterns if strands_event.get("messageStop"): - log_event("message_added_to_history") + logger.debug("Message added to history") session.agent.messages.append(strands_event["messageStop"]["message"]) - + # Handle user audio transcripts - add to message history if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": user_transcript = strands_event["textOutput"]["text"] if user_transcript.strip(): # Only add non-empty transcripts user_message = {"role": "user", "content": user_transcript} session.agent.messages.append(user_message) - log_event("user_transcript_added_to_history") + logger.debug("User transcript added to history") except Exception as e: - log_event("model_events_error", error=str(e)) + logger.error("Model events error: %s", str(e)) traceback.print_exc() finally: - log_flow("model_events", "processor stopped") + logger.debug("Model events processor stopped") async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. Background task that manages tool execution without blocking model event - processing or user interaction. Uses proper asyncio cancellation for + processing or user interaction. Uses proper asyncio cancellation for interruption handling rather than manual state checks. Args: session: BidirectionalConnection containing tool queue. """ - log_flow("tool_execution", "processor started") + logger.debug("Tool execution processor started") while session.active: try: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) - log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) + logger.debug("Tool execution started: %s (id: %s)", tool_use.get("name"), tool_use.get("toolUseId")) task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) @@ -330,13 +330,13 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: # Log completion status if completed_task.cancelled(): - log_event("tool_task_cleanup_cancelled", task_id=task_id) + logger.debug("Tool task cleanup cancelled: %s", task_id) elif completed_task.exception(): - log_event("tool_task_cleanup_error", task_id=task_id, error=str(completed_task.exception())) + logger.error("Tool task cleanup error: %s - %s", task_id, str(completed_task.exception())) else: - log_event("tool_task_cleanup_success", task_id=task_id) + logger.debug("Tool task cleanup success: %s", task_id) except Exception as e: - log_event("tool_task_cleanup_failed", task_id=task_id, error=str(e)) + logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) task.add_done_callback(cleanup_task) @@ -350,15 +350,18 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: del session.pending_tool_tasks[task_id] if completed_tasks: - log_event("periodic_task_cleanup", count=len(completed_tasks)) + logger.debug("Periodic task cleanup: %d tasks", len(completed_tasks)) continue except Exception as e: - log_event("tool_execution_error", error=str(e)) + logger.error("Tool execution error: %s", str(e)) if not session.active: break - log_flow("tool_execution", "processor stopped") + logger.debug("Tool execution processor stopped") + + + async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: @@ -390,7 +393,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] if not valid_tool_uses: - log_event("tool_validation_failed", name=tool_name, id=tool_id) + logger.warning("Tool validation failed: %s (id: %s)", tool_name, tool_id) return # Execute tools directly (simpler approach for bidirectional) @@ -408,29 +411,29 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_results.append(tool_result) except Exception as e: - log_event("tool_execution_failed", name=tool_name, error=str(e)) + logger.error("Tool execution failed: %s - %s", tool_name, str(e)) tool_result = _create_error_result(tool_use["toolUseId"], str(e)) tool_results.append(tool_result) else: - log_event("tool_not_found", name=tool_name) + logger.warning("Tool not found: %s", tool_name) # Send results through provider-specific session for result in tool_results: await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) - log_event("tool_execution_completed", name=tool_name, results=len(tool_results)) + logger.debug("Tool execution completed: %s (%d results)", tool_name, len(tool_results)) except asyncio.CancelledError: # Task was cancelled due to interruption - this is expected behavior - log_event("tool_task_cancelled_gracefully", name=tool_name, id=tool_id) + logger.debug("Tool task cancelled gracefully: %s (id: %s)", tool_name, tool_id) raise # Re-raise to properly handle cancellation except Exception as e: - log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) - + logger.error("Tool execution error: %s - %s", tool_use.get("name"), str(e)) + try: await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) except Exception as send_error: - log_event("tool_error_send_failed", error=str(send_error)) + logger.error("Tool error send failed: %s", str(send_error)) def _extract_callable_function(tool_func: any) -> any: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index dfd911172..7f7937ef1 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -36,7 +36,7 @@ InterruptionDetectedEvent, TextOutputEvent, ) -from ..utils.debug import log_event, log_flow, time_it_async + from .bidirectional_model import BidirectionalModel, BidirectionalModelSession logger = logging.getLogger(__name__) @@ -121,10 +121,10 @@ async def initialize( init_events = self._build_initialization_events(system_prompt, tools or [], messages) - log_flow("nova_init", f"sending {len(init_events)} events") + logger.debug(f"Nova Sonic initialization - sending {len(init_events)} events") await self._send_initialization_events(init_events) - log_event("nova_connection_initialized") + logger.info("Nova Sonic connection initialized successfully") self._response_task = asyncio.create_task(self._process_responses()) except Exception as e: @@ -147,12 +147,12 @@ def _build_initialization_events( async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): - await time_it_async(f"send_init_event_{i + 1}", lambda event=event: self._send_nova_event(event)) + await self._send_nova_event(event) await asyncio.sleep(EVENT_DELAY) async def _process_responses(self) -> None: """Process Nova Sonic responses continuously.""" - log_flow("nova_responses", "processor started") + logger.debug("Nova Sonic response processor started") try: while self._active: @@ -167,14 +167,14 @@ async def _process_responses(self) -> None: await asyncio.sleep(0.1) continue except Exception as e: - log_event("nova_response_error", error=str(e)) + logger.warning(f"Nova Sonic response error: {e}") await asyncio.sleep(0.1) continue except Exception as e: - log_event("nova_fatal_error", error=str(e)) + logger.error(f"Nova Sonic fatal error: {e}") finally: - log_flow("nova_responses", "processor stopped") + logger.debug("Nova Sonic response processor stopped") async def _handle_response_data(self, response_data: str) -> None: """Handle decoded response data from Nova Sonic.""" @@ -190,21 +190,21 @@ async def _handle_response_data(self, response_data: str) -> None: await self._event_queue.put(nova_event) except json.JSONDecodeError as e: - log_event("nova_json_error", error=str(e)) + logger.warning(f"Nova Sonic JSON decode error: {e}") def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" if "usageEvent" in nova_event: - log_event("nova_usage", usage=nova_event["usageEvent"]) + logger.debug("Nova usage: %s", nova_event["usageEvent"]) elif "textOutput" in nova_event: - log_event("nova_text_output") + logger.debug("Nova text output") elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] - log_event("nova_tool_use", name=tool_use["toolName"], id=tool_use["toolUseId"]) + logger.debug("Nova tool use: %s (id: %s)", tool_use["toolName"], tool_use["toolUseId"]) elif "audioOutput" in nova_event: audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) - log_event("nova_audio_output", bytes=len(audio_bytes)) + logger.debug("Nova audio output: %d bytes", len(audio_bytes)) async def receive_events(self) -> AsyncIterable[dict[str, any]]: """Receive Nova Sonic events and convert to provider-agnostic format.""" @@ -212,7 +212,7 @@ async def receive_events(self) -> AsyncIterable[dict[str, any]]: logger.error("Stream is None") return - log_flow("nova_events", "starting event stream") + logger.debug("Nova events - starting event stream") # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { @@ -257,7 +257,7 @@ async def start_audio_connection(self) -> None: if self.audio_connection_active: return - log_event("nova_audio_connection_start") + logger.debug("Nova audio connection start") audio_content_start = json.dumps( { @@ -319,7 +319,7 @@ async def _check_silence(self) -> None: if self.audio_connection_active and self.last_audio_time: elapsed = time.time() - self.last_audio_time if elapsed >= self.silence_threshold: - log_event("nova_silence_detected", elapsed=elapsed) + logger.debug("Nova silence detected: %.2f seconds", elapsed) await self.end_audio_input() except asyncio.CancelledError: pass @@ -329,7 +329,7 @@ async def end_audio_input(self) -> None: if not self.audio_connection_active: return - log_event("nova_audio_connection_end") + logger.debug("Nova audio connection end") audio_content_end = json.dumps( {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} @@ -375,7 +375,7 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No if not self._active: return - log_event("nova_tool_result_send", id=tool_use_id) + logger.debug("Nova tool result send: %s", tool_use_id) content_name = str(uuid.uuid4()) events = [ self._get_tool_content_start_event(content_name, tool_use_id), @@ -384,14 +384,16 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No ] for i, event in enumerate(events): - await time_it_async(f"send_tool_event_{i + 1}", lambda event=event: self._send_nova_event(event)) + await self._send_nova_event(event) + + async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: return - log_flow("nova_cleanup", "starting connection close") + logger.debug("Nova cleanup - starting connection close") self._active = False # Cancel response processing task if running @@ -423,9 +425,9 @@ async def close(self) -> None: logger.warning("Error closing Nova Sonic stream: %s", e) except Exception as e: - log_event("nova_cleanup_error", error=str(e)) + logger.error("Nova cleanup error: %s", str(e)) finally: - log_event("nova_connection_closed") + logger.debug("Nova connection closed") def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | None: """Convert Nova Sonic events to provider-agnostic format.""" @@ -452,7 +454,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No # Check for Nova Sonic interruption pattern (matches working sample) if '{ "interrupted" : true }' in text_content: - log_event("nova_interruption_in_text") + logger.debug("Nova interruption detected in text") interruption: InterruptionDetectedEvent = {"reason": "user_input"} return {"interruptionDetected": interruption} @@ -480,7 +482,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": - log_event("nova_interruption_stop_reason") + logger.debug("Nova interruption stop reason") interruption: InterruptionDetectedEvent = {"reason": "user_input"} @@ -664,29 +666,26 @@ async def create_bidirectional_connection( **kwargs, ) -> BidirectionalModelSession: """Create Nova Sonic bidirectional connection.""" - log_flow("nova_connection_create", "starting") + logger.debug("Nova connection create - starting") # Initialize client if needed if not self._client: - await time_it_async("initialize_client", lambda: self._initialize_client()) + await self._initialize_client() # Start Nova Sonic bidirectional stream try: - stream = await time_it_async( - "invoke_model_with_bidirectional_stream", - lambda: self._client.invoke_model_with_bidirectional_stream( - InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - ), + stream = await self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) ) # Create and initialize connection connection = NovaSonicSession(stream, self.config) - await time_it_async("initialize_connection", lambda: connection.initialize(system_prompt, tools, messages)) + await connection.initialize(system_prompt, tools, messages) - log_event("nova_connection_created") + logger.debug("Nova connection created") return connection except Exception as e: - log_event("nova_connection_create_error", error=str(e)) + logger.error("Nova connection create error: %s", str(e)) logger.error("Failed to create Nova Sonic connection: %s", e) raise From 04265baa9267865fe9686dbe89440c552e77f2da Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 14:02:24 -0400 Subject: [PATCH 14/15] Removed logging utility --- .../bidirectional_streaming/utils/__init__.py | 5 -- .../bidirectional_streaming/utils/debug.py | 48 ------------------- 2 files changed, 53 deletions(-) delete mode 100644 src/strands/experimental/bidirectional_streaming/utils/__init__.py delete mode 100644 src/strands/experimental/bidirectional_streaming/utils/debug.py diff --git a/src/strands/experimental/bidirectional_streaming/utils/__init__.py b/src/strands/experimental/bidirectional_streaming/utils/__init__.py deleted file mode 100644 index 579478436..000000000 --- a/src/strands/experimental/bidirectional_streaming/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Utility functions for bidirectional streaming.""" - -from .debug import log_event, log_flow, time_it_async - -__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/utils/debug.py b/src/strands/experimental/bidirectional_streaming/utils/debug.py deleted file mode 100644 index 6a7fc3982..000000000 --- a/src/strands/experimental/bidirectional_streaming/utils/debug.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Debug utilities for Strands bidirectional streaming. - -Provides consistent debug logging across all bidirectional streaming components -with configurable output control matching the Nova Sonic tool use example. -""" - -import datetime -import inspect -import time - -# Debug logging system matching successful tool use example -DEBUG = False # Disable debug logging for clean output like tool use example - - -def debug_print(message): - """Print debug message with timestamp and function name.""" - if DEBUG: - function_name = inspect.stack()[1].function - if function_name == "time_it_async": - function_name = inspect.stack()[2].function - timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] - print(f"{timestamp} {function_name} {message}") - - -def log_event(event_type, **context): - """Log important events with structured context.""" - if DEBUG: - function_name = inspect.stack()[1].function - timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] - context_str = " ".join([f"{k}={v}" for k, v in context.items()]) if context else "" - print(f"{timestamp} {function_name} EVENT: {event_type} {context_str}") - - -def log_flow(step, details=""): - """Log important flow steps without excessive detail.""" - if DEBUG: - function_name = inspect.stack()[1].function - timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] - print(f"{timestamp} {function_name} FLOW: {step} {details}") - - -async def time_it_async(label, method_to_run): - """Time asynchronous method execution.""" - start_time = time.perf_counter() - result = await method_to_run() - end_time = time.perf_counter() - debug_print(f"Execution time for {label}: {end_time - start_time:.4f} seconds") - return result From 8a7396cf0715409b7fb35deb2c51b1164541a307 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 14:36:36 -0400 Subject: [PATCH 15/15] Updated types --- .../experimental/bidirectional_streaming/__init__.py | 3 --- .../bidirectional_streaming/models/bidirectional_model.py | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index f6a3b41bf..52822711a 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,5 +1,2 @@ """Bidirectional streaming package for real-time audio/text conversations.""" -from .utils import log_event, log_flow, time_it_async - -__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 4cd9cc6b8..d5c3c9b65 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -17,7 +17,7 @@ from ....types.content import Messages from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import AudioInputEvent +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ class BidirectionalModelSession(abc.ABC): """ @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[dict[str, any]]: + async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive events from the model in standardized format. Converts provider-specific events to a common format that can be @@ -71,7 +71,7 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No """Send tool execution result to the model. Formats and sends tool results according to the provider's specific protocol. - Handles both successful results and error cases. + Handles both successful results and error cases through the result dictionary. """ raise NotImplementedError