In [2]:
!pip install -q -r model_training/requirements.txt

In [1]:
from general_utils import load_config
import torch

CONFIG = load_config()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
from general_utils import S3Manager
s3_client = S3Manager.get_client()

In [5]:
S3Manager.download_files(
    s3_client,
    'model_training/data',
     ["training.json",
     "test_queries.jsonl",
     "corpus.jsonl",
     "test_qrels.jsonl"
    ],
    "medical-qa-data")

----

In [6]:
import subprocess

lr = float(CONFIG['LR'])
epochs = int(CONFIG['EPOCHS'])
warmup = float(CONFIG['WARMUP_RATIO'])
batch_size = int(CONFIG['BATCH_SIZE'])
embedding_model = "bge-small-en-v1.5"
query_instruction = CONFIG["QUERY_INSTRUCTION_AT_RETRIEVAL"]

cmd = [
    "torchrun", "--nproc_per_node", "1",
    "-m", "FlagEmbedding.finetune.embedder.encoder_only.base",
    "--model_name_or_path", f"BAAI/{embedding_model}",
    "--train_data", "model_training/data/training.json",
    "--query_instruction_for_retrieval", query_instruction,
    "--output_dir", "model_training/model",
    "--learning_rate", str(lr),
    "--fp16",
    "--num_train_epochs", str(epochs),
    "--per_device_train_batch_size", str(batch_size),
    "--query_max_len", "256",
    "--passage_max_len", "512",
    "--warmup_ratio", str(warmup),
    "--normalize_embeddings", "True",
    "--logging_steps", "10",
]

# see exactly what you'll run:
print(" ".join(cmd))

# actually run it
result = subprocess.run(cmd, check=True)

torchrun --nproc_per_node 1 -m FlagEmbedding.finetune.embedder.encoder_only.base --model_name_or_path BAAI/bge-small-en-v1.5 --train_data model_training/data/training.json --query_instruction_for_retrieval Represent this sentence for searching relevant passages: --output_dir model_training/model --learning_rate 1e-05 --fp16 --num_train_epochs 1 --per_device_train_batch_size 20 --query_max_len 256 --passage_max_len 512 --warmup_ratio 0.05 --normalize_embeddings True --logging_steps 10


2025-07-11 22:16:43.835107: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752272203.851190    8285 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752272203.856200    8285 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-11 22:16:43.871930: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
07/11/2025 22:16:47 - INFO - FlagEmbedding.abc.finetune.embedder.AbsRunner -   Training/evaluation paramete

{'loss': 0.5164, 'grad_norm': 11.60119915008545, 'learning_rate': 2.3333333333333336e-06, 'epoch': 0.02}


  3%|▎         | 20/599 [00:14<06:50,  1.41it/s]

{'loss': 0.5128, 'grad_norm': 7.791902542114258, 'learning_rate': 5.666666666666667e-06, 'epoch': 0.03}


  4%|▎         | 22/599 [00:16<06:49,  1.41it/s]W0711 22:17:08.925000 8283 site-packages/torch/distributed/elastic/agent/server/api.py:719] Received 2 death signal, shutting down workers
W0711 22:17:08.926000 8283 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 8285 closing signal SIGINT
[rank0]: Traceback (most recent call last):
[rank0]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]:   File "<frozen runpy>", line 88, in _run_code
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py", line 31, in <module>
[rank0]:     main()
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/FlagEmbedding/finetune/embedder/encoder_only/base/__main__.py", line 27, in main
[rank0]:     runner.run()
[rank0]:   File "/opt/conda/lib/python3.12/site-packages/FlagEmbedding/abc/finetune/embedder/AbsRunner.py", line 149, in run
[rank0]:     self.trainer.train(resume_from_checkpoint=se

KeyboardInterrupt: 

Save the model files to S3 for later usage

In [None]:
import os
file_list = []
for root, dirs, files in os.walk('model_training/model/'):
    if root=='model_training/model/':
        for file in files:
            file_list.append(file)

S3Manager.upload_bulk(
    s3_client,
    'model_training/model/',
    file_list,
    "medical-qa-data",
    "finetuned_model/"
)

---

# Evaluate

In [None]:
from datasets import load_dataset

queries = load_dataset("json", data_files="model_training/data/test_queries.jsonl")["train"]
corpus = load_dataset("json", data_files="model_training/data/corpus.jsonl")["train"]
qrels = load_dataset("json", data_files="model_training/data/test_qrels.jsonl")["train"]

queries_text = queries["text"]
corpus_text = [text for text in corpus["text"]]
qrels_dict = {}
for line in qrels:
    if line['qid'] not in qrels_dict:
        qrels_dict[str(line['qid'])] = {}
    for doc in line['docid']:
        qrels_dict[str(line['qid'])][str(doc)] = line['relevance']

In [None]:
from FlagEmbedding.abc.evaluation.utils import evaluate_metrics, evaluate_mrr
from FlagEmbedding import FlagModel
from model_training.utils import Validator

k_values = [3, 5, 10]

raw_name = "BAAI/bge-large-en-v1.5"
finetuned_path = "model_training/model/"

Raw model w/o prompting

In [None]:
raw_model = FlagModel(
    raw_name, 
    query_instruction_for_retrieval="",
    devices=[0],
    use_fp16=True
)


results, _ = Validator.search(raw_model, queries_text, corpus_text, queries)
eval_res = evaluate_metrics(qrels_dict, results, k_values)
mrr = evaluate_mrr(qrels_dict, results, k_values)

for res in eval_res:
    print(res)
print(mrr)

Raw model with prompting

In [None]:
prompted_raw_model = FlagModel(
    raw_name, 
    query_instruction_for_retrieval=CONFIG['QUERY_INSTRUCTION_AT_RETRIEVAL'],
    devices=[0],
    use_fp16=True
)


results, _ = Validator.search(prompted_raw_model, queries_text, corpus_text, queries)
eval_res = evaluate_metrics(qrels_dict, results, k_values)
mrr = evaluate_mrr(qrels_dict, results, k_values)

for res in eval_res:
    print(res)
print(mrr)

Finetuned model with prompting

In [None]:
ft_model = FlagModel(
    finetuned_path, 
    query_instruction_for_retrieval=CONFIG['QUERY_INSTRUCTION_AT_RETRIEVAL'],
    devices=[0],
    use_fp16=True
)

results, corpus_embeddings = Validator.search(ft_model, queries_text, corpus_text, queries)
eval_res = evaluate_metrics(qrels_dict, results, k_values)
mrr = evaluate_mrr(qrels_dict, results, k_values)

for res in eval_res:
    print(res)
print(mrr)

  4%|▎         | 22/599 [00:18<07:58,  1.21it/s]


---

# Ingest into OpenSearch

Finally we will ingest our finetuned embeddings into opensearch so the can be used at inference time

In [3]:
!pip install -q opensearch_py==3.0.0

In [2]:
from model_training.utils import OpenSearchManager
index_name = 'embedding-finetuned-v1'
host = CONFIG['OPENSEARCH_INDEX_URL'].removeprefix('https://').removesuffix('/'+index_name)
opens_mngr = OpenSearchManager(host)

In [3]:
opens_mngr.create_index(index_name)



In [None]:
opens_mngr.bulk_ingestion(index_name, corpus_text, corpus_embeddings)