## Обучение embedding модели

In [None]:
# install libs
!pip install -U sentence-transformers datasets



#### Проверка CUDA

In [None]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())

2.5.1+cu124
True


#### Сборка датасета

In [None]:
from datasets import load_dataset

# Load dataset from Hugging Face Hub
dataset = load_dataset("fitlemon/rag-labor-codex-dataset")

training_dataset = dataset["train"]
test_dataset= dataset["test"]

In [None]:
training_dataset = training_dataset.train_test_split(test_size=0.1, seed=42)

In [None]:
# save datasets to disk
training_dataset["train"].to_json("data/train_dataset.json", orient="records")
training_dataset["test"].to_json("data/val_dataset.json", orient="records")

Creating json from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

959850

#### Инициализация моделей

In [None]:
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets

model_id = "BAAI/bge-m3"  # Hugging Face model ID
matryoshka_dimensions = [1024, 768, 512, 256, 128, 64]

model = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

val_dataset = load_dataset("json", data_files="data/val_dataset.json", split="train")
train_dataset = load_dataset("json", data_files="data/train_dataset.json", split="train")

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
corpus = val_dataset['chunk']
queries = val_dataset['question']

# corpus ids as indexes of list
corpus = dict(zip(map(str, range(len(corpus))), corpus))  # Our corpus (cid => document)
queries = dict(zip(map(str, range(len(queries))), queries))  # Our queries (qid => question)

In [None]:
# --- Step 1: Deduplicate the corpus ---

# This dictionary will help us check if a text has been seen before.
seen_texts = {}

# new_corpus will hold the deduplicated texts with new document IDs.
new_corpus = {}

# This dictionary maps old corpus IDs to the new corpus IDs.
old_to_new = {}

new_id = 0
for old_id, text in corpus.items():
    if text in seen_texts:
        # If this text is a duplicate, map the old ID to the existing new ID.
        old_to_new[old_id] = seen_texts[text]
    else:
        # Otherwise, add the text as a new entry.
        new_id_str = str(new_id)
        new_corpus[new_id_str] = text
        seen_texts[text] = new_id_str
        old_to_new[old_id] = new_id_str
        new_id += 1

print(f"Original corpus size: {len(corpus)}")
print(f"Deduplicated corpus size: {len(new_corpus)}")

# --- Step 2: Update the relevant docs mapping ---

new_relevant_docs = {}
for qid in queries.keys():
    new_relevant_docs[qid] = {old_to_new[qid]}

Original corpus size: 527
Deduplicated corpus size: 377


In [None]:
matryoshka_evaluators = []
# Iterate over the different dimensions
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=new_corpus,
        relevant_docs=new_relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)

# Create a sequential evaluator
evaluator = SequentialEvaluator(matryoshka_evaluators)

#### Функция Loss

In [None]:
from sentence_transformers import SentenceTransformerModelCardData, SentenceTransformer


model_id = "BAAI/bge-m3"

# load model with SDPA for using Flash Attention 2
model = SentenceTransformer(
    model_id,
    model_kwargs={"attn_implementation": "sdpa"},
    model_card_data=SentenceTransformerModelCardData(
        language="uz",
        license="apache-2.0",
        model_name="BGE m3 Uzbek Legal Matryoshka",
    ),
)

In [None]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss

matryoshka_dimensions = [1024, 768, 512, 256, 128, 64]  # Important: large to small
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

#### Finetuning

In [None]:
# login to wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mihmatullaev[0m ([33mfitlemon[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

# load train dataset again
train_dataset = load_dataset("json", data_files="data/train_dataset.json", split="train")

# define training arguments
args = SentenceTransformerTrainingArguments(
    output_dir="bge-m3-uz-legal-matryoshka", # output directory and hugging face model ID
    num_train_epochs=4,                         # number of epochs
    per_device_train_batch_size=8,             # train batch size
    gradient_accumulation_steps=1,             # for a global batch size of 512
    per_device_eval_batch_size=8,              # evaluation batch size
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    tf32=False,                                  # use tf32 precision
    bf16=False,
    fp16=True,# use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",                      # evaluate after each epoch
    save_strategy="epoch",                      # save after each epoch
    logging_steps=10,                           # log every 10 steps
    save_total_limit=3,                         # save only the last 3 models
    load_best_model_at_end=True,                # load the best model when training ends
    metric_for_best_model="eval_dim_1024_cosine_recall@3",  # Optimizing for the best ndcg@10 score for the 128 dimension
    run_name="bge-m3-uz-legal-matryoshka",      # name of the run
    report_to="wandb",                    # report to tensorboard
)

In [None]:
from sentence_transformers import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=model, # bg-base-en-v1
    args=args,  # training arguments
    train_dataset=train_dataset.select_columns(
        ["question", "chunk"]
    ),  # training dataset
    loss=train_loss,
    evaluator=evaluator,
)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [None]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mihmatullaev[0m ([33mfitlemon[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch,Training Loss,Validation Loss,Dim 1024 Cosine Accuracy@1,Dim 1024 Cosine Accuracy@3,Dim 1024 Cosine Accuracy@5,Dim 1024 Cosine Accuracy@10,Dim 1024 Cosine Precision@1,Dim 1024 Cosine Precision@3,Dim 1024 Cosine Precision@5,Dim 1024 Cosine Precision@10,Dim 1024 Cosine Recall@1,Dim 1024 Cosine Recall@3,Dim 1024 Cosine Recall@5,Dim 1024 Cosine Recall@10,Dim 1024 Cosine Ndcg@10,Dim 1024 Cosine Mrr@10,Dim 1024 Cosine Map@100,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
1,0.58,No log,0.633776,0.827324,0.867173,0.929791,0.633776,0.275775,0.173435,0.092979,0.633776,0.827324,0.867173,0.929791,0.786384,0.740065,0.743645,0.629981,0.814042,0.872865,0.931689,0.629981,0.271347,0.174573,0.093169,0.629981,0.814042,0.872865,0.931689,0.782972,0.735142,0.738591,0.614801,0.812144,0.882353,0.929791,0.614801,0.270715,0.176471,0.092979,0.614801,0.812144,0.882353,0.929791,0.777013,0.727552,0.730828,0.59962,0.812144,0.857685,0.916509,0.59962,0.270715,0.171537,0.091651,0.59962,0.812144,0.857685,0.916509,0.763099,0.713605,0.717701,0.57685,0.772296,0.83871,0.912713,0.57685,0.257432,0.167742,0.091271,0.57685,0.772296,0.83871,0.912713,0.741432,0.686916,0.691056,0.518027,0.749526,0.821632,0.888046,0.518027,0.249842,0.164326,0.088805,0.518027,0.749526,0.821632,0.888046,0.70459,0.645557,0.650749,0.70459
2,0.4215,No log,0.624288,0.819734,0.87666,0.937381,0.624288,0.273245,0.175332,0.093738,0.624288,0.819734,0.87666,0.937381,0.782126,0.732248,0.735768,0.620493,0.814042,0.872865,0.935484,0.620493,0.271347,0.174573,0.093548,0.620493,0.814042,0.872865,0.935484,0.779124,0.728937,0.732604,0.614801,0.812144,0.870968,0.931689,0.614801,0.270715,0.174194,0.093169,0.614801,0.812144,0.870968,0.931689,0.775282,0.724921,0.728738,0.590133,0.815939,0.863378,0.916509,0.590133,0.27198,0.172676,0.091651,0.590133,0.815939,0.863378,0.916509,0.761006,0.710297,0.715149,0.588235,0.789374,0.848197,0.924099,0.588235,0.263125,0.169639,0.09241,0.588235,0.789374,0.848197,0.924099,0.756176,0.702495,0.706149,0.555977,0.759013,0.844402,0.908918,0.555977,0.253004,0.16888,0.090892,0.555977,0.759013,0.844402,0.908918,0.732603,0.675943,0.680494,0.732603
3,1.3753,No log,0.639469,0.827324,0.893738,0.939279,0.639469,0.275775,0.178748,0.093928,0.639469,0.827324,0.893738,0.939279,0.792213,0.744627,0.748146,0.635674,0.827324,0.889943,0.933586,0.635674,0.275775,0.177989,0.093359,0.635674,0.827324,0.889943,0.933586,0.788561,0.741475,0.745437,0.631879,0.821632,0.895636,0.929791,0.631879,0.273877,0.179127,0.092979,0.631879,0.821632,0.895636,0.929791,0.785555,0.738571,0.742679,0.618596,0.815939,0.888046,0.922201,0.618596,0.27198,0.177609,0.09222,0.618596,0.815939,0.888046,0.922201,0.775188,0.727171,0.731633,0.601518,0.814042,0.867173,0.920304,0.601518,0.271347,0.173435,0.09203,0.601518,0.814042,0.867173,0.920304,0.765598,0.71546,0.719915,0.586338,0.791271,0.857685,0.910816,0.586338,0.263757,0.171537,0.091082,0.586338,0.791271,0.857685,0.910816,0.751065,0.699495,0.704197,0.751065
4,0.2071,No log,0.647059,0.834915,0.891841,0.935484,0.647059,0.278305,0.178368,0.093548,0.647059,0.834915,0.891841,0.935484,0.794629,0.748923,0.752815,0.643264,0.836812,0.895636,0.933586,0.643264,0.278937,0.179127,0.093359,0.643264,0.836812,0.895636,0.933586,0.793255,0.747568,0.75153,0.641366,0.833017,0.893738,0.931689,0.641366,0.277672,0.178748,0.093169,0.641366,0.833017,0.893738,0.931689,0.791454,0.74581,0.749677,0.620493,0.821632,0.88425,0.925996,0.620493,0.273877,0.17685,0.0926,0.620493,0.821632,0.88425,0.925996,0.77794,0.72966,0.73377,0.622391,0.817837,0.867173,0.918406,0.622391,0.272612,0.173435,0.091841,0.622391,0.817837,0.867173,0.918406,0.772742,0.725653,0.73031,0.593928,0.795066,0.857685,0.914611,0.593928,0.265022,0.171537,0.091461,0.593928,0.795066,0.857685,0.914611,0.755534,0.704293,0.708815,0.755534


TrainOutput(global_step=2372, training_loss=0.6182019080822223, metrics={'train_runtime': 2866.6157, 'train_samples_per_second': 6.61, 'train_steps_per_second': 0.827, 'total_flos': 0.0, 'train_loss': 0.6182019080822223, 'epoch': 4.0})

In [None]:
# push model to hub
trainer.model.push_to_hub("bge-m3-uz-legal-matryoshka", exist_ok=True)

model.safetensors:   0%|          | 0.00/2.27G [00:00<?, ?B/s]

'https://huggingface.co/fitlemon/bge-m3-uz-legal-matryoshka/commit/b151c48a1ed6c95b7da368bf879c877e0c2b88ff'

In [None]:
# save the best model
trainer.save_model()

# push model to hub
trainer.model.push_to_hub("bge-m3-uz-legal-matryoshka", exist_ok=True)