-
Notifications
You must be signed in to change notification settings - Fork 0
isolate model inference configs #100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,6 @@ | |
| from ....types.content import Messages | ||
| from ....types.tools import ToolResult, ToolSpec, ToolUse | ||
| from .._async import stop_all | ||
| from ..types.model import AudioConfig | ||
| from ..types.events import ( | ||
| AudioChannel, | ||
| AudioSampleRate, | ||
|
|
@@ -41,6 +40,7 @@ | |
| BidiUsageEvent, | ||
| ModalityUsage, | ||
| ) | ||
| from ..types.model import AudioConfig | ||
| from .model import BidiModel, BidiModelTimeoutError | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -70,7 +70,7 @@ def __init__( | |
|
|
||
| Args: | ||
| model_id: Model identifier (default: gemini-2.5-flash-native-audio-preview-09-2025) | ||
| provider_config: Model behavior (audio, response_modalities, speech_config, transcription) | ||
| provider_config: Model behavior (audio, inference) | ||
| client_config: Authentication (api_key, http_options) | ||
| **kwargs: Reserved for future parameters. | ||
|
|
||
|
|
@@ -108,44 +108,28 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: | |
|
|
||
| def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: | ||
| """Merge user config with defaults (user takes precedence).""" | ||
| # Extract voice from provider-specific speech_config.voice_config.prebuilt_voice_config.voice_name if present | ||
| provider_voice = None | ||
| if "speech_config" in config and isinstance(config["speech_config"], dict): | ||
| provider_voice = ( | ||
| config["speech_config"].get("voice_config", {}).get("prebuilt_voice_config", {}).get("voice_name") | ||
| ) | ||
|
|
||
| # Define default audio configuration | ||
| default_audio: AudioConfig = { | ||
| "input_rate": GEMINI_INPUT_SAMPLE_RATE, | ||
| "output_rate": GEMINI_OUTPUT_SAMPLE_RATE, | ||
| "channels": GEMINI_CHANNELS, | ||
| "format": "pcm", | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we should have a default voice here
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does work without specifying. I tested that on all the models actually. With that said, we could remove the default voice setting on all configs but I didn't want to make too many changes here. |
||
| } | ||
|
|
||
| if provider_voice: | ||
| default_audio["voice"] = provider_voice | ||
|
|
||
| user_audio = config.get("audio", {}) | ||
| merged_audio = {**default_audio, **user_audio} | ||
|
|
||
| default_provider_settings = { | ||
| default_inference = { | ||
| "response_modalities": ["AUDIO"], | ||
| "outputAudioTranscription": {}, | ||
| "inputAudioTranscription": {}, | ||
| } | ||
|
|
||
| resolved = { | ||
| **default_provider_settings, | ||
| **config, | ||
| "audio": merged_audio, # Audio always uses merged version | ||
| "audio": { | ||
| **default_audio, | ||
| **config.get("audio", {}), | ||
| }, | ||
| "inference": { | ||
| **default_inference, | ||
| **config.get("inference", {}), | ||
| }, | ||
| } | ||
|
|
||
| if user_audio: | ||
| logger.debug("audio_config | merged user-provided config with defaults") | ||
| else: | ||
| logger.debug("audio_config | using default Gemini Live audio configuration") | ||
|
|
||
| return resolved | ||
|
|
||
| async def start( | ||
|
|
@@ -505,9 +489,7 @@ def _build_live_config( | |
| Simply passes through all config parameters from provider_config, allowing users | ||
| to configure any Gemini Live API parameter directly. | ||
| """ | ||
| config_dict: dict[str, Any] = {} | ||
| if self.config: | ||
| config_dict.update({k: v for k, v in self.config.items() if k != "audio"}) | ||
| config_dict: dict[str, Any] = self.config["inference"].copy() | ||
|
|
||
| config_dict["session_resumption"] = {"handle": kwargs.get("live_session_handle")} | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,6 @@ | |
| from ....types.content import Messages | ||
| from ....types.tools import ToolResult, ToolSpec, ToolUse | ||
| from .._async import stop_all | ||
| from ..types.model import AudioConfig | ||
| from ..types.events import ( | ||
| AudioChannel, | ||
| AudioSampleRate, | ||
|
|
@@ -53,12 +52,16 @@ | |
| BidiTranscriptStreamEvent, | ||
| BidiUsageEvent, | ||
| ) | ||
| from ..types.model import AudioConfig | ||
| from .model import BidiModel, BidiModelTimeoutError | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Nova Sonic configuration constants | ||
| NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to explicitly provide defaults. Nova already has implicit defaults for these that we can rely on. |
||
| _NOVA_INFERENCE_CONFIG_KEYS = { | ||
| "max_tokens": "maxTokens", | ||
| "temperature": "temperature", | ||
| "top_p": "topP", | ||
| } | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using to promote consistency. We use snake_case everywhere else. |
||
|
|
||
| NOVA_AUDIO_INPUT_CONFIG = { | ||
| "mediaType": "audio/lpcm", | ||
|
|
@@ -156,28 +159,21 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: | |
|
|
||
| def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: | ||
| """Merge user config with defaults (user takes precedence).""" | ||
| # Define default audio configuration | ||
| default_audio_config: AudioConfig = { | ||
| default_audio: AudioConfig = { | ||
| "input_rate": cast(AudioSampleRate, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), | ||
| "output_rate": cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), | ||
| "channels": cast(AudioChannel, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), | ||
| "format": "pcm", | ||
| "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), | ||
| } | ||
|
|
||
| user_audio_config = config.get("audio", {}) | ||
| merged_audio = {**default_audio_config, **user_audio_config} | ||
|
|
||
| resolved = { | ||
| "audio": merged_audio, | ||
| **{k: v for k, v in config.items() if k != "audio"}, | ||
| "audio": { | ||
| **default_audio, | ||
| **config.get("audio", {}), | ||
| }, | ||
| "inference": config.get("inference", {}), | ||
| } | ||
|
|
||
| if user_audio_config: | ||
| logger.debug("audio_config | merged user-provided config with defaults") | ||
| else: | ||
| logger.debug("audio_config | using default Nova Sonic audio configuration") | ||
|
|
||
| return resolved | ||
|
|
||
| async def start( | ||
|
|
@@ -577,7 +573,8 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N | |
|
|
||
| def _get_connection_start_event(self) -> str: | ||
| """Generate Nova Sonic connection start event.""" | ||
| return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) | ||
| inference_config = {_NOVA_INFERENCE_CONFIG_KEYS[key]: value for key, value in self.config["inference"].items()} | ||
| return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": inference_config}}}) | ||
|
|
||
| def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: | ||
| """Generate Nova Sonic prompt start event with tool configuration.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,6 @@ | |
| from ....types.content import Messages | ||
| from ....types.tools import ToolResult, ToolSpec, ToolUse | ||
| from .._async import stop_all | ||
| from ..types.model import AudioConfig | ||
| from ..types.events import ( | ||
| AudioSampleRate, | ||
| BidiAudioInputEvent, | ||
|
|
@@ -37,6 +36,7 @@ | |
| Role, | ||
| StopReason, | ||
| ) | ||
| from ..types.model import AudioConfig | ||
| from .model import BidiModel, BidiModelTimeoutError | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -160,34 +160,21 @@ def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: | |
|
|
||
| def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: | ||
| """Merge user config with defaults (user takes precedence).""" | ||
| # Extract voice from provider-specific audio.output.voice if present | ||
| provider_voice = None | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Voice provided through "audio" config of type |
||
| if "audio" in config and isinstance(config["audio"], dict): | ||
| if "output" in config["audio"] and isinstance(config["audio"]["output"], dict): | ||
| provider_voice = config["audio"]["output"].get("voice") | ||
|
|
||
| # Define default audio configuration | ||
| default_audio: AudioConfig = { | ||
| "input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), | ||
| "output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), | ||
| "channels": 1, | ||
| "format": "pcm", | ||
| "voice": provider_voice or "alloy", | ||
| "voice": "alloy", | ||
| } | ||
|
|
||
| user_audio = config.get("audio", {}) | ||
| merged_audio = {**default_audio, **user_audio} | ||
|
|
||
| resolved = { | ||
| "audio": merged_audio, | ||
| **{k: v for k, v in config.items() if k != "audio"}, | ||
| "audio": { | ||
| **default_audio, | ||
| **config.get("audio", {}), | ||
| }, | ||
| "inference": config.get("inference", {}), | ||
| } | ||
|
|
||
| if user_audio: | ||
| logger.debug("audio_config | merged user-provided config with defaults") | ||
| else: | ||
| logger.debug("audio_config | using default OpenAI Realtime audio configuration") | ||
|
|
||
| return resolved | ||
|
|
||
| async def start( | ||
|
|
@@ -277,22 +264,12 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | |
|
|
||
| # Apply user-provided session configuration | ||
| supported_params = { | ||
| "type", | ||
| "max_output_tokens", | ||
| "output_modalities", | ||
| "instructions", | ||
| "voice", | ||
| "tools", | ||
| "tool_choice", | ||
| "input_audio_format", | ||
| "output_audio_format", | ||
| "input_audio_transcription", | ||
| "turn_detection", | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For more details on supported settings, see https://platform.openai.com/docs/api-reference/realtime-client-events/session/update#realtime_client_events-session-update-session. |
||
| } | ||
|
|
||
| for key, value in self.config.items(): | ||
| if key == "audio": | ||
| continue | ||
| elif key in supported_params: | ||
| for key, value in self.config["inference"].items(): | ||
| if key in supported_params: | ||
| config[key] = value | ||
| else: | ||
| logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Voice is passed in through the "audio" config.