diff --git a/libs/vertexai/langchain_google_vertexai/_utils.py b/libs/vertexai/langchain_google_vertexai/_utils.py index 08c0543b..d343647f 100644 --- a/libs/vertexai/langchain_google_vertexai/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_utils.py @@ -123,6 +123,13 @@ class GoogleModelFamily(str, Enum): @classmethod def _missing_(cls, value: Any) -> "GoogleModelFamily": + # https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning + if value.lower() in [ + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-pro-preview-0409", + ]: + return GoogleModelFamily.GEMINI_ADVANCED if "gemini" in value.lower(): return GoogleModelFamily.GEMINI if "code" in value.lower(): diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 33d26284..4979addd 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -1137,7 +1137,7 @@ def bind_tools( ) vertexai_tool = _format_to_gapic_tool(tools) if tool_choice: - all_names = [f["name"] for f in vertexai_tool.function_declarations] + all_names = [f.name for f in vertexai_tool.function_declarations] tool_config = _tool_choice_to_tool_config(tool_choice, all_names) # Bind dicts for easier serialization/deserialization. return self.bind(tools=[vertexai_tool], tool_config=tool_config, **kwargs) diff --git a/libs/vertexai/tests/unit_tests/test_utils.py b/libs/vertexai/tests/unit_tests/test_utils.py new file mode 100644 index 00000000..74db2009 --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_utils.py @@ -0,0 +1,47 @@ +from typing import List + +import pytest + +from langchain_google_vertexai._utils import GoogleModelFamily + + +@pytest.mark.parametrize( + "srcs,exp", + [ + ( + [ + "chat-bison@001", + "text-bison@002", + ], + GoogleModelFamily.PALM, + ), + ( + [ + "code-bison@002", + "code-gecko@002", + ], + GoogleModelFamily.CODEY, + ), + ( + [ + "gemini-1.0-pro-001", + "gemini-1.0-pro-002", + "gemini-1.0-pro-vision-001", + "gemini-1.0-pro-vision", + ], + GoogleModelFamily.GEMINI, + ), + ( + [ + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-pro-preview-0409", + ], + GoogleModelFamily.GEMINI_ADVANCED, + ), + ], +) +def test_google_model_family(srcs: List[str], exp: GoogleModelFamily): + for src in srcs: + res = GoogleModelFamily(src) + assert res == exp