# 3) Dense Retriever Training with Hard Negatives (PyTorch + FAISS)

In [1]:
%%capture
!pip -q install --upgrade pip
!pip -q install datasets transformers sentence-transformers faiss-cpu rank-bm25 torchmetrics scikit-learn lightgbm langdetect unidecode pandas matplotlib tqdm nltk

In [2]:

import numpy as np, torch, faiss
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
SEED=42; np.random.seed(SEED); torch.manual_seed(SEED)

<torch._C.Generator at 0x7996f7f50b90>

In [3]:

train_ds = load_dataset("ms_marco", "v2.1", split="train[:2%]")
pairs = []
for r in train_ds:
    q = r["query"]
    doc = None
    if r.get("wellFormedAnswers") and len(r["wellFormedAnswers"])>0:
        doc = r["wellFormedAnswers"][0]
    elif r.get("passages") and r["passages"]["passage_text"]:
        doc = r["passages"]["passage_text"][0]
    if doc: pairs.append(InputExample(texts=[q, doc]))
len(pairs)

README.md: 0.00B [00:00, ?B/s]

v2.1/validation-00000-of-00001.parquet:   0%|          | 0.00/210M [00:00<?, ?B/s]

v2.1/train-00000-of-00007.parquet:   0%|          | 0.00/240M [00:00<?, ?B/s]

v2.1/train-00001-of-00007.parquet:   0%|          | 0.00/240M [00:00<?, ?B/s]

v2.1/train-00002-of-00007.parquet:   0%|          | 0.00/241M [00:00<?, ?B/s]

v2.1/train-00003-of-00007.parquet:   0%|          | 0.00/242M [00:00<?, ?B/s]

v2.1/train-00004-of-00007.parquet:   0%|          | 0.00/242M [00:00<?, ?B/s]

v2.1/train-00005-of-00007.parquet:   0%|          | 0.00/242M [00:00<?, ?B/s]

v2.1/train-00006-of-00007.parquet:   0%|          | 0.00/244M [00:00<?, ?B/s]

v2.1/test-00000-of-00001.parquet:   0%|          | 0.00/204M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/101093 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/808731 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/101092 [00:00<?, ? examples/s]

16175

In [4]:

model = SentenceTransformer("sentence-transformers/msmarco-distilbert-base-tas-b", device=device)
loader = DataLoader(pairs, batch_size=128, shuffle=True, drop_last=True)
loss = losses.MultipleNegativesRankingLoss(model)
model.fit([(loader, loss)], epochs=1, warmup_steps=100, output_path="artifacts_dense_mnr")

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/265M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mchanderjayaraman[0m ([33mchanderjayaraman-yahooinc[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss


In [5]:

corpus = [p.texts[1] for p in pairs[:40000]]
queries = [p.texts[0] for p in pairs[:2000]]
doc_vec = model.encode(corpus, batch_size=128, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=True).astype("float32")
q_vec = model.encode(queries, batch_size=128, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=True).astype("float32")
index = faiss.IndexFlatIP(doc_vec.shape[1]); index.add(doc_vec)
scores, idx = index.search(q_vec, 10)
hits = sum([np.any(row < min(len(corpus), len(queries))) for row in idx])
print("Approx Recall@10:", hits/len(queries))

Batches:   0%|          | 0/127 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Approx Recall@10: 0.974
