# Improving Fine-tuned Retrieval Models in Okareo

<a target="_blank" href="https://colab.research.google.com/github/okareo-ai/okareo-python-sdk/blob/main/examples/retrieval_embedding_finetuning_eval.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## 🎯 Goals

After using this notebook, you will be able to:
- Evaluate a pre-trained embedding model in Okareo
- Filter the results of the retrieval evaluation
- Generate fine-tuning data based on the filtered results
- Fine-tune the model with the generated data
- Compare the performance of the embedding models pre/post fine-tuning in Okareo 

## Problem Statement: Retrieval Model

Suppose we are developing a RAG system that answers user questions.

This notebook focuses on finetuning an open source embedding model, [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2), for the Context Retrieval component of the RAG pipeline.

The purpose of Context Retrieval is to fetch relevant documents/chunks to build the context for a downstream generative model. The better performance we achieve on Retrieval, the higher quality the RAG's final output will be.

# Evaluate the pre-trained model in Okareo

Here, we perform a retrieval evaluation in Okareo by:
1. Uploading the retrieval data as a scenario in Okareo
2. Defining a CustomModel for retrieval using ChromaDB to store our model's embeddings
3. Run a retrieval evaluation on our CustomModel with the scenario

In [1]:
import os
import random
import string
from okareo import Okareo

OKAREO_API_KEY = os.environ["OKAREO_API_KEY"]
okareo = Okareo(OKAREO_API_KEY)
random_string = ''.join(random.choices(string.ascii_letters, k=5))


In [None]:
from okareo_api_client.models import SeedData
from okareo_api_client.models.scenario_set_create import ScenarioSetCreate

import json

scenarios = {}
# LOAD BOTH TRAIN AND TEST DATA SCENARIOS
for file_info in [
    ("./data/squad_qa_train_210.jsonl", "train"), 
    ("./data/squad_qa_test_90.jsonl", "test")
]:
    filename, split = file_info
    scenario_data = []
    
    with open(filename, "r") as f:
        for line in f:
            data = json.loads(line)
            scenario_data.append(
                SeedData(
                    input_=data["input"],
                    result=[data["passage_id"]]
                )
            )

    # Create scenario
    request = ScenarioSetCreate(
        name=f"Autogen Retrieval - {split} - {random_string}",
        seed_data=scenario_data
    )
    scenario = okareo.create_scenario_set(request)
    scenarios[split] = scenario
    print(f"{split}: {scenario.app_link}")

Upload a custom retrieval model using a Chroma collection of embeddings.

In [4]:
# CREATE INSTANCE OF CHROMADB AND LOAD CORPUS
import chromadb
import pandas as pd
from chromadb.utils import embedding_functions
from okareo.model_under_test import CustomModel, ModelInvocation


corpus_file = "./data/squad_corpus.jsonl"
embedding_model_name = "all-MiniLM-L6-v2"
#embedding_model_name = "/Users/mrpositive/Downloads/ft-models/autogen-retrieval-finetune"

collection_name = "squad-corpus"
embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embedding_model_name)


def load_and_initialize_collection(collection_name: str, embedding_function: callable, corpus_file="./data/squad_corpus.jsonl"):
    # Load corpus from file
    jsonObj = pd.read_json(path_or_buf=corpus_file, lines=True)
    corpus = dict(zip(jsonObj.result, jsonObj.input))

    chroma_client = chromadb.Client(chromadb.config.Settings(allow_reset=True))
    chroma_client.reset() # chromadb has some weird'a bug that hangs on to some state

    collection = chroma_client.get_or_create_collection(collection_name,
                                                metadata={"hnsw:space": "cosine"},
                                                embedding_function=embedding_function)

    # Add the documents to the collection with the corresponding metadata
    collection.add(
        documents=list(jsonObj.input),
        ids=[str(x) for x in list(jsonObj.result)],
    )
    
    return collection, corpus

collection, corpus = load_and_initialize_collection(collection_name, embedding_function)

# A funtion to convert the query results from our ChromaDB collection into a list of dictionaries with the document ID, score, metadata, and label
def query_results_to_score(results):
    parsed_ids_with_scores = []
    for i in range(0, len(results['distances'][0])):
        # Create a score based on cosine similarity
        score = (2 - results['distances'][0][i]) / 2
        parsed_ids_with_scores.append(
            {
                "id": results['ids'][0][i],
                "score": score,
                "metadata": {"context": corpus[results['ids'][0][i]]},
                "label": f"Context w/ ID: {results['ids'][0][i]}"
            }
        )
    return parsed_ids_with_scores


mut_name = f"Retrieval Model - {embedding_model_name}"

class RetrievalModel(CustomModel):
    def invoke(self, input: str):
        results = collection.query(
            query_texts=[input],
            n_results=5
        )
        # return a tuple of (parsed_ids_with_scores, overall model response context)
        return ModelInvocation(
            model_prediction=query_results_to_score(results),
            model_input=input,
            model_output_metadata={'model_data': input}
        )

# Register the model to use in the test run
model_under_test = okareo.register_model(
    name=mut_name,
    model=[RetrievalModel(name=RetrievalModel.__name__)],
    update=True
)

  from tqdm.autonotebook import tqdm, trange


In [5]:
# EVALUATE THE RETRIEVAL MODEL
from okareo_api_client.models import TestRunType

# Import the datetime module for timestamping
from datetime import datetime

# Define thresholds for the evaluation metrics
at_k_intervals = [1, 2, 3, 4, 5] 
metrics_kwargs={
    "accuracy_at_k": at_k_intervals ,
    "precision_recall_at_k": at_k_intervals ,
    "ndcg_at_k": at_k_intervals,
    "mrr_at_k": at_k_intervals,
    "map_at_k": at_k_intervals,
}

# Perform a test run using the uploaded scenario set
test_runs = {}
for split_name, seed_scenario in scenarios.items():
    if split_name == "train":
        continue
    test_run_item = model_under_test.run_test(
        scenario=seed_scenario, # use the scenario from the scenario set uploaded earlier
        name=f"Retrieval ({split_name}) - {datetime.now().strftime('%m-%d %H:%M:%S')}", # add a timestamp to the test run name
        test_run_type=TestRunType.INFORMATION_RETRIEVAL, # specify that we are running an information retrieval test
        calculate_metrics=True,
        # Define the evaluation metrics to calculate
        metrics_kwargs=metrics_kwargs
    )

    # Generate a link back to Okareo for evaluation visualization
    model_results = test_run_item.model_metrics.to_dict()
    test_runs[split_name] = model_results
    app_link = test_run_item.app_link
    print(f"See {split_name} eval results in Okareo: {app_link}")

See test eval results in Okareo: http://localhost:3000/project/47895342-8441-426b-bc86-e8dd831d2971/eval/79598552-4731-4230-bb8c-06443ac5667f


## Expand the finetuning set with Failure Rows

To improve our finetuned model, we need a fine-tuning set that is similar to our data. To do this, we extract rows from our retrieval evaluation based on some failure criteria, and we generate new queries based on these failed rows.

In [10]:
# filter IDs based on failure criteria
K = "3"
filter_thresh = 0.5
filter_metric = "recall"

# get failure rows from train split eval
failed_ids = []
for id, metrics in test_runs['train']['row_level_metrics'].items():
    if metrics[K][filter_metric] <= filter_thresh:
        failed_ids.append(id)
print(f"-> {len(failed_ids)} queries have {filter_metric}@{K} <= {filter_thresh}")

sdp = okareo.get_scenario_data_points(scenarios['train'].scenario_id)

scenario_points = [
    SeedData(input_=dp.input_, result=dp.result)
    for dp in sdp
    if dp.id in failed_ids
]

# create the scenario set for evaluation
create_request = ScenarioSetCreate(
    name=f"Autogen Retrieval - train - failure rows - {random_string}",
    seed_data = scenario_points,
)

failure_scenario = okareo.create_scenario_set(create_request)
failure_scenario.app_link


-> 37 queries have recall@3 <= 0.5


'https://app.okareo.com/project/394c2c12-be7a-47a6-911b-d6c673bc543b/scenario/8946d33d-cd76-4a74-a2b0-aff48ce47f7e'

In [56]:
# generate new queries based on the failure rows
from okareo_api_client.models.scenario_set_generate import ScenarioSetGenerate
from okareo_api_client.models.generation_tone import GenerationTone
from okareo_api_client.models.scenario_type import ScenarioType

generate_request = ScenarioSetGenerate(
    source_scenario_id=failure_scenario.scenario_id,
    name=f"Autogen Retrieval - train - augmented failure rows - {random_string}",
    number_examples=3,
    generation_type=ScenarioType.REPHRASE_INVARIANT,
    generation_tone=GenerationTone.NEUTRAL,
)

rephrased_scenario = okareo.generate_scenario_set(generate_request)

In [None]:
import json

# format the scenario data points properly for the huggingface trainer
# this requires that we pivot each {'query': ..., 'answers': [1, ..., N]}
# to {'query': ..., 'answer': 1}, ..., {'query': ..., 'answer': N}
rephrased_sdp = okareo.get_scenario_data_points('b4b0fff6-4e80-41fa-b6dd-f6c96a1a13b2')

sdp = okareo.get_scenario_data_points(scenarios['train'].scenario_id)

# augment the train data with the rephrased failure rows
# use the augmented train set for finetuning
finetuning_embedding_data = []
for dp in sdp + rephrased_sdp:
    for did in dp.result:
        finetuning_embedding_data.append({'query': dp.input_, 'answer': corpus[did]})

print(len(finetuning_embedding_data))


file_path = f"autogen_retrieval_finetuning_embedding_data.jsonl"

# write the finetuning data to a jsonl file
with open(file_path, "w") as file:
    for row in finetuning_embedding_data:
        file.write(json.dumps(row) + "\n")

## Fine-tune the Model on the Augmented Train Split

Use the train split and the rephrased failures to repeat fine-tune the embedding model, and then compare the retrieval performance of the fine-tuned model to the pretrained model in Okareo.

For more details on fine-tuning embedding models, see this [huggingface blog on training sentence transformers](https://huggingface.co/blog/train-sentence-transformers#local-data-that-requires-pre-processing).

In [None]:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MultipleNegativesRankingLoss

# Load a model to train/finetune
huggingface_model_name = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(huggingface_model_name)

# Initialize the MultipleNegativesRankingLoss
# This loss requires pairs of queries and related document chunks
loss = MultipleNegativesRankingLoss(model)

# Load an example training dataset that works with our loss function:
dataset = load_dataset("json", data_files=file_path)

In [77]:
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

output_dir = f"models/{huggingface_model_name}/autogen-retrieval-finetune"

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    # Optional training parameters:
    num_train_epochs=4,
    per_device_train_batch_size=32, # reduced from 32
    per_device_eval_batch_size=32, # reduced from 32
    warmup_ratio=0.1,
    fp16=False,  # Set to False if your GPU can't handle FP16
    bf16=False,  # Set to True if your GPU supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # Losses using "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    learning_rate=5e-6, # added for smaller dataset
)

In [78]:
from sentence_transformers.trainer import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    loss=loss,
)
trainer.train()


Step,Training Loss
100,0.4782
200,0.2885


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

TrainOutput(global_step=212, training_loss=0.38013167313809665, metrics={'train_runtime': 17.1361, 'train_samples_per_second': 392.154, 'train_steps_per_second': 12.372, 'total_flos': 0.0, 'train_loss': 0.38013167313809665, 'epoch': 3.7735849056603774})

In [7]:
from sentence_transformers import SentenceTransformer

# Save the trained model
model.save(output_dir)

# Load the saved model
ft_model = SentenceTransformer("models/sentence-transformers/all-MiniLM-L6-v2/autogen-retrieval-finetune")

## Evaluate the fine-tuned model

Create a new ChromaDB collection with the updated document embeddings to use with our new fine-tuned CustomModel.

In [11]:
import chromadb
import pandas as pd

chroma_client = chromadb.Client(chromadb.config.Settings(allow_reset=True))
chroma_client.reset() # chromadb has some weird'a bug that hangs on to some state

collection_name = "retrieval_finetune_test"
# if collection_name in [col.name for col in chroma_client.list_collections()]:
#     chroma_client.delete_collection(collection_name)
finetuned_collection = chroma_client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})

corpus_file = "./data/squad_corpus.jsonl"
jsonObj = pd.read_json(path_or_buf=corpus_file, lines=True)

finetuned_collection.add(
    documents=list(jsonObj.input),
    ids=[str(x) for x in list(jsonObj.result)],
    embeddings=ft_model.encode(list(jsonObj.input)),
)

In [9]:
from okareo.model_under_test import CustomModel, ModelInvocation


mut_name = f"Finetuned Retrieval Model - MiniLM-L6-v2"

class FinetunedRetrievalModel(CustomModel):
    def invoke(self, input: str):
        embeddings = ft_model.encode([input])
        results = finetuned_collection.query(
            query_embeddings=embeddings,
            n_results=5
        )
        # return a tuple of (parsed_ids_with_scores, overall model response context)
        return ModelInvocation(
            model_prediction=query_results_to_score(results),
            model_input=input,
            model_output_metadata={'model_data': input}
        )

# Register the model to use in the test run
ft_model_under_test = okareo.register_model(
    name=mut_name,
    model=[FinetunedRetrievalModel(name=FinetunedRetrievalModel.__name__)],
    update=True
)

In [None]:
from okareo_api_client.models import TestRunType

# Import the datetime module for timestamping
from datetime import datetime

# Define thresholds for the evaluation metrics
at_k_intervals = [1, 2, 3, 4, 5] 
metrics_kwargs={
    "accuracy_at_k": at_k_intervals ,
    "precision_recall_at_k": at_k_intervals ,
    "ndcg_at_k": at_k_intervals,
    "mrr_at_k": at_k_intervals,
    "map_at_k": at_k_intervals,
}

# Perform a test run using the uploaded scenario set
finetuned_test_runs = {}
for split_name, seed_scenario in scenarios.items():
    test_run_item = ft_model_under_test.run_test(
        scenario=seed_scenario, # use the scenario from the scenario set uploaded earlier
        name=f"Retrieval ({split_name}) - {datetime.now().strftime('%m-%d %H:%M:%S')}", # add a timestamp to the test run name
        test_run_type=TestRunType.INFORMATION_RETRIEVAL, # specify that we are running an information retrieval test
        calculate_metrics=True,
        # Define the evaluation metrics to calculate
        metrics_kwargs=metrics_kwargs
    )

    # Generate a link back to Okareo for evaluation visualization
    finetuned_model_results = test_run_item.model_metrics.to_dict()
    finetuned_test_runs[split_name] = finetuned_model_results
    app_link = test_run_item.app_link
    print(f"See {split_name} eval results in Okareo: {app_link}")

In [None]:
# compare the results pre/post fine-tuning

print(f"Pre Fine-tuning | Post Fine-tuning | Difference")
for split_name in ["train", "test"]:
    print(f"------ {split_name} ------")
    model_results = test_runs[split_name]
    finetuned_model_results = finetuned_test_runs[split_name]
    for key in finetuned_model_results.keys():
        if key == "row_level_metrics":
            continue
        print(f"------ {key} ------")
        for K in at_k_intervals:
            pre = model_results[key][str(K)]
            post = finetuned_model_results[key][str(K)]
            diff = post - pre
            print(f"K={K}: {pre:4.3f} | {post:4.3f} | {'+' if diff >= 0 else ''}{diff:4.3f}")