In [1]:
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
from langchain_chroma.vectorstores import Chroma
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.manifold import TSNE
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random, json, os
from tqdm import tqdm

In [2]:
biobert_model_name = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
biobert_embedder = SentenceTransformer(biobert_model_name)

reranker_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker = CrossEncoder(reranker_model_name)

In [3]:
medquad = pd.read_csv("../data/raw/med_quad.csv")    
healthcare = pd.read_json("../data/raw/HealthCareMagic-100k.json")

print(f"MedQuAD: {len(medquad)} items")
print(f"HealthCareMagic: {len(healthcare)} items")

MedQuAD: 16407 items
HealthCareMagic: 112165 items


In [4]:
def semantic_search_biobert(query, texts, top_k=10):
    q_emb = biobert_embedder.encode([query], convert_to_tensor=True)
    doc_embs = biobert_embedder.encode(texts, convert_to_tensor=True)
    
    sims = cosine_similarity(q_emb.cpu().numpy(), doc_embs.cpu().numpy())[0]
    
    top_idx = np.argsort(sims)[::-1][:top_k]
    results = [(texts[i], float(sims[i])) for i in top_idx]
    
    return results

test_query = "What are the symptoms of diabetes?"
candidates = medquad["Question"].sample(20, random_state=0).tolist() + healthcare["input"].sample(20, random_state=0).tolist()

results = semantic_search_biobert(test_query, candidates)

print(f"Query: {test_query}\n")
for i, (txt, score) in enumerate(results):
    print(f"[{i+1}] ({score:.3f}) {txt}")

Query: What are the symptoms of diabetes?

[1] (0.572) What are the symptoms of Danon disease ?
[2] (0.525) What are the symptoms of Renal dysplasia-limb defects syndrome ?
[3] (0.451) What is (are) Limb-girdle muscular dystrophy ?
[4] (0.437) What are the treatments for Orthostatic Hypotension ?
[5] (0.417) What are the symptoms of Dystonia 7, torsion ?
[6] (0.399) What are the symptoms of Erdheim-Chester disease ?
[7] (0.391) What is (are) Tay-Sachs disease ?
[8] (0.372) What are the symptoms of Cataract, autosomal recessive congenital 2 ?
[9] (0.354) What are the treatments for Acquired Cystic Kidney Disease ?
[10] (0.347) What is (are) Gorham's disease ?


In [5]:
def rerank_with_cross_encoder(query, retrieved_results, top_n=5):
    pairs = [(query, doc) for doc, _ in retrieved_results]
    
    scores = reranker.predict(pairs)
    
    reranked = sorted(
        [(doc, float(score)) for (doc, _), score in zip(retrieved_results, scores)],
        key=lambda x: x[1],
        reverse=True
    )
    
    return reranked[:top_n]

In [6]:
query = "What are the symptoms of diabetes?"
texts = medquad["Question"].sample(50, random_state=42).tolist() + healthcare["input"].sample(50, random_state=42).tolist()

# Step 1: retrieval using BioBERT
retrieved = semantic_search_biobert(query, texts, top_k=10)
print("\n [BioBERT Embedding Results]")
for i, (t, s) in enumerate(retrieved, 1):
    print(f"[{i}] ({s:.3f}) {t}")

# Step 2: re-ranking using Cross-Encoder
reranked = rerank_with_cross_encoder(query, retrieved, top_n=5)
print("\n [Cross-Encoder Re-ranked Results]")
for i, (t, s) in enumerate(reranked, 1):
    print(f"[{i}] ({s:.3f}) {t}")


 [BioBERT Embedding Results]
[1] (0.746) What causes Causes of Diabetes ?
[2] (0.570) What are the treatments for Prevent diabetes problems: Keep your nervous system healthy ?
[3] (0.539) What are the treatments for glucose phosphate isomerase deficiency ?
[4] (0.486) What are the symptoms of Retinitis pigmentosa, deafness, mental retardation, and hypogonadism ?
[5] (0.485) What is (are) Parathyroid Disorders ?
[6] (0.452) What is the outlook for Lesch-Nyhan Syndrome ?
[7] (0.435) What are the symptoms of Charcot-Marie-Tooth disease type 2O ?
[8] (0.413) Is diastrophic dysplasia inherited ?
[9] (0.413) What are the symptoms of Krabbe disease atypical due to Saposin A deficiency ?
[10] (0.384) What is the outlook for Gaucher Disease ?

 [Cross-Encoder Re-ranked Results]
[1] (-3.980) What causes Causes of Diabetes ?
[2] (-5.343) What are the symptoms of Retinitis pigmentosa, deafness, mental retardation, and hypogonadism ?
[3] (-5.815) What are the treatments for Prevent diabetes proble