Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added tool config #135

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion libs/vertexai/langchain_google_vertexai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from langchain_google_vertexai.chains import create_structured_runnable
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
from langchain_google_vertexai.functions_utils import (
PydanticFunctionsOutputParser,
ToolConfig,
)
from langchain_google_vertexai.gemma import (
GemmaChatLocalHF,
GemmaChatLocalKaggle,
Expand Down Expand Up @@ -42,6 +45,7 @@
"HarmBlockThreshold",
"HarmCategory",
"PydanticFunctionsOutputParser",
"ToolConfig",
"create_structured_runnable",
"VertexAIImageCaptioning",
"VertexAIImageCaptioningChat",
Expand Down
88 changes: 48 additions & 40 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
is_gemini_model,
)
from langchain_google_vertexai.functions_utils import (
_format_tool_config,
_format_tools_to_vertex_tool,
)

Expand Down Expand Up @@ -164,17 +165,22 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
convert_system_message_to_human_content = None
system_instruction = None
for i, message in enumerate(history):
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
if isinstance(message, SystemMessage):
if system_instruction is not None:
raise ValueError(
"Detected more than one SystemMessage in the list of messages."
"Gemini APIs support the insertion of only SystemMessage."
"Gemini APIs support the insertion of only one SystemMessage."
)
else:
if convert_system_message_to_human:
logger.warning(
"gemini models released from April 2024 support"
"SystemMessages natively. For best performances,"
"when working with these models,"
"set convert_system_message_to_human to False"
)
convert_system_message_to_human_content = message
continue
system_instruction = Content(
role="user", parts=_convert_to_parts(message)
)
Expand All @@ -184,11 +190,6 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
and isinstance(message, SystemMessage)
and convert_system_message_to_human
):
logger.warning(
"gemini models released from April 2024 support SystemMessages"
"natively. For best performances, when working with these models,"
"set convert_system_message_to_human to False"
)
convert_system_message_to_human_content = message
continue
elif isinstance(message, AIMessage):
Expand Down Expand Up @@ -299,12 +300,10 @@ def _get_client_with_sys_instruction(
client: GenerativeModel,
system_instruction: Content,
model_name: str,
safety_settings: Optional[Dict] = None,
):
if client._system_instruction != system_instruction:
client = GenerativeModel(
model_name=model_name,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
return client
Expand Down Expand Up @@ -472,24 +471,27 @@ def _generate(
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
self.client = _get_client_with_sys_instruction(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)

# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
raw_tool_config = (
params.pop("tool_config") if "tool_config" in params else None
)
tool_config = (
_format_tool_config(raw_tool_config) if raw_tool_config else None
)
with telemetry.tool_context_manager(self._user_agent):
response = chat.send_message(
message,
response = self.client.generate_content(
history_gemini,
generation_config=params,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
)
generations = [
Expand Down Expand Up @@ -562,24 +564,27 @@ async def _agenerate(
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()

self.client = _get_client_with_sys_instruction(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
raw_tool_config = (
params.pop("tool_config") if "tool_config" in params else None
)
tool_config = (
_format_tool_config(raw_tool_config) if raw_tool_config else None
)
with telemetry.tool_context_manager(self._user_agent):
response = await chat.send_message_async(
message,
response = await self.client.generate_content_async(
history_gemini,
generation_config=params,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
)
generations = [
Expand Down Expand Up @@ -630,25 +635,28 @@ def _stream(
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
self.client = _get_client_with_sys_instruction(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
raw_tool_config = (
params.pop("tool_config") if "tool_config" in params else None
)
tool_config = (
_format_tool_config(raw_tool_config) if raw_tool_config else None
)
with telemetry.tool_context_manager(self._user_agent):
responses = chat.send_message(
message,
responses = self.client.generate_content(
history_gemini,
stream=True,
generation_config=params,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
)
for response in responses:
message = _parse_response_candidate(
Expand All @@ -659,7 +667,7 @@ def _stream(
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
)
if run_manager:
if run_manager and isinstance(message.content, str):
run_manager.on_llm_new_token(message.content)
if isinstance(message, AIMessageChunk):
yield ChatGenerationChunk(
Expand Down Expand Up @@ -716,27 +724,27 @@ async def _astream(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
raw_tool_config = params.pop("tool_config") if "tool_config" in params else None
tool_config = _format_tool_config(raw_tool_config) if raw_tool_config else None
with telemetry.tool_context_manager(self._user_agent):
async for chunk in await chat.send_message_async(
message,
async for chunk in await self.client.generate_content_async(
history_gemini,
stream=True,
generation_config=params,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
):
message = _parse_response_candidate(chunk.candidates[0], streaming=True)
generation_info = get_generation_info(
chunk.candidates[0],
self._is_gemini_model,
usage_metadata=chunk.to_dict().get("usage_metadata"),
)
if run_manager:
if run_manager and isinstance(message.content, str):
await run_manager.on_llm_new_token(message.content)
if isinstance(message, AIMessageChunk):
yield ChatGenerationChunk(
Expand Down
15 changes: 15 additions & 0 deletions libs/vertexai/langchain_google_vertexai/functions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
Tool as VertexTool,
)

# FIXME: vertexai is not exporting ToolConfig
from vertexai.generative_models._generative_models import ( # type: ignore
ToolConfig,
)


def _format_pydantic_to_vertex_function(
pydantic_model: Type[BaseModel],
Expand Down Expand Up @@ -75,6 +80,16 @@ def _format_tools_to_vertex_tool(
return [VertexTool(function_declarations=function_declarations)]


def _format_tool_config(tool_config: Dict[str, Any]) -> Union[ToolConfig, None]:
if "function_calling_config" not in tool_config:
return None
return ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
**tool_config["function_calling_config"]
)
)


class ParametersSchema(BaseModel):
"""
This is a schema of currently supported definitions in function calling.
Expand Down
72 changes: 71 additions & 1 deletion libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import tool

from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory
from langchain_google_vertexai import (
ChatVertexAI,
HarmBlockThreshold,
HarmCategory,
ToolConfig,
)

model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]

Expand Down Expand Up @@ -338,3 +343,68 @@ def my_tool(name: str, age: int) -> None:
tool_call_chunk = gathered.tool_call_chunks[0]
assert tool_call_chunk["name"] == "my_tool"
assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}'


@pytest.mark.release
def test_chat_vertexai_gemini_function_calling_tool_config_any() -> None:
class MyModel(BaseModel):
name: str
age: int

safety = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
}
model = ChatVertexAI(
model_name="gemini-1.5-pro-preview-0409", safety_settings=safety
).bind(
functions=[MyModel],
tool_config={
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["MyModel"],
}
},
)
message = HumanMessage(content="My name is Erick and I am 27 years old")
response = model.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == "MyModel"
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}


@pytest.mark.release
def test_chat_vertexai_gemini_function_calling_tool_config_none() -> None:
class MyModel(BaseModel):
name: str
age: int

safety = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
}
model = ChatVertexAI(
model_name="gemini-1.5-pro-preview-0409", safety_settings=safety
).bind(
functions=[MyModel],
tool_config={
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.NONE,
}
},
)
message = HumanMessage(content="My name is Erick and I am 27 years old")
response = model.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content != ""
function_call = response.additional_kwargs.get("function_call")
assert function_call is None
24 changes: 13 additions & 11 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

import pytest
from google.cloud.aiplatform_v1beta1.types import (
Content,
Content as Content,
)
from google.cloud.aiplatform_v1beta1.types import (
FunctionCall,
FunctionResponse,
Part,
)
from google.cloud.aiplatform_v1beta1.types import (
Part as Part,
)
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
Expand Down Expand Up @@ -176,7 +180,7 @@ def test_parse_history_gemini_converted_message() -> None:
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
system_instructions, history = _parse_chat_history_gemini(
_, history = _parse_chat_history_gemini(
messages, convert_system_message_to_human=True
)
assert len(history) == 3
Expand Down Expand Up @@ -314,18 +318,16 @@ def test_default_params_gemini() -> None:
citation_metadata=None,
)
]
mock_chat = MagicMock()
mock_send_message = MagicMock(return_value=mock_response)
mock_chat.send_message = mock_send_message

mock_generate_content = MagicMock(return_value=mock_response)
mock_model = MagicMock()
mock_start_chat = MagicMock(return_value=mock_chat)
mock_model.start_chat = mock_start_chat
mock_model.generate_content = mock_generate_content
gm.return_value = mock_model

model = ChatVertexAI(model_name="gemini-pro")
message = HumanMessage(content=user_prompt)
_ = model([message])
mock_start_chat.assert_called_once_with(history=[])
_ = model.invoke([message])
mock_generate_content.assert_called_once()
assert mock_generate_content.call_args.args[0][0].parts[0].text == user_prompt


@pytest.mark.parametrize(
Expand Down
Loading
Loading