diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 3af8d707f..88d7f5a0c 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -25,7 +25,6 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.model import AudioConfig from ..types.events import ( AudioChannel, AudioSampleRate, @@ -41,6 +40,7 @@ BidiUsageEvent, ModalityUsage, ) +from ..types.model import AudioConfig from .model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) @@ -70,7 +70,7 @@ def __init__( Args: model_id: Model identifier (default: gemini-2.5-flash-native-audio-preview-09-2025) - provider_config: Model behavior (audio, response_modalities, speech_config, transcription) + provider_config: Model behavior (audio, inference) client_config: Authentication (api_key, http_options) **kwargs: Reserved for future parameters. @@ -108,44 +108,28 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: """Merge user config with defaults (user takes precedence).""" - # Extract voice from provider-specific speech_config.voice_config.prebuilt_voice_config.voice_name if present - provider_voice = None - if "speech_config" in config and isinstance(config["speech_config"], dict): - provider_voice = ( - config["speech_config"].get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") - ) - - # Define default audio configuration default_audio: AudioConfig = { "input_rate": GEMINI_INPUT_SAMPLE_RATE, "output_rate": GEMINI_OUTPUT_SAMPLE_RATE, "channels": GEMINI_CHANNELS, "format": "pcm", } - - if provider_voice: - default_audio["voice"] = provider_voice - - user_audio = config.get("audio", {}) - merged_audio = {**default_audio, **user_audio} - - default_provider_settings = { + default_inference = { "response_modalities": ["AUDIO"], "outputAudioTranscription": {}, "inputAudioTranscription": {}, } resolved = { - **default_provider_settings, - **config, - "audio": merged_audio, # Audio always uses merged version + "audio": { + **default_audio, + **config.get("audio", {}), + }, + "inference": { + **default_inference, + **config.get("inference", {}), + }, } - - if user_audio: - logger.debug("audio_config | merged user-provided config with defaults") - else: - logger.debug("audio_config | using default Gemini Live audio configuration") - return resolved async def start( @@ -505,9 +489,7 @@ def _build_live_config( Simply passes through all config parameters from provider_config, allowing users to configure any Gemini Live API parameter directly. """ - config_dict: dict[str, Any] = {} - if self.config: - config_dict.update({k: v for k, v in self.config.items() if k != "audio"}) + config_dict: dict[str, Any] = self.config["inference"].copy() config_dict["session_resumption"] = {"handle": kwargs.get("live_session_handle")} diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 0cfa51181..6a2477e22 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -37,7 +37,6 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.model import AudioConfig from ..types.events import ( AudioChannel, AudioSampleRate, @@ -53,12 +52,16 @@ BidiTranscriptStreamEvent, BidiUsageEvent, ) +from ..types.model import AudioConfig from .model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) -# Nova Sonic configuration constants -NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} +_NOVA_INFERENCE_CONFIG_KEYS = { + "max_tokens": "maxTokens", + "temperature": "temperature", + "top_p": "topP", +} NOVA_AUDIO_INPUT_CONFIG = { "mediaType": "audio/lpcm", @@ -156,8 +159,7 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: """Merge user config with defaults (user takes precedence).""" - # Define default audio configuration - default_audio_config: AudioConfig = { + default_audio: AudioConfig = { "input_rate": cast(AudioSampleRate, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), "output_rate": cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), "channels": cast(AudioChannel, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), @@ -165,19 +167,13 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), } - user_audio_config = config.get("audio", {}) - merged_audio = {**default_audio_config, **user_audio_config} - resolved = { - "audio": merged_audio, - **{k: v for k, v in config.items() if k != "audio"}, + "audio": { + **default_audio, + **config.get("audio", {}), + }, + "inference": config.get("inference", {}), } - - if user_audio_config: - logger.debug("audio_config | merged user-provided config with defaults") - else: - logger.debug("audio_config | using default Nova Sonic audio configuration") - return resolved async def start( @@ -577,7 +573,8 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" - return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) + inference_config = {_NOVA_INFERENCE_CONFIG_KEYS[key]: value for key, value in self.config["inference"].items()} + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": inference_config}}}) def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" diff --git a/src/strands/experimental/bidi/models/openai_realtime.py b/src/strands/experimental/bidi/models/openai_realtime.py index 39312c7d3..9196a39d5 100644 --- a/src/strands/experimental/bidi/models/openai_realtime.py +++ b/src/strands/experimental/bidi/models/openai_realtime.py @@ -19,7 +19,6 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.model import AudioConfig from ..types.events import ( AudioSampleRate, BidiAudioInputEvent, @@ -37,6 +36,7 @@ Role, StopReason, ) +from ..types.model import AudioConfig from .model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) @@ -160,34 +160,21 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: """Merge user config with defaults (user takes precedence).""" - # Extract voice from provider-specific audio.output.voice if present - provider_voice = None - if "audio" in config and isinstance(config["audio"], dict): - if "output" in config["audio"] and isinstance(config["audio"]["output"], dict): - provider_voice = config["audio"]["output"].get("voice") - - # Define default audio configuration default_audio: AudioConfig = { "input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), "output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), "channels": 1, "format": "pcm", - "voice": provider_voice or "alloy", + "voice": "alloy", } - user_audio = config.get("audio", {}) - merged_audio = {**default_audio, **user_audio} - resolved = { - "audio": merged_audio, - **{k: v for k, v in config.items() if k != "audio"}, + "audio": { + **default_audio, + **config.get("audio", {}), + }, + "inference": config.get("inference", {}), } - - if user_audio: - logger.debug("audio_config | merged user-provided config with defaults") - else: - logger.debug("audio_config | using default OpenAI Realtime audio configuration") - return resolved async def start( @@ -277,22 +264,12 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] # Apply user-provided session configuration supported_params = { - "type", + "max_output_tokens", "output_modalities", - "instructions", - "voice", - "tools", "tool_choice", - "input_audio_format", - "output_audio_format", - "input_audio_transcription", - "turn_detection", } - - for key, value in self.config.items(): - if key == "audio": - continue - elif key in supported_params: + for key, value in self.config["inference"].items(): + if key in supported_params: config[key] = value else: logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index c92211816..da516d4a0 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -98,9 +98,9 @@ def test_model_initialization(mock_genai_client, model_id, api_key): assert model_default.api_key is None assert model_default._live_session is None # Check default config includes transcription - assert model_default.config["response_modalities"] == ["AUDIO"] - assert "outputAudioTranscription" in model_default.config - assert "inputAudioTranscription" in model_default.config + assert model_default.config["inference"]["response_modalities"] == ["AUDIO"] + assert "outputAudioTranscription" in model_default.config["inference"] + assert "inputAudioTranscription" in model_default.config["inference"] # Test with API key model_with_key = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) @@ -108,13 +108,13 @@ def test_model_initialization(mock_genai_client, model_id, api_key): assert model_with_key.api_key == api_key # Test with custom config (merges with defaults) - provider_config = {"temperature": 0.7, "top_p": 0.9} + provider_config = {"inference": {"temperature": 0.7, "top_p": 0.9}} model_custom = BidiGeminiLiveModel(model_id=model_id, provider_config=provider_config) # Custom config should be merged with defaults - assert model_custom.config["temperature"] == 0.7 - assert model_custom.config["top_p"] == 0.9 + assert model_custom.config["inference"]["temperature"] == 0.7 + assert model_custom.config["inference"]["top_p"] == 0.9 # Defaults should still be present - assert "response_modalities" in model_custom.config + assert "response_modalities" in model_custom.config["inference"] # Connection Tests diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py index 7ec0c32a1..04f8043be 100644 --- a/tests/strands/experimental/bidi/models/test_nova_sonic.py +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -58,7 +58,7 @@ def mock_stream(): @pytest.fixture def mock_client(mock_stream): """Mock Bedrock Runtime client.""" - with patch("strands.experimental.bidi.models.novasonic.BedrockRuntimeClient") as mock_cls: + with patch("strands.experimental.bidi.models.nova_sonic.BedrockRuntimeClient") as mock_cls: mock_instance = AsyncMock() mock_instance.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream) mock_cls.return_value = mock_instance diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 805144446..5c9c0900d 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -46,7 +46,7 @@ def mock_websockets_connect(mock_websocket): async def async_connect(*args, **kwargs): return mock_websocket - with unittest.mock.patch("strands.experimental.bidi.models.openai.websockets.connect") as mock_connect: + with unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.websockets.connect") as mock_connect: mock_connect.side_effect = async_connect yield mock_connect, mock_websocket @@ -515,7 +515,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): assert tru_events == exp_events -@unittest.mock.patch("strands.experimental.bidi.models.openai.time.time") +@unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.time.time") @pytest.mark.asyncio async def test_receive_timeout(mock_time, model): mock_time.side_effect = [1, 2]