# Introduction

### About Retrieval Augmented Generation
Retrieval Augmented Generation (RAG) is a versatile pattern that can unlock a number of use cases requiring factual recall of information, such as querying a knowledge base in natural language.

In its simplest form, RAG requires 3 steps:

- Index knowledge base passages (once)
- Retrieve relevant passage(s) from knowledge base (for every user query)
- Generate a response by feeding retrieved passage into a large language model (for every user query)

### About cookbook 1.2: watsonx.ai tech preview + Milvus vectordb (long-form question answering)
This cookbook is a variant of [cookbook 1.1](/rag-1.1-vectordb.ipynb), focusing on long-form extractive question answering). The main change is the use of Milvus as a vector datastore. This cookbook uses an indexing strategy that allows Milvus to perform as well as Chromadb.

Users may want to chain the answer generated with this RAG pattern to another LLM prompt that helps paraphrase the answer according to a desired template. 

### About the example dataset
The dataset used in this cookbook is a subset of [nq_open](https://huggingface.co/datasets/nq_open), an open-source question answering dataset based on contents from Wikipedia. The selected subset includes the gold standard passages to answer the queries in the dataset, which enables evaluating the retrieval quality.

You can select one of the two dataset available:
1. **nq910** - an information retrieval (a.k.a. search) data set extracted from Google's Natural Questions dataset. This dataset is an example of extractive, short-form question answering.
2. **LongNQ** - an end-to-end retrieval and answer dataset extracted from the same NQ dataset, but focused more on abstractive, longer-form question answering. The answers were modified for fluency by IBM Research. This is the default dataset for this pattern.

These datasets are available in the [data/rag](data/rag/) folder.

**Disclaimer: to use this cookbook you need a REST API key compatible with ibm-generative-ai SDK. Note that this API is currently in Beta and will change in the future.**

### Limitations
Given that we are leveraging a locally-hosted embedding model, data ingestion and querying speeds can be slow.

### Cookbook Structure
1. Set-up dependencies
2. Index knowledge base <br>
3. Generate a retrieval-augmented response <br>
4. Evaluate RAG performance on your data <br>

In [1]:
# Improve code auto-completion by disabling
%config Completer.use_jedi = False

# 1. Set-up dependencies

### 1.1 Install the required dependencies

Note that `ibm-generative-ai` requires `python>=3.9` and `pip>=22.0.1`. A user may need to make sure these pre-requisites are met before using this notebook

In [14]:
!pip install numpy
!pip install matplotlib
!pip install python-dotenv
!pip install pandas
!pip install unitxt
!pip install --upgrade ibm-generative-ai
!pip install pymilvus
!pip install langchain

Collecting langchain
  Downloading langchain-0.1.19-py3-none-any.whl.metadata (13 kB)
Collecting SQLAlchemy<3,>=1.4 (from langchain)
  Downloading SQLAlchemy-2.0.30-cp312-cp312-macosx_11_0_arm64.whl.metadata (9.6 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain)
  Downloading dataclasses_json-0.6.6-py3-none-any.whl.metadata (25 kB)
Collecting langchain-community<0.1,>=0.0.38 (from langchain)
  Downloading langchain_community-0.0.38-py3-none-any.whl.metadata (8.7 kB)
Collecting langchain-core<0.2.0,>=0.1.52 (from langchain)
  Downloading langchain_core-0.1.52-py3-none-any.whl.metadata (5.9 kB)
Collecting langchain-text-splitters<0.1,>=0.0.1 (from langchain)
  Downloading langchain_text_splitters-0.0.1-py3-none-any.whl.metadata (2.0 kB)
Collecting langsmith<0.2.0,>=0.1.17 (from langchain)
  Downloading langsmith-0.1.56-py3-none-any.whl.metadata (13 kB)
Collecting tenacity<9.0.0,>=8.1.0 (from langchain)
  Downloading tenacity-8.3.0-py3-none-any.whl.metadata (1.2 kB)
Collecting 

### 1.2. Import necessary modules

In [27]:
import logging
import os
import pickle
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from genai import Client, Credentials
from genai.schema import (
    TextEmbeddingParameters,
    TextGenerationParameters,
    DecodingMethod,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pymilvus import (
    Collection,
    CollectionSchema,
    DataType,
    FieldSchema,
    connections,
    utility,
)
from tqdm.notebook import tqdm
from unitxt import add_to_catalog
from unitxt.eval_utils import evaluate
from unitxt.metrics import MetricPipeline
from unitxt.operators import CopyFields

logging.getLogger("unitxt").setLevel(logging.ERROR)

### 1.3. Load credentials for `ibm-generative-ai`

Your `.env` file needs to have the following lines without spaces around `=`.

```
GENAI_KEY=your-genai-key
GENAI_API=your-genai-api
```

By default, `IBM-Generative-AI` will automatically use the following API endpoint: `https://bam-api.res.ibm.com`. However, if you wish to target a different Gen AI API, you can do so by providing a custom API endpoint. 

In [28]:
load_dotenv(override=True)

creds = Credentials.from_env()
if creds.api_endpoint:
    print(f"Your API endpoint is: {creds.api_endpoint}")

Your API endpoint is: https://bam-api.res.ibm.com


### 1.4. Initialize SDK Client `ibm-generative-ai`

In [29]:
client = Client(credentials=creds)

# 2. Index knowledge base

### 2.1. Load data

Select one of the two dataset available:
1. *nq910* - an Information Retrieval (a.k.a. search) data set extracted from Google's Natural Questions dataset.
2. *LongNQ* - an end-to-end retrieval and answer dataset extracted from the same NQ dataset, but focused more on abstractive question answering.

In [32]:
datasets = ["LongNQ", "nq910", "LongNQ_docs"]
dataset = datasets[2]  # The current dataset to use
data_root = "./data/rag"
data_dir = os.path.join(data_root, dataset)

print ("Data directory:", data_dir)
print("Selected dataset:", dataset)

Data directory: ./data/rag/LongNQ_docs
Selected dataset: LongNQ_docs


In [33]:
def load_data_v1(data_dir, data_root):
    if not os.path.exists(data_dir):
        # Try to unzip the directory
        from zipfile import ZipFile

        with ZipFile(data_dir + ".zip", "r") as zObject:
            zObject.extractall(data_root)

    psgs = pd.read_csv(os.path.join(data_dir, "psgs.tsv"), sep="\t", header=0)
    # psgs['indextext'] = psgs['title'].astype(str) + "\n" + psgs['text'] -JS Don't think we need?

    qas = pd.read_csv(os.path.join(data_dir, "questions.tsv"), sep="\t", header=0).rename(
        columns={"text": "question", "id": "qid"}
    )

    return psgs, qas


documents, questions = load_data_v1(data_dir, data_root)

FileNotFoundError: [Errno 2] No such file or directory: './data/rag/LongNQ_docs.zip'

In [None]:
questions.head()

In [None]:
documents.head()


The dataset we are using is already chunked into self-contained passages that can be ingested by a vector store.

The size of each passage is limited by the embedding model's context window (which is 256 tokens for `all-MiniLM-L6-v2`).

In case your dataset requires chunking, it is recommended to chunk according to the document's structure and include contextual metadata such as a title for each passage. You may need to include a stride window for lengthier passages if there is a risk of cutting off important context. There is usually some experimentation required to get chunking right. It's helpful to have a test dataset to evaluate the impact of passage chunking on the retrieval quality (see section 4.1.).

### 2.2. Create embedding function

In [None]:
# Simple function that converts the texts to embeddings


def get_embeddings(texts: list[str]):
    embeddings: list[list[str]] = []
    for response in client.text.embedding.create(
        model_id="sentence-transformers/all-minilm-l6-v2",
        inputs=texts,
        parameters=TextEmbeddingParameters(truncate_input_tokens=True),
    ):
        embeddings.extend(response.results)

    return embeddings

### 2.3 Start Milvus

Start the Milvus embedded server.

In [None]:
if default_server.running is not True:
    default_server.start()
    print("Server should have now started")
else:
    default_server.stop()
    default_server.cleanup()
    default_server.start()
    print("Server is already running")

Establish a connection with the embedded server and print its version information.

In [None]:
connections.connect(host="localhost", port=default_server.listen_port)
print(utility.get_server_version())

### 2.4 Define a collection

In [None]:
COLLECTION_NAME = dataset + "_collection"
INDEX_NAME = dataset + "_index"

In [None]:
# Run if you want to drop your old data
try:
    utility.drop_collection(COLLECTION_NAME)
    print("Collection has been deleted")
except:  # noqa: E722
    pass

In [None]:
id = FieldSchema(
    name="id",
    dtype=DataType.INT64,
    is_primary=True,
    auto_id=True,
)

text = FieldSchema(
    name="text",
    dtype=DataType.VARCHAR,
    max_length=4096,
)

text_vector = FieldSchema(name="text_vector", dtype=DataType.FLOAT_VECTOR, dim=384)

qid = FieldSchema(name="qid", dtype=DataType.INT64)

title = FieldSchema(
    name="title",
    dtype=DataType.VARCHAR,
    max_length=4096,
)

schema = CollectionSchema(
    fields=[id, text, text_vector, qid, title],
    description="Demo vector store",
    enable_dynamic_field=True,
)

collection = Collection(name=COLLECTION_NAME, schema=schema, using="default", shards_num=2)

###  2.5 Prepare collection

Prepare the embeddings, texts, titles and question id's for insertion in collection.

In [None]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, chunk_overlap=20, length_function=len, add_start_index=False
)


def split_and_prepare_document_new(qid: str, title: str, text: str):
    split_text = [text.page_content for text in text_splitter.create_documents([text])]
    ids = [qid] * len(split_text)
    titles = [title] * len(split_text)
    embeddings = get_embeddings(split_text)
    return split_text, ids, titles, embeddings


def process_batch(document_list):
    batch_results = []

    for id, title, text in zip(
        document_list["id"].values.tolist(),
        document_list["title"].values.tolist(),
        document_list["text"].values.tolist(),
    ):
        for sub_text, sub_id, sub_title, sub_embedding in zip(*split_and_prepare_document_new(id, title, text)):
            batch_results.append(tuple((sub_id, sub_title, sub_text, sub_embedding)))
    return batch_results

In [None]:
%%time

batch_size = 10
processed_docs = []
cache_filename = Path("data/.cache/rag-1.2-prepared-docs.pkl")
allow_cache = True

if allow_cache and cache_filename.exists():
    print("Prepared docs cache file exists, loading.")
    with open(cache_filename, "rb+") as f:
        processed_docs = pickle.load(f)

    print("Processed docs loaded from pickle checkpoint")
else:
    for i in tqdm(range(0, len(documents), batch_size), desc="Processing Documents in Batches"):
        # find end of batch
        i_end = min(i + batch_size, len(documents))
        documents_batch = documents[i:i_end]

        # Process the batch
        processed = process_batch(documents_batch)
        processed_docs.extend(processed)

    # Save results for potential reuse
    cache_filename.parent.mkdir(exist_ok=True, parents=True)
    with open(cache_filename, "wb+") as f:
        pickle.dump(processed_docs, f)

    print("Processed docs saved to pickle checkpoint")

Insert the embeddings, texts, titles and question id's in collection.

In [None]:
if default_server.running:
    collection = Collection(COLLECTION_NAME)

    batch_size = 500
    for i in tqdm(
        range(0, len(processed_docs), batch_size),
        desc="Inserting documents batches to Milvus VectorDB",
    ):
        # find end of batch
        i_end = min(i + batch_size, len(processed_docs))
        id_l, title_l, text_l, embed_l = list(zip(*processed_docs[i:i_end]))

        data_to_insert = [text_l, embed_l, id_l, title_l]
        try:
            collection.insert(data_to_insert)
        except Exception as ex:
            print(f"Failed to insert: {ex}")
            print(title_l)
else:
    print("Milvus server is not running! Rerun related notebook cells.")

Create an index on vector field (the one containing the embeddings)

**NOTE: use HNSW as the index type**

In [None]:
NLIST_SIZE = 1024

index_params = {
    "metric_type": "L2",
    "index_type": "HNSW",
    "params": {"nlist": NLIST_SIZE},
    "M": 16,
    "efConstruction": 200,
}

collection.create_index(field_name="text_vector", index_params=index_params)

print("Collection index has been successfully created!")

# 3. Generate a retrieval-augmented response to a question 

### 3.1. Setup Generative Model

In [None]:
# get the list of supported models from the API
models = pd.DataFrame(data=(model.model_dump() for model in client.model.list().results))
models = models.set_index("id")
models

In [None]:
# select generative model to use
model_id = "google/flan-ul2"
parameters = TextGenerationParameters(decoding_method=DecodingMethod.GREEDY, max_new_tokens=100, min_new_tokens=1)

# Find model token limit
model_token_limit = models.loc[model_id].token_limits[0]["token_limit"]
print(f"Model token limit:  {model_token_limit}")

In [None]:
# set-up inference parameters
parameters = TextGenerationParameters(decoding_method=DecodingMethod.GREEDY, max_new_tokens=100, min_new_tokens=1)

The input token limit depends on the selected generative model's max sequence length. The total input tokens in the RAG prompt should not exceed the model's max sequence length minus the number of desired output tokens. The choice of the number of paragraphs to retrieve as context impacts the number tokens in the prompt.

In [None]:
# For setting the input token limit we subtract the max_new_tokens (to be generated) and -1 from the model_token_limit
input_token_limit = model_token_limit - parameters.max_new_tokens - 1
print(f"Input token limit: {input_token_limit}")

### 3.2. Select a question

In [None]:
qidx = 2
question_text = questions.question[qidx].strip("?") + "?"
question_embeddings = get_embeddings([question_text])[0]
print(question_text)

### 3.3. Retrieve relevant context

i.e. Fetch paragraphs similar to the question

In [None]:
collection.load()
search_params = {"metric_type": "L2", "params": {"ef": 10}}


@dataclass
class RetrievedContext:
    id: int
    text: str
    title: str
    distance: Optional[float] = None


def query_documents(question_text: str, n_results=5) -> list[RetrievedContext]:
    question_embedding = get_embeddings([question_text])[0]
    response = collection.search(
        data=[question_embedding],
        anns_field="text_vector",
        param=search_params,
        limit=4,
        expr=None,
        output_fields=[
            "qid",
            "text",
            "title",
        ],  # name of the field to retrieve from the search result
        consistency_level="Strong",
    )
    results = []
    for raw_results in response:
        for document in raw_results:
            results.append(
                RetrievedContext(
                    id=document.entity.get("qid"),
                    text=document.entity.get("text"),
                    title=document.entity.get("title"),
                    distance=document.distance,
                )
            )
    return results

In [None]:
relevant_documents = query_documents(question_text)
pd.DataFrame(relevant_documents).set_index("id")

### 3.4. Feed the context and the question to `genai` model.

In [None]:
# Token counting function
def token_count(doc: str):
    response = list(client.text.tokenization.create(input=[doc], model_id=model_id))
    return response[0].results[0].token_count

`prompt_template` is a function to create a prompt from the given context and question. Changing the prompt will sometimes result in much more appropriate answers (or it may degrade the quality significantly). The prompt template below is most appropriate for short-form extractive use cases.

`make_prompt` includes a script to truncate the context length provided as an input in case the total token inputs exceed the model's limit. The paragraphs with the largest distance are truncated first. This functionality is helpful in case the embedded passages are not of the same size.

In [None]:
def prompt_template(context, question_text):
    return (
        f'Please answer the question using the context provided. If the question is unanswerable, say "unanswerable". Question: {question_text}.\n\n'
        + "Context:\n\n"
        + f"{context}:\n\n"
        + f'Question: {question_text}. If the question is unanswerable, say "unanswerable".'
    )


def make_prompt(
    relevant_documents: list[RetrievedContext],
    question_text: str,
    max_input_tokens: int,
):
    context = "\n\n\n".join(doc.text for doc in relevant_documents)
    prompt = prompt_template(context, question_text)

    prompt_token_count = token_count(prompt)

    if prompt_token_count <= max_input_tokens:
        return prompt

    print("exceeded input token limit, truncating context", prompt_token_count)
    distances = [doc.distance for doc in relevant_documents]
    documents = [doc.text for doc in relevant_documents]

    # documents with the lower distance scores are included in the truncated context first
    sorted_indices = sorted(range(len(distances)), key=lambda k: distances[k])

    truncated_context = ""
    token_count_so_far = 0
    i = 0

    while token_count_so_far <= max_input_tokens and i < len(sorted_indices):
        doc_index = sorted_indices[i]
        document = documents[doc_index]
        doc_token_count = token_count(document)

        if token_count_so_far + doc_token_count <= max_input_tokens:
            truncated_context += document + "\n\n\n"
            token_count_so_far += doc_token_count
        else:
            remaining_tokens = max_input_tokens - token_count_so_far
            truncated_context += document[:remaining_tokens]
            break

        i += 1

    return prompt_template(truncated_context, question_text)

In [None]:
prompt = make_prompt(relevant_documents, question_text, input_token_limit)
print(prompt)

**Generate response**

In [None]:
responses = list(client.text.generation.create(model_id=model_id, inputs=prompt, parameters=parameters))
response = responses[0].results[0]

In [None]:
print("Question = ", question_text)
print("Answer = ", response.generated_text)
print(
    "Expected Answer(s) (may not appear with exact wording in the dataset) = ",
    questions.answers[qidx],
)
print(f"QUID: {qidx}")

## 4. Evaluate RAG performance on your data

This step requires having a test dataset that includes for each question:
- The indexes of the passage(s) that contain the answer - i.e. the goldstandard passages (if the question is answerable by the knowledge base)
- The question's goldstandard answer (this can be short or long-form)

### 4.1 Retrieve context documents and generate answers for all questions
We will now run the RAG pipeline on the given questions 

In [None]:
# Prepare all data for evaluation
@dataclass
class QuestionData:
    qid: int  # Question ID
    prompt: str  # Generated prompt
    question: str  # Original question
    ground_truth_contexts: list[str]  # Text content of ground truth contexts
    ground_truths_context_ids: list[str]  # IDs of ground truth contexts
    contexts: list[str]  # Retrieved contexts from vector database
    context_ids: list[int]  # IDs of retrieved contexts
    ground_truths: list[str]  # Possible ground truth formulations of answer
    answer: Optional[str] = None  # Answer from LLM (will be filled later)


# NOTE: Sampling less questions for example purposes.
# Change this to len(questions) for full evaluation
num_eval_questions = 250
eval_questions = questions.sample(num_eval_questions)
num_retrieve_relevant = 5

prompts = []


def to_int(data: list[str]) -> list[int]:
    return list(map(int, data))


data_for_evaluation: list[QuestionData] = []

for _, question in tqdm(eval_questions.iterrows(), total=num_eval_questions):
    retrieved_documents = query_documents(question.question, n_results=num_retrieve_relevant)
    prompt = make_prompt(retrieved_documents, question.question, input_token_limit)
    contexts = [doc.text for doc in retrieved_documents]
    context_ids = [doc.id for doc in retrieved_documents]

    relevant_documents = to_int(str(question.relevant).split(","))
    # Filter out unanswerable questions (-1 means that no relevant context exists)
    relevant_documents = [document_id for document_id in relevant_documents if document_id >= 0]
    ground_truth_contexts = [documents.set_index("id").loc[document_id].text for document_id in relevant_documents]
    ground_truths_context_ids = relevant_documents
    ground_truths = str(question.answers).split("::")
    # Filter out unanswerable questions ('-' means that no relevant answer exists)
    if ground_truths == ["-"]:
        ground_truths = ["unanswerable"]

    data_for_evaluation.append(
        QuestionData(
            qid=question.qid,
            prompt=prompt,
            question=question.question,
            ground_truth_contexts=ground_truth_contexts,
            ground_truths_context_ids=ground_truths_context_ids,
            contexts=contexts,
            context_ids=context_ids,
            ground_truths=ground_truths,
        )
    )

In [None]:
# Generate answers by LLM
new_data_for_evaluation = []
inputs = [datapoint.prompt for datapoint in data_for_evaluation]
for idx, response in tqdm(
    enumerate(client.text.generation.create(model_id=model_id, inputs=inputs, parameters=parameters)),
    total=len(eval_questions),
):
    data_for_evaluation[idx].answer = response.results[0].generated_text

In [None]:
# Create final Data frame for evaluation
data_frame_for_evaluation = pd.DataFrame(data_for_evaluation).set_index("qid")
data_frame_for_evaluation

### 4.2 Evaluate Retrieval quality

There are many ways to compute retrieval quality, namely how the information contained in the documents that are relevant to the question being asked. We're focusing here on success at given number of returns  (aka recall at given levels), which is to say, given a fixed number of documents returned (e.g., 1, 3, 5), is the question's answer contained in them. The scores increase with the recall level.


In [None]:
# Prepare custom unitxt metrics
recall_levels = [1, 3, 5]
for recall_level in recall_levels:
    metric = MetricPipeline(
        main_score=f"recall_at_{recall_level}",
        preprocess_steps=[
            CopyFields(field_to_field=[("context_ids", "prediction")], use_query=True),
            CopyFields(
                field_to_field=[("ground_truths_context_ids", "references")],
                use_query=True,
            ),
        ],
        metric="metrics.retrieval_at_k",
    )
    add_to_catalog(metric, f"metrics.rag.recall_at_{recall_level}", overwrite=True)


def evaluate_metrics(data: pd.DataFrame, metric_names: list[str]):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        result_df = evaluate(data, metric_names=metric_names)
    return result_df

In [None]:
recall_metric_names = [f"metrics.rag.recall_at_{level}" for level in recall_levels]
# Filter out questions with no relevant context
non_empty_context_mask = data_frame_for_evaluation.ground_truth_contexts.apply(lambda x: len(x) > 0)
result_df = evaluate_metrics(data_frame_for_evaluation[non_empty_context_mask], recall_metric_names)

**Note**: 
We do not take into account chunking in the evaluation.
Because the model can retrieve multiple paragraphs from a single document, and we do not have the ground truth passages per document, the recall metric can grow above 1.
Therefore, the numbers are more heuristic and not directly compareable to other notebooks.   

In [None]:
%matplotlib inline
ax = result_df[recall_metric_names].mean().plot(kind="bar")
ax.set_xticklabels(f"Recall at {k}" for k in recall_levels)
plt.show()

### 4.2 Evaluate answered and unanswered questions
The following table breaks the count of question/answer pairs by whether the test dataset has an answer (rows) and the RAG model returned an answer (columns).

In [None]:
answer_matrix = pd.crosstab(
    data_frame_for_evaluation.answer != "unanswerable",  # Question was answered
    data_frame_for_evaluation.ground_truths.apply(lambda x: x != ["unanswerable"]),
    rownames=["System"],
    colnames=["Ground Truth"],
)
answer_matrix = answer_matrix.rename({True: "Answered", False: "Not answered"})
answer_matrix = answer_matrix.rename(columns={True: "Has answer", False: "No answer"})
answer_matrix

## 4.3 Complex evaluation of retrieval quality and generated answers
We will leverage [unitxt](https://github.com/IBM/unitxt) metrics to evaluate the system in a more robust, complex way.
Please refer to [this document](https://github.ibm.com/conversational-ai/rag-metrics/blob/d771becd557d01d9c20a7479b3883b9c40d9fde6/README.md) to see the full explanation of the metrics.

(Can take several minutes)

In [None]:
metric_names = [
    "metrics.rag.mrr",
    "metrics.rag.map",
    "metrics.rag.context_correctness",
    "metrics.rag.context_perplexity",
    "metrics.rag.context_relevance",
    "metrics.rag.faithfulness",  # Requires both context and answer, but it makes sense only for answerable questions
]
# Evaluate metrics that take into account only context on answerable questions,
# because their score for an unanswerable question is always 0
result_df = evaluate_metrics(data_frame_for_evaluation[non_empty_context_mask], metric_names)
result_df

In [None]:
result_df["metrics.rag.context_perplexity"] = result_df["metrics.rag.context_perplexity"].apply(
    lambda perplexities: np.mean(perplexities)
)
result_df[metric_names].mean()

In [None]:
metric_names = [
    "metrics.rag.answer_reward",
    "metrics.rag.answer_correctness",
]
# Evaluate metrics that take into account answers on all questions
result_df = evaluate_metrics(data_frame_for_evaluation, metric_names)
result_df

In [None]:
result_df[metric_names].mean()