diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index b41aa2371707dd..4f5f126793e38e 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -2,7 +2,7 @@ import inspect import warnings from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Any, Coroutine, Dict, List, Optional from pydantic import Extra, Field, root_validator @@ -197,6 +197,14 @@ def _generate( generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) + def _agenerate( + self, + messages: List[BaseMessage], + stop: List[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + ) -> Coroutine[Any, Any, ChatResult]: + return self._generate(messages, stop=stop, run_manager=run_manager) + @abstractmethod def _call( self, diff --git a/langchain/utils.py b/langchain/utils.py index 7420b371c2cb5e..5b7757c3a859f4 100644 --- a/langchain/utils.py +++ b/langchain/utils.py @@ -4,6 +4,8 @@ from requests import HTTPError, Response +from langchain.schema import BaseMessage + def get_from_dict_or_env( data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None @@ -78,3 +80,9 @@ def stringify_dict(data: dict) -> str: for key, value in data.items(): text += key + ": " + stringify_value(value) + "\n" return text + + +def serialize_msgs(msgs: list[BaseMessage], include_type=False) -> str: + return "\n\n".join( + (f"{msg.type}: {msg.content}" if include_type else msg.content) for msg in msgs + ) diff --git a/langchain/wrappers/__init__.py b/langchain/wrappers/__init__.py new file mode 100644 index 00000000000000..505ff66e9f2ff5 --- /dev/null +++ b/langchain/wrappers/__init__.py @@ -0,0 +1,7 @@ +from langchain.wrappers.chat_model_facade import ChatModelFacade +from langchain.wrappers.llm_facade import LLMFacade + +__all__ = [ + "ChatModelFacade", + "LLMFacade", +] diff --git a/langchain/wrappers/chat_model_facade.py b/langchain/wrappers/chat_model_facade.py new file mode 100644 index 00000000000000..d24d01f2594dbf --- /dev/null +++ b/langchain/wrappers/chat_model_facade.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import List, Optional + +from langchain.chat_models.base import BaseChatModel, SimpleChatModel +from langchain.schema import BaseMessage +from langchain.llms.base import BaseLanguageModel +from langchain.utils import serialize_msgs + + +class ChatModelFacade(SimpleChatModel): + llm: BaseLanguageModel + + def _call(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> str: + if isinstance(self.llm, BaseChatModel): + return self.llm(messages, stop=stop).content + elif isinstance(self.llm, BaseLanguageModel): + return self.llm(serialize_msgs(messages), stop=stop) + else: + raise ValueError( + f"Invalid llm type: {type(self.llm)}. Must be a chat model or language model." + ) + + @classmethod + def of(cls, llm): + if isinstance(llm, BaseChatModel): + return llm + elif isinstance(llm, BaseLanguageModel): + return cls(llm) + else: + raise ValueError( + f"Invalid llm type: {type(llm)}. Must be a chat model or language model." + ) diff --git a/langchain/wrappers/llm_facade.py b/langchain/wrappers/llm_facade.py new file mode 100644 index 00000000000000..71b31459581ac2 --- /dev/null +++ b/langchain/wrappers/llm_facade.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any, List, Mapping, Optional + +from langchain.chat_models.base import BaseChatModel +from langchain.llms.base import LLM, BaseLanguageModel + + +class LLMFacade(LLM): + chat_model: BaseChatModel + + @property + def _llm_type(self) -> str: + return self.chat_model._llm_type + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + ) -> str: + return self.chat_model.call_as_llm(prompt, stop=stop) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return self._chat._identifying_params + + @staticmethod + def of(llm) -> LLMFacade: + if isinstance(llm, BaseChatModel): + return LLMFacade(llm) + elif isinstance(llm, BaseLanguageModel): + return llm + else: + raise ValueError( + f"Invalid llm type: {type(llm)}. Must be a chat model or language model." + ) diff --git a/tests/unit_tests/wrappers/__init__.py b/tests/unit_tests/wrappers/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/unit_tests/wrappers/test_chat_model_facade.py b/tests/unit_tests/wrappers/test_chat_model_facade.py new file mode 100644 index 00000000000000..4264f678bcc720 --- /dev/null +++ b/tests/unit_tests/wrappers/test_chat_model_facade.py @@ -0,0 +1,12 @@ +from langchain.llms.fake import FakeListLLM +from langchain.schema import SystemMessage +from langchain.wrappers.chat_model_facade import ChatModelFacade + + +def test_chat_model_facade(): + llm = FakeListLLM(responses=["hello", "goodbye"]) + chat_model = ChatModelFacade.of(llm) + input_message = SystemMessage(content="hello") + output_message = chat_model([input_message]) + assert output_message.content == "hello" + assert output_message.type == "ai" diff --git a/tests/unit_tests/wrappers/test_llm_facade.py b/tests/unit_tests/wrappers/test_llm_facade.py new file mode 100644 index 00000000000000..ffb90ac88e945a --- /dev/null +++ b/tests/unit_tests/wrappers/test_llm_facade.py @@ -0,0 +1,2 @@ +def test_llm_facade(): + pass