# Bedrock Knowledge Base Retrieval and Generation with SageMaker Inference and Reranking

![Reranking](./reranking.png)

## 1: Import Required Functions

In [None]:
# Import necessary functions from advanced_rag_utils
from advanced_rag_utils import (
    load_variables,
    setup_bedrock_client,
    format_llama3_prompt,
    search_knowledge_base_with_reranking,
    enhanced_generate_sagemaker_response,
    compare_reranking
)

# For the boto3 import
import boto3

## 2: Load Configuration Variables

In [None]:
# Load the configuration variables
variables = load_variables("../variables.json")
variables

## 3: Define Configuration Details

In [None]:
# Knowledge Base Selection  
kb_id = variables["kbFixedChunk"]  # Options: "kbFixedChunk", "kbHierarchicalChunk", "kbSemanticChunk"

# SageMaker endpoint
sagemaker_endpoint = variables['sagemakerLLMEndpoint']

# Retrieval-Augmented Generation (RAG) Configuration  
number_of_results = 10  # Number of relevant documents to retrieve  
generation_configuration = {
    "temperature": 0,  # Lower temperature for more deterministic responses  
    "top_k": 10,  # Consider top 10 tokens at each generation step  
    "max_new_tokens": 5000,  # Maximum number of tokens to generate  
    "stop": "<|eot_id|>"  # Stop sequence to end the response generation  
}

# Define ARN for the reranking model
rerank_model_arn = "arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0"

# Default user query
query = "what was the % increase in sales?"

## 4: Initialize Bedrock Client

In [None]:
# Initialize the Bedrock agent runtime client
bedrock_client = setup_bedrock_client(variables["regionName"])

## 5: Compare Search Results With & Without Reranking

In [None]:
# Define a new query for comparison
query = "Compare the results between 2022 and 2023"

# Search without reranking
print("WITHOUT RERANKING:")
context_without_reranking, details_without_reranking = search_knowledge_base_with_reranking(
    query=query,
    knowledge_base_id=kb_id,
    bedrock_client=bedrock_client,
    num_results=number_of_results,
    use_reranking=False,
    region_name=variables["regionName"]
)

In [None]:
# Search with reranking
print("\nWITH RERANKING:")
context_with_reranking, details_with_reranking = search_knowledge_base_with_reranking(
    query=query,
    knowledge_base_id=kb_id,
    bedrock_client=bedrock_client,
    num_results=number_of_results,
    use_reranking=True,
    rerank_model_arn=rerank_model_arn,
    region_name=variables["regionName"]
)

## 6: Compare Generated Responses With & Without Reranking

In [None]:
# Generate response without reranking
print("WITHOUT RERANKING:")

# Format the prompt
prompt_without_reranking = format_llama3_prompt(query, context_without_reranking)

# Generate the response
response_without_reranking = enhanced_generate_sagemaker_response(
    prompt=prompt_without_reranking,
    endpoint_name=sagemaker_endpoint,
    generation_config=generation_configuration
)

# Print the query and answer
print("Question:", query)
print("Answer:", response_without_reranking)

In [None]:
# Generate response with reranking
print("WITH RERANKING:")

# Format the prompt
prompt_with_reranking = format_llama3_prompt(query, context_with_reranking)

# Generate the response
response_with_reranking = enhanced_generate_sagemaker_response(
    prompt=prompt_with_reranking,
    endpoint_name=sagemaker_endpoint,
    generation_config=generation_configuration
)

# Print the query and answer
print("Question:", query)
print("Answer:", response_with_reranking)

## 7: All-in-One Comparison (Alternative Approach)

In [None]:
# Using the comprehensive function to compare results
query = "What are the key financial metrics for Amazon in 2023?"

# Run the comprehensive comparison
comparison_results = compare_reranking(
    query=query,
    knowledge_base_id=kb_id,
    sagemaker_endpoint=sagemaker_endpoint,
    rerank_model_arn=rerank_model_arn,
    generation_config=generation_configuration,
    bedrock_client=bedrock_client,
    num_results=number_of_results,
    region_name=variables["regionName"]
)

# Display comparison of final responses
print("\n-------- FINAL RESPONSE COMPARISON --------\n")
print("QUESTION: ", query)

print("\nRESPONSE WITHOUT RERANKING:")
print(comparison_results["without_reranking"]["response"])

print("\nRESPONSE WITH RERANKING:")
print(comparison_results["with_reranking"]["response"])