# Bedrock Knowledge Base Retrieval and Generation with Reranking

The Rerank API in Amazon Bedrock is a new feature that improves the accuracy and relevance of responses in Retrieval-Augmented Generation (RAG) applications. It supports reranker models that rank a set of retrieved documents based on their relevance to a user's query, helping to prioritize the most relevant content for response generation.

## Key features and use cases:

1. **Enhancing RAG applications**: The Rerank API addresses challenges in semantic search, particularly with complex or ambiguous queries. For example, it can help a customer service chatbot focus on return policies rather than shipping guidelines when asked about returning an online purchase.

2. **Improving search relevance**: It enables developers to significantly enhance their search relevance and content ranking capabilities, making enterprise-grade search technology more accessible.

3. **Optimizing context window usage**: By ensuring the most useful information is sent to the foundation model, it potentially reduces costs and improves response accuracy.

4. **Flexible integration**: The Rerank API can be used independently to rerank documents even if you're not using Amazon Bedrock Knowledge Bases.

5. **Multiple model support**: At launch, it supports Amazon Rerank 1.0 and Cohere Rerank 3.5 models.

6. **Customizable configurations**: Developers can specify additional model configurations as key-value pairs for more tailored reranking.

The Rerank API is available in select AWS Regions, including US West (Oregon), Canada (Central), Europe (Frankfurt), and Asia Pacific (Tokyo). It can be integrated into existing systems at scale, whether keyword-based or semantic, through a single API call in Amazon Bedrock.


![Reranking](./reranking.png)

## 1: Import and Load Variables

In [None]:
import json

# Load the configuration variables from a JSON file
with open("../Lab 1/variables.json", "r") as f:
    variables = json.load(f)

variables


## 2: Define ARN and Configuration Details

In [None]:
# Setting up configuration for Bedrock
regionName=variables['regionName'] 
accountNumber = variables['accountNumber']
knowledge_base_id = variables['kbFixedChunk']   
model_id = 'us.amazon.nova-pro-v1:0' 

# Define ARNs (Amazon Resource Names) for the model
model_arn = f"arn:aws:bedrock:us-west-2:{accountNumber}:inference-profile/{model_id}"
rerank_model_arn=f"arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0"


## 3: Set Up Bedrock Client

In [None]:
import boto3
import json
from typing import *

# Configure the Bedrock client
bedrock_agent_runtime = boto3.client('bedrock-agent-runtime', region_name="us-west-2")


## 4: Define Function for Reranking

In [None]:
import boto3
import json

def search_kb_with_optional_rerank(query, kb_id, model_arn=None, use_reranking=False):
    """Search KB and optionally rerank results"""
    client = boto3.client("bedrock-agent-runtime", region_name=regionName)
    
    # 1. Retrieve from knowledge base
    kb_response = client.retrieve(
        knowledgeBaseId=kb_id,
        retrievalQuery={"text": query},
        retrievalConfiguration={"vectorSearchConfiguration": {"numberOfResults": 10}}
    )
    
    # Extract documents
    documents = []
    original_results = []
    
    for i, result in enumerate(kb_response.get("retrievalResults", [])):
        # Extract text from result
        text = ""
        if "content" in result and "text" in result["content"]:
            text = "".join([item.get("span", "") if isinstance(item, dict) else str(item) 
                           for item in result["content"]["text"]])
            
        # Store original result
        original_results.append({
            "position": i + 1,
            "score": result.get("scoreValue", 0),
            "text": text[:300] + "..." if len(text) > 300 else text
        })
        documents.append(text)
    
    # 2. Rerank if enabled
    if use_reranking and model_arn and documents:
        reranked = client.rerank(
            queries=[{"textQuery": {"text": query}, "type": "TEXT"}],
            rerankingConfiguration={
                "bedrockRerankingConfiguration": {
                    "modelConfiguration": {"modelArn": model_arn},
                    "numberOfResults": 10
                },
                "type": "BEDROCK_RERANKING_MODEL"
            },
            sources=[{
                "inlineDocumentSource": {"textDocument": {"text": doc}, "type": "TEXT"},
                "type": "INLINE"
            } for doc in documents]
        )
        
        # Process reranked results
        reranked_results = []
        for result in reranked.get("results", []):
            idx = result.get("index", 0)
            reranked_results.append({
                "original_position": idx + 1,
                "new_position": len(reranked_results) + 1,
                "relevance_score": result.get("relevanceScore", 0),  # Full precision score
                "text": documents[idx][:300] + "..."
            })
        return {"original_results": original_results, "reranked_results": reranked_results}
        
    return {"original_results": original_results}

In [None]:
def retrieve_and_generate(query, kb_id, model_arn, use_reranking=True):
    """Full RAG pipeline adapted for Amazon Nova Pro Model"""
    import json
    import boto3
    from botocore.exceptions import ClientError

    # 1. Search and get documents
    results = search_kb_with_optional_rerank(query, kb_id, model_arn, use_reranking)
    
    # 2. Prepare context
    if use_reranking and "reranked_results" in results:
        docs = [doc["text"] for doc in results["reranked_results"]]
        source_type = "reranked"
    else:
        docs = [doc["text"] for doc in results["original_results"]]
        source_type = "vector search"
    
    context = "\n".join([f"Document {i+1}: {doc[:300]}..." for i, doc in enumerate(docs[:3])])
    prompt_data = f"Query: {query}\n\nContext from {source_type}:\n{context}\n\nAnswer:"

    # 3. Format request body for Nova Pro Model
    body = json.dumps(
        {
            "messages": [{"role": "user", "content": [{"text": prompt_data}]}],
            "inferenceConfig": {
                "max_new_tokens": 2000,
                "top_p": 0.9,
                "top_k": 20,
                "temperature": 0.1
            }
        }
    )
    
    # 4. Set up model invocation parameters
    modelId = model_id  # Replace with the correct Nova Pro model ID
    accept = "application/json"
    contentType = "application/json"
    
    client = boto3.client("bedrock-runtime", region_name=regionName)  # Update region as needed

    try:
        # Invoke the Nova Pro Model
        response = client.invoke_model(
            body=body,
            modelId=modelId,
            accept=accept,
            contentType=contentType
        )
        
        # Parse the response body
        response_body = json.loads(response.get("body").read())
        #print(response_body)
        # Extract the generated text from the response
        return response_body

    except ClientError as error:
        if error.response['Error']['Code'] == 'AccessDeniedException':
            print(f"\x1b[41m{error.response['Error']['Message']}\
                \nTo troubleshoot this issue, please refer to the following resources:\
                 \nhttps://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_access-denied.html\
                 \nhttps://docs.aws.amazon.com/bedrock/latest/userguide/security-iam.html\x1b[0m\n")
        else:
            raise error




In [None]:
def retrieve_and_generate_claude(query, kb_id, model_arn, use_reranking=True):
    """Full RAG pipeline adapted for Amazon Bedrock models"""
    import json
    import boto3
    from botocore.exceptions import ClientError

    # 1. Search and get documents
    results = search_kb_with_optional_rerank(query, kb_id, rerank_model_arn, use_reranking)
    
    # 2. Prepare context
    if use_reranking and "reranked_results" in results:
        docs = [doc["text"] for doc in results["reranked_results"]]
        source_type = "reranked"
    else:
        docs = [doc["text"] for doc in results["original_results"]]
        source_type = "vector search"
    
    # Limit to top 3 documents
    docs = docs[:3]
    
    # 3. Format request body based on model type
    client = boto3.client("bedrock-runtime", region_name=regionName)
    
    if "claude" in model_id.lower():
        # Claude-specific formatting
        body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 2000,
            "top_p": 0.9,
            "top_k": 20,
            "temperature": 0.1,
            "messages": [
                {
                    "role": "user", 
                    "content": [
                        {
                            "type": "text", 
                            "text": f"Answer the following question based on the provided context.\n\nQuestion: {query}\n\nContext:\n{' '.join(docs)}"
                        }
                    ]
                }
            ]
        })
    else:
        # For document-query format API models (like the error suggests)
        # Convert document objects to plain strings
        formatted_docs = []
        for doc in docs:
            formatted_docs.append(doc)
            
        body = json.dumps({
            "query": query,
            "documents": formatted_docs,
            "api_version": 1  # Integer instead of string
        })

    try:
        # Invoke the model
        response = client.invoke_model(
            body=body.encode("utf-8"),  # Ensure body is encoded as bytes
            modelId=model_id,
            accept="application/json",
            contentType="application/json"
        )
        
        # Parse the response body
        response_body = json.loads(response.get("body").read())
        
        return response_body

    except ClientError as error:
        if error.response['Error']['Code'] == 'AccessDeniedException':
            print(f"\x1b[41m{error.response['Error']['Message']}\n\
                \nTo troubleshoot this issue, please refer to the following resources:\
                 \nhttps://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_access-denied.html\
                 \nhttps://docs.aws.amazon.com/bedrock/latest/userguide/security-iam.html\x1b[0m\n")
        else:
            raise error

## 5: Define Function for Retrieve and Generate

## 6: Compare the Retrieved results WITH & WITHOUT Reranking

In [None]:
# Example usage

query = "What was the margin percentage change in Amazon's international segment between year 2022 and 2023"
# Without reranking
print("WITHOUT RERANKING:")
results_no_rerank = search_kb_with_optional_rerank(
    query, knowledge_base_id, rerank_model_arn, use_reranking=False
)

# Display original results
print("\nTOP  DOCUMENTS WITHOUT RERANKING:")
for doc in results_no_rerank["original_results"][:20]:
    print(f"Position {doc['position']} (Score: {doc['score']}):")
    print(f"{doc['text']}\n")


In [None]:
# With reranking
print("\nWITH RERANKING:")
results_with_rerank = search_kb_with_optional_rerank(
    query, knowledge_base_id, rerank_model_arn, use_reranking=True
)

# Show reranked results with full precision scores
print("\nTOP  DOCUMENTS AFTER RERANKING:")
for doc in results_with_rerank["reranked_results"][:10]:
    print(f"Moved from position {doc['original_position']} to {doc['new_position']}")
    print(f"Relevance score: {doc['relevance_score']}")  # Full precision
    print(f"{doc['text']}\n")


## 7: Compare the Generated results WITH & WITHOUT Reranking

In [None]:
print("\nGENERATED ANSWER WITHOUT RERANKING:")
answer_no_rerank = retrieve_and_generate(query, knowledge_base_id, rerank_model_arn, use_reranking=False)
#print(answer_no_rerank['content'][0]['text'])
print(answer_no_rerank['output']['message']['content'][0]['text'])


In [None]:
print("\nGENERATED ANSWER WITH RERANKING:")
answer_with_rerank = retrieve_and_generate(query, knowledge_base_id, rerank_model_arn, use_reranking=True)
#print(answer_with_rerank['content'][0]['text'])
print(answer_with_rerank['output']['message']['content'][0]['text'])
