In [None]:
from pymilvus import MilvusClient
from transformers import AutoTokenizer, AutoModel
import torch

class MilvusSearcher:
    def __init__(self, uri: str, collection_name: str):
        self.milvus_client = MilvusClient(uri=uri)
        self.collection_name = collection_name

        # Check collection existence
        if not self.milvus_client.has_collection(self.collection_name):
            print(f"Collection '{self.collection_name}' does not exist.")
            return
        else:
            print(f"Collection '{self.collection_name}' exists.")

        # Check for GPU availability
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained("pritamdeka/S-PubMedBert-MS-MARCO")
        self.model = AutoModel.from_pretrained("pritamdeka/S-PubMedBert-MS-MARCO").to(self.device)
        self.model.to(self.device)

    def encode_text(self, title, abstract):
        """Encode text using PubMedBERT with GPU support."""
        margin = 12
        max_length = 512 - margin # Maximum length for PubMedBERT
        text = f"{title} {abstract}"
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=max_length)

        # Move inputs to the same device as model
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs)
    
        # Move embeddings back to CPU for numpy conversion
        embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
        return embeddings

    def search(self, query: str, limit: int = 10):
        query_vector = self.encode_text(query, "")
        search_res = self.milvus_client.search(
            collection_name=self.collection_name,
            data=[query_vector],
            limit=limit,
            search_params={"metric_type": "COSINE", "params": {"nprobe": 10}},
            output_fields=["doc"],
        )
        return search_res

In [4]:
milvus_searcher = MilvusSearcher(uri="./milvus_pmc.db", collection_name = "pmc_trec_2016")

Collection 'pmc_trec_2016' exists.
Using device: cuda


In [8]:
query_1 = "A 58-year-old African-American woman presents to the ER with episodic pressing/burning anterior chest pain that began two days earlier for the first time in her life. The pain started while she was walking, radiates to the back, and is accompanied by nausea, diaphoresis and mild dyspnea, but is not increased on inspiration. The latest episode of pain ended half an hour prior to her arrival. She is known to have hypertension and obesity. She denies smoking, diabetes, hypercholesterolemia, or a family history of heart disease. She currently takes no medications. Physical examination is normal. The EKG shows nonspecific changes."
query_2 = "An 8-year-old male presents in March to the ER with fever up to 39 C, dyspnea and cough for 2 days. He has just returned from a 5 day vacation in Colorado. Parents report that prior to the onset of fever and cough, he had loose stools. He denies upper respiratory tract symptoms. On examination he is in respiratory distress and has bronchial respiratory sounds on the left. A chest x-ray shows bilateral lung infiltrates."

search_res = milvus_searcher.search(query_2)

for hits in search_res:
    for hit in hits:
        print(f"ID: {hit.id}, Distance: {hit.distance}, Doc: {hit.entity.get('doc')}")
    print("-------------------")

ID: 4251250, Distance: 0.9582884311676025, Doc: {'title': 'A Woman with Dyspnea and Hemoptysis', 'abstract': 'A 55-year-old female presented to the emergency department at a small community hospital with cough, fever, dyspnea and blood-streaked sputum. A chest radiograph was ordered. She was diagnosed with pneumonia and discharged home with antibiotics. She returned three days later, afebrile, with worsening dyspnea and gross hemoptysis. She was found to have a murmur reported as chronic but had never been evaluated by echocardiography. A computed tomography chest and echocardiography were performed ( Figure Video'}
ID: 4663880, Distance: 0.9546623229980469, Doc: {'title': 'A 35-year old woman with productive cough and breathlessness', 'abstract': 'A 35-year-old lady was seen in the outpatient clinic owing to fever, cough with mucopurulent expectoration, and breathlessness for the duration of 1 month. She had history of similar episodes treated with antibiotics four times during last 2

### Perform Evaluation Search On Trec 2016 Queries

load queries from Trec 2014

In [10]:
import xml.etree.ElementTree as ET

topics_path = "topics2014.xml"
tree = ET.parse(topics_path)
root = tree.getroot()

queries = []
for topic in root.findall("topic"):
    topic_id = topic.get("number") or topic.findtext("number")
    query_text = topic.findtext("description")
    queries.append({"id": topic_id.strip(), "query": query_text.strip()})

print(f"Loaded {len(queries)} queries")

Loaded 30 queries


Load Qrels

In [11]:
import pandas as pd

qrels_path = "qrels-treceval-2014.txt"
qrels_df = pd.read_csv(
    qrels_path, 
    sep=r"\s+", 
    names=["topic", "iter", "docid", "relevance"]
)

print(qrels_df.head())

   topic  iter    docid  relevance
0      1     0  1033658          0
1      1     0  1033958          0
2      1     0  1034932          0
3      1     0  1034982          0
4      1     0  1035890          0


Perform search

In [None]:
results = []

for query in queries:
    search_res = milvus_searcher.search(query["query"])

    for hit in search_res[0]:
        results.append({
            "topic": query["id"],
            "docid": hit["id"],
            "score": hit["distance"]
        })

results_df = pd.DataFrame(results)
print(results_df.head())

  topic    docid     score
0     1  4168866  0.953263
1     1  4326113  0.952289
2     1  4163392  0.951331
3     1  3807683  0.949312
4     1  2917397  0.948998


Perform Evaluation

In [18]:
import numpy as np
import pandas as pd

# Assuming:
# results_df = DataFrame with columns ["topic", "docid", "score"]
# qrels_df = DataFrame with columns ["topic", "iter", "docid", "relevance"]

def dcg_at_k(rels, k):
    rels = rels[:k]
    return np.sum((2**np.array(rels) - 1) / np.log2(np.arange(2, len(rels) + 2)))

def ndcg_at_k(rels, k):
    if len(rels) == 0:
        return 0.0
    dcg = dcg_at_k(rels, k)
    idcg = dcg_at_k(sorted(rels, reverse=True), k)
    return dcg / idcg if idcg > 0 else 0.0

def average_precision(rels):
    """rels is a binary or graded relevance list sorted by retrieved order"""
    hits = np.array(rels) > 0
    if hits.sum() == 0:
        return 0.0
    precisions = [hits[:i+1].sum() / (i+1) for i in range(len(hits)) if hits[i]]
    return np.mean(precisions)

metrics = {
    "MAP": [],
    "nDCG@5": [],
    "nDCG@10": [],
    "Acc@5": [],
    "Acc@10": [],
    "Recall@5": [],
    "Recall@10": []
}


# Ensure topic IDs are both strings
results_df["topic"] = results_df["topic"].astype(str)
qrels_df["topic"] = qrels_df["topic"].astype(str)

results_df["docid"] = results_df["docid"].astype(str)
qrels_df["docid"] = qrels_df["docid"].astype(str)


for topic_id, group in results_df.groupby("topic"):
    # Merge search results with qrels to get relevance
    merged = pd.merge(group, qrels_df, on=["topic", "docid"], how="left").fillna(0)
    rels = merged["relevance"].tolist()

    # MAP
    metrics["MAP"].append(average_precision(rels))

    # nDCG
    metrics["nDCG@5"].append(ndcg_at_k(rels, 5))
    metrics["nDCG@10"].append(ndcg_at_k(rels, 10))

    # Accuracy@k (1 if at least one relevant doc is in top-k)
    metrics["Acc@5"].append(1.0 if np.any(np.array(rels[:5]) > 0) else 0.0)
    metrics["Acc@10"].append(1.0 if np.any(np.array(rels[:10]) > 0) else 0.0)

    # Recall@k (relevant docs retrieved / total relevant docs)
    total_rel = (qrels_df.query("topic == @topic_id")["relevance"] > 0).sum()
    metrics["Recall@5"].append(min(np.sum(np.array(rels[:5]) > 0) / total_rel, 1.0) if total_rel > 0 else 0.0)
    metrics["Recall@10"].append(min(np.sum(np.array(rels[:10]) > 0) / total_rel, 1.0) if total_rel > 0 else 0.0)

# Compute mean scores
final_scores = {k: np.mean(v) for k, v in metrics.items()}

print("Retrieval Results:")
for k, v in final_scores.items():
    print(f"{k}: {v:.4f}")


Retrieval Results:
MAP: 0.2154
nDCG@5: 0.2090
nDCG@10: 0.2842
Acc@5: 0.3000
Acc@10: 0.4667
Recall@5: 0.0029
Recall@10: 0.0065
