-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generic LLM wrapper to support chat model interface with configurable…
… chat prompt format
- Loading branch information
Showing
6 changed files
with
1,100 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
"""Generic Wrapper for chat LLMs, with sample implementations | ||
for Llama-2-chat, Llama-2-instruct and Vicuna models. | ||
""" | ||
from typing import Any, List, Optional | ||
|
||
from langchain.callbacks.manager import ( | ||
AsyncCallbackManagerForLLMRun, | ||
CallbackManagerForLLMRun, | ||
) | ||
from langchain.chat_models.base import BaseChatModel | ||
from langchain.llms.base import LLM | ||
from langchain.schema import ( | ||
AIMessage, | ||
BaseMessage, | ||
ChatGeneration, | ||
ChatResult, | ||
HumanMessage, | ||
LLMResult, | ||
SystemMessage, | ||
) | ||
|
||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. | ||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501 | ||
|
||
|
||
class ChatWrapper(BaseChatModel): | ||
llm: LLM | ||
sys_beg: str | ||
sys_end: str | ||
ai_n_beg: str | ||
ai_n_end: str | ||
usr_n_beg: str | ||
usr_n_end: str | ||
usr_0_beg: Optional[str] = None | ||
usr_0_end: Optional[str] = None | ||
|
||
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) | ||
|
||
def _generate( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> ChatResult: | ||
llm_input = self._to_chat_prompt(messages) | ||
llm_result = self.llm._generate( | ||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs | ||
) | ||
return self._to_chat_result(llm_result) | ||
|
||
async def _agenerate( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> ChatResult: | ||
llm_input = self._to_chat_prompt(messages) | ||
llm_result = await self.llm._agenerate( | ||
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs | ||
) | ||
return self._to_chat_result(llm_result) | ||
|
||
def _to_chat_prompt( | ||
self, | ||
messages: List[BaseMessage], | ||
) -> str: | ||
"""Convert a list of messages into a prompt format expected by wrapped LLM.""" | ||
if not messages: | ||
raise ValueError("at least one HumanMessage must be provided") | ||
|
||
if not isinstance(messages[0], SystemMessage): | ||
messages = [self.system_message] + messages | ||
|
||
if not isinstance(messages[1], HumanMessage): | ||
raise ValueError( | ||
"messages list must start with a SystemMessage or UserMessage" | ||
) | ||
|
||
if not isinstance(messages[-1], HumanMessage): | ||
raise ValueError("last message must be a HumanMessage") | ||
|
||
prompt_parts = [] | ||
|
||
if self.usr_0_beg is None: | ||
self.usr_0_beg = self.usr_n_beg | ||
|
||
if self.usr_0_end is None: | ||
self.usr_0_end = self.usr_n_end | ||
|
||
prompt_parts.append(self.sys_beg + messages[0].content + self.sys_end) | ||
prompt_parts.append(self.usr_0_beg + messages[1].content + self.usr_0_end) | ||
|
||
for ai_message, human_message in zip(messages[2::2], messages[3::2]): | ||
if not isinstance(ai_message, AIMessage) or not isinstance( | ||
human_message, HumanMessage | ||
): | ||
raise ValueError( | ||
"messages must be alternating human- and ai-messages, " | ||
"optionally prepended by a system message" | ||
) | ||
|
||
prompt_parts.append(self.ai_n_beg + ai_message.content + self.ai_n_end) | ||
prompt_parts.append(self.usr_n_beg + human_message.content + self.usr_n_end) | ||
|
||
return "".join(prompt_parts) | ||
|
||
@staticmethod | ||
def _to_chat_result(llm_result: LLMResult) -> ChatResult: | ||
chat_generations = [] | ||
|
||
for g in llm_result.generations[0]: | ||
chat_generation = ChatGeneration( | ||
message=AIMessage(content=g.text), generation_info=g.generation_info | ||
) | ||
chat_generations.append(chat_generation) | ||
|
||
return ChatResult( | ||
generations=chat_generations, llm_output=llm_result.llm_output | ||
) | ||
|
||
|
||
class Llama2Chat(ChatWrapper): | ||
@property | ||
def _llm_type(self) -> str: | ||
return "llama-2-chat" | ||
|
||
sys_beg: str = "<s>[INST] <<SYS>>\n" | ||
sys_end: str = "\n<</SYS>>\n\n" | ||
ai_n_beg: str = " " | ||
ai_n_end: str = " </s>" | ||
usr_n_beg: str = "<s>[INST] " | ||
usr_n_end: str = " [/INST]" | ||
usr_0_beg: str = "" | ||
usr_0_end: str = " [/INST]" | ||
|
||
|
||
class Orca(ChatWrapper): | ||
@property | ||
def _llm_type(self) -> str: | ||
return "orca-style" | ||
|
||
sys_beg: str = "### System:\n" | ||
sys_end: str = "\n\n" | ||
ai_n_beg: str = "### Assistant:\n" | ||
ai_n_end: str = "\n\n" | ||
usr_n_beg: str = "### User:\n" | ||
usr_n_end: str = "\n\n" | ||
|
||
|
||
class Vicuna(ChatWrapper): | ||
@property | ||
def _llm_type(self) -> str: | ||
return "vicuna-style" | ||
|
||
sys_beg: str = "" | ||
sys_end: str = " " | ||
ai_n_beg: str = "ASSISTANT: " | ||
ai_n_end: str = " </s>" | ||
usr_n_beg: str = "USER: " | ||
usr_n_end: str = " " |
157 changes: 157 additions & 0 deletions
157
libs/langchain/tests/unit_tests/chat_models/test_llm_wrapper_llama2chat.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
from typing import Any, List, Optional | ||
|
||
import pytest | ||
|
||
from langchain.callbacks.manager import ( | ||
AsyncCallbackManagerForLLMRun, | ||
CallbackManagerForLLMRun, | ||
) | ||
from langchain.chat_models import Llama2Chat | ||
from langchain.chat_models.llm_wrapper import DEFAULT_SYSTEM_PROMPT | ||
from langchain.llms.base import LLM | ||
from langchain.schema import AIMessage, HumanMessage, SystemMessage | ||
|
||
|
||
class FakeLLM(LLM): | ||
def _call( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> str: | ||
return prompt | ||
|
||
async def _acall( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> str: | ||
return prompt | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
return "fake-llm" | ||
|
||
|
||
@pytest.fixture | ||
def model() -> Llama2Chat: | ||
return Llama2Chat(llm=FakeLLM()) | ||
|
||
|
||
@pytest.fixture | ||
def model_cfg_sys_msg() -> Llama2Chat: | ||
return Llama2Chat(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg")) | ||
|
||
|
||
def test_default_system_message(model: Llama2Chat) -> None: | ||
messages = [HumanMessage(content="usr-msg-1")] | ||
|
||
actual = model.predict_messages(messages).content # type: ignore | ||
expected = ( | ||
f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\nusr-msg-1 [/INST]" | ||
) | ||
|
||
assert actual == expected | ||
|
||
|
||
def test_configured_system_message( | ||
model_cfg_sys_msg: Llama2Chat, | ||
) -> None: | ||
messages = [HumanMessage(content="usr-msg-1")] | ||
|
||
actual = model_cfg_sys_msg.predict_messages(messages).content # type: ignore | ||
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]" | ||
|
||
assert actual == expected | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_configured_system_message_async( | ||
model_cfg_sys_msg: Llama2Chat, | ||
) -> None: | ||
messages = [HumanMessage(content="usr-msg-1")] | ||
|
||
actual = await model_cfg_sys_msg.apredict_messages(messages) # type: ignore | ||
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]" | ||
|
||
assert actual.content == expected | ||
|
||
|
||
def test_provided_system_message( | ||
model_cfg_sys_msg: Llama2Chat, | ||
) -> None: | ||
messages = [ | ||
SystemMessage(content="custom-sys-msg"), | ||
HumanMessage(content="usr-msg-1"), | ||
] | ||
|
||
actual = model_cfg_sys_msg.predict_messages(messages).content | ||
expected = "<s>[INST] <<SYS>>\ncustom-sys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]" | ||
|
||
assert actual == expected | ||
|
||
|
||
def test_human_ai_dialogue(model_cfg_sys_msg: Llama2Chat) -> None: | ||
messages = [ | ||
HumanMessage(content="usr-msg-1"), | ||
AIMessage(content="ai-msg-1"), | ||
HumanMessage(content="usr-msg-2"), | ||
AIMessage(content="ai-msg-2"), | ||
HumanMessage(content="usr-msg-3"), | ||
] | ||
|
||
actual = model_cfg_sys_msg.predict_messages(messages).content | ||
expected = ( | ||
"<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST] ai-msg-1 </s>" | ||
"<s>[INST] usr-msg-2 [/INST] ai-msg-2 </s><s>[INST] usr-msg-3 [/INST]" | ||
) | ||
|
||
assert actual == expected | ||
|
||
|
||
def test_no_message(model: Llama2Chat) -> None: | ||
with pytest.raises(ValueError) as info: | ||
model.predict_messages([]) | ||
|
||
assert info.value.args[0] == "at least one HumanMessage must be provided" | ||
|
||
|
||
def test_ai_message_first(model: Llama2Chat) -> None: | ||
with pytest.raises(ValueError) as info: | ||
model.predict_messages([AIMessage(content="ai-msg-1")]) | ||
|
||
assert ( | ||
info.value.args[0] | ||
== "messages list must start with a SystemMessage or UserMessage" | ||
) | ||
|
||
|
||
def test_human_ai_messages_not_alternating(model: Llama2Chat) -> None: | ||
messages = [ | ||
HumanMessage(content="usr-msg-1"), | ||
HumanMessage(content="usr-msg-2"), | ||
HumanMessage(content="ai-msg-1"), | ||
] | ||
|
||
with pytest.raises(ValueError) as info: | ||
model.predict_messages(messages) # type: ignore | ||
|
||
assert info.value.args[0] == ( | ||
"messages must be alternating human- and ai-messages, " | ||
"optionally prepended by a system message" | ||
) | ||
|
||
|
||
def test_last_message_not_human_message(model: Llama2Chat) -> None: | ||
messages = [ | ||
HumanMessage(content="usr-msg-1"), | ||
AIMessage(content="ai-msg-1"), | ||
] | ||
|
||
with pytest.raises(ValueError) as info: | ||
model.predict_messages(messages) | ||
|
||
assert info.value.args[0] == "last message must be a HumanMessage" |
29 changes: 29 additions & 0 deletions
29
libs/langchain/tests/unit_tests/chat_models/test_llm_wrapper_orca.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
|
||
from langchain.chat_models import Orca | ||
from langchain.schema import AIMessage, HumanMessage, SystemMessage | ||
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM | ||
|
||
|
||
@pytest.fixture | ||
def model() -> Orca: | ||
return Orca(llm=FakeLLM()) | ||
|
||
|
||
@pytest.fixture | ||
def model_cfg_sys_msg() -> Orca: | ||
return Orca(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg")) | ||
|
||
|
||
def test_prompt(model: Orca) -> None: | ||
messages = [ | ||
SystemMessage(content="sys-msg"), | ||
HumanMessage(content="usr-msg-1"), | ||
AIMessage(content="ai-msg-1"), | ||
HumanMessage(content="usr-msg-2"), | ||
] | ||
|
||
actual = model.predict_messages(messages).content # type: ignore | ||
expected = "### System:\nsys-msg\n\n### User:\nusr-msg-1\n\n### Assistant:\nai-msg-1\n\n### User:\nusr-msg-2\n\n" # noqa: E501 | ||
|
||
assert actual == expected |
29 changes: 29 additions & 0 deletions
29
libs/langchain/tests/unit_tests/chat_models/test_llm_wrapper_vicuna.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
|
||
from langchain.chat_models import Vicuna | ||
from langchain.schema import AIMessage, HumanMessage, SystemMessage | ||
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM | ||
|
||
|
||
@pytest.fixture | ||
def model() -> Vicuna: | ||
return Vicuna(llm=FakeLLM()) | ||
|
||
|
||
@pytest.fixture | ||
def model_cfg_sys_msg() -> Vicuna: | ||
return Vicuna(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg")) | ||
|
||
|
||
def test_prompt(model: Vicuna) -> None: | ||
messages = [ | ||
SystemMessage(content="sys-msg"), | ||
HumanMessage(content="usr-msg-1"), | ||
AIMessage(content="ai-msg-1"), | ||
HumanMessage(content="usr-msg-2"), | ||
] | ||
|
||
actual = model.predict_messages(messages).content # type: ignore | ||
expected = "sys-msg USER: usr-msg-1 ASSISTANT: ai-msg-1 </s>USER: usr-msg-2 " | ||
|
||
assert actual == expected |