-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
langchain[minor]: Add retriever for Knowledge Bases for Amazon Bedrock (
#13980) - **Description:** Adds a retriever implementation for [Knowledge Bases for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/), a new service announced at AWS re:Invent, shortly before this PR was opened. This depends on the `bedrock-agent-runtime` service, which will be included in a future version of `boto3` and of `botocore`. We will open a follow-up PR documenting the minimum required versions of `boto3` and `botocore` after that information is available. - **Issue:** N/A - **Dependencies:** `boto3>=1.33.2, botocore>=1.33.2` - **Tag maintainer:** @baskaryan - **Twitter handles:** `@pjain7` `@dead_letter_q` This PR includes a documentation notebook under `docs/docs/integrations/retrievers`, which I (@dlqqq) have verified independently. EDIT: `bedrock-agent-runtime` service is now included in `boto3>=1.33.2`: boto/boto3@5cf793f --------- Co-authored-by: Piyush Jain <piyushjain@duck.com> Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Bagatur <baskaryan@gmail.com>
- Loading branch information
1 parent
1aed2d1
commit 9fb6805
Showing
4 changed files
with
244 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b6636c27-35da-4ba7-8313-eca21660cab3", | ||
"metadata": {}, | ||
"source": [ | ||
"# Amazon Bedrock (Knowledge Bases)\n", | ||
"\n", | ||
"> [Knowledge bases for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/) is an Amazon Web Services (AWS) offering which lets you quickly build RAG applications by using your private data to customize FM response.\n", | ||
"\n", | ||
"> Implementing RAG requires organizations to perform several cumbersome steps to convert data into embeddings (vectors), store the embeddings in a specialized vector database, and build custom integrations into the database to search and retrieve text relevant to the user’s query. This can be time-consuming and inefficient.\n", | ||
"\n", | ||
"> With Knowledge Bases for Amazon Bedrock, simply point to the location of your data in Amazon S3, and Knowledge Bases for Amazon Bedrock takes care of the entire ingestion workflow into your vector database. If you do not have an existing vector database, Amazon Bedrock creates an Amazon OpenSearch Serverless vector store for you. For retrievals, use the Langchain - Amazon Bedrock integration via the Retrieve API to retrieve relevant results for a user query from knowledge bases.\n", | ||
"\n", | ||
"> Knowledge base can be configured through [AWS Console](https://aws.amazon.com/console/) or by using [AWS SDKs](https://aws.amazon.com/developer/tools/)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b34c8cbe-c6e5-4398-adf1-4925204bcaed", | ||
"metadata": {}, | ||
"source": [ | ||
"## Using the Knowledge Bases Retriever" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "26c97d36-911c-4fe0-a478-546192728f30", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install boto3" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "30337664-8844-4dfe-97db-077abb51af68", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.retrievers import AmazonKnowledgeBasesRetriever\n", | ||
"\n", | ||
"retriever = AmazonKnowledgeBasesRetriever(\n", | ||
" knowledge_base_id=\"PUIJP4EQUA\",\n", | ||
" retrieval_config={\"vectorSearchConfiguration\": {\"numberOfResults\": 4}},\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f9fefa50-f0fb-40e3-b4e4-67c5b232a090", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"query = \"What did the president say about Ketanji Brown?\"\n", | ||
"\n", | ||
"retriever.get_relevant_documents(query=query)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7de9b61b-597b-4aba-95fb-49d11e84510e", | ||
"metadata": {}, | ||
"source": [ | ||
"### Using in a QA Chain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0fd71709-aaed-42b5-a990-e3067bfa7143", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from botocore.client import Config\n", | ||
"\n", | ||
"from langchain.chains import RetrievalQA\n", | ||
"from langchain.llms import Bedrock\n", | ||
"\n", | ||
"model_kwargs_claude = {\"temperature\": 0, \"top_k\": 10, \"max_tokens_to_sample\": 3000}\n", | ||
"\n", | ||
"llm = Bedrock(model_id=\"anthropic.claude-v2\", model_kwargs=model_kwargs_claude)\n", | ||
"\n", | ||
"qa = RetrievalQA.from_chain_type(\n", | ||
" llm=llm, retriever=retriever, return_source_documents=True\n", | ||
")\n", | ||
"\n", | ||
"qa(query)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
from langchain_core.callbacks import CallbackManagerForRetrieverRun | ||
from langchain_core.documents import Document | ||
from langchain_core.pydantic_v1 import BaseModel, root_validator | ||
from langchain_core.retrievers import BaseRetriever | ||
|
||
|
||
class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg] | ||
numberOfResults: int = 4 | ||
|
||
|
||
class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg] | ||
vectorSearchConfiguration: VectorSearchConfig | ||
|
||
|
||
class AmazonKnowledgeBasesRetriever(BaseRetriever): | ||
"""A retriever class for `Amazon Bedrock Knowledge Bases`. | ||
See https://aws.amazon.com/bedrock/knowledge-bases for more info. | ||
Args: | ||
knowledge_base_id: Knowledge Base ID. | ||
region_name: The aws region e.g., `us-west-2`. | ||
Fallback to AWS_DEFAULT_REGION env variable or region specified in | ||
~/.aws/config. | ||
credentials_profile_name: The name of the profile in the ~/.aws/credentials | ||
or ~/.aws/config files, which has either access keys or role information | ||
specified. If not specified, the default credential profile or, if on an | ||
EC2 instance, credentials from IMDS will be used. | ||
client: boto3 client for bedrock agent runtime. | ||
retrieval_config: Configuration for retrieval. | ||
Example: | ||
.. code-block:: python | ||
from langchain.retrievers import AmazonKnowledgeBasesRetriever | ||
retriever = AmazonKnowledgeBasesRetriever( | ||
knowledge_base_id="<knowledge-base-id>", | ||
retrieval_config={ | ||
"vectorSearchConfiguration": { | ||
"numberOfResults": 4 | ||
} | ||
}, | ||
) | ||
""" | ||
|
||
knowledge_base_id: str | ||
region_name: Optional[str] = None | ||
credentials_profile_name: Optional[str] = None | ||
endpoint_url: Optional[str] = None | ||
client: Any | ||
retrieval_config: RetrievalConfig | ||
|
||
@root_validator(pre=True) | ||
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: | ||
if values.get("client") is not None: | ||
return values | ||
|
||
try: | ||
import boto3 | ||
from botocore.client import Config | ||
from botocore.exceptions import UnknownServiceError | ||
|
||
if values.get("credentials_profile_name"): | ||
session = boto3.Session(profile_name=values["credentials_profile_name"]) | ||
else: | ||
# use default credentials | ||
session = boto3.Session() | ||
|
||
client_params = { | ||
"config": Config( | ||
connect_timeout=120, read_timeout=120, retries={"max_attempts": 0} | ||
) | ||
} | ||
if values.get("region_name"): | ||
client_params["region_name"] = values["region_name"] | ||
|
||
if values.get("endpoint_url"): | ||
client_params["endpoint_url"] = values["endpoint_url"] | ||
|
||
values["client"] = session.client("bedrock-agent-runtime", **client_params) | ||
|
||
return values | ||
except ImportError: | ||
raise ModuleNotFoundError( | ||
"Could not import boto3 python package. " | ||
"Please install it with `pip install boto3`." | ||
) | ||
except UnknownServiceError as e: | ||
raise ModuleNotFoundError( | ||
"Ensure that you have installed the latest boto3 package " | ||
"that contains the API for `bedrock-runtime-agent`." | ||
) from e | ||
except Exception as e: | ||
raise ValueError( | ||
"Could not load credentials to authenticate with AWS client. " | ||
"Please check that credentials in the specified " | ||
"profile name are valid." | ||
) from e | ||
|
||
def _get_relevant_documents( | ||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun | ||
) -> List[Document]: | ||
response = self.client.retrieve( | ||
retrievalQuery={"text": query.strip()}, | ||
knowledgeBaseId=self.knowledge_base_id, | ||
retrievalConfiguration=self.retrieval_config.dict(), | ||
) | ||
results = response["retrievalResults"] | ||
documents = [] | ||
for result in results: | ||
documents.append( | ||
Document( | ||
page_content=result["content"]["text"], | ||
metadata={ | ||
"location": result["location"], | ||
"score": result["score"] if "score" in result else 0, | ||
}, | ||
) | ||
) | ||
|
||
return documents |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters