Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

add concept of prompt collection #1507

Merged
merged 7 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/modules/chains/key_concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,12 @@ This is a specific type of chain where multiple other chains are run in sequence
to the next. A subtype of this type of chain is the [`SimpleSequentialChain`](./generic/sequential_chains.html#simplesequentialchain), where all subchains have only one input and one output,
and the output of one is therefore used as sole input to the next chain.

## Prompt Selectors
One thing that we've noticed is that the best prompt to use is really dependent on the model you use.
Some prompts work really good with some models, but not great with others.
One of our goals is provide good chains that "just work" out of the box.
A big part of chains like that is having prompts that "just work".
So rather than having a default prompt for chains, we are moving towards a paradigm where if a prompt is not explicitly
provided we select one with a PromptSelector. This class takes in the model passed in, and returns a default prompt.
The inner workings of the PromptSelector can look at any aspect of the model - LLM vs ChatModel, OpenAI vs Cohere, GPT3 vs GPT4, etc.
Due to this being a newer feature, this may not be implemented for all chains, but this is the direction we are moving.
10 changes: 5 additions & 5 deletions langchain/chains/chat_vector_db/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""Chain for chatting with a vector database."""
from __future__ import annotations

from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

from pydantic import BaseModel

from langchain.chains.base import Chain
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.vectorstores.base import VectorStore


Expand Down Expand Up @@ -58,10 +58,10 @@ def output_keys(self) -> List[str]:
@classmethod
def from_llm(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
vectorstore: VectorStore,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
qa_prompt: BasePromptTemplate = QA_PROMPT,
qa_prompt: Optional[BasePromptTemplate] = None,
chain_type: str = "stuff",
**kwargs: Any,
) -> ChatVectorDBChain:
Expand Down
38 changes: 38 additions & 0 deletions langchain/chains/prompt_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple

from pydantic import BaseModel, Field

from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel


class BasePromptSelector(BaseModel, ABC):
@abstractmethod
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
"""Get default prompt for a language model."""


class ConditionalPromptSelector(BasePromptSelector):
"""Prompt collection that goes through conditionals."""

default_prompt: BasePromptTemplate
conditionals: List[
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
] = Field(default_factory=list)

def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
for condition, prompt in self.conditionals:
if condition(llm):
return prompt
return self.default_prompt


def is_llm(llm: BaseLanguageModel) -> bool:
return isinstance(llm, BaseLLM)


def is_chat_model(llm: BaseLanguageModel) -> bool:
return isinstance(llm, BaseChatModel)
53 changes: 34 additions & 19 deletions langchain/chains/question_answering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@
refine_prompts,
stuff_prompt,
)
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel


class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""

def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
def __call__(
self, llm: BaseLanguageModel, **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""


def _load_map_rerank_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
verbose: bool = False,
document_variable_name: str = "context",
Expand All @@ -50,13 +52,14 @@ def _load_map_rerank_chain(


def _load_stuff_chain(
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
document_variable_name: str = "context",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
)
Expand All @@ -71,28 +74,34 @@ def _load_stuff_chain(


def _load_map_reduce_chain(
llm: BaseLLM,
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
llm: BaseLanguageModel,
question_prompt: Optional[BasePromptTemplate] = None,
combine_prompt: Optional[BasePromptTemplate] = None,
combine_document_variable_name: str = "summaries",
map_reduce_document_variable_name: str = "context",
collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None,
reduce_llm: Optional[BaseLanguageModel] = None,
collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
_question_prompt = (
question_prompt or map_reduce_prompt.QUESTION_PROMPT_SELECTOR.get_prompt(llm)
)
_combine_prompt = (
combine_prompt or map_reduce_prompt.COMBINE_PROMPT_SELECTOR.get_prompt(llm)
)
map_chain = LLMChain(
llm=llm,
prompt=question_prompt,
prompt=_question_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(
llm=_reduce_llm,
prompt=combine_prompt,
prompt=_combine_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
Expand Down Expand Up @@ -135,26 +144,32 @@ def _load_map_reduce_chain(


def _load_refine_chain(
llm: BaseLLM,
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
llm: BaseLanguageModel,
question_prompt: Optional[BasePromptTemplate] = None,
refine_prompt: Optional[BasePromptTemplate] = None,
document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None,
refine_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
_question_prompt = (
question_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt(llm)
)
_refine_prompt = refine_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt(
llm
)
initial_chain = LLMChain(
llm=llm,
prompt=question_prompt,
prompt=_question_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
_refine_llm = refine_llm or llm
refine_chain = LLMChain(
llm=_refine_llm,
prompt=refine_prompt,
prompt=_refine_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
Expand All @@ -170,7 +185,7 @@ def _load_refine_chain(


def load_qa_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
Expand Down
40 changes: 39 additions & 1 deletion langchain/chains/question_answering/map_reduce_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# flake8: noqa
from langchain.prompts import PromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.chat import (
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
ChatPromptTemplate,
)
from langchain.chains.prompt_selector import (
ConditionalPromptSelector,
is_chat_model,
)

question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question.
Return any relevant text verbatim.
Expand All @@ -9,6 +18,20 @@
QUESTION_PROMPT = PromptTemplate(
template=question_prompt_template, input_variables=["context", "question"]
)
system_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question.
Return any relevant text verbatim.
______________________
{context}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(messages)


QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=QUESTION_PROMPT, conditionals=[(is_chat_model, CHAT_QUESTION_PROMPT)]
)

combine_prompt_template = """Given the following extracted parts of a long document and a question, create a final answer.
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
Expand Down Expand Up @@ -43,3 +66,18 @@
COMBINE_PROMPT = PromptTemplate(
template=combine_prompt_template, input_variables=["summaries", "question"]
)

system_template = """Given the following extracted parts of a long document and a question, create a final answer.
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
______________________
{summaries}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_COMBINE_PROMPT = ChatPromptTemplate.from_messages(messages)


COMBINE_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=COMBINE_PROMPT, conditionals=[(is_chat_model, CHAT_COMBINE_PROMPT)]
)
50 changes: 49 additions & 1 deletion langchain/chains/question_answering/refine_prompts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
# flake8: noqa
from langchain.prompts import PromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.chat import (
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
ChatPromptTemplate,
AIMessagePromptTemplate,
)
from langchain.chains.prompt_selector import (
ConditionalPromptSelector,
is_chat_model,
)


DEFAULT_REFINE_PROMPT_TMPL = (
"The original question is as follows: {question}\n"
Expand All @@ -17,6 +28,26 @@
input_variables=["question", "existing_answer", "context_str"],
template=DEFAULT_REFINE_PROMPT_TMPL,
)
refine_template = (
"We have the opportunity to refine the existing answer"
"(only if needed) with some more context below.\n"
"------------\n"
"{context_str}\n"
"------------\n"
"Given the new context, refine the original answer to better "
"answer the question. "
"If the context isn't useful, return the original answer."
)
messages = [
HumanMessagePromptTemplate.from_template("{question}"),
AIMessagePromptTemplate.from_template("{existing_answer}"),
HumanMessagePromptTemplate.from_template(refine_template),
]
CHAT_REFINE_PROMPT = ChatPromptTemplate.from_messages(messages)
REFINE_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_REFINE_PROMPT,
conditionals=[(is_chat_model, CHAT_REFINE_PROMPT)],
)


DEFAULT_TEXT_QA_PROMPT_TMPL = (
Expand All @@ -30,3 +61,20 @@
DEFAULT_TEXT_QA_PROMPT = PromptTemplate(
input_variables=["context_str", "question"], template=DEFAULT_TEXT_QA_PROMPT_TMPL
)
chat_qa_prompt_template = (
"Context information is below. \n"
"---------------------\n"
"{context_str}"
"\n---------------------\n"
"Given the context information and not prior knowledge, "
"answer any questions"
)
messages = [
SystemMessagePromptTemplate.from_template(chat_qa_prompt_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(messages)
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_TEXT_QA_PROMPT,
conditionals=[(is_chat_model, CHAT_QUESTION_PROMPT)],
)
25 changes: 25 additions & 0 deletions langchain/chains/question_answering/stuff_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# flake8: noqa
from langchain.prompts import PromptTemplate
from langchain.chains.prompt_selector import (
ConditionalPromptSelector,
is_chat_model,
)
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)


prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

Expand All @@ -10,3 +20,18 @@
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)

system_template = """Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----------------
{context}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)


PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=PROMPT, conditionals=[(is_chat_model, CHAT_PROMPT)]
)