Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class _LLMOptions:
seed: NotGivenOr[int]
safety_settings: NotGivenOr[list[types.SafetySettingOrDict]]
service_tier: NotGivenOr[types.ServiceTier]
media_resolution: NotGivenOr[types.MediaResolution]


BLOCKED_REASONS = [
Expand Down Expand Up @@ -119,6 +120,7 @@ def __init__(
seed: NotGivenOr[int] = NOT_GIVEN,
safety_settings: NotGivenOr[list[types.SafetySettingOrDict]] = NOT_GIVEN,
service_tier: NotGivenOr[types.ServiceTier] = NOT_GIVEN,
media_resolution: NotGivenOr[types.MediaResolution] = NOT_GIVEN,
credentials: google.auth.credentials.Credentials | None = None,
) -> None:
"""
Expand Down Expand Up @@ -151,6 +153,7 @@ def __init__(
seed (int, optional): Random seed for reproducible generation. Defaults to None.
safety_settings (list[SafetySettingOrDict], optional): Safety settings for content filtering. Defaults to None.
service_tier (types.ServiceTier, optional): The service tier for the request (e.g. types.ServiceTier.PRIORITY). Defaults to None.
media_resolution (types.MediaResolution, optional): The media resolution for the request. Defaults to None.
""" # noqa: E501
super().__init__()
gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
Expand Down Expand Up @@ -224,6 +227,7 @@ def __init__(
seed=seed,
safety_settings=safety_settings,
service_tier=service_tier,
media_resolution=media_resolution,
)
self._client = Client(
api_key=gemini_api_key,
Expand Down Expand Up @@ -391,6 +395,9 @@ def chat(
if is_given(self._opts.service_tier):
extra["service_tier"] = self._opts.service_tier

if is_given(self._opts.media_resolution):
extra["media_resolution"] = self._opts.media_resolution

return LLMStream(
self,
client=self._client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class _RealtimeOptions:
image_encode_options: NotGivenOr[images.EncodeOptions]
conn_options: APIConnectOptions
http_options: NotGivenOr[types.HttpOptions]
media_resolution: NotGivenOr[types.MediaResolution] = NOT_GIVEN
enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN
proactivity: NotGivenOr[bool] = NOT_GIVEN
realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN
Expand Down Expand Up @@ -216,6 +217,7 @@ def __init__(
api_version: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
http_options: NotGivenOr[types.HttpOptions] = NOT_GIVEN,
media_resolution: NotGivenOr[types.MediaResolution] = NOT_GIVEN,
thinking_config: NotGivenOr[types.ThinkingConfig] = NOT_GIVEN,
credentials: google.auth.credentials.Credentials | None = None,
) -> None:
Expand Down Expand Up @@ -248,6 +250,7 @@ def __init__(
input_audio_transcription (AudioTranscriptionConfig | None, optional): The configuration for input audio transcription. Defaults to None.)
output_audio_transcription (AudioTranscriptionConfig | None, optional): The configuration for output audio transcription. Defaults to AudioTranscriptionConfig().
image_encode_options (images.EncodeOptions, optional): The configuration for image encoding. Defaults to DEFAULT_ENCODE_OPTIONS.
media_resolution (MediaResolution, optional): The media resolution for the session. Defaults to None.
enable_affective_dialog (bool, optional): Whether to enable affective dialog. Defaults to False.
proactivity (bool, optional): Whether to enable proactive audio. Defaults to False.
realtime_input_config (RealtimeInputConfig, optional): The configuration for realtime input. Defaults to None.
Expand Down Expand Up @@ -371,6 +374,7 @@ def __init__(
tool_response_scheduling=tool_response_scheduling,
conn_options=conn_options,
http_options=http_options,
media_resolution=media_resolution,
thinking_config=thinking_config,
session_resumption=session_resumption,
credentials=credentials,
Expand Down Expand Up @@ -1082,6 +1086,9 @@ def _build_connect_config(self) -> types.LiveConnectConfig:
thinking_config=self._opts.thinking_config
if is_given(self._opts.thinking_config)
else None,
media_resolution=self._opts.media_resolution
if is_given(self._opts.media_resolution)
else None,
),
system_instruction=types.Content(parts=[types.Part(text=self._opts.instructions)])
if is_given(self._opts.instructions)
Expand Down
66 changes: 65 additions & 1 deletion tests/test_plugin_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pytest
from google.genai import types

from livekit.plugins.google.llm import LLMStream
from livekit.agents import llm
from livekit.plugins.google.llm import LLM, LLMStream
from livekit.plugins.google.realtime.realtime_api import RealtimeModel, RealtimeSession


@pytest.fixture
Expand Down Expand Up @@ -63,3 +65,65 @@ def test_empty_text_part_returns_none(self, llm_stream: LLMStream):
chunk = llm_stream._parse_part("test-id", part)

assert chunk is None


class TestMediaResolution:
def test_llm_media_resolution_is_passed_to_stream_kwargs(self):
model = LLM(
api_key="test-api-key",
media_resolution=types.MediaResolution.MEDIA_RESOLUTION_LOW,
)

with patch.object(
LLMStream,
"__init__",
lambda self, *a, **kw: self.__dict__.update(_extra_kwargs=kw["extra_kwargs"]),
):
stream = model.chat(chat_ctx=llm.ChatContext.empty())

assert (
stream._extra_kwargs["media_resolution"] == types.MediaResolution.MEDIA_RESOLUTION_LOW
)

def test_llm_media_resolution_is_omitted_by_default(self):
model = LLM(api_key="test-api-key")

with patch.object(
LLMStream,
"__init__",
lambda self, *a, **kw: self.__dict__.update(_extra_kwargs=kw["extra_kwargs"]),
):
stream = model.chat(chat_ctx=llm.ChatContext.empty())

assert "media_resolution" not in stream._extra_kwargs

def test_realtime_media_resolution_is_passed_to_connect_config(self):
model = RealtimeModel(
api_key="test-api-key",
media_resolution=types.MediaResolution.MEDIA_RESOLUTION_LOW,
)
session = RealtimeSession.__new__(RealtimeSession)
session._opts = model._opts
session._tools = llm.ToolContext.empty()
session._realtime_model = model
session._session_resumption_handle = None

config = session._build_connect_config()

assert config.generation_config
assert (
config.generation_config.media_resolution == types.MediaResolution.MEDIA_RESOLUTION_LOW
)

def test_realtime_media_resolution_is_unset_by_default(self):
model = RealtimeModel(api_key="test-api-key")
session = RealtimeSession.__new__(RealtimeSession)
session._opts = model._opts
session._tools = llm.ToolContext.empty()
session._realtime_model = model
session._session_resumption_handle = None

config = session._build_connect_config()

assert config.generation_config
assert config.generation_config.media_resolution is None
Loading