# Evaluating RAG quality with MLFlow
This notebook demonstrates how to use MLFlow to evaluate the quality of a Retrieval-Augmented Generation (RAG) system. We will:
- Split, vectorize, and index a text with ChromaDB
- Configure an MLFlow model that queries the vector DB based on a user prompt and summarizes the results
- Compare the output to an expected output with `mlflow.evaluate`.

## Setting up the vector database

In [0]:
!pip install --upgrade torch
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install einops
!pip install chromadb
!pip install --upgrade typing-extensions

def is_databricks():
    try:
        dbutils
        return True
    except NameError:
        return False
  
if is_databricks():
  dbutils.library.restartPython()

In [0]:
# set up chromadb collection
import chromadb
import openai
import mlflow
import os
from dotenv import load_dotenv
import torch

chroma_client = chromadb.Client()
docs = chroma_client.create_collection("retrieval_docs", get_or_create=True)

In [0]:
if is_databricks():
  os.environ["OPENAI_API_KEY"]=dbutils.secrets.get(scope="daniel.liden", key="OPENAI_API_KEY")
  os.environ['TRANSFORMERS_CACHE'] = "/dbfs/daniel.liden/cache/hf/"
  openai.api_key=dbutils.secrets.get(scope="daniel.liden", key="OPENAI_API_KEY")
else:
  load_dotenv()

assert (
    "OPENAI_API_KEY" in os.environ
), "Please set the OPENAI_API_KEY environment variable."

In [0]:
assert (
    "OPENAI_API_KEY" in os.environ
), "Please set the OPENAI_API_KEY environment variable."

For simplicity, we'll restrict our attention to one document—the [MLFlow Concepts](https://mlflow.org/docs/latest/concepts.html) docs. Let's extract the docs and split them by sentence.

In [0]:
# Extract text from https://mlflow.org/docs/latest/concepts.html
import requests
from bs4 import BeautifulSoup


def extract_text(url):
    response = requests.get(url)
    soup = BeautifulSoup(response.text, "html.parser")

    # remove script and style elements
    for script in soup(["script", "style"]):
        script.decompose()

    # find the header and get all text after it
    text = ""
    start_collecting = False
    for tag in soup.find_all(True):
        if tag.name == "h1" and tag.text.strip().lower() == "concepts":
            start_collecting = True
        if start_collecting:
            text += " " + tag.get_text()
    # get text
    # text = soup.get_text()

    # split into sentences
    text = text.replace("\n", " ")
    sentences = text.split(".")
    # remove leading and trailing whitespaces
    sentences = [sentence.strip() for sentence in sentences if sentence]

    return sentences


url = "https://mlflow.org/docs/latest/concepts.html"
concepts = extract_text(url)

# remove footer/navigation components
concepts = concepts[:-4]

Now we can add aour texts to our ChromaDB vector database. Note that, in a production setting, it would be worthwhile to spend some more time on document formatting; e.g. grouping (or omitting) code blocks and removing strings that do not contain meaningful information.

[By default](https://docs.trychroma.com/embeddings#default-all-minilm-l6-v2), ChromaDB uses the `all-MiniLM-L6-v2` model to generate embeddings from the texts; this can be changed easily.

In [0]:
docs.add(
    documents=concepts,
    ids=[f"id_{i}" for i in range(len(concepts))],
)

Now we can `peek()` at the first few entries.

In [0]:
docs.peek()

Now we can run a sample query against this database.

In [0]:
results = docs.query(
    query_texts=["How can an individual data scientist use MLFlow?"]
)
results["documents"][0]

## Configuring the MLFlow Model
We're going to write a pyfunc wrapper around an OpenAI model. We want the model to connect to the ChromaDB collection we initialized above *without* needing to save the collection as an artifact in the MLFlow tracking system (depending on your specific needs, you may actually want to log the database as an artifact; we're opting for a lighter-weight approach here).

Note that we also add a `gen_context` instance method that takes the top n results from the vector database and formats them for insertion into the prompt template.

In [0]:
import pandas as pd

class PyfuncWithRetrieval(mlflow.pyfunc.PythonModel):
    """
    A custom MLflow model for text generation with retrieval functionality.

    Extends the mlflow.pyfunc.PythonModel class and utilizes a pre-trained
    OpenAI transformer model for text generation based on an external vector
    database for retrieval of relevant context.
    """

    def __init__(self, db_name):
        """
        Initialize an instance of the PyfuncWithRetrieval class.

        Args:
            db_name (str): The name of the external vector database for context retrieval.
        """
        self.db_name = db_name
        super().__init__()

    def load_context(self, context):
        """
        Load the MLflow context.

        Args:
            context: The MLflow context.
        """
        self.prompt_template = """
You are a question answering assistant. Answer the user question below based on the provided context.

Context: {context}

Question: {question}

If the context is not relevant to the question, respond that you have no relevant information."""

    def gen_context(self, db, prompt, top_n=3):
        """
        Generate context from the external database based on the input prompt.

        Args:
            db: The external vector database.
            prompt: The input prompt for the query.
            top_n (int, optional): Number of top results to retrieve. Defaults to 3.

        Returns:
            str: Retrieved context from the database.
        """
        results = db.query(query_texts=prompt, n_results=top_n)
        texts = results["documents"][0]
        texts = "\n-----------------------------------------------------------------------------------\n".join(
            texts
        )
        return texts

    def predict(self, context, model_input):
        """
        Generate text based on the provided model input.

        Args:
            context: The MLflow context.
            model_input: The input used for generating the text.

        Returns:
            list: A list of generated texts.
        """
        chroma_client = chromadb.Client()
        collection = chroma_client.get_collection(self.db_name)

        if isinstance(model_input, pd.DataFrame):
            model_input = model_input.values.flatten().tolist()
        elif not isinstance(model_input, list):
            model_input = [model_input]

        generated_text = []
        for input_text in model_input:
            context = self.gen_context(collection, input_text, top_n=3)
            prompt = self.prompt_template.format(
                context=context, question=input_text
            )
            output = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
                messages=[
                    {
                        "role": "system",
                        "content": "You are a helpful question-answering assistant.",
                    },
                    {"role": "user", "content": prompt},
                ],
            )
            output_text = output.choices[0].message.content
            generated_text.append(output_text)

        return generated_text

Now we can load the model, specifying the name of the relevant chromaDBcollection as we do so.

In [0]:
#mlflow.set_experiment("retrieval-eval")


gpt_3_5_retrieval = PyfuncWithRetrieval(db_name="retrieval_docs")
with mlflow.start_run(run_name=f"log_model_gpt_3_5_retrieval"):
    pyfunc_model = gpt_3_5_retrieval
    artifact_path = f"gpt_3_5_retrieval_model"
    gpt3_5_retrieval_model_info = mlflow.pyfunc.log_model(
        artifact_path=artifact_path,
        python_model=pyfunc_model,
    )

In [0]:
model = mlflow.pyfunc.load_model(gpt3_5_retrieval_model_info.model_uri)
print(model.predict("Where would you likely find a whale?"))
print(model.predict("Who can benefit from MLFlow?"))

## Evaluating the Retrieval-Augmented Model
Now we can use `mlflow.evaluate()` (as described in [this post](https://medium.com/@dliden/comparing-llms-with-mlflow-1c69553718df)) to try out our retrieval system on a few different prompts.

First, we set up our evaluation dataset.

In [0]:
eval_df = pd.DataFrame(
    {
        "question": [
            "Which MLflow component helps manage the machine learning workflow by logging parameters, metrics, and artifacts?",
            "True or False: MLflow Projects allow packaging data science code in a reusable format with configuration files describing its dependencies and how to run it.",
            "What syntax does the MLflow Tracking API use to reference the location of artifacts?",
            "How can large organizations use MLflow Model Registry?",
            "Which API allows deploying models in multiple flavors for diverse platforms like Docker and Apache Spark?",
            "What is the largest country in the world by area?",
            "How many legs does a spider have?",
        ]
    }
)

Lastly, we evaluate!

In [0]:
with mlflow.start_run(
    run_id=gpt3_5_retrieval_model_info.run_id,
):  # reopen the run with the stored run ID
    evaluation_results = mlflow.evaluate(
        model=f"runs:/{gpt3_5_retrieval_model_info.run_id}/{gpt3_5_retrieval_model_info.artifact_path}",
        model_type="text",
        data=eval_df,
    )

In [0]:
mlflow.load_table("eval_results_table.json")


# Compare Multiple Models
Now we will compare the OpenAI model defined above to `falcon-7b-instruct`.

In [0]:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


## Set up the pyfunc model

In [0]:
class PyfuncFalconInstr7bWithRetrieval(mlflow.pyfunc.PythonModel):
    """PyfuncTransformer is a class that extends the mlflow.pyfunc.PythonModel class
    and is used to create a custom MLflow model for text generation using Transformers.
    """

    def __init__(self, db_name):
        """
        Initializes a new instance of the PyfuncTransformer class.
        """
        self.db_name = db_name
        super().__init__()

    def load_context(self, context):
        """
        Loads the model and tokenizer using the specified model_name.

        Args:
            context: The MLflow context.
        """
        model_id = "tiiuae/falcon-7b-instruct"
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            device_map="auto",
            trust_remote_code=True,
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id)

        self.model = pipeline(
            "text-generation", model=model, tokenizer=tokenizer
        )
        self.eos_token_id = tokenizer.eos_token_id
        # see https://huggingface.co/tiiuae/falcon-7b-instruct/discussions/1#6478508e9c1f42c1f4d8b0bf
        self.prompt_template = """Answer the question as truthfully as possible using the provided text, and if the answer is not contained within the text below, say "I don't know"

Context: {context}

{question}\n\n"""

    def gen_context(self, db, prompt, top_n=3):
        """
        Generate context from the external database based on the input prompt.

        Args:
            db: The external vector database.
            prompt: The input prompt for the query.
            top_n (int, optional): Number of top results to retrieve. Defaults to 3.

        Returns:
            str: Retrieved context from the database.
        """
        results = db.query(query_texts=prompt, n_results=top_n)
        texts = results["documents"][0]
        texts = "\n-----------------------------------------------------------------------------------\n".join(
            texts
        )
        return texts

    def predict(self, context, model_input):
        """
        Generates text based on the provided model_input using the loaded model.

        Args:
            context: The MLflow context.
            model_input: The input used for generating the text.

        Returns:
            list: A list of generated texts.
        """

        chroma_client = chromadb.Client()
        collection = chroma_client.get_collection(self.db_name)

        if isinstance(model_input, pd.DataFrame):
            model_input = model_input.values.flatten().tolist()
        elif not isinstance(model_input, list):
            model_input = [model_input]

        generated_text = []
        for input_text in model_input:
            context = self.gen_context(collection, input_text, top_n=3)
            prompt = self.prompt_template.format(context=context, question=input_text)
            output = self.model(
                prompt,
                return_full_text=False,
                max_new_tokens=50,
                do_sample=True,
                top_k=10,
                num_return_sequences=1,
                eos_token_id=self.eos_token_id,
            )
            output_text = output[0]["generated_text"]
            cutoff_index = output_text.find("\n\n")
            # Cut off the text before this position if '\n' is found. If not, return the full text.
            short_output = (
                output_text if cutoff_index == -1 else output_text[:cutoff_index]
            )
            generated_text.append(short_output)

        return generated_text

In [0]:
# Log the model
falcon_7b_retrieval = PyfuncFalconInstr7bWithRetrieval(db_name="retrieval_docs")
with mlflow.start_run(run_name=f"log_model_falcon_7b_retrieval"):
    pyfunc_model = falcon_7b_retrieval
    artifact_path = f"falcon_7b_retrieval_model"
    falcon_7b_retrieval_model_info = mlflow.pyfunc.log_model(
        artifact_path=artifact_path,
        python_model=pyfunc_model,
    )

In [0]:
model = mlflow.pyfunc.load_model(falcon_7b_retrieval_model_info.model_uri)

In [0]:
# Print some test outputs
print(model.predict("Where do whales live?"))
print(model.predict("Who can benefit from MLFlow?"))

## Run the comparison
Note that we do not need to re-run the gpt-3.5-turbo comparison.

In [0]:
with mlflow.start_run(
    run_id=falcon_7b_retrieval_model_info.run_id,
):  # reopen the run with the stored run ID
    evaluation_results = mlflow.evaluate(
        model=f"runs:/{falcon_7b_retrieval_model_info.run_id}/{falcon_7b_retrieval_model_info.artifact_path}",
        model_type="text",
        data=eval_df,
    )