# How to Build a RAG-Powered Chatbot with Chat, Embed, and Rerank

*Read the accompanying [blog post here](https://txt.cohere.com/rag-chatbot).*

![Feature](https://github.com/cohere-ai/notebooks/blob/main/notebooks/images/rag-chatbot.png?raw=1)

In this notebook, you’ll learn how to build a chatbot that has RAG capabilities, enabling it to connect to external documents, ground its responses on these documents, and produce document citations in its responses.

Below is a diagram that provides an overview of what we’ll build, followed by a list of the key steps involved.

![Overview](https://github.com/cohere-ai/notebooks/blob/main/notebooks/images/rag-chatbot-flow.png?raw=1)

Setup phase:
- Step 0: Ingest the documents – get documents, chunk, embed, and index.

For each user-chatbot interaction:
- Step 1: Get the user message
- Step 2: Call the Chat endpoint in query-generation mode
- If at least one query is generated
    - Step 3: Retrieve and rerank relevant documents
    - Step 4: Call the Chat endpoint in document mode to generate a grounded response with citations
- If no query is generated
    - Step 4: Call the Chat endpoint in normal mode to generate a response

Throughout the conversation:
- Append the user-chatbot interaction to the conversation thread
- Repeat with every interaction

In [1]:
! pip install openai tiktoken

Collecting openai
  Obtaining dependency information for openai from https://files.pythonhosted.org/packages/30/1d/27c3571504fb6fb1e9f7c906d93590ead22f5f34910489e155ee28512eeb/openai-1.3.5-py3-none-any.whl.metadata
  Downloading openai-1.3.5-py3-none-any.whl.metadata (16 kB)
Collecting tiktoken
  Obtaining dependency information for tiktoken from https://files.pythonhosted.org/packages/f4/2e/0adf6e264b996e263b1c57cad6560ffd5492a69beb9fd779ed0463d486bc/tiktoken-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading tiktoken-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting distro<2,>=1.7.0 (from openai)
  Downloading distro-1.8.0-py3-none-any.whl (20 kB)
Collecting httpx<1,>=0.23.0 (from openai)
  Obtaining dependency information for httpx<1,>=0.23.0 from https://files.pythonhosted.org/packages/a2/65/6940eeb21dcb2953778a6895281c179efd9100463ff08cb6232bb6480da7/httpx-0.25.2-py3-none-any.whl.metadata
  Downloadin

In [2]:
! pip install PyPDF2

Collecting PyPDF2
  Downloading pypdf2-3.0.1-py3-none-any.whl (232 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: PyPDF2
Successfully installed PyPDF2-3.0.1


In [3]:
! pip install cohere hnswlib unstructured -q

In [4]:

import cohere
import os
import hnswlib
import json
import uuid
from typing import List, Dict
from unstructured.partition.html import partition_html
# from unstructured.chunking.title import chunk_by_title
from PyPDF2 import PdfReader
import requests
from io import BytesIO
import PyPDF2

os.environ["COHERE_API_KEY"]= "vOXc8PMABEh4ZgaSmxiirTsGom3Ttq482wdMmYBC"

co = cohere.Client(os.environ["COHERE_API_KEY"])



In [5]:
#@title Enable text wrapping in Google colab

from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

### Documents component

In [6]:
class Documents:
    """
    A class representing a collection of documents.

    Parameters:
    sources (list): A list of dictionaries representing the sources of the documents. Each dictionary should have 'title' and 'url' keys.

    Attributes:
    sources (list): A list of dictionaries representing the sources of the documents.
    docs (list): A list of dictionaries representing the documents, with 'title', 'content', and 'url' keys.
    docs_embs (list): A list of the associated embeddings for the documents.
    retrieve_top_k (int): The number of documents to retrieve during search.
    rerank_top_k (int): The number of documents to rerank after retrieval.
    docs_len (int): The number of documents in the collection.
    index (hnswlib.Index): The index used for document retrieval.

    Methods:
    load(): Loads the data from the sources and partitions the HTML content into chunks.
    embed(): Embeds the documents using the Cohere API.
    index(): Indexes the documents for efficient retrieval.
    retrieve(query): Retrieves documents based on the given query.

    """

    def __init__(self, sources: List[Dict[str, str]]):
        self.sources = sources
        self.docs = []
        self.docs_embs = []
        self.retrieve_top_k = 10
        self.rerank_top_k = 3
        self.load()
        self.embed()
        self.index()

    def is_title_line(self,line):
    # A simple heuristic to identify title lines. You might need to adjust this.
      return line.isupper() and len(line) > 10

    def chunk_by_title(self, text, max_tokens=300):
      chunks = []
      current_chunk = []
      current_token_count = 0

      for line in text.split('\n'):
          line_tokens = line.split()
          line_token_count = len(line_tokens)

          # Check if adding this line would exceed the token limit
          if current_token_count + line_token_count > max_tokens:
              chunks.append('\n'.join(current_chunk))
              current_chunk = []
              current_token_count = 0

          if self.is_title_line(line):
              if current_chunk:
                  chunks.append('\n'.join(current_chunk))
                  current_chunk = []
                  current_token_count = 0

          current_chunk.append(line)
          current_token_count += line_token_count

      if current_chunk:
          chunks.append('\n'.join(current_chunk))

      return chunks


    def load(self) -> None:
        """
        Loads the documents from the sources and chunks the HTML content.
        """
        print("Loading documents...")

        for source in self.sources:
          response = requests.get(source['url'])
          if response.status_code != 200:
            print('Failed to retrieve the PDF')
            return []

            # Explicitly decode the content as UTF-8 if needed

          pdf_content = BytesIO(response.content)
          reader = PyPDF2.PdfReader(pdf_content)
          all_chunks = []
          for page_num in range(len(reader.pages)):
            page = reader.pages[page_num]
            text = page.extract_text()
            if text:
              if isinstance(text, bytes):
                text = text.decode('utf-8')
              chunks = self.chunk_by_title(text)
              for chunk in chunks:
                # print(str(chunk))
                # print('-'*100)
                # print(len(str(chunk).split()))
                self.docs.append(
                    {
                        "title": source["title"],
                        "text": str(chunk),
                        "url": source["url"],
                    }
                )
        # for source in self.sources:
        #     response = requests.get(source["url"])
        #     if response.status_code == 200:
        #       pdf_content = BytesIO(response.content)
        #     else:
        #       print('Failed to retrieve the PDF')
        #       pdf_content = None
        #     if pdf_content:
        #       reader = PyPDF2.PdfReader(pdf_content)
        #       for page_num in range(len(reader.pages)):
        #         page = reader.pages[page_num]
        #         chunks = chunk_by_title(page.extract_text())
        #         for chunk in chunks:
        #           self.docs.append(
        #               {
        #                   "title": source["title"],
        #                   "text": str(chunk),
        #                   "url": source["url"],
        #               }
        #           )
            #reader = PdfReader(source['url'])

    def embed(self) -> None:
        """
        Embeds the documents using the Cohere API.
        """
        print("Embedding documents...")

        batch_size = 90
        self.docs_len = len(self.docs)

        for i in range(0, self.docs_len, batch_size):
            batch = self.docs[i : min(i + batch_size, self.docs_len)]
            texts = [item["text"] for item in batch]
            docs_embs_batch = co.embed(
                texts=texts, model="embed-english-v3.0", input_type="search_document"
            ).embeddings
            self.docs_embs.extend(docs_embs_batch)

    def index(self) -> None:
        """
        Indexes the documents for efficient retrieval.
        """
        print("Indexing documents...")

        self.idx = hnswlib.Index(space="ip", dim=1024)
        self.idx.init_index(max_elements=self.docs_len, ef_construction=512, M=64)
        self.idx.add_items(self.docs_embs, list(range(len(self.docs_embs))))

        print(f"Indexing complete with {self.idx.get_current_count()} documents.")

    def retrieve(self, query: str) -> List[Dict[str, str]]:
        """
        Retrieves documents based on the given query.

        Parameters:
        query (str): The query to retrieve documents for.

        Returns:
        List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'text', and 'url' keys.
        """
        docs_retrieved = []
        query_emb = co.embed(
            texts=[query], model="embed-english-v3.0", input_type="search_query"
        ).embeddings

        doc_ids = self.idx.knn_query(query_emb, k=self.retrieve_top_k)[0][0]

        docs_to_rerank = []
        for doc_id in doc_ids:
            docs_to_rerank.append(self.docs[doc_id]["text"])

        rerank_results = co.rerank(
            query=query,
            documents=docs_to_rerank,
            top_n=self.rerank_top_k,
            model="rerank-english-v2.0",
        )

        doc_ids_reranked = []
        for result in rerank_results:
            doc_ids_reranked.append(doc_ids[result.index])

        for doc_id in doc_ids_reranked:
            docs_retrieved.append(
                {
                    "title": self.docs[doc_id]["title"],
                    "text": self.docs[doc_id]["text"],
                    "url": self.docs[doc_id]["url"],
                }
            )

        return docs_retrieved

### Chatbot component

In [7]:
class Chatbot:
    """
    A class representing a chatbot.

    Parameters:
    docs (Documents): An instance of the Documents class representing the collection of documents.

    Attributes:
    conversation_id (str): The unique ID for the conversation.
    docs (Documents): An instance of the Documents class representing the collection of documents.

    Methods:
    generate_response(message): Generates a response to the user's message.
    retrieve_docs(response): Retrieves documents based on the search queries in the response.

    """

    def __init__(self, docs: Documents):
        self.docs = docs
        self.conversation_id = str(uuid.uuid4())

    def generate_response(self, message: str):
        """
        Generates a response to the user's message.

        Parameters:
        message (str): The user's message.

        Yields:
        Event: A response event generated by the chatbot.

        Returns:
        List[Dict[str, str]]: A list of dictionaries representing the retrieved documents.

        """
        # Generate search queries (if any)
        response = co.chat(message=message, search_queries_only=True)

        # If there are search queries, retrieve documents and respond
        if response.search_queries:
            print("Retrieving information...")

            documents = self.retrieve_docs(response)

            response = co.chat(
                message=message,
                documents=documents,
                conversation_id=self.conversation_id,
                stream=True,
            )
            for event in response:
                yield event

        # If there is no search query, directly respond
        else:
            response = co.chat(
                message=message,
                conversation_id=self.conversation_id,
                stream=True
            )
            for event in response:
                yield event

    def retrieve_docs(self, response) -> List[Dict[str, str]]:
        """
        Retrieves documents based on the search queries in the response.

        Parameters:
        response: The response object containing search queries.

        Returns:
        List[Dict[str, str]]: A list of dictionaries representing the retrieved documents.

        """
        # Get the query(s)
        queries = []
        for search_query in response.search_queries:
            queries.append(search_query["text"])

        # Retrieve documents for each query
        retrieved_docs = []
        for query in queries:
            retrieved_docs.extend(self.docs.retrieve(query))

        # # Uncomment this code block to display the chatbot's retrieved documents
        # print("DOCUMENTS RETRIEVED:")
        # for idx, doc in enumerate(retrieved_docs):
        #     print(f"doc_{idx}: {doc}")
        # print("\n")

        return retrieved_docs

### App component

In [8]:
class App:
    def __init__(self, chatbot: Chatbot):
        """
        Initializes an instance of the App class.

        Parameters:
        chatbot (Chatbot): An instance of the Chatbot class.

        """
        self.chatbot = chatbot

    def run(self):
        """
        Runs the chatbot application.

        """
        while True:
            # Get the user message
            message = input("User: ")

            # Typing "quit" ends the conversation
            if message.lower() == "quit":
                print("Ending chat.")
                break
            else:
                print(f"User: {message}")

            # Get the chatbot response
            response = self.chatbot.generate_response(message)

            # Print the chatbot response
            print("Chatbot:")
            flag = False
            for event in response:
                # Text
                if event.event_type == "text-generation":
                    print(event.text, end="")

                # Citations
                if event.event_type == "citation-generation":
                    if not flag:
                        print("\n\nCITATIONS:")
                        flag = True
                    print(event.citations)

            print(f"\n{'-'*100}\n")

### Define the documents

In [9]:
# Define the sources for the documents
# As an example, we'll use LLM University's Module 1: What are Large Language Models?
# https://docs.cohere.com/docs/intro-large-language-models

sources = [
    {
        "title": "Text Line Segmentation of Historical Documents: a Survey",
        "url": "https://arxiv.org/pdf/0704.1267.pdf"
    },
    {
        "title": "Riemannian level-set methods for tensor-valued data",
        "url": "https://arxiv.org/pdf/0705.0214.pdf"
    },
    {
        "title": "Multiresolution Approximation of Polygonal Curves in Linear Complexity",
        "url": "https://arxiv.org/pdf/0705.0449.pdf"
    },
    {
        "title": "Medical Image Segmentation and Localization using Deformable Templates",
        "url": "https://arxiv.org/pdf/0705.0781.pdf"
    },
    {
        "title": "Enhancement of Noisy Planar Nuclear Medicine Images using Mean Field\n  Annealing",
        "url": "https://arxiv.org/pdf/0705.0828.pdf"
    },
    {
        "title": "An Independent Evaluation of Subspace Face Recognition Algorithms",
        "url": "https://arxiv.org/pdf/0705.0952.pdf"
    },
    {
        "title": "MI image registration using prior knowledge",
        "url": "https://arxiv.org/pdf/0705.3593.pdf"
    },
    {
        "title": "Automatic Detection of Pulmonary Embolism using Computational\n  Intelligence",
        "url": "https://arxiv.org/pdf/0706.0300.pdf"
    },
    {
        "title": "Variational local structure estimation for image super-resolution",
        "url": "https://arxiv.org/pdf/0709.1771.pdf"
    },
    {
        "title": "Bandwidth selection for kernel estimation in mixed multi-dimensional\n  spaces",
        "url": "https://arxiv.org/pdf/0709.1920.pdf"
    },
    {
        "title": "Supervised learning on graphs of spatio-temporal similarity in satellite\n  image sequences",
        "url": "https://arxiv.org/pdf/0709.3013.pdf"
    },
    {
        "title": "Graph rigidity, Cyclic Belief Propagation and Point Pattern Matching",
        "url": "https://arxiv.org/pdf/0710.0043.pdf"
    },
    {
        "title": "High-Order Nonparametric Belief-Propagation for Fast Image Inpainting",
        "url": "https://arxiv.org/pdf/0710.0243.pdf"
    },
    {
        "title": "An Affinity Propagation Based method for Vector Quantization Codebook\n  Design",
        "url": "https://arxiv.org/pdf/0710.2037.pdf"
    },
    {
        "title": "Comparison and Combination of State-of-the-art Techniques for\n  Handwritten Character Recognition: Topping the MNIST Benchmark",
        "url": "https://arxiv.org/pdf/0710.2231.pdf"
    }
]


### Process the documents

In [10]:
# Create an instance of the Documents class with the given sources
documents = Documents(sources)

Loading documents...
Embedding documents...
Indexing documents...
Indexing complete with 391 documents.


### Run the chatbot

In [None]:
# Create an instance of the Chatbot class with the Documents instance
chatbot = Chatbot(documents)

# Create an instance of the App class with the Chatbot instance
app = App(chatbot)

# Run the chatbot
app.run()