# Retrieval and Generation with SageMaker Endpoint LLM

## Overview  
This notebook demonstrates how to perform retrieval-augmented generation (RAG) using a SageMaker-hosted large language model (LLM). We will retrieve relevant documents from a knowledge base and use the LLM to generate responses based on the retrieved information.  

## Key Steps:  
- Configure and query a knowledge base for relevant documents.  
- Use a SageMaker-hosted LLM to generate contextual responses.  
- Optimize retrieval and generation parameters for improved accuracy.  

By the end of this notebook, you'll understand how to integrate SageMaker-hosted models into a RAG pipeline to enhance answer generation with domain-specific knowledge.  

In [2]:

import importlib
import advanced_rag_utils

# Reload module
importlib.reload(advanced_rag_utils)

# Re-import all functions
from advanced_rag_utils import *

from datetime import datetime, timedelta, UTC


In [3]:
!pip install -Uq sagemaker boto3 langchain-aws

Fetching existing resource information

In [4]:
import json

# Load variables from JSON file
with open("../variables.json", "r") as f:
    variables = json.load(f)

variables

{'accountNumber': '989679345636',
 'regionName': 'us-west-2',
 'collectionArn': 'arn:aws:aoss:us-west-2:989679345636:collection/ny2d41n7rmju74rh4ue2',
 'collectionId': 'ny2d41n7rmju74rh4ue2',
 'vectorIndexName': 'ws-index-',
 'bedrockExecutionRoleArn': 'arn:aws:iam::989679345636:role/advanced-rag-workshop-bedrock_execution_role-us-west-2',
 's3Bucket': '989679345636-us-west-2-advanced-rag-workshop',
 'kbFixedChunk': 'TYG3IXCHCX',
 'kbSemanticChunk': 'N7ZHYZVLOX',
 'kbHierarchicalChunk': 'UDPUVOULM1',
 'kbCustomChunk': 'AD07GOEBQ2'}

In this example, you will use a model from [SageMaker Jumpstart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html). Amazon SageMaker JumpStart is a machine learning (ML) hub that can help you accelerate your ML journey. With SageMaker JumpStart, you can evaluate, compare, and select FMs quickly based on pre-defined quality and responsibility metrics to perform tasks like article summarization and image generation.

To load a model from SageMaker Jumpstart you need to specify a `model_id` and a `model_version`. The current list of models and versions can be found [here](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html).

The Llama 3.2 3B model has a `model_id` of `meta-textgeneration-llama-3-2-3b-instruct`. To always use the latest version of the model, you can set `model_version` to `*`, but pining to a specific version is recommended to ensure consistency.

Llama 3.2 3B was selected for this example because it is small, fast, and still supports a long context length (128k) to support larger retrievals if necessary for generation.

In [5]:
# LLM Configuration  
model_id = "meta-textgeneration-llama-3-8b-instruct" 
model_version = "2.11.2"
instance_type = "ml.g5.4xlarge"

> **Note**: The model deployment process for the SageMaker endpoint will take approximately 8-10 minutes to complete. During this time, the system is:
> 1. Provisioning the required compute resources (GPU instances)
> 2. Downloading and installing the model artifacts
> 3. Configuring the inference environment
> 4. Setting up auto-scaling and monitoring for the endpoint
>
> No further action is needed during this time. The cell will continue to execute until the endpoint is fully deployed and ready for inference. This is a one-time setup that will be used throughout the workshop.

In [6]:
# Deploy or find an existing SageMaker endpoint
llm_endpoint_name = get_or_deploy_sagemaker_endpoint(
    model_id=model_id, 
    model_version=model_version, 
    instance_type=instance_type,
    region_name=variables["regionName"]
)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


Model 'meta-textgeneration-llama-3-8b-instruct' requires accepting end-user license agreement (EULA). See https://jumpstart-cache-prod-us-east-1.s3.us-east-1.amazonaws.com/fmhMetadata/eula/llama3Eula.txt for terms of use.


Using model 'meta-textgeneration-llama-3-8b-instruct' with version '2.11.2'. You can upgrade to version '2.11.3' to get the latest model specifications. Note that models may have different input/output signatures after a major version upgrade.


An error occurred (ValidationException) when calling the CreateEndpointConfig operation: 1 validation error detected: Value 'endpoint-meta-textgeneration-llama-3-8b-instruct-2025-04-29-22-23-26' at 'endpointConfigName' failed to satisfy constraint: Member must have length less than or equal to 63
New endpoint cannot be created. Looking for any existing endpoints...


#### Check the progress of a SageMaker Endpoint deployment [here](https://console.aws.amazon.com/sagemaker/home#/endpoints). 

Store the SageMaker endpoint name for later use.

In [None]:
# Save the SageMaker endpoint name to the variables JSON file
variables = save_sagemaker_endpoint_to_variables(
    variables=variables,
    endpoint_name=llm_endpoint_name
)

# Display the endpoint name
llm_endpoint_name

# Retrieval and Generation using Bedrock Knowledge Bases and SageMaker hosted models

With your endpoint successfully created, you can now use it as an output model in your RAG workflow. The following examples use the Amazon Bedrock Knowledge Bases that you created earlier for retrieval, combined with your SageMaker hosted model for generation. This hybrid approach results in a robust solution, combining the ease of use and managed aspects of Bedrock Knowledge Bases with the model flexibility and configuration controls of SageMaker hosting.

## RAG Orchestration with LangChain

To integrate LangChain with SageMaker endpoints, you first need to define a `ContentHandler`. Its purpose is to perform any transformations of the input/output data to match what the model expects and provide a processed output to client applications.

This content handler specifies the input/output content types as UTF-8 encoded `application/json` and pulls the `generated_text` parameter from the json response as the output.

In [None]:
# Create content handler for SageMaker endpoint
content_handler = create_sagemaker_content_handler()

With your `ContentHandler` defined, the next step is to setup your retriever. This retriever is responsible for fetching the results from your Bedrock Knowledge Base so it can be provided as contextual input for generation.

The `AmazonKnowledgeBasesRetriever` takes in a parameter of `knowledge_base_id` to select the appropriate knowledge base.  In this example the ids of `kbFixedChunk`, `kbHierarchicalChunk`, `kbSemanticChunk` refer to saved variables in your `variables.json` file that hold the actual knowledge base id.

It also takes a `retrieval_config`, which at this time consists of a `vectorSearchConfiguration` with `numberOfResults` as the only configurable parameter. The `numberOfResults` parameter controls the maximum number of search results from the knowledge base.

In [None]:
# Knowledge Base Selection and configuration
kb_id = variables["kbSemanticChunk"]
number_of_results = 3

# Create retriever for Bedrock Knowledge Base
retriever = create_bedrock_retriever(
    kb_id=kb_id,
    number_of_results=number_of_results,
    region_name=variables["regionName"]
)

Next, define a prompt template for your call to the output model. 

Since you are using a Llama-3 model in this example, it needs to follow the [correct prompt format](https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/).

This template uses the following roles:
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that help the model respond effectively.
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
- `assistant`: Represents the response generated by the AI model based on the context provided in the system and user prompts.

The fields `{context}` and `{question}` in the template will by dynamically injected as part of your RAG chain in a later step. These names are not hardcoded, but need to match what you specify when you build your chain.

In [None]:
# Get prompt template for Llama model
prompt_template = get_llama_prompt_template()

Specify the parameters for generation.

`temperature` – Affects the shape of the probability distribution for the predicted output and influences the likelihood of the model selecting lower-probability outputs.
    - Choose a lower value to influence the model to select higher-probability outputs.
    - Choose a higher value to influence the model to select lower-probability outputs.
    - In technical terms, the temperature modulates the probability mass function for the next token. A lower temperature steepens the function and leads to more deterministic responses, and a higher temperature flattens the function and leads to more random responses.

`top_k` – The number of most-likely candidates that the model considers for the next token.
    - Choose a lower value to decrease the size of the pool and limit the options to more likely outputs.
    - Choose a higher value to increase the size of the pool and allow the model to consider less likely outputs.
    - For example, if you choose a value of 50 for Top K, the model selects from 50 of the most probable tokens that could be next in the sequence.

`top_p` – The percentage of most-likely candidates that the model considers for the next token.
    - Choose a lower value to decrease the size of the pool and limit the options to more likely outputs.
    - Choose a higher value to increase the size of the pool and allow the model to consider less likely outputs.
    - In technical terms, the model computes the cumulative probability distribution for the set of responses and considers only the top P% of the distribution. For example, if you choose a value of 0.8 for Top P, the model selects from the top 80% of the probability distribution of tokens that could be next in the sequence.

`max_new_tokens` - The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.

`stop` - Specify sequences of characters that stop the model from generating further tokens. If the model generates a stop sequence that you specify, it will stop generating after that sequence.

In [None]:
# Get default generation configuration for SageMaker
generation_configuration =get_default_sagemaker_generation_config(
    temperature=0,
    top_k=10,
    max_new_tokens=512,
    stop="<|eot_id|>"
)

Here you will create your chain.

1. Initialize the `ContentHandler` from above
2. Create a `sagemaker-runtime` boto3 client for calling the endpoint
3. Initialize the `PromptTemplate` from above
4. Define a function to process the documents from the retriever. In this example, the document array is iterated through and the content is joined together using `\n\n` between them to break up the context.
5. Finally, define your chain. Here, you'll define your chain using LangChain's [LangChain Expression Language (LCEL)](https://python.langchain.com/docs/concepts/lcel/) to replace deprecated methods like [RetrievalQA](https://python.langchain.com/docs/versions/migrating_chains/retrieval_qa/). LCEL is designed to streamline the process of building useful apps with LLMs and combining related components.

Your `qa_chain` will fill pass the `question` parameter from the invocation of the chain, and the context parameter by invoking the retriever and processing the result with the `format_docs` function. From there, those outputs are piped to the prompt template to fill in the defined placeholders, then sent to the `llm` SageMaker endpoint for generation. Finally, the model output is sent to the `StrOutputParser` to convert into a usable string.

In [None]:
# Create SageMaker LLM
llm = create_sagemaker_llm(
    endpoint_name=llm_endpoint_name,
    generation_config=generation_configuration,
    content_handler=content_handler,
    region_name=variables["regionName"]
)

# Create RAG chain
qa_chain = aru.create_rag_chain(
    retriever=retriever,
    llm=llm,
    prompt_template=prompt_template
)

You can now test your model with an example query. This query will get converted to an embedding and used for Knowledge Base search prior to question answering.

In [None]:
# Test query
query = "In CONSOLIDATED STATEMENTS OF CASH FLOWS, How much did net income change in years 2022, 2023, 2024?"

# Invoke RAG chain
response = invoke_rag_chain(qa_chain, query)

print(f"Question: {query}")
print(f"Answer: {response}")

## RAG using boto3

If you are not using LangChain, you can still perform the same tasks using the standard boto3 apis. This example shows how to use the Bedrock Knowledge Base `retrieve` API for retrieval, manually building the generation prompt, then using the SageMaker `invoke_endpoint` API to generate the output. This approach provides the most flexibility by leveraging low level constructs to build your own orchestration flow.

First, set up resources using configuration from above and define the boto3 client for Bedrock, you'll use this to perform retrievals from your knowledge base.

In [None]:
# Initialize constants
KNOWLEDGE_BASE_ID = kb_id
ENDPOINT_NAME = llm_endpoint_name
NUM_RESULTS = number_of_results

Now, let's call our utility function to perform RAG using direct boto3 approach. This combines all the steps: retrieval, prompt formatting, and generation.

In [None]:
# Perform RAG using boto3
query, response = setup_and_run_rag_with_sagemaker(
    query=query,
    kb_id=KNOWLEDGE_BASE_ID,
    endpoint_name=ENDPOINT_NAME,
    generation_config=generation_configuration,
    number_of_results=NUM_RESULTS,
    region_name=variables["regionName"]
)

print(f"Question: {query}")
print(f"Answer: {response}")