# Retrieval and Generation with SageMaker Endpoint LLM

## Overview  
This notebook demonstrates how to perform retrieval-augmented generation (RAG) using a SageMaker-hosted large language model (LLM). We will retrieve relevant documents from a knowledge base and use the LLM to generate responses based on the retrieved information.  

## Key Steps:  
- Configure and query a knowledge base for relevant documents.  
- Use a SageMaker-hosted LLM to generate contextual responses.  
- Optimize retrieval and generation parameters for improved accuracy.  

By the end of this notebook, you'll understand how to integrate SageMaker-hosted models into a RAG pipeline to enhance answer generation with domain-specific knowledge.  

import warnings
warnings.filterwarnings("ignore")

In [None]:
!pip install -U sagemaker boto3 langchain-aws 2>/dev/null

In [None]:
import json
with open("variables.json", "r") as f:
    variables = json.load(f)

variables

In [None]:
# LLM Configuration  
model_id = "meta-textgeneration-llama-3-8b-instruct"  # Choose a language model (e.g., LLaMA 3, DeepSeek, etc.)
instance_type = "ml.g5.4xlarge"  # Define the SageMaker instance type for model inference

# Knowledge Base Selection  
kb_id = variables["kbFixedChunk"]  # Options: "kbFixedChunk", "kbHierarchicalChunk", "kbSemanticChunk"

# Retrieval-Augmented Generation (RAG) Configuration  
number_of_results = 3  # Number of relevant documents to retrieve  
generation_configuration = {
    "temperature": 0,  # Lower temperature for more deterministic responses  
    "top_k": 10,  # Consider top 10 tokens at each generation step  
    "max_new_tokens": 5000,  # Maximum number of tokens to generate  
    "stop": "<|eot_id|>"  # Stop sequence to end the response generation  
}

# User Query  
query = "What are three hree sub-tasks in question answering over knowledge bases?"

> **Note**: The model deployment process for the SageMaker endpoint will take approximately 10-15 minutes to complete. During this time, the system is:
> 1. Provisioning the required compute resources (GPU instances)
> 2. Downloading and installing the model artifacts
> 3. Configuring the inference environment
> 4. Setting up auto-scaling and monitoring for the endpoint
>
> No further action is needed during this time. The cell will continue to execute until the endpoint is fully deployed and ready for inference. This is a one-time setup that will be used throughout the workshop.

In [None]:
from sagemaker.jumpstart.model import JumpStartModel

# Load the JumpStart model with the specified model ID and instance type  
llm_model = JumpStartModel(model_id=model_id, instance_type=instance_type)

# Deploy the model as a SageMaker endpoint  
llm_endpoint = llm_model.deploy(
    accept_eula=True,  # Accept the model's End User License Agreement (EULA)  
    initial_instance_count=1,  # Number of instances for hosting the model  
    endpoint_name="llm-inference-advanced-rag"  # Custom name for the deployed endpoint  
)

In [None]:
# Save the SageMaker endpoint name to the variables JSON file  
with open("variables.json", "w") as f:
    json.dump({**variables, "sagemakerLLMEndpoint": llm_endpoint.endpoint_name}, f)

# Print or return the deployed SageMaker endpoint name  
llm_endpoint.endpoint_name

## LangChain with AmazonKnowledgeBase Retriver and SagemakerEndpoint

In [None]:
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler

# Define a custom content handler for SageMaker LLM endpoint
class ContentHandler(LLMContentHandler):
    # Specify content type for input and output
    content_type = "application/json"
    accepts = "application/json"

    # Method to transform user input into the format expected by SageMaker
    def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})  # Format input as JSON
        return input_str.encode("utf-8")  # Encode to bytes

    # Method to process the output from SageMaker
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))  # Decode response JSON
        return response_json["generated_text"]  # Extract the generated text from response

In [None]:
from langchain_aws.retrievers import AmazonKnowledgeBasesRetriever

# Initialize the retriever to fetch relevant documents from the Amazon Knowledge Base
retriever = AmazonKnowledgeBasesRetriever(
    knowledge_base_id=kb_id,  # Specify the Knowledge Base ID to retrieve data from
    region_name=variables["regionName"],  # Define the AWS region where the Knowledge Base is located
    retrieval_config={
        "vectorSearchConfiguration": {
            "numberOfResults": number_of_results  # Set the number of relevant documents to retrieve
        }
    },
)


In [None]:
import boto3
import json 
from botocore.client import Config
from langchain.chains import RetrievalQA
from langchain_aws.llms import SagemakerEndpoint

# Initialize content handler for processing model inputs/outputs
content_handler = ContentHandler()

# Define the query to be answered by the model
query = "What are three sub-tasks in question answering over knowledge bases?"

# Create a SageMaker runtime client to interact with the deployed model endpoint
sagemaker_runtime = boto3.client("sagemaker-runtime")

# Initialize the LLM with the SageMaker endpoint
llm = SagemakerEndpoint(
        endpoint_name=llm_endpoint.endpoint_name,  # Specify the SageMaker endpoint name
        client=sagemaker_runtime,  # Attach the SageMaker runtime client
        model_kwargs=generation_configuration,  # Pass the model configuration parameters
        content_handler=content_handler,  # Use the custom content handler for formatting
    )

# Create a Retrieval-Augmented Generation (RAG) pipeline with a retriever and LLM
qa = RetrievalQA.from_chain_type(
    llm=llm,  # Use the initialized LLM for answering queries
    retriever=retriever,  # Use the retriever to fetch relevant documents
    return_source_documents=True  # Enable returning source documents along with the answer
)

# Execute the query using the RAG pipeline
answer = qa(query)

# Print the question and the generated answer
print("Question:", query)
print("Answer:", answer["result"])


## RAG using boto3 and SageMaker invoke_model

In [None]:
# Initialize Bedrock client to interact with the Bedrock Knowledge Base
bedrock_agent_runtime = boto3.client("bedrock-agent-runtime", region_name=variables["regionName"])

# Constants for Knowledge Base ID, SageMaker endpoint, and number of results to retrieve
KNOWLEDGE_BASE_ID = kb_id
ENDPOINT_NAME = llm_endpoint.endpoint_name
NUM_RESULTS = number_of_results

# Function to retrieve relevant context from the Bedrock Knowledge Base
def retrieve_from_bedrock(query):
    """Retrieve relevant context from Bedrock Knowledge Base"""
    try:
        # Retrieve context based on the query using vector search configuration
        response = bedrock_agent_runtime.retrieve(
            knowledgeBaseId=KNOWLEDGE_BASE_ID,
            retrievalQuery={
                'text': query  # The query text to search in the knowledge base
            },
            retrievalConfiguration={
                'vectorSearchConfiguration': {
                    'numberOfResults': NUM_RESULTS  # Adjust based on the number of results required
                }
            }
        )
        # Extract the 'text' from the retrieval results and return as a list
        return [result['content']['text'] for result in response['retrievalResults']]
    except Exception as e:
        # Raise an error if the retrieval process fails
        raise RuntimeError(f"Bedrock retrieval failed: {str(e)}")

# Function to format the prompt for Llama 3 model using retrieved context
def format_prompt(query, context):
    """Format prompt for Llama 3"""
    # Create the system prompt that includes the context and the user's question
    system_prompt = f"""Use the following context to answer the question. If you don't know the answer, say 'I don't know'.
        Context:
        {" ".join(context)}"
    """

    # Format the complete prompt including system and user instructions
    return f"""
        <|begin_of_text|>
        <|start_header_id|>system<|end_header_id|>
        {system_prompt}
        <|start_header_id|>user<|end_header_id|>
        Question: {query}
        <|start_header_id|>assistant<|end_header_id|>
        """.strip()

# Function to generate a response from the SageMaker endpoint based on the formatted prompt
def generate_response(prompt):
    """Generate response using SageMaker endpoint"""
    # Initialize SageMaker runtime client
    runtime = boto3.client('sagemaker-runtime')
    
    # Prepare the payload with prompt and generation parameters
    payload = {
        "inputs": prompt,  # The formatted prompt to pass to the model
        "parameters": generation_configuration  # Additional parameters for the model (e.g., temperature, tokens)
    }
    try:
        # Call the SageMaker endpoint to generate the response
        response = runtime.invoke_endpoint(
            EndpointName=ENDPOINT_NAME,  # SageMaker endpoint name
            ContentType='application/json',  # Content type for the request
            Body=json.dumps(payload)  # Send the payload as JSON
        )

        # Parse the response body
        result = json.loads(response['Body'].read().decode("utf-8"))
        
        # Handle different response formats (list or dictionary)
        if isinstance(result, list):
            # If the result is a list, extract the generated text from the first element
            return result[0]['generated_text']
        elif 'generated_text' in result:
            # If the result is a dictionary with 'generated_text', return the generated text
            return result['generated_text']
        elif 'generation' in result:
            # Alternative format with 'generation' key
            return result['generation']
        else:
            # Raise an error if the response format is unexpected
            raise RuntimeError("Unexpected response format")
            
    except Exception as e:
        # Raise an error if the generation process fails
        raise RuntimeError(f"Generation failed: {str(e)}")


In [None]:
# Retrieve relevant context from the Bedrock Knowledge Base based on the query
context = retrieve_from_bedrock(query)

# Format the prompt by combining the user's query and the retrieved context
prompt = format_prompt(query, context)

# Generate the response using the formatted prompt by calling the SageMaker endpoint
response = generate_response(prompt)

# Print the user's query
print("Question:", {query})

# Uncomment below line if you want to debug and see the retrieved context
# print(f"Context: {context}")

# Print the generated answer from the model based on the query and context
print("Answer:", response)
