Skip to content

Commit

Permalink
Merge pull request #123 from langchain-ai/cc/tool_calls_msg
Browse files Browse the repository at this point in the history
update vertex to use tool call attribute on AIMessage
  • Loading branch information
efriis committed Apr 10, 2024
2 parents 8cac88a + c7cfdd2 commit 22c2056
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 29 deletions.
118 changes: 93 additions & 25 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@
BaseMessage,
FunctionMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
ToolCallChunk,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
Expand Down Expand Up @@ -284,7 +288,9 @@ def _get_client_with_sys_instruction(
return client


def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
def _parse_response_candidate(
response_candidate: "Candidate", streaming: bool = False
) -> AIMessage:
try:
content = response_candidate.text
except AttributeError:
Expand All @@ -302,6 +308,52 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
additional_kwargs["function_call"] = function_call
if streaming:
tool_call_chunks = [
ToolCallChunk(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id"),
index=function_call.get("index"),
)
]
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)
else:
tool_calls = []
invalid_tool_calls = []
try:
tool_calls_dicts = parse_tool_calls(
[{"function": function_call}],
return_id=False,
)
tool_calls = [
ToolCall(
name=tool_call["name"],
args=tool_call["args"],
id=tool_call.get("id"),
)
for tool_call in tool_calls_dicts
]
except Exception as e:
invalid_tool_calls = [
InvalidToolCall(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id"),
error=str(e),
)
]

return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
return AIMessage(content=content, additional_kwargs=additional_kwargs)


Expand Down Expand Up @@ -577,20 +629,29 @@ def _stream(
tools=tools,
)
for response in responses:
message = _parse_response_candidate(response.candidates[0])
message = _parse_response_candidate(
response.candidates[0], streaming=True
)
generation_info = get_generation_info(
response.candidates[0],
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
)
if run_manager:
run_manager.on_llm_new_token(message.content)
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=get_generation_info(
response.candidates[0],
self._is_gemini_model,
usage_metadata=response.to_dict().get("usage_metadata"),
),
)
if isinstance(message, AIMessageChunk):
yield ChatGenerationChunk(
message=message,
generation_info=generation_info,
)
else:
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=generation_info,
)
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
Expand Down Expand Up @@ -647,20 +708,27 @@ async def _astream(
safety_settings=safety_settings,
tools=tools,
):
message = _parse_response_candidate(chunk.candidates[0])
message = _parse_response_candidate(chunk.candidates[0], streaming=True)
generation_info = get_generation_info(
chunk.candidates[0],
self._is_gemini_model,
usage_metadata=chunk.to_dict().get("usage_metadata"),
)
if run_manager:
await run_manager.on_llm_new_token(message.content)
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=get_generation_info(
chunk.candidates[0],
self._is_gemini_model,
usage_metadata=chunk.to_dict().get("usage_metadata"),
),
)
if isinstance(message, AIMessageChunk):
yield ChatGenerationChunk(
message=message,
generation_info=generation_info,
)
else:
yield ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
),
generation_info=generation_info,
)

def with_structured_output(
self,
Expand Down
5 changes: 3 additions & 2 deletions libs/vertexai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions libs/vertexai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[tool.poetry]
name = "langchain-google-vertexai"
version = "0.1.2"

version = "0.1.3-rc.1"
description = "An integration package connecting Google VertexAI and LangChain"
authors = []
readme = "README.md"
Expand All @@ -12,7 +13,7 @@ license = "MIT"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.1.27,<0.2"
langchain-core = {version ="^0.1.42-rc.1", allow-prereleases=true}
google-cloud-aiplatform = "^1.47.0"
google-cloud-storage = "^2.14.0"
types-requests = "^2.31.0"
Expand Down
20 changes: 20 additions & 0 deletions libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
AIMessageChunk,
HumanMessage,
SystemMessage,
ToolCall,
ToolCallChunk,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel
Expand Down Expand Up @@ -287,3 +289,21 @@ class MyModel(BaseModel):
"name": "Erick",
"age": 27.0,
}
assert response.tool_calls == [
ToolCall(name="MyModel", args={"age": 27.0, "name": "Erick"}, id=None)
]

stream = model.stream([message])
first = True
for chunk in stream:
if first:
gathered = chunk
first = False
else:
gathered = gathered + chunk # type: ignore
assert isinstance(gathered, AIMessageChunk)
assert gathered.tool_call_chunks == [
ToolCallChunk(
name="MyModel", args='{"age": 27.0, "name": "Erick"}', id=None, index=None
)
]

0 comments on commit 22c2056

Please sign in to comment.