diff --git a/pyproject.toml b/pyproject.toml index 3c2243299..dd01ebde3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,13 @@ sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] +bidirectional-streaming = [ + "pyaudio>=0.2.13", + "rx>=3.2.0", + "smithy-aws-core>=0.0.1", + "pytz", + "aws_sdk_bedrock_runtime", +] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -68,7 +75,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,docs,gemini,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py new file mode 100644 index 000000000..52822711a --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -0,0 +1,2 @@ +"""Bidirectional streaming package for real-time audio/text conversations.""" + diff --git a/src/strands/experimental/bidirectional_streaming/agent/__init__.py b/src/strands/experimental/bidirectional_streaming/agent/__init__.py new file mode 100644 index 000000000..c490e001d --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/__init__.py @@ -0,0 +1,5 @@ +"""Bidirectional agent for real-time streaming conversations.""" + +from .agent import BidirectionalAgent + +__all__ = ["BidirectionalAgent"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py new file mode 100644 index 000000000..68d371a51 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -0,0 +1,161 @@ +"""Bidirectional Agent for real-time streaming conversations. + +Provides real-time audio and text interaction through persistent streaming sessions. +Unlike traditional request-response patterns, this agent maintains long-running +conversations where users can interrupt, provide additional input, and receive +continuous responses including audio output. + +Key capabilities: +- Persistent conversation sessions with concurrent processing +- Real-time audio input/output streaming +- Mid-conversation interruption and tool execution +- Event-driven communication with model providers +""" + +import asyncio +import logging +from typing import AsyncIterable + +from ....tools.executors import ConcurrentToolExecutor +from ....tools.registry import ToolRegistry +from ....types.content import Messages +from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection +from ..models.bidirectional_model import BidirectionalModel +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent + + +logger = logging.getLogger(__name__) + + +class BidirectionalAgent: + """Agent for bidirectional streaming conversations. + + Enables real-time audio and text interaction with AI models through persistent + sessions. Supports concurrent tool execution and interruption handling. + """ + + def __init__( + self, + model: BidirectionalModel, + tools: list | None = None, + system_prompt: str | None = None, + messages: Messages | None = None, + ): + """Initialize bidirectional agent with required model and optional configuration. + + Args: + model: BidirectionalModel instance supporting streaming sessions. + tools: Optional list of tools available to the model. + system_prompt: Optional system prompt for conversations. + messages: Optional conversation history to initialize with. + """ + self.model = model + self.system_prompt = system_prompt + self.messages = messages or [] + + # Initialize tool registry using existing Strands infrastructure + self.tool_registry = ToolRegistry() + if tools: + self.tool_registry.process_tools(tools) + self.tool_registry.initialize_tools() + + # Initialize tool executor for concurrent execution + self.tool_executor = ConcurrentToolExecutor() + + # Session management + self._session = None + self._output_queue = asyncio.Queue() + + async def start(self) -> None: + """Start a persistent bidirectional conversation session. + + Initializes the streaming session and starts background tasks for processing + model events, tool execution, and session management. + + Raises: + ValueError: If conversation already active. + ConnectionError: If session creation fails. + """ + if self._session and self._session.active: + raise ValueError("Conversation already active. Call end() first.") + + logger.debug("Conversation start - initializing session") + self._session = await start_bidirectional_connection(self) + logger.debug("Conversation ready") + + async def send(self, input_data: str | AudioInputEvent) -> None: + """Send input to the model (text or audio). + + Unified method for sending both text and audio input to the model during + an active conversation session. User input is automatically added to + conversation history for complete message tracking. + + Args: + input_data: Either a string for text input or AudioInputEvent for audio input. + + Raises: + ValueError: If no active session or invalid input type. + """ + self._validate_active_session() + + if isinstance(input_data, str): + # Add user text message to history + self.messages.append({"role": "user", "content": input_data}) + + logger.debug("Text sent: %d characters", len(input_data)) + await self._session.model_session.send_text_content(input_data) + elif isinstance(input_data, dict) and "audioData" in input_data: + # Handle audio input + await self._session.model_session.send_audio_content(input_data) + else: + raise ValueError( + "Input must be either a string (text) or AudioInputEvent " + "(dict with audioData, format, sampleRate, channels)" + ) + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive events from the model including audio, text, and tool calls. + + Yields model output events processed by background tasks including audio output, + text responses, tool calls, and session updates. + + Yields: + BidirectionalStreamEvent: Events from the model session. + """ + while self._session and self._session.active: + try: + event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) + yield event + except asyncio.TimeoutError: + continue + + async def interrupt(self) -> None: + """Interrupt the current model generation and clear audio buffers. + + Sends interruption signal to stop generation immediately and clears + pending audio output for responsive conversation flow. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + await self._session.model_session.send_interrupt() + + async def end(self) -> None: + """End the conversation session and cleanup all resources. + + Terminates the streaming session, cancels background tasks, and + closes the connection to the model provider. + """ + if self._session: + await stop_bidirectional_connection(self._session) + self._session = None + + def _validate_active_session(self) -> None: + """Validate that an active session exists. + + Raises: + ValueError: If no active session. + """ + if not self._session or not self._session.active: + raise ValueError("No active conversation. Call start() first.") diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py new file mode 100644 index 000000000..af8c4e1e1 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py @@ -0,0 +1,15 @@ +"""Event loop management for bidirectional streaming.""" + +from .bidirectional_event_loop import ( + BidirectionalConnection, + bidirectional_event_loop_cycle, + start_bidirectional_connection, + stop_bidirectional_connection, +) + +__all__ = [ + "BidirectionalConnection", + "start_bidirectional_connection", + "stop_bidirectional_connection", + "bidirectional_event_loop_cycle", +] diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py new file mode 100644 index 000000000..16be08aaf --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -0,0 +1,458 @@ +"""Bidirectional session management for concurrent streaming conversations. + +Manages bidirectional communication sessions with concurrent processing of model events, +tool execution, and audio processing. Provides coordination between background tasks +while maintaining a simple interface for agent interaction. + +Features: +- Concurrent task management for model events and tool execution +- Interruption handling with audio buffer clearing +- Tool execution with cancellation support +- Session lifecycle management +""" + +import asyncio +import json +import logging +import traceback +import uuid + +from ....tools._validator import validate_and_prepare_tools +from ....types.content import Message +from ....types.tools import ToolResult, ToolUse +from ..models.bidirectional_model import BidirectionalModelSession + + +logger = logging.getLogger(__name__) + +# Session constants +TOOL_QUEUE_TIMEOUT = 0.5 +SUPERVISION_INTERVAL = 0.1 + + +class BidirectionalConnection: + """Session wrapper for bidirectional communication with concurrent task management. + + Coordinates background tasks for model event processing, tool execution, and audio + handling while providing a simple interface for agent interactions. + """ + + def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent") -> None: + """Initialize session with model session and agent reference. + + Args: + model_session: Provider-specific bidirectional model session. + agent: BidirectionalAgent instance for tool registry access. + """ + self.model_session = model_session + self.agent = agent + self.active = True + + # Background processing coordination + self.background_tasks = [] + self.tool_queue = asyncio.Queue() + self.audio_output_queue = asyncio.Queue() + + # Task management for cleanup + self.pending_tool_tasks: dict[str, asyncio.Task] = {} + + # Interruption handling (model-agnostic) + self.interrupted = False + self.interruption_lock = asyncio.Lock() + + +async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: + """Initialize bidirectional session with conycurrent background tasks. + + Creates a model-specific session and starts background tasks for processing + model events, executing tools, and managing the session lifecycle. + + Args: + agent: BidirectionalAgent instance. + + Returns: + BidirectionalConnection: Active session with background tasks running. + """ + logger.debug("Starting bidirectional session - initializing model session") + + # Create provider-specific session + model_session = await agent.model.create_bidirectional_connection( + system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages + ) + + # Create session wrapper for background processing + session = BidirectionalConnection(model_session=model_session, agent=agent) + + # Start concurrent background processors IMMEDIATELY after session creation + # This is critical - Nova Sonic needs response processing during initialization + logger.debug("Starting background processors for concurrent processing") + session.background_tasks = [ + asyncio.create_task(_process_model_events(session)), # Handle model responses + asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently + ] + + # Start main coordination cycle + session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) + + logger.debug("Session ready with %d background tasks", len(session.background_tasks)) + return session + + +async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: + """End session and cleanup resources including background tasks. + + Args: + session: BidirectionalConnection to cleanup. + """ + if not session.active: + return + + logger.debug("Session cleanup starting") + session.active = False + + # Cancel pending tool tasks + for _, task in session.pending_tool_tasks.items(): + if not task.done(): + task.cancel() + + # Cancel background tasks + for task in session.background_tasks: + if not task.done(): + task.cancel() + + # Cancel main cycle task + if hasattr(session, "main_cycle_task") and not session.main_cycle_task.done(): + session.main_cycle_task.cancel() + + # Wait for tasks to complete + all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) + if hasattr(session, "main_cycle_task"): + all_tasks.append(session.main_cycle_task) + + if all_tasks: + await asyncio.gather(*all_tasks, return_exceptions=True) + + # Close model session + await session.model_session.close() + logger.debug("Session closed") + + +async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: + """Main event loop coordinator that runs continuously during the session. + + Monitors background tasks, manages session state, and handles session lifecycle. + Provides supervision for concurrent model event processing and tool execution. + + Args: + session: BidirectionalConnection to coordinate. + """ + while session.active: + try: + # Check if background processors are still running + if all(task.done() for task in session.background_tasks): + logger.debug("Session end - all processors completed") + session.active = False + break + + # Check for failed background tasks + for i, task in enumerate(session.background_tasks): + if task.done() and not task.cancelled(): + exception = task.exception() + if exception: + logger.error("Session error in processor %d: %s", i, str(exception)) + session.active = False + raise exception + + # Brief pause before next supervision check + await asyncio.sleep(SUPERVISION_INTERVAL) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Event loop error: %s", str(e)) + session.active = False + raise + + +async def _handle_interruption(session: BidirectionalConnection) -> None: + """Handle interruption detection with task cancellation and audio buffer clearing. + + Cancels pending tool tasks and clears audio output queues to ensure responsive + interruption handling during conversations. Protected by async lock to prevent + concurrent execution and race conditions. + + Args: + session: BidirectionalConnection to handle interruption for. + """ + async with session.interruption_lock: + # If already interrupted, skip duplicate processing + if session.interrupted: + logger.debug("Interruption already in progress") + return + + logger.debug("Interruption detected") + session.interrupted = True + + # Cancel all pending tool execution tasks + cancelled_tools = 0 + for task_id, task in list(session.pending_tool_tasks.items()): + if not task.done(): + task.cancel() + cancelled_tools += 1 + logger.debug("Tool task cancelled: %s", task_id) + + if cancelled_tools > 0: + logger.debug("Tool tasks cancelled: %d", cancelled_tools) + + # Clear all queued audio output events + cleared_count = 0 + while True: + try: + session.audio_output_queue.get_nowait() + cleared_count += 1 + except asyncio.QueueEmpty: + break + + # Also clear the agent's audio output queue + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + logger.debug("Agent audio queue cleared: %d events", audio_cleared) + + if cleared_count > 0: + logger.debug("Session audio queue cleared: %d events", cleared_count) + + # Reset interruption flag after clearing (automatic recovery) + session.interrupted = False + logger.debug("Interruption handled - tools cancelled: %d, audio cleared: %d", cancelled_tools, cleared_count) + + +async def _process_model_events(session: BidirectionalConnection) -> None: + """Process model events and convert them to Strands format. + + Background task that handles all model responses, converts provider-specific + events to standardized formats, and manages interruption detection. + + Args: + session: BidirectionalConnection containing model session. + """ + logger.debug("Model events processor started") + try: + async for provider_event in session.model_session.receive_events(): + if not session.active: + break + + # Basic validation - skip invalid events + if not isinstance(provider_event, dict): + continue + + strands_event = provider_event + + # Handle interruption detection (provider converts raw patterns to interruptionDetected) + if strands_event.get("interruptionDetected"): + logger.debug("Interruption forwarded") + await _handle_interruption(session) + # Forward interruption event to agent for application-level handling + await session.agent._output_queue.put(strands_event) + continue + + # Queue tool requests for concurrent execution + if strands_event.get("toolUse"): + logger.debug("Tool queued: %s", strands_event["toolUse"].get("name")) + await session.tool_queue.put(strands_event["toolUse"]) + continue + + # Send output events to Agent for receive() method + if strands_event.get("audioOutput") or strands_event.get("textOutput"): + await session.agent._output_queue.put(strands_event) + + # Update Agent conversation history using existing patterns + if strands_event.get("messageStop"): + logger.debug("Message added to history") + session.agent.messages.append(strands_event["messageStop"]["message"]) + + # Handle user audio transcripts - add to message history + if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": + user_transcript = strands_event["textOutput"]["text"] + if user_transcript.strip(): # Only add non-empty transcripts + user_message = {"role": "user", "content": user_transcript} + session.agent.messages.append(user_message) + logger.debug("User transcript added to history") + + except Exception as e: + logger.error("Model events error: %s", str(e)) + traceback.print_exc() + finally: + logger.debug("Model events processor stopped") + + +async def _process_tool_execution(session: BidirectionalConnection) -> None: + """Execute tools concurrently with interruption support. + + Background task that manages tool execution without blocking model event + processing or user interaction. Uses proper asyncio cancellation for + interruption handling rather than manual state checks. + + Args: + session: BidirectionalConnection containing tool queue. + """ + logger.debug("Tool execution processor started") + while session.active: + try: + tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) + logger.debug("Tool execution started: %s (id: %s)", tool_use.get("name"), tool_use.get("toolUseId")) + + task_id = str(uuid.uuid4()) + task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) + session.pending_tool_tasks[task_id] = task + + def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: + try: + # Remove from pending tasks + if task_id in session.pending_tool_tasks: + del session.pending_tool_tasks[task_id] + + # Log completion status + if completed_task.cancelled(): + logger.debug("Tool task cleanup cancelled: %s", task_id) + elif completed_task.exception(): + logger.error("Tool task cleanup error: %s - %s", task_id, str(completed_task.exception())) + else: + logger.debug("Tool task cleanup success: %s", task_id) + except Exception as e: + logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) + + task.add_done_callback(cleanup_task) + + except asyncio.TimeoutError: + if not session.active: + break + # Remove completed tasks from tracking + completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] + for task_id in completed_tasks: + if task_id in session.pending_tool_tasks: + del session.pending_tool_tasks[task_id] + + if completed_tasks: + logger.debug("Periodic task cleanup: %d tasks", len(completed_tasks)) + + continue + except Exception as e: + logger.error("Tool execution error: %s", str(e)) + if not session.active: + break + + logger.debug("Tool execution processor stopped") + + + + + +async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: + """Execute tool using Strands infrastructure with interruption support. + + Executes tools using the existing Strands tool system with proper asyncio + cancellation handling. Tool execution is stopped via task cancellation, + not manual state checks. + + Args: + session: BidirectionalConnection for context. + tool_use: Tool use event to execute. + """ + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + + try: + # Create message structure for existing tool system + tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + + tool_uses: list[ToolUse] = [] + tool_results: list[ToolResult] = [] + invalid_tool_use_ids: list[str] = [] + + # Validate using existing Strands validation + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) + + # Filter valid tool uses + valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] + + if not valid_tool_uses: + logger.warning("Tool validation failed: %s (id: %s)", tool_name, tool_id) + return + + # Execute tools directly (simpler approach for bidirectional) + for tool_use in valid_tool_uses: + tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) + + if tool_func: + try: + actual_func = _extract_callable_function(tool_func) + + # Execute tool function with provided input + result = actual_func(**tool_use.get("input", {})) + + tool_result = _create_success_result(tool_use["toolUseId"], result) + tool_results.append(tool_result) + + except Exception as e: + logger.error("Tool execution failed: %s - %s", tool_name, str(e)) + tool_result = _create_error_result(tool_use["toolUseId"], str(e)) + tool_results.append(tool_result) + else: + logger.warning("Tool not found: %s", tool_name) + + # Send results through provider-specific session + for result in tool_results: + await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) + + logger.debug("Tool execution completed: %s (%d results)", tool_name, len(tool_results)) + + except asyncio.CancelledError: + # Task was cancelled due to interruption - this is expected behavior + logger.debug("Tool task cancelled gracefully: %s (id: %s)", tool_name, tool_id) + raise # Re-raise to properly handle cancellation + except Exception as e: + logger.error("Tool execution error: %s - %s", tool_use.get("name"), str(e)) + + try: + await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) + except Exception as send_error: + logger.error("Tool error send failed: %s", str(send_error)) + + +def _extract_callable_function(tool_func: any) -> any: + """Extract the callable function from different tool object types.""" + if hasattr(tool_func, "_tool_func"): + return tool_func._tool_func + elif hasattr(tool_func, "func"): + return tool_func.func + elif callable(tool_func): + return tool_func + else: + raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") + + +def _create_success_result(tool_use_id: str, result: any) -> dict[str, any]: + """Create a successful tool result.""" + return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} + + +def _create_error_result(tool_use_id: str, error: str) -> dict[str, any]: + """Create an error tool result.""" + return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py new file mode 100644 index 000000000..6cba974e0 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -0,0 +1,6 @@ +"""Bidirectional model interfaces and implementations.""" + +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession + +__all__ = ["BidirectionalModel", "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession"] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py new file mode 100644 index 000000000..d5c3c9b65 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -0,0 +1,104 @@ +"""Bidirectional model interface for real-time streaming conversations. + +Defines the interface for models that support bidirectional streaming capabilities. +Provides abstractions for different model providers with connection-based communication +patterns that support real-time audio and text interaction. + +Features: +- connection-based persistent connections +- Real-time bidirectional communication +- Provider-agnostic event normalization +- Tool execution integration +""" + +import abc +import logging +from typing import AsyncIterable + +from ....types.content import Messages +from ....types.tools import ToolSpec +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent + +logger = logging.getLogger(__name__) + + +class BidirectionalModelSession(abc.ABC): + """Abstract interface for model-specific bidirectional communication connections. + + Defines the contract for managing persistent streaming connections with individual + model providers, handling audio/text input, receiving events, and managing + tool execution results. + """ + + @abc.abstractmethod + async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive events from the model in standardized format. + + Converts provider-specific events to a common format that can be + processed uniformly by the event loop. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content to the model during an active connection. + + Handles audio encoding and provider-specific formatting while presenting + a simple AudioInputEvent interface. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content to the model during ongoing generation. + + Allows natural interruption and follow-up questions without requiring + connection restart. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_interrupt(self) -> None: + """Send interruption signal to stop generation immediately. + + Enables responsive conversational experiences where users can + naturally interrupt during model responses. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: + """Send tool execution result to the model. + + Formats and sends tool results according to the provider's specific protocol. + Handles both successful results and error cases through the result dictionary. + """ + raise NotImplementedError + + @abc.abstractmethod + async def close(self) -> None: + """Close the connection and cleanup resources.""" + raise NotImplementedError + + +class BidirectionalModel(abc.ABC): + """Interface for models that support bidirectional streaming. + + Defines the contract for creating persistent streaming connections that support + real-time audio and text communication with AI models. + """ + + @abc.abstractmethod + async def create_bidirectional_connection( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> BidirectionalModelSession: + """Create a bidirectional connection with the model. + + Establishes a persistent connection for real-time communication while + abstracting provider-specific initialization requirements. + """ + raise NotImplementedError diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py new file mode 100644 index 000000000..7f7937ef1 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -0,0 +1,711 @@ +"""Nova Sonic bidirectional model provider for real-time streaming conversations. + +Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the +complex event sequencing and audio processing required by Nova Sonic's +InvokeModelWithBidirectionalStream protocol. + +Nova Sonic specifics: +- Hierarchical event sequences: connectionStart → promptStart → content streaming +- Base64-encoded audio format with hex encoding +- Tool execution with content containers and identifier tracking +- 8-minute connection limits with proper cleanup sequences +- Interruption detection through stopReason events +""" + +import asyncio +import base64 +import json +import logging +import time +import traceback +import uuid +from typing import AsyncIterable + +from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput +from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme +from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk +from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + InterruptionDetectedEvent, + TextOutputEvent, +) + +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# Nova Sonic configuration constants +NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} + +NOVA_AUDIO_INPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "audioType": "SPEECH", + "encoding": "base64", +} + +NOVA_AUDIO_OUTPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 24000, + "sampleSizeBits": 16, + "channelCount": 1, + "voiceId": "matthew", + "encoding": "base64", + "audioType": "SPEECH", +} + +NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} +NOVA_TOOL_CONFIG = {"mediaType": "application/json"} + +# Timing constants +SILENCE_THRESHOLD = 2.0 +EVENT_DELAY = 0.1 +RESPONSE_TIMEOUT = 1.0 + + +class NovaSonicSession(BidirectionalModelSession): + """Nova Sonic connection implementation handling the provider's specific protocol. + + Manages Nova Sonic's complex event sequencing, audio format conversion, and + tool execution patterns while providing the standard BidirectionalModelSession + interface. + """ + + def __init__(self, stream: any, config: dict[str, any]) -> None: + """Initialize Nova Sonic connection. + + Args: + stream: Nova Sonic bidirectional stream. + config: Model configuration. + """ + self.stream = stream + self.config = config + self.prompt_name = str(uuid.uuid4()) + self._active = True + + # Nova Sonic requires unique content names + self.audio_content_name = str(uuid.uuid4()) + self.text_content_name = str(uuid.uuid4()) + + # Audio connection state + self.audio_connection_active = False + self.last_audio_time = None + self.silence_threshold = SILENCE_THRESHOLD + self.silence_task = None + + # Validate stream + if not stream: + logger.error("Stream is None") + raise ValueError("Stream cannot be None") + + logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) + + async def initialize( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + ) -> None: + """Initialize Nova Sonic connection with required protocol sequence.""" + try: + system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." + + init_events = self._build_initialization_events(system_prompt, tools or [], messages) + + logger.debug(f"Nova Sonic initialization - sending {len(init_events)} events") + await self._send_initialization_events(init_events) + + logger.info("Nova Sonic connection initialized successfully") + self._response_task = asyncio.create_task(self._process_responses()) + + except Exception as e: + logger.error("Error during Nova Sonic initialization: %s", e) + raise + + def _build_initialization_events( + self, system_prompt: str, tools: list[ToolSpec], messages: Messages | None + ) -> list[str]: + """Build the sequence of initialization events.""" + events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] + + events.extend(self._get_system_prompt_events(system_prompt)) + + # Message history would be processed here if needed in the future + # Currently not implemented as it's not used in the existing test cases + + return events + + async def _send_initialization_events(self, events: list[str]) -> None: + """Send initialization events with required delays.""" + for i, event in enumerate(events): + await self._send_nova_event(event) + await asyncio.sleep(EVENT_DELAY) + + async def _process_responses(self) -> None: + """Process Nova Sonic responses continuously.""" + logger.debug("Nova Sonic response processor started") + + try: + while self._active: + try: + output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) + result = await output[1].receive() + + if result.value and result.value.bytes_: + await self._handle_response_data(result.value.bytes_.decode("utf-8")) + + except asyncio.TimeoutError: + await asyncio.sleep(0.1) + continue + except Exception as e: + logger.warning(f"Nova Sonic response error: {e}") + await asyncio.sleep(0.1) + continue + + except Exception as e: + logger.error(f"Nova Sonic fatal error: {e}") + finally: + logger.debug("Nova Sonic response processor stopped") + + async def _handle_response_data(self, response_data: str) -> None: + """Handle decoded response data from Nova Sonic.""" + try: + json_data = json.loads(response_data) + + if "event" in json_data: + nova_event = json_data["event"] + self._log_event_type(nova_event) + + if not hasattr(self, "_event_queue"): + self._event_queue = asyncio.Queue() + + await self._event_queue.put(nova_event) + except json.JSONDecodeError as e: + logger.warning(f"Nova Sonic JSON decode error: {e}") + + def _log_event_type(self, nova_event: dict[str, any]) -> None: + """Log specific Nova Sonic event types for debugging.""" + if "usageEvent" in nova_event: + logger.debug("Nova usage: %s", nova_event["usageEvent"]) + elif "textOutput" in nova_event: + logger.debug("Nova text output") + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + logger.debug("Nova tool use: %s (id: %s)", tool_use["toolName"], tool_use["toolUseId"]) + elif "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + logger.debug("Nova audio output: %d bytes", len(audio_bytes)) + + async def receive_events(self) -> AsyncIterable[dict[str, any]]: + """Receive Nova Sonic events and convert to provider-agnostic format.""" + if not self.stream: + logger.error("Stream is None") + return + + logger.debug("Nova events - starting event stream") + + # Emit connection start event to Strands event system + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.prompt_name, + "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")}, + } + yield {"BidirectionalConnectionStart": connection_start} + + # Initialize event queue if not already done + if not hasattr(self, "_event_queue"): + self._event_queue = asyncio.Queue() + + try: + while self._active: + try: + # Get events from the queue populated by _process_responses + nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + + # Convert to provider-agnostic format + provider_event = self._convert_nova_event(nova_event) + if provider_event: + yield provider_event + + except asyncio.TimeoutError: + # No events in queue - continue waiting + continue + + except Exception as e: + logger.error("Error receiving Nova Sonic event: %s", e) + logger.error(traceback.format_exc()) + finally: + # Emit connection end event when exiting + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.prompt_name, + "reason": "connection_complete", + "metadata": {"provider": "nova_sonic"}, + } + yield {"BidirectionalConnectionEnd": connection_end} + + async def start_audio_connection(self) -> None: + """Start audio input connection (call once before sending audio chunks).""" + if self.audio_connection_active: + return + + logger.debug("Nova audio connection start") + + audio_content_start = json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG, + } + } + } + ) + + await self._send_nova_event(audio_content_start) + self.audio_connection_active = True + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio using Nova Sonic protocol-specific format.""" + if not self._active: + return + + # Start audio connection if not already active + if not self.audio_connection_active: + await self.start_audio_connection() + + # Update last audio time and cancel any pending silence task + self.last_audio_time = time.time() + if self.silence_task and not self.silence_task.done(): + self.silence_task.cancel() + + # Convert audio to Nova Sonic base64 format + nova_audio_data = base64.b64encode(audio_input["audioData"]).decode("utf-8") + + # Send audio input event + audio_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "content": nova_audio_data, + } + } + } + ) + + await self._send_nova_event(audio_event) + + # Start silence detection task + self.silence_task = asyncio.create_task(self._check_silence()) + + async def _check_silence(self) -> None: + """Check for silence and automatically end audio connection.""" + try: + await asyncio.sleep(self.silence_threshold) + if self.audio_connection_active and self.last_audio_time: + elapsed = time.time() - self.last_audio_time + if elapsed >= self.silence_threshold: + logger.debug("Nova silence detected: %.2f seconds", elapsed) + await self.end_audio_input() + except asyncio.CancelledError: + pass + + async def end_audio_input(self) -> None: + """End current audio input connection to trigger Nova Sonic processing.""" + if not self.audio_connection_active: + return + + logger.debug("Nova audio connection end") + + audio_content_end = json.dumps( + {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} + ) + + await self._send_nova_event(audio_content_end) + self.audio_connection_active = False + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content using Nova Sonic format.""" + if not self._active: + return + + content_name = str(uuid.uuid4()) + events = [ + self._get_text_content_start_event(content_name), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name), + ] + + for event in events: + await self._send_nova_event(event) + + async def send_interrupt(self) -> None: + """Send interruption signal to Nova Sonic.""" + if not self._active: + return + + # Nova Sonic handles interruption through special input events + interrupt_event = { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "stopReason": "INTERRUPTED", + } + } + } + await self._send_nova_event(interrupt_event) + + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: + """Send tool result using Nova Sonic toolResult format.""" + if not self._active: + return + + logger.debug("Nova tool result send: %s", tool_use_id) + content_name = str(uuid.uuid4()) + events = [ + self._get_tool_content_start_event(content_name, tool_use_id), + self._get_tool_result_event(content_name, result), + self._get_content_end_event(content_name), + ] + + for i, event in enumerate(events): + await self._send_nova_event(event) + + + + async def close(self) -> None: + """Close Nova Sonic connection with proper cleanup sequence.""" + if not self._active: + return + + logger.debug("Nova cleanup - starting connection close") + self._active = False + + # Cancel response processing task if running + if hasattr(self, "_response_task") and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + # End audio connection if active + if self.audio_connection_active: + await self.end_audio_input() + + # Send cleanup events + cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] + + for event in cleanup_events: + try: + await self._send_nova_event(event) + except Exception as e: + logger.warning("Error during Nova Sonic cleanup: %s", e) + + # Close stream + try: + await self.stream.input_stream.close() + except Exception as e: + logger.warning("Error closing Nova Sonic stream: %s", e) + + except Exception as e: + logger.error("Nova cleanup error: %s", str(e)) + finally: + logger.debug("Nova connection closed") + + def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | None: + """Convert Nova Sonic events to provider-agnostic format.""" + # Handle audio output + if "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + + audio_output: AudioOutputEvent = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": "base64", + } + + return {"audioOutput": audio_output} + + # Handle text output + elif "textOutput" in nova_event: + text_content = nova_event["textOutput"]["content"] + # Use stored role from contentStart event, fallback to event role + role = getattr(self, "_current_role", nova_event["textOutput"].get("role", "assistant")) + + # Check for Nova Sonic interruption pattern (matches working sample) + if '{ "interrupted" : true }' in text_content: + logger.debug("Nova interruption detected in text") + interruption: InterruptionDetectedEvent = {"reason": "user_input"} + return {"interruptionDetected": interruption} + + # Show transcription for user speech - ALWAYS show these regardless of DEBUG flag + if role == "USER": + print(f"User: {text_content}") + elif role == "ASSISTANT": + print(f"Assistant: {text_content}") + + text_output: TextOutputEvent = {"text": text_content, "role": role.lower()} + + return {"textOutput": text_output} + + # Handle tool use + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + + tool_use_event: ToolUse = { + "toolUseId": tool_use["toolUseId"], + "name": tool_use["toolName"], + "input": json.loads(tool_use["content"]), + } + + return {"toolUse": tool_use_event} + + # Handle interruption + elif nova_event.get("stopReason") == "INTERRUPTED": + logger.debug("Nova interruption stop reason") + + interruption: InterruptionDetectedEvent = {"reason": "user_input"} + + return {"interruptionDetected": interruption} + + # Handle usage events (ignore) + elif "usageEvent" in nova_event: + return None + + # Handle content start events (track role) + elif "contentStart" in nova_event: + role = nova_event["contentStart"].get("role", "unknown") + # Store role for subsequent text output events + self._current_role = role + return None + + # Handle other events + else: + return None + + # Nova Sonic event template methods + def _get_connection_start_event(self) -> str: + """Generate Nova Sonic connection start event.""" + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) + + def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: + """Generate Nova Sonic prompt start event with tool configuration.""" + prompt_start_event = { + "event": { + "promptStart": { + "promptName": self.prompt_name, + "textOutputConfiguration": NOVA_TEXT_CONFIG, + "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, + } + } + } + + if tools: + tool_config = self._build_tool_configuration(tools) + prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG + prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} + + return json.dumps(prompt_start_event) + + def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict]: + """Build tool configuration from tool specs.""" + tool_config = [] + for tool in tools: + input_schema = ( + {"json": json.dumps(tool["inputSchema"]["json"])} + if "json" in tool["inputSchema"] + else {"json": json.dumps(tool["inputSchema"])} + ) + + tool_config.append( + {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} + ) + return tool_config + + def _get_system_prompt_events(self, system_prompt: str) -> list[str]: + """Generate system prompt events.""" + content_name = str(uuid.uuid4()) + return [ + self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_input_event(content_name, system_prompt), + self._get_content_end_event(content_name), + ] + + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: + """Generate text content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG, + } + } + } + ) + + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: + """Generate tool content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG, + }, + } + } + } + ) + + def _get_text_input_event(self, content_name: str, text: str) -> str: + """Generate text input event.""" + return json.dumps( + {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} + ) + + def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: + """Generate tool result event.""" + return json.dumps( + { + "event": { + "toolResult": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": json.dumps(result), + } + } + } + ) + + def _get_content_end_event(self, content_name: str) -> str: + """Generate content end event.""" + return json.dumps({"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": content_name}}}) + + def _get_prompt_end_event(self) -> str: + """Generate prompt end event.""" + return json.dumps({"event": {"promptEnd": {"promptName": self.prompt_name}}}) + + def _get_connection_end_event(self) -> str: + """Generate connection end event.""" + return json.dumps({"event": {"connectionEnd": {}}}) + + async def _send_nova_event(self, event: str) -> None: + """Send event JSON string to Nova Sonic stream.""" + try: + # Event is already a JSON string + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk(value=BidirectionalInputPayloadPart(bytes_=bytes_data)) + await self.stream.input_stream.send(chunk) + logger.debug("Successfully sent Nova Sonic event") + + except Exception as e: + logger.error("Error sending Nova Sonic event: %s", e) + logger.error("Event was: %s", event) + raise + + +class NovaSonicBidirectionalModel(BidirectionalModel): + """Nova Sonic model implementation for bidirectional streaming. + + Provides access to Amazon's Nova Sonic model through the bidirectional + streaming interface, handling AWS authentication and connection management. + """ + + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: + """Initialize Nova Sonic bidirectional model. + + Args: + model_id: Nova Sonic model identifier. + region: AWS region. + **config: Additional configuration. + """ + self.model_id = model_id + self.region = region + self.config = config + self._client = None + + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) + + async def create_bidirectional_connection( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> BidirectionalModelSession: + """Create Nova Sonic bidirectional connection.""" + logger.debug("Nova connection create - starting") + + # Initialize client if needed + if not self._client: + await self._initialize_client() + + # Start Nova Sonic bidirectional stream + try: + stream = await self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ) + + # Create and initialize connection + connection = NovaSonicSession(stream, self.config) + await connection.initialize(system_prompt, tools, messages) + + logger.debug("Nova connection created") + return connection + except Exception as e: + logger.error("Nova connection create error: %s", str(e)) + logger.error("Failed to create Nova Sonic connection: %s", e) + raise + + async def _initialize_client(self) -> None: + """Initialize Nova Sonic client.""" + try: + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, + ) + + self._client = BedrockRuntimeClient(config=config) + logger.debug("Nova Sonic client initialized") + + except ImportError as e: + logger.error("Nova Sonic dependencies not available: %s", e) + raise + except Exception as e: + logger.error("Error initializing Nova Sonic client: %s", e) + raise diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py new file mode 100644 index 000000000..b31607966 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -0,0 +1,198 @@ +"""Test suite for bidirectional streaming with real-time audio interaction. + +Tests the complete bidirectional streaming system including audio input/output, +interruption handling, and concurrent tool execution using Nova Sonic. +""" + +import asyncio +import sys +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import time + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel + + +async def play(context): + """Play audio output with responsive interruption support.""" + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=1, + format=pyaudio.paInt16, + output=True, + rate=24000, + frames_per_buffer=1024, + ) + + try: + while context["active"]: + try: + # Check for interruption first + if context.get("interrupted", False): + # Clear entire audio queue immediately + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get next audio data + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if context.get("interrupted", False) or not context["active"]: + break + + end = min(i + chunk_size, len(audio_data)) + chunk = audio_data[i:end] + speaker.write(chunk) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue # No audio available + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + finally: + speaker.close() + audio.terminate() + + +async def record(context): + """Record audio input from microphone.""" + audio = pyaudio.PyAudio() + microphone = audio.open( + channels=1, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + rate=16000, + ) + + try: + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + context["audio_in"].put_nowait(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + except asyncio.CancelledError: + pass + finally: + microphone.close() + audio.terminate() + + +async def receive(agent, context): + """Receive and process events from agent.""" + try: + async for event in agent.receive(): + # Handle audio output + if "audioOutput" in event: + if not context.get("interrupted", False): + context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) + + # Handle interruption events + elif "interruptionDetected" in event: + context["interrupted"] = True + elif "interrupted" in event: + context["interrupted"] = True + + # Handle text output with interruption detection + elif "textOutput" in event: + text_content = event["textOutput"].get("content", "") + role = event["textOutput"].get("role", "unknown") + + # Check for text-based interruption patterns + if '{ "interrupted" : true }' in text_content: + context["interrupted"] = True + elif "interrupted" in text_content.lower(): + context["interrupted"] = True + + # Log text output + if role.upper() == "USER": + print(f"User: {text_content}") + elif role.upper() == "ASSISTANT": + print(f"Assistant: {text_content}") + + except asyncio.CancelledError: + pass + + +async def send(agent, context): + """Send audio input to agent.""" + try: + while time.time() - context["start_time"] < context["duration"]: + try: + audio_bytes = context["audio_in"].get_nowait() + audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} + await agent.send(audio_event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) # Restored to working timing + except asyncio.CancelledError: + break + + context["active"] = False + except asyncio.CancelledError: + pass + + +async def main(duration=180): + """Main function for bidirectional streaming test.""" + print("Starting bidirectional streaming test...") + print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") + + # Initialize model and agent + model = NovaSonicBidirectionalModel(region="us-east-1") + agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") + + await agent.start() + + # Create shared context for all tasks + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "connection": agent._session, + "duration": duration, + "start_time": time.time(), + "interrupted": False, + } + + print("Speak into microphone. Press Ctrl+C to exit.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), record(context), receive(agent, context), send(agent, context), return_exceptions=True + ) + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + finally: + print("Cleaning up...") + context["active"] = False + await agent.end() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py new file mode 100644 index 000000000..510285f06 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -0,0 +1,31 @@ +"""Type definitions for bidirectional streaming.""" + +from .bidirectional_streaming import ( + DEFAULT_CHANNELS, + DEFAULT_SAMPLE_RATE, + SUPPORTED_AUDIO_FORMATS, + SUPPORTED_CHANNELS, + SUPPORTED_SAMPLE_RATES, + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, +) + +__all__ = [ + "AudioInputEvent", + "AudioOutputEvent", + "BidirectionalConnectionEndEvent", + "BidirectionalConnectionStartEvent", + "BidirectionalStreamEvent", + "InterruptionDetectedEvent", + "TextOutputEvent", + "SUPPORTED_AUDIO_FORMATS", + "SUPPORTED_SAMPLE_RATES", + "SUPPORTED_CHANNELS", + "DEFAULT_SAMPLE_RATE", + "DEFAULT_CHANNELS", +] diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py new file mode 100644 index 000000000..01d72356a --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -0,0 +1,144 @@ +"""Bidirectional streaming types for real-time audio/text conversations. + +Type definitions for bidirectional streaming that extends Strands' existing streaming +capabilities with real-time audio and persistent connection support. + +Key features: +- Audio input/output events with standardized formats +- Interruption detection and handling +- connection lifecycle management +- Provider-agnostic event types +- Backwards compatibility with existing StreamEvent types + +Audio format normalization: +- Supports PCM, WAV, Opus, and MP3 formats +- Standardizes sample rates (16kHz, 24kHz, 48kHz) +- Normalizes channel configurations (mono/stereo) +- Abstracts provider-specific encodings +""" + +from typing import Any, Dict, Literal, Optional + +from typing_extensions import TypedDict + +from ....types.content import Role +from ....types.streaming import StreamEvent + +# Audio format constants +SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] +SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] +SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo +DEFAULT_SAMPLE_RATE = 16000 +DEFAULT_CHANNELS = 1 + + +class AudioOutputEvent(TypedDict): + """Audio output event from the model. + + Provides standardized audio output format across different providers using + raw bytes instead of provider-specific encodings. + + Attributes: + audioData: Raw audio bytes (not base64 or hex encoded). + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + encoding: Original provider encoding for debugging purposes. + """ + + audioData: bytes + format: Literal["pcm", "wav", "opus", "mp3"] + sampleRate: Literal[16000, 24000, 48000] + channels: Literal[1, 2] + encoding: Optional[str] + + +class AudioInputEvent(TypedDict): + """Audio input event for sending audio to the model. + + Used for sending audio data through the send() method. + + Attributes: + audioData: Raw audio bytes to send to model. + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + """ + + audioData: bytes + format: Literal["pcm", "wav", "opus", "mp3"] + sampleRate: Literal[16000, 24000, 48000] + channels: Literal[1, 2] + + +class TextOutputEvent(TypedDict): + """Text output event from the model during bidirectional streaming. + + Attributes: + text: The text content from the model. + role: The role of the message sender. + """ + + text: str + role: Role + + +class InterruptionDetectedEvent(TypedDict): + """Interruption detection event. + + Signals when user interruption is detected during model generation. + + Attributes: + reason: Interruption reason from predefined set. + """ + + reason: Literal["user_input", "vad_detected", "manual"] + + +class BidirectionalConnectionStartEvent(TypedDict, total=False): + """connection start event for bidirectional streaming. + + Attributes: + connectionId: Unique connection identifier. + metadata: Provider-specific connection metadata. + """ + + connectionId: Optional[str] + metadata: Optional[Dict[str, Any]] + + +class BidirectionalConnectionEndEvent(TypedDict): + """connection end event for bidirectional streaming. + + Attributes: + reason: Reason for connection end from predefined set. + connectionId: Unique connection identifier. + metadata: Provider-specific connection metadata. + """ + + reason: Literal["user_request", "timeout", "error"] + connectionId: Optional[str] + metadata: Optional[Dict[str, Any]] + + +class BidirectionalStreamEvent(StreamEvent, total=False): + """Bidirectional stream event extending existing StreamEvent. + + Extends the existing StreamEvent type with bidirectional-specific events + while maintaining full backward compatibility with existing Strands streaming. + + Attributes: + audioOutput: Audio output from the model. + audioInput: Audio input sent to the model. + textOutput: Text output from the model. + interruptionDetected: User interruption detection. + BidirectionalConnectionStart: connection start event. + BidirectionalConnectionEnd: connection end event. + """ + + audioOutput: AudioOutputEvent + audioInput: AudioInputEvent + textOutput: TextOutputEvent + interruptionDetected: InterruptionDetectedEvent + BidirectionalConnectionStart: BidirectionalConnectionStartEvent + BidirectionalConnectionEnd: BidirectionalConnectionEndEvent