Skip to content

Commit

Permalink
fixed forced tool calls on gemini-preview models (#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
alx13 committed May 21, 2024
1 parent 37e637f commit 2f7be6f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
7 changes: 7 additions & 0 deletions libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ class GoogleModelFamily(str, Enum):

@classmethod
def _missing_(cls, value: Any) -> "GoogleModelFamily":
# https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning
if value.lower() in [
"gemini-1.5-flash-preview-0514",
"gemini-1.5-pro-preview-0514",
"gemini-1.5-pro-preview-0409",
]:
return GoogleModelFamily.GEMINI_ADVANCED
if "gemini" in value.lower():
return GoogleModelFamily.GEMINI
if "code" in value.lower():
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ def bind_tools(
)
vertexai_tool = _format_to_gapic_tool(tools)
if tool_choice:
all_names = [f["name"] for f in vertexai_tool.function_declarations]
all_names = [f.name for f in vertexai_tool.function_declarations]
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
# Bind dicts for easier serialization/deserialization.
return self.bind(tools=[vertexai_tool], tool_config=tool_config, **kwargs)
Expand Down
47 changes: 47 additions & 0 deletions libs/vertexai/tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import List

import pytest

from langchain_google_vertexai._utils import GoogleModelFamily


@pytest.mark.parametrize(
"srcs,exp",
[
(
[
"chat-bison@001",
"text-bison@002",
],
GoogleModelFamily.PALM,
),
(
[
"code-bison@002",
"code-gecko@002",
],
GoogleModelFamily.CODEY,
),
(
[
"gemini-1.0-pro-001",
"gemini-1.0-pro-002",
"gemini-1.0-pro-vision-001",
"gemini-1.0-pro-vision",
],
GoogleModelFamily.GEMINI,
),
(
[
"gemini-1.5-flash-preview-0514",
"gemini-1.5-pro-preview-0514",
"gemini-1.5-pro-preview-0409",
],
GoogleModelFamily.GEMINI_ADVANCED,
),
],
)
def test_google_model_family(srcs: List[str], exp: GoogleModelFamily):
for src in srcs:
res = GoogleModelFamily(src)
assert res == exp

0 comments on commit 2f7be6f

Please sign in to comment.