# Improving Fine-tuned Retrieval Models in Okareo

<a target="_blank" href="https://colab.research.google.com/github/okareo-ai/okareo-python-sdk/blob/main/examples/retrieval_embedding_finetuning_eval.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## 🎯 Goals

After using this notebook, you will be able to:
- Evaluate a pre-trained embedding model in Okareo
- Filter the results of the retrieval evaluation
- Generate fine-tuning data based on the filtered results
- Fine-tune the model with the generated data
- Compare the performance of the embedding models pre/post fine-tuning in Okareo 

## Problem Statement: Retrieval Model

Suppose we are developing a RAG system that answers user questions about an online retailer called WebBizz.

This notebook focuses on finetuning an open source embedding model, [all-MiniLLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2), for the Context Retrieval component of the RAG pipeline.

The purpose of Context Retrieval is to fetch relevant documents/chunks to build the context for a downstream generative model. The better performance we achieve on Retrieval, the higher quality the RAG's final output will be.

# Load the embedding model + WebBizz retrieval dataset

We start by loading a pre-trained `sentence_transformer` model.

In [1]:
from sentence_transformers import SentenceTransformer

# Load a model to train/finetune
huggingface_model_name = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(huggingface_model_name)

  from tqdm.autonotebook import tqdm, trange


Now we load our WebBizz data, including:
- Corpus: {"id": string, "doc": string}
    - Corpus of WebBizz articles
- Queries: {"id": string, "query": string}
    - User queries
- Relevancy data: {"qid": string, "dids": List[string]}
    - Maps query ID -> relevant doc IDs

In [2]:
from datasets import load_dataset

wb_corpus = load_dataset("json", data_files="data/webbizz_corpus.jsonl")
wb_queries = load_dataset("json", data_files="data/webbizz_queries.jsonl")
wb_relevant_docs_data = load_dataset("json", data_files="data/webbizz_qrels.jsonl")

In [3]:
# Convert the datasets to dictionaries
corpus = dict(zip(wb_corpus['train']["id"], wb_corpus['train']["doc"]))  # Our corpus (cid => document)
queries = dict(zip(wb_queries['train']["id"], wb_queries['train']["query"]))  # Our queries (qid => question)
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for qid, corpus_ids in zip(wb_relevant_docs_data['train']["query-id"], wb_relevant_docs_data['train']["corpus-id"]):
    qid = str(qid)
    corpus_ids = str(corpus_ids)
    if qid not in relevant_docs:
        relevant_docs[qid] = set()
    relevant_docs[qid].add(corpus_ids)

# Evaluate the pre-trained model in Okareo

Here, we perform a retrieval evaluation in Okareo by:
1. Uploading the WebBizz retrieval data as a scenario in Okareo
2. Defining a CustomModel for retrieval using ChromaDB to store our model's embeddings
3. Run a retrieval evaluation on our CustomModel with the WebBizz scenario

In [4]:
import os
from okareo import Okareo

OKAREO_API_KEY = os.environ["OKAREO_API_KEY"]
okareo = Okareo(OKAREO_API_KEY)

Upload user queries as scenario to evaluate.

In [5]:
from okareo_api_client.models import SeedData
from okareo_api_client.models.scenario_set_create import ScenarioSetCreate

# use queries + qrels to upload eval scenario
scenario_points = [
    SeedData(
        input_=queries[qid],
        result=list(dids),
    ) for qid, dids in relevant_docs.items()
]

# create the scenario set for evaluation
create_request = ScenarioSetCreate(
    name="WebBizz Retrieval - Queries",
    seed_data = scenario_points,
)

seed_scenario = okareo.create_scenario_set(create_request)
seed_scenario.app_link

'http://localhost:3000/project/d38b3714-8c8f-4d69-8c07-cc7285bbe1b5/scenario/3048b921-36e9-4b47-a631-3922ed4c2e80'

Upload a custom retrieval model using a Chroma collection of embeddings.

In [6]:
import chromadb

chroma_client = chromadb.Client()

collection = chroma_client.create_collection(name="retrieval_test", metadata={"hnsw:space": "cosine"})

collection.add(
    documents=list(corpus.values()),
    ids=list(corpus.keys()),
    embeddings=model.encode(list(corpus.values())),
)

In [7]:
# A funtion to convert the query results from our ChromaDB collection into a list of dictionaries with the document ID, score, metadata, and label
def query_results_to_score(results):
    parsed_ids_with_scores = []
    for i in range(0, len(results['distances'][0])):
        # Create a score based on cosine similarity
        score = (2 - results['distances'][0][i]) / 2
        parsed_ids_with_scores.append(
            {
                "id": results['ids'][0][i],
                "score": score,
                "metadata": {"article": corpus[results['ids'][0][i]]},
                "label": f"WebBizz Article w/ ID: {results['ids'][0][i]}"
            }
        )
    return parsed_ids_with_scores

In [8]:
from okareo.model_under_test import CustomModel, ModelInvocation

mut_name = f"Retrieval Model - all-MiniLM-L6-v2"

class RetrievalModel(CustomModel):
    def invoke(self, input: str):
        embeddings = model.encode([input])
        results = collection.query(
            query_embeddings=embeddings,
            n_results=5
        )
        # return a tuple of (parsed_ids_with_scores, overall model response context)
        return ModelInvocation(
            model_prediction=query_results_to_score(results),
            model_input=input,
            raw_model_output={'model_data': input}
        )

# Register the model to use in the test run
model_under_test = okareo.register_model(
    name=mut_name,
    model=[RetrievalModel(name=RetrievalModel.__name__)],
    update=True
)

In [9]:
from okareo_api_client.models import TestRunType

# Import the datetime module for timestamping
from datetime import datetime

# Define thresholds for the evaluation metrics
at_k_intervals = [1, 2, 3, 4, 5] 
metrics_kwargs={
    "accuracy_at_k": at_k_intervals ,
    "precision_recall_at_k": at_k_intervals ,
    "ndcg_at_k": at_k_intervals,
    "mrr_at_k": at_k_intervals,
    "map_at_k": at_k_intervals,
}

# Perform a test run using the uploaded scenario set
test_run_item = model_under_test.run_test(
    scenario=seed_scenario, # use the scenario from the scenario set uploaded earlier
    name=f"WebBizz Retrieval Test Run - {datetime.now().strftime('%m-%d %H:%M:%S')}", # add a timestamp to the test run name
    test_run_type=TestRunType.INFORMATION_RETRIEVAL, # specify that we are running an information retrieval test
    calculate_metrics=True,
    # Define the evaluation metrics to calculate
    metrics_kwargs=metrics_kwargs
)

# Generate a link back to Okareo for evaluation visualization
model_results = test_run_item.model_metrics.to_dict()
app_link = test_run_item.app_link
print(f"See results in Okareo: {app_link}")

See results in Okareo: http://localhost:3000/project/d38b3714-8c8f-4d69-8c07-cc7285bbe1b5/eval/9f6e1424-b48b-46cc-86d9-fb7416acb804


## Expand the finetuning set with Failure Rows

To improve our finetuned model, we need a fine-tuning set that is similar to our WebBizz data. To do this, we extract rows from our retrieval evaluation based on some failure criteria, and we generate new queries based on these failed rows.

In [10]:
# filter IDs based on failure criteria
K = "5"
filter_thresh = 0.5
filter_metric = "recall"

failed_ids = []
for id, metrics in model_results['row_level_metrics'].items():
    if metrics[K][filter_metric] <= filter_thresh:
        failed_ids.append(id)
print(f"-> {len(failed_ids)} queries have {filter_metric}@{K} <= {filter_thresh}")

-> 17 queries have recall@5 <= 0.5


In [11]:
# use queries + qrels to upload failure scenario
sdp = okareo.get_scenario_data_points(seed_scenario.scenario_id)

scenario_points = []
query_to_id = {v: k for k, v in queries.items()}
for dp in sdp:
    # get the qid for failed okareo scenario IDs
    if dp.id in failed_ids:
        qid = query_to_id[dp.input_]
        dids = relevant_docs[qid]
        scenario_points.append(
            SeedData(
                input_=queries[qid],
                result=list(dids),
            )
        )

# create the scenario set for evaluation
create_request = ScenarioSetCreate(
    name="WebBizz Retrieval - Queries (Failure Rows)",
    seed_data = scenario_points,
)

failure_scenario = okareo.create_scenario_set(create_request)
failure_scenario.app_link

'http://localhost:3000/project/d38b3714-8c8f-4d69-8c07-cc7285bbe1b5/scenario/9fca01cb-7cd9-4672-93ec-ea7533b3c5bc'

In [12]:
# generate new queries based on the failure rows
from okareo_api_client.models.scenario_set_generate import ScenarioSetGenerate
from okareo_api_client.models.generation_tone import GenerationTone
from okareo_api_client.models.scenario_type import ScenarioType

generate_request = ScenarioSetGenerate(
    source_scenario_id=failure_scenario.scenario_id,
    name="WebBizz Retrieval - Queries (Rephrased Failure Rows)",
    number_examples=2,
    generation_type=ScenarioType.REPHRASE_INVARIANT,
    generation_tone=GenerationTone.INFORMAL,
)

rephrased_scenario = okareo.generate_scenario_set(generate_request)

In [13]:
# format the scenario data points properly for the huggingface trainer
# this requires that we pivot each {'query': ..., 'answers': [1, ..., N]}
# to {'query': ..., 'answer': 1}, ..., {'query': ..., 'answer': N}
sdp = okareo.get_scenario_data_points(rephrased_scenario.scenario_id)

finetuning_embedding_data = []
for dp in sdp:
    for did in dp.result:
        finetuning_embedding_data.append({'query': dp.input_, 'answer': corpus[did]})

print(len(finetuning_embedding_data))

114


In [14]:
import json

file_path = f"./data/webbizz_finetuning_embedding_data.jsonl"

# write the finetuning data to a jsonl file
with open(file_path, "w") as file:
    for row in finetuning_embedding_data:
        file.write(json.dumps(row) + "\n")

## Fine-tune the Model on the Generated Scenario

Use the generated data to repeat the fine-tuning and evaluation steps from before, and compare the performance of the two models in Okareo.

For more details on fine-tuning embedding models, see this [huggingface blog on training sentence transformers](https://huggingface.co/blog/train-sentence-transformers#local-data-that-requires-pre-processing).

In [15]:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MultipleNegativesRankingLoss

# Load a model to train/finetune
huggingface_model_name = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(huggingface_model_name)

# Initialize the MultipleNegativesRankingLoss
# This loss requires pairs of queries and related document chunks
loss = MultipleNegativesRankingLoss(model)

# Load an example training dataset that works with our loss function:
dataset = load_dataset("json", data_files="./data/webbizz_finetuning_embedding_data.jsonl")

Generating train split: 0 examples [00:00, ? examples/s]

In [16]:
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{huggingface_model_name}-webbizz",
    # Optional training parameters:
    num_train_epochs=20,
    per_device_train_batch_size=32, # smaller batch sizes for our smaller dataset
    per_device_eval_batch_size=32,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if your GPU can't handle FP16
    bf16=False,  # Set to True if your GPU supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # Losses using "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name=f"{huggingface_model_name}-webbizz",  # Used in W&B if `wandb` is installed
)

In [17]:
from sentence_transformers import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    loss=loss,
)
trainer.train()


  0%|          | 0/80 [00:00<?, ?it/s]

{'train_runtime': 20.7249, 'train_samples_per_second': 110.013, 'train_steps_per_second': 3.86, 'train_loss': 1.2328235626220703, 'epoch': 9.5}


TrainOutput(global_step=80, training_loss=1.2328235626220703, metrics={'train_runtime': 20.7249, 'train_samples_per_second': 110.013, 'train_steps_per_second': 3.86, 'total_flos': 0.0, 'train_loss': 1.2328235626220703, 'epoch': 9.5})

## Evaluate the fine-tuned model

Create a new ChromaDB collection with the updated document embeddings to use with our new fine-tuned CustomModel.

In [18]:
finetuned_collection = chroma_client.create_collection(name="retrieval_finetune_test", metadata={"hnsw:space": "cosine"})

finetuned_collection.add(
    documents=list(corpus.values()),
    ids=list(corpus.keys()),
    embeddings=model.encode(list(corpus.values())),
)

In [19]:
mut_name = f"Finetuned Retrieval Model - all-MiniLM-L6-v2"

class FinetunedRetrievalModel(CustomModel):
    def invoke(self, input: str):
        embeddings = model.encode([input])
        results = finetuned_collection.query(
            query_embeddings=embeddings,
            n_results=5
        )
        # return a tuple of (parsed_ids_with_scores, overall model response context)
        return ModelInvocation(
            model_prediction=query_results_to_score(results),
            model_input=input,
            raw_model_output={'model_data': input}
        )

# Register the model to use in the test run
model_under_test = okareo.register_model(
    name=mut_name,
    model=[FinetunedRetrievalModel(name=FinetunedRetrievalModel.__name__)],
    update=True
)

In [20]:
# Perform a test run using the uploaded scenario set
test_run_item = model_under_test.run_test(
    scenario=seed_scenario, # use the scenario from the scenario set uploaded earlier
    name=f"WebBizz Finetuned Retrieval Test Run - {datetime.now().strftime('%m-%d %H:%M:%S')}", # add a timestamp to the test run name
    test_run_type=TestRunType.INFORMATION_RETRIEVAL, # specify that we are running an information retrieval test
    calculate_metrics=True,
    # Define the evaluation metrics to calculate
    metrics_kwargs=metrics_kwargs
)

# Generate a link back to Okareo for evaluation visualization
finetuned_model_results = test_run_item.model_metrics.to_dict()
app_link = test_run_item.app_link
print(f"See results in Okareo: {app_link}")

See results in Okareo: http://localhost:3000/project/d38b3714-8c8f-4d69-8c07-cc7285bbe1b5/eval/1c5e368f-73f4-4907-943c-43a203c66e24


In [28]:
# compare the results pre/post fine-tuning

print(f"Pre Fine-tuning | Post Fine-tuning | Difference")
for key in model_results.keys():
    if key == "row_level_metrics":
        continue
    print(f"------ {key} ------")
    for K in at_k_intervals:
        pre = model_results[key][str(K)]
        post = finetuned_model_results[key][str(K)]
        diff = post - pre
        print(f"K={K}: {pre:4.3f} | {post:4.3f} | {'+' if diff >= 0 else '-'}{diff:4.3f}")

Pre Fine-tuning | Post Fine-tuning | Difference
------ Accuracy@k ------
K=1: 0.850 | 0.950 | +0.100
K=2: 0.950 | 0.967 | +0.017
K=3: 0.967 | 0.983 | +0.017
K=4: 0.983 | 1.000 | +0.017
K=5: 0.983 | 1.000 | +0.017
------ Precision@k ------
K=1: 0.850 | 0.950 | +0.100
K=2: 0.600 | 0.717 | +0.117
K=3: 0.478 | 0.578 | +0.100
K=4: 0.383 | 0.483 | +0.100
K=5: 0.333 | 0.407 | +0.073
------ Recall@k ------
K=1: 0.491 | 0.532 | +0.041
K=2: 0.631 | 0.729 | +0.098
K=3: 0.711 | 0.824 | +0.113
K=4: 0.748 | 0.882 | +0.135
K=5: 0.793 | 0.910 | +0.117
------ NDCG@k ------
K=1: 0.850 | 0.950 | +0.100
K=2: 0.757 | 0.870 | +0.113
K=3: 0.759 | 0.879 | +0.120
K=4: 0.768 | 0.899 | +0.131
K=5: 0.787 | 0.907 | +0.120
------ MRR@k ------
K=1: 0.850 | 0.950 | +0.100
K=2: 0.900 | 0.958 | +0.058
K=3: 0.906 | 0.964 | +0.058
K=4: 0.910 | 0.968 | +0.058
K=5: 0.910 | 0.968 | +0.058
------ MAP@k ------
K=1: 0.850 | 0.950 | +0.100
K=2: 0.704 | 0.842 | +0.137
K=3: 0.694 | 0.844 | +0.151
K=4: 0.697 | 0.862 | +0.165
K=5: 

We see that all retrieval metrics have improved across all K values!