# Dense Retriever (SBERT)

In [1]:
! pip install sentence-transformers datasets

Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cusolver_cu12-11

## Load Model

- I will use `all-MiniLM-L6-v2` : This is a sentence-transformers model: It maps sentences & paragraphs to a `384 dimensional` dense vector space and can be used for tasks like clustering or semantic search.

https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2

In [2]:
from sentence_transformers import SentenceTransformer, util
import torch


model = SentenceTransformer("all-MiniLM-L6-v2")

2025-05-20 09:50:53.739996: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747734653.935006      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747734653.988396      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [3]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

## Load & Prepare Data

- I will use `MedQuAD` it contain question and answer in medical field

https://huggingface.co/datasets/lavita/MedQuAD

In [4]:
from datasets import load_dataset

dataset = load_dataset("lavita/MedQuAD", split="train[:2000]")

dataset

README.md:   0%|          | 0.00/2.77k [00:00<?, ?B/s]

(…)-00000-of-00001-e36383d177026d53.parquet:   0%|          | 0.00/10.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/47441 [00:00<?, ? examples/s]

Dataset({
    features: ['document_id', 'document_source', 'document_url', 'category', 'umls_cui', 'umls_semantic_types', 'umls_semantic_group', 'synonyms', 'question_id', 'question_focus', 'question_type', 'question', 'answer'],
    num_rows: 2000
})

In [5]:
answers = dataset["answer"]
questions = dataset["question"]

In [6]:
print(questions[0])
print("=="*30)
print(answers[0])

What is (are) keratoderma with woolly hair ?
Keratoderma with woolly hair is a group of related conditions that affect the skin and hair and in many cases increase the risk of potentially life-threatening heart problems. People with these conditions have hair that is unusually coarse, dry, fine, and tightly curled. In some cases, the hair is also sparse. The woolly hair texture typically affects only scalp hair and is present from birth. Starting early in life, affected individuals also develop palmoplantar keratoderma, a condition that causes skin on the palms of the hands and the soles of the feet to become thick, scaly, and calloused.  Cardiomyopathy, which is a disease of the heart muscle, is a life-threatening health problem that can develop in people with keratoderma with woolly hair. Unlike the other features of this condition, signs and symptoms of cardiomyopathy may not appear until adolescence or later. Complications of cardiomyopathy can include an abnormal heartbeat (arrhyt

## Embedding

* Do embedding for answers only

In [7]:
doc_embeddings_aswer = model.encode(answers, convert_to_tensor=True)
doc_embeddings_aswer.shape

Batches:   0%|          | 0/63 [00:00<?, ?it/s]

torch.Size([2000, 384])

- Combine between questions and answers and embedding them

In [8]:
combined_text = [q + " " + a for q, a in zip(questions, answers)]

# Embed the combined documents
doc_embeddings = model.encode(combined_text, convert_to_tensor=True)
doc_embeddings.shape

Batches:   0%|          | 0/63 [00:00<?, ?it/s]

torch.Size([2000, 384])

In [9]:
import torch
from sentence_transformers import SentenceTransformer, util


def retrieve_documents(query, k, model, doc_embeddings, questions, answers):
    """
    Retrieve top-k relevant answers to a query using dense embeddings.

    Args:
        query (str): The input query string.
        k (int): Number of top results to return.
        model (SentenceTransformer): Pretrained Sentence-BERT model.
        doc_embeddings (torch.Tensor): Precomputed document embeddings.
        questions (List[str]): List of original questions.
        answers (List[str]): List of original answers.

    Returns:
        List[Tuple[str, str, float]]: A list of tuples (question, answer, similarity score).
    """
    # Encode query
    query_embedding = model.encode(query, convert_to_tensor=True)

    # Compute cosine similarity
    cos_scores = util.pytorch_cos_sim(query_embedding, doc_embeddings)
    top_results = torch.topk(cos_scores, k=k)

    # Collect results
    results = []
    for score, idx in zip(top_results[0][0], top_results[1][0]):
        results.append((questions[idx], answers[idx], score.item()))

    return results

In [10]:
query = "	How many people are affected by primary ciliary dyskinesia ?"
top_k = 5

# Run retrieval
results = retrieve_documents(query, top_k, model, doc_embeddings_aswer, questions, answers)

# Print results
for q, a, score in results:
    print(f"\nScore: {score:.4f}")
    print(f"Q: {q}")
    print(f"A: {a}")

Batches:   0%|          | 0/1 [00:00<?, ?it/s]


Score: 0.8613
Q: How many people are affected by primary ciliary dyskinesia ?
A: Primary ciliary dyskinesia occurs in approximately 1 in 16,000 individuals.

Score: 0.6707
Q: What are the treatments for primary ciliary dyskinesia ?
A: These resources address the diagnosis or management of primary ciliary dyskinesia:  - Gene Review: Gene Review: Primary Ciliary Dyskinesia  - Genetic Testing Registry: Ciliary dyskinesia, primary, 17  - Genetic Testing Registry: Kartagener syndrome  - Genetic Testing Registry: Primary ciliary dyskinesia   These resources from MedlinePlus offer information about the diagnosis and management of various health conditions:  - Diagnostic Tests  - Drug Therapy  - Surgery and Rehabilitation  - Genetic Counseling   - Palliative Care

Score: 0.6320
Q: How many people are affected by familial paroxysmal kinesigenic dyskinesia ?
A: Familial paroxysmal kinesigenic dyskinesia is estimated to occur in 1 in 150,000 individuals. For unknown reasons, this condition affec

- **Save**

In [11]:
import joblib

# Save embeddings and original data
joblib.dump(doc_embeddings_aswer, 'doc_embeddings_aswer.pt')          # Tensor  # only answers
joblib.dump(doc_embeddings, 'doc_embeddings.pt')          # Tensor # questions + answers
joblib.dump(questions, 'questions.pkl')                    # List
joblib.dump(answers, 'answers.pkl')

['answers.pkl']

## Inference

In [12]:
import torch
import joblib
from sentence_transformers import SentenceTransformer, util

# Load saved items
doc_embeddings_aswer = joblib.load('doc_embeddings_aswer.pt')
doc_embeddings = joblib.load('doc_embeddings.pt')
questions = joblib.load('questions.pkl')
answers = joblib.load('answers.pkl')

# Load the model
model = SentenceTransformer('all-MiniLM-L6-v2')

# Retrieval function (same as before)
def retrieve_documents(query, top_k, model, doc_embeddings, questions, answers):
    query_embedding = model.encode(query, convert_to_tensor=True)
    cos_scores = util.pytorch_cos_sim(query_embedding, doc_embeddings)
    top_results = torch.topk(cos_scores, k=top_k)

    results = []
    for score, idx in zip(top_results[0][0], top_results[1][0]):
        results.append((questions[idx], answers[idx], score.item()))
    return results

In [13]:
# Use for inference
query = "What are the genetic changes related to capillary malformation-arteriovenous malformation syndrome ?"
top_k = 5

results = retrieve_documents(query, top_k, model, doc_embeddings, questions, answers)

for q, a, score in results:
    print(f"\nScore: {score:.4f}")
    print(f"Q: {q}")
    print(f"A: {a}")

Batches:   0%|          | 0/1 [00:00<?, ?it/s]


Score: 0.8456
Q: Is capillary malformation-arteriovenous malformation syndrome inherited ?
A: This condition is inherited in an autosomal dominant pattern, which means one copy of the altered gene in each cell is sufficient to cause the disorder.  In most cases, an affected person inherits the mutation from one affected parent. Other cases result from new mutations in the gene and occur in people with no history of the disorder in their family.

Score: 0.7341
Q: What are the treatments for capillary malformation-arteriovenous malformation syndrome ?
A: These resources address the diagnosis or management of CM-AVM:  - Gene Review: Gene Review: RASA1-Related Disorders  - Genetic Testing Registry: Capillary malformation-arteriovenous malformation   These resources from MedlinePlus offer information about the diagnosis and management of various health conditions:  - Diagnostic Tests  - Drug Therapy  - Surgery and Rehabilitation  - Genetic Counseling   - Palliative Care

Score: 0.7070
Q: W