# How to Build a RAG-Powered Chatbot with Chat, Embed, and Rerank

*Read the accompanying [blog post here](https://txt.cohere.com/rag-chatbot).*

![Feature](images/rag-chatbot.png)

In this notebook, you’ll learn how to build a chatbot that has RAG capabilities, enabling it to connect to external documents, ground its responses on these documents, and produce document citations in its responses.

Below is a diagram that provides an overview of what we’ll build, followed by a list of the key steps involved.

![Overview](images/rag-chatbot-flow.png)

Setup phase:
- Step 0: Ingest the documents – get documents, chunk, embed, and index.

For each user-chatbot interaction:
- Step 1: Get the user message
- Step 2: Call the Chat endpoint in query-generation mode
- If at least one query is generated
    - Step 3: Retrieve and rerank relevant documents
    - Step 4: Call the Chat endpoint in document mode to generate a grounded response with citations
- If no query is generated
    - Step 4: Call the Chat endpoint in normal mode to generate a response

Throughout the conversation:
- Append the user-chatbot interaction to the conversation thread
- Repeat with every interaction

In [None]:
# TODO: upgrade to "cohere>5"! pip install "cohere<5" hnswlib unstructured -q

In [1]:
import cohere
import hnswlib
import json
import uuid
from typing import List, Dict
from unstructured.partition.html import partition_html
from unstructured.chunking.title import chunk_by_title

co = cohere.Client("COHERE_API_KEY")

In [3]:
#@title Enable text wrapping in Google colab

from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

### Datastore component

In [2]:
class Datastore:
    """
    A class representing a collection of documents.

    Parameters:
    sources (list): A list of dictionaries representing the sources of the documents. Each dictionary should have 'title' and 'url' keys.

    Attributes:
    sources (list): A list of dictionaries representing the sources of the documents.
    docs (list): A list of dictionaries representing the documents, with 'title', 'content', and 'url' keys.
    docs_embs (list): A list of the associated embeddings for the documents.
    docs_len (int): The number of documents in the collection.
    index (hnswlib.Index): The index used for document retrieval.

    Methods:
    load_and_chunk(): Loads the data from the sources and partitions the HTML content into chunks.
    embed(): Embeds the documents using the Cohere API.
    index(): Indexes the documents for efficient retrieval.
    """

    def __init__(self, raw_documents: List[Dict[str, str]]):
        self.raw_documents = raw_documents  # raw documents
        self.chunks = []            # chunked version of documents
        self.chunks_embs = []       # embeddings of chunked documents
        self.retrieve_top_k = 10
        self.rerank_top_k = 3
        self.load_and_chunk()  # load raw documents and break into chunks
        self.embed() # generate embeddings for each chunk
        self.index() # store embeddings in an index


    def load_and_chunk(self) -> None:
        """
        Loads the text from the sources and chunks the HTML content.
        """
        print("Loading documents...")

        for source in self.raw_documents:
            elements = partition_html(url=source["url"])
            chunks = chunk_by_title(elements)
            for chunk in chunks:
                self.chunks.append(
                    {
                        "title": source["title"],
                        "text": str(chunk),
                        "url": source["url"],
                    }
                )

    def embed(self) -> None:
        """
        Embeds the document chunks using the Cohere API.
        """
        print("Embedding document chunks...")

        batch_size = 90
        self.chunks_len = len(self.chunks)

        for i in range(0, self.chunks_len, batch_size):
            batch = self.chunks[i : min(i + batch_size, self.chunks_len)]
            texts = [item["text"] for item in batch]
            chunks_embs_batch = co.embed(
                texts=texts, model="embed-english-v3.0", input_type="search_document"
            ).embeddings
            self.chunks_embs.extend(chunks_embs_batch)

    def index(self) -> None:
        """
        Indexes the document chunks for efficient retrieval.
        """
        print("Indexing documents...")

        self.idx = hnswlib.Index(space="ip", dim=1024)
        self.idx.init_index(max_elements=self.chunks_len, ef_construction=512, M=64)
        self.idx.add_items(self.chunks_embs, list(range(len(self.chunks_embs))))

        print(f"Indexing complete with {self.idx.get_current_count()} documents.")

        return self.idx

    def search_and_rerank(self, query: str) -> List[Dict[str, str]]:
        # SEARCH
        query_emb = co.embed(
                  texts=[query], model="embed-english-v3.0", input_type="search_query"
              ).embeddings

        chunk_ids = self.idx.knn_query(query_emb, k=self.retrieve_top_k)[0][0]

        # RERANK
        chunks_to_rerank = [self.chunks[chunk_id]["text"] for chunk_id in chunk_ids]

        rerank_results = co.rerank(
            query=query,
            documents=chunks_to_rerank,
            top_n=self.rerank_top_k,
            model="rerank-english-v2.0",
        )

        chunk_ids_reranked = [chunk_ids[result.index] for result in rerank_results]

        chunks_retrieved = []
        for chunk_id in chunk_ids_reranked:
            chunks_retrieved.append(
                {
                "title": self.chunks[chunk_id]["title"],
                "text": self.chunks[chunk_id]["text"],
                "url": self.chunks[chunk_id]["url"],
                }
            )

        return chunks_retrieved

### Define and process documents

In [4]:
# Define the sources for the documents
# As an example, we'll use LLM University's Module 1: What are Large Language Models?
# https://docs.cohere.com/docs/intro-large-language-models

sources = [
    {
        "title": "Text Embeddings", 
        "url": "https://docs.cohere.com/docs/text-embeddings"},
    {
        "title": "Similarity Between Words and Sentences", 
        "url": "https://docs.cohere.com/docs/similarity-between-words-and-sentences"},
    {
        "title": "The Attention Mechanism", 
        "url": "https://docs.cohere.com/docs/the-attention-mechanism"},
    {
        "title": "Transformer Models", 
        "url": "https://docs.cohere.com/docs/transformer-models"}   
]

# Create an instance of the Datastore class with the given sources
datastore = Datastore(sources)

Loading documents...
Embedding document chunks...
Indexing documents...
Indexing complete with 136 documents.


In [5]:
# Test retrieving documents from the datastore
datastore.search_and_rerank("word embeddings")

[{'title': 'Similarity Between Words and Sentences',
  'text': 'In the previous chapter, I explained the concept of word embeddings. In a nutshell, a word embedding is an assignment of a list of numbers (vector) to every word, in a way that semantic properties of the word translate into mathematical properties of the numbers. What do we mean by this? For example, two similar words will have similar vectors, and two different words will have different vectors. But most importantly, each entry in the vector corresponding to a word keeps track of some property ',
  'url': 'https://docs.cohere.com/docs/similarity-between-words-and-sentences'},
 {'title': 'The Attention Mechanism',
  'text': 'In the previous chapters, you learned about word and sentence embeddings and similarity between words and sentences. In short, a word embedding is a way to associate words with lists of numbers (vectors) in such a way that similar words are associated with numbers that are close by, and dissimilar word

### Chatbot component

In [45]:
class Chatbot:
    def __init__(self, datastore: Datastore):
        """
        Initializes an instance of the Chatbot class.

        Parameters:
        storage (Storage): An instance of the Storage class.

        """
        self.datastore = datastore
        self.conversation_id = str(uuid.uuid4())

    def run(self):
        """
        Runs the chatbot application.

        """
        while True:
            # Get the user message
            message = input("User: ")

            # Typing "quit" ends the conversation
            if message.lower() == "quit":
                print("Ending chat.")
                break
            else:
                print(f"User: {message}")

            # Generate search queries, if any
            response_queries = co.chat(message=message, search_queries_only=True)

            if response_queries.search_queries:
                print("Retrieving information...", end="")

                # Get the query(s)
                queries = []
                for search_query in response_queries.search_queries:
                    queries.append(search_query["text"])

                # Retrieve documents for each query
                chunks = []
                for query in queries:
                    chunks.extend(self.datastore.search_and_rerank(query))
            
                response = co.chat(
                    message=message,
                    documents=chunks,
                    conversation_id=self.conversation_id,
                    stream=True,
                )

            else:
                response = co.chat(
                    message=message,
                    conversation_id=self.conversation_id,
                    stream=True,
                )

            # Print the chatbot response
            print("\nChatbot:")
            
            citations_flag = False
            
            for event in response:
                                
                # Text
                if event.event_type == "text-generation":
                    print(event.text, end="")

                # Citations
                if event.event_type == "citation-generation":
                    if not citations_flag:
                        print("\n\nCITATIONS:")
                        citations_flag = True
                    print(event.citations[0])
            
            # Documents
            if citations_flag:
                print("\n\nDOCUMENTS:")
                documents = [{'id': doc['id'],
                                'text': doc['text'][:50] + '...',
                                'title': doc['title'],
                                'url': doc['url']} 
                                for doc in response.documents]
                for doc in documents:
                    print(doc)

            print(f"\n{'-'*100}\n")

### Run the chatbot

In [46]:
# Create an instance of the Chatbot class with the Datastore instance
chatbot = Chatbot(documents)

# Run the chatbot
chatbot.run()

User: Hello

Chatbot:
Hi there! How can I help you today?

If you would like, you can skip the small talk and get right to what you need assistance with. I'm a chatbot trained to be helpful, respectful and truthful, and I aim to make this conversation a pleasant and meaningful one for you :) 

Let me know if I can help you with anything specific, or you can even present me with a challenge!
----------------------------------------------------------------------------------------------------

User: who are you
Retrieving information...
Chatbot:
My name is Coral! I'm a large language model (LLM) developed by Cohere, and I'm designed to have polite, helpful, and inclusive conversations with users like you. I'm powered by Command, a powerful LLM built by Cohere.

Currently, I'm only fluent in English, but models like me can be trained to understand and interpret various languages, too!

CITATIONS:
{'start': 64, 'end': 70, 'text': 'Cohere', 'document_ids': ['doc_0', 'doc_1', 'doc_2']}


DOCU