# Overview

This is the pipeline in charge of the bge-base-en-1.5v mode fine-tune. In this notebook what I'll do is the following:
1. download train and validation data from the previous step from S3
2. fine-tune the model for a couple of rounds
3. compare the fine-tuned model with other versions of the same model (to do so I create a faiss index for each and run a serach of 10 chunks per query to then calculate several metrics)
4. save the embeddings generated by the fine-tuned embedding model to OpenSearch

---

# Setup

In [1]:
!pip install -q -r model_training/requirements.txt

In [2]:
from general_utils import load_config
import torch

CONFIG = load_config()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
from general_utils import S3Manager
s3_client = S3Manager.get_client()

In [4]:
S3Manager.download_files(
    s3_client,
    'model_training/data',
     ["training.json",
     "test_queries.jsonl",
     "corpus.jsonl",
     "test_qrels.jsonl"
    ],
    "medical-qa-data")

----

# Fine tuning
Will utilize the FlagEmbedding library implementation of the tuning pipeline, as it it suggested by the model's developers.

In [5]:
import subprocess

lr = float(CONFIG['LR'])
epochs = int(CONFIG['EPOCHS'])
warmup = float(CONFIG['WARMUP_RATIO'])
batch_size = int(CONFIG['BATCH_SIZE'])
embedding_model = CONFIG['EMBEDDING_MODEL']
query_instruction = CONFIG["QUERY_INSTRUCTION_AT_RETRIEVAL"]

cmd = [
    "torchrun", "--nproc_per_node", "1",
    "-m", "FlagEmbedding.finetune.embedder.encoder_only.base",
    "--model_name_or_path", f"BAAI/{embedding_model}",
    "--train_data", "model_training/data/training.json",
    "--query_instruction_for_retrieval", query_instruction,
    "--output_dir", "model_training/model",
    "--learning_rate", str(lr),
    "--fp16",
    "--num_train_epochs", str(epochs),
    "--per_device_train_batch_size", str(batch_size),
    "--query_max_len", "256",
    "--passage_max_len", "512",
    "--warmup_ratio", str(warmup),
    "--normalize_embeddings", "True",
    "--logging_steps", "10",
]

print(" ".join(cmd))
result = subprocess.run(cmd, check=True)

torchrun --nproc_per_node 1 -m FlagEmbedding.finetune.embedder.encoder_only.base --model_name_or_path BAAI/bge-base-en-v1.5 --train_data model_training/data/training.json --query_instruction_for_retrieval Represent this sentence for searching relevant passages: --output_dir model_training/model --learning_rate 1e-05 --fp16 --num_train_epochs 1 --per_device_train_batch_size 8 --query_max_len 256 --passage_max_len 512 --warmup_ratio 0.05 --normalize_embeddings True --logging_steps 10


2025-07-11 23:27:17.955527: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752276437.971267    2309 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752276437.976325    2309 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-11 23:27:17.991838: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
07/11/2025 23:27:21 - INFO - FlagEmbedding.abc.finetune.embedder.AbsRunner -   Training/evaluation paramete

{'loss': 0.2228, 'grad_norm': 0.2872343063354492, 'learning_rate': 8.000000000000001e-07, 'epoch': 0.01}


  1%|▏         | 20/1498 [00:13<16:05,  1.53it/s]

{'loss': 0.4406, 'grad_norm': 0.5305854082107544, 'learning_rate': 2.133333333333334e-06, 'epoch': 0.01}


  2%|▏         | 30/1498 [00:19<15:59,  1.53it/s]

{'loss': 0.383, 'grad_norm': 4.104724407196045, 'learning_rate': 3.4666666666666672e-06, 'epoch': 0.02}


  3%|▎         | 40/1498 [00:26<15:51,  1.53it/s]

{'loss': 0.2743, 'grad_norm': 12.460972785949707, 'learning_rate': 4.800000000000001e-06, 'epoch': 0.03}


  3%|▎         | 50/1498 [00:33<15:42,  1.54it/s]

{'loss': 0.1724, 'grad_norm': 0.3199245035648346, 'learning_rate': 6.133333333333334e-06, 'epoch': 0.03}


  4%|▍         | 60/1498 [00:39<15:37,  1.53it/s]

{'loss': 0.3378, 'grad_norm': 0.15537141263484955, 'learning_rate': 7.4666666666666675e-06, 'epoch': 0.04}


  5%|▍         | 70/1498 [00:46<15:28,  1.54it/s]

{'loss': 0.0934, 'grad_norm': 14.955554962158203, 'learning_rate': 8.8e-06, 'epoch': 0.05}


  5%|▌         | 80/1498 [00:52<15:04,  1.57it/s]

{'loss': 0.0784, 'grad_norm': 15.120832443237305, 'learning_rate': 9.992972593113141e-06, 'epoch': 0.05}


  6%|▌         | 90/1498 [00:58<15:10,  1.55it/s]

{'loss': 0.0411, 'grad_norm': 0.9340715408325195, 'learning_rate': 9.922698524244555e-06, 'epoch': 0.06}


  7%|▋         | 100/1498 [01:05<15:10,  1.53it/s]

{'loss': 0.0061, 'grad_norm': 0.02038661390542984, 'learning_rate': 9.852424455375968e-06, 'epoch': 0.07}


  7%|▋         | 110/1498 [01:11<15:05,  1.53it/s]

{'loss': 0.0494, 'grad_norm': 0.20235000550746918, 'learning_rate': 9.782150386507379e-06, 'epoch': 0.07}


  8%|▊         | 120/1498 [01:18<14:58,  1.53it/s]

{'loss': 0.0057, 'grad_norm': 0.17647592723369598, 'learning_rate': 9.711876317638792e-06, 'epoch': 0.08}


  9%|▊         | 130/1498 [01:25<14:52,  1.53it/s]

{'loss': 0.1015, 'grad_norm': 1.0885107517242432, 'learning_rate': 9.641602248770204e-06, 'epoch': 0.09}


  9%|▉         | 140/1498 [01:31<14:45,  1.53it/s]

{'loss': 0.0088, 'grad_norm': 0.022232770919799805, 'learning_rate': 9.571328179901617e-06, 'epoch': 0.09}


 10%|█         | 150/1498 [01:38<14:40,  1.53it/s]

{'loss': 0.1201, 'grad_norm': 12.321647644042969, 'learning_rate': 9.50105411103303e-06, 'epoch': 0.1}


 11%|█         | 160/1498 [01:44<14:31,  1.54it/s]

{'loss': 0.0237, 'grad_norm': 0.2354845106601715, 'learning_rate': 9.430780042164443e-06, 'epoch': 0.11}


 11%|█▏        | 170/1498 [01:51<14:25,  1.53it/s]

{'loss': 0.0012, 'grad_norm': 0.3331996500492096, 'learning_rate': 9.360505973295854e-06, 'epoch': 0.11}


 12%|█▏        | 180/1498 [01:57<14:21,  1.53it/s]

{'loss': 0.0686, 'grad_norm': 28.511619567871094, 'learning_rate': 9.290231904427267e-06, 'epoch': 0.12}


 13%|█▎        | 190/1498 [02:04<14:13,  1.53it/s]

{'loss': 0.015, 'grad_norm': 0.08075923472642899, 'learning_rate': 9.219957835558679e-06, 'epoch': 0.13}


 13%|█▎        | 200/1498 [02:10<14:06,  1.53it/s]

{'loss': 0.0788, 'grad_norm': 0.0035928403958678246, 'learning_rate': 9.149683766690092e-06, 'epoch': 0.13}


 14%|█▍        | 210/1498 [02:17<13:44,  1.56it/s]

{'loss': 0.0464, 'grad_norm': 1.007736086845398, 'learning_rate': 9.079409697821505e-06, 'epoch': 0.14}


 15%|█▍        | 220/1498 [02:23<13:52,  1.54it/s]

{'loss': 0.0317, 'grad_norm': 0.02787395939230919, 'learning_rate': 9.009135628952918e-06, 'epoch': 0.15}


 15%|█▌        | 230/1498 [02:30<13:47,  1.53it/s]

{'loss': 0.0139, 'grad_norm': 0.0015129104722291231, 'learning_rate': 8.93886156008433e-06, 'epoch': 0.15}


 16%|█▌        | 240/1498 [02:36<13:39,  1.54it/s]

{'loss': 0.0559, 'grad_norm': 0.14703983068466187, 'learning_rate': 8.868587491215742e-06, 'epoch': 0.16}


 17%|█▋        | 250/1498 [02:43<13:35,  1.53it/s]

{'loss': 0.0048, 'grad_norm': 0.1100415363907814, 'learning_rate': 8.798313422347154e-06, 'epoch': 0.17}


 17%|█▋        | 260/1498 [02:49<13:27,  1.53it/s]

{'loss': 0.0341, 'grad_norm': 13.640917778015137, 'learning_rate': 8.728039353478567e-06, 'epoch': 0.17}


 18%|█▊        | 270/1498 [02:56<13:22,  1.53it/s]

{'loss': 0.0633, 'grad_norm': 0.00011624133912846446, 'learning_rate': 8.657765284609978e-06, 'epoch': 0.18}


 19%|█▊        | 280/1498 [03:02<13:13,  1.54it/s]

{'loss': 0.0344, 'grad_norm': 0.8813758492469788, 'learning_rate': 8.587491215741393e-06, 'epoch': 0.19}


 19%|█▉        | 290/1498 [03:09<13:06,  1.54it/s]

{'loss': 0.004, 'grad_norm': 0.513365626335144, 'learning_rate': 8.517217146872804e-06, 'epoch': 0.19}


 20%|██        | 300/1498 [03:15<12:56,  1.54it/s]

{'loss': 0.0438, 'grad_norm': 0.002510434715077281, 'learning_rate': 8.446943078004218e-06, 'epoch': 0.2}


 21%|██        | 310/1498 [03:22<12:55,  1.53it/s]

{'loss': 0.0208, 'grad_norm': 1.275662899017334, 'learning_rate': 8.376669009135629e-06, 'epoch': 0.21}


 21%|██▏       | 320/1498 [03:28<12:49,  1.53it/s]

{'loss': 0.0832, 'grad_norm': 0.005666928365826607, 'learning_rate': 8.306394940267042e-06, 'epoch': 0.21}


 22%|██▏       | 330/1498 [03:35<12:42,  1.53it/s]

{'loss': 0.0027, 'grad_norm': 0.0005533059593290091, 'learning_rate': 8.236120871398453e-06, 'epoch': 0.22}


 23%|██▎       | 340/1498 [03:42<12:34,  1.53it/s]

{'loss': 0.0842, 'grad_norm': 0.03399350866675377, 'learning_rate': 8.165846802529868e-06, 'epoch': 0.23}


 23%|██▎       | 350/1498 [03:48<12:29,  1.53it/s]

{'loss': 0.0012, 'grad_norm': 0.0014170610811561346, 'learning_rate': 8.09557273366128e-06, 'epoch': 0.23}


 24%|██▍       | 360/1498 [03:55<12:23,  1.53it/s]

{'loss': 0.0004, 'grad_norm': 0.024806996807456017, 'learning_rate': 8.025298664792693e-06, 'epoch': 0.24}


 25%|██▍       | 370/1498 [04:01<12:15,  1.53it/s]

{'loss': 0.0689, 'grad_norm': 4.044215679168701, 'learning_rate': 7.955024595924104e-06, 'epoch': 0.25}


 25%|██▌       | 380/1498 [04:08<12:09,  1.53it/s]

{'loss': 0.0275, 'grad_norm': 0.0018002022989094257, 'learning_rate': 7.884750527055517e-06, 'epoch': 0.25}


 26%|██▌       | 390/1498 [04:14<12:03,  1.53it/s]

{'loss': 0.0842, 'grad_norm': 4.649872779846191, 'learning_rate': 7.814476458186929e-06, 'epoch': 0.26}


 27%|██▋       | 400/1498 [04:21<11:54,  1.54it/s]

{'loss': 0.0275, 'grad_norm': 0.022798078134655952, 'learning_rate': 7.744202389318343e-06, 'epoch': 0.27}


 27%|██▋       | 410/1498 [04:27<11:49,  1.53it/s]

{'loss': 0.0041, 'grad_norm': 1.2741225957870483, 'learning_rate': 7.673928320449755e-06, 'epoch': 0.27}


 28%|██▊       | 420/1498 [04:34<11:44,  1.53it/s]

{'loss': 0.0116, 'grad_norm': 0.00047184681170620024, 'learning_rate': 7.603654251581167e-06, 'epoch': 0.28}


 29%|██▊       | 430/1498 [04:40<11:37,  1.53it/s]

{'loss': 0.0891, 'grad_norm': 0.07220727205276489, 'learning_rate': 7.533380182712579e-06, 'epoch': 0.29}


 29%|██▉       | 440/1498 [04:47<11:30,  1.53it/s]

{'loss': 0.0004, 'grad_norm': 0.18373744189739227, 'learning_rate': 7.463106113843992e-06, 'epoch': 0.29}


 30%|███       | 450/1498 [04:53<11:22,  1.53it/s]

{'loss': 0.0489, 'grad_norm': 2.5354275703430176, 'learning_rate': 7.3928320449754046e-06, 'epoch': 0.3}


 31%|███       | 460/1498 [05:00<11:16,  1.53it/s]

{'loss': 0.0402, 'grad_norm': 14.390166282653809, 'learning_rate': 7.322557976106818e-06, 'epoch': 0.31}


 31%|███▏      | 470/1498 [05:06<11:10,  1.53it/s]

{'loss': 0.0295, 'grad_norm': 0.06840209662914276, 'learning_rate': 7.25228390723823e-06, 'epoch': 0.31}


 32%|███▏      | 480/1498 [05:13<11:04,  1.53it/s]

{'loss': 0.012, 'grad_norm': 0.0005931578925810754, 'learning_rate': 7.182009838369642e-06, 'epoch': 0.32}


 33%|███▎      | 490/1498 [05:19<10:57,  1.53it/s]

{'loss': 0.0043, 'grad_norm': 0.029638569802045822, 'learning_rate': 7.111735769501054e-06, 'epoch': 0.33}


 33%|███▎      | 500/1498 [05:26<10:51,  1.53it/s]07/11/2025 23:32:52 - INFO - FlagEmbedding.finetune.embedder.encoder_only.base.trainer -   Saving model checkpoint to model_training/model/checkpoint-500


{'loss': 0.0062, 'grad_norm': 0.01664602942764759, 'learning_rate': 7.041461700632467e-06, 'epoch': 0.33}


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
 34%|███▍      | 510/1498 [05:44<13:07,  1.26it/s]  

{'loss': 0.0531, 'grad_norm': 0.0051517244428396225, 'learning_rate': 6.97118763176388e-06, 'epoch': 0.34}


 35%|███▍      | 520/1498 [05:51<10:42,  1.52it/s]

{'loss': 0.0731, 'grad_norm': 1.578456794959493e-05, 'learning_rate': 6.900913562895292e-06, 'epoch': 0.35}


 35%|███▌      | 530/1498 [05:58<10:29,  1.54it/s]

{'loss': 0.0354, 'grad_norm': 0.008438973687589169, 'learning_rate': 6.830639494026705e-06, 'epoch': 0.35}


 36%|███▌      | 540/1498 [06:04<10:23,  1.54it/s]

{'loss': 0.0567, 'grad_norm': 0.3208337128162384, 'learning_rate': 6.760365425158117e-06, 'epoch': 0.36}


 37%|███▋      | 550/1498 [06:11<10:19,  1.53it/s]

{'loss': 0.0172, 'grad_norm': 0.8457298278808594, 'learning_rate': 6.6900913562895295e-06, 'epoch': 0.37}


 37%|███▋      | 560/1498 [06:17<10:07,  1.54it/s]

{'loss': 0.0555, 'grad_norm': 0.0037350100465118885, 'learning_rate': 6.619817287420942e-06, 'epoch': 0.37}


 38%|███▊      | 570/1498 [06:24<10:04,  1.54it/s]

{'loss': 0.063, 'grad_norm': 0.024603016674518585, 'learning_rate': 6.549543218552354e-06, 'epoch': 0.38}


 39%|███▊      | 580/1498 [06:30<09:58,  1.53it/s]

{'loss': 0.0011, 'grad_norm': 0.8704374432563782, 'learning_rate': 6.479269149683767e-06, 'epoch': 0.39}


 39%|███▉      | 590/1498 [06:37<09:52,  1.53it/s]

{'loss': 0.0079, 'grad_norm': 0.12570367753505707, 'learning_rate': 6.40899508081518e-06, 'epoch': 0.39}


 40%|████      | 600/1498 [06:43<09:44,  1.54it/s]

{'loss': 0.0347, 'grad_norm': 1.631155252456665, 'learning_rate': 6.338721011946592e-06, 'epoch': 0.4}


 41%|████      | 610/1498 [06:50<09:38,  1.53it/s]

{'loss': 0.0003, 'grad_norm': 0.04360812157392502, 'learning_rate': 6.268446943078005e-06, 'epoch': 0.41}


 41%|████▏     | 620/1498 [06:56<09:32,  1.53it/s]

{'loss': 0.0653, 'grad_norm': 24.29164695739746, 'learning_rate': 6.198172874209417e-06, 'epoch': 0.41}


 42%|████▏     | 630/1498 [07:03<09:25,  1.53it/s]

{'loss': 0.0132, 'grad_norm': 0.0032785367220640182, 'learning_rate': 6.127898805340829e-06, 'epoch': 0.42}


 43%|████▎     | 640/1498 [07:09<09:17,  1.54it/s]

{'loss': 0.0295, 'grad_norm': 0.0019707975443452597, 'learning_rate': 6.057624736472241e-06, 'epoch': 0.43}


 43%|████▎     | 650/1498 [07:16<09:13,  1.53it/s]

{'loss': 0.0003, 'grad_norm': 0.002090967958793044, 'learning_rate': 5.987350667603655e-06, 'epoch': 0.43}


 44%|████▍     | 660/1498 [07:22<09:05,  1.54it/s]

{'loss': 0.0433, 'grad_norm': 7.764810288790613e-05, 'learning_rate': 5.9170765987350676e-06, 'epoch': 0.44}


 45%|████▍     | 670/1498 [07:29<09:00,  1.53it/s]

{'loss': 0.0206, 'grad_norm': 0.002593463519588113, 'learning_rate': 5.84680252986648e-06, 'epoch': 0.45}


 45%|████▌     | 680/1498 [07:35<08:52,  1.54it/s]

{'loss': 0.0075, 'grad_norm': 2.8137617111206055, 'learning_rate': 5.776528460997892e-06, 'epoch': 0.45}


 46%|████▌     | 690/1498 [07:42<08:43,  1.54it/s]

{'loss': 0.0001, 'grad_norm': 0.08308595418930054, 'learning_rate': 5.706254392129304e-06, 'epoch': 0.46}


 47%|████▋     | 700/1498 [07:48<08:41,  1.53it/s]

{'loss': 0.0187, 'grad_norm': 0.02304450236260891, 'learning_rate': 5.6359803232607165e-06, 'epoch': 0.47}


 47%|████▋     | 710/1498 [07:55<08:34,  1.53it/s]

{'loss': 0.0053, 'grad_norm': 0.5275506973266602, 'learning_rate': 5.5657062543921305e-06, 'epoch': 0.47}


 48%|████▊     | 720/1498 [08:01<08:27,  1.53it/s]

{'loss': 0.0126, 'grad_norm': 0.0021512776147574186, 'learning_rate': 5.495432185523543e-06, 'epoch': 0.48}


 49%|████▊     | 730/1498 [08:08<08:19,  1.54it/s]

{'loss': 0.0427, 'grad_norm': 5.724028960685246e-05, 'learning_rate': 5.425158116654955e-06, 'epoch': 0.49}


 49%|████▉     | 740/1498 [08:14<08:14,  1.53it/s]

{'loss': 0.0385, 'grad_norm': 26.72345733642578, 'learning_rate': 5.354884047786367e-06, 'epoch': 0.49}


 50%|█████     | 750/1498 [08:21<08:00,  1.56it/s]

{'loss': 0.1036, 'grad_norm': 22.272626876831055, 'learning_rate': 5.2846099789177794e-06, 'epoch': 0.5}


 51%|█████     | 760/1498 [08:27<08:01,  1.53it/s]

{'loss': 0.0016, 'grad_norm': 0.03281472250819206, 'learning_rate': 5.214335910049192e-06, 'epoch': 0.51}


 51%|█████▏    | 770/1498 [08:34<07:55,  1.53it/s]

{'loss': 0.0245, 'grad_norm': 0.0003843794693239033, 'learning_rate': 5.144061841180604e-06, 'epoch': 0.51}


 52%|█████▏    | 780/1498 [08:40<07:46,  1.54it/s]

{'loss': 0.0021, 'grad_norm': 3.085937023162842, 'learning_rate': 5.073787772312018e-06, 'epoch': 0.52}


 53%|█████▎    | 790/1498 [08:47<07:42,  1.53it/s]

{'loss': 0.0181, 'grad_norm': 0.03226613998413086, 'learning_rate': 5.00351370344343e-06, 'epoch': 0.53}


 53%|█████▎    | 800/1498 [08:53<07:34,  1.53it/s]

{'loss': 0.0081, 'grad_norm': 0.0007956930203363299, 'learning_rate': 4.933239634574842e-06, 'epoch': 0.53}


 54%|█████▍    | 810/1498 [09:00<07:29,  1.53it/s]

{'loss': 0.0001, 'grad_norm': 0.1493021845817566, 'learning_rate': 4.862965565706255e-06, 'epoch': 0.54}


 55%|█████▍    | 820/1498 [09:06<07:21,  1.54it/s]

{'loss': 0.0235, 'grad_norm': 0.026879537850618362, 'learning_rate': 4.792691496837668e-06, 'epoch': 0.55}


 55%|█████▌    | 830/1498 [09:13<07:15,  1.53it/s]

{'loss': 0.0228, 'grad_norm': 0.3993976414203644, 'learning_rate': 4.72241742796908e-06, 'epoch': 0.55}


 56%|█████▌    | 840/1498 [09:20<07:08,  1.54it/s]

{'loss': 0.0003, 'grad_norm': 0.0015805925941094756, 'learning_rate': 4.652143359100492e-06, 'epoch': 0.56}


 57%|█████▋    | 850/1498 [09:26<07:02,  1.53it/s]

{'loss': 0.0159, 'grad_norm': 0.0005396092310547829, 'learning_rate': 4.581869290231905e-06, 'epoch': 0.57}


 57%|█████▋    | 860/1498 [09:33<06:56,  1.53it/s]

{'loss': 0.0235, 'grad_norm': 11.488312721252441, 'learning_rate': 4.5115952213633175e-06, 'epoch': 0.57}


 58%|█████▊    | 870/1498 [09:39<06:49,  1.53it/s]

{'loss': 0.001, 'grad_norm': 0.00032758706947788596, 'learning_rate': 4.44132115249473e-06, 'epoch': 0.58}


 59%|█████▊    | 880/1498 [09:46<06:43,  1.53it/s]

{'loss': 0.0001, 'grad_norm': 0.0004016447637695819, 'learning_rate': 4.371047083626142e-06, 'epoch': 0.59}


 59%|█████▉    | 890/1498 [09:52<06:37,  1.53it/s]

{'loss': 0.0059, 'grad_norm': 0.00017022380779962987, 'learning_rate': 4.300773014757555e-06, 'epoch': 0.59}


 60%|██████    | 900/1498 [09:59<06:30,  1.53it/s]

{'loss': 0.0086, 'grad_norm': 0.06104109436273575, 'learning_rate': 4.230498945888967e-06, 'epoch': 0.6}


 61%|██████    | 910/1498 [10:05<06:23,  1.53it/s]

{'loss': 0.0559, 'grad_norm': 0.00017051940085366368, 'learning_rate': 4.1602248770203795e-06, 'epoch': 0.61}


 61%|██████▏   | 920/1498 [10:12<06:17,  1.53it/s]

{'loss': 0.0665, 'grad_norm': 3.943494903069222e-06, 'learning_rate': 4.089950808151793e-06, 'epoch': 0.61}


 62%|██████▏   | 930/1498 [10:18<06:11,  1.53it/s]

{'loss': 0.0523, 'grad_norm': 0.003912900574505329, 'learning_rate': 4.019676739283205e-06, 'epoch': 0.62}


 63%|██████▎   | 940/1498 [10:25<06:03,  1.54it/s]

{'loss': 0.0002, 'grad_norm': 0.0022401735186576843, 'learning_rate': 3.949402670414617e-06, 'epoch': 0.63}


 63%|██████▎   | 950/1498 [10:31<05:58,  1.53it/s]

{'loss': 0.0336, 'grad_norm': 1.3069478273391724, 'learning_rate': 3.87912860154603e-06, 'epoch': 0.63}


 64%|██████▍   | 960/1498 [10:38<05:50,  1.54it/s]

{'loss': 0.0222, 'grad_norm': 0.010462482459843159, 'learning_rate': 3.8088545326774424e-06, 'epoch': 0.64}


 65%|██████▍   | 970/1498 [10:44<05:45,  1.53it/s]

{'loss': 0.0179, 'grad_norm': 0.10365218669176102, 'learning_rate': 3.7385804638088547e-06, 'epoch': 0.65}


 65%|██████▌   | 980/1498 [10:51<05:37,  1.54it/s]

{'loss': 0.0711, 'grad_norm': 6.747337341308594, 'learning_rate': 3.6683063949402673e-06, 'epoch': 0.65}


 66%|██████▌   | 990/1498 [10:57<05:31,  1.53it/s]

{'loss': 0.0563, 'grad_norm': 0.15940849483013153, 'learning_rate': 3.59803232607168e-06, 'epoch': 0.66}


 67%|██████▋   | 1000/1498 [11:04<05:24,  1.53it/s]07/11/2025 23:38:30 - INFO - FlagEmbedding.finetune.embedder.encoder_only.base.trainer -   Saving model checkpoint to model_training/model/checkpoint-1000


{'loss': 0.0001, 'grad_norm': 0.007190367206931114, 'learning_rate': 3.5277582572030923e-06, 'epoch': 0.67}


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
 67%|██████▋   | 1010/1498 [11:23<06:30,  1.25it/s]

{'loss': 0.038, 'grad_norm': 5.659230709075928, 'learning_rate': 3.457484188334505e-06, 'epoch': 0.67}


 68%|██████▊   | 1020/1498 [11:29<05:13,  1.53it/s]

{'loss': 0.0005, 'grad_norm': 0.0004365237837191671, 'learning_rate': 3.3872101194659176e-06, 'epoch': 0.68}


 69%|██████▉   | 1030/1498 [11:36<05:05,  1.53it/s]

{'loss': 0.0001, 'grad_norm': 0.0074431621469557285, 'learning_rate': 3.31693605059733e-06, 'epoch': 0.69}


 69%|██████▉   | 1040/1498 [11:42<04:58,  1.53it/s]

{'loss': 0.0503, 'grad_norm': 0.0005899499519728124, 'learning_rate': 3.2466619817287425e-06, 'epoch': 0.69}


 70%|███████   | 1050/1498 [11:49<04:52,  1.53it/s]

{'loss': 0.021, 'grad_norm': 0.013082698918879032, 'learning_rate': 3.1763879128601547e-06, 'epoch': 0.7}


 71%|███████   | 1060/1498 [11:55<04:46,  1.53it/s]

{'loss': 0.0357, 'grad_norm': 0.004151010420173407, 'learning_rate': 3.1061138439915674e-06, 'epoch': 0.71}


 71%|███████▏  | 1070/1498 [12:02<04:38,  1.54it/s]

{'loss': 0.0012, 'grad_norm': 0.030608082190155983, 'learning_rate': 3.03583977512298e-06, 'epoch': 0.71}


 72%|███████▏  | 1080/1498 [12:08<04:29,  1.55it/s]

{'loss': 0.0226, 'grad_norm': 0.014499650336802006, 'learning_rate': 2.9655657062543923e-06, 'epoch': 0.72}


 73%|███████▎  | 1090/1498 [12:15<04:26,  1.53it/s]

{'loss': 0.0011, 'grad_norm': 0.018620064482092857, 'learning_rate': 2.895291637385805e-06, 'epoch': 0.73}


 73%|███████▎  | 1100/1498 [12:21<04:19,  1.53it/s]

{'loss': 0.031, 'grad_norm': 0.010325251147150993, 'learning_rate': 2.8250175685172176e-06, 'epoch': 0.73}


 74%|███████▍  | 1110/1498 [12:28<04:13,  1.53it/s]

{'loss': 0.0312, 'grad_norm': 0.02561090886592865, 'learning_rate': 2.75474349964863e-06, 'epoch': 0.74}


 75%|███████▍  | 1120/1498 [12:34<04:06,  1.54it/s]

{'loss': 0.0603, 'grad_norm': 0.012307639233767986, 'learning_rate': 2.684469430780042e-06, 'epoch': 0.75}


 75%|███████▌  | 1130/1498 [12:41<03:59,  1.53it/s]

{'loss': 0.0029, 'grad_norm': 1.2785083055496216, 'learning_rate': 2.6141953619114548e-06, 'epoch': 0.75}


 76%|███████▌  | 1140/1498 [12:47<03:52,  1.54it/s]

{'loss': 0.0064, 'grad_norm': 3.6805906295776367, 'learning_rate': 2.5439212930428674e-06, 'epoch': 0.76}


 77%|███████▋  | 1150/1498 [12:54<03:47,  1.53it/s]

{'loss': 0.0181, 'grad_norm': 0.0004166977305430919, 'learning_rate': 2.4736472241742797e-06, 'epoch': 0.77}


 77%|███████▋  | 1160/1498 [13:01<03:40,  1.53it/s]

{'loss': 0.0006, 'grad_norm': 0.0038504544645547867, 'learning_rate': 2.4033731553056924e-06, 'epoch': 0.77}


 78%|███████▊  | 1170/1498 [13:07<03:34,  1.53it/s]

{'loss': 0.0, 'grad_norm': 0.029274195432662964, 'learning_rate': 2.333099086437105e-06, 'epoch': 0.78}


 79%|███████▉  | 1180/1498 [13:14<03:27,  1.53it/s]

{'loss': 0.0442, 'grad_norm': 0.12505193054676056, 'learning_rate': 2.2628250175685173e-06, 'epoch': 0.79}


 79%|███████▉  | 1190/1498 [13:20<03:20,  1.54it/s]

{'loss': 0.024, 'grad_norm': 0.6312962770462036, 'learning_rate': 2.19255094869993e-06, 'epoch': 0.79}


 80%|████████  | 1200/1498 [13:27<03:14,  1.53it/s]

{'loss': 0.0362, 'grad_norm': 4.650592745747417e-05, 'learning_rate': 2.122276879831342e-06, 'epoch': 0.8}


 81%|████████  | 1210/1498 [13:33<03:08,  1.53it/s]

{'loss': 0.0169, 'grad_norm': 4.689202785491943, 'learning_rate': 2.052002810962755e-06, 'epoch': 0.81}


 81%|████████▏ | 1220/1498 [13:40<03:01,  1.53it/s]

{'loss': 0.0713, 'grad_norm': 0.0010695622768253088, 'learning_rate': 1.9817287420941675e-06, 'epoch': 0.81}


 82%|████████▏ | 1230/1498 [13:46<02:53,  1.54it/s]

{'loss': 0.0207, 'grad_norm': 0.00016365396731998771, 'learning_rate': 1.9114546732255797e-06, 'epoch': 0.82}


 83%|████████▎ | 1240/1498 [13:53<02:48,  1.53it/s]

{'loss': 0.0017, 'grad_norm': 0.0009108020458370447, 'learning_rate': 1.8411806043569924e-06, 'epoch': 0.83}


 83%|████████▎ | 1250/1498 [13:59<02:42,  1.53it/s]

{'loss': 0.0018, 'grad_norm': 0.000542948953807354, 'learning_rate': 1.770906535488405e-06, 'epoch': 0.83}


 84%|████████▍ | 1260/1498 [14:06<02:35,  1.53it/s]

{'loss': 0.0178, 'grad_norm': 11.916707992553711, 'learning_rate': 1.7006324666198173e-06, 'epoch': 0.84}


 85%|████████▍ | 1270/1498 [14:12<02:28,  1.53it/s]

{'loss': 0.0008, 'grad_norm': 0.1760421097278595, 'learning_rate': 1.63035839775123e-06, 'epoch': 0.85}


 85%|████████▌ | 1280/1498 [14:19<02:21,  1.54it/s]

{'loss': 0.0315, 'grad_norm': 8.214435577392578, 'learning_rate': 1.5600843288826426e-06, 'epoch': 0.85}


 86%|████████▌ | 1290/1498 [14:25<02:15,  1.53it/s]

{'loss': 0.0021, 'grad_norm': 0.005861240904778242, 'learning_rate': 1.4898102600140549e-06, 'epoch': 0.86}


 87%|████████▋ | 1300/1498 [14:32<02:09,  1.53it/s]

{'loss': 0.0339, 'grad_norm': 0.0013637726660817862, 'learning_rate': 1.4195361911454676e-06, 'epoch': 0.87}


 87%|████████▋ | 1310/1498 [14:38<02:02,  1.53it/s]

{'loss': 0.027, 'grad_norm': 5.392150796978967e-06, 'learning_rate': 1.3492621222768798e-06, 'epoch': 0.87}


 88%|████████▊ | 1320/1498 [14:45<01:56,  1.53it/s]

{'loss': 0.0409, 'grad_norm': 0.14365530014038086, 'learning_rate': 1.2789880534082925e-06, 'epoch': 0.88}


 89%|████████▉ | 1330/1498 [14:51<01:49,  1.53it/s]

{'loss': 0.005, 'grad_norm': 8.911045733839273e-05, 'learning_rate': 1.208713984539705e-06, 'epoch': 0.89}


 89%|████████▉ | 1340/1498 [14:58<01:43,  1.53it/s]

{'loss': 0.0254, 'grad_norm': 13.946990013122559, 'learning_rate': 1.1384399156711176e-06, 'epoch': 0.89}


 90%|█████████ | 1350/1498 [15:05<01:36,  1.53it/s]

{'loss': 0.033, 'grad_norm': 0.0004055551835335791, 'learning_rate': 1.06816584680253e-06, 'epoch': 0.9}


 91%|█████████ | 1360/1498 [15:11<01:30,  1.53it/s]

{'loss': 0.0197, 'grad_norm': 3.191890239715576, 'learning_rate': 9.978917779339425e-07, 'epoch': 0.91}


 91%|█████████▏| 1370/1498 [15:18<01:23,  1.53it/s]

{'loss': 0.0386, 'grad_norm': 0.012581920251250267, 'learning_rate': 9.27617709065355e-07, 'epoch': 0.91}


 92%|█████████▏| 1380/1498 [15:24<01:16,  1.54it/s]

{'loss': 0.0201, 'grad_norm': 6.996563911437988, 'learning_rate': 8.573436401967675e-07, 'epoch': 0.92}


 93%|█████████▎| 1390/1498 [15:31<01:10,  1.53it/s]

{'loss': 0.0001, 'grad_norm': 0.0003349929756950587, 'learning_rate': 7.8706957132818e-07, 'epoch': 0.93}


 93%|█████████▎| 1400/1498 [15:37<01:04,  1.53it/s]

{'loss': 0.0006, 'grad_norm': 0.0004560559755191207, 'learning_rate': 7.167955024595925e-07, 'epoch': 0.93}


 94%|█████████▍| 1410/1498 [15:44<00:57,  1.53it/s]

{'loss': 0.0079, 'grad_norm': 1.229417324066162, 'learning_rate': 6.46521433591005e-07, 'epoch': 0.94}


 95%|█████████▍| 1420/1498 [15:50<00:50,  1.53it/s]

{'loss': 0.0471, 'grad_norm': 0.0058892397210001945, 'learning_rate': 5.762473647224174e-07, 'epoch': 0.95}


 95%|█████████▌| 1430/1498 [15:57<00:44,  1.53it/s]

{'loss': 0.0009, 'grad_norm': 0.02863295190036297, 'learning_rate': 5.0597329585383e-07, 'epoch': 0.95}


 96%|█████████▌| 1440/1498 [16:03<00:37,  1.53it/s]

{'loss': 0.0463, 'grad_norm': 6.582356929779053, 'learning_rate': 4.356992269852425e-07, 'epoch': 0.96}


 97%|█████████▋| 1450/1498 [16:10<00:31,  1.53it/s]

{'loss': 0.0379, 'grad_norm': 0.04216351732611656, 'learning_rate': 3.65425158116655e-07, 'epoch': 0.97}


 97%|█████████▋| 1460/1498 [16:16<00:24,  1.53it/s]

{'loss': 0.0326, 'grad_norm': 0.019962292164564133, 'learning_rate': 2.951510892480675e-07, 'epoch': 0.97}


 98%|█████████▊| 1470/1498 [16:23<00:18,  1.53it/s]

{'loss': 0.0268, 'grad_norm': 15.850833892822266, 'learning_rate': 2.2487702037948e-07, 'epoch': 0.98}


 99%|█████████▉| 1480/1498 [16:29<00:11,  1.54it/s]

{'loss': 0.0932, 'grad_norm': 0.00018782119150273502, 'learning_rate': 1.5460295151089248e-07, 'epoch': 0.99}


 99%|█████████▉| 1490/1498 [16:36<00:05,  1.53it/s]

{'loss': 0.0028, 'grad_norm': 0.0001386524672852829, 'learning_rate': 8.432888264230499e-08, 'epoch': 0.99}


100%|██████████| 1498/1498 [16:41<00:00,  1.81it/s]07/11/2025 23:44:07 - INFO - FlagEmbedding.finetune.embedder.encoder_only.base.trainer -   Saving model checkpoint to model_training/model/checkpoint-1498
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
100%|██████████| 1498/1498 [16:53<00:00,  1.81it/s]

{'train_runtime': 1013.7359, 'train_samples_per_second': 11.818, 'train_steps_per_second': 1.478, 'train_loss': 0.039941850994951095, 'epoch': 1.0}


100%|██████████| 1498/1498 [16:54<00:00,  1.48it/s]
07/11/2025 23:44:20 - INFO - FlagEmbedding.finetune.embedder.encoder_only.base.trainer -   Saving model checkpoint to model_training/model
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


Save the model files to S3 for later usage

In [6]:
import os
file_list = []
for root, dirs, files in os.walk('model_training/model/'):
    if root=='model_training/model/':
        for file in files:
            file_list.append(file)

S3Manager.upload_bulk(
    s3_client,
    'model_training/model/',
    file_list,
    "medical-qa-data",
    "finetuned_model/"
)

---

# Evaluate
First, I'll load and generate the needed files

In [7]:
from datasets import load_dataset

queries = load_dataset("json", data_files="model_training/data/test_queries.jsonl")["train"]
corpus = load_dataset("json", data_files="model_training/data/corpus.jsonl")["train"]
qrels = load_dataset("json", data_files="model_training/data/test_qrels.jsonl")["train"]

queries_text = queries["text"]
corpus_text = [text for text in corpus["text"]]
qrels_dict = {}
for line in qrels:
    if line['qid'] not in qrels_dict:
        qrels_dict[str(line['qid'])] = {}
    for doc in line['docid']:
        qrels_dict[str(line['qid'])][str(doc)] = line['relevance']

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

I will now load the needed metrics from FlagEmbedding, the wrapper to load models with the library and a custom made wrapper "Validator" in charge of processing the needed steps to embed and index (in faiss) the queries and corpus, to then perform the semantic search to bring the top 10 neighbors.

In [10]:
!pip install -q opensearch_py==3.0.0

As we will use this retrieval without any reranking, I will focus my analysis to 1 and 3 k (nevertheless I compute 5 & 10 too)

In [15]:
from FlagEmbedding.abc.evaluation.utils import evaluate_metrics, evaluate_mrr
from FlagEmbedding import FlagModel
from model_training.utils import Validator

k_values = [1, 3, 5, 10]
raw_name = f"BAAI/{CONFIG['EMBEDDING_MODEL']}"
finetuned_path = "model_training/model/"

#### Raw model w/o prompting
Let's 1st check the stats for the raw model without any pormpting technique

In [16]:
raw_model = FlagModel(
    raw_name, 
    query_instruction_for_retrieval="",
    devices=[0],
    use_fp16=True
)


results, _ = Validator.search(raw_model, queries_text, corpus_text, queries)
eval_res = evaluate_metrics(qrels_dict, results, k_values)
mrr = evaluate_mrr(qrels_dict, results, k_values)

for res in eval_res:
    print(res)
print(mrr)

pre tokenize: 100%|██████████| 12/12 [00:00<00:00, 31.16it/s]
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Inference Embeddings: 100%|██████████| 12/12 [00:00<00:00, 40.79it/s]
pre tokenize: 100%|██████████| 74/74 [00:04<00:00, 15.96it/s]
Inference Embeddings: 100%|██████████| 74/74 [00:21<00:00,  3.48it/s]
100%|██████████| 94/94 [00:01<00:00, 91.57it/s]


defaultdict(<class 'list'>, {'NDCG@1': 0.62016, 'NDCG@3': 0.6928, 'NDCG@5': 0.71499, 'NDCG@10': 0.73232})
defaultdict(<class 'list'>, {'MAP@1': 0.55161, 'MAP@3': 0.64691, 'MAP@5': 0.664, 'MAP@10': 0.67436})
defaultdict(<class 'list'>, {'Recall@1': 0.55161, 'Recall@3': 0.74314, 'Recall@5': 0.80182, 'Recall@10': 0.85531})
defaultdict(<class 'list'>, {'P@1': 0.62016, 'P@3': 0.29194, 'P@5': 0.19339, 'P@10': 0.10621})
defaultdict(<class 'list'>, {'MRR@1': 0.62016, 'MRR@3': 0.70572, 'MRR@5': 0.71795, 'MRR@10': 0.72384})


We can see that the metrics @1 and @3 are standard for a model that hasn't been finetuned in out task nor is being used correctly as it's not preceeded by a prompt

### Raw model with prompting

In [17]:
prompted_raw_model = FlagModel(
    raw_name, 
    query_instruction_for_retrieval=CONFIG['QUERY_INSTRUCTION_AT_RETRIEVAL'],
    devices=[0],
    use_fp16=True
)


results, _ = Validator.search(prompted_raw_model, queries_text, corpus_text, queries)
eval_res = evaluate_metrics(qrels_dict, results, k_values)
mrr = evaluate_mrr(qrels_dict, results, k_values)

for res in eval_res:
    print(res)
print(mrr)

pre tokenize: 100%|██████████| 12/12 [00:00<00:00, 109.04it/s]
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Inference Embeddings: 100%|██████████| 12/12 [00:00<00:00, 31.36it/s]
pre tokenize: 100%|██████████| 74/74 [00:04<00:00, 15.98it/s]
Inference Embeddings: 100%|██████████| 74/74 [00:21<00:00,  3.49it/s]
100%|██████████| 94/94 [00:01<00:00, 92.55it/s]


defaultdict(<class 'list'>, {'NDCG@1': 0.63718, 'NDCG@3': 0.705, 'NDCG@5': 0.72482, 'NDCG@10': 0.74046})
defaultdict(<class 'list'>, {'MAP@1': 0.56459, 'MAP@3': 0.65901, 'MAP@5': 0.67493, 'MAP@10': 0.68473})
defaultdict(<class 'list'>, {'Recall@1': 0.56459, 'Recall@3': 0.75307, 'Recall@5': 0.80703, 'Recall@10': 0.85567})
defaultdict(<class 'list'>, {'P@1': 0.63718, 'P@3': 0.29717, 'P@5': 0.19546, 'P@10': 0.10684})
defaultdict(<class 'list'>, {'MRR@1': 0.63718, 'MRR@3': 0.7189, 'MRR@5': 0.73017, 'MRR@10': 0.7352})


When we add a prompt we see increases in the metrics, mostly @3, but let's check now our fine tuned model

### Finetuned model with prompting

In [18]:
ft_model = FlagModel(
    finetuned_path, 
    query_instruction_for_retrieval=CONFIG['QUERY_INSTRUCTION_AT_RETRIEVAL'],
    devices=[0],
    use_fp16=True
)

results, corpus_embeddings = Validator.search(ft_model, queries_text, corpus_text, queries)
eval_res = evaluate_metrics(qrels_dict, results, k_values)
mrr = evaluate_mrr(qrels_dict, results, k_values)

for res in eval_res:
    print(res)
print(mrr)

pre tokenize: 100%|██████████| 12/12 [00:00<00:00, 28.85it/s] 
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Inference Embeddings: 100%|██████████| 12/12 [00:00<00:00, 31.20it/s]
pre tokenize: 100%|██████████| 74/74 [00:04<00:00, 15.74it/s]
Inference Embeddings: 100%|██████████| 74/74 [00:21<00:00,  3.48it/s]
100%|██████████| 94/94 [00:01<00:00, 92.15it/s]


defaultdict(<class 'list'>, {'NDCG@1': 0.78605, 'NDCG@3': 0.82023, 'NDCG@5': 0.82805, 'NDCG@10': 0.83965})
defaultdict(<class 'list'>, {'MAP@1': 0.70468, 'MAP@3': 0.7826, 'MAP@5': 0.79265, 'MAP@10': 0.80183})
defaultdict(<class 'list'>, {'Recall@1': 0.70468, 'Recall@3': 0.84114, 'Recall@5': 0.86802, 'Recall@10': 0.90394})
defaultdict(<class 'list'>, {'P@1': 0.78605, 'P@3': 0.33311, 'P@5': 0.21155, 'P@10': 0.11452})
defaultdict(<class 'list'>, {'MRR@1': 0.78605, 'MRR@3': 0.84001, 'MRR@5': 0.84423, 'MRR@10': 0.84715})


We see that the metrics increased greatly (eg: NDCG@1 increased +20% and NDCG@3 +17%). We can say that the embedding fine tune worked correclty and we ended up with a decent model to be used in a RAG system.

---

# Ingest into OpenSearch

Finally we will ingest our finetuned embeddings into opensearch so the can be used at inference time

In [22]:
from model_training.utils import OpenSearchManager
index_name = 'embedding-finetuned-v1'
host = CONFIG['OPENSEARCH_INDEX_URL'].removeprefix('https://').removesuffix('/'+index_name)
opens_mngr = OpenSearchManager(host)

In [23]:
opens_mngr.create_index(index_name)



In [24]:
opens_mngr.bulk_ingestion(index_name, corpus_text, corpus_embeddings)

Now we are ready to consume these vectors at inference time