Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add create_conv_retrieval_chain func #15084

Merged
merged 26 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions libs/core/langchain_core/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable

if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
Expand All @@ -18,8 +18,13 @@
Callbacks,
)

RetrieverInput = str
RetrieverOutput = List[Document]
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
RetrieverOutputLike = Runnable[Any, RetrieverOutput]

class BaseRetriever(RunnableSerializable[str, List[Document]], ABC):

class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"""Abstract base class for a Document retrieval system.

A retrieval system is defined as something that can take string queries and return
Expand Down
9 changes: 8 additions & 1 deletion libs/langchain/langchain/chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain
from langchain.chains.graph_qa.neptune_cypher import NeptuneOpenCypherQAChain
from langchain.chains.graph_qa.sparql import GraphSparqlQAChain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.llm import LLMChain
from langchain.chains.llm_checker.base import LLMCheckerChain
Expand All @@ -65,7 +66,11 @@
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.retrieval_qa.base import (
RetrievalQA,
VectorDBQA,
)
from langchain.chains.router import (
LLMRouterChain,
MultiPromptChain,
Expand Down Expand Up @@ -133,4 +138,6 @@
"generate_example",
"load_chain",
"create_sql_query_chain",
"create_retrieval_chain",
"create_history_aware_retriever",
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables import RunnableConfig
from langchain_core.vectorstores import VectorStore

from langchain.callbacks.manager import (
Expand Down
67 changes: 67 additions & 0 deletions libs/langchain/langchain/chains/history_aware_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from langchain_core.language_models import LanguageModelLike
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.retrievers import RetrieverLike, RetrieverOutputLike
from langchain_core.runnables import RunnableBranch


def create_history_aware_retriever(
llm: LanguageModelLike,
retriever: RetrieverLike,
prompt: BasePromptTemplate,
) -> RetrieverOutputLike:
"""Create a chain that takes conversation history and returns documents.

If there is no `chat_history`, then the `input` is just passed directly to the
retriever. If there is `chat_history`, then the prompt and LLM will be used
to generate a search query. That search query is then passed to the retriever.

Args:
llm: Language model to use for generating a search term given chat history
retriever: RetrieverLike object that takes a string as input and outputs
a list of Documents.
prompt: The prompt used to generate the search query for the retriever.

Returns:
An LCEL Runnable. The runnable input must take in `input`, and if there
is chat history should take it in the form of `chat_history`.
The Runnable output is a list of Documents

Example:
.. code-block:: python

# pip install -U langchain langchain-community

from langchain_community.chat_models import ChatOpenAI
from langchain.chains import create_chat_history_retriever
from langchain import hub

rephrase_prompt = hub.pull("langchain-ai/chat-langchain-rephrase")
llm = ChatOpenAI()
retriever = ...
chat_retriever_chain = create_chat_retriever_chain(
llm, retriever, rephrase_prompt
)

chain.invoke({"input": "...", "chat_history": })

"""
if "input" not in prompt.input_variables:
raise ValueError(
"Expected `input` to be a prompt variable, "
f"but got {prompt.input_variables}"
)

retrieve_documents: RetrieverOutputLike = RunnableBranch(
(
# Both empty string and empty list evaluate to False
lambda x: not x.get("chat_history", False),
# If no chat history, then we just pass input to retriever
(lambda x: x["input"]) | retriever,
),
# If chat history, then we pass inputs to LLM chain, then to retriever
prompt | llm | StrOutputParser() | retriever,
).with_config(run_name="chat_retriever_chain")
return retrieve_documents
71 changes: 71 additions & 0 deletions libs/langchain/langchain/chains/retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

from typing import Any, Dict, Union

from langchain_core.retrievers import (
BaseRetriever,
RetrieverOutput,
)
from langchain_core.runnables import Runnable, RunnablePassthrough


def create_retrieval_chain(
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
combine_docs_chain: Runnable[Dict[str, Any], str],
) -> Runnable:
"""Create retrieval chain that retrieves documents and then passes them on.

Args:
retriever: Retriever-like object that returns list of documents. Should
either be a subclass of BaseRetriever or a Runnable that returns
a list of documents. If a subclass of BaseRetriever, then it
is expected that an `input` key be passed in - this is what
is will be used to pass into the retriever. If this is NOT a
subclass of BaseRetriever, then all the inputs will be passed
into this runnable, meaning that runnable should take a dictionary
as input.
combine_docs_chain: Runnable that takes inputs and produces a string output.
The inputs to this will be any original inputs to this chain, a new
context key with the retrieved documents, and chat_history (if not present
in the inputs) with a value of `[]` (to easily enable conversational
retrieval.

Returns:
An LCEL Runnable. The Runnable return is a dictionary containing at the very
least a `context` and `answer` key.

Example:
.. code-block:: python

# pip install -U langchain langchain-community

from langchain_community.chat_models import ChatOpenAI
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain import hub

retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
llm = ChatOpenAI()
retriever = ...
combine_docs_chain = create_stuff_documents_chain(
llm, retrieval_qa_chat_prompt
)
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)

chain.invoke({"input": "..."})

"""
if not isinstance(retriever, BaseRetriever):
retrieval_docs: Runnable[dict, RetrieverOutput] = retriever
else:
retrieval_docs = (lambda x: x["input"]) | retriever

retrieval_chain = (
RunnablePassthrough.assign(
context=retrieval_docs.with_config(run_name="retrieve_documents"),
chat_history=lambda x: x.get("chat_history", []),
)
| RunnablePassthrough.assign(answer=combine_docs_chain)
).with_config(run_name="retrieval_chain")

return retrieval_chain
Loading
Loading