# Retrieval-Augmented Generation: Question Answering based on Custom Dataset with Open-sourced [LangChain](https://python.langchain.com/en/latest/index.html) Library
- 원본 코드
    - https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/question_answering_retrieval_augmented_generation/question_answering_langchain_jumpstart.ipynb

# 1. 기본 환경 설정

In [2]:
%load_ext autoreload
%autoreload 2

# src 폴더 경로 설정
import sys
sys.path.append('../common_code')

In [3]:
import time
import sagemaker, boto3, json
from sagemaker.session import Session
from sagemaker.model import Model
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base


sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()
model_version = "*"

## 모델 정보 입력
- SageMaker 엔드포인트 ARN 입력 등

In [4]:
_MODEL_CONFIG_ = {
    "KoAlpaca-12-8B": {
        "instance type": "ml.g5.12xlarge",
        "endpoint_name" : "KoAlpaca-12-8B-2023-05-27-06-52-39",
        "env": {"TS_DEFAULT_WORKERS_PER_MODEL": "1"},
        "parse_function": "parse_response_model_KoAlpaca",
        "prompt": """Answer based on context:\n\n{context}\n\n{question}""",
    },
    "KoSimCSE-roberta": {
        "instance type": "ml.g5.12xlarge",
        "endpoint_name" : "KoSimCSE-roberta-2023-05-28-02-08-55",        
        "env": {"TS_DEFAULT_WORKERS_PER_MODEL": "1"},
    },
}

# 2. LLM 에 Context 없이 추론 테스트

In [5]:
# question = "Which instances can I use with Managed Spot Training in SageMaker?"
q = "홈플러스 중계점은 몇시까지 장사해?"
c = None
# prompt_wo_c = f"### 질문: {q}\n\n### 맥락: {c}\n\n### 답변:" if c else f"### 질문: {q}\n\n### 답변:" 
prompt_wo_c = f"### question: {q}\n\n### context: {c}\n\n### answer:" if c else f"### question: {q}\n\n### answer:" 
print("prompt_wo_c: \n", prompt_wo_c)

prompt_wo_c: 
 ### question: 홈플러스 중계점은 몇시까지 장사해?

### answer:


In [6]:

from inference_lib import invoke_inference, query_endpoint_with_text_payload
from inference_lib import parse_response_text_model

model_id = "KoAlpaca-12-8B"
endpoint_name = _MODEL_CONFIG_[model_id]["endpoint_name"]

query_response = query_endpoint_with_text_payload(
    prompt_wo_c, endpoint_name=endpoint_name, 
)

query_response = parse_response_text_model(query_response)
print(query_response)

### question: 홈플러스 중계점은 몇시까지 장사해?

### answer: 홈플러스 중계점은 자정까지 영업합니다.


You can see the generated answer is wrong or doesn't make much sense. 

# 3. 데이터 준비

In [7]:
# original_data = "s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/"

# !mkdir -p rag_data
# !aws s3 cp --recursive $original_data rag_data

In [8]:
import glob
import os
import pandas as pd

# all_files = glob.glob(os.path.join("rag_data/", "Amazon_SageMaker_FAQs.csv"))
# all_files = glob.glob(os.path.join("rag_data/", "Korean_Sample_FAQ.csv"))
# all_files = glob.glob(os.path.join("rag_data/", "English_Sample_FAQ.csv"))



# df_knowledge = pd.concat(
#     (pd.read_csv(f, header=None, names=["Question", "Answer"]) for f in all_files),
#     axis=0,
#     ignore_index=True,
# )

In [9]:
# df_knowledge.drop(["Question"], axis=1, inplace=True)

In [10]:
# df_knowledge.to_csv("rag_data/processed_data.csv", header=False, index=False)

In [11]:
# s_num = 6
# simple_knowledge = df_knowledge[0:s_num]
# simple_knowledge.to_csv("rag_data/simple-processed_data.csv", header=False, index=False)

## Sample File
- Lang Chain CSV Loader Code
    - https://github.com/hwchase17/langchain/blob/master/langchain/document_loaders/csv_loader.py

In [12]:
from langchain.document_loaders.csv_loader import CSVLoader

loader = CSVLoader(file_path="rag_data/simple-processed_data.csv", encoding="utf-8")
documents = loader.load()
documents


[Document(page_content='FAQDoc: Is the Bundang E-Mart store open this Sunday?\n: ', metadata={'source': 'rag_data/simple-processed_data.csv', 'row': 0}),
 Document(page_content='FAQDoc: Is Bundang E-Mart open on Sunday?\n: ', metadata={'source': 'rag_data/simple-processed_data.csv', 'row': 1}),
 Document(page_content='FAQDoc: What time is the Bundang E-Mart store open on Saturdays?\n: ', metadata={'source': 'rag_data/simple-processed_data.csv', 'row': 2}),
 Document(page_content='FAQDoc: Hi welcome to E-Mart Bundang store\n: ', metadata={'source': 'rag_data/simple-processed_data.csv', 'row': 3}),
 Document(page_content='FAQDoc: Parking lot is next to the main building\n: ', metadata={'source': 'rag_data/simple-processed_data.csv', 'row': 4})]

In [13]:

# text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0)
# texts = text_splitter.split_documents(documents) ### if you use langchain.document_loaders.TextLoader to load text file. You can uncomment the code
## to split the text.

Next, we wrap up our SageMaker endpoints for LLM into `langchain.llms.sagemaker_endpoint.SagemakerEndpoint`. 

# 4 SageMaker Endpoint Wrapper 준비

## SageMaker LLM Wrapper

In [14]:
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint

In [15]:
from inference_lib import KoAlpacaContentHandler
_KoAlpacaContentHandler = KoAlpacaContentHandler()

In [16]:
parameters = {}

sm_llm = SagemakerEndpoint(
    endpoint_name=_MODEL_CONFIG_["KoAlpaca-12-8B"]["endpoint_name"],
    region_name=aws_region,
    model_kwargs=parameters,
    content_handler=_KoAlpacaContentHandler,
)

## SageMaker Embedding Model Wrapper

In [17]:
from inference_lib import SagemakerEndpointEmbeddingsJumpStart
from inference_lib import KoSimCSERobertaContentHandler

In [18]:

_KoSimCSERobertaContentHandler = KoSimCSERobertaContentHandler()

# content_handler = ContentHandler()

embeddings = SagemakerEndpointEmbeddingsJumpStart(
    endpoint_name=_MODEL_CONFIG_["KoSimCSE-roberta"]["endpoint_name"],
    region_name=aws_region,
    content_handler=_KoSimCSERobertaContentHandler,
)

**Now, we can build an QA application. <span style="color:red">LangChain makes it extremly simple with following few lines of code</span>.**

# 5. Vector Store 생성
- FAISS Vector Store 생성

In [19]:
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from langchain.document_loaders import TextLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import Chroma, AtlasDB, FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain


In [20]:
index_creator = VectorstoreIndexCreator(
    vectorstore_cls=FAISS,
    embedding=embeddings,
#    text_splitter=CharacterTextSplitter(chunk_size=300, chunk_overlap=0),
    text_splitter=CharacterTextSplitter(chunk_size=10000, chunk_overlap=0),    
)

In [21]:
index = index_creator.from_loaders([loader])

  ndim = np.array(response_json).ndim


In [22]:
index.vectorstore.index_to_docstore_id

{0: 'd590d001-c122-46d5-b188-869f38fc3dcd',
 1: 'acfe291a-6fc5-4e02-a5e4-266522ada42e',
 2: '75aceacf-d757-4eca-b881-c14fdb6a7c29',
 3: 'acce2082-738c-405d-9e0e-5dee1a9f0b30',
 4: '418a4dc9-6f53-4f7c-b163-a638d05b7c3a'}

In [23]:
question = 'what time is the store opened ?'

In [None]:
index.query(question=question, llm=sm_llm)

# 6. 다른 프로프트로 QA 애플리케이션 테스트

In [None]:
docsearch = FAISS.from_documents(documents, embeddings)

In [None]:
question

Based on the question above, we then **identify top K most relevant documents based on user query, where K = 3 in this setup**.

In [None]:
docs = docsearch.similarity_search(question, k=3)
docs

Print out the top 3 most relevant docuemnts as below.

Finally, we **combine the retrieved documents with prompt and question and send them into SageMaker LLM.** 

We define a customized prompt as below.

In [None]:
prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""

PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
PROMPT

In [None]:
chain = load_qa_chain(llm=sm_llm, prompt=PROMPT)

Send the top 3 most relevant docuemnts and question into LLM to get a answer.

In [None]:
result = chain({"input_documents": docs, "question": question}, return_only_outputs=True)[
    "output_text"
]
result

Print the final answer from LLM as below, which is accurate.

In [None]:
r2 = chain({"input_documents": docs, "question": question}, return_only_outputs=True)[
    "output_text"
]
r2