# Bedrock Knowledge Base Retrieval and Generation with SageMaker Inference and Reranking

## 1: Import and Load Variables

In [None]:
import json

# Load the configuration variables from a JSON file
with open("../Lab 1/variables.json", "r") as f:
    variables = json.load(f)

variables


## 2: Define ARN and Configuration Details

In [None]:
# Knowledge Base Selection  
kb_id = variables["kbFixedChunk"]  # Options: "kbFixedChunk", "kbHierarchicalChunk", "kbSemanticChunk"

# Retrieval-Augmented Generation (RAG) Configuration  
number_of_results = 3  # Number of relevant documents to retrieve  
generation_configuration = {
    "temperature": 0,  # Lower temperature for more deterministic responses  
    "top_k": 10,  # Consider top 10 tokens at each generation step  
    "max_new_tokens": 5000,  # Maximum number of tokens to generate  
    "stop": "<|eot_id|>"  # Stop sequence to end the response generation  
}

# Define ARNs (Amazon Resource Names) for the model
rerank_model_arn=f"arn:aws:bedrock:us-west-2::foundation-model/cohere.rerank-v3-5:0"

# User Query
query = "what was the % increase in sales?"  # Sample query to retrieve data from the knowledge base


## 3: Set Up Bedrock Client

In [None]:
import boto3
import json
from typing import *

# Configure the Bedrock client
bedrock_agent_runtime = boto3.client('bedrock-agent-runtime', region_name="us-west-2")


## 4: Define Function for Reranking

In [None]:
import boto3

# Initialize Bedrock client to interact with the Bedrock Knowledge Base
bedrock_agent_runtime = boto3.client("bedrock-agent-runtime", region_name=variables["regionName"])
bedrock_client = boto3.client("bedrock-runtime", region_name=variables["regionName"])

# Constants for Knowledge Base ID, SageMaker endpoint, and number of results to retrieve
KNOWLEDGE_BASE_ID = kb_id
ENDPOINT_NAME = variables['sagemakerLLMEndpoint']
NUM_RESULTS = number_of_results

def search_knowledge_base(query, region_name, kb_id, num_results=5, use_reranking=False, model_arn=None):
    """
    Search Bedrock Knowledge Base and optionally rerank results.
    Returns document texts and displays detailed metadata.
    """
    client = boto3.client("bedrock-agent-runtime", region_name=region_name)
    
    # 1. Retrieve from knowledge base
    try:
        kb_response = client.retrieve(
            knowledgeBaseId=kb_id,
            retrievalQuery={"text": query},
            retrievalConfiguration={"vectorSearchConfiguration": {"numberOfResults": num_results}}
        )
        
        # Extract documents and metadata
        documents = []
        original_results = []
        
        for i, result in enumerate(kb_response.get("retrievalResults", [])):
            # Extract text from result
            text = ""
            if "content" in result and "text" in result["content"]:
                content_text = result["content"]["text"]
                if isinstance(content_text, list):
                    text = " ".join([item.get("span", "") if isinstance(item, dict) else str(item) 
                                  for item in content_text])
                else:
                    text = str(content_text)
                
            # Store original result with metadata
            original_results.append({
                "position": i + 1,
                "score": result.get("scoreValue", 0),
                "text": text[:300] + "..." if len(text) > 300 else text
            })
            documents.append(text)
        
        # Display original results
        print("\nTOP 3 DOCUMENTS WITHOUT RERANKING:")
        for doc in original_results[:min(3, len(original_results))]:
            print(f"Position {doc['position']} (Score: {doc['score']}):")
            print(f"{doc['text']}\n")
        
    except Exception as e:
        print(f"Search failed: {e}")
        return []
    
    # 2. Rerank if enabled
    if use_reranking and model_arn and documents:
        try:
            reranked = client.rerank(
                queries=[{"textQuery": {"text": query}, "type": "TEXT"}],
                rerankingConfiguration={
                    "bedrockRerankingConfiguration": {
                        "modelConfiguration": {"modelArn": model_arn},
                        "numberOfResults": num_results
                    },
                    "type": "BEDROCK_RERANKING_MODEL"
                },
                sources=[{
                    "inlineDocumentSource": {"textDocument": {"text": doc}, "type": "TEXT"},
                    "type": "INLINE"
                } for doc in documents]
            )
            
            # Process reranked results
            reranked_results = []
            reranked_documents = []
            
            for new_pos, result in enumerate(reranked.get("results", [])):
                idx = result.get("index", 0)
                if 0 <= idx < len(documents):
                    reranked_results.append({
                        "original_position": idx + 1,
                        "new_position": new_pos + 1,
                        "relevance_score": result.get("relevanceScore", 0),
                        "text": documents[idx][:300] + "..." if len(documents[idx]) > 300 else documents[idx]
                    })
                    reranked_documents.append(documents[idx])
            
            # Display reranked results
            print("\nTOP 3 DOCUMENTS AFTER RERANKING:")
            for doc in reranked_results[:min(3, len(reranked_results))]:
                print(f"Moved from position {doc['original_position']} to {doc['new_position']}")
                print(f"Relevance score: {doc['relevance_score']}")
                print(f"{doc['text']}\n")
            
            return reranked_documents
                
        except Exception as e:
            print(f"Reranking failed: {e}")
            print("Using original search results instead")
    
    # Return document texts for format_prompt
    return documents

## 5. Define SageMaker & Bedrock helper functions

In [None]:
# Function to format the prompt for Llama 3 model using retrieved context
def format_prompt(query, context):
    """Format prompt for Llama 3"""
    # Create the system prompt that includes the context and the user's question
    system_prompt = f"""Use the following context to answer the question. If you don't know the answer, say 'I don't know'.
        Context:
        {" ".join(context)}"
    """

    # Format the complete prompt including system and user instructions
    return f"""
        <|begin_of_text|>
        <|start_header_id|>system<|end_header_id|>
        {system_prompt}
        <|start_header_id|>user<|end_header_id|>
        Question: {query}
        <|start_header_id|>assistant<|end_header_id|>
        """.strip()

# Function to generate a response from the SageMaker endpoint based on the formatted prompt
def generate_response(prompt):
    """Generate response using SageMaker endpoint"""
    # Initialize SageMaker runtime client
    runtime = boto3.client('sagemaker-runtime')
    
    # Prepare the payload with prompt and generation parameters
    payload = {
        "inputs": prompt,  # The formatted prompt to pass to the model
        "parameters": generation_configuration  # Additional parameters for the model (e.g., temperature, tokens)
    }
    try:
        # Call the SageMaker endpoint to generate the response
        response = runtime.invoke_endpoint(
            EndpointName=ENDPOINT_NAME,  # SageMaker endpoint name
            ContentType='application/json',  # Content type for the request
            Body=json.dumps(payload)  # Send the payload as JSON
        )

        # Parse the response body
        result = json.loads(response['Body'].read().decode("utf-8"))
        
        # Handle different response formats (list or dictionary)
        if isinstance(result, list):
            # If the result is a list, extract the generated text from the first element
            return result[0]['generated_text']
        elif 'generated_text' in result:
            # If the result is a dictionary with 'generated_text', return the generated text
            return result['generated_text']
        elif 'generation' in result:
            # Alternative format with 'generation' key
            return result['generation']
        else:
            # Raise an error if the response format is unexpected
            raise RuntimeError("Unexpected response format")
            
    except Exception as e:
        # Raise an error if the generation process fails
        raise RuntimeError(f"Generation failed: {str(e)}")

## 6: Compare the Retrieved results WITH & WITHOUT Reranking

In [None]:
query = "Compare the results between 2022 and 2023"

print("WITHOUT RERANKING:")
context_without_reranking = search_knowledge_base(
    query=query,
    region_name=variables["regionName"],
    kb_id=variables["kbFixedChunk"],
    num_results=number_of_results,
    use_reranking=False
)

In [None]:
print("\nWITH RERANKING:")
context_with_reranking = search_knowledge_base(
    query=query,
    region_name=variables["regionName"],
    kb_id=variables["kbFixedChunk"],
    num_results=number_of_results,
    use_reranking=True,
    model_arn=rerank_model_arn
)

## 7: Compare the Generated results WITH & WITHOUT Reranking

In [None]:
print("WITHOUT RERANKING:")

# Format the prompt by combining the user's query and the retrieved context
prompt_without_reranking = format_prompt(query, context_without_reranking)

# Generate the response using the formatted prompt
response_without_reranking = generate_response(prompt_without_reranking)

# Print the user's query and answer
print("Question:", query)
print("Answer:", response_without_reranking)

In [None]:
print("WITH RERANKING:")

# Format the prompt by combining the user's query and the retrieved context
prompt_with_reranking = format_prompt(query, context_with_reranking)

# Generate the response using the formatted prompt
response_with_reranking = generate_response(prompt_with_reranking)

# Print the user's query and answer
print("Question:", query)
print("Answer:", response_with_reranking)