# Fine-tuning the model using `BAAI/bge-m3` as baseline.

Customize the embedding model (BAAI/bge-m3) for a specific domain (Legal) and language (Spanish).

In [1]:
!pip install sentence-transformers==3.0.1 accelerate==0.32.1 transformers[torch]==4.42.3 datasets==2.20.0

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
import os

from sentence_transformers import (
    SentenceTransformerModelCardData,
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from datasets import DatasetDict, load_dataset, concatenate_datasets
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.util import cos_sim
from transformers import set_seed
import torch

  from tqdm.autonotebook import tqdm, trange


# Set seed

In [3]:
set_seed(42)

In [4]:
import numpy as np


def set_own_seed(seed):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_own_seed(42)

# Config

In [5]:
INPUT_DATASET = "dariolopez/justicio-rag-embedding-qa-tmp-2"

In [6]:
INPUT_MODEL = "BAAI/bge-m3"

In [7]:
OUTPUT_MODEL = "dariolopez/bge-m3-es-legal-tmp-6"

In [8]:
CONFIG = {
    'num_train_epochs': 6,  # TODO
    'per_device_train_batch_size': 16,
    'per_device_eval_batch_size': 16,
    'gradient_accumulation_steps': 16,
    'learning_rate': 2e-5,
}

# Check GPU

In [9]:
!nvidia-smi

Fri Jul 12 13:59:43 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A6000               On  | 00000000:44:00.0 Off |                  Off |
| 30%   39C    P8              24W / 300W |      2MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Load & Prepare Dataset

In [11]:
dataset = load_dataset(INPUT_DATASET)

In [12]:
# Crear una columna 'id' autoincremental
ids = list(range(1, len(dataset['train']) + 1))

# Añadir la columna 'id' al dataset
dataset['train'] = dataset['train'].add_column('id', ids)

In [13]:
test_size = 0.1

dataset = dataset['train'].train_test_split(test_size=test_size)
dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'context', 'answer', 'id'],
        num_rows: 2947
    })
    test: Dataset({
        features: ['question', 'context', 'answer', 'id'],
        num_rows: 328
    })
})

# Model & Evaluator

In [14]:
matryoshka_dimensions=[1024, 768, 512, 256, 128, 64]

In [15]:
model = SentenceTransformer(
    INPUT_MODEL,
    device="cuda",
    # model_kwargs={"attn_implementation": "sdpa"},  # needs Ampere GPU or newer
    model_card_data=SentenceTransformerModelCardData(
        language="es",
        license="apache-2.0",
        model_name="BGE large Legal Spanish",
    ),
)

In [16]:
def create_evaluator(
    train_dataset, test_dataset, matryoshka_dimensions=[1024, 768, 512, 256, 128, 64]
):
    corpus_dataset = concatenate_datasets([train_dataset, test_dataset])

    # Convert the datasets to dictionaries
    corpus = dict(
        zip(corpus_dataset["id"], corpus_dataset["context"])
    )  # Our corpus (cid => document)
    queries = dict(
        zip(test_dataset["id"], test_dataset["question"])
    )  # Our queries (qid => question)

    # Create a mapping of relevant document (1 in our case) for each query
    relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
    for q_id in queries:
        relevant_docs[q_id] = [q_id]

    matryoshka_evaluators = []
    # Iterate over the different dimensions
    for dim in matryoshka_dimensions:
        ir_evaluator = InformationRetrievalEvaluator(
            queries=queries,
            corpus=corpus,
            relevant_docs=relevant_docs,
            name=f"dim_{dim}",
            truncate_dim=dim,  # Truncate the embeddings to a certain dimension
            score_functions={"cosine": cos_sim},
        )
        matryoshka_evaluators.append(ir_evaluator)

    # Create a sequential evaluator
    return SequentialEvaluator(matryoshka_evaluators)

In [17]:
evaluator = create_evaluator(
    dataset['train'], dataset['test'], matryoshka_dimensions=matryoshka_dimensions
)

# Loss Function

In [18]:
# create Matryoshka loss function with MultipleNegativesRankingLoss
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

# Training

In [19]:
import json
from datetime import datetime


now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
model_save_path = os.path.join('output', now)

# os.makedirs(model_save_path, exist_ok=True)
# with open(os.path.join(model_save_path, 'train_config.json'), 'w') as file:
#     file.write(json.dumps(CONFIG, indent=4))

In [20]:
training_args = SentenceTransformerTrainingArguments(
    output_dir=model_save_path,
    num_train_epochs=CONFIG['num_train_epochs'],  # number of epochs
    per_device_train_batch_size=CONFIG['per_device_train_batch_size'],  # training batch size
    per_device_eval_batch_size=CONFIG['per_device_eval_batch_size'],  # evaluation batch size
    gradient_accumulation_steps=CONFIG['gradient_accumulation_steps'],  # gradient accumulation steps
    warmup_ratio=0.1,  # warmup ratio
    learning_rate=CONFIG['learning_rate'],  # learning rate
    lr_scheduler_type="cosine",  # use constant learning rate scheduler
    optim="adamw_torch_fused",  # use fused adamw optimizer
    tf32=True,  # use tf32 precision # needs Ampere GPU or newer # TODO
    bf16=True,  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",  # evaluate after each epoch
    save_strategy="epoch",  # "no",  # "epoch",  # save after each epoch
    logging_steps=5,  # log every 10 steps
    save_total_limit=3,  # save only the last 3 models
    load_best_model_at_end=True,  # load the best model when training ends
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension  # TODO
)

In [21]:
trainer = SentenceTransformerTrainer(
    model=model,  # bge-bm3
    args=training_args,
    train_dataset=dataset['train'].select_columns(
        ["context", "question"]
    ),
    eval_dataset=dataset['test'].select_columns(
        ['context', 'question']
    ),
    loss=train_loss,
    evaluator=evaluator
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


# Train model

In [22]:
%%time

# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

Epoch,Training Loss,Validation Loss,Dim 1024 Cosine Accuracy@1,Dim 1024 Cosine Accuracy@3,Dim 1024 Cosine Accuracy@5,Dim 1024 Cosine Accuracy@10,Dim 1024 Cosine Precision@1,Dim 1024 Cosine Precision@3,Dim 1024 Cosine Precision@5,Dim 1024 Cosine Precision@10,Dim 1024 Cosine Recall@1,Dim 1024 Cosine Recall@3,Dim 1024 Cosine Recall@5,Dim 1024 Cosine Recall@10,Dim 1024 Cosine Ndcg@10,Dim 1024 Cosine Mrr@10,Dim 1024 Cosine Map@100,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,0.9598,0.547653,0.54878,0.789634,0.844512,0.878049,0.54878,0.263211,0.168902,0.087805,0.54878,0.789634,0.844512,0.878049,0.727703,0.677843,0.683303,0.536585,0.780488,0.847561,0.884146,0.536585,0.260163,0.169512,0.088415,0.536585,0.780488,0.847561,0.884146,0.722369,0.669048,0.674367,0.542683,0.777439,0.838415,0.887195,0.542683,0.259146,0.167683,0.08872,0.542683,0.777439,0.838415,0.887195,0.724475,0.671195,0.675803,0.557927,0.795732,0.820122,0.884146,0.557927,0.265244,0.164024,0.088415,0.557927,0.795732,0.820122,0.884146,0.729455,0.679181,0.683644,0.545732,0.753049,0.792683,0.847561,0.545732,0.251016,0.158537,0.084756,0.545732,0.753049,0.792683,0.847561,0.702934,0.655904,0.661607,0.472561,0.692073,0.737805,0.814024,0.472561,0.230691,0.147561,0.081402,0.472561,0.692073,0.737805,0.814024,0.646596,0.592962,0.599376,0.599376
1,0.3858,0.424167,0.536585,0.783537,0.835366,0.89939,0.536585,0.261179,0.167073,0.089939,0.536585,0.783537,0.835366,0.89939,0.727262,0.671047,0.674756,0.527439,0.783537,0.844512,0.89939,0.527439,0.261179,0.168902,0.089939,0.527439,0.783537,0.844512,0.89939,0.723423,0.665855,0.669687,0.533537,0.79878,0.838415,0.893293,0.533537,0.26626,0.167683,0.089329,0.533537,0.79878,0.838415,0.893293,0.7251,0.669898,0.67403,0.557927,0.786585,0.838415,0.875,0.557927,0.262195,0.167683,0.0875,0.557927,0.786585,0.838415,0.875,0.726772,0.678004,0.683274,0.521341,0.768293,0.820122,0.856707,0.521341,0.256098,0.164024,0.085671,0.521341,0.768293,0.820122,0.856707,0.700064,0.648496,0.654429,0.493902,0.716463,0.780488,0.85061,0.493902,0.238821,0.156098,0.085061,0.493902,0.716463,0.780488,0.85061,0.675184,0.618683,0.623347,0.623347
2,0.1703,0.394022,0.530488,0.79878,0.85061,0.89939,0.530488,0.26626,0.170122,0.089939,0.530488,0.79878,0.85061,0.89939,0.727805,0.671495,0.675518,0.533537,0.801829,0.847561,0.905488,0.533537,0.267276,0.169512,0.090549,0.533537,0.801829,0.847561,0.905488,0.731124,0.674018,0.677587,0.539634,0.804878,0.844512,0.89939,0.539634,0.268293,0.168902,0.089939,0.539634,0.804878,0.844512,0.89939,0.730987,0.675754,0.679721,0.551829,0.789634,0.832317,0.881098,0.551829,0.263211,0.166463,0.08811,0.551829,0.789634,0.832317,0.881098,0.727623,0.677364,0.68227,0.515244,0.771341,0.820122,0.871951,0.515244,0.257114,0.164024,0.087195,0.515244,0.771341,0.820122,0.871951,0.702748,0.647689,0.652332,0.487805,0.716463,0.77439,0.841463,0.487805,0.238821,0.154878,0.084146,0.487805,0.716463,0.77439,0.841463,0.669486,0.614018,0.619572,0.619572
3,0.0594,0.373521,0.551829,0.804878,0.85061,0.902439,0.551829,0.268293,0.170122,0.090244,0.551829,0.804878,0.85061,0.902439,0.736904,0.682802,0.68673,0.54878,0.79878,0.85061,0.89939,0.54878,0.26626,0.170122,0.089939,0.54878,0.79878,0.85061,0.89939,0.732951,0.67838,0.682615,0.551829,0.814024,0.847561,0.89939,0.551829,0.271341,0.169512,0.089939,0.551829,0.814024,0.847561,0.89939,0.735266,0.681377,0.685362,0.554878,0.795732,0.829268,0.884146,0.554878,0.265244,0.165854,0.088415,0.554878,0.795732,0.829268,0.884146,0.73163,0.681611,0.686539,0.527439,0.756098,0.820122,0.868902,0.527439,0.252033,0.164024,0.08689,0.527439,0.756098,0.820122,0.868902,0.706632,0.653749,0.658763,0.484756,0.719512,0.780488,0.859756,0.484756,0.239837,0.156098,0.085976,0.484756,0.719512,0.780488,0.859756,0.67383,0.614289,0.618874,0.618874
4,0.0524,0.364199,0.54878,0.804878,0.85061,0.902439,0.54878,0.268293,0.170122,0.090244,0.54878,0.804878,0.85061,0.902439,0.737197,0.68301,0.686972,0.54878,0.807927,0.85061,0.902439,0.54878,0.269309,0.170122,0.090244,0.54878,0.807927,0.85061,0.902439,0.73594,0.681302,0.685259,0.554878,0.810976,0.85061,0.896341,0.554878,0.270325,0.170122,0.089634,0.554878,0.810976,0.85061,0.896341,0.73577,0.682937,0.687114,0.554878,0.795732,0.832317,0.884146,0.554878,0.265244,0.166463,0.088415,0.554878,0.795732,0.832317,0.884146,0.731181,0.68095,0.685772,0.52439,0.765244,0.814024,0.865854,0.52439,0.255081,0.162805,0.086585,0.52439,0.765244,0.814024,0.865854,0.705001,0.652323,0.65773,0.490854,0.728659,0.786585,0.859756,0.490854,0.242886,0.157317,0.085976,0.490854,0.728659,0.786585,0.859756,0.676922,0.618148,0.622775,0.622775
5,0.0491,0.362644,0.551829,0.804878,0.844512,0.902439,0.551829,0.268293,0.168902,0.090244,0.551829,0.804878,0.844512,0.902439,0.737986,0.684161,0.688087,0.54878,0.804878,0.85061,0.902439,0.54878,0.268293,0.170122,0.090244,0.54878,0.804878,0.85061,0.902439,0.736128,0.681556,0.685489,0.557927,0.810976,0.85061,0.893293,0.557927,0.270325,0.170122,0.089329,0.557927,0.810976,0.85061,0.893293,0.736263,0.684515,0.68893,0.554878,0.795732,0.832317,0.884146,0.554878,0.265244,0.166463,0.088415,0.554878,0.795732,0.832317,0.884146,0.730738,0.680399,0.685134,0.521341,0.762195,0.814024,0.865854,0.521341,0.254065,0.162805,0.086585,0.521341,0.762195,0.814024,0.865854,0.702848,0.649508,0.654997,0.484756,0.72561,0.780488,0.853659,0.484756,0.24187,0.156098,0.085366,0.484756,0.72561,0.780488,0.853659,0.672942,0.614667,0.619832,0.619832


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

CPU times: user 7min 7s, sys: 48.2 s, total: 7min 55s
Wall time: 6min 47s


TrainOutput(global_step=66, training_loss=0.3664314911672563, metrics={'train_runtime': 407.2155, 'train_samples_per_second': 43.422, 'train_steps_per_second': 0.162, 'total_flos': 0.0, 'train_loss': 0.3664314911672563, 'epoch': 5.708108108108108})

In [23]:
# save the best model
trainer.save_model(model_save_path)

# Push model

In [24]:
import huggingface_hub

huggingface_hub.login()

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [25]:
# Save the model to the Hugging Face Hub!
# model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
model.push_to_hub(OUTPUT_MODEL)

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.27G [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

'https://huggingface.co/dariolopez/bge-m3-es-legal-tmp-6/commit/42d0a03ceecf430ecfd7f3f49843b5dadb594bf9'

In [26]:
OUTPUT_MODEL

'dariolopez/bge-m3-es-legal-tmp-6'