Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions google/genai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,16 @@ def _FunctionCallingConfig_to_mldev(
return to_object


def _is_text_embedding_batch(
contents: Union[types.ContentListUnion, types.ContentListUnionDict],
) -> bool:
return (
isinstance(contents, list)
and bool(contents)
and all(isinstance(content, str) for content in contents)
)


def _GenerateContentConfig_to_mldev(
api_client: BaseApiClient,
from_object: Union[dict[str, Any], object],
Expand Down Expand Up @@ -6335,7 +6345,9 @@ def embed_content(
)
"""
if not self._api_client.vertexai:
if 'gemini-embedding-2' in model:
if 'gemini-embedding-2' in model and not _is_text_embedding_batch(
contents
):
contents = t.t_contents(contents) # type: ignore[assignment]
return self._embed_content(model=model, contents=contents, config=config)

Expand Down Expand Up @@ -9296,7 +9308,9 @@ async def embed_content(
)
"""
if not self._api_client.vertexai:
if 'gemini-embedding-2' in model:
if 'gemini-embedding-2' in model and not _is_text_embedding_batch(
contents
):
contents = t.t_contents(contents) # type: ignore[assignment]
return await self._embed_content(
model=model, contents=contents, config=config
Expand Down
73 changes: 73 additions & 0 deletions google/genai/tests/models/test_embed_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from ... import _transformers as t
from ... import models
from ... import types
from .. import pytest_helper

Expand Down Expand Up @@ -227,6 +228,78 @@ def _get_bytes_from_file(relative_path: str) -> bytes:
)


class _FakeApiClient:
vertexai = False


def test_gemini_embedding_2_text_list_stays_batched(
monkeypatch, use_vertex, replays_prefix, http_options
):
module = models.Models(_FakeApiClient())
captured = {}

def fake_embed_content(**kwargs):
captured.update(kwargs)
return types.EmbedContentResponse()

monkeypatch.setattr(module, '_embed_content', fake_embed_content)

module.embed_content(
model='gemini-embedding-2-preview',
contents=['first text', 'second text'],
)

assert captured['contents'] == ['first text', 'second text']


def test_gemini_embedding_2_mixed_content_still_combines_parts(
monkeypatch, use_vertex, replays_prefix, http_options
):
module = models.Models(_FakeApiClient())
captured = {}

def fake_embed_content(**kwargs):
captured.update(kwargs)
return types.EmbedContentResponse()

monkeypatch.setattr(module, '_embed_content', fake_embed_content)

module.embed_content(
model='gemini-embedding-2-preview',
contents=[
'Similar things to the following image:',
types.Part.from_uri(
file_uri='gs://generativeai-downloads/images/scones.jpg',
mime_type='image/jpeg',
),
],
)

assert len(captured['contents']) == 1
assert len(captured['contents'][0].parts) == 2


@pytest.mark.asyncio
async def test_async_gemini_embedding_2_text_list_stays_batched(
monkeypatch, use_vertex, replays_prefix, http_options
):
module = models.AsyncModels(_FakeApiClient())
captured = {}

async def fake_embed_content(**kwargs):
captured.update(kwargs)
return types.EmbedContentResponse()

monkeypatch.setattr(module, '_embed_content', fake_embed_content)

await module.embed_content(
model='gemini-embedding-2-preview',
contents=['first text', 'second text'],
)

assert captured['contents'] == ['first text', 'second text']


def test_gemini_embedding_2_content_combination(client):
response = client.models.embed_content(
model='gemini-embedding-2-preview',
Expand Down