Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement saving and loading of RetrievalQA chain #5818

Merged
merged 2 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 22 additions & 1 deletion langchain/chains/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from langchain.chains.pal.base import PALChain
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import VectorDBQA
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.llms.loading import load_llm, load_llm_from_config
from langchain.prompts.loading import load_prompt, load_prompt_from_config
Expand Down Expand Up @@ -371,6 +371,26 @@ def _load_vector_db_qa_with_sources_chain(
**config,
)

def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
if "retriever" in kwargs:
retriever = kwargs.pop("retriever")
else:
raise ValueError("`retriever` must be present.")
if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config)
elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path"))
else:
raise ValueError(
"One of `combine_documents_chain` or "
"`combine_documents_chain_path` must be present."
)
return RetrievalQA(
combine_documents_chain=combine_documents_chain,
retriever=retriever,
**config,
)

def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
if "vectorstore" in kwargs:
Expand Down Expand Up @@ -459,6 +479,7 @@ def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
"sql_database_chain": _load_sql_database_chain,
"vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain,
"vector_db_qa": _load_vector_db_qa,
"retrieval_qa": _load_retrieval_qa,
}


Expand Down
5 changes: 5 additions & 0 deletions langchain/chains/retrieval_qa/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ def _get_docs(self, question: str) -> List[Document]:
async def _aget_docs(self, question: str) -> List[Document]:
return await self.retriever.aget_relevant_documents(question)

@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "retrieval_qa"


class VectorDBQA(BaseRetrievalQA):
"""Chain for question-answering against a vector database."""
Expand Down
29 changes: 29 additions & 0 deletions tests/unit_tests/chains/test_retrieval_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Test RetrievalQA functionality."""
from pathlib import Path

import pytest

from langchain.chains import RetrievalQA
from langchain.chains.loading import load_chain
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma


def test_retrieval_qa_saving_loading(tmp_path: Path) -> None:
"""Test saving and loading."""
loader = TextLoader('docs/modules/state_of_the_union.txt')
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
docsearch = Chroma.from_documents(texts, embeddings)
qa = RetrievalQA.from_llm(llm=OpenAI(), retriever=docsearch.as_retriever())

file_path = tmp_path / "RetrievalQA_chain.yaml"
qa.save(file_path=file_path)
qa_loaded = load_chain(file_path, retriever=docsearch.as_retriever())

assert qa_loaded == qa