Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update vertex to use tool call attribute on AIMessage #123

Merged
merged 8 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 93 additions & 25 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@
BaseMessage,
FunctionMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
ToolCallChunk,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
Expand Down Expand Up @@ -284,7 +288,9 @@ def _get_client_with_sys_instruction(
return client


def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
def _parse_response_candidate(
response_candidate: "Candidate", streaming: bool = False
) -> AIMessage:
try:
content = response_candidate.text
except AttributeError:
Expand All @@ -302,6 +308,52 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
additional_kwargs["function_call"] = function_call
if streaming:
tool_call_chunks = [
ToolCallChunk(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id"),
index=function_call.get("index"),
)
]
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
else:
tool_calls = []
invalid_tool_calls = []
try:
tool_calls_dicts = parse_tool_calls(
[{"function": function_call}],
return_id=False,
)
tool_calls = [
ToolCall(
name=tool_call["name"],
args=tool_call["args"],
id=tool_call.get("id"),
)
for tool_call in tool_calls_dicts
]
except Exception as e:
invalid_tool_calls = [
InvalidToolCall(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id"),
error=str(e),
)
]

return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
return AIMessage(content=content, additional_kwargs=additional_kwargs)


Expand Down Expand Up @@ -577,20 +629,29 @@ def _stream(
tools=tools,
)
for response in responses:
message = _parse_response_candidate(response.candidates[0])
message = _parse_response_candidate(
response.candidates[0], streaming=True
)
generation_info = get_generation_info(
response.candidates[0],
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
)
if run_manager:
run_manager.on_llm_new_token(message.content)
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=get_generation_info(
response.candidates[0],
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
if isinstance(message, AIMessageChunk):
yield ChatGenerationChunk(
message=message,
generation_info=generation_info,
)
else:
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=generation_info,
)
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
Expand Down Expand Up @@ -647,20 +708,27 @@ async def _astream(
safety_settings=safety_settings,
tools=tools,
):
message = _parse_response_candidate(chunk.candidates[0])
message = _parse_response_candidate(chunk.candidates[0], streaming=True)
generation_info = get_generation_info(
chunk.candidates[0],
self._is_gemini_model,
usage_metadata=chunk.to_dict().get("usage_metadata"),
)
if run_manager:
await run_manager.on_llm_new_token(message.content)
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=get_generation_info(
chunk.candidates[0],
self._is_gemini_model,
usage_metadata=chunk.to_dict().get("usage_metadata"),
),
)
if isinstance(message, AIMessageChunk):
yield ChatGenerationChunk(
message=message,
generation_info=generation_info,
)
else:
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=generation_info,
)

def with_structured_output(
self,
Expand Down
5 changes: 3 additions & 2 deletions libs/vertexai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions libs/vertexai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[tool.poetry]
name = "langchain-google-vertexai"
version = "0.1.2"

version = "0.1.3-rc.1"
description = "An integration package connecting Google VertexAI and LangChain"
authors = []
readme = "README.md"
Expand All @@ -12,7 +13,7 @@ license = "MIT"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.1.27,<0.2"
langchain-core = {version ="^0.1.42-rc.1", allow-prereleases=true}
google-cloud-aiplatform = "^1.47.0"
google-cloud-storage = "^2.14.0"
types-requests = "^2.31.0"
Expand Down
20 changes: 20 additions & 0 deletions libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
AIMessageChunk,
HumanMessage,
SystemMessage,
ToolCall,
ToolCallChunk,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel
Expand Down Expand Up @@ -287,3 +289,21 @@ class MyModel(BaseModel):
"name": "Erick",
"age": 27.0,
}
assert response.tool_calls == [
ToolCall(name="MyModel", args={"age": 27.0, "name": "Erick"}, id=None)
]

stream = model.stream([message])
first = True
for chunk in stream:
if first:
gathered = chunk
first = False
else:
gathered = gathered + chunk # type: ignore
assert isinstance(gathered, AIMessageChunk)
assert gathered.tool_call_chunks == [
ToolCallChunk(
name="MyModel", args='{"age": 27.0, "name": "Erick"}', id=None, index=None
)
]
Loading