Skip to content

Commit

Permalink
community[patch]: Extend Baichuan model with tool support (langchain-…
Browse files Browse the repository at this point in the history
…ai#24529)

**Description:** Expanded the chat model functionality to support tools
in the 'baichuan.py' file. Updated module imports and added tool object
handling in message conversions. Additional changes include the
implementation of tool binding and related unit tests. The alterations
offer enhanced model capabilities by enabling interaction with tool-like
objects.

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
  • Loading branch information
nobbbbby and ccurme committed Jul 26, 2024
1 parent ee399e3 commit 4f3b4fc
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 10 deletions.
135 changes: 125 additions & 10 deletions libs/community/langchain_community/chat_models/baichuan.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Type
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)

import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
Expand All @@ -24,14 +37,27 @@
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.function_calling import convert_to_openai_tool

from langchain_community.chat_models.llamacpp import (
_lc_invalid_tool_call_to_openai_tool_call,
_lc_tool_call_to_openai_tool_call,
)

logger = logging.getLogger(__name__)

Expand All @@ -40,14 +66,33 @@

def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
content = message.content
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
message_dict = {"role": message.role, "content": content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
message_dict = {"role": "user", "content": content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
message_dict = {"role": "assistant", "content": content}
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]

elif message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_openai_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"tool_call_id": message.tool_call_id,
"content": content,
"name": message.name or message.additional_kwargs.get("name"),
}

elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
message_dict = {"role": "system", "content": content}
else:
raise TypeError(f"Got unknown type {message}")

Expand All @@ -56,14 +101,43 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:

def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
content = _dict.get("content", "")
if role == "user":
return HumanMessage(content=_dict["content"])
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(content=_dict.get("content", "") or "")
tool_calls = []
invalid_tool_calls = []
additional_kwargs = {}

if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)

return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls, # type: ignore[arg-type]
invalid_tool_calls=invalid_tool_calls,
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=content,
tool_call_id=_dict.get("tool_call_id"), # type: ignore[arg-type]
additional_kwargs=additional_kwargs,
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
return SystemMessage(content=content)
else:
return ChatMessage(content=_dict["content"], role=role)
return ChatMessage(content=content, role=role)


def _convert_delta_to_message_chunk(
Expand Down Expand Up @@ -226,6 +300,24 @@ class ChatBaichuan(BaseChatModel):
},
id='run-952509ed-9154-4ff9-b187-e616d7ddfbba-0'
)
Tool calling:
.. code-block:: python
class get_current_weather(BaseModel):
'''Get current weather.'''
location: str = Field('City or province, such as Shanghai')
llm_with_tools = ChatBaichuan(model='Baichuan3-Turbo').bind_tools([get_current_weather])
llm_with_tools.invoke('How is the weather today?')
.. code-block:: python
[{'name': 'get_current_weather',
'args': {'location': 'New York'},
'id': '3951017OF8doB0A',
'type': 'tool_call'}]
Response metadata
.. code-block:: python
Expand Down Expand Up @@ -486,6 +578,7 @@ def _create_payload_parameters( # type: ignore[no-untyped-def]
model = parameters.pop("model")
with_search_enhance = parameters.pop("with_search_enhance", False)
stream = parameters.pop("stream", False)
tools = parameters.pop("tools", [])

payload = {
"model": model,
Expand All @@ -495,7 +588,9 @@ def _create_payload_parameters( # type: ignore[no-untyped-def]
"temperature": temperature,
"with_search_enhance": with_search_enhance,
"stream": stream,
"tools": tools,
}

return payload

def _create_headers_parameters(self, **kwargs) -> Dict[str, Any]: # type: ignore[no-untyped-def]
Expand Down Expand Up @@ -526,3 +621,23 @@ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
@property
def _llm_type(self) -> str:
return "baichuan-chat"

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool.
Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""

formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
13 changes: 13 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
HumanMessage,
HumanMessageChunk,
SystemMessage,
ToolMessage,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
Expand Down Expand Up @@ -58,6 +59,18 @@ def test__convert_message_to_dict_system() -> None:
assert result == expected_output


def test__convert_message_to_dict_tool() -> None:
message = ToolMessage(name="foo", content="bar", tool_call_id="abc123")
result = _convert_message_to_dict(message)
expected_output = {
"name": "foo",
"content": "bar",
"tool_call_id": "abc123",
"role": "tool",
}
assert result == expected_output


def test__convert_message_to_dict_function() -> None:
message = FunctionMessage(name="foo", content="bar")
with pytest.raises(TypeError) as e:
Expand Down

0 comments on commit 4f3b4fc

Please sign in to comment.