# Building a simple RAG chatbot with LangChain, Hugging Face, FAISS, Amazon SageMaker and Amazon Textract

In [None]:
%%sh
pip install sagemaker langchain amazon-textract-caller amazon-textract-textractor sentence-transformers pypdf pip install faiss-cpu -qU

In [None]:
import boto3, json, sagemaker
from typing import Dict
from langchain import LLMChain
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri

## Deploy LLM on SageMaker

In [None]:
role = sagemaker.get_execution_role()

hub = {
	'HF_MODEL_ID':'mistralai/Mistral-7B-Instruct-v0.1',
	'SM_NUM_GPUS': '1'
}

huggingface_model = HuggingFaceModel(
	image_uri=get_huggingface_llm_image_uri("huggingface",version="1.1.0"),
	env=hub,
	role=role 
)

predictor = huggingface_model.deploy(
	initial_instance_count=1,
	instance_type="ml.g5.2xlarge",
	container_startup_health_check_timeout=300,
  )

In [None]:
endpoint_name = predictor.endpoint_name
endpoint_name

## Configure LLM in LangChain

In [None]:
model_kwargs = {"max_new_tokens": 512, "top_p": 0.8, "temperature": 0.8}

In [None]:
class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps(
            # Mistral prompt, see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
            {"inputs": f"<s>[INST] {prompt} [/INST]", "parameters": {**model_kwargs}}
        )
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        splits = response_json[0]["generated_text"].split("[/INST] ")
        return splits[1]

content_handler = ContentHandler()

In [None]:
sm_client = boto3.client("sagemaker-runtime") # needed for AWS credentials

llm = SagemakerEndpoint(
    endpoint_name=endpoint_name,
    model_kwargs=model_kwargs,
    content_handler=content_handler,
    client=sm_client,
)

## Zero-shot example

In [None]:
system_prompt = """
As a helpful energy specialist, please answer the question, focusing on numerical data.
Don't invent facts. If you can't provide a factual answer, say you don't know what the answer is.
"""

prompt = PromptTemplate.from_template(system_prompt + "{content}")

In [None]:
llm_chain = LLMChain(llm=llm, prompt=prompt)

In [None]:
question = "What is the latest trend for solar investments in China?"

query = f"question: {question}"

In [None]:
answer = llm_chain.run({query})
print(answer)

## RAG example with PDF files

In [None]:
from langchain.document_loaders import AmazonTextractPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA

### Upload local PDF files to S3

Sources:
* https://www.iea.org/reports/world-energy-investment-2023
* https://www.iea.org/reports/coal-2022
* https://www.iea.org/reports/world-energy-outlook-2023

Feel free to use your own files, the code below should work without any change.

In [None]:
# Define S3 bucket and prefix for PDF storage

bucket = sagemaker.Session().default_bucket()
prefix = "langchain-rag-demo"

In [None]:
%%sh -s $bucket $prefix
aws s3 cp --recursive pdfs s3://$1/$2/

In [None]:
# Build list of S3 URIs

s3 = boto3.client("s3")
objs = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
objs = objs['Contents']
uris = [f's3://{bucket}/{obj["Key"]}' for obj in objs]
uris    

### Analyze documents with Amazon Textract and split them in chunks

In [None]:
%%time

textract_client = boto3.client('textract')
splitter = RecursiveCharacterTextSplitter(chunk_size=256, chunk_overlap=0)

all_chunks = []

for uri in uris:
    loader = AmazonTextractPDFLoader(uri, client=textract_client)
    document = loader.load()
    chunks = splitter.split_documents(document)
    all_chunks += chunks
    print(f"Loaded {uri}, {len(document)} pages, {len(chunks)} chunks")

### Embed document chunks and store them in FAISS
https://github.com/facebookresearch/faiss 

In [None]:
# Define embedding model
# See https://huggingface.co/spaces/mteb/leaderboard

embedding_model_id = "BAAI/bge-small-en-v1.5"

embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model_id,
)

In [None]:
%%time
# Embed chunks
embeddings_db = FAISS.from_documents(all_chunks, embeddings)

In [None]:
# Save database
embeddings_db.save_local("faiss_index")

### Shortcut : load existing embedding database

In [None]:
embeddings_db = FAISS.load_local("faiss_index", embeddings)

********

### Configure RAG chain

In [None]:
retriever = embeddings_db.as_retriever(search_kwargs={"k": 10})

In [None]:
# Define prompt template
prompt_template = """
As a helpful energy specialist, please answer the question below, focusing on numerical data and using only the context below.
Don't invent facts. If you can't provide a factual answer, say you don't know what the answer is.

question: {question}

context: {context}
"""

prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

In [None]:
chain = RetrievalQA.from_chain_type(
    llm=llm, 
    chain_type="stuff",
    retriever=retriever, 
    chain_type_kwargs = {"prompt": prompt})

### Ask our question again

In [None]:
question = "What is the latest trend for solar investments in China?"
answer = chain.run({"query": question})
print(answer)

In [None]:
question = "What does STEPS mean?"
answer = chain.run({"query": question})
print(answer)

## Delete endpoint and model

In [None]:
predictor.delete_model()
predictor.delete_endpoint()