# Finetuning a Linear Adapter on Top of any Black-Box Embedding Model


We have capabilities in LlamaIndex allowing you to fine-tune a linear adapter on top of embeddings produced from any model (sentence_transformers, OpenAI, and more). 

This allows you to transform your embedding representations into a new latent space that's optimized for retrieval over your specific data and queries. This can lead to small increases in retrieval performance that in turn translate to better performing RAG systems.

We do this via our `EmbeddingAdapterFinetuneEngine` abstraction.

## Generate Corpus

We use our helper abstractions, `generate_qa_embedding_pairs`, to generate our training and evaluation dataset. This function takes in any set of text nodes (chunks) and generates a structured dataset containing (question, context) pairs.

In [1]:
import json

from llama_index import SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from llama_index.schema import MetadataMode

In [2]:
TRAIN_FILES = ["../../../examples/data/10k/lyft_2021.pdf"]
VAL_FILES = ["../../../examples/data/10k/uber_2021.pdf"]

TRAIN_CORPUS_FPATH = "./data/train_corpus.json"
VAL_CORPUS_FPATH = "./data/val_corpus.json"

In [3]:
def load_corpus(files, verbose=False):
    if verbose:
        print(f"Loading files {files}")

    reader = SimpleDirectoryReader(input_files=files)
    docs = reader.load_data()
    if verbose:
        print(f"Loaded {len(docs)} docs")

    parser = SimpleNodeParser.from_defaults()
    nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)

    if verbose:
        print(f"Parsed {len(nodes)} nodes")

    return nodes

We do a very naive train/val split by having the Lyft corpus as the train dataset, and the Uber corpus as the val dataset.

In [4]:
train_nodes = load_corpus(TRAIN_FILES, verbose=True)
val_nodes = load_corpus(VAL_FILES, verbose=True)

Loading files ['../../../examples/data/10k/lyft_2021.pdf']
Loaded 238 docs


Parsing documents into nodes:   0%|          | 0/238 [00:00<?, ?it/s]

Parsed 349 nodes
Loading files ['../../../examples/data/10k/uber_2021.pdf']
Loaded 307 docs


Parsing documents into nodes:   0%|          | 0/307 [00:00<?, ?it/s]

Parsed 418 nodes


### Generate synthetic queries

Now, we use an LLM (gpt-3.5-turbo) to generate questions using each text chunk in the corpus as context.

Each pair of (generated question, text chunk used as context) becomes a datapoint in the finetuning dataset (either for training or evaluation).

In [1]:
from llama_index.finetuning import (
    generate_qa_embedding_pairs,
    EmbeddingQAFinetuneDataset,
)

In [None]:
train_dataset = generate_qa_embedding_pairs(train_nodes)
val_dataset = generate_qa_embedding_pairs(val_nodes)

train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")

In [2]:
# [Optional] Load
train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")

## Run Embedding Finetuning

We then fine-tune our linear adapter on top of an existing embedding model. We import our new `EmbeddingAdapterFinetuneEngine` abstraction, which takes in an existing embedding model and a set of training parameters.

In [11]:
from llama_index.finetuning import EmbeddingAdapterFinetuneEngine
from llama_index.embeddings import resolve_embed_model
import torch

base_embed_model = resolve_embed_model("local:BAAI/bge-small-en")

finetune_engine = EmbeddingAdapterFinetuneEngine(
    train_dataset,
    base_embed_model,
    model_output_path="model_output_test",
    # bias=True,
    epochs=4,
    verbose=True,
    # optimizer_class=torch.optim.SGD,
    # optimizer_params={"lr": 0.01}
)

In [None]:
finetune_engine.finetune()

In [21]:
embed_model = finetune_engine.get_finetuned_model()

# alternatively import model
# from llama_index.embeddings import LinearAdapterEmbeddingModel
# embed_model = LinearAdapterEmbeddingModel(base_embed_model, "model_output_test")

## Evaluate Finetuned Model

We compare the fine-tuned model against the base model, as well as against text-embedding-ada-002.

We evaluate with two ranking metrics:
- **Hit-rate metric**: For each (query, context) pair, we retrieve the top-k documents with the query. It's a hit if the results contain the ground-truth context.
- **Mean Reciprocal Rank**: A slightly more granular ranking metric that looks at the "reciprocal rank" of the ground-truth context in the top-k retrieved set. The reciprocal rank is defined as 1/rank. Of course, if the results don't contain the context, then the reciprocal rank is 0.

In [25]:
from llama_index.embeddings import OpenAIEmbedding
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import TextNode
from tqdm.notebook import tqdm
import pandas as pd

from eval_utils import evaluate, display_results

In [26]:
ada = OpenAIEmbedding()
ada_val_results = evaluate(val_dataset, ada)

Generating embeddings:   0%|          | 0/395 [00:00<?, ?it/s]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 790/790 [02:54<00:00,  4.53it/s]


In [27]:
display_results(["ada"], [ada_val_results])

Unnamed: 0,retrievers,hit_rate,mrr
0,ada,0.870886,0.730105


In [28]:
bge = "local:BAAI/bge-small-en"
bge_val_results = evaluate(val_dataset, bge)

Generating embeddings:   0%|          | 0/395 [00:00<?, ?it/s]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 790/790 [00:20<00:00, 38.41it/s]


In [29]:
display_results(["bge"], [bge_val_results])

Unnamed: 0,retrievers,hit_rate,mrr
0,bge,0.787342,0.643038


In [31]:
ft_val_results = evaluate(val_dataset, embed_model)

Generating embeddings:   0%|          | 0/395 [00:00<?, ?it/s]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 790/790 [00:21<00:00, 36.95it/s]


In [32]:
display_results(["ft"], [ft_val_results])

Unnamed: 0,retrievers,hit_rate,mrr
0,ft,0.798734,0.662152


Here we show all the results concatenated together.

In [33]:
display_results(
    ["ada", "bge", "ft"], [ada_val_results, bge_val_results, ft_val_results]
)

Unnamed: 0,retrievers,hit_rate,mrr
0,ada,0.870886,0.730105
1,bge,0.787342,0.643038
2,ft,0.798734,0.662152
