Skip to content

Commit

Permalink
Generic LLM wrapper to support chat model interface with configurable…
Browse files Browse the repository at this point in the history
… chat prompt format
  • Loading branch information
krasserm committed Sep 23, 2023
1 parent b809c24 commit bc99b88
Show file tree
Hide file tree
Showing 6 changed files with 1,100 additions and 0 deletions.
718 changes: 718 additions & 0 deletions docs/extras/integrations/chat/llama2_chat.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions libs/langchain/langchain/chat_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from langchain.chat_models.jinachat import JinaChat
from langchain.chat_models.konko import ChatKonko
from langchain.chat_models.litellm import ChatLiteLLM
from langchain.chat_models.llm_wrapper import Llama2Chat, Orca, Vicuna
from langchain.chat_models.minimax import MiniMaxChat
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
from langchain.chat_models.ollama import ChatOllama
Expand All @@ -49,6 +50,9 @@
"ChatOllama",
"ChatVertexAI",
"JinaChat",
"Llama2Chat",
"Orca",
"Vicuna",
"HumanInputChatModel",
"MiniMaxChat",
"ChatAnyscale",
Expand Down
163 changes: 163 additions & 0 deletions libs/langchain/langchain/chat_models/llm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Generic Wrapper for chat LLMs, with sample implementations
for Llama-2-chat, Llama-2-instruct and Vicuna models.
"""
from typing import Any, List, Optional

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import LLM
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
SystemMessage,
)

DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501


class ChatWrapper(BaseChatModel):
llm: LLM
sys_beg: str
sys_end: str
ai_n_beg: str
ai_n_end: str
usr_n_beg: str
usr_n_end: str
usr_0_beg: Optional[str] = None
usr_0_end: Optional[str] = None

system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
return self._to_chat_result(llm_result)

def _to_chat_prompt(
self,
messages: List[BaseMessage],
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
if not messages:
raise ValueError("at least one HumanMessage must be provided")

if not isinstance(messages[0], SystemMessage):
messages = [self.system_message] + messages

if not isinstance(messages[1], HumanMessage):
raise ValueError(
"messages list must start with a SystemMessage or UserMessage"
)

if not isinstance(messages[-1], HumanMessage):
raise ValueError("last message must be a HumanMessage")

prompt_parts = []

if self.usr_0_beg is None:
self.usr_0_beg = self.usr_n_beg

if self.usr_0_end is None:
self.usr_0_end = self.usr_n_end

prompt_parts.append(self.sys_beg + messages[0].content + self.sys_end)
prompt_parts.append(self.usr_0_beg + messages[1].content + self.usr_0_end)

for ai_message, human_message in zip(messages[2::2], messages[3::2]):
if not isinstance(ai_message, AIMessage) or not isinstance(
human_message, HumanMessage
):
raise ValueError(
"messages must be alternating human- and ai-messages, "
"optionally prepended by a system message"
)

prompt_parts.append(self.ai_n_beg + ai_message.content + self.ai_n_end)
prompt_parts.append(self.usr_n_beg + human_message.content + self.usr_n_end)

return "".join(prompt_parts)

@staticmethod
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []

for g in llm_result.generations[0]:
chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
)
chat_generations.append(chat_generation)

return ChatResult(
generations=chat_generations, llm_output=llm_result.llm_output
)


class Llama2Chat(ChatWrapper):
@property
def _llm_type(self) -> str:
return "llama-2-chat"

sys_beg: str = "<s>[INST] <<SYS>>\n"
sys_end: str = "\n<</SYS>>\n\n"
ai_n_beg: str = " "
ai_n_end: str = " </s>"
usr_n_beg: str = "<s>[INST] "
usr_n_end: str = " [/INST]"
usr_0_beg: str = ""
usr_0_end: str = " [/INST]"


class Orca(ChatWrapper):
@property
def _llm_type(self) -> str:
return "orca-style"

sys_beg: str = "### System:\n"
sys_end: str = "\n\n"
ai_n_beg: str = "### Assistant:\n"
ai_n_end: str = "\n\n"
usr_n_beg: str = "### User:\n"
usr_n_end: str = "\n\n"


class Vicuna(ChatWrapper):
@property
def _llm_type(self) -> str:
return "vicuna-style"

sys_beg: str = ""
sys_end: str = " "
ai_n_beg: str = "ASSISTANT: "
ai_n_end: str = " </s>"
usr_n_beg: str = "USER: "
usr_n_end: str = " "
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import Any, List, Optional

import pytest

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models import Llama2Chat
from langchain.chat_models.llm_wrapper import DEFAULT_SYSTEM_PROMPT
from langchain.llms.base import LLM
from langchain.schema import AIMessage, HumanMessage, SystemMessage


class FakeLLM(LLM):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return prompt

async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return prompt

@property
def _llm_type(self) -> str:
return "fake-llm"


@pytest.fixture
def model() -> Llama2Chat:
return Llama2Chat(llm=FakeLLM())


@pytest.fixture
def model_cfg_sys_msg() -> Llama2Chat:
return Llama2Chat(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))


def test_default_system_message(model: Llama2Chat) -> None:
messages = [HumanMessage(content="usr-msg-1")]

actual = model.predict_messages(messages).content # type: ignore
expected = (
f"<s>[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>>\n\nusr-msg-1 [/INST]"
)

assert actual == expected


def test_configured_system_message(
model_cfg_sys_msg: Llama2Chat,
) -> None:
messages = [HumanMessage(content="usr-msg-1")]

actual = model_cfg_sys_msg.predict_messages(messages).content # type: ignore
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"

assert actual == expected


@pytest.mark.asyncio
async def test_configured_system_message_async(
model_cfg_sys_msg: Llama2Chat,
) -> None:
messages = [HumanMessage(content="usr-msg-1")]

actual = await model_cfg_sys_msg.apredict_messages(messages) # type: ignore
expected = "<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"

assert actual.content == expected


def test_provided_system_message(
model_cfg_sys_msg: Llama2Chat,
) -> None:
messages = [
SystemMessage(content="custom-sys-msg"),
HumanMessage(content="usr-msg-1"),
]

actual = model_cfg_sys_msg.predict_messages(messages).content
expected = "<s>[INST] <<SYS>>\ncustom-sys-msg\n<</SYS>>\n\nusr-msg-1 [/INST]"

assert actual == expected


def test_human_ai_dialogue(model_cfg_sys_msg: Llama2Chat) -> None:
messages = [
HumanMessage(content="usr-msg-1"),
AIMessage(content="ai-msg-1"),
HumanMessage(content="usr-msg-2"),
AIMessage(content="ai-msg-2"),
HumanMessage(content="usr-msg-3"),
]

actual = model_cfg_sys_msg.predict_messages(messages).content
expected = (
"<s>[INST] <<SYS>>\nsys-msg\n<</SYS>>\n\nusr-msg-1 [/INST] ai-msg-1 </s>"
"<s>[INST] usr-msg-2 [/INST] ai-msg-2 </s><s>[INST] usr-msg-3 [/INST]"
)

assert actual == expected


def test_no_message(model: Llama2Chat) -> None:
with pytest.raises(ValueError) as info:
model.predict_messages([])

assert info.value.args[0] == "at least one HumanMessage must be provided"


def test_ai_message_first(model: Llama2Chat) -> None:
with pytest.raises(ValueError) as info:
model.predict_messages([AIMessage(content="ai-msg-1")])

assert (
info.value.args[0]
== "messages list must start with a SystemMessage or UserMessage"
)


def test_human_ai_messages_not_alternating(model: Llama2Chat) -> None:
messages = [
HumanMessage(content="usr-msg-1"),
HumanMessage(content="usr-msg-2"),
HumanMessage(content="ai-msg-1"),
]

with pytest.raises(ValueError) as info:
model.predict_messages(messages) # type: ignore

assert info.value.args[0] == (
"messages must be alternating human- and ai-messages, "
"optionally prepended by a system message"
)


def test_last_message_not_human_message(model: Llama2Chat) -> None:
messages = [
HumanMessage(content="usr-msg-1"),
AIMessage(content="ai-msg-1"),
]

with pytest.raises(ValueError) as info:
model.predict_messages(messages)

assert info.value.args[0] == "last message must be a HumanMessage"
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from langchain.chat_models import Orca
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM


@pytest.fixture
def model() -> Orca:
return Orca(llm=FakeLLM())


@pytest.fixture
def model_cfg_sys_msg() -> Orca:
return Orca(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))


def test_prompt(model: Orca) -> None:
messages = [
SystemMessage(content="sys-msg"),
HumanMessage(content="usr-msg-1"),
AIMessage(content="ai-msg-1"),
HumanMessage(content="usr-msg-2"),
]

actual = model.predict_messages(messages).content # type: ignore
expected = "### System:\nsys-msg\n\n### User:\nusr-msg-1\n\n### Assistant:\nai-msg-1\n\n### User:\nusr-msg-2\n\n" # noqa: E501

assert actual == expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from langchain.chat_models import Vicuna
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from tests.unit_tests.chat_models.test_llm_wrapper_llama2chat import FakeLLM


@pytest.fixture
def model() -> Vicuna:
return Vicuna(llm=FakeLLM())


@pytest.fixture
def model_cfg_sys_msg() -> Vicuna:
return Vicuna(llm=FakeLLM(), system_message=SystemMessage(content="sys-msg"))


def test_prompt(model: Vicuna) -> None:
messages = [
SystemMessage(content="sys-msg"),
HumanMessage(content="usr-msg-1"),
AIMessage(content="ai-msg-1"),
HumanMessage(content="usr-msg-2"),
]

actual = model.predict_messages(messages).content # type: ignore
expected = "sys-msg USER: usr-msg-1 ASSISTANT: ai-msg-1 </s>USER: usr-msg-2 "

assert actual == expected

0 comments on commit bc99b88

Please sign in to comment.