# 6: Retrieval Augmented Generation Chatbot with LangChain

- SageMaker Notebook Kernel: `conda_python3`
- SageMaker Notebook Instance Type: ml.m5d.large | ml.t3.large

In this notebook, we will bring together all of the pieces to form the foundation of a retrieval augmented generation chatbot using [Amazon Bedrock](https://aws.amazon.com/bedrock/), [Amazon OpenSearch](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-vector-search.html) Vector Search, and [LangChain](https://python.langchain.com/). 

LangChain is a framework for developing applications powered by language models. The library provides LLM model adapters, data retrieval components, text splitters, conversation memory and storage, as well as components that wire all of these things together. It also provides agents which orchestrate LLMs and tools that go beyond chat interfaces.

You'll learn how to construct a simple chain with just a prompt and a model using LangChain expression language. Then you'll learn how to add retrieval components, chat history, and question rephrasing to the chain to build a chatbot that can answer questions with private data. 

## Runtime 

This notebook takes approximately 15 minutes to run.

## Contents

1. [Prerequisites](#prerequisites)
1. [Setup](#setup)
1. [LangChain Expression Language](#langchain-expression-language-lcel)
1. [Simple chain](#simple-chain)
1. [Retrieval chain](#retrieval-chain)
1. [Chat chain](#chat-chain)
1. [Question rephrasing](#question-rephrasing)
1. [Retriever chain](#retriever-chain)
1. [Chat with retriever chain](#chat-with-retriever-chain)

## Prerequisites

- Deployed Llama2 inference endpoint on Amazon SageMaker
- Bedrock user guide documentation ingested into Amazon OpenSearch Serverless Vector Store
- `amazon.titan-embed-text-v1` embeddings model enabled in the Amazon Bedrock console in `us-west-2`


## Setup

Let's start by installing and importing the required packages for this notebook. 

<div class="alert alert-block alert-warning"><b>Note:</b> Verify that the notebook kernel is `conda_python3`. Also, if you run into an issue where a module can't be imported after installation, restart the notebook kernel, then rerun the import notebook cell.</div>

In [None]:
%pip install langchain==0.0.317 --quiet
%pip install opensearch-py==2.3.2 --quiet

In [None]:
%reload_ext autoreload
%autoreload 2

import os
import sys
import json
import boto3
import langchain.vectorstores.opensearch_vector_search as ovs

from typing import Dict, List, Optional, Sequence
from pprint import pprint
from IPython.display import display
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, helpers
from langchain.embeddings import BedrockEmbeddings
from langchain.vectorstores import OpenSearchVectorSearch
from operator import itemgetter
from langchain.schema import Document
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableLambda, RunnableMap, RunnableBranch
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage, ChatMessage

# load langchain helper classes from our app
sys.path.append("../app/lib/langchain/")
from opensearch import create_ovs_client
from llama2 import Llama2Chat


***

Next, we will initialize the Amazon Bedrock boto3 client.

***


In [None]:
# boto3 session
boto3_session = boto3.Session()
region = boto3_session.region_name

print(f"Boto3 region: {region}")

***

Let's retrieve the Amazon OpenSearch collection id

***

In [None]:
aoss_client = boto3_session.client("opensearchserverless")
list_collections_response = aoss_client.list_collections()
collection_id = list_collections_response.get("collectionSummaries")[0].get("id")
index_name = "bedrock-docs"
print(f"OpenSearch collection name: {collection_id}")

***

Finally, initialize the Embeddings, OpenSearch and LLM LangChain foundational components

***

In [None]:
# Embeddings
bedrock_client = boto3_session.client("bedrock-runtime")
embeddings_model_id = "amazon.titan-embed-text-v1"
bedrock_embeddings = BedrockEmbeddings(
    client=bedrock_client, model_id=embeddings_model_id
)

# VectorStore (using a helper function to create LangChain OpenSearchVectorSearch with client patch)
vector_store = create_ovs_client(
    collection_id, index_name, region, boto3_session, bedrock_embeddings
)

# Our custom Llama2Chat class (see the /app/lib/langchain/llama2.py file)
llm = Llama2Chat(
    endpoint_name="llama-2-7b",
    client=boto3_session.client("sagemaker-runtime"),
    streaming=False,
)

***

A retriever is a lightweight wrapper around a vector store object to make it confirm to the retriever interface use by the LangChain components. We will initialize it here with the parameter, `k`, which represents the number of documents we want returned from the vector store when a search is run. 

***

In [None]:
num_docs_to_return = 2
retriever = vector_store.as_retriever(search_kwargs={"k": num_docs_to_return})

## LangChain Expression Language (LCEL)

LangChain Expression Language (LCEL) is a declarative language designed to facilitate the composition of chains. The principle behind LCEL is to enable an easier and more intuitive interaction with core components, aiding in the construction and management of chains in a simplified manner. LCEL provides a straightforward way to compose chains, aiding in the encapsulation of different operations within a unified syntax. This composition is enhanced through intuitive pipe operations, making the engagement with core components more effortless​.


If you want to dive deeper into LCEL see the following link [LCEL](https://python.langchain.com/docs/expression_language/)


## Simple chain

Let's start with a simple chain that invokes our LLM with a single prompt. Remember, prompting Llama2 requires the prompt to be formatted with special tokens to get good output from the model. The [Llama2Chat](../app/lib/langchain/llama2.py) class converts LangChain's role based ChatMessages (Human, AI, System) into the LLama 2 format before invoking the model, so we need to use the chat prompt template and chat message components.

Let's take a look at the code that creates the chain:

```python
simple_chain = simple_prompt | llm | StrOutputParser()
```

The code may look foreign, however here's a way to think about what's happening. 

When invoking the chain, the input is piped to the simple prompt template for formatting, then the formatted prompt is piped into the LLM to generate a response, then the LLM's response is piped into the string output parser and the parsed output is returned to the caller. 

The components are considered [runnables](https://api.python.langchain.com/en/latest/schema/langchain.schema.runnable.base.Runnable.html#langchain.schema.runnable.base.Runnable). A runnable is a unit of work that can be invoked, batched, streamed, transformed and composed. 

If you look at the `SIMPLE_TEMPLATE` you will see a placeholder `{question}`. When we call the invoke method of the chain we pass in an object with the value of that placeholder. The value gets replaced during the formatting step. Later in this notebook we will show more complex examples of these placeholders and how their values are provided from other chains.

<div class="alert alert-block alert-info"><b>Note: </b> LangChain supports response streaming, but for the sake of simplicity, we'll be utilizing the non-streaming method in this notebook. In the next module, you will have the option to use either method.</div>


In [None]:
SIMPLE_TEMPLATE = "{question}"

simple_prompt = ChatPromptTemplate.from_template(SIMPLE_TEMPLATE)

simple_chain = (
    simple_prompt | llm | StrOutputParser()
)

simple_chain.invoke({"question": "What is the capital of France?"})

## Retrieval chain

Let's add the vector store retriever to create a retrieval augmented generation (RAG) chain. Our prompt template now has two placeholders; One for the context retrieved by the vector store and one for the question. 

Our input to the invoke method is the question that needs to be used with the retriever to get the context and both the question and context need to be passed to the chat prompt (`rag_prompt`). We do this by creating a dictionary at the beginning of the chain with properties that are runnables. `itemgetter` picks values out of the input and allows us to pipe the question to the retriever and then bind that value to the `context`` key of the dictionary before it's passed into the prompt formatter.

In [None]:
RAG_TEMPLATE = """\
Answer the question based only on the following context:
{context}

Question: {question}\
"""

rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE)

rag_chain = (
    {
        "context": itemgetter("question") | retriever,
        "question": itemgetter("question"),
    }
    | rag_prompt
    | llm
    | StrOutputParser()
)

rag_chain.invoke({"question": "What is Amazon Bedrock?"})

## Chat chain 

So far we have only looked at single prompts, but chats are dialogs with many question and answer pairs. For conversational interfaces we need to include the dialog between the human and the ai so that the LLM can understand the entire context when responding to the question. 

We will update our prompt template to include a placeholder for `chat_history` and then pass the history when invoking the chain. LangChain uses message types to represent the dialog entities: `SystemMessage`, `HumanMessage`, and `AIMessage`. Remember that for Llama 2 the model expects a specific format for these messages which is built into the `Llama2Chat` class. We just need to make sure that the chat history we pass always has alternating human and ai messages with an optional system message at the begging otherwise you will get an exception.


In [None]:
chat_prompt = ChatPromptTemplate.from_messages(
    [
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{question}"),
    ]
)

chat_chain = (
    chat_prompt
    | llm
    | StrOutputParser()
)

chat_history = [
    HumanMessage(content="What is Amazon Bedrock?"),
    AIMessage(content="Amazon Bedrock is a fully managed service that makes base models from Amazon and third-party model providers accessible through an API."),
]

chat_chain.invoke({"question": "What does fully managed mean?",  "chat_history": chat_history})

## Question rephrasing 

Before we integrate the chat and retrieval systems, a challenge arises in comprehending user messages without the context of the preceding conversation. Consider a scenario where a user inquires, `What's Amazon Bedrock?`, followed by, `What models does it support?`. To effectively retrieve relevant documents from the vector store, we need to have an input encapsulating all pertinent contextual information from the dialogue or we won't retrieve the right data. 

We can accomplish this with a question rephrasing step that uses the LLM, the chat history, and the users question to generate a stand alone rephrased statement that can be used by the vector store to retrieve documents. 

Run the cell below to see how the users question `What models does it support?` is rephrased into a standalone question that can be passed to the vector store for search.

In [None]:
REPHRASE_TEMPLATE = """\
Given the following conversation and a follow up question, rephrase the follow up \
question to be a standalone question. Only return the standalone question and nothing else.

Chat History:
{chat_history}
Follow Up Input: {question}
Standalone Question:\
"""

condense_question_chain = (
    PromptTemplate.from_template(REPHRASE_TEMPLATE)
    | llm
    | StrOutputParser()
)

condense_question_chain.invoke({
    "chat_history": chat_history,
    "question": "What models does it support?"
})

***

Let's combine the `condense_question_chain` and the `retriever` and see what the response looks like. Does the content of the returned documents seem correct? 

***

In [None]:
retriever_condense_chain = condense_question_chain | retriever

retriever_condense_chain.invoke({
    "chat_history": chat_history,
    "question": "What models does it support?"
})

## Retriever chain

There is a case when we don't want to run the condense question chain and that's when there is no history. So we will use the `RunnableBranch` helper that only executes `retriever_condense_chain` if there is a `chat_history`, otherwise it just passes the question directly to the `retriever`.

In [None]:
retriever_chain = RunnableBranch(
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))),
        retriever_condense_chain,
    ),
    (RunnableLambda(itemgetter("question")) | retriever)
)

***

Let's run the retriever chain without chat history and see what happens. Notice that the standalone question prompt isn't executed and the question is passed through to the vector store.

***

In [None]:
retriever_chain.invoke({
    "chat_history": None, 
    "question": "What models does it support?"
})

***

Now run it with history to verify that the retriever chain is working correctly.

***

In [None]:
retriever_chain.invoke({
    "chat_history": chat_history, 
    "question": "What models does it support?"
})

## Chat with retriever chain

Let's combine all of the pieces together to create a retrieval augmented chat chain. First, let's define the system prompt template which will contain instructions for the model to follow with a place holder for the content retrieved by the retriever. 

In [None]:
SYSTEM_TEMPLATE = """\
Generate a comprehensive and informative answer of 80 words or less for the \
given question based solely on the provided search results (URL and content). You must \
only use information from the provided search results. Use an unbiased and \
journalistic tone. Combine search results together into a coherent answer. Do not \
repeat text. Cite search results using [${{number}}] notation. Only cite the most \
relevant results that answer the question accurately. Place these citations at the end \
of the sentence or paragraph that reference them - do not put them all at the end. If \
different results refer to different entities within the same name, write separate \
answers for each entity.

You should use bullet points in your answer for readability. Put citations where they apply
rather than putting them all at the end.

If there is nothing in the context relevant to the question at hand, just say "Hmm, \
I'm not sure." Don't try to make up an answer.

Anything between the following <context></context>  html blocks is retrieved from a knowledge \
bank, not part of the conversation with the user. 

<context>
    {context} 
</context>
"""

***

The content returned by the retriever is a list of `Documents` and we want to add some formatting for the LLM to understand that they are separate. `format_docs` is a helper method that takes a list of documents and formats each document text within `<doc></doc>` tags. The LLM can use this information to cite specific sources when providing responses.

***

In [None]:
def format_docs(docs: Sequence[Document]) -> str:
    formatted_docs = []
    for i, doc in enumerate(docs):
        doc_string = f"<doc id='{i}'>{doc.page_content}</doc>"
        formatted_docs.append(doc_string)
    return "\n".join(formatted_docs)

***

Previously, we use the `HumanMessage` and `AIMessage` classes, but let's create a helper method that takes a list of simple messages and converts into the correct types. This will be useful later when we build the chat application as in a production application you would need to serialize and deserialize the chat history between invocations from the user. It also makes it easier for us to play with the inputs.

***

In [None]:
def serialize_history(request):
    chat_history = request["chat_history"] or []
    converted_chat_history = []
    for message in chat_history:
        if message.get("human") is not None:
            converted_chat_history.append(HumanMessage(content=message["human"]))
        if message.get("ai") is not None:
            converted_chat_history.append(AIMessage(content=message["ai"]))
    return converted_chat_history


***

Now we will complete the chain by combining the previous parts. The following list explains each part in the code below.  

- `input_map` - maps over the input and converts the simple chat_history into message type objects.
- `context_chain` - passes the `input_map` to the retriever chain to rephrase question and search for content, then formats the content and maps it to the context variable. The question and chat_history are passed through.
- `rag_prompt` - combines the context, chat_history, and question with the system prompt to be executed by the llm
- `chat_retrieval_chain` - combines the workflow steps into a chain. `input -> input_map `

![](./assets/images/langchain-chat-retrieval-flow.png)

***

In [None]:
input_map = {
        "question": RunnableLambda(itemgetter("question")),
        "chat_history": RunnableLambda(serialize_history),
    }

context_chain = RunnableMap(
    {
        "context": retriever_chain | format_docs,
        "question": itemgetter("question"),
        "chat_history": itemgetter("chat_history"),
    }
)

rag_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", SYSTEM_TEMPLATE),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{question}"),
    ]
)

chat_retrieval_chain = (
    input_map
    | context_chain
    | rag_prompt
    | llm
    | StrOutputParser()
)

***

Finally, let's run the chain by providing some chat history and a question.

***

In [None]:
chat_retrieval_chain.invoke({
    "chat_history": [
        {"human": "What is Amazon Bedrock?"},
        {"ai": "Amazon Bedrock is a fully managed service that makes base models from Amazon and third-party model providers accessible through an API."},
    ],
    "question": "What models does it support?",
})

## Notebook complete

So far, you've learned about all of the foundational elements to build a retrieval augmented generation chatbot that can answer questions with private data. Next, head back to the workshop content to learn how to incorporate these components into a Streamlit app to build and end-to-end solution.
