# Build Document Index and Retrieval Augmented Generation (RAG) with LangChain

---

This workshop uses SageMaker Notebook, and please ensure the kernel is set to **conda_python3**.

### Scenarios

In post-call analysis, we would like to provide guidance on how agents should respond to customer requests or demands. Therefore, we send the transcript as the prompt and search through the documents in the vector store to identify the specific guidance that needs to be extracted to address the agent's requirements.

In this notebook we will demonstrate:

1. Build an index on [Chroma](https://www.trychroma.com/) with a given document from the local directory. 

2. Implement Retrieval Augmented Generation (RAG) with [LangChain](https://python.langchain.com/docs/get_started/introduction.html) as an orchestration tool and [HuggingFaceEmbeddings](https://huggingface.co/sentence-transformers) as the main embedding model.


### Contents

- [1. Environment Setup](#1.-Environment-Setup)
- [2. Query the Endpoint with Testing Prompts](#2.-Query-the-Endpoint-with-Testing-Prompts)
- [3. Retrieval Augmented Generation with LangChain and Chroma](#3.-Retrieval-Augmented-Generation-with-LangChain-and-Chroma)

**Note**

This notebook serves as a template so that you can easily replace the example data with your own to build a custom question and answering application.

## 1. Environment Setup

In [None]:
!pip install langchain==0.0.275 --quiet
!pip install pypdf==3.15.4 --quiet
!pip install chromadb==0.4.7 --quiet
!pip install nltk --quiet
!pip install sentence_transformers --quiet

In [None]:
import json
import boto3
import sagemaker
from sagemaker.session import Session

# variables
sagemaker_session = Session()
region = boto3.session.Session().region_name

# boto3 clients
s3 = boto3.client('s3')
print(f"Region is {region}")

## 2. Query the Endpoint with Testing Prompts

To provide guidance to agents on handling customer requests and demands, we begin by asking the large language model (LLM) to summarize the intent from the transcript.

Please note down the endpoint name of Flan T5 XL in `1-model-selection-and-deployment.ipynb`.

In [None]:
# TODO: replace "None" with the endpoint name
ENDPOINT_NAME = None # i.e. jp-huggingface-text2text-flan-t5-xl-XXX


In [None]:
def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type="application/json"):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json
    )
    return response

def parse_response_model_flan_t5(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    generated_text = model_predictions["generated_texts"]
    return generated_text

_MODEL_CONFIG_ = {
    "huggingface-text2text-flan-t5-xl": {
        "instance type": "ml.g5.2xlarge",
        "env": {},
        "parse_function": parse_response_model_flan_t5,
        "prompt": """Answer based on context:\n\n{context}\n\n{question}""",  
        "endpoint_name": ENDPOINT_NAME
    }
}

In [None]:
prompt_intent = "Please summarize the transcript above."

In [None]:
prompt_template_1 = """
Here is what customer said in the call: 
{transcript}
    
    
{objective}
"""

In [None]:
# sample transcript 2
f = open("transcripts/negative-refund.txt", "r")
transcript_neg_refund = f.read()

In [None]:
prompt_summary = prompt_template_1.format(transcript=transcript_neg_refund, objective=prompt_intent)

payload = {
    "text_inputs": prompt_summary,
    "max_length": 100,
    "num_return_sequences": 1,
    "top_k": 10,
    "top_p": 0.95,
    "do_sample": True,
}

In [None]:
for model_id in _MODEL_CONFIG_.keys():
    endpoint_name = _MODEL_CONFIG_[model_id]['endpoint_name']
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = _MODEL_CONFIG_[model_id]["parse_function"](query_response)
    print(f"For model: {model_id}, the generated output is: {generated_texts[0]}\n")

In [None]:
# keep the response for the next session
customer_ask = generated_texts[0]
customer_ask

## 3. Retrieval Augmented Generation with LangChain and Chroma

Next, we divide the internal customer-facing guidance into chunks and perform an embedding search to retrieve the target documents. We then prompt the LLM to provide appropriate guidance according to the customer's intent.

In [None]:
from typing import Dict
import nltk

from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import NLTKTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

import warnings
warnings.filterwarnings('ignore')

nltk.download('punkt')

**[Step 1]** Split the document into chunks.

In [None]:
# split docs to chucks
loader = PyPDFLoader("StaffCapacityBuilding.pdf")
docs = loader.load()

text_splitter = NLTKTextSplitter(chunk_size=1000)
texts = text_splitter.split_documents(docs)

**[Step 2]** Save the chunks and their embeddings into the vector store. This step takes some time for data ingestion.

In [None]:
# choose your embedding model
embeddings = HuggingFaceEmbeddings()

# ingest documents into your vector store
vectordb = Chroma.from_documents(texts, embeddings)

**[Step 3]** Create a reusable prompt template.

In [None]:
class QAContentHandler(LLMContentHandler):
    
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["generated_texts"][0]

qa_content_handler = QAContentHandler()

In [None]:
prompt_template_2 = """ 
Given the following context from a document, please respond to the question below. 
If you cannot find a proper answer, then just say "You don't know."

Document: {document}
Question: {question}
Answer:
"""

prompt = PromptTemplate(
    input_variables = ["document", "question"], 
    template = prompt_template_2)

FLAN_T5_PARAMETERS = {
    "max_length": 100,
    "num_return_sequences": 2,
    "top_k": 5,
    "top_p": 0.6,
    "do_sample": True,
}

**[Step 4]** Chain the model and the prompt together.

In [None]:
# build a question-answer chain
qa_chain = LLMChain(
    llm = SagemakerEndpoint(
        endpoint_name = ENDPOINT_NAME, # replace with your endpoint name if needed
        region_name = region,
        model_kwargs = FLAN_T5_PARAMETERS,
        content_handler = qa_content_handler
    ),
    prompt = prompt
)

**[Step 5]** Extract the chunks based on similarity search.

In [None]:
# conduct the similarity search through the vector store
similar_docs = vectordb.similarity_search(customer_ask, k=3) 

In [None]:
# print out the retrieved documents
context_list = [a.page_content for a in similar_docs]
metadata_list = [a.metadata.get('source') for a in similar_docs]
context = "\n\n=============\n\n".join(context_list)
print(context)

**[Step 6]** Get the guidance given by the LLM.

In [None]:
qa_chain.run({
    'document': context,
    'question': f"How to react to the following situation: {customer_ask}?",
})