Skip to content

Commit

Permalink
core: bind_tools interface on basechatmodel (#20360)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis committed Apr 12, 2024
1 parent e6806a0 commit 2928237
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 114 deletions.
57 changes: 36 additions & 21 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
cast,
)
Expand Down Expand Up @@ -53,7 +55,9 @@
from langchain_core.tracers.log_stream import LogStreamCallbackHandler

if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool


def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
Expand Down Expand Up @@ -599,16 +603,18 @@ def _generate_with_cache(
# astream_events() or astream_log(). Bail out if _stream not implemented
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
"stream",
next(
(
True
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
),
False,
)
if run_manager
else False,
(
next(
(
True
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
),
False,
)
if run_manager
else False
),
):
chunks: List[ChatGenerationChunk] = []
for chunk in self._stream(messages, stop=stop, **kwargs):
Expand Down Expand Up @@ -680,16 +686,18 @@ async def _agenerate_with_cache(
or type(self)._stream != BaseChatModel._stream
) and kwargs.pop(
"stream",
next(
(
True
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
),
False,
)
if run_manager
else False,
(
next(
(
True
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
),
False,
)
if run_manager
else False
),
):
chunks: List[ChatGenerationChunk] = []
async for chunk in self._astream(messages, stop=stop, **kwargs):
Expand Down Expand Up @@ -896,6 +904,13 @@ def dict(self, **kwargs: Any) -> Dict:
starter_dict["_type"] = self._llm_type
return starter_dict

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError()


class SimpleChatModel(BaseChatModel):
"""A simplified implementation for a chat model to inherit from."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def chat_model_params(self) -> dict:
def chat_model_has_tool_calling(
self, chat_model_class: Type[BaseChatModel]
) -> bool:
return hasattr(chat_model_class, "bind_tools")
return chat_model_class.bind_tools is not BaseChatModel.bind_tools

@pytest.fixture
def chat_model_has_structured_output(
Expand Down
Loading

0 comments on commit 2928237

Please sign in to comment.