Skip to content

Commit

Permalink
enable serde retrieval qa with sources (#10132)
Browse files Browse the repository at this point in the history
#3983 mentions serialization/deserialization issues with both
`RetrievalQA` & `RetrievalQAWithSourcesChain`.
`RetrievalQA` has already been fixed in #5818. 

Mimicing #5818, I added the logic for `RetrievalQAWithSourcesChain`.

---------

Co-authored-by: Markus Tretzmüller <markus.tretzmueller@cortecs.at>
Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
3 people authored Sep 8, 2023
1 parent 62fa2bc commit b3a8fc7
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
26 changes: 26 additions & 0 deletions libs/langchain/langchain/chains/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
from langchain.llms.loading import load_llm, load_llm_from_config
Expand Down Expand Up @@ -424,6 +425,30 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
)


def _load_retrieval_qa_with_sources_chain(
config: dict, **kwargs: Any
) -> RetrievalQAWithSourcesChain:
if "retriever" in kwargs:
retriever = kwargs.pop("retriever")
else:
raise ValueError("`retriever` must be present.")
if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config)
elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path"))
else:
raise ValueError(
"One of `combine_documents_chain` or "
"`combine_documents_chain_path` must be present."
)
return RetrievalQAWithSourcesChain(
combine_documents_chain=combine_documents_chain,
retriever=retriever,
**config,
)


def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
if "vectorstore" in kwargs:
vectorstore = kwargs.pop("vectorstore")
Expand Down Expand Up @@ -537,6 +562,7 @@ def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
"vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain,
"vector_db_qa": _load_vector_db_qa,
"retrieval_qa": _load_retrieval_qa,
"retrieval_qa_with_sources_chain": _load_retrieval_qa_with_sources_chain,
"graph_cypher_chain": _load_graph_cypher_chain,
}

Expand Down
5 changes: 5 additions & 0 deletions libs/langchain/langchain/chains/qa_with_sources/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,8 @@ async def _aget_docs(
question, callbacks=run_manager.get_child()
)
return self._reduce_tokens_below_limit(docs)

@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "retrieval_qa_with_sources_chain"
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Test RetrievalQA functionality."""
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.loading import load_chain
from langchain.document_loaders import DirectoryLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS


def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None:
"""Test saving and loading."""
loader = DirectoryLoader("docs/extras/modules/", glob="*.txt")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
docsearch = FAISS.from_documents(texts, embeddings)
qa = RetrievalQAWithSourcesChain.from_llm(
llm=OpenAI(), retriever=docsearch.as_retriever()
)
qa("What did the president say about Ketanji Brown Jackson?")

file_path = tmp_path + "/RetrievalQAWithSourcesChain.yaml"
qa.save(file_path=file_path)
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever())

assert qa_loaded == qa

0 comments on commit b3a8fc7

Please sign in to comment.