Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/4299 cant instantiate simplechatmodel subclass without defining agenerate #4300

10 changes: 9 additions & 1 deletion langchain/chat_models/base.py
Expand Up @@ -2,7 +2,7 @@
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from typing import Any, Coroutine, Dict, List, Optional

from pydantic import Extra, Field, root_validator

Expand Down Expand Up @@ -197,6 +197,14 @@ def _generate(
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

def _agenerate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

This will have to be an async method. Will defer to @agola11 on whether it makes sense to have it default to

await asyncio.get_event_loop().run_in_executor(self._generate(messages, stop=stop, run_manager=run_manager)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this needs to be async and we should not be calling blocking code in async methods. A good solution is having the default implementation for async def _agenerate call run_in_executor, as @vowelparrot suggested

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#4701

how does this look

self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
) -> Coroutine[Any, Any, ChatResult]:
return self._generate(messages, stop=stop, run_manager=run_manager)

@abstractmethod
def _call(
self,
Expand Down
8 changes: 8 additions & 0 deletions langchain/utils.py
Expand Up @@ -4,6 +4,8 @@

from requests import HTTPError, Response

from langchain.schema import BaseMessage


def get_from_dict_or_env(
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
Expand Down Expand Up @@ -78,3 +80,9 @@ def stringify_dict(data: dict) -> str:
for key, value in data.items():
text += key + ": " + stringify_value(value) + "\n"
return text


def serialize_msgs(msgs: list[BaseMessage], include_type=False) -> str:
return "\n\n".join(
(f"{msg.type}: {msg.content}" if include_type else msg.content) for msg in msgs
)
7 changes: 7 additions & 0 deletions langchain/wrappers/__init__.py
@@ -0,0 +1,7 @@
from langchain.wrappers.chat_model_facade import ChatModelFacade
from langchain.wrappers.llm_facade import LLMFacade

__all__ = [
"ChatModelFacade",
"LLMFacade",
]
33 changes: 33 additions & 0 deletions langchain/wrappers/chat_model_facade.py
@@ -0,0 +1,33 @@
from __future__ import annotations

from typing import List, Optional

from langchain.chat_models.base import BaseChatModel, SimpleChatModel
from langchain.schema import BaseMessage
from langchain.llms.base import BaseLanguageModel
from langchain.utils import serialize_msgs


class ChatModelFacade(SimpleChatModel):
llm: BaseLanguageModel

def _call(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> str:
if isinstance(self.llm, BaseChatModel):
return self.llm(messages, stop=stop).content
elif isinstance(self.llm, BaseLanguageModel):
return self.llm(serialize_msgs(messages), stop=stop)
else:
raise ValueError(
f"Invalid llm type: {type(self.llm)}. Must be a chat model or language model."
)

@classmethod
def of(cls, llm):
if isinstance(llm, BaseChatModel):
return llm
elif isinstance(llm, BaseLanguageModel):
return cls(llm)
else:
raise ValueError(
f"Invalid llm type: {type(llm)}. Must be a chat model or language model."
)
37 changes: 37 additions & 0 deletions langchain/wrappers/llm_facade.py
@@ -0,0 +1,37 @@
from __future__ import annotations

from typing import Any, List, Mapping, Optional

from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import LLM, BaseLanguageModel


class LLMFacade(LLM):
chat_model: BaseChatModel

@property
def _llm_type(self) -> str:
return self.chat_model._llm_type

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
) -> str:
return self.chat_model.call_as_llm(prompt, stop=stop)

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return self._chat._identifying_params

@staticmethod
def of(llm) -> LLMFacade:
if isinstance(llm, BaseChatModel):
return LLMFacade(llm)
elif isinstance(llm, BaseLanguageModel):
return llm
else:
raise ValueError(
f"Invalid llm type: {type(llm)}. Must be a chat model or language model."
)
Empty file.
12 changes: 12 additions & 0 deletions tests/unit_tests/wrappers/test_chat_model_facade.py
@@ -0,0 +1,12 @@
from langchain.llms.fake import FakeListLLM
from langchain.schema import SystemMessage
from langchain.wrappers.chat_model_facade import ChatModelFacade


def test_chat_model_facade():
llm = FakeListLLM(responses=["hello", "goodbye"])
chat_model = ChatModelFacade.of(llm)
input_message = SystemMessage(content="hello")
output_message = chat_model([input_message])
assert output_message.content == "hello"
assert output_message.type == "ai"
2 changes: 2 additions & 0 deletions tests/unit_tests/wrappers/test_llm_facade.py
@@ -0,0 +1,2 @@
def test_llm_facade():
pass