In [1]:
!pip install sentence_transformers datasets accelerate



# base

In [2]:
!pip install -U datasets



In [3]:
import random
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from datasets import load_dataset

In [26]:
corpus = load_dataset("khanglt0004/vietnamese_legal_chunks", split = "train")
queries = load_dataset("khanglt0004/questions", split = "train")
relevant_docs_data = load_dataset("khanglt0004/links", split = "train")
# Convert the datasets to dictionaries
corpuss = dict(zip(corpus["id"], corpus["text"]))  # Our corpus (cid => document)
corpus = {}
for cid, text in corpuss.items():
    cid = str(cid)
    corpus[cid] = text
queries = dict(zip(queries["qid"], queries["question"]))  # Our queries (qid => question)
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for qid, corpus_ids in zip(relevant_docs_data["q_id"], relevant_docs_data["chunk_id"]):
    qid = str(qid)
    corpus_ids = str(corpus_ids)
    if qid not in relevant_docs:
        relevant_docs[qid] = set()
    relevant_docs[qid].add(corpus_ids)

In [28]:
import json
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets

model = SentenceTransformer("bkai-foundation-models/vietnamese-bi-encoder")
matryoshka_dimensions = [768, 512, 256, 128, 64] # Important: large to small
matryoshka_evaluators = []
# Iterate over the different dimensions
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)

# Create a sequential evaluator
evaluator = SequentialEvaluator(matryoshka_evaluators)

In [29]:
# Evaluate the model
results = evaluator(model)
for k,v in results.items():
    print(k, v)

dim_768_cosine_accuracy@1 0.4070760471736478
dim_768_cosine_accuracy@3 0.607157381049207
dim_768_cosine_accuracy@5 0.6856445709638064
dim_768_cosine_accuracy@10 0.7669784465229769
dim_768_cosine_precision@1 0.4070760471736478
dim_768_cosine_precision@3 0.20238579368306897
dim_768_cosine_precision@5 0.13712891419276127
dim_768_cosine_precision@10 0.07669784465229768
dim_768_cosine_recall@1 0.4070760471736478
dim_768_cosine_recall@3 0.607157381049207
dim_768_cosine_recall@5 0.6856445709638064
dim_768_cosine_recall@10 0.7669784465229769
dim_768_cosine_ndcg@10 0.5845575816547848
dim_768_cosine_mrr@10 0.5264370598449489
dim_768_cosine_map@100 0.5332781076975427
dim_512_cosine_accuracy@1 0.4062627084180561
dim_512_cosine_accuracy@3 0.601870679137861
dim_512_cosine_accuracy@5 0.6738511590077267
dim_512_cosine_accuracy@10 0.7559983733224889
dim_512_cosine_precision@1 0.4062627084180561
dim_512_cosine_precision@3 0.2006235597126203
dim_512_cosine_precision@5 0.13477023180154535
dim_512_cosine_p

# Training

In [30]:
import pandas as pd
from datasets import Dataset

def prepare_training_dataset(queries, corpus, relevant_docs):
    anchors = []
    positives = []
    for query_id, docs in relevant_docs.items():
        for doc_id in docs:
          anchors.append(queries[query_id])
          positives.append(corpus[str(doc_id)] )
    df = {
        "anchor": anchors,
        "positive": positives
    }

    return Dataset.from_dict(df)

pairs = prepare_training_dataset(queries, corpus, relevant_docs)
pairs

Dataset({
    features: ['anchor', 'positive'],
    num_rows: 2459
})

In [31]:
pairs[0]

{'anchor': 'Quy định này áp dụng cho những đối tượng nào liên quan đến chính sách dân số và kế hoạch hóa gia đình?',
 'positive': 'Đối tượng áp dụng\n\nQuy định này quy định tiêu chuẩn, điều kiện, thẩm quyền xem xét kết nạp lại vào Đảng đối với đảng viên đã bị đưa ra khỏi Đảng do vi phạm chính sách dân số và kế hoạch hoá gia đình, kết nạp quần chúng vi phạm chính sách dân số và kế hoạch hoá gia đình có nguyện vọng phấn đấu vào Đảng.\n\nĐiều 2. Những trường hợp sinh con không bị coi là vi phạm chính sách dân số và kế hoạch hoá gia đình\n\n1. Cặp vợ chồng sinh con thứ ba, nếu cả hai hoặc một trong hai người thuộc dân tộc có số dân dưới 10.000 người hoặc thuộc dân tộc có nguy cơ suy giảm số dân (tỉ lệ nhỏ hơn hoặc bằng tỉ lệ chết) theo công bố chính thức của Bộ Kế hoạch và Đầu tư.\n\n2. Cặp vợ chồng sinh lần thứ nhất mà sinh ba con trở lên.\n\n3. Cặp vợ chồng đã có một con đẻ, sinh lần thứ hai mà sinh hai con trở lên.\n\n4. Cặp vợ chồng sinh lần thứ ba trở lên, nếu tại thời điểm sinh chỉ 

In [32]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss

matryoshka_dimensions = [768, 512, 256, 128, 64]  # Important: large to small
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

In [33]:
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

# define training arguments
args = SentenceTransformerTrainingArguments(
    output_dir="sample", # output directory and hugging face model ID
    num_train_epochs=1,                         # number of epochs
    per_device_train_batch_size=8,             # train batch size
    gradient_accumulation_steps=4,             # for a global batch size of 512
    per_device_eval_batch_size=4,              # evaluation batch size
    #gradient_checkpointing=True,
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    #tf32=True,                                  # use tf32 precision
    bf16=True,                                  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="steps",                      # evaluate after each epoch
    #save_strategy="epoch",                      # save after each epoch
    save_steps = 500,
    logging_steps=10,                           # log every 10 steps
    save_total_limit=3,                         # save only the last 3 models
    load_best_model_at_end=True,                # load the best model when training ends
    metric_for_best_model="eval_dim_768_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension
)

In [34]:
from sentence_transformers import SentenceTransformerTrainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,  # training arguments
    train_dataset=pairs,
    loss=train_loss,
    evaluator=evaluator,
)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [35]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

# save the best model
trainer.save_model()



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mluu92241[0m ([33mluu92241-hanoi-university-of-science-and-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
10,5.2058,No log,0.45181,0.665311,0.736885,0.812119,0.45181,0.22177,0.147377,0.081212,0.45181,0.665311,0.736885,0.812119,0.631425,0.573517,0.579415,0.446116,0.653924,0.727125,0.803985,0.446116,0.217975,0.145425,0.080399,0.446116,0.653924,0.727125,0.803985,0.624372,0.56692,0.573096,0.414396,0.636844,0.707605,0.786905,0.414396,0.212281,0.141521,0.078691,0.414396,0.636844,0.707605,0.786905,0.600391,0.540635,0.546959,0.390809,0.588044,0.672631,0.760065,0.390809,0.196015,0.134526,0.076007,0.390809,0.588044,0.672631,0.760065,0.571187,0.511191,0.517903,0.321269,0.501017,0.572997,0.666124,0.321269,0.167006,0.114599,0.066612,0.321269,0.501017,0.572997,0.666124,0.487354,0.430866,0.439904,0.487354
20,3.3976,No log,0.472956,0.688898,0.759658,0.840179,0.472956,0.229633,0.151932,0.084018,0.472956,0.688898,0.759658,0.840179,0.655933,0.596966,0.602303,0.47621,0.681578,0.753152,0.834892,0.47621,0.227193,0.15063,0.083489,0.47621,0.681578,0.753152,0.834892,0.653377,0.595467,0.60097,0.443676,0.661244,0.739325,0.816999,0.443676,0.220415,0.147865,0.0817,0.443676,0.661244,0.739325,0.816999,0.629497,0.569503,0.575349,0.416836,0.630744,0.700691,0.786499,0.416836,0.210248,0.140138,0.07865,0.416836,0.630744,0.700691,0.786499,0.598614,0.538793,0.545439,0.352582,0.551037,0.615291,0.707605,0.352582,0.183679,0.123058,0.07076,0.352582,0.551037,0.615291,0.707605,0.525111,0.467267,0.476059,0.525111
30,3.4291,No log,0.470923,0.694591,0.771858,0.843839,0.470923,0.23153,0.154372,0.084384,0.470923,0.694591,0.771858,0.843839,0.657437,0.597594,0.603228,0.477837,0.687271,0.764132,0.840992,0.477837,0.22909,0.152826,0.084099,0.477837,0.687271,0.764132,0.840992,0.657704,0.599101,0.604738,0.445303,0.672224,0.747458,0.825539,0.445303,0.224075,0.149492,0.082554,0.445303,0.672224,0.747458,0.825539,0.635492,0.574523,0.580499,0.422123,0.640911,0.716145,0.799512,0.422123,0.213637,0.143229,0.079951,0.422123,0.640911,0.716145,0.799512,0.608445,0.547493,0.554166,0.368849,0.563644,0.637658,0.729972,0.368849,0.187881,0.127532,0.072997,0.368849,0.563644,0.637658,0.729972,0.5438,0.484862,0.493522,0.5438
40,3.4904,No log,0.475397,0.708825,0.781212,0.855226,0.475397,0.236275,0.156242,0.085523,0.475397,0.708825,0.781212,0.855226,0.666844,0.606236,0.611411,0.476617,0.702725,0.776739,0.854819,0.476617,0.234242,0.155348,0.085482,0.476617,0.702725,0.776739,0.854819,0.664245,0.603247,0.608395,0.450183,0.678731,0.764538,0.843026,0.450183,0.226244,0.152908,0.084303,0.450183,0.678731,0.764538,0.843026,0.646405,0.583381,0.588583,0.436763,0.645791,0.727938,0.816592,0.436763,0.215264,0.145588,0.081659,0.436763,0.645791,0.727938,0.816592,0.622539,0.560875,0.566937,0.379829,0.578691,0.652298,0.752338,0.379829,0.192897,0.13046,0.075234,0.379829,0.578691,0.652298,0.752338,0.559524,0.498642,0.506968,0.559524
50,3.379,No log,0.47987,0.709638,0.785685,0.857259,0.47987,0.236546,0.157137,0.085726,0.47987,0.709638,0.785685,0.857259,0.670198,0.609933,0.615229,0.475397,0.705978,0.781212,0.859699,0.475397,0.235326,0.156242,0.08597,0.475397,0.705978,0.781212,0.859699,0.667209,0.605528,0.61051,0.45669,0.686458,0.771045,0.847499,0.45669,0.228819,0.154209,0.08475,0.45669,0.686458,0.771045,0.847499,0.651505,0.588711,0.593963,0.435136,0.653518,0.732818,0.817812,0.435136,0.217839,0.146564,0.081781,0.435136,0.653518,0.732818,0.817812,0.623706,0.561766,0.568281,0.383489,0.584791,0.666124,0.762912,0.383489,0.19493,0.133225,0.076291,0.383489,0.584791,0.666124,0.762912,0.566707,0.504714,0.512996,0.566707
60,2.5151,No log,0.483123,0.714111,0.789752,0.859699,0.483123,0.238037,0.15795,0.08597,0.483123,0.714111,0.789752,0.859699,0.673225,0.613163,0.618397,0.481903,0.708825,0.783652,0.860512,0.481903,0.236275,0.15673,0.086051,0.481903,0.708825,0.783652,0.860512,0.671043,0.610296,0.615413,0.458723,0.689305,0.775112,0.848312,0.458723,0.229768,0.155022,0.084831,0.458723,0.689305,0.775112,0.848312,0.653621,0.591128,0.596561,0.437576,0.657178,0.741358,0.823505,0.437576,0.219059,0.148272,0.082351,0.437576,0.657178,0.741358,0.823505,0.62789,0.565502,0.571812,0.389589,0.592517,0.674664,0.769012,0.389589,0.197506,0.134933,0.076901,0.389589,0.592517,0.674664,0.769012,0.571992,0.509842,0.518016,0.571992
70,2.8592,No log,0.48353,0.715738,0.789752,0.859292,0.48353,0.238579,0.15795,0.085929,0.48353,0.715738,0.789752,0.859292,0.673932,0.614148,0.619504,0.482717,0.710451,0.787312,0.861326,0.482717,0.236817,0.157462,0.086133,0.482717,0.710451,0.787312,0.861326,0.672354,0.611699,0.61682,0.46157,0.690118,0.777552,0.848719,0.46157,0.230039,0.15551,0.084872,0.46157,0.690118,0.777552,0.848719,0.655264,0.593147,0.598642,0.43961,0.658804,0.742985,0.825132,0.43961,0.219601,0.148597,0.082513,0.43961,0.658804,0.742985,0.825132,0.629716,0.567388,0.573625,0.390403,0.595771,0.676698,0.772672,0.390403,0.19859,0.13534,0.077267,0.390403,0.595771,0.676698,0.772672,0.574506,0.511952,0.519903,0.574506


# Re-evaluate

In [36]:
from sentence_transformers import SentenceTransformer
import torch
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
# Evaluate the model
results = evaluator(fine_tuned_model)

for k,v in results.items():
    print(k, v)

dim_768_cosine_accuracy@1 0.48393655957706383
dim_768_cosine_accuracy@3 0.7157381049206994
dim_768_cosine_accuracy@5 0.7897519316795445
dim_768_cosine_accuracy@10 0.8605124034160228
dim_768_cosine_precision@1 0.48393655957706383
dim_768_cosine_precision@3 0.23857936830689977
dim_768_cosine_precision@5 0.1579503863359089
dim_768_cosine_precision@10 0.08605124034160226
dim_768_cosine_recall@1 0.48393655957706383
dim_768_cosine_recall@3 0.7157381049206994
dim_768_cosine_recall@5 0.7897519316795445
dim_768_cosine_recall@10 0.8605124034160228
dim_768_cosine_ndcg@10 0.674188755112397
dim_768_cosine_mrr@10 0.6141569356494113
dim_768_cosine_map@100 0.6194318226528878
dim_512_cosine_accuracy@1 0.4843432289548597
dim_512_cosine_accuracy@3 0.7092313948759659
dim_512_cosine_accuracy@5 0.7864985766571777
dim_512_cosine_accuracy@10 0.862545750305002
dim_512_cosine_precision@1 0.4843432289548597
dim_512_cosine_precision@3 0.23641046495865525
dim_512_cosine_precision@5 0.15729971533143552
dim_512_cosi

# Push model to hub

In [None]:
from huggingface_hub import login

login(token="", add_to_git_credential=True)  # ADD YOUR TOKEN HERE

In [38]:
# push model to hub
trainer.model.push_to_hub("ltk_embedding")

model.safetensors:   0%|          | 0.00/540M [00:00<?, ?B/s]

'https://huggingface.co/khanglt0004/ltk_embedding/commit/f4d64f308e223af63fcb5cf3fe3f1ca4ef6f6d0b'