# RAG Pipeline With Keras NLP, MongoDB and OpenAI

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mongodb-developer/GenAI-Showcase/blob/main/notebooks/rag/rag_pipeline_kerasnlp_mongodb_gemma2.ipynb)

## Set Up Libraries

In [5]:
# Install all deps
!pip --quiet install keras
!pip --quiet install keras-nlp
!pip --quiet install --upgrade --quiet datasets pandas pymongo
!pip --quiet install openai

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/327.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━[0m [32m204.8/327.5 kB[0m [31m6.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m327.5/327.5 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.3/58.3 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25h

## Set Up Environment Variables

In [53]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["OPENAI_API_KEY"] = ""

## Data Loading

In [23]:
# Load Dataset
import pandas as pd
from datasets import load_dataset

# Make sure you have an Hugging Face token(HF_TOKEN) in your development environemnt before running the code below
# How to get a token: https://huggingface.co/docs/hub/en/security-tokens

# https://huggingface.co/datasets/MongoDB/embedded_movies
dataset = load_dataset(
    "MongoDB/subset_arxiv_papers_with_embeddings", split="train", streaming=True
)
dataset = dataset.take(4000)

# Convert the dataset to a pandas dataframe
dataset_df = pd.DataFrame(dataset)

dataset_df.head(5)

Repo card metadata block was not found. Setting CardData to empty.


Unnamed: 0,id,submitter,authors,title,comments,journal-ref,doi,report-no,categories,license,abstract,versions,update_date,authors_parsed,embedding
0,704.0001,Pavel Nadolsky,"C. Bal\'azs, E. L. Berger, P. M. Nadolsky, C.-...",Calculation of prompt diphoton production cros...,"37 pages, 15 figures; published version","Phys.Rev.D76:013009,2007",10.1103/PhysRevD.76.013009,ANL-HEP-PR-07-12,hep-ph,,A fully differential calculation in perturba...,"[{'version': 'v1', 'created': 'Mon, 2 Apr 2007...",2008-11-26,"[[Balázs, C., ], [Berger, E. L., ], [Nadolsky,...","[0.0594153292, -0.0440569334, -0.0487333685, -..."
1,704.0002,Louis Theran,Ileana Streinu and Louis Theran,Sparsity-certifying Graph Decompositions,To appear in Graphs and Combinatorics,,,,math.CO cs.CG,http://arxiv.org/licenses/nonexclusive-distrib...,"We describe a new algorithm, the $(k,\ell)$-...","[{'version': 'v1', 'created': 'Sat, 31 Mar 200...",2008-12-13,"[[Streinu, Ileana, ], [Theran, Louis, ]]","[0.0247399714, -0.065658465, 0.0201423876, -0...."
2,704.0003,Hongjun Pan,Hongjun Pan,The evolution of the Earth-Moon system based o...,"23 pages, 3 figures",,,,physics.gen-ph,,The evolution of Earth-Moon system is descri...,"[{'version': 'v1', 'created': 'Sun, 1 Apr 2007...",2008-01-13,"[[Pan, Hongjun, ]]","[0.0491479263, 0.0728017688, 0.0604138002, 0.0..."
3,704.0004,David Callan,David Callan,A determinant of Stirling cycle numbers counts...,11 pages,,,,math.CO,,We show that a determinant of Stirling cycle...,"[{'version': 'v1', 'created': 'Sat, 31 Mar 200...",2007-05-23,"[[Callan, David, ]]","[0.0389556214, -0.0410280302, 0.0410280302, -0..."
4,704.0005,Alberto Torchinsky,Wael Abu-Shammala and Alberto Torchinsky,From dyadic $\Lambda_{\alpha}$ to $\Lambda_{\a...,,"Illinois J. Math. 52 (2008) no.2, 681-689",,,math.CA math.FA,,In this paper we show how to compute the $\L...,"[{'version': 'v1', 'created': 'Mon, 2 Apr 2007...",2013-10-15,"[[Abu-Shammala, Wael, ], [Torchinsky, Alberto, ]]","[0.118412666, -0.0127423415, 0.1185125113, 0.0..."


## Data Cleaning

In [24]:
# Remove rows where 'abstract' or 'title' is NA or empty
dataset_df = dataset_df.dropna(subset=["abstract", "title"])

# Remove the embedding from each data point in the dataset as we are going to create new embeddings
dataset_df = dataset_df.drop(columns=["embedding"])

## Embedding Generation

In [25]:
import openai
from tqdm.notebook import tqdm

openai.api_key = os.environ["OPENAI_API_KEY"]

EMBEDDING_MODEL = "text-embedding-3-small"


def get_embedding(text):
    """Generate an embedding for the given text using OpenAI's API."""
    if not text or not isinstance(text, str):
        return None

    try:
        embedding = (
            openai.embeddings.create(input=text, model=EMBEDDING_MODEL, dimensions=1536)
            .data[0]
            .embedding
        )
        return embedding
    except Exception as e:
        print(f"Error in get_embedding: {e}")
        return None


def combine_columns(row, columns):
    """Combine the contents of specified columns into a single string."""
    return " ".join(str(row[col]) for col in columns if pd.notna(row[col]))


def apply_embedding_with_progress(df, columns):
    """Apply embedding to concatenated text from multiple dataframe columns with a progress bar."""
    if not all(col in df.columns for col in columns):
        missing_cols = [col for col in columns if col not in df.columns]
        raise ValueError(f"Columns {missing_cols} not found in the DataFrame.")

    tqdm.pandas(desc=f"Generating embeddings for columns: {', '.join(columns)}")

    # Combine specified columns
    df["combined_text"] = df.apply(lambda row: combine_columns(row, columns), axis=1)

    # Generate embeddings
    df["embedding"] = df["combined_text"].progress_apply(get_embedding)

    # Remove the temporary 'combined_text' column
    df = df.drop(columns=["combined_text"])

    return df


# Ggenerate embeddings based on 'abstract' and 'title' columns
try:
    # Ensure 'embedding' column is dropped if it exists
    dataset_df = dataset_df.drop(columns=["embedding"], errors="ignore")

    # Apply embeddings using multiple columns
    columns_to_embed = [
        "abstract",
        "title",
    ]  # Add or remove columns as needed (text only)
    dataset_df = apply_embedding_with_progress(dataset_df, columns_to_embed)
except Exception as e:
    print(f"An error occurred: {e}")

Generating embeddings for columns: abstract, title:   0%|          | 0/4000 [00:00<?, ?it/s]

In [30]:
# Display the first few rows of the result
dataset_df[columns_to_embed + ["embedding"]].head()

Unnamed: 0,abstract,title,embedding
0,A fully differential calculation in perturba...,Calculation of prompt diphoton production cros...,"[0.04978983476758003, -0.027831584215164185, -..."
1,"We describe a new algorithm, the $(k,\ell)$-...",Sparsity-certifying Graph Decompositions,"[0.021434221416711807, -0.030077634379267693, ..."
2,The evolution of Earth-Moon system is descri...,The evolution of the Earth-Moon system based o...,"[0.023649143055081367, 0.04319588467478752, 0...."
3,We show that a determinant of Stirling cycle...,A determinant of Stirling cycle numbers counts...,"[0.013857707381248474, -0.016583219170570374, ..."
4,In this paper we show how to compute the $\L...,From dyadic $\Lambda_{\alpha}$ to $\Lambda_{\a...,"[0.05201460048556328, 0.00613348139449954, 0.0..."


## MongoDB Vector Database and Connection Setup


MongoDB acts as both an operational and a vector database for the RAG system.
MongoDB Atlas specifically provides a database solution that efficiently stores, queries and retrieves vector embeddings.

Creating a database and collection within MongoDB is made simple with MongoDB Atlas.

1. First, register for a [MongoDB Atlas account](https://www.mongodb.com/cloud/atlas/register). For existing users, sign into MongoDB Atlas.
2. [Follow the instructions](https://www.mongodb.com/docs/atlas/tutorial/deploy-free-tier-cluster/). Select Atlas UI as the procedure to deploy your first cluster.
3. Create the database: `knowledge`.
4. Within the database ` research_papers`, create the collection ‘listings_reviews’.
5. Create a [vector search index](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure/) named vector_index for the ‘listings_reviews’ collection. This index enables the RAG application to retrieve records as additional context to supplement user queries via vector search. Below is the JSON definition of the data collection vector search index.

Your vector search index created on MongoDB Atlas should look like below:

```
{
  "fields": [
    {
      "numDimensions": 1536,
      "path": "embedding",
      "similarity": "cosine",
      "type": "vector"
    }
  ]
}

```

Follow MongoDB’s [steps to get the connection](https://www.mongodb.com/docs/manual/reference/connection-string/) string from the Atlas UI. After setting up the database and obtaining the Atlas cluster connection URI, securely store the URI within your development environment.

This guide uses Google Colab, which offers a feature for securely storing environment secrets. These secrets can then be accessed within the development environment. Specifically, the line mongo_uri = userdata.get('MONGO_URI') retrieves the URI from the secure storage.

In [None]:
os.environ["MONGO_URI"] = ""

In [31]:
import pymongo
from google.colab import userdata


def get_mongo_client(mongo_uri):
    """Establish and validate connection to the MongoDB."""

    client = pymongo.MongoClient(mongo_uri, appname="devrel.showcase.gemma2.python")

    # Validate the connection
    ping_result = client.admin.command("ping")
    if ping_result.get("ok") == 1.0:
        # Connection successful
        print("Connection to MongoDB successful")
        return client
    print("Connection to MongoDB failed")
    return None


mongo_uri = os.environ["MONGO_URI"]

if not mongo_uri:
    print("MONGO_URI not set in environment variables")

mongo_client = get_mongo_client(mongo_uri)

DB_NAME = "knowledge"
COLLECTION_NAME = "research_papers"

db = mongo_client.get_database(DB_NAME)
collection = db.get_collection(COLLECTION_NAME)

Connection to MongoDB successful


In [32]:
# Delete any existing records in the collection
collection.delete_many({})

DeleteResult({'n': 0, 'electionId': ObjectId('7fffffff000000000000002a'), 'opTime': {'ts': Timestamp(1719597926, 1), 't': 42}, 'ok': 1.0, '$clusterTime': {'clusterTime': Timestamp(1719597926, 1), 'signature': {'hash': b'\xb3\xc2\xbaK\x7f\x82\xe0m`\xea\xfa\x94H\x15/\xc7M!*i', 'keyId': 7320226449804230662}}, 'operationTime': Timestamp(1719597926, 1)}, acknowledged=True)

## Data Ingestion

In [33]:
# Ingest data into MongoDB
try:
    collection.insert_many(dataset_df.to_dict("records"))
    print("Data ingestion into MongoDB completed")
except Exception as e:
    print(f"An error occurred during data ingestion: {e}")

Data ingestion into MongoDB completed


## Vector Search Operation

In [34]:
def vector_search(user_query, collection):
    """
    Perform a vector search in the MongoDB collection based on the user query.

    Args:
    user_query (str): The user's query string.
    collection (MongoCollection): The MongoDB collection to search.

    Returns:
    list: A list of matching documents.
    """

    # Generate embedding for the user query
    query_embedding = get_embedding(user_query)

    if query_embedding is None:
        return "Invalid query or embedding generation failed."

    # Define the vector search pipeline
    vector_search_stage = {
        "$vectorSearch": {
            "index": "vector_index",
            "queryVector": query_embedding,
            "path": "embedding",
            "numCandidates": 150,  # Number of candidate matches to consider
            "limit": 4,  # Return top 4 matches
        }
    }

    project_stage = {
        "$project": {
            "_id": 0,  # Exclude the _id field
            "fullplot": 1,  # Include the plot field
            "title": 1,  # Include the title field
            "genres": 1,  # Include the genres field
            "score": {
                "$meta": "vectorSearchScore"  # Include the search score
            },
        }
    }

    pipeline = [vector_search_stage, project_stage]

    # Execute the search
    results = collection.aggregate(pipeline)
    return list(results)

## Handle User Results

In [35]:
def get_search_result(query, collection):
    get_knowledge = vector_search(query, collection)

    search_result = ""
    for result in get_knowledge:
        search_result += f"Title: {result.get('title', 'N/A')}, Plot: {result.get('fullplot', 'N/A')}\n"

    return search_result

In [49]:
# Conduct query with retrival of sources
query = "Give me a recommended paper on machine learning"
source_information = get_search_result(query, collection)
combined_information = f"Query: {query}\nContinue to answer the query by using the Search Results:\n{source_information}."

print(combined_information)

Query: Give me a recommended paper on machine learning
Continue to answer the query by using the Search Results:
Title: Using Access Data for Paper Recommendations on ArXiv.org, Plot: N/A
Title: Missing Data: A Comparison of Neural Network and Expectation
  Maximisation Techniques, Plot: N/A
Title: An Adaptive Strategy for the Classification of G-Protein Coupled
  Receptors, Plot: N/A
Title: A multivariate approach to heavy flavour tagging with cascade training, Plot: N/A
.


## Keras Config and Markdown

In [39]:
import textwrap

import keras
import keras_nlp
from IPython.display import Markdown

# Run at half precision.
keras.config.set_floatx("bfloat16")


def to_markdown(text):
    text = text.replace("•", "  *")
    return Markdown(textwrap.indent(text, "> ", predicate=lambda _: True))

## Handle Response Generation and History

In [40]:
from typing import Dict, Optional


class GemmaChat:
    __START_TURN__ = "<start_of_turn>"
    __END_TURN__ = "<end_of_turn>"
    __SYSTEM_STOP__ = "<eos>"

    def __init__(
        self, model, system: str = "", history: Optional[Dict[str, str]] = None
    ):
        self.model = model
        self.system = system
        self.history_params = history or {}
        self.client = pymongo.MongoClient(
            self.history_params.get("connection_string", "mongodb://localhost:27017/")
        )
        self.db = self.client[self.history_params.get("database", "gemma_chat")]
        self.collection = self.db[self.history_params.get("collection", "chat_history")]
        self.session_id = self.history_params.get("session_id", "default_session")

    def format_message(self, message: str, prefix: str = "") -> str:
        return f"{self.__START_TURN__}{prefix}\n{message}{self.__END_TURN__}\n"

    def add_to_history(self, message: str, prefix: str = ""):
        formatted_message = self.format_message(message, prefix)
        self.collection.insert_one(
            {"session_id": self.session_id, "message": formatted_message}
        )

    def get_full_prompt(self) -> str:
        history = self.collection.find({"session_id": self.session_id}).sort("_id", 1)
        prompt = self.system + "\n" + "\n".join([item["message"] for item in history])
        return prompt

    def send_message(self, message: str) -> str:
        self.add_to_history(message, "user")
        prompt = self.get_full_prompt()
        response = self.model.generate(prompt, max_length=2048)
        result = response.replace(prompt, "").replace(self.__SYSTEM_STOP__, "")
        self.add_to_history(result, "model")
        return result

    def show_history(self):
        history = self.collection.find({"session_id": self.session_id}).sort("_id", 1)
        for item in history:
            print(item["message"])

## Gemma2 Model Initalisation

In [41]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(
    "hf://gg-tt/gemma-2-instruct-9b-keras"
)
gemma_lm.summary()

metadata.json:   0%|          | 0.00/143 [00:00<?, ?B/s]

task.json:   0%|          | 0.00/2.25k [00:00<?, ?B/s]

model.weights.h5:   0%|          | 0.00/18.5G [00:00<?, ?B/s]

vocabulary.spm:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

In [42]:
# Testing Gemma
%time result = gemma_lm.generate("What are your current capabilities?", max_length=256)
to_markdown(result)  # noqa: F821

CPU times: user 56.1 s, sys: 741 ms, total: 56.8 s
Wall time: 36.7 s


> What are your current capabilities?
> 
> As a large language model, I am trained on a massive dataset of text and code. This allows me to perform a variety of tasks, including:
> 
> * **Generating text:** I can write stories, articles, poems, and other types of creative content.
> * **Translating languages:** I can translate text from one language to another.
> * **Summarizing text:** I can provide concise summaries of long pieces of text.
> * **Answering questions:** I can answer questions based on the information I have been trained on.
> * **Coding:** I can generate and understand code in multiple programming languages.
> 
> **However, it is important to note that I am still under development and my abilities are constantly evolving.** I am not able to access real-time information or interact with the physical world. I also do not have personal opinions or beliefs.
> 
> My purpose is to assist users with their language-based tasks and provide helpful information.<end_of_turn>


## Query Gemma 2 with Retrieved Data

In [50]:
history_params = {
    "connection_string": userdata.get("MONGO_URI"),
    "database": DB_NAME,
    "collection": "chat_history",
    "session_id": "unique_session_id",
}

gemma_chat = GemmaChat(
    gemma_lm, system="You are a research assistant", history=history_params
)

In [51]:
result = gemma_chat.send_message(combined_information)
to_markdown(result)

> 
> 
> Based on your search results, I'd recommend **"Missing Data: A Comparison of Neural Network and Expectation Maximisation Techniques"**. 
> 
> Here's why:
> 
> * **Relevance to Machine Learning:** This paper directly addresses a common challenge in machine learning: handling missing data. 
> * **Comparison of Methods:** It compares two popular approaches for dealing with missing data – neural networks and Expectation Maximisation – which is valuable for understanding the strengths and weaknesses of different techniques.
> 
> 
> Let me know if you'd like to explore other papers based on specific aspects of machine learning! 
> <end_of_turn>


## View Chat History

In [52]:
gemma_chat.show_history()

<start_of_turn>user
Query: What is the best romantic movie to watch and why?
Continue to answer the query by using the Search Results:
Title: Non-Associativity of Lorentz Transformation and Associative Reflection
  Symmetric Transformation, Plot: N/A
Title: Erwin Schroedinger, Francis Crick and epigenetic stability, Plot: N/A
Title: Time and motion in physics: the Reciprocity Principle, relativistic
  invariance of the lengths of rulers and time dilatation, Plot: N/A
Title: XMM-Newton X-ray Observations of the Wolf-Rayet Binary System WR 147, Plot: N/A
.<end_of_turn>

<start_of_turn>model


It seems like you've provided me with some scientific research papers rather than movie titles!  

To recommend a great romantic movie, I need information about what kind of romance you're looking for. 

For example, do you prefer:

* **Classic romantic comedies?**
* **Heartbreaking dramas?**
* **Something lighthearted and fun?**
* **A movie with a historical setting?**


Tell me more about your tas