-
Notifications
You must be signed in to change notification settings - Fork 15.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enable serde retrieval qa with sources (#10132)
#3983 mentions serialization/deserialization issues with both `RetrievalQA` & `RetrievalQAWithSourcesChain`. `RetrievalQA` has already been fixed in #5818. Mimicing #5818, I added the logic for `RetrievalQAWithSourcesChain`. --------- Co-authored-by: Markus Tretzmüller <markus.tretzmueller@cortecs.at> Co-authored-by: Bagatur <baskaryan@gmail.com>
- Loading branch information
1 parent
62fa2bc
commit b3a8fc7
Showing
3 changed files
with
59 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
28 changes: 28 additions & 0 deletions
28
libs/langchain/tests/integration_tests/chains/test_retrieval_qa_with_sources.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Test RetrievalQA functionality.""" | ||
from langchain.chains import RetrievalQAWithSourcesChain | ||
from langchain.chains.loading import load_chain | ||
from langchain.document_loaders import DirectoryLoader | ||
from langchain.embeddings.openai import OpenAIEmbeddings | ||
from langchain.llms import OpenAI | ||
from langchain.text_splitter import CharacterTextSplitter | ||
from langchain.vectorstores import FAISS | ||
|
||
|
||
def test_retrieval_qa_with_sources_chain_saving_loading(tmp_path: str) -> None: | ||
"""Test saving and loading.""" | ||
loader = DirectoryLoader("docs/extras/modules/", glob="*.txt") | ||
documents = loader.load() | ||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | ||
texts = text_splitter.split_documents(documents) | ||
embeddings = OpenAIEmbeddings() | ||
docsearch = FAISS.from_documents(texts, embeddings) | ||
qa = RetrievalQAWithSourcesChain.from_llm( | ||
llm=OpenAI(), retriever=docsearch.as_retriever() | ||
) | ||
qa("What did the president say about Ketanji Brown Jackson?") | ||
|
||
file_path = tmp_path + "/RetrievalQAWithSourcesChain.yaml" | ||
qa.save(file_path=file_path) | ||
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever()) | ||
|
||
assert qa_loaded == qa |