Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 12 additions & 30 deletions src/strands/experimental/bidi/models/gemini_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ....types.content import Messages
from ....types.tools import ToolResult, ToolSpec, ToolUse
from .._async import stop_all
from ..types.model import AudioConfig
from ..types.events import (
AudioChannel,
AudioSampleRate,
Expand All @@ -41,6 +40,7 @@
BidiUsageEvent,
ModalityUsage,
)
from ..types.model import AudioConfig
from .model import BidiModel, BidiModelTimeoutError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(

Args:
model_id: Model identifier (default: gemini-2.5-flash-native-audio-preview-09-2025)
provider_config: Model behavior (audio, response_modalities, speech_config, transcription)
provider_config: Model behavior (audio, inference)
client_config: Authentication (api_key, http_options)
**kwargs: Reserved for future parameters.

Expand Down Expand Up @@ -108,44 +108,28 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]:

def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]:
"""Merge user config with defaults (user takes precedence)."""
# Extract voice from provider-specific speech_config.voice_config.prebuilt_voice_config.voice_name if present
provider_voice = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Voice is passed in through the "audio" config.

if "speech_config" in config and isinstance(config["speech_config"], dict):
provider_voice = (
config["speech_config"].get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name")
)

# Define default audio configuration
default_audio: AudioConfig = {
"input_rate": GEMINI_INPUT_SAMPLE_RATE,
"output_rate": GEMINI_OUTPUT_SAMPLE_RATE,
"channels": GEMINI_CHANNELS,
"format": "pcm",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should have a default voice here

Copy link
Collaborator Author

@pgrayy pgrayy Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work without specifying. I tested that on all the models actually. With that said, we could remove the default voice setting on all configs but I didn't want to make too many changes here.

}

if provider_voice:
default_audio["voice"] = provider_voice

user_audio = config.get("audio", {})
merged_audio = {**default_audio, **user_audio}

default_provider_settings = {
default_inference = {
"response_modalities": ["AUDIO"],
"outputAudioTranscription": {},
"inputAudioTranscription": {},
}

resolved = {
**default_provider_settings,
**config,
"audio": merged_audio, # Audio always uses merged version
"audio": {
**default_audio,
**config.get("audio", {}),
},
"inference": {
**default_inference,
**config.get("inference", {}),
},
}

if user_audio:
logger.debug("audio_config | merged user-provided config with defaults")
else:
logger.debug("audio_config | using default Gemini Live audio configuration")

return resolved

async def start(
Expand Down Expand Up @@ -505,9 +489,7 @@ def _build_live_config(
Simply passes through all config parameters from provider_config, allowing users
to configure any Gemini Live API parameter directly.
"""
config_dict: dict[str, Any] = {}
if self.config:
config_dict.update({k: v for k, v in self.config.items() if k != "audio"})
config_dict: dict[str, Any] = self.config["inference"].copy()

config_dict["session_resumption"] = {"handle": kwargs.get("live_session_handle")}

Expand Down
31 changes: 14 additions & 17 deletions src/strands/experimental/bidi/models/nova_sonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from ....types.content import Messages
from ....types.tools import ToolResult, ToolSpec, ToolUse
from .._async import stop_all
from ..types.model import AudioConfig
from ..types.events import (
AudioChannel,
AudioSampleRate,
Expand All @@ -53,12 +52,16 @@
BidiTranscriptStreamEvent,
BidiUsageEvent,
)
from ..types.model import AudioConfig
from .model import BidiModel, BidiModelTimeoutError

logger = logging.getLogger(__name__)

# Nova Sonic configuration constants
NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to explicitly provide defaults. Nova already has implicit defaults for these that we can rely on.

_NOVA_INFERENCE_CONFIG_KEYS = {
"max_tokens": "maxTokens",
"temperature": "temperature",
"top_p": "topP",
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using to promote consistency. We use snake_case everywhere else.


NOVA_AUDIO_INPUT_CONFIG = {
"mediaType": "audio/lpcm",
Expand Down Expand Up @@ -156,28 +159,21 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]:

def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]:
"""Merge user config with defaults (user takes precedence)."""
# Define default audio configuration
default_audio_config: AudioConfig = {
default_audio: AudioConfig = {
"input_rate": cast(AudioSampleRate, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]),
"output_rate": cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]),
"channels": cast(AudioChannel, NOVA_AUDIO_INPUT_CONFIG["channelCount"]),
"format": "pcm",
"voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]),
}

user_audio_config = config.get("audio", {})
merged_audio = {**default_audio_config, **user_audio_config}

resolved = {
"audio": merged_audio,
**{k: v for k, v in config.items() if k != "audio"},
"audio": {
**default_audio,
**config.get("audio", {}),
},
"inference": config.get("inference", {}),
}

if user_audio_config:
logger.debug("audio_config | merged user-provided config with defaults")
else:
logger.debug("audio_config | using default Nova Sonic audio configuration")

return resolved

async def start(
Expand Down Expand Up @@ -577,7 +573,8 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N

def _get_connection_start_event(self) -> str:
"""Generate Nova Sonic connection start event."""
return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}})
inference_config = {_NOVA_INFERENCE_CONFIG_KEYS[key]: value for key, value in self.config["inference"].items()}
return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": inference_config}}})

def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str:
"""Generate Nova Sonic prompt start event with tool configuration."""
Expand Down
43 changes: 10 additions & 33 deletions src/strands/experimental/bidi/models/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ....types.content import Messages
from ....types.tools import ToolResult, ToolSpec, ToolUse
from .._async import stop_all
from ..types.model import AudioConfig
from ..types.events import (
AudioSampleRate,
BidiAudioInputEvent,
Expand All @@ -37,6 +36,7 @@
Role,
StopReason,
)
from ..types.model import AudioConfig
from .model import BidiModel, BidiModelTimeoutError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -160,34 +160,21 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]:

def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]:
"""Merge user config with defaults (user takes precedence)."""
# Extract voice from provider-specific audio.output.voice if present
provider_voice = None
Copy link
Collaborator Author

@pgrayy pgrayy Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Voice provided through "audio" config of type AudioConfig.

if "audio" in config and isinstance(config["audio"], dict):
if "output" in config["audio"] and isinstance(config["audio"]["output"], dict):
provider_voice = config["audio"]["output"].get("voice")

# Define default audio configuration
default_audio: AudioConfig = {
"input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE),
"output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE),
"channels": 1,
"format": "pcm",
"voice": provider_voice or "alloy",
"voice": "alloy",
}

user_audio = config.get("audio", {})
merged_audio = {**default_audio, **user_audio}

resolved = {
"audio": merged_audio,
**{k: v for k, v in config.items() if k != "audio"},
"audio": {
**default_audio,
**config.get("audio", {}),
},
"inference": config.get("inference", {}),
}

if user_audio:
logger.debug("audio_config | merged user-provided config with defaults")
else:
logger.debug("audio_config | using default OpenAI Realtime audio configuration")

return resolved

async def start(
Expand Down Expand Up @@ -277,22 +264,12 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec]

# Apply user-provided session configuration
supported_params = {
"type",
"max_output_tokens",
"output_modalities",
"instructions",
"voice",
"tools",
"tool_choice",
"input_audio_format",
"output_audio_format",
"input_audio_transcription",
"turn_detection",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • type always has to be realtime and is already set by us.
  • instructions is set by us through system prompt.
  • voice is set by us through "audio" config.
  • tools is set by us through the passed in tools param.
  • input_audio_format, output_audio_format, input_audio_transcription, and turn_detection are not top-level configs and so would lead to exceptions if setting.

For more details on supported settings, see https://platform.openai.com/docs/api-reference/realtime-client-events/session/update#realtime_client_events-session-update-session.

}

for key, value in self.config.items():
if key == "audio":
continue
elif key in supported_params:
for key, value in self.config["inference"].items():
if key in supported_params:
config[key] = value
else:
logger.warning("parameter=<%s> | ignoring unsupported session parameter", key)
Expand Down
14 changes: 7 additions & 7 deletions tests/strands/experimental/bidi/models/test_gemini_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,23 @@ def test_model_initialization(mock_genai_client, model_id, api_key):
assert model_default.api_key is None
assert model_default._live_session is None
# Check default config includes transcription
assert model_default.config["response_modalities"] == ["AUDIO"]
assert "outputAudioTranscription" in model_default.config
assert "inputAudioTranscription" in model_default.config
assert model_default.config["inference"]["response_modalities"] == ["AUDIO"]
assert "outputAudioTranscription" in model_default.config["inference"]
assert "inputAudioTranscription" in model_default.config["inference"]

# Test with API key
model_with_key = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key})
assert model_with_key.model_id == model_id
assert model_with_key.api_key == api_key

# Test with custom config (merges with defaults)
provider_config = {"temperature": 0.7, "top_p": 0.9}
provider_config = {"inference": {"temperature": 0.7, "top_p": 0.9}}
model_custom = BidiGeminiLiveModel(model_id=model_id, provider_config=provider_config)
# Custom config should be merged with defaults
assert model_custom.config["temperature"] == 0.7
assert model_custom.config["top_p"] == 0.9
assert model_custom.config["inference"]["temperature"] == 0.7
assert model_custom.config["inference"]["top_p"] == 0.9
# Defaults should still be present
assert "response_modalities" in model_custom.config
assert "response_modalities" in model_custom.config["inference"]


# Connection Tests
Expand Down
2 changes: 1 addition & 1 deletion tests/strands/experimental/bidi/models/test_nova_sonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def mock_stream():
@pytest.fixture
def mock_client(mock_stream):
"""Mock Bedrock Runtime client."""
with patch("strands.experimental.bidi.models.novasonic.BedrockRuntimeClient") as mock_cls:
with patch("strands.experimental.bidi.models.nova_sonic.BedrockRuntimeClient") as mock_cls:
mock_instance = AsyncMock()
mock_instance.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream)
mock_cls.return_value = mock_instance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def mock_websockets_connect(mock_websocket):
async def async_connect(*args, **kwargs):
return mock_websocket

with unittest.mock.patch("strands.experimental.bidi.models.openai.websockets.connect") as mock_connect:
with unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.websockets.connect") as mock_connect:
mock_connect.side_effect = async_connect
yield mock_connect, mock_websocket

Expand Down Expand Up @@ -515,7 +515,7 @@ async def test_receive_lifecycle_events(mock_websocket, model):
assert tru_events == exp_events


@unittest.mock.patch("strands.experimental.bidi.models.openai.time.time")
@unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.time.time")
@pytest.mark.asyncio
async def test_receive_timeout(mock_time, model):
mock_time.side_effect = [1, 2]
Expand Down
Loading