diff --git a/libs/vertexai/langchain_google_vertexai/__init__.py b/libs/vertexai/langchain_google_vertexai/__init__.py index 5d08c169..6f37300f 100644 --- a/libs/vertexai/langchain_google_vertexai/__init__.py +++ b/libs/vertexai/langchain_google_vertexai/__init__.py @@ -2,7 +2,10 @@ from langchain_google_vertexai.chains import create_structured_runnable from langchain_google_vertexai.chat_models import ChatVertexAI from langchain_google_vertexai.embeddings import VertexAIEmbeddings -from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser +from langchain_google_vertexai.functions_utils import ( + PydanticFunctionsOutputParser, + ToolConfig, +) from langchain_google_vertexai.gemma import ( GemmaChatLocalHF, GemmaChatLocalKaggle, @@ -42,6 +45,7 @@ "HarmBlockThreshold", "HarmCategory", "PydanticFunctionsOutputParser", + "ToolConfig", "create_structured_runnable", "VertexAIImageCaptioning", "VertexAIImageCaptioningChat", diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 8a75a584..6e6ad005 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -88,6 +88,7 @@ is_gemini_model, ) from langchain_google_vertexai.functions_utils import ( + _format_tool_config, _format_tools_to_vertex_tool, ) @@ -164,17 +165,22 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: convert_system_message_to_human_content = None system_instruction = None for i, message in enumerate(history): - if ( - i == 0 - and isinstance(message, SystemMessage) - and not convert_system_message_to_human - ): + if isinstance(message, SystemMessage): if system_instruction is not None: raise ValueError( "Detected more than one SystemMessage in the list of messages." - "Gemini APIs support the insertion of only SystemMessage." + "Gemini APIs support the insertion of only one SystemMessage." ) else: + if convert_system_message_to_human: + logger.warning( + "gemini models released from April 2024 support" + "SystemMessages natively. For best performances," + "when working with these models," + "set convert_system_message_to_human to False" + ) + convert_system_message_to_human_content = message + continue system_instruction = Content( role="user", parts=_convert_to_parts(message) ) @@ -184,11 +190,6 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]: and isinstance(message, SystemMessage) and convert_system_message_to_human ): - logger.warning( - "gemini models released from April 2024 support SystemMessages" - "natively. For best performances, when working with these models," - "set convert_system_message_to_human to False" - ) convert_system_message_to_human_content = message continue elif isinstance(message, AIMessage): @@ -299,12 +300,10 @@ def _get_client_with_sys_instruction( client: GenerativeModel, system_instruction: Content, model_name: str, - safety_settings: Optional[Dict] = None, ): if client._system_instruction != system_instruction: client = GenerativeModel( model_name=model_name, - safety_settings=safety_settings, system_instruction=system_instruction, ) return client @@ -472,24 +471,27 @@ def _generate( project=self.project, convert_system_message_to_human=self.convert_system_message_to_human, ) - message = history_gemini.pop() self.client = _get_client_with_sys_instruction( client=self.client, system_instruction=system_instruction, model_name=self.model_name, - safety_settings=safety_settings, ) - with telemetry.tool_context_manager(self._user_agent): - chat = self.client.start_chat(history=history_gemini) # set param to `functions` until core tool/function calling implemented raw_tools = params.pop("functions") if "functions" in params else None tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + raw_tool_config = ( + params.pop("tool_config") if "tool_config" in params else None + ) + tool_config = ( + _format_tool_config(raw_tool_config) if raw_tool_config else None + ) with telemetry.tool_context_manager(self._user_agent): - response = chat.send_message( - message, + response = self.client.generate_content( + history_gemini, generation_config=params, tools=tools, + tool_config=tool_config, safety_settings=safety_settings, ) generations = [ @@ -562,24 +564,27 @@ async def _agenerate( project=self.project, convert_system_message_to_human=self.convert_system_message_to_human, ) - message = history_gemini.pop() self.client = _get_client_with_sys_instruction( client=self.client, system_instruction=system_instruction, model_name=self.model_name, - safety_settings=safety_settings, ) - with telemetry.tool_context_manager(self._user_agent): - chat = self.client.start_chat(history=history_gemini) # set param to `functions` until core tool/function calling implemented raw_tools = params.pop("functions") if "functions" in params else None tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + raw_tool_config = ( + params.pop("tool_config") if "tool_config" in params else None + ) + tool_config = ( + _format_tool_config(raw_tool_config) if raw_tool_config else None + ) with telemetry.tool_context_manager(self._user_agent): - response = await chat.send_message_async( - message, + response = await self.client.generate_content_async( + history_gemini, generation_config=params, tools=tools, + tool_config=tool_config, safety_settings=safety_settings, ) generations = [ @@ -630,25 +635,28 @@ def _stream( project=self.project, convert_system_message_to_human=self.convert_system_message_to_human, ) - message = history_gemini.pop() self.client = _get_client_with_sys_instruction( client=self.client, system_instruction=system_instruction, model_name=self.model_name, - safety_settings=safety_settings, ) - with telemetry.tool_context_manager(self._user_agent): - chat = self.client.start_chat(history=history_gemini) # set param to `functions` until core tool/function calling implemented raw_tools = params.pop("functions") if "functions" in params else None tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + raw_tool_config = ( + params.pop("tool_config") if "tool_config" in params else None + ) + tool_config = ( + _format_tool_config(raw_tool_config) if raw_tool_config else None + ) with telemetry.tool_context_manager(self._user_agent): - responses = chat.send_message( - message, + responses = self.client.generate_content( + history_gemini, stream=True, generation_config=params, - safety_settings=safety_settings, tools=tools, + tool_config=tool_config, + safety_settings=safety_settings, ) for response in responses: message = _parse_response_candidate( @@ -716,19 +724,19 @@ async def _astream( client=self.client, system_instruction=system_instruction, model_name=self.model_name, - safety_settings=safety_settings, ) - with telemetry.tool_context_manager(self._user_agent): - chat = self.client.start_chat(history=history_gemini) raw_tools = params.pop("functions") if "functions" in params else None tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + raw_tool_config = params.pop("tool_config") if "tool_config" in params else None + tool_config = _format_tool_config(raw_tool_config) if raw_tool_config else None with telemetry.tool_context_manager(self._user_agent): - async for chunk in await chat.send_message_async( - message, + async for chunk in await self.client.generate_content_async( + history_gemini, stream=True, generation_config=params, - safety_settings=safety_settings, tools=tools, + tool_config=tool_config, + safety_settings=safety_settings, ): message = _parse_response_candidate(chunk.candidates[0], streaming=True) generation_info = get_generation_info( diff --git a/libs/vertexai/langchain_google_vertexai/functions_utils.py b/libs/vertexai/langchain_google_vertexai/functions_utils.py index 89dbb7cc..abb0d00a 100644 --- a/libs/vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/vertexai/langchain_google_vertexai/functions_utils.py @@ -15,6 +15,11 @@ Tool as VertexTool, ) +# FIXME: vertexai is not exporting ToolConfig +from vertexai.generative_models._generative_models import ( # type: ignore + ToolConfig, +) + def _format_pydantic_to_vertex_function( pydantic_model: Type[BaseModel], @@ -75,6 +80,16 @@ def _format_tools_to_vertex_tool( return [VertexTool(function_declarations=function_declarations)] +def _format_tool_config(tool_config: Dict[str, Any]) -> Union[ToolConfig, None]: + if "function_calling_config" not in tool_config: + return None + return ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + **tool_config["function_calling_config"] + ) + ) + + class ParametersSchema(BaseModel): """ This is a schema of currently supported definitions in function calling. diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index 3a30d56d..66f07bdc 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -15,7 +15,12 @@ from langchain_core.pydantic_v1 import BaseModel from langchain_core.tools import tool -from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory +from langchain_google_vertexai import ( + ChatVertexAI, + HarmBlockThreshold, + HarmCategory, + ToolConfig, +) model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"] @@ -338,3 +343,68 @@ def my_tool(name: str, age: int) -> None: tool_call_chunk = gathered.tool_call_chunks[0] assert tool_call_chunk["name"] == "my_tool" assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}' + + +@pytest.mark.release +def test_chat_vertexai_gemini_function_calling_tool_config_any() -> None: + class MyModel(BaseModel): + name: str + age: int + + safety = { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + } + model = ChatVertexAI( + model_name="gemini-1.5-pro-preview-0409", safety_settings=safety + ).bind( + functions=[MyModel], + tool_config={ + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, + "allowed_function_names": ["MyModel"], + } + }, + ) + message = HumanMessage(content="My name is Erick and I am 27 years old") + response = model.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert response.content == "" + function_call = response.additional_kwargs.get("function_call") + assert function_call + assert function_call["name"] == "MyModel" + arguments_str = function_call.get("arguments") + assert arguments_str + arguments = json.loads(arguments_str) + assert arguments == { + "name": "Erick", + "age": 27.0, + } + + +@pytest.mark.release +def test_chat_vertexai_gemini_function_calling_tool_config_none() -> None: + class MyModel(BaseModel): + name: str + age: int + + safety = { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + } + model = ChatVertexAI( + model_name="gemini-1.5-pro-preview-0409", safety_settings=safety + ).bind( + functions=[MyModel], + tool_config={ + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.NONE, + } + }, + ) + message = HumanMessage(content="My name is Erick and I am 27 years old") + response = model.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert response.content != "" + function_call = response.additional_kwargs.get("function_call") + assert function_call is None diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index 9265ad4f..5f6f04ef 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -7,10 +7,14 @@ import pytest from google.cloud.aiplatform_v1beta1.types import ( - Content, + Content as Content, +) +from google.cloud.aiplatform_v1beta1.types import ( FunctionCall, FunctionResponse, - Part, +) +from google.cloud.aiplatform_v1beta1.types import ( + Part as Part, ) from google.cloud.aiplatform_v1beta1.types import ( content as gapic_content_types, @@ -176,7 +180,7 @@ def test_parse_history_gemini_converted_message() -> None: message2 = AIMessage(content=text_answer1) message3 = HumanMessage(content=text_question2) messages = [system_message, message1, message2, message3] - system_instructions, history = _parse_chat_history_gemini( + _, history = _parse_chat_history_gemini( messages, convert_system_message_to_human=True ) assert len(history) == 3 @@ -314,18 +318,17 @@ def test_default_params_gemini() -> None: citation_metadata=None, ) ] - mock_chat = MagicMock() - mock_send_message = MagicMock(return_value=mock_response) - mock_chat.send_message = mock_send_message - + mock_generate_content = MagicMock(return_value=mock_response) mock_model = MagicMock() - mock_start_chat = MagicMock(return_value=mock_chat) - mock_model.start_chat = mock_start_chat + mock_model.generate_content = mock_generate_content gm.return_value = mock_model + model = ChatVertexAI(model_name="gemini-pro") message = HumanMessage(content=user_prompt) - _ = model([message]) - mock_start_chat.assert_called_once_with(history=[]) + _ = model.invoke([message]) + mock_generate_content.assert_called_once() + assert mock_generate_content.call_args.args[0][0].parts[0].text == user_prompt + @pytest.mark.parametrize( diff --git a/libs/vertexai/tests/unit_tests/test_function_utils.py b/libs/vertexai/tests/unit_tests/test_function_utils.py index b69a4008..f79336a1 100644 --- a/libs/vertexai/tests/unit_tests/test_function_utils.py +++ b/libs/vertexai/tests/unit_tests/test_function_utils.py @@ -3,8 +3,10 @@ from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool +from vertexai.generative_models._generative_models import ToolConfig from langchain_google_vertexai.functions_utils import ( + _format_tool_config, _format_tool_to_vertex_function, _get_parameters_from_schema, ) @@ -52,6 +54,22 @@ def do_something_optional(a: float, b: float = 0) -> str: assert len(schema["parameters"]["required"]) == 1 +def test_format_tool_config(): + tool_config = _format_tool_config({}) + assert tool_config is None + + tool_config = _format_tool_config( + { + "function_calling_config": { + "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, + "allowed_function_names": "my_fun", + } + } + ) + assert isinstance(tool_config, ToolConfig) + + + def test_get_parameters_from_schema(): class StringEnum(str, Enum): pear = "pear" diff --git a/libs/vertexai/tests/unit_tests/test_imports.py b/libs/vertexai/tests/unit_tests/test_imports.py index fc578c88..7708c113 100644 --- a/libs/vertexai/tests/unit_tests/test_imports.py +++ b/libs/vertexai/tests/unit_tests/test_imports.py @@ -14,6 +14,7 @@ "HarmBlockThreshold", "HarmCategory", "PydanticFunctionsOutputParser", + "ToolConfig", "create_structured_runnable", "VertexAIImageCaptioning", "VertexAIImageCaptioningChat",