<a href="https://colab.research.google.com/github/marib00/llamaindex-embedding-lora/finetune_embedding_lora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LoRA finetuning of any Black-Box Embedding Model

This notebook is based on https://github.com/run-llama/llama_index/blob/3e5d0a146fcda01a984818d381f31a19287aead8/docs/examples/finetuning/embeddings/finetune_embedding_adapter.ipynb and demonstrates how to:

- Generate a fine-tuning corpus using a local LLM
- Fine-tune a local embedding model using LoRA

The latter is achieved by subclassing the `EmbeddingAdapterFinetuneEngine` and a few tricks in order to make it behave (in the way we want it to).

## Generate Corpus

We use our helper abstractions, `generate_qa_embedding_pairs`, to generate our training and evaluation dataset. This function takes in any set of text nodes (chunks) and generates a structured dataset containing (question, context) pairs.

In [1]:
import torch
from typing import Any, List, Optional, Tuple#, Union
from llama_index.core import SimpleDirectoryReader
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.huggingface.base import HuggingFaceEmbedding
from llama_index.embeddings.huggingface.pooling import Pooling
from llama_index.finetuning import EmbeddingAdapterFinetuneEngine
from llama_index.finetuning.embeddings.adapter_utils import BaseAdapter

Download Data

In [2]:
!mkdir -p 'data/10k/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf' -O 'data/10k/uber_2021.pdf'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'

--2024-03-18 14:51:34--  https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1880483 (1.8M) [application/octet-stream]
Saving to: ‘data/10k/uber_2021.pdf’


2024-03-18 14:51:34 (41.6 MB/s) - ‘data/10k/uber_2021.pdf’ saved [1880483/1880483]

--2024-03-18 14:51:34--  https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1440303 (1.4M) [application/oc

In [3]:
TRAIN_FILES = ["./data/10k/lyft_2021.pdf"]
VAL_FILES = ["./data/10k/uber_2021.pdf"]

TRAIN_CORPUS_FPATH = "./data/train_corpus.json"
VAL_CORPUS_FPATH = "./data/val_corpus.json"

In [4]:
def load_corpus(files, verbose=False):
    if verbose: print(f"Loading files {files}")

    reader = SimpleDirectoryReader(input_files=files)
    docs = reader.load_data()
    if verbose: print(f"Loaded {len(docs)} docs")

    parser = SentenceSplitter()
    nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)
    if verbose: print(f"Parsed {len(nodes)} nodes")

    return nodes

We do a very naive train/val split by having the Lyft corpus as the train dataset, and the Uber corpus as the val dataset.

In [5]:
train_nodes = load_corpus(TRAIN_FILES, verbose=True)
val_nodes = load_corpus(VAL_FILES, verbose=True)

Loading files ['./data/10k/lyft_2021.pdf']
Loaded 238 docs


Parsing nodes:   0%|          | 0/238 [00:00<?, ?it/s]

Parsed 344 nodes
Loading files ['./data/10k/uber_2021.pdf']
Loaded 307 docs


Parsing nodes:   0%|          | 0/307 [00:00<?, ?it/s]

Parsed 410 nodes


### Generate synthetic queries

Now, we use an LLM (Mixtral) to generate questions using each text chunk in the corpus as context.

Each pair of (generated question, text chunk used as context) becomes a datapoint in the finetuning dataset (either for training or evaluation).

In [6]:
from llama_index.finetuning import generate_qa_embedding_pairs
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset

In [7]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core.prompts import PromptTemplate

model_id = 'TheBloke/Mixtral-8x7B-v0.1-GPTQ'
code_revision = 'gptq-4bit-32g-actorder_True'
tokenizer = AutoTokenizer.from_pretrained(model_id, attn_implementation='flash_attention_2')
model = AutoModelForCausalLM.from_pretrained(model_id, code_revision=code_revision, device_map='auto')

llm = HuggingFaceLLM(
    model=model,
    tokenizer=tokenizer,
    query_wrapper_prompt=PromptTemplate('[INST] {query_str} [/INST]'),
    context_window=16*1024,
    max_new_tokens=1024,
)



In [None]:
train_dataset = generate_qa_embedding_pairs(train_nodes, llm=llm)
train_dataset.save_json("train_dataset.json")

val_dataset = generate_qa_embedding_pairs(val_nodes, llm=llm)
val_dataset.save_json("val_dataset.json")

In [2]:
# release cuda memory - at this point it's probably a good idea to restart the kernel and load the data
from llama_index.finetuning import generate_qa_embedding_pairs
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset

train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")

## Run Embedding Finetuning

Here we first define the subclasses needed for LoRA finetuning.

In [3]:
class UniversalAdapter(torch.nn.Identity, BaseAdapter):
    """Adapter model that does nothing, but includes trainable parameters 
    (e.g. LoRAs) of the embedding model, which the FinetuneEngine actually trains."""
    def __init__(self, embed_model):
        super().__init__()
        self.embed_model = embed_model

    def save(self, output_path):
        self.embed_model.save_pretrained(output_path, save_adapter=True, save_config=True)

In [4]:
class UniversalEmbeddingFinetuneEngine(EmbeddingAdapterFinetuneEngine):
    """Fintune any parameters of embed_model with requires_grad set to True, e.g. LoRA adapaters."""
    def __init__(
        self,
        dataset: EmbeddingQAFinetuneDataset,
        embed_model: BaseEmbedding,
        batch_size: int = 10,
        epochs: int = 1,
        dim: Optional[int] = None,
        device: Optional[str] = None,
        model_output_path: str = "model_output",
        model_checkpoint_path: Optional[str] = None,
        checkpoint_save_steps: int = 100,
        verbose: bool = False,
        bias: bool = False,
        **train_kwargs: Any,
    ) -> None:
        super().__init__(
            dataset=dataset,
            embed_model=embed_model,
            batch_size=batch_size,
            epochs=epochs,
            adapter_model=UniversalAdapter(embed_model._model),
            dim=dim,
            device=device,
            model_output_path=model_output_path,
            model_checkpoint_path=model_checkpoint_path,
            checkpoint_save_steps=checkpoint_save_steps,
            verbose=verbose,
            bias=bias,
            **train_kwargs,
        )

    def smart_batching_collate(self, batch: List) -> Tuple[Any, Any]:
        """Smart batching collate."""
        import torch
        from torch import Tensor

        query_embeddings: List[Tensor] = []
        text_embeddings: List[Tensor] = []

        for query, text in batch:
            query_embedding = self.embed_model.get_query_embedding(query)
            text_embedding = self.embed_model.get_text_embedding(text)

            query_embeddings.append(query_embedding)    # was stripping gradients: query_embeddings.append(torch.tensor(query_embedding))
            text_embeddings.append(text_embedding)      # was stripping gradients: text_embeddings.append(torch.tensor(text_embedding))

        query_embeddings_t = torch.stack(query_embeddings)
        text_embeddings_t = torch.stack(text_embeddings)

        return query_embeddings_t, text_embeddings_t

In [5]:
class HuggingFaceEmbeddingWithGrad(HuggingFaceEmbedding):
    """HuggingFaceEmbedding with gradient support."""

    def __getattr__(self, name: str) -> Any:
        return getattr(self._model, name)
    
    def _embed(self, sentences: List[str]) -> torch.Tensor:
        """Embed sentences."""
        encoded_input = self._tokenizer(
            sentences,
            padding=True,
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt",
        )

        # pop token_type_ids
        encoded_input.pop("token_type_ids", None)

        # move tokenizer inputs to device
        encoded_input = {
            key: val.to(self._device) for key, val in encoded_input.items()
        }

        model_output = self._model(**encoded_input)

        context_layer: "torch.Tensor" = model_output[0]
        if self.pooling == Pooling.CLS:
            embeddings = self.pooling.cls_pooling(context_layer)
        elif self.pooling == Pooling.LAST:
            embeddings = self.pooling.last_pooling(context_layer)           
        else:
            embeddings = self._mean_pooling(
                token_embeddings=context_layer,
                attention_mask=encoded_input["attention_mask"],
            )

        if self.normalize:
            import torch
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

        return embeddings  # was embeddings.tolist()

In [25]:
from pydantic import fields as pydantic_fields

class disable_pydantic:
    """Context manager to disable pydantic validation."""

    def __enter__(self) -> None:
        self.validate = pydantic_fields.ModelField.validate
        pydantic_fields.ModelField.validate = lambda *args, **kwargs: (args[1], None)

    def __exit__(self, *args) -> None:
        pydantic_fields.ModelField.validate = self.validate

### Fine-tune sfr-embedding-mistral

As of March 2024 SFR-Embedding-Mistral is at the top of the Massive Text Embedding Benchmark (MTEB) Leaderboard: https://huggingface.co/spaces/mteb/leaderboard

We quantize the model to 4-bit first:

In [7]:
model_id = 'Salesforce/SFR-Embedding-Mistral'
quant_path = f'/tmp/models/{model_id.replace("/","-")}-quant'

In [8]:
from transformers import BitsAndBytesConfig, AutoModel, AutoTokenizer

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(quant_path)

model = AutoModel.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

# freeze the model before saving just as a precaution
for param in model.parameters():
    param.requires_grad = False

model.save_pretrained(quant_path, low_cpu_mem_usage=False)
print(f'Quantized model saved to {quant_path}')

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Quantized model saved to /tmp/models/Salesforce-SFR-Embedding-Mistral-quant


In [9]:
# release cuda memory
del model, tokenizer, bnb_config
import gc; gc.collect()
with torch.no_grad(): torch.cuda.empty_cache()

In [10]:
lora_adapters_path = '/tmp/whatever'

In [11]:
from transformers import AutoModel, AutoTokenizer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

embed_tokenizer = AutoTokenizer.from_pretrained(quant_path)
embed_model = AutoModel.from_pretrained(quant_path, low_cpu_mem_usage=True)
embed_model.to = lambda _: embed_model  # quantized model does not have .to() method
for param in embed_model.parameters():
    param.requires_grad = False

In [12]:
hf_base_model = HuggingFaceEmbedding(
    model=embed_model, 
    tokenizer=embed_tokenizer, 
    query_instruction='Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery:',
    pooling='last',
    embed_batch_size=1
)

Evaluate the base model:

In [13]:
from eval_utils import evaluate, display_results

with torch.no_grad():
    base_sfr_val_results = evaluate(val_dataset, hf_base_model)
display_results(["base_sfr"], [base_sfr_val_results])

Generating embeddings:   0%|          | 0/410 [00:00<?, ?it/s]

100%|██████████| 861/861 [01:56<00:00,  7.39it/s]


Unnamed: 0,retrievers,hit_rate,mrr
0,base_sfr,0.872242,0.68494


In [16]:
# create the peft model
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
    task_type="FEATURE_EXTRACTION",
)

kbit_model = prepare_model_for_kbit_training(embed_model)
peft_model = get_peft_model(kbit_model, peft_config)

In [None]:
# ...or  load trained adapters
from peft import PeftModel
peft_model = PeftModel.from_pretrained(embed_model, lora_adapters_path)

In [17]:
hf_qlora_model = HuggingFaceEmbeddingWithGrad(
    model=peft_model, 
    tokenizer=embed_tokenizer, 
    query_instruction='Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery:',
    pooling='last',
    embed_batch_size=1
)

In [26]:
finetune_engine = UniversalEmbeddingFinetuneEngine(
    train_dataset,
    embed_model=hf_qlora_model,
    dim=4096,
    model_output_path=lora_adapters_path,
    epochs=5,
    verbose=False,
)

with disable_pydantic():
    finetune_engine.finetune()

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/77 [00:00<?, ?it/s]



In [19]:
# repackage as HuggingFaceEmbedding to avoid grief from pydantic which wants embeddings to be lists not tensors
hf_embeddig_model = HuggingFaceEmbedding(
    model=hf_qlora_model.model, 
    tokenizer=hf_qlora_model._tokenizer, 
    query_instruction=hf_qlora_model.query_instruction,
    pooling=hf_qlora_model.pooling,
    embed_batch_size=hf_qlora_model.embed_batch_size
)

Evaluate the fine-tuned model:

In [20]:
from eval_utils import evaluate, display_results

with torch.no_grad():
    lora_sfr_val_results = evaluate(val_dataset, hf_embeddig_model)
display_results(["lora_sfr"], [lora_sfr_val_results])

Generating embeddings:   0%|          | 0/410 [00:00<?, ?it/s]

100%|██████████| 861/861 [01:59<00:00,  7.19it/s]


Unnamed: 0,retrievers,hit_rate,mrr
0,lora_sfr,0.941928,0.803949
