In [None]:

!pip install -q rank_bm25 sentence-transformers datasets faiss-cpu scikit-learn

import pandas as pd
import numpy as np
import faiss
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
import random


from datasets import load_dataset
ds = load_dataset("FreedomIntelligence/Disease_Database", "en")
ds['train'] = ds['train'].map(lambda x: {'symptom_text': x['common_symptom']})
df = pd.DataFrame(ds['train'])


def generate_synthetic_symptoms(symptom_text, n_variations=5):
    words = [w.strip() for w in symptom_text.split(",") if w.strip()]
    variations = []
    for _ in range(n_variations):
        shuffled = random.sample(words, len(words))
        variation = ", ".join(shuffled)
        variation = f"I am experiencing {variation}"
        variations.append(variation)
    return variations

synthetic_rows = []
for idx, row in df.iterrows():
    disease = row['disease']
    symptoms = row['symptom_text']
    synthetics = generate_synthetic_symptoms(symptoms, n_variations=5)
    for s in synthetics:
        synthetic_rows.append({'symptom_text': s, 'disease': disease})

df_synthetic = pd.DataFrame(synthetic_rows)


df_expanded = pd.concat([df, df_synthetic], ignore_index=True)
print(f"Expanded retriever dataset size: {len(df_expanded)}")


train_df, test_df = train_test_split(df_expanded, test_size=0.1, random_state=42)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

# --------------------------
# 4️⃣ Prepare BM25 retriever
# --------------------------
tokenized_corpus = [doc.lower().split() for doc in train_df['symptom_text']]
bm25 = BM25Okapi(tokenized_corpus)

# --------------------------
# 5️⃣ Prepare FAISS retriever
# --------------------------
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
train_embeddings = model.encode(train_df['symptom_text'], convert_to_numpy=True, show_progress_bar=True)
faiss.normalize_L2(train_embeddings)

dim = train_embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(train_embeddings)

# --------------------------
# 6️⃣ Helper: Deduplicate results
# --------------------------
def dedup(results):
    seen = set()
    deduped = []
    for r in results:
        if r not in seen:
            deduped.append(r)
            seen.add(r)
    return deduped

# --------------------------
# 7️⃣ Retrieval functions
# --------------------------
def bm25_retrieve(user_symptoms, k=5):
    tokenized_query = user_symptoms.lower().split()
    scores = bm25.get_scores(tokenized_query)
    top_indices = np.argsort(scores)[::-1][:k*2]  # get a bit more, then dedup
    results = [train_df.iloc[i]['disease'] for i in top_indices]
    return dedup(results)[:k]

def faiss_retrieve(user_symptoms, k=5):
    user_emb = model.encode([user_symptoms], convert_to_numpy=True)
    faiss.normalize_L2(user_emb)
    sims, indices = index.search(user_emb, k*2)
    results = [train_df.iloc[i]['disease'] for i in indices[0]]
    return dedup(results)[:k]

def hybrid_retrieve(user_symptoms, k=5, alpha=0.5):
    # BM25 scores
    tokenized_query = user_symptoms.lower().split()
    bm25_scores = np.array(bm25.get_scores(tokenized_query))
    bm25_scores = bm25_scores / (bm25_scores.max() + 1e-8)

    # FAISS scores
    user_emb = model.encode([user_symptoms], convert_to_numpy=True)
    faiss.normalize_L2(user_emb)
    sims, indices = index.search(user_emb, len(train_df))
    faiss_scores = np.zeros(len(train_df))
    faiss_scores[indices[0]] = sims[0]
    faiss_scores = faiss_scores / (faiss_scores.max() + 1e-8)

    # Combine
    hybrid_scores = alpha * faiss_scores + (1 - alpha) * bm25_scores
    top_indices = np.argsort(hybrid_scores)[::-1][:k*2]
    results = [train_df.iloc[i]['disease'] for i in top_indices]
    return dedup(results)[:k]

# --------------------------
# 8️⃣ Evaluation function
# --------------------------
def compute_recall_at_k(method_fn, test_pairs, k=5):
    correct = 0
    for user_symptoms, true_disease in test_pairs:
        retrieved = method_fn(user_symptoms, k=k)
        if true_disease in retrieved:
            correct += 1
    return correct / len(test_pairs)

test_pairs = list(zip(test_df['symptom_text'], test_df['disease']))

print("\n🔹 Evaluating retrievers (Recall@5)...")
bm25_recall = compute_recall_at_k(bm25_retrieve, test_pairs, k=5)
faiss_recall = compute_recall_at_k(faiss_retrieve, test_pairs, k=5)
hybrid_recall = compute_recall_at_k(lambda q, k: hybrid_retrieve(q, k, alpha=0.6), test_pairs, k=5)

print(f"BM25 Recall@5:  {bm25_recall:.4f}")
print(f"FAISS Recall@5: {faiss_recall:.4f}")
print(f"Hybrid Recall@5 (α=0.6): {hybrid_recall:.4f}")

# --------------------------
# 9️⃣ Example queries
# --------------------------
queries = [
    "I have nausea and acid reflux",
    "My chest feels tight and I am short of breath",
    "I am coughing with fever and chills",
    "I feel dizzy and have blurred vision",
    "There is pain in my lower back and frequent urination"
]

for q in queries:
    print(f"\nQuery: {q}")
    print("BM25 →", bm25_retrieve(q, k=5))
    print("FAISS →", faiss_retrieve(q, k=5))
    print("Hybrid →", hybrid_retrieve(q, k=5, alpha=0.6))


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m51.8 MB/s[0m eta [36m0:00:00[0m
[?25h

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/719 [00:00<?, ?B/s]

disease_database_en.json:   0%|          | 0.00/23.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9620 [00:00<?, ? examples/s]

Map:   0%|          | 0/9620 [00:00<?, ? examples/s]

Expanded retriever dataset size: 57720


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/1624 [00:00<?, ?it/s]


🔹 Evaluating retrievers (Recall@5)...
BM25 Recall@5:  0.9823
FAISS Recall@5: 0.9870
Hybrid Recall@5 (α=0.6): 0.9936

Query: I have nausea and acid reflux
BM25 → ['Postpartum Dyspepsia', 'Gastrointestinal Hypoperfusion', 'Nausea and Vomiting', 'Posterior Wall Perforating Ulcer']
FAISS → ['Gastrointestinal Hypoperfusion', 'Non-ulcer Dyspepsia', 'Postpartum Dyspepsia', 'Nausea and Vomiting']
Hybrid → ['Gastrointestinal Hypoperfusion', 'Postpartum Dyspepsia', 'Remnant Stomach Syndrome', 'Nausea and Vomiting']

Query: My chest feels tight and I am short of breath
BM25 → ['Aortic Valve Calcification', 'Pulmonary Valve Stenosis', 'Variant Angina in the Elderly', 'Pregnancy Complicated with Heart Disease', 'Splenic Tumor']
FAISS → ['Pregnancy Complicated with Heart Disease', 'Aortic Valve Calcification', 'Tension Pneumothorax', 'Asymptomatic Myocardial Ischemia', 'Old Myocardial Infarction']
Hybrid → ['Aortic Valve Calcification', 'Pregnancy Complicated with Heart Disease', 'Old Myocardial In