# Retrieval-Augmented Generation: Question Answering based on Custom Dataset with Open-sourced [LangChain](https://python.langchain.com/en/latest/index.html) Library


<div class="alert alert-block alert-warning">
    This notebook should be run using the <b>Data Science 2.0 </b> Python kernel !!!  </div>

In this notebook we will demonstrate how to use a **Flan T5 XL** to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from **GPT-J-6B-FP16** embedding model. 

**This notebook serves a template such that you can easily replace the example dataset by your own to build a custom question and asnwering application.**

---
## Step 1. Deploy large language model (LLM) and embedding model in SageMaker JumpStart


In [None]:
# update libraries
!pip install --upgrade pip
!pip install --upgrade sagemaker --quiet
!pip install ipywidgets==7.0.0 --quiet
!pip install langchain==0.0.148 --quiet
!pip install faiss-cpu --quiet

In [None]:
import time
from time import gmtime, strftime, sleep
import sagemaker, boto3, json
from sagemaker.session import Session
from sagemaker.model import Model
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base
from typing import Any, Dict, List, Optional
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()
sagemaker_client = boto3.client('sagemaker')
model_version = "*"

In [None]:
# define the names of the endpoints
endpoint_name_llm = "llm-endpoint-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
endpoint_name_embeddings = "embeddings-endpoint-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

In [None]:
# deploy the LLM model to an endpoing
from sagemaker.jumpstart.model import JumpStartModel

model_llm = JumpStartModel(
    model_id="huggingface-text2text-flan-t5-xl",  # huggingface-llm-falcon-7b-instruct-bf16
    instance_type="ml.g5.2xlarge"
)
predictor_llm = model_llm.deploy(
    endpoint_name=endpoint_name_llm,
    wait=False  # allow notebook to continue to be responsive
)

In [None]:
# deploy the Embedding model to an endpoing
model_embedding = JumpStartModel(
    model_id="huggingface-textembedding-gpt-j-6b-fp16",
    instance_type="ml.p3.2xlarge"
)
predictor_embedding = model_embedding.deploy(
    endpoint_name=endpoint_name_embeddings,
    wait=False  # allow notebook to continue to be responsive
)

In [None]:
# Poll every 30sec to see when both endpoints are active

endpoint_llm = False
endpoint_embedding = False

while ((endpoint_llm == False) | (endpoint_embedding == False)):

    active_endpoints = sagemaker_client.list_endpoints(
        StatusEquals='InService'
    )
    
    for endpoint in active_endpoints['Endpoints']:
        if endpoint['EndpointName'] == endpoint_name_llm:
            endpoint_llm = True
        if endpoint['EndpointName'] == endpoint_name_embeddings:
            endpoint_embedding = True
        
    print('LLM endpoint:', endpoint_llm, '| Embedding endpoint:', endpoint_embedding)
    sleep(30)

---
## Step 2. Ask a question to LLM without providing the context

To better illustrate why we need retrieval-augmented generation (RAG) based approach to solve the question and anwering problem. Let's directly ask the model a question and see how they respond.

In [None]:
# helper function to interact with the 2 endpoints

def query_endpoint_with_json_payload(payload, endpoint_name, model_type="llm"):
    client = boto3.client("runtime.sagemaker")
    encoded_json = json.dumps(payload).encode("utf-8")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, 
        ContentType="application/json", 
        Body=encoded_json
    )
    model_predictions = json.loads(response["Body"].read())
    if model_type == 'llm':
        output = model_predictions[0]["generated_text"]
    elif model_type == 'embeddings':
        output = model_predictions["embedding"]
    return output


In [None]:
# define a question string to be used for testing. You can change this question and try new things.

question = "Which instances can I use with Managed Spot Training in SageMaker?"

In [None]:
# test the embeddings endpoint

payload = {
    "text_inputs": question
}

response = query_endpoint_with_json_payload(
    payload=payload, 
    endpoint_name=endpoint_name_embeddings,
    model_type='embeddings'
)

print('Vector size:', len(response[0]), '\nFirst 20 numbers:', response[0][:20])

In [None]:
# test the LLM endpoint

payload = {
    "inputs": question,
    "parameters": {
        "max_length": 100,
        "num_return_sequences": 1,
        "top_k": 50,
        "top_p": 0.95,
        "temperature": 0.5,
        "do_sample": True,
    }
}

response = query_endpoint_with_json_payload(
    payload=payload, 
    endpoint_name=endpoint_name_llm,
    model_type='llm'
)

print('Question:', question, '\nAnswer:', response)


You can see the generated answer is wrong or doesn't make much sense. 

---
## Step 3. Improve the answer to the same question using **prompt engineering** with insightful context


To better answer the question well, we provide extra contextual information, combine it with a prompt, and send it to model together with the question. Below is an example.

In [None]:
context = "Managed Spot Training can be used with all instances supported in Amazon SageMaker. Managed Spot Training is supported in all AWS Regions where Amazon SageMaker is currently available."

In [None]:
payload = {
    "inputs": context + question,
    "parameters": {
        "max_length": 100,
        "num_return_sequences": 1,
        "top_k": 50,
        "top_p": 0.95,
        "temperature": 0.5,
        "do_sample": True,
    }
}

response = query_endpoint_with_json_payload(
    payload=payload, 
    endpoint_name=endpoint_name_llm,
    model_type='llm'
)

print('Context:', context, '\nQuestion:', question, '\nAnswer:', response)

The output from step 3 tells us the chance to get the correct response significantly correlates with the insightful context you send into the LLM. 

**<span style="color:red">Now, the question becomes where can I find the insightful context based on the user query? The answer is to use a pre-stored knowledge data base with retrieval augmented generation, as shown in step 4 below</span>.**

---
## Step 4. Use RAG-based approach with [LangChain](https://python.langchain.com/en/latest/index.html) and SageMaker endpoints 


We plan to use document embeddings to fetch the most relevant documents in our document knowledge library and combine them with the prompt that we provide to LLM.

To achieve that, we will do following.

1. **Generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B FP16 embedding model.**
2. **Identify top K most relevant documents based on user query.**
    - 2.1 **For a query of your interest, generate the embedding of the query using the same embedding model.**
    - 2.2 **Search the indexes of top K most relevant documents in the embedding space using in-memory Faiss search.**
    - 2.3 **Use the indexes to retrieve the corresponded documents.**
3. **Combine the retrieved documents with prompt and question and send them into SageMaker LLM.**



Note: The retrieved document/text should be large enough to contain enough information to answer a question; but small enough to fit into the LLM prompt -- maximum sequence length of 1024 tokens. 

---
To build a simiplied QA application with LangChain, we need: 
1. Wrap up our SageMaker endpoints for embedding model and LLM into `langchain.embeddings.SagemakerEndpointEmbeddings` and `langchain.llms.sagemaker_endpoint.SagemakerEndpoint`. That requires a small overwritten of `SagemakerEndpointEmbeddings` class to make it compatible with SageMaker embedding model.
2. Prepare the dataset to build the knowledge data base. 

---

Wrap up our SageMaker endpoints for embedding model into `langchain.embeddings.SagemakerEndpointEmbeddings`. That requires a small overwritten of `SagemakerEndpointEmbeddings` class to make it compatible with SageMaker embedding model.

In [None]:
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler


class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:
        """Compute doc embeddings using a SageMaker Inference Endpoint.

        Args:
            texts: The list of texts to embed.
            chunk_size: The chunk size defines how many input texts will
                be grouped together as request. If None, will use the
                chunk size specified by the class.

        Returns:
            List of embeddings, one for each text.
        """
        results = []
        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size

        for i in range(0, len(texts), _chunk_size):
            response = self._embedding_func(texts[i : i + _chunk_size])
            print
            results.extend(response)
        return results


class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        return embeddings


content_handler = ContentHandler()

embeddings = SagemakerEndpointEmbeddingsJumpStart(
    endpoint_name=endpoint_name_embeddings,
    region_name=aws_region,
    content_handler=content_handler,
)

Next, we wrap up our SageMaker endpoints for LLM into `langchain.llms.sagemaker_endpoint.SagemakerEndpoint`. 

In [None]:
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint

parameters = {
    "max_length": 200,
    "num_return_sequences": 1,
    "top_k": 250,
    "top_p": 0.95,
    "do_sample": False,
    "temperature": 1,
}


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"inputs": prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]


content_handler = ContentHandler()

sm_llm = SagemakerEndpoint(
    endpoint_name=endpoint_name_llm,
    region_name=aws_region,
    model_kwargs=parameters,
    content_handler=content_handler,
)

Now, let's download the example data and prepare it for demonstration. We will use [Amazon SageMaker FAQs](https://aws.amazon.com/sagemaker/faqs/) as knowledge library. The data are formatted in a CSV file with two columns Question and Answer. We use the Answer column as the documents of knowledge library, from which relevant documents are retrieved based on a query. 

**For your purpose, you can replace the example dataset of your own to build a custom question and answering application.**

### Prepare a text dataset to be used as context for RAG

In [None]:
original_data = "s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/"

!mkdir -p rag_data
!aws s3 cp --recursive $original_data rag_data

For the case when you have data saved in multiple subsets. The following code will read all files that end with `.csv` and concatenate them together. Please ensure each `csv` file has the same format.

In [None]:
import glob
import os
import pandas as pd

all_files = glob.glob(os.path.join("rag_data/", "*.csv"))

df_knowledge = pd.concat(
    (pd.read_csv(f, header=None, names=["Question", "Answer"]) for f in all_files),
    axis=0,
    ignore_index=True,
)

df_knowledge.drop(["Question"], axis=1, inplace=True)  # Drop the Question column as it is not used in this demonstration.
df_knowledge.to_csv("rag_data/processed_data.csv", header=False, index=False)  # save the documents as a csv file

df_knowledge.head(10)

### Setup RAG with LangChain

In [None]:
from langchain.document_loaders import TextLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import FAISS  # in-memory vector store
from langchain.text_splitter import CharacterTextSplitter
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.document_loaders.csv_loader import CSVLoader

Use langchain to read the `csv` data. There are multiple built-in functions in LangChain to read different format of files such as `txt`, `html`, and `pdf`. For details, see [LangChain document loaders](https://python.langchain.com/en/latest/modules/indexes/document_loaders.html).

In [None]:
loader = CSVLoader(file_path="rag_data/processed_data.csv")  # define loader object for the csv file
documents = loader.load()

# create an in-memory vector store
index_creator = VectorstoreIndexCreator(
    vectorstore_cls=FAISS,
    embedding=embeddings,
    text_splitter=CharacterTextSplitter(
        chunk_size=300, 
        chunk_overlap=0
    ),
)

# index all the documents
index = index_creator.from_loaders([loader])  

In [None]:
# test the RAG chain

question = "What is SageMaker?"

rag_output = index.query(
    question=question, 
    llm=sm_llm
)

print('Question:',question, '\nRAG answer:', rag_output)

### Compare RAG vs oroginal LLM

Uncomment ony by one the following lines and try different questions, in order to see the differences between RAG and the "out of the box" LLM. Try each question mutliple times, and see how the LLM outputs vary each time, whereas the RAG output is more stable. 

In [None]:
question = "What is SageMaker?"
# question = "What algorithms does SageMaker Autopilot support?"
# question = "What is the cost for SageMaker Kubeflow Pipelines?"
# question = "What is SageMaker Neo?"
# question = "Which instances can I use with Managed Spot Training in SageMaker?"
# question = "What is SageMaker Ground Truth?"
# question = "What is SageMaker Geospatial?"
# question = "Who is the president of Greece?"


rag_output = index.query(
    question=question, 
    llm=sm_llm
)

payload = {
    "inputs": question,
    "parameters": {
        "max_length": 100,
        "num_return_sequences": 1,
        "top_k": 50,
        "top_p": 0.95,
        "temperature": 0.5,
        "do_sample": True,
    }
}

llm_output = query_endpoint_with_json_payload(
    payload=payload, 
    endpoint_name=endpoint_name_llm,
    model_type='llm'
)

print('Question:',question, '\nLLM (out-of-the-box) answer:',llm_output, '\nRAG answer:', rag_output)

---
## Step 5. Customizing RAG approach

Now, we see how simple it is to use LangChain to achieve question and answering application with just few lines of code. Let's break down the above `VectorstoreIndexCreator` and see what's happening under the hood. Furthermore, we will see how to incorporate a customize prompt rather than using a default prompt with `VectorstoreIndexCreator`.

We generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B FP16 embedding model. Based on the question above, we then identify top K most relevant documents based on user query, where K = 3 in this setup. Then we print out the top 3 most relevant docuemnts as below.

In [None]:
document_embeddings = FAISS.from_documents(documents, embeddings)  # generate embeddings for all the documents

In [None]:
print('Question:', question, '\n\n')

retrieved_documents = document_embeddings.similarity_search(question, k=3)  # the top 3 most similar documents to the given question

for i,document in enumerate(retrieved_documents):
    print('Document', i, '----------')
    print(document)
    print('   ')

Finally, we **combine the retrieved documents with prompt and question and send them into SageMaker LLM.** 

We define a customized prompt as below.

In [None]:
# define a custom prompt template

prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""

PROMPT = PromptTemplate(
    template=prompt_template, 
    input_variables=["context", "question"]
)

In [None]:
# form a chain

chain = load_qa_chain(
    llm=sm_llm, 
    prompt=PROMPT, 
    verbose=True,  # in order to show what happens behind the hood
    chain_type="stuff"
)

In [None]:
# get an answer from the prompt template

rag_output = chain(
    {"input_documents": retrieved_documents, "question": question}, 
    return_only_outputs=True
)["output_text"]

print('\n\nAnswer:', rag_output)

---
## Clean up the environment
Delete the endpoints if you don't use them because they incure cost per hour!

In [None]:
# Create a low-level SageMaker service client.
sagemaker_client = boto3.client('sagemaker', region_name=aws_region)

# Delete endpoint
print('Deleting endpoint:', endpoint_name_llm)
sagemaker_client.delete_endpoint(EndpointName=endpoint_name_llm)
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_name_llm)

# Delete endpoint
print('Deleting endpoint:', endpoint_name_embeddings)
sagemaker_client.delete_endpoint(EndpointName=endpoint_name_embeddings)
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_name_embeddings)