# Bedrock Knowledge Base Retrieval and Generation with SageMaker Inference and Metadata Filtering  

### Description:  
This notebook showcases how to query and retrieve information from an Amazon Bedrock-powered knowledge base while leveraging SageMaker inference and metadata filtering. It covers key steps such as configuring queries, applying metadata filters, retrieving responses, and extracting citations used in the generated results.

![Metadata Filtering](./metadata_filtering.png)

## 1. Load Configuration Variables

In [1]:
# Load configuration variables from a JSON file to access knowledge base ID, account number, and guardrail info.
import json

with open("../Lab 1/variables.json", "r") as f:
    variables = json.load(f)

variables  # Display the loaded variables for confirmation

{'accountNumber': '307297743176',
 'regionName': 'us-west-2',
 'collectionArn': 'arn:aws:aoss:us-west-2:307297743176:collection/h7cmj732p9d3v91spkhd',
 'collectionId': 'h7cmj732p9d3v91spkhd',
 'vectorIndexName': 'ws-index-',
 'bedrockExecutionRoleArn': 'arn:aws:iam::307297743176:role/advanced-rag-workshop-bedrock_execution_role-us-west-2',
 's3Bucket': '307297743176-us-west-2-advanced-rag-workshop',
 'kbFixedChunk': '4P6PBDDEGL',
 'kbSemanticChunk': 'IC3ZCBORXT',
 'kbCustomChunk': 'Q2T9CZ5VFA',
 'kbHierarchicalChunk': '1YIFVW0Z5E',
 'sagemakerLLMEndpoint': 'endpoint-llama-3-2-3b-instruct-2025-04-07-16-05-17',
 'guardrail_id': 'fe7ryshi7i7b',
 'guardrail_version': '1'}

## 2. Set Up Required IDs and Model ARNs

In [2]:
# Knowledge Base Selection  
kb_id = variables["kbFixedChunk"]  # Options: "kbFixedChunk", "kbHierarchicalChunk", "kbSemanticChunk"

# Retrieval-Augmented Generation (RAG) Configuration  
number_of_results = 3  # Number of relevant documents to retrieve  
generation_configuration = {
    "temperature": 0,  # Lower temperature for more deterministic responses  
    "top_k": 10,  # Consider top 10 tokens at each generation step  
    "max_new_tokens": 5000,  # Maximum number of tokens to generate  
    "stop": "<|eot_id|>"  # Stop sequence to end the response generation  
}

# User Query
query = "what was the % increase in sales?"  # Sample query to retrieve data from the knowledge base


## 3. Define Metadata Filter

In [3]:
# Define a metadata filter for advanced filtering based on specific conditions
one_group_filter= {
    "andAll": [
        {
            "equals": {
                "key": "docType",
                "value": '10K Report'
            }
        },
        {
            "equals": {
                "key": "year",
                "value": 2023
            }
        }
    ]
}


## 4. Define SageMaker & Bedrock helper functions

In [4]:
import boto3

# Initialize Bedrock client to interact with the Bedrock Knowledge Base
bedrock_agent_runtime = boto3.client("bedrock-agent-runtime", region_name=variables["regionName"])

# Constants for Knowledge Base ID, SageMaker endpoint, and number of results to retrieve
KNOWLEDGE_BASE_ID = kb_id
ENDPOINT_NAME = variables['sagemakerLLMEndpoint']
NUM_RESULTS = number_of_results

# Function to retrieve relevant context from the Bedrock Knowledge Base
def retrieve_from_bedrock(query):
    """Retrieve relevant context from Bedrock Knowledge Base"""
    try:
        # Retrieve context based on the query using vector search configuration
        response = bedrock_agent_runtime.retrieve(
            knowledgeBaseId=KNOWLEDGE_BASE_ID,
            retrievalQuery={
                'text': query  # The query text to search in the knowledge base
            },
            retrievalConfiguration={
                'vectorSearchConfiguration': {
                    'numberOfResults': NUM_RESULTS,  # Adjust based on needs
                     "filter": one_group_filter
                }
            }
        )
        # Extract the 'text' from the retrieval results and return as a list
        return [result['content']['text'] for result in response['retrievalResults']]
    except Exception as e:
        # Raise an error if the retrieval process fails
        raise RuntimeError(f"Bedrock retrieval failed: {str(e)}")

# Function to format the prompt for Llama 3 model using retrieved context
def format_prompt(query, context):
    """Format prompt for Llama 3"""
    # Create the system prompt that includes the context and the user's question
    system_prompt = f"""Use the following context to answer the question. If you don't know the answer, say 'I don't know'.
        Context:
        {" ".join(context)}"
    """

    # Format the complete prompt including system and user instructions
    return f"""
        <|begin_of_text|>
        <|start_header_id|>system<|end_header_id|>
        {system_prompt}
        <|start_header_id|>user<|end_header_id|>
        Question: {query}
        <|start_header_id|>assistant<|end_header_id|>
        """.strip()

# Function to generate a response from the SageMaker endpoint based on the formatted prompt
def generate_response(prompt):
    """Generate response using SageMaker endpoint"""
    # Initialize SageMaker runtime client
    runtime = boto3.client('sagemaker-runtime')
    
    # Prepare the payload with prompt and generation parameters
    payload = {
        "inputs": prompt,  # The formatted prompt to pass to the model
        "parameters": generation_configuration  # Additional parameters for the model (e.g., temperature, tokens)
    }
    try:
        # Call the SageMaker endpoint to generate the response
        response = runtime.invoke_endpoint(
            EndpointName=ENDPOINT_NAME,  # SageMaker endpoint name
            ContentType='application/json',  # Content type for the request
            Body=json.dumps(payload)  # Send the payload as JSON
        )

        # Parse the response body
        result = json.loads(response['Body'].read().decode("utf-8"))
        
        # Handle different response formats (list or dictionary)
        if isinstance(result, list):
            # If the result is a list, extract the generated text from the first element
            return result[0]['generated_text']
        elif 'generated_text' in result:
            # If the result is a dictionary with 'generated_text', return the generated text
            return result['generated_text']
        elif 'generation' in result:
            # Alternative format with 'generation' key
            return result['generation']
        else:
            # Raise an error if the response format is unexpected
            raise RuntimeError("Unexpected response format")
            
    except Exception as e:
        # Raise an error if the generation process fails
        raise RuntimeError(f"Generation failed: {str(e)}")


## 5. Generate Response with Metadata Filter

In [5]:
# Retrieve relevant context from the Bedrock Knowledge Base based on the query
context = retrieve_from_bedrock(query)

# Format the prompt by combining the user's query and the retrieved context
prompt = format_prompt(query, context)

# Generate the response using the formatted prompt by calling the SageMaker endpoint
response = generate_response(prompt)

# Print the user's query
print("Question:", {query})

# Uncomment below line if you want to debug and see the retrieved context
# print(f"Context: {context}")

# Print the generated answer from the model based on the query and context
print("Answer:", response)


Question: {'what was the % increase in sales?'}
Answer: 

According to the text, the sales increased 9% in 2022, compared to the prior year.
