# How to Build a RAG-Powered Chatbot with Chat, Embed, and Rerank

*Read the accompanying [blog post here](https://txt.cohere.com/rag-chatbot).*

![Feature](https://github.com/cohere-ai/notebooks/blob/main/notebooks/images/rag-chatbot.png?raw=1)

In this notebook, you’ll learn how to build a chatbot that has RAG capabilities, enabling it to connect to external documents, ground its responses on these documents, and produce document citations in its responses.

Below is a diagram that provides an overview of what we’ll build, followed by a list of the key steps involved.

![Overview](https://github.com/cohere-ai/notebooks/blob/main/notebooks/images/rag-chatbot-flow.png?raw=1)

Setup phase:
- Step 0: Ingest the documents – get documents, chunk, embed, and index.

For each user-chatbot interaction:
- Step 1: Get the user message
- Step 2: Call the Chat endpoint in query-generation mode
- If at least one query is generated
    - Step 3: Retrieve and rerank relevant documents
    - Step 4: Call the Chat endpoint in document mode to generate a grounded response with citations
- If no query is generated
    - Step 4: Call the Chat endpoint in normal mode to generate a response

Throughout the conversation:
- Append the user-chatbot interaction to the conversation thread
- Repeat with every interaction

In [44]:
! pip install cohere hnswlib unstructured lime -q

In [45]:
import cohere
import os
import hnswlib
import json
import uuid
from typing import List, Dict
from unstructured.partition.html import partition_html
from unstructured.chunking.title import chunk_by_title

import lime
import re

co = cohere.Client("LWOoWCmAD3YSNqSNUG70VgveqeNAs5Q2j3cNcoDD")

## Lime

In [46]:
import sklearn
import numpy as np
import sklearn
import sklearn.ensemble
import sklearn.metrics

from sklearn.datasets import fetch_20newsgroups

In [47]:
import lime
import sklearn
import numpy as np
import sklearn
import sklearn.ensemble
import sklearn.metrics
from __future__ import print_function

from sklearn.datasets import fetch_20newsgroups
categories = ['alt.atheism', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
class_names = ['atheism', 'christian']

vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(lowercase=False)
train_vectors = vectorizer.fit_transform(newsgroups_train.data)
test_vectors = vectorizer.transform(newsgroups_test.data)

rf = sklearn.ensemble.RandomForestClassifier(n_estimators=500)
rf.fit(train_vectors, newsgroups_train.target)


prompt = "Classifying new articles between Atheism and Christianity"

pred = rf.predict(test_vectors)
sklearn.metrics.f1_score(newsgroups_test.target, pred, average='binary')


from lime import lime_text
from sklearn.pipeline import make_pipeline
c = make_pipeline(vectorizer, rf)

print(c.predict_proba([newsgroups_test.data[0]]))


from lime.lime_text import LimeTextExplainer
explainer = LimeTextExplainer(class_names=class_names)



idx = 83


exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6)
print('Document id: %d' % idx)
print('Probability(christian) =', c.predict_proba([newsgroups_test.data[idx]])[0,1])
print('True class: %s' % class_names[newsgroups_test.target[idx]])


# using output from here
exp.as_list()


print('Original prediction:', rf.predict_proba(test_vectors[idx])[0,1])
tmp = test_vectors[idx].copy()
tmp[0,vectorizer.vocabulary_['Posting']] = 0
tmp[0,vectorizer.vocabulary_['Host']] = 0
print('Prediction removing some features:', rf.predict_proba(tmp)[0,1])
print('Difference:', rf.predict_proba(tmp)[0,1] - rf.predict_proba(test_vectors[idx])[0,1])



all_explanations = []

for idx in range(10):
    # Explain the instance
    exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6)

    # Extract the top three feature weights
    top_features = sorted(exp.as_list(), key=lambda x: abs(x[1]), reverse=True)[:3]

    # Store the results
    all_explanations.append({
        'probability_christian': c.predict_proba([newsgroups_test.data[idx]])[0,1],
        'true_class': class_names[newsgroups_test.target[idx]],
        'top_features': top_features
    })

print(all_explanations)

[[0.304 0.696]]
Document id: 83
Probability(christian) = 0.424
True class: atheism
Original prediction: 0.424
Prediction removing some features: 0.682
Difference: 0.25800000000000006
[{'probability_christian': 0.696, 'true_class': 'christian', 'top_features': [('article', -0.078300159574246), ('au', -0.052026541580377964), ('deleted', -0.023290267290638107)]}, {'probability_christian': 0.662, 'true_class': 'christian', 'top_features': [('morality', -0.0539254567963974), ('alt', -0.024569979106880337), ('atheist', -0.019500083271421782)]}, {'probability_christian': 0.21, 'true_class': 'atheism', 'top_features': [('Keith', -0.10664864563822533), ('Re', -0.07006049475826955), ('California', -0.060473063276174854)]}, {'probability_christian': 0.876, 'true_class': 'christian', 'top_features': [('Christ', 0.025473411718284086), ('rutgers', 0.023079071237483352), ('Christians', 0.022271285304222493)]}, {'probability_christian': 0.406, 'true_class': 'atheism', 'top_features': [('Posting', -0.0

In [48]:
#@title Enable text wrapping in Google colab

from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

### Documents component

In [49]:
class Documents:
    """
    A class representing a collection of documents.

    Parameters:
    sources (list): A list of dictionaries representing the sources of the documents. Each dictionary should have 'title' and 'url' keys.

    Attributes:
    sources (list): A list of dictionaries representing the sources of the documents.
    docs (list): A list of dictionaries representing the documents, with 'title', 'content', and 'url' keys.
    docs_embs (list): A list of the associated embeddings for the documents.
    retrieve_top_k (int): The number of documents to retrieve during search.
    rerank_top_k (int): The number of documents to rerank after retrieval.
    docs_len (int): The number of documents in the collection.
    index (hnswlib.Index): The index used for document retrieval.

    Methods:
    load(): Loads the data from the sources and partitions the HTML content into chunks.
    embed(): Embeds the documents using the Cohere API.
    index(): Indexes the documents for efficient retrieval.
    retrieve(query): Retrieves documents based on the given query.

    """

    def __init__(self, sources: List[Dict[str, str]]):
        self.sources = sources
        self.docs = []
        self.docs_embs = []
        self.retrieve_top_k = 10
        self.rerank_top_k = 3
        self.load()
        self.embed()
        self.index()

    def load(self) -> None:
        """
        Loads the documents from the sources and chunks the HTML content.
        """
        print("Loading documents...")

        for source in self.sources:
            elements = partition_html(url=source["url"])
            chunks = chunk_by_title(elements)
            for chunk in chunks:
                self.docs.append(
                    {
                        "title": source["title"],
                        "text": str(chunk),
                        "url": source["url"],
                    }
                )

    def embed(self) -> None:
        """
        Embeds the documents using the Cohere API.
        """
        print("Embedding documents...")

        batch_size = 90
        self.docs_len = len(self.docs)

        for i in range(0, self.docs_len, batch_size):
            batch = self.docs[i : min(i + batch_size, self.docs_len)]
            texts = [item["text"] for item in batch]
            docs_embs_batch = co.embed(
                texts=texts, model="embed-english-v3.0", input_type="search_document"
            ).embeddings
            self.docs_embs.extend(docs_embs_batch)

    def index(self) -> None:
        """
        Indexes the documents for efficient retrieval.
        """
        print("Indexing documents...")

        self.idx = hnswlib.Index(space="ip", dim=1024)
        self.idx.init_index(max_elements=self.docs_len, ef_construction=512, M=64)
        self.idx.add_items(self.docs_embs, list(range(len(self.docs_embs))))

        print(f"Indexing complete with {self.idx.get_current_count()} documents.")

    def retrieve(self, query: str) -> List[Dict[str, str]]:
        """
        Retrieves documents based on the given query.

        Parameters:
        query (str): The query to retrieve documents for.

        Returns:
        List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'text', and 'url' keys.
        """
        docs_retrieved = []
        query_emb = co.embed(
            texts=[query], model="embed-english-v3.0", input_type="search_query"
        ).embeddings

        doc_ids = self.idx.knn_query(query_emb, k=self.retrieve_top_k)[0][0]

        docs_to_rerank = []
        for doc_id in doc_ids:
            docs_to_rerank.append(self.docs[doc_id]["text"])

        rerank_results = co.rerank(
            query=query,
            documents=docs_to_rerank,
            top_n=self.rerank_top_k,
            model="rerank-english-v2.0",
        )

        doc_ids_reranked = []
        for result in rerank_results:
            doc_ids_reranked.append(doc_ids[result.index])

        for doc_id in doc_ids_reranked:
            docs_retrieved.append(
                {
                    "title": self.docs[doc_id]["title"],
                    "text": self.docs[doc_id]["text"],
                    "url": self.docs[doc_id]["url"],
                }
            )

        return docs_retrieved

### Chatbot component

In [50]:
class Chatbot:
    """
    A class representing a chatbot.

    Parameters:
    docs (Documents): An instance of the Documents class representing the collection of documents.

    Attributes:
    conversation_id (str): The unique ID for the conversation.
    docs (Documents): An instance of the Documents class representing the collection of documents.

    Methods:
    generate_response(message): Generates a response to the user's message.
    retrieve_docs(response): Retrieves documents based on the search queries in the response.

    """

    def __init__(self, docs: Documents):
        self.docs = docs
        self.conversation_id = str(uuid.uuid4())

    def generate_response(self, message: str):
        """
        Generates a response to the user's message.

        Parameters:
        message (str): The user's message.

        Yields:
        Event: A response event generated by the chatbot.

        Returns:
        List[Dict[str, str]]: A list of dictionaries representing the retrieved documents.

        """
        # Generate search queries (if any)
        response = co.chat(message=message, search_queries_only=True)

        # If there are search queries, retrieve documents and respond
        if response.search_queries:
            print("Retrieving information...")

            documents = self.retrieve_docs(response)

            response = co.chat(
                message=message,
                documents=documents,
                conversation_id=self.conversation_id,
                stream=True,
            )
            for event in response:
                yield event

        # If there is no search query, directly respond
        else:
            response = co.chat(
                message=message,
                conversation_id=self.conversation_id,
                stream=True
            )
            for event in response:
                yield event

    def retrieve_docs(self, response) -> List[Dict[str, str]]:
        """
        Retrieves documents based on the search queries in the response.

        Parameters:
        response: The response object containing search queries.

        Returns:
        List[Dict[str, str]]: A list of dictionaries representing the retrieved documents.

        """
        # Get the query(s)
        queries = []
        for search_query in response.search_queries:
            queries.append(search_query["text"])

        # Retrieve documents for each query
        retrieved_docs = []
        for query in queries:
            retrieved_docs.extend(self.docs.retrieve(query))

        # # Uncomment this code block to display the chatbot's retrieved documents
        # print("DOCUMENTS RETRIEVED:")
        # for idx, doc in enumerate(retrieved_docs):
        #     print(f"doc_{idx}: {doc}")
        # print("\n")

        return retrieved_docs

### App component

In [51]:
class App:
    def __init__(self, chatbot: Chatbot):
        """
        Initializes an instance of the App class.

        Parameters:
        chatbot (Chatbot): An instance of the Chatbot class.

        """
        self.chatbot = chatbot

    def run(self):
        """
        Runs the chatbot application.

        """
        while True:
            # Get the user message
            message = input("User: ")

            # Typing "quit" ends the conversation
            if message.lower() == "quit":
                print("Ending chat.")
                break
            else:
                print(f"User: {message}")

            # Get the chatbot response
            response = self.chatbot.generate_response(message)

            # Print the chatbot response
            print("Chatbot:")
            flag = False
            for event in response:
                # Text
                if event.event_type == "text-generation":
                    print(event.text, end="")

                # Citations
                if event.event_type == "citation-generation":
                    if not flag:
                        print("\n\nCITATIONS:")
                        flag = True
                    print(event.citations)

            print(f"\n{'-'*100}\n")

### Run the chatbot

In [92]:
message = """I am trying to gain additional context regarding explainablilty of model predictions.
                        Can you suggest about 50-100 wikipedia links (just the links)
                        that can be helpful in providing more contextual information
                        for the following model purpose:""" + prompt
response = co.chat(message=message)

In [73]:
print(response.text)

Sure, here are a few Wikipedia links that might provide some context about the classification of articles between Atheism and Christianity:

1. https://en.wikipedia.org/wiki/Atheism
2. https://en.wikipedia.org/wiki/Christianity
3. https://en.wikipedia.org/wiki/Religion
4. https://en.wikipedia.org/wiki/Atheist
5. https://en.wikipedia.org/wiki/Agnosticism
6. https://en.wikipedia.org/wiki/List_of_atheists
7. https://en.wikipedia.org/wiki/List_of_Christians
8. https://en.wikipedia.org/wiki/Freethinking
9. https://en.wikipedia.org/wiki/Secularism
10. https://en.wikipedia.org/wiki/Rational_atheism

Note that these links provide a wide range of information, from definitions and distinctions between concepts, to lists of examples and groups relevant to the topic. This variety of sources can help provide a more holistic understanding of the context necessary to classify articles between Atheism and Christianity. 

Would you like me to provide more focused recommendations based on this prelimina

In [93]:
import re

def extract_links(text):
    # Define a regular expression pattern to match URLs
    url_pattern = r'https?://\S+|www\.\S+'

    # Use re.findall() to extract all URLs from the text
    links = re.findall(url_pattern, text)

    return links

extract_links(response.text)


['https://en.wikipedia.org/wiki/Atheism',
 'https://en.wikipedia.org/wiki/Christianity',
 'https://en.wikipedia.org/wiki/Category:Religion_and_philosophy',
 'https://en.wikipedia.org/wiki/Category:Belief_systems',
 'https://en.wikipedia.org/wiki/Comparison_of_religions',
 'https://en.wikipedia.org/wiki/List_of_atheistic_philosophers',
 'https://en.wikipedia.org/wiki/List_of_Christian_philosophers',
 'https://en.wikipedia.org/wiki/Religion_and_science',
 'https://en.wikipedia.org/wiki/Rationalism',
 'https://en.wikipedia.org/wiki/Empiricism']

In [94]:
def create_dict(links):
  res =[]
  for link in links:
    res.append(
        {
            'title' : link,
            'url' : link
        })
  return res

import requests

def filter_urls(url_list):
    existing_urls = []

    for url in url_list:
        try:
            response = requests.head(url, allow_redirects=True)
            if response.status_code == 200:
                existing_urls.append(url)
        except requests.RequestException as e:
          continue

    return existing_urls

In [95]:
links = extract_links(response.text)
sources = create_dict(filter_urls(links))
documents = Documents(sources)

Loading documents...
Embedding documents...
Indexing documents...
Indexing complete with 2054 documents.


In [96]:
summaries = ""
for link in links[:2]:
  response = co.chat(message=f'Summarize all the data in this link: {link} in 100 words or less')
  summaries+=response.text
  print(link)

https://en.wikipedia.org/wiki/Atheism
https://en.wikipedia.org/wiki/Christianity


In [118]:
# Create an instance of the Chatbot class with the Documents instance
chatbot = Chatbot(documents)

final_explanations = ""
for i in all_explanations:
  final_explanations += str(i)

message = "Given a machine learning model with the following purpose: " + prompt + "and the following explanations for the features" + final_explanations + "can you explain why this model contextually makes sense with the following references: "   + summaries

response = chatbot.generate_response(message=message)


In [108]:
len(summaries)

2309

In [119]:
for event in response:
  try:
    if event.event_type =='text-generation':
      print(event.text, end="")

  except:
    continue


Retrieving information...
This model appears to be a classifier attempting to distinguish between articles likely written from the perspective of an atheist and articles likely written from the perspective of a Christian. 

Atheism is a disbelief in the existence of deities. More broadly, atheists reject the belief that any deities exist. In this sense, atheists argue that since there is an absence of evidence for the existence of deities, it can be concluded that they do not exist. Atheists often promote the scientific method and empirical evidence as the best ways to understand the universe, and they argue that religious beliefs should be subject to questioning and criticism like any other belief system. 

Christianity, on the other hand, is based on the teachings of Jesus Christ, who Christians believe to be the Son of God and the savior of humanity. Christians make up a significant portion of the global population and can be found in all parts of the world. Christianity plays a sig