Skip to content

Commit

Permalink
langchain[minor]: Add retriever for Knowledge Bases for Amazon Bedrock (
Browse files Browse the repository at this point in the history
#13980)

- **Description:** Adds a retriever implementation for [Knowledge Bases
for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/), a
new service announced at AWS re:Invent, shortly before this PR was
opened. This depends on the `bedrock-agent-runtime` service, which will
be included in a future version of `boto3` and of `botocore`. We will
open a follow-up PR documenting the minimum required versions of `boto3`
and `botocore` after that information is available.
  - **Issue:** N/A
  - **Dependencies:** `boto3>=1.33.2, botocore>=1.33.2`
  - **Tag maintainer:** @baskaryan
  - **Twitter handles:** `@pjain7` `@dead_letter_q`

This PR includes a documentation notebook under
`docs/docs/integrations/retrievers`, which I (@dlqqq) have verified
independently.

EDIT: `bedrock-agent-runtime` service is now included in
`boto3>=1.33.2`:
boto/boto3@5cf793f

---------

Co-authored-by: Piyush Jain <piyushjain@duck.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
4 people committed Nov 28, 2023
1 parent 1aed2d1 commit 9fb6805
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 0 deletions.
117 changes: 117 additions & 0 deletions docs/docs/integrations/retrievers/bedrock.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "b6636c27-35da-4ba7-8313-eca21660cab3",
"metadata": {},
"source": [
"# Amazon Bedrock (Knowledge Bases)\n",
"\n",
"> [Knowledge bases for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/) is an Amazon Web Services (AWS) offering which lets you quickly build RAG applications by using your private data to customize FM response.\n",
"\n",
"> Implementing RAG requires organizations to perform several cumbersome steps to convert data into embeddings (vectors), store the embeddings in a specialized vector database, and build custom integrations into the database to search and retrieve text relevant to the user’s query. This can be time-consuming and inefficient.\n",
"\n",
"> With Knowledge Bases for Amazon Bedrock, simply point to the location of your data in Amazon S3, and Knowledge Bases for Amazon Bedrock takes care of the entire ingestion workflow into your vector database. If you do not have an existing vector database, Amazon Bedrock creates an Amazon OpenSearch Serverless vector store for you. For retrievals, use the Langchain - Amazon Bedrock integration via the Retrieve API to retrieve relevant results for a user query from knowledge bases.\n",
"\n",
"> Knowledge base can be configured through [AWS Console](https://aws.amazon.com/console/) or by using [AWS SDKs](https://aws.amazon.com/developer/tools/)."
]
},
{
"cell_type": "markdown",
"id": "b34c8cbe-c6e5-4398-adf1-4925204bcaed",
"metadata": {},
"source": [
"## Using the Knowledge Bases Retriever"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26c97d36-911c-4fe0-a478-546192728f30",
"metadata": {},
"outputs": [],
"source": [
"%pip install boto3"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "30337664-8844-4dfe-97db-077abb51af68",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import AmazonKnowledgeBasesRetriever\n",
"\n",
"retriever = AmazonKnowledgeBasesRetriever(\n",
" knowledge_base_id=\"PUIJP4EQUA\",\n",
" retrieval_config={\"vectorSearchConfiguration\": {\"numberOfResults\": 4}},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9fefa50-f0fb-40e3-b4e4-67c5b232a090",
"metadata": {},
"outputs": [],
"source": [
"query = \"What did the president say about Ketanji Brown?\"\n",
"\n",
"retriever.get_relevant_documents(query=query)"
]
},
{
"cell_type": "markdown",
"id": "7de9b61b-597b-4aba-95fb-49d11e84510e",
"metadata": {},
"source": [
"### Using in a QA Chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fd71709-aaed-42b5-a990-e3067bfa7143",
"metadata": {},
"outputs": [],
"source": [
"from botocore.client import Config\n",
"\n",
"from langchain.chains import RetrievalQA\n",
"from langchain.llms import Bedrock\n",
"\n",
"model_kwargs_claude = {\"temperature\": 0, \"top_k\": 10, \"max_tokens_to_sample\": 3000}\n",
"\n",
"llm = Bedrock(model_id=\"anthropic.claude-v2\", model_kwargs=model_kwargs_claude)\n",
"\n",
"qa = RetrievalQA.from_chain_type(\n",
" llm=llm, retriever=retriever, return_source_documents=True\n",
")\n",
"\n",
"qa(query)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions libs/langchain/langchain/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from langchain.retrievers.arcee import ArceeRetriever
from langchain.retrievers.arxiv import ArxivRetriever
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
from langchain.retrievers.bedrock import AmazonKnowledgeBasesRetriever
from langchain.retrievers.bm25 import BM25Retriever
from langchain.retrievers.chaindesk import ChaindeskRetriever
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
Expand Down Expand Up @@ -72,6 +73,7 @@

__all__ = [
"AmazonKendraRetriever",
"AmazonKnowledgeBasesRetriever",
"ArceeRetriever",
"ArxivRetriever",
"AzureCognitiveSearchRetriever",
Expand Down
124 changes: 124 additions & 0 deletions libs/langchain/langchain/retrievers/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import Any, Dict, List, Optional

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.retrievers import BaseRetriever


class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
numberOfResults: int = 4


class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
vectorSearchConfiguration: VectorSearchConfig


class AmazonKnowledgeBasesRetriever(BaseRetriever):
"""A retriever class for `Amazon Bedrock Knowledge Bases`.
See https://aws.amazon.com/bedrock/knowledge-bases for more info.
Args:
knowledge_base_id: Knowledge Base ID.
region_name: The aws region e.g., `us-west-2`.
Fallback to AWS_DEFAULT_REGION env variable or region specified in
~/.aws/config.
credentials_profile_name: The name of the profile in the ~/.aws/credentials
or ~/.aws/config files, which has either access keys or role information
specified. If not specified, the default credential profile or, if on an
EC2 instance, credentials from IMDS will be used.
client: boto3 client for bedrock agent runtime.
retrieval_config: Configuration for retrieval.
Example:
.. code-block:: python
from langchain.retrievers import AmazonKnowledgeBasesRetriever
retriever = AmazonKnowledgeBasesRetriever(
knowledge_base_id="<knowledge-base-id>",
retrieval_config={
"vectorSearchConfiguration": {
"numberOfResults": 4
}
},
)
"""

knowledge_base_id: str
region_name: Optional[str] = None
credentials_profile_name: Optional[str] = None
endpoint_url: Optional[str] = None
client: Any
retrieval_config: RetrievalConfig

@root_validator(pre=True)
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("client") is not None:
return values

try:
import boto3
from botocore.client import Config
from botocore.exceptions import UnknownServiceError

if values.get("credentials_profile_name"):
session = boto3.Session(profile_name=values["credentials_profile_name"])
else:
# use default credentials
session = boto3.Session()

client_params = {
"config": Config(
connect_timeout=120, read_timeout=120, retries={"max_attempts": 0}
)
}
if values.get("region_name"):
client_params["region_name"] = values["region_name"]

if values.get("endpoint_url"):
client_params["endpoint_url"] = values["endpoint_url"]

values["client"] = session.client("bedrock-agent-runtime", **client_params)

return values
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except UnknownServiceError as e:
raise ModuleNotFoundError(
"Ensure that you have installed the latest boto3 package "
"that contains the API for `bedrock-runtime-agent`."
) from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
response = self.client.retrieve(
retrievalQuery={"text": query.strip()},
knowledgeBaseId=self.knowledge_base_id,
retrievalConfiguration=self.retrieval_config.dict(),
)
results = response["retrievalResults"]
documents = []
for result in results:
documents.append(
Document(
page_content=result["content"]["text"],
metadata={
"location": result["location"],
"score": result["score"] if "score" in result else 0,
},
)
)

return documents
1 change: 1 addition & 0 deletions libs/langchain/tests/unit_tests/retrievers/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

EXPECTED_ALL = [
"AmazonKendraRetriever",
"AmazonKnowledgeBasesRetriever",
"ArceeRetriever",
"ArxivRetriever",
"AzureCognitiveSearchRetriever",
Expand Down

0 comments on commit 9fb6805

Please sign in to comment.