# Retrieval-Augmented Generation: Question Answering based on Custom Dataset with Open-sourced [LangChain](https://python.langchain.com/en/latest/index.html) Library


# 1. 기본 환경 설정

In [2]:
%load_ext autoreload
%autoreload 2

# src 폴더 경로 설정
import sys
sys.path.append('../common_code')

In [3]:
import time
import sagemaker, boto3, json
from sagemaker.session import Session
from sagemaker.model import Model
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base


sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()
model_version = "*"

In [4]:

# def parse_response_model_KoAlpaca(query_response):
#     model_predictions = json.loads(query_response["Body"].read())
#     print("model_predictions \n", model_predictions)
#     generated_text = model_predictions["generated_texts"]
#     return generated_text


Deploy SageMaker endpoint(s) for large language models and GPT-J 6B embedding model. Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance.

In [5]:
_MODEL_CONFIG_ = {
    "KoAlpaca-12-8B": {
        "instance type": "ml.g5.12xlarge",
        "endpoint_name" : "KoAlpaca-12-8B-2023-05-27-06-52-39",
        "env": {"TS_DEFAULT_WORKERS_PER_MODEL": "1"},
        "parse_function": "parse_response_model_KoAlpaca",
        "prompt": """Answer based on context:\n\n{context}\n\n{question}""",
    },
    "KoSimCSE-roberta": {
        "instance type": "ml.g5.12xlarge",
        "endpoint_name" : "KoSimCSE-roberta-2023-05-28-02-08-55",        
        "env": {"TS_DEFAULT_WORKERS_PER_MODEL": "1"},
    },
}

# embedding inference

In [6]:
payload_0 = {
    "inputs" : "이번 주 일요일에 분당 이마트 점은 문을 여나요"
}


## Step 2. Ask a question to LLM without providing the context

To better illustrate why we need retrieval-augmented generation (RAG) based approach to solve the question and anwering problem. Let's directly ask the model a question and see how they respond.

In [7]:
# question = "Which instances can I use with Managed Spot Training in SageMaker?"
q = "홈플러스 중계점은 몇시까지 장사해?"
c = None
# prompt_wo_c = f"### 질문: {q}\n\n### 맥락: {c}\n\n### 답변:" if c else f"### 질문: {q}\n\n### 답변:" 
prompt_wo_c = f"### question: {q}\n\n### context: {c}\n\n### answer:" if c else f"### question: {q}\n\n### answer:" 
print("prompt_wo_c: \n", prompt_wo_c)

prompt_wo_c: 
 ### question: 홈플러스 중계점은 몇시까지 장사해?

### answer:


In [8]:

from inference_lib import invoke_inference, query_endpoint_with_text_payload
from inference_lib import parse_response_text_model

model_id = "KoAlpaca-12-8B"
endpoint_name = _MODEL_CONFIG_[model_id]["endpoint_name"]

query_response = query_endpoint_with_text_payload(
    prompt_wo_c, endpoint_name=endpoint_name, 
)

query_response = parse_response_text_model(query_response)
print(query_response)

ImportError: cannot import name 'EmbeddingsContentHandler' from 'langchain.embeddings.sagemaker_endpoint' (/opt/conda/lib/python3.9/site-packages/langchain/embeddings/sagemaker_endpoint.py)

You can see the generated answer is wrong or doesn't make much sense. 

# 데이터 준비

In [9]:
original_data = "s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/"

!mkdir -p rag_data
!aws s3 cp --recursive $original_data rag_data

download: s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/Amazon_SageMaker_FAQs.csv to rag_data/Amazon_SageMaker_FAQs.csv


In [10]:
import glob
import os
import pandas as pd

# all_files = glob.glob(os.path.join("rag_data/", "Amazon_SageMaker_FAQs.csv"))
# all_files = glob.glob(os.path.join("rag_data/", "Korean_Sample_FAQ.csv"))
all_files = glob.glob(os.path.join("rag_data/", "English_Sample_FAQ.csv"))



df_knowledge = pd.concat(
    (pd.read_csv(f, header=None, names=["Question", "Answer"]) for f in all_files),
    axis=0,
    ignore_index=True,
)

In [11]:
# df_knowledge.drop(["Question"], axis=1, inplace=True)

In [12]:
df_knowledge.to_csv("rag_data/processed_data.csv", header=False, index=False)

In [13]:
s_num = 6
simple_knowledge = df_knowledge[0:s_num]
simple_knowledge.to_csv("rag_data/simple-processed_data.csv", header=False, index=False)

## Sample File
- Lang Chain CSV Loader Code
    - https://github.com/hwchase17/langchain/blob/master/langchain/document_loaders/csv_loader.py

In [14]:
from langchain.document_loaders.csv_loader import CSVLoader

loader = CSVLoader(file_path="rag_data/simple-processed_data.csv", encoding="utf-8")
documents = loader.load()
documents






[Document(page_content='FAQDoc: Is the Bundang E-Mart store open this Sunday?\n: ', metadata={'source': 'rag_data/simple-processed_data.csv', 'row': 0}),
 Document(page_content='FAQDoc: Is Bundang E-Mart open on Sunday?\n: ', metadata={'source': 'rag_data/simple-processed_data.csv', 'row': 1}),
 Document(page_content='FAQDoc: What time is the Bundang E-Mart store open on Saturdays?\n: ', metadata={'source': 'rag_data/simple-processed_data.csv', 'row': 2}),
 Document(page_content='FAQDoc: Hi welcome to E-Mart Bundang store\n: ', metadata={'source': 'rag_data/simple-processed_data.csv', 'row': 3}),
 Document(page_content='FAQDoc: Parking lot is next to the main building\n: ', metadata={'source': 'rag_data/simple-processed_data.csv', 'row': 4})]

In [15]:

# text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0)
# texts = text_splitter.split_documents(documents) ### if you use langchain.document_loaders.TextLoader to load text file. You can uncomment the code
## to split the text.

# SageMaker Embedding Model Wrapper

In [2]:
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

TypeError: issubclass() arg 1 must be a class

In [3]:
! pip list | grep langchain

langchain                  0.0.148


In [16]:
from inference_lib import SagemakerEndpointEmbeddingsJumpStart
from inference_lib import EmbeddingContentHandler



ImportError: cannot import name 'EmbeddingsContentHandler' from 'langchain.embeddings.sagemaker_endpoint' (/opt/conda/lib/python3.9/site-packages/langchain/embeddings/sagemaker_endpoint.py)

In [None]:

EmbeddingContentHandler = EmbeddingContentHandler()

# content_handler = ContentHandler()

embeddings = SagemakerEndpointEmbeddingsJumpStart(
    endpoint_name=_MODEL_CONFIG_["KoSimCSE-roberta"]["endpoint_name"],
    region_name=aws_region,
    content_handler=EmbeddingContentHandler,
)

Next, we wrap up our SageMaker endpoints for LLM into `langchain.llms.sagemaker_endpoint.SagemakerEndpoint`. 

# SageMaker LLM Wrapper

In [None]:
from inference_lib import KorLLMContentHandler
LLM_contentHandler = KorLLMContentHandler()

In [None]:
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint

In [None]:


parameters = {}

# content_handler = ContentHandler()

sm_llm = SagemakerEndpoint(
    endpoint_name=_MODEL_CONFIG_["KoAlpaca-12-8B"]["endpoint_name"],
    region_name=aws_region,
    model_kwargs=parameters,
    content_handler=LLM_contentHandler,
)

In [None]:
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from langchain.document_loaders import TextLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import Chroma, AtlasDB, FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain


**Now, we can build an QA application. <span style="color:red">LangChain makes it extremly simple with following few lines of code</span>.**

In [None]:
question = "hi mart"

In [None]:
index_creator = VectorstoreIndexCreator(
    vectorstore_cls=FAISS,
    embedding=embeddings,
#    text_splitter=CharacterTextSplitter(chunk_size=300, chunk_overlap=0),
    text_splitter=CharacterTextSplitter(chunk_size=10000, chunk_overlap=0),    
)

In [None]:
# a = [[[[-0.043768156319856644, -0.2643952965736389]]]]
# import numpy as np

# np.array(a).shape

In [None]:
index = index_creator.from_loaders([loader])

In [None]:
index.vectorstore.index_to_docstore_id

In [None]:
index.vectorstore.similarity_search("hi")

In [None]:
index.query(question=question, llm=sm_llm)

## Step 5. Customize the QA application above with different prompt.

Now, we see how simple it is to use LangChain to achieve question and answering application with just few lines of code. Let's break down the above `VectorstoreIndexCreator` and see what's happening under the hood. Furthermore, we will see how to incorporate a customize prompt rather than using a default prompt with `VectorstoreIndexCreator`.

Firstly, we **generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**

In [None]:
docsearch = FAISS.from_documents(documents, embeddings)

In [None]:
question

Based on the question above, we then **identify top K most relevant documents based on user query, where K = 3 in this setup**.

In [None]:
docs = docsearch.similarity_search(question, k=3)

Print out the top 3 most relevant docuemnts as below.

In [None]:
docs

Finally, we **combine the retrieved documents with prompt and question and send them into SageMaker LLM.** 

We define a customized prompt as below.

In [None]:
prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""

PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

In [None]:
PROMPT

In [None]:
chain = load_qa_chain(llm=sm_llm, prompt=PROMPT)

Send the top 3 most relevant docuemnts and question into LLM to get a answer.

In [None]:
result = chain({"input_documents": docs, "question": question}, return_only_outputs=True)[
    "output_text"
]

Print the final answer from LLM as below, which is accurate.

In [None]:
result

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answerIng_retrieval_augmented_generation_jumpstart|question_answerIng_langchain_jumpstart.ipynb)
