Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/strands/experimental/bidirectional_streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""Bidirectional streaming package."""

# Main components - Primary user interface
from .agent.agent import BidirectionalAgent
from .agent.agent import BidiAgent

# IO channels - Hardware abstraction
from .io.audio import AudioIO

# Model interface (for custom implementations)
from .models.bidirectional_model import BidirectionalModel
from .models.bidirectional_model import BidiModel

# Model providers - What users need to create models
from .models.gemini_live import GeminiLiveModel
from .models.novasonic import NovaSonicModel
from .models.openai import OpenAIRealtimeModel
from .models.gemini_live import BidiGeminiLiveModel
from .models.novasonic import BidiNovaSonicModel
from .models.openai import BidiOpenAIRealtimeModel

# Event types - For type hints and event handling
from .types.bidirectional_streaming import (
Expand All @@ -29,13 +29,13 @@

__all__ = [
# Main interface
"BidirectionalAgent",
"BidiAgent",
# IO channels
"AudioIO",
# Model providers
"GeminiLiveModel",
"NovaSonicModel",
"OpenAIRealtimeModel",
"BidiGeminiLiveModel",
"BidiNovaSonicModel",
"BidiOpenAIRealtimeModel",

# Event types
"AudioInputEvent",
Expand All @@ -48,5 +48,5 @@
"VoiceActivityEvent",
"UsageMetricsEvent",
# Model interface
"BidirectionalModel",
"BidiModel",
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Bidirectional agent for real-time streaming conversations."""

from .agent import BidirectionalAgent
from .agent import BidiAgent

__all__ = ["BidirectionalAgent"]
__all__ = ["BidiAgent"]
34 changes: 17 additions & 17 deletions src/strands/experimental/bidirectional_streaming/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from ....types.content import Message, Messages
from ....types.tools import ToolResult, ToolUse, AgentTool

from ..event_loop.bidirectional_event_loop import BidirectionalAgentLoop
from ..models.bidirectional_model import BidirectionalModel
from ..models.novasonic import NovaSonicModel
from ..event_loop.bidirectional_event_loop import BidiAgentLoop
from ..models.bidirectional_model import BidiModel
from ..models.novasonic import BidiNovaSonicModel
from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent
from ..types import BidiIO
from ....experimental.tools import ToolProvider
Expand All @@ -42,7 +42,7 @@
BidirectionalInput = str | AudioInputEvent | ImageInputEvent


class BidirectionalAgent:
class BidiAgent:
"""Agent for bidirectional streaming conversations.

Enables real-time audio and text interaction with AI models through persistent
Expand All @@ -51,7 +51,7 @@ class BidirectionalAgent:

def __init__(
self,
model: BidirectionalModel| str | None = None,
model: BidiModel| str | None = None,
tools: list[str| AgentTool| ToolProvider]| None = None,
system_prompt: str | None = None,
messages: Messages | None = None,
Expand All @@ -66,7 +66,7 @@ def __init__(
"""Initialize bidirectional agent.

Args:
model: BidirectionalModel instance, string model_id, or None for default detection.
model: BidiModel instance, string model_id, or None for default detection.
tools: Optional list of tools with flexible format support.
system_prompt: Optional system prompt for conversations.
messages: Optional conversation history to initialize with.
Expand All @@ -83,9 +83,9 @@ def __init__(
TypeError: If model type is unsupported.
"""
self.model = (
NovaSonicModel()
BidiNovaSonicModel()
if not model
else NovaSonicModel(model_id=model)
else BidiNovaSonicModel(model_id=model)
if isinstance(model, str)
else model
)
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
self._tool_caller = _ToolCaller(self)

# connection management
self._agent_loop: "BidirectionalAgentLoop" | None = None
self._agent_loop: "BidiAgentLoop" | None = None
self._output_queue = asyncio.Queue()
self._current_adapters = [] # Track adapters for cleanup

Expand All @@ -134,7 +134,7 @@ def tool(self) -> _ToolCaller:

Example:
```
agent = BidirectionalAgent(model=model, tools=[calculator])
agent = BidiAgent(model=model, tools=[calculator])
agent.tool.calculator(expression="2+2")
```
"""
Expand Down Expand Up @@ -252,11 +252,11 @@ async def start(self) -> None:
logger.debug("Conversation start - initializing connection")

# Create model session and event loop directly
await self.model.connect(
await self.model.start(
system_prompt=self.system_prompt, tools=self.tool_registry.get_all_tool_specs(), messages=self.messages
)

self._agent_loop = BidirectionalAgentLoop(model=self.model, agent=self)
self._agent_loop = BidiAgentLoop(model=self.model, agent=self)
await self._agent_loop.start()

logger.debug("Conversation ready")
Expand Down Expand Up @@ -306,7 +306,7 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]:
except asyncio.TimeoutError:
continue

async def end(self) -> None:
async def stop(self) -> None:
"""End the conversation connection and cleanup all resources.

Terminates the streaming connection, cancels background tasks, and
Expand All @@ -316,7 +316,7 @@ async def end(self) -> None:
await self._agent_loop.stop()
self._agent_loop = None

async def __aenter__(self) -> "BidirectionalAgent":
async def __aenter__(self) -> "BidiAgent":
"""Async context manager entry point.

Automatically starts the bidirectional connection when entering the context.
Expand Down Expand Up @@ -350,7 +350,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
for adapter in self._current_adapters:
if hasattr(adapter, "cleanup"):
try:
adapter.end()
adapter.stop()
logger.debug(f"Cleaned up adapter: {type(adapter).__name__}")
except Exception as adapter_error:
logger.warning(f"Error cleaning up adapter: {adapter_error}")
Expand All @@ -359,7 +359,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
self._current_adapters = []

# Cleanup agent connection
await self.end()
await self.stop()

except Exception as cleanup_error:
if exc_type is None:
Expand Down Expand Up @@ -393,7 +393,7 @@ async def run(self, io_channels: list[BidiIO | tuple[Callable, Callable]]) -> No
```python
# With IO channel
audio_io = AudioIO(audio_config={"input_sample_rate": 16000})
agent = BidirectionalAgent(model=model, tools=[calculator])
agent = BidiAgent(model=model, tools=[calculator])
await agent.run(io_channels=[audio_io])

# With tuple (backward compatibility)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ....types._events import ToolResultEvent, ToolStreamEvent
from ....types.content import Message
from ....types.tools import ToolResult, ToolUse
from ..models.bidirectional_model import BidirectionalModel
from ..models.bidirectional_model import BidiModel

logger = logging.getLogger(__name__)

Expand All @@ -37,12 +37,12 @@ class BidirectionalConnection:
handling while providing a simple interface for agent interactions.
"""

def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> None:
def __init__(self, model: BidiModel, agent: "BidiAgent") -> None:
"""Initialize connection with model and agent reference.

Args:
model: Bidirectional model instance.
agent: BidirectionalAgent instance for tool registry access.
agent: BidiAgent instance for tool registry access.
"""
self.model = model
self.agent = agent
Expand All @@ -64,22 +64,22 @@ def __init__(self, model: BidirectionalModel, agent: "BidirectionalAgent") -> No
self.tool_count = 0


async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection:
async def start_bidirectional_connection(agent: "BidiAgent") -> BidirectionalConnection:
"""Initialize bidirectional session with conycurrent background tasks.

Creates a model-specific session and starts background tasks for processing
model events, executing tools, and managing the session lifecycle.

Args:
agent: BidirectionalAgent instance.
agent: BidiAgent instance.

Returns:
BidirectionalConnection: Active session with background tasks running.
"""
logger.debug("Starting bidirectional session - initializing model connection")

# Connect to model
await agent.model.connect(
await agent.model.start(
system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages
)

Expand Down Expand Up @@ -136,7 +136,7 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non
await asyncio.gather(*all_tasks, return_exceptions=True)

# Close model connection
await session.model.close()
await session.model.stop()
logger.debug("Connection closed")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def receive(self, event: dict) -> None:
elif role.upper() == "USER":
print(f"User: {text}")

def end(self) -> None:
def stop(self) -> None:
"""Clean up IO channel resources."""
try:
if self.input_stream:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Bidirectional model interfaces and implementations."""

from .bidirectional_model import BidirectionalModel
from .gemini_live import GeminiLiveModel
from .novasonic import NovaSonicModel
from .openai import OpenAIRealtimeModel
from .bidirectional_model import BidiModel
from .gemini_live import BidiGeminiLiveModel
from .novasonic import BidiNovaSonicModel
from .openai import BidiOpenAIRealtimeModel

__all__ = [
"BidirectionalModel",
"GeminiLiveModel",
"NovaSonicModel",
"OpenAIRealtimeModel",
"BidiModel",
"BidiGeminiLiveModel",
"BidiNovaSonicModel",
"BidiOpenAIRealtimeModel",
]
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@
logger = logging.getLogger(__name__)


class BidirectionalModel(Protocol):
class BidiModel(Protocol):
"""Protocol for bidirectional streaming models.

This interface defines the contract for models that support persistent streaming
connections with real-time audio and text communication. Implementations handle
provider-specific protocols while exposing a standardized event-based API.
"""

async def connect(
async def start(
self,
system_prompt: str | None = None,
tools: list[ToolSpec] | None = None,
Expand All @@ -56,12 +56,12 @@ async def connect(
"""
...

async def close(self) -> None:
async def stop(self) -> None:
"""Close the streaming connection and release resources.

Terminates the active bidirectional connection and cleans up any associated
resources such as network connections, buffers, or background tasks. After
calling close(), the model instance cannot be used until connect() is called again.
calling close(), the model instance cannot be used until start() is called again.
"""
...

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Gemini Live API bidirectional model provider using official Google GenAI SDK.

Implements the BidirectionalModel interface for Google's Gemini Live API using the
Implements the BidiModel interface for Google's Gemini Live API using the
official Google GenAI SDK for simplified and robust WebSocket communication.

Key improvements over custom WebSocket implementation:
Expand Down Expand Up @@ -34,7 +34,7 @@
TextOutputEvent,
TranscriptEvent,
)
from .bidirectional_model import BidirectionalModel
from .bidirectional_model import BidiModel

logger = logging.getLogger(__name__)

Expand All @@ -44,7 +44,7 @@
GEMINI_CHANNELS = 1


class GeminiLiveModel(BidirectionalModel):
class BidiGeminiLiveModel(BidiModel):
"""Gemini Live API implementation using official Google GenAI SDK.

Combines model configuration and connection state in a single class.
Expand Down Expand Up @@ -82,13 +82,13 @@ def __init__(

self.client = genai.Client(**client_kwargs)

# Connection state (initialized in connect())
# Connection state (initialized in start())
self.live_session = None
self.live_session_context_manager = None
self.session_id = None
self._active = False

async def connect(
async def start(
self,
system_prompt: Optional[str] = None,
tools: Optional[List[ToolSpec]] = None,
Expand Down Expand Up @@ -404,7 +404,7 @@ async def _send_tool_result(self, tool_result: ToolResult) -> None:
except Exception as e:
logger.error("Error sending tool result: %s", e)

async def close(self) -> None:
async def stop(self) -> None:
"""Close Gemini Live API connection."""
if not self._active:
return
Expand Down Expand Up @@ -435,7 +435,7 @@ def _build_live_config(
if self.live_config:
config_dict.update(self.live_config)

# Override with any kwargs from connect()
# Override with any kwargs from start()
config_dict.update(kwargs)

# Add system instruction if provided
Expand Down
Loading
Loading