# Retrieval and Generation with SageMaker Endpoint LLM

## Overview  
This notebook demonstrates how to perform retrieval-augmented generation (RAG) using a SageMaker-hosted large language model (LLM). We will retrieve relevant documents from a knowledge base and use the LLM to generate responses based on the retrieved information.  

## Key Steps:  
- Configure and query a knowledge base for relevant documents.  
- Use a SageMaker-hosted LLM to generate contextual responses.  
- Optimize retrieval and generation parameters for improved accuracy.  

By the end of this notebook, you'll understand how to integrate SageMaker-hosted models into a RAG pipeline to enhance answer generation with domain-specific knowledge.  

In [1]:
import warnings
warnings.filterwarnings("ignore")

!pip install -Uq sagemaker boto3 langchain-aws

Fetching existing resource information

In [2]:
import json
with open("variables.json", "r") as f:
    variables = json.load(f)

variables

{'accountNumber': '307297743176',
 'regionName': 'us-west-2',
 'collectionArn': 'arn:aws:aoss:us-west-2:307297743176:collection/h7cmj732p9d3v91spkhd',
 'collectionId': 'h7cmj732p9d3v91spkhd',
 'vectorIndexName': 'ws-index-',
 'bedrockExecutionRoleArn': 'arn:aws:iam::307297743176:role/advanced-rag-workshop-bedrock_execution_role-us-west-2',
 's3Bucket': '307297743176-us-west-2-advanced-rag-workshop',
 'kbFixedChunk': '4P6PBDDEGL',
 'kbSemanticChunk': 'IC3ZCBORXT',
 'kbCustomChunk': 'Q2T9CZ5VFA',
 'kbHierarchicalChunk': '1YIFVW0Z5E'}

In this example, you will use a model from [SageMaker Jumpstart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html). Amazon SageMaker JumpStart is a machine learning (ML) hub that can help you accelerate your ML journey. With SageMaker JumpStart, you can evaluate, compare, and select FMs quickly based on pre-defined quality and responsibility metrics to perform tasks like article summarization and image generation.

To load a model from SageMaker Jumpstart you need to specify a `model_id` and a `model_version`. The current list of models and versions can be found [here](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html).

The Llama 3.2 3B model has a `model_id` of `meta-textgeneration-llama-3-2-3b-instruct`. To always use the latest version of the model, you can set `model_version` to `*`, but pining to a specific version is recommended to ensure consistency.

Llama 3.2 3B was selected for this example because it is small, fast, and still supports a long context length (128k) to support larger retrievals if necessary for generation.

In [3]:
# LLM Configuration  
model_id, model_version = "meta-textgeneration-llama-3-8b-instruct", "2.11.2"
instance_type = "ml.g5.4xlarge"  # Define the SageMaker instance type for model inference

> **Note**: The model deployment process for the SageMaker endpoint will take approximately 8-10 minutes to complete. During this time, the system is:
> 1. Provisioning the required compute resources (GPU instances)
> 2. Downloading and installing the model artifacts
> 3. Configuring the inference environment
> 4. Setting up auto-scaling and monitoring for the endpoint
>
> No further action is needed during this time. The cell will continue to execute until the endpoint is fully deployed and ready for inference. This is a one-time setup that will be used throughout the workshop.

In [None]:
import time
import boto3
from sagemaker.jumpstart.model import JumpStartModel

# Initialize SageMaker client
sagemaker_client = boto3.client('sagemaker')

# Generate timestamp-based endpoint name
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S")
endpoint_name = f"endpoint-llama-3-2-3b-instruct-{timestamp}"

# First check for any existing endpoints
llm_endpoint_name = None
try:
    endpoints = sagemaker_client.list_endpoints()
    for endpoint in endpoints['Endpoints']:
        if 'llama-3-2-3b-instruct' in endpoint['EndpointName']:
            llm_endpoint_name = endpoint['EndpointName']
            print(f"Found existing endpoint: {llm_endpoint_name}")
            break
except Exception as e:
    print(f"Error checking for existing endpoints: {e}")

# If no existing endpoint found, try to deploy a new one
if not llm_endpoint_name:
    try:
        # Load the JumpStart model
        llm_model = JumpStartModel(model_id=model_id, instance_type=instance_type)
        
        # Deploy the model
        llm_endpoint = llm_model.deploy(
            accept_eula=True,
            initial_instance_count=1,
            endpoint_name=endpoint_name
        )
        llm_endpoint_name = llm_endpoint.endpoint_name
        print(f"Deployed new endpoint: {llm_endpoint_name}")
    except Exception as e:
        print(e)
        print("New endpoint cannot be created. Looking for any existing endpoints...")
        
        # Try again to find any existing endpoint if deployment failed
        try:
            endpoints = sagemaker_client.list_endpoints()
            for endpoint in endpoints['Endpoints']:
                if 'llama-3-2-3b-instruct' in endpoint['EndpointName']:
                    llm_endpoint_name = endpoint['EndpointName']
                    print(f"Using existing endpoint as fallback: {llm_endpoint_name}")
                    break
        except Exception:
            pass

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


Model 'meta-textgeneration-llama-3-8b-instruct' requires accepting end-user license agreement (EULA). See https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/fmhMetadata/eula/llama3Eula.txt for terms of use.


Using model 'meta-textgeneration-llama-3-8b-instruct' with wildcard version identifier '*'. You can pin to version '2.11.2' for more stable results. Note that models may have different input/output signatures after a major version upgrade.


---------!Deployed new endpoint: endpoint-llama-3-2-3b-instruct-2025-04-07-16-05-17


#### Check the progress of a SageMaker Endpoint deployment [here](https://console.aws.amazon.com/sagemaker/home#/endpoints). 

Store the SageMaker endpoint name for later use.

In [None]:
# Save the SageMaker endpoint name to the variables JSON file  
with open("variables.json", "w") as f:
    json.dump({**variables, "sagemakerLLMEndpoint": llm_endpoint_name}, f)

# Print or return the deployed SageMaker endpoint name
llm_endpoint_name

'endpoint-llama-3-2-3b-instruct-2025-04-07-16-05-17'

# Retrieval and Generation using Bedrock Knowledge Bases and SageMaker hosted models

With your endpoint successfully created, you can now use it as an output model in your RAG workflow. The following examples use the Amazon Bedrock Knowledge Bases that you created earlier for retrieval, combined with your SageMaker hosted model for generation. This hybrid approach results in a robust solution, combining the ease of use and managed aspects of Bedrock Knowledge Bases with the model flexibility and configuration controls of SageMaker hosting.

## RAG Orchestration with LangChain

To integrate LangChain with SageMaker endpoints, you first need to define a `ContentHandler`. Its purpose is to perform any transformations of the input/output data to match what the model expects and provide a processed output to client applications.

This content handler specifies the input/output content types as UTF-8 encoded `application/json` and pulls the `generated_text` parameter from the json response as the output.

In [17]:
from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler

# Define a custom content handler for SageMaker LLM endpoint
class ContentHandler(LLMContentHandler):
    # Specify content type for input and output
    content_type = "application/json"
    accepts = "application/json"

    # Method to transform user input into the format expected by SageMaker
    def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})  # Format input as JSON
        return input_str.encode("utf-8")  # Encode to bytes

    # Method to process the output from SageMaker
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))  # Decode response JSON
        return response_json["generated_text"]  # Extract the generated text from response

With your `ContentHandler` defined, the next step is to setup your retriever. This retriever is responsible for fetching the results from your Bedrock Knowledge Base so it can be provided as contextual input for generation.

The `AmazonKnowledgeBasesRetriever` takes in a parameter of `knowledge_base_id` to select the appropriate knowledge base.  In this example the ids of `kbFixedChunk`, `kbHierarchicalChunk`, `kbSemanticChunk` refer to saved variables in your `variables.json` file that hold the actual knowledge base id.

It also takes a `retrieval_config`, which at this time consists of a `vectorSearchConfiguration` with `numberOfResults` as the only configurable parameter. The `numberOfResults` parameter controls the maximum number of search results from the knowledge base.

In [18]:
from langchain_aws.retrievers import AmazonKnowledgeBasesRetriever

# Knowledge Base Selection
kb_id = variables["kbSemanticChunk"]  # Options: "kbFixedChunk", "kbHierarchicalChunk", "kbSemanticChunk"

# Retrieval-Augmented Generation (RAG) Configuration
number_of_results = 3  # Number of relevant documents to retrieve

# Initialize the retriever to fetch relevant documents from the Amazon Knowledge Base
retriever = AmazonKnowledgeBasesRetriever(
    knowledge_base_id=kb_id,  # Specify the Knowledge Base ID to retrieve data from
    region_name=variables["regionName"],  # Define the AWS region where the Knowledge Base is located
    retrieval_config={
        "vectorSearchConfiguration": {
            "numberOfResults": number_of_results  # Set the number of relevant documents to retrieve
        }
    },
)


Next, define a prompt template for your call to the output model. 

Since you are using a Llama-3 model in this example, it needs to follow the [correct prompt format](https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/).

This template uses the following roles:
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that help the model respond effectively.
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
- `assistant`: Represents the response generated by the AI model based on the context provided in the system and user prompts.

The fields `{context}` and `{question}` in the template will by dynamically injected as part of your RAG chain in a later step. These names are not hardcoded, but need to match what you specify when you build your chain.

In [19]:
prompt_template = """
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
You are an assistant for question-answering tasks. Answer the following question using the provided context. If you don't know the answer, just say "I don't know.".
<|start_header_id|>user<|end_header_id|>
Context: {context} 
Question: {question}
<|start_header_id|>assistant<|end_header_id|> 
Answer:
"""

Specify the parameters for generation.

`temperature` – Affects the shape of the probability distribution for the predicted output and influences the likelihood of the model selecting lower-probability outputs.
    - Choose a lower value to influence the model to select higher-probability outputs.
    - Choose a higher value to influence the model to select lower-probability outputs.
    - In technical terms, the temperature modulates the probability mass function for the next token. A lower temperature steepens the function and leads to more deterministic responses, and a higher temperature flattens the function and leads to more random responses.

`top_k` – The number of most-likely candidates that the model considers for the next token.
    - Choose a lower value to decrease the size of the pool and limit the options to more likely outputs.
    - Choose a higher value to increase the size of the pool and allow the model to consider less likely outputs.
    - For example, if you choose a value of 50 for Top K, the model selects from 50 of the most probable tokens that could be next in the sequence.

`top_p` – The percentage of most-likely candidates that the model considers for the next token.
    - Choose a lower value to decrease the size of the pool and limit the options to more likely outputs.
    - Choose a higher value to increase the size of the pool and allow the model to consider less likely outputs.
    - In technical terms, the model computes the cumulative probability distribution for the set of responses and considers only the top P% of the distribution. For example, if you choose a value of 0.8 for Top P, the model selects from the top 80% of the probability distribution of tokens that could be next in the sequence.

`max_new_tokens` - The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.

`stop` - Specify sequences of characters that stop the model from generating further tokens. If the model generates a stop sequence that you specify, it will stop generating after that sequence.

In [20]:
generation_configuration = {
    "temperature": 0,  # Lower temperature for more deterministic responses  
    "top_k": 10,  # Consider top 10 tokens at each generation step  
    "max_new_tokens": 512,  # Maximum number of tokens to generate  
    "stop": "<|eot_id|>"  # Stop sequence to end the response generation  
}

Here you will create your chain.

1. Initialize the `ContentHandler` from above
2. Create a `sagemaker-runtime` boto3 client for calling the endpoint
3. Initialize the `PromptTemplate` from above
4. Define a function to process the documents from the retriever. In this example, the document array is iterated through and the content is joined together using `\n\n` between them to break up the context.
5. Finally, define your chain. Here, you'll define your chain using LangChain's [LangChain Expression Language (LCEL)](https://python.langchain.com/docs/concepts/lcel/) to replace deprecated methods like [RetrievalQA](https://python.langchain.com/docs/versions/migrating_chains/retrieval_qa/). LCEL is designed to streamline the process of building useful apps with LLMs and combining related components.

Your `qa_chain` will fill pass the `question` parameter from the invocation of the chain, and the context parameter by invoking the retriever and processing the result with the `format_docs` function. From there, those outputs are piped to the prompt template to fill in the defined placeholders, then sent to the `llm` SageMaker endpoint for generation. Finally, the model output is sent to the `StrOutputParser` to convert into a usable string.

In [21]:
import boto3
from botocore.client import Config
from langchain_aws.llms import SagemakerEndpoint
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough


# Initialize content handler for processing model inputs/outputs
content_handler = ContentHandler()

# Create a SageMaker runtime client to interact with the deployed model endpoint
sagemaker_runtime = boto3.client("sagemaker-runtime")

# Initialize the LLM with the SageMaker endpoint
llm = SagemakerEndpoint(
        endpoint_name=llm_endpoint_name,  # Specify the SageMaker endpoint name
        client=sagemaker_runtime,  # Attach the SageMaker runtime client
        model_kwargs=generation_configuration,  # Pass the model configuration parameters
        content_handler=content_handler,  # Use the custom content handler for formatting
    )

prompt = PromptTemplate.from_template(prompt_template)


def format_docs(docs):
    results = "\n\n".join(doc.page_content for doc in docs)
    return results


qa_chain = (
    {
        "context": retriever | format_docs,
        "question": RunnablePassthrough(),
    }
    | prompt
    | llm
    | StrOutputParser()
)

You can now test your model with an example query. This query will get converted to an embedding and used for Knowledge Base search prior to question answering.

In [22]:
query = "In CONSOLIDATED STATEMENTS OF CASH FLOWS, How much did net income change in years 2022, 2023, 2024?"

print(f"Question: {query}")
print(f"Answer: {qa_chain.invoke(query)}")

Question: In CONSOLIDATED STATEMENTS OF CASH FLOWS, How much did net income change in years 2022, 2023, 2024?
Answer: According to the provided Consolidated Statements of Cash Flows, the net income (loss) for the years 2022, 2023, and 2024 are:

* 2022: (2,722)
* 2023: 30,425
* 2024: 59,248

So, the net income changed from a loss of $2,722 in 2022 to a profit of $30,425 in 2023, and then further increased to a profit of $59,248 in 2024.


## RAG using boto3

If you are not using LangChain, you can still perform the same tasks using the standard boto3 apis. This example shows how to use the Bedrock Knowledge Base `retrieve` API for retrieval, manually building the generation prompt, then using the SageMaker `invoke_endpoint` API to generate the output. This approach provides the most flexibility by leveraging low level constructs to build your own orchestration flow.

First, set up resources using configuration from above and define the boto3 client for Bedrock, you'll use this to perform retrievals from your knowledge base.

In [23]:
# Initialize Bedrock client to interact with the Bedrock Knowledge Base
bedrock_agent_runtime = boto3.client("bedrock-agent-runtime", region_name=variables["regionName"])

# Constants for Knowledge Base ID, SageMaker endpoint, and number of results to retrieve
KNOWLEDGE_BASE_ID = kb_id
ENDPOINT_NAME = llm_endpoint_name
NUM_RESULTS = number_of_results

Next, you'll define a series of wrapper functions to simplify the steps of retrieval, prompt formatting, and generation.

The `retrieve_from_bedrock` function takes an input query, Bedrock Knowledge Base id, the max number of results to retrieve from the knowledge base, and returns an array of text elements.

In [24]:
# Function to retrieve relevant context from the Bedrock Knowledge Base
def retrieve_from_bedrock(query, kb_id, num_results=5):
    """Retrieve relevant context from Bedrock Knowledge Base"""
    try:
        # Retrieve context based on the query using vector search configuration
        response = bedrock_agent_runtime.retrieve(
            knowledgeBaseId=kb_id,
            retrievalQuery={
                'text': query  # The query text to search in the knowledge base
            },
            retrievalConfiguration={
                'vectorSearchConfiguration': {
                    'numberOfResults': num_results  # Adjust based on the number of results required
                }
            }
        )
        # Extract the 'text' from the retrieval results and return as a list
        return [result['content']['text'] for result in response['retrievalResults']]
    except Exception as e:
        # Raise an error if the retrieval process fails
        raise RuntimeError(f"Bedrock retrieval failed: {str(e)}")

The `format_prompt` function takes in a user query and a context string from your knowledge base, then formats that into the desired prompt template for generation.

In [25]:
# Function to format the prompt for Llama 3 model using retrieved context
def format_prompt(query, context):
    """Format prompt for Llama 3"""
    # Format the complete prompt including system and user instructions
    return f"""
        <|begin_of_text|>
        <|start_header_id|>system<|end_header_id|>
        You are an assistant for question-answering tasks. Answer the following question using the provided context. If you don't know the answer, just say "I don't know.".
        <|start_header_id|>user<|end_header_id|>
        Context: {context} 
        Question: {query}
        <|start_header_id|>assistant<|end_header_id|> 
        Answer:
        """.strip()

The `generate_response` function takes the fully formatted prompt and  SageMaker endpoint name, then uses it to invoke the endpoint to generate the RAG response.

In [26]:
# Function to generate a response from the SageMaker endpoint based on the formatted prompt
def generate_response(prompt, endpoint_name):
    """Generate response using SageMaker endpoint"""
    # Initialize SageMaker runtime client
    runtime = boto3.client('sagemaker-runtime')
    
    # Prepare the payload with prompt and generation parameters
    payload = {
        "inputs": prompt,  # The formatted prompt to pass to the model
        "parameters": generation_configuration  # Additional parameters for the model (e.g., temperature, tokens)
    }
    try:
        # Call the SageMaker endpoint to generate the response
        response = runtime.invoke_endpoint(
            EndpointName=endpoint_name,  # SageMaker endpoint name
            ContentType='application/json',  # Content type for the request
            Body=json.dumps(payload)  # Send the payload as JSON
        )

        # Parse the response body
        result = json.loads(response['Body'].read().decode("utf-8"))
        
        # Handle different response formats (list or dictionary)
        if isinstance(result, list):
            # If the result is a list, extract the generated text from the first element
            return result[0]['generated_text']
        elif 'generated_text' in result:
            # If the result is a dictionary with 'generated_text', return the generated text
            return result['generated_text']
        elif 'generation' in result:
            # Alternative format with 'generation' key
            return result['generation']
        else:
            # Raise an error if the response format is unexpected
            raise RuntimeError("Unexpected response format")
            
    except Exception as e:
        # Raise an error if the generation process fails
        raise RuntimeError(f"Generation failed: {str(e)}")

Finally, you can call the series of functions in order to invoke the workflow and view the results.

In [27]:
# Retrieve relevant context from the Bedrock Knowledge Base based on the query
context = retrieve_from_bedrock(query,KNOWLEDGE_BASE_ID,NUM_RESULTS)

# Format the prompt by combining the user's query and the retrieved context
prompt = format_prompt(query, context)

# Generate the response using the formatted prompt by calling the SageMaker endpoint
response = generate_response(prompt, ENDPOINT_NAME)

# Print the user's query
print(f"Question: {query}")

# Uncomment below line if you want to debug and see the retrieved context
# print(f"Context: {context}")

# Print the generated answer from the model based on the query and context
print(f"Answer: {response}")


Question: In CONSOLIDATED STATEMENTS OF CASH FLOWS, How much did net income change in years 2022, 2023, 2024?
Answer:  According to the provided context, the net income changed as follows:

* 2022: Net income (loss) was -2,722
* 2023: Net income was 30,425
* 2024: Net income was 59,248

So, the net income changed from a loss of $2,722 in 2022 to a profit of $30,425 in 2023, and then further increased to a profit of $59,248 in 2024.
