Skip to content

Commit

Permalink
added tool config (#135)
Browse files Browse the repository at this point in the history
* tool-config

---------

Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
  • Loading branch information
alx13 and lkuligin authored Apr 11, 2024
1 parent 5523121 commit 09978f1
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 53 deletions.
6 changes: 5 additions & 1 deletion libs/vertexai/langchain_google_vertexai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from langchain_google_vertexai.chains import create_structured_runnable
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
from langchain_google_vertexai.functions_utils import (
PydanticFunctionsOutputParser,
ToolConfig,
)
from langchain_google_vertexai.gemma import (
GemmaChatLocalHF,
GemmaChatLocalKaggle,
Expand Down Expand Up @@ -42,6 +45,7 @@
"HarmBlockThreshold",
"HarmCategory",
"PydanticFunctionsOutputParser",
"ToolConfig",
"create_structured_runnable",
"VertexAIImageCaptioning",
"VertexAIImageCaptioningChat",
Expand Down
88 changes: 48 additions & 40 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
is_gemini_model,
)
from langchain_google_vertexai.functions_utils import (
_format_tool_config,
_format_tools_to_vertex_tool,
)

Expand Down Expand Up @@ -164,17 +165,22 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
convert_system_message_to_human_content = None
system_instruction = None
for i, message in enumerate(history):
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
if isinstance(message, SystemMessage):
if system_instruction is not None:
raise ValueError(
"Detected more than one SystemMessage in the list of messages."
"Gemini APIs support the insertion of only SystemMessage."
"Gemini APIs support the insertion of only one SystemMessage."
)
else:
if convert_system_message_to_human:
logger.warning(
"gemini models released from April 2024 support"
"SystemMessages natively. For best performances,"
"when working with these models,"
"set convert_system_message_to_human to False"
)
convert_system_message_to_human_content = message
continue
system_instruction = Content(
role="user", parts=_convert_to_parts(message)
)
Expand All @@ -184,11 +190,6 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
and isinstance(message, SystemMessage)
and convert_system_message_to_human
):
logger.warning(
"gemini models released from April 2024 support SystemMessages"
"natively. For best performances, when working with these models,"
"set convert_system_message_to_human to False"
)
convert_system_message_to_human_content = message
continue
elif isinstance(message, AIMessage):
Expand Down Expand Up @@ -299,12 +300,10 @@ def _get_client_with_sys_instruction(
client: GenerativeModel,
system_instruction: Content,
model_name: str,
safety_settings: Optional[Dict] = None,
):
if client._system_instruction != system_instruction:
client = GenerativeModel(
model_name=model_name,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
return client
Expand Down Expand Up @@ -472,24 +471,27 @@ def _generate(
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
self.client = _get_client_with_sys_instruction(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)

# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
raw_tool_config = (
params.pop("tool_config") if "tool_config" in params else None
)
tool_config = (
_format_tool_config(raw_tool_config) if raw_tool_config else None
)
with telemetry.tool_context_manager(self._user_agent):
response = chat.send_message(
message,
response = self.client.generate_content(
history_gemini,
generation_config=params,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
)
generations = [
Expand Down Expand Up @@ -562,24 +564,27 @@ async def _agenerate(
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()

self.client = _get_client_with_sys_instruction(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
raw_tool_config = (
params.pop("tool_config") if "tool_config" in params else None
)
tool_config = (
_format_tool_config(raw_tool_config) if raw_tool_config else None
)
with telemetry.tool_context_manager(self._user_agent):
response = await chat.send_message_async(
message,
response = await self.client.generate_content_async(
history_gemini,
generation_config=params,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
)
generations = [
Expand Down Expand Up @@ -630,25 +635,28 @@ def _stream(
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
self.client = _get_client_with_sys_instruction(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
raw_tool_config = (
params.pop("tool_config") if "tool_config" in params else None
)
tool_config = (
_format_tool_config(raw_tool_config) if raw_tool_config else None
)
with telemetry.tool_context_manager(self._user_agent):
responses = chat.send_message(
message,
responses = self.client.generate_content(
history_gemini,
stream=True,
generation_config=params,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
)
for response in responses:
message = _parse_response_candidate(
Expand All @@ -659,7 +667,7 @@ def _stream(
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
)
if run_manager:
if run_manager and isinstance(message.content, str):
run_manager.on_llm_new_token(message.content)
if isinstance(message, AIMessageChunk):
yield ChatGenerationChunk(
Expand Down Expand Up @@ -716,27 +724,27 @@ async def _astream(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
raw_tool_config = params.pop("tool_config") if "tool_config" in params else None
tool_config = _format_tool_config(raw_tool_config) if raw_tool_config else None
with telemetry.tool_context_manager(self._user_agent):
async for chunk in await chat.send_message_async(
message,
async for chunk in await self.client.generate_content_async(
history_gemini,
stream=True,
generation_config=params,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
):
message = _parse_response_candidate(chunk.candidates[0], streaming=True)
generation_info = get_generation_info(
chunk.candidates[0],
self._is_gemini_model,
usage_metadata=chunk.to_dict().get("usage_metadata"),
)
if run_manager:
if run_manager and isinstance(message.content, str):
await run_manager.on_llm_new_token(message.content)
if isinstance(message, AIMessageChunk):
yield ChatGenerationChunk(
Expand Down
15 changes: 15 additions & 0 deletions libs/vertexai/langchain_google_vertexai/functions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
Tool as VertexTool,
)

# FIXME: vertexai is not exporting ToolConfig
from vertexai.generative_models._generative_models import ( # type: ignore
ToolConfig,
)


def _format_pydantic_to_vertex_function(
pydantic_model: Type[BaseModel],
Expand Down Expand Up @@ -75,6 +80,16 @@ def _format_tools_to_vertex_tool(
return [VertexTool(function_declarations=function_declarations)]


def _format_tool_config(tool_config: Dict[str, Any]) -> Union[ToolConfig, None]:
if "function_calling_config" not in tool_config:
return None
return ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
**tool_config["function_calling_config"]
)
)


class ParametersSchema(BaseModel):
"""
This is a schema of currently supported definitions in function calling.
Expand Down
72 changes: 71 additions & 1 deletion libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import tool

from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory
from langchain_google_vertexai import (
ChatVertexAI,
HarmBlockThreshold,
HarmCategory,
ToolConfig,
)

model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]

Expand Down Expand Up @@ -338,3 +343,68 @@ def my_tool(name: str, age: int) -> None:
tool_call_chunk = gathered.tool_call_chunks[0]
assert tool_call_chunk["name"] == "my_tool"
assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}'


@pytest.mark.release
def test_chat_vertexai_gemini_function_calling_tool_config_any() -> None:
class MyModel(BaseModel):
name: str
age: int

safety = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
}
model = ChatVertexAI(
model_name="gemini-1.5-pro-preview-0409", safety_settings=safety
).bind(
functions=[MyModel],
tool_config={
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["MyModel"],
}
},
)
message = HumanMessage(content="My name is Erick and I am 27 years old")
response = model.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == "MyModel"
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}


@pytest.mark.release
def test_chat_vertexai_gemini_function_calling_tool_config_none() -> None:
class MyModel(BaseModel):
name: str
age: int

safety = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
}
model = ChatVertexAI(
model_name="gemini-1.5-pro-preview-0409", safety_settings=safety
).bind(
functions=[MyModel],
tool_config={
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.NONE,
}
},
)
message = HumanMessage(content="My name is Erick and I am 27 years old")
response = model.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content != ""
function_call = response.additional_kwargs.get("function_call")
assert function_call is None
24 changes: 13 additions & 11 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

import pytest
from google.cloud.aiplatform_v1beta1.types import (
Content,
Content as Content,
)
from google.cloud.aiplatform_v1beta1.types import (
FunctionCall,
FunctionResponse,
Part,
)
from google.cloud.aiplatform_v1beta1.types import (
Part as Part,
)
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
Expand Down Expand Up @@ -176,7 +180,7 @@ def test_parse_history_gemini_converted_message() -> None:
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
system_instructions, history = _parse_chat_history_gemini(
_, history = _parse_chat_history_gemini(
messages, convert_system_message_to_human=True
)
assert len(history) == 3
Expand Down Expand Up @@ -314,18 +318,16 @@ def test_default_params_gemini() -> None:
citation_metadata=None,
)
]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message

mock_generate_content = MagicMock(return_value=mock_response)
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
mock_model.generate_content = mock_generate_content
gm.return_value = mock_model

model = ChatVertexAI(model_name="gemini-pro")
message = HumanMessage(content=user_prompt)
_ = model([message])
mock_start_chat.assert_called_once_with(history=[])
_ = model.invoke([message])
mock_generate_content.assert_called_once()
assert mock_generate_content.call_args.args[0][0].parts[0].text == user_prompt


@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit 09978f1

Please sign in to comment.