# Large Language Model Customization using Retrieval-Augmented Generation (RAG) Pattern and Faiss Similarity Search Library

---
This Amazon SageMaker Studio Notebook demonstrates how to use SageMaker and LangChain python libraries to generate text following the Retrieval-Augmented Generation (RAG) pattern. The notebook implements semantic search using [Faiss](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) similarity search library.

This notebook has the following prerequisites:
- Select an AWS region where [Amazon SageMaker JumpStart](https://aws.amazon.com/sagemaker/jumpstart) is available. 
- [Setup Amazon SageMaker Domain](https://docs.aws.amazon.com/sagemaker/latest/dg/onboard-quick-start.html).
- [Available service queta](https://docs.aws.amazon.com/general/latest/gr/sagemaker.html) for "ml.g5.12xlarge for endpoint usage" and "ml.g5.24xlarge for endpoint usage".
- Select the [Amazon SageMaker Kernel](https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-kernels.html), "Python 3 (Data Science 3.0) with Python 3.10" or higher.
- Familiarity with [LangChain framework](https://python.langchain.com/docs/get_started/introduction.html) used for developing applications powered by language models.
- Basic understanding of [Retrieval Augmented Generation (RAG)](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-customize-rag.html) pattern.
- About $10 per hour to spend on Amazon SageMaker JumpStart model deployments and usage of other AWS services.
---

In [5]:
!pip install --upgrade sagemaker --quiet
!pip install ipywidgets --quiet
!pip install langchain --quiet
!pip install faiss-cpu --quiet

[0m

In [6]:
# important required libraries

import sagemaker, boto3, json
from sagemaker.session import Session
from sagemaker.model import Model
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base
from typing import Any, Dict, List, Optional
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()
model_version = "*"

In [7]:
# define some helper functions

def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type="application/json"):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json
    )
    return response


def parse_response_model_flan_t5(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    generated_text = model_predictions["generated_texts"]
    return generated_text


def parse_response_multiple_texts_bloomz(query_response):
    generated_text = []
    model_predictions = json.loads(query_response["Body"].read())
    for x in model_predictions[0]:
        generated_text.append(x["generated_text"])
    return generated_text

#### Step 1: Deploy two SageMaker endpoints for Flan-T5-XXL 11B language model and GPT-J 6B embeddings model on ml.g5.12xlarge instances.

In [8]:
%%time

_MODEL_CONFIG_ = {
    "huggingface-text2text-flan-t5-xxl": {
        "instance type": "ml.g5.12xlarge",
        "env": {"SAGEMAKER_MODEL_SERVER_WORKERS": "1", "TS_DEFAULT_WORKERS_PER_MODEL": "1"},
        "parse_function": parse_response_model_flan_t5,
        "prompt": """Answer based on context:\n\n{context}\n\n{question}""",
    },
    "huggingface-textembedding-gpt-j-6b": {
        "instance type": "ml.g5.12xlarge",
        "env": {"SAGEMAKER_MODEL_SERVER_WORKERS": "1", "TS_DEFAULT_WORKERS_PER_MODEL": "1"},
    }
}

newline, bold, unbold = "\n", "\033[1m", "\033[0m"

for model_id in _MODEL_CONFIG_:
    endpoint_name = name_from_base(f"jumpstart-example-raglc-{model_id}")
    inference_instance_type = _MODEL_CONFIG_[model_id]["instance type"]

    # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.
    deploy_image_uri = image_uris.retrieve(
        region=None,
        framework=None,  # automatically inferred from model_id
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=inference_instance_type,
    )
    # Retrieve the model uri.
    model_uri = model_uris.retrieve(
        model_id=model_id, model_version=model_version, model_scope="inference"
    )
    model_inference = Model(
        image_uri=deploy_image_uri,
        model_data=model_uri,
        role=aws_role,
        predictor_cls=Predictor,
        name=endpoint_name,
        env=_MODEL_CONFIG_[model_id]["env"],
    )
    model_predictor_inference = model_inference.deploy(
        initial_instance_count=1,
        instance_type=inference_instance_type,
        predictor_cls=Predictor,
        endpoint_name=endpoint_name,
    )
    print(f"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}")
    _MODEL_CONFIG_[model_id]["endpoint_name"] = endpoint_name

-------------![1mModel huggingface-text2text-flan-t5-xxl has been deployed successfully.[0m

------------![1mModel huggingface-textembedding-gpt-j-6b has been deployed successfully.[0m



#### Step 2: Wrap the two SageMaker endpoints for LLM and embeddings model into langchain Framework and define LLM hyperparameters 

In [9]:
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:
        results = []
        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size

        for i in range(0, len(texts), _chunk_size):
            response = self._embedding_func(texts[i : i + _chunk_size])
            print
            results.extend(response)
        return results

class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        return embeddings

content_handler = ContentHandler()

embeddings = SagemakerEndpointEmbeddingsJumpStart(
    endpoint_name=_MODEL_CONFIG_["huggingface-textembedding-gpt-j-6b"]["endpoint_name"],
    region_name=aws_region,
    content_handler=content_handler,
)

In [10]:
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint

# define LLM hyperparameters 
parameters = {
    "max_length": 200,
    "num_return_sequences": 1,
    "top_k": 250,
    "top_p": 0.95,
    "do_sample": False,
    "temperature": 1,
}

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["generated_texts"][0]

content_handler = ContentHandler()

sm_llm = SagemakerEndpoint(
    endpoint_name=_MODEL_CONFIG_["huggingface-text2text-flan-t5-xxl"]["endpoint_name"],
    region_name=aws_region,
    model_kwargs=parameters,
    content_handler=content_handler,
)

#### Step 3: Generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embeddings model.
This notebook uses an example from SageMaker FAQ dataset. You can replace the example dataset of your own to build a custom question and answering application.

In [11]:
%%time

original_data = "s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/"

!mkdir -p rag_data
!aws s3 cp --recursive $original_data rag_data

download: s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/Amazon_SageMaker_FAQs.csv to rag_data/Amazon_SageMaker_FAQs.csv
CPU times: user 24.8 ms, sys: 18.1 ms, total: 42.9 ms
Wall time: 1.63 s


#### Step 4: Format the data by concatenation of any files ending with .csv and dropping the `Question` column as it is not used in this demonstration.

In [12]:
import glob
import os
import pandas as pd

all_files = glob.glob(os.path.join("rag_data/", "*.csv"))

df_knowledge = pd.concat(
    (pd.read_csv(f, header=None, names=["Question", "Answer"]) for f in all_files),
    axis=0,
    ignore_index=True,
)

df_knowledge.drop(["Question"], axis=1, inplace=True)
df_knowledge.to_csv("rag_data/processed_data.csv", header=False, index=False)

print("Top 5 documents after loading and formating:")
df_knowledge.head(5)

Top 5 documents after loading and formating:


Unnamed: 0,Answer
0,Amazon SageMaker is a fully managed service to...
1,For a list of the supported Amazon SageMaker A...
2,Amazon SageMaker is designed for high availabi...
3,Amazon SageMaker stores code in ML storage vol...
4,Amazon SageMaker ensures that ML model artifac...


#### Setp 5: Configure Langchain to read the `csv` data

In [13]:
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from langchain.document_loaders import TextLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import Chroma, AtlasDB, FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.document_loaders.csv_loader import CSVLoader

loader = CSVLoader(file_path="rag_data/processed_data.csv")
documents = loader.load()

#### Step 6: Generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.

In [17]:
%%time

docsearch = FAISS.from_documents(documents, embeddings)

top_doc = docsearch.similarity_search("Which instances can I use with Managed Spot Training in SageMaker?", k=1)
print(top_doc)

[Document(page_content='Amazon SageMaker is a fully managed service to prepare data and build, train, and deploy machine learning (ML) models for any use case with fully managed infrastructure, tools, and workflows.: Once a Managed Spot Training job is completed, you can see the savings in the AWS Management Console and also calculate the cost savings as the percentage difference between the duration for which the training job ran and the duration for which you were billed. Regardless of how many times your Managed Spot Training jobs are interrupted, you are charged only once for the duration for which the data was downloaded.', metadata={'source': 'rag_data/processed_data.csv', 'row': 88})]
CPU times: user 392 ms, sys: 13 ms, total: 405 ms
Wall time: 24.7 s


#### Step 7: Define the LLM prompt template

In [18]:
prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain = load_qa_chain(llm=sm_llm, prompt=PROMPT)

#### Step 8: Now let us build a little fun app to test prompting with and without RAG

In [20]:
import ipywidgets as widgets
from IPython.display import display, clear_output

def on_send_button_click(button):
    question = input_field.value
        
    if rag_check.value:
        topk_docs = docsearch.similarity_search(question, k=3)
        response = chain({"input_documents": topk_docs, "question": question}, return_only_outputs=True)["output_text"]
    else:
        response = sm_llm(question)
    
    with output:
        print("Q:", question)
        print("A:", response.strip())
        print("-"*40)

    input_field.value = ""

def on_input_field_submit(text):
    on_send_button_click(None)

# Create the input field and send button
input_field = widgets.Text(placeholder='Type your question here...')
rag_check = widgets.Checkbox(value=True, description='Enable RAG', indent=False)
send_button = widgets.Button(description='Send')
top_box = widgets.HBox([input_field, rag_check])
bottom_box = widgets.HBox([send_button])
v_box = widgets.VBox([top_box, bottom_box])
output = widgets.Output()

# Assign the function to the button click event and the input field submit event
send_button.on_click(on_send_button_click)
input_field.on_submit(on_input_field_submit)

# Display the chat interface
display(output, v_box)

Output()

VBox(children=(HBox(children=(Text(value='', placeholder='Type your question here...'), Checkbox(value=True, d…