In [1]:
!pip install numpy sentence-transformers bertopic hdbscan nltk scann
import nltk
nltk.download('punkt')
import nltk
nltk.download('punkt_tab')
!pip install sentence-transformers bertopic hdbscan umap-learn scann nltk datasets
!pip install gensim

Collecting bertopic
  Downloading bertopic-0.17.3-py3-none-any.whl.metadata (24 kB)
Collecting scann
  Downloading scann-1.4.0-cp311-cp311-manylinux_2_27_x86_64.whl.metadata (5.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transform

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


Collecting gensim
  Downloading gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting numpy<2.0,>=1.18.5 (from gensim)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy<1.14.0,>=1.7.0 (from gensim)
  Downloading scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Downloading gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.7/26.7 MB[0m [31m74.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━

In [1]:

# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict, Counter
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score, precision_recall_fscore_support

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
nltk.download("punkt")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# === CLEANING & CONTEXT EXTRACTION ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    return re.sub(r'\s+', ' ', text).strip()

def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = (
                        " ".join(sentences[max(0, i - 1): i + 2])
                        if use_multi_sentence else sent.strip()
                    )
                    entity_context_pairs.append((ent_lower, context.strip()))
                    matched = True
                    break
            if not matched:
                entity_context_pairs.append((ent_lower, chunk))
    return entity_context_pairs

# === TOPIC SEARCHER CLASS ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params,
                 model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(
            self.chunks,
            self.entities_per_chunk,
            use_multi_sentence=True
        )

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        contextual_texts = [context for _, context in entity_context_pairs]
        # UPDATED: use full 768-dim embeddings for search
        contextual_embeddings = self.embedding_model.encode(
            contextual_texts, normalize_embeddings=True
        )

        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for i, idx in enumerate(neighbors):
          meta = self.topic_metadata[idx]
          topic_score = float(scores[i])  # similarity score with topic embedding
          seen = set()
          unique_sentences = []
          unique_embeddings = []

          for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
              if sent not in seen:
                  seen.add(sent)
                  unique_sentences.append(sent)
                  unique_embeddings.append(emb)

          sent_embs = np.array(unique_embeddings)
          if len(sent_embs) == 0:
              continue

          sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
          sims = np.dot(sent_embs_norm, query_emb)
          top_indices = sims.argsort()[::-1][:top_k_sents]
          top_sents = [(unique_sentences[i], float(sims[i])) for i in top_indices]

          results.append({
              "topic_id": meta["topic_id"],
              "topic_score": topic_score,  # <-- Added score here
              "entities": list(meta["entities"]),
              "sentences": top_sents,
          })

        return results

# === EVALUATION METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for sent in meta["sentences"]:
            tokens = clean_text(sent).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return coherence_model.get_coherence()

def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    all_words = [word for topic in topics for word, _ in topic]
    return len(set(all_words)) / (len(topics) * topk)

def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        embeddings = meta["sentence_embeddings"]
        labels = [meta["topic_id"]] * len(embeddings)
        all_embeddings.append(embeddings)
        all_labels.extend(labels)

    if len(all_embeddings) == 0:
        return None

    all_embeddings = np.vstack(all_embeddings)
    n_samples = all_embeddings.shape[0]
    n_labels = len(set(all_labels))

    if n_labels < 2 or n_labels > n_samples - 1:
        return None

    return silhouette_score(all_embeddings, all_labels, metric="cosine")

# === DATASET & INITIALIZATION ===
allergy_dataset = {
  "chunks": [
    "Mrs. M is a 68-year-old female with a complex medical history including hypertension, type 2 diabetes mellitus, and stage 3 chronic kidney disease.",
    "She presents to the internal medicine clinic with ongoing complaints of dull, persistent occipital headaches for the past six weeks.",
    "The headaches are associated with episodes of dizziness and intermittent blurred vision, particularly in the morning hours.",
    "She reports that these symptoms have gradually increased in frequency and intensity.",
    "At-home blood pressure monitoring consistently shows elevated readings above 160/95 mmHg.",
    "These elevated blood pressure values are most pronounced during early mornings.",
    "Her current antihypertensive regimen includes amlodipine 10 mg daily and hydrochlorothiazide 25 mg daily.",
    "Despite adherence, her blood pressure remains poorly controlled.",
    "She also has a history of hyperlipidemia managed with atorvastatin.",
    "Her most recent HbA1c is 7.8%, indicating suboptimal glycemic control.",
    "She suffers from osteoarthritis, predominantly affecting her knees and hips.",
    "Osteoarthritis significantly limits her mobility and daily activity.",
    "She averages fewer than 1,000 steps per day, as recorded by her wearable tracker.",
    "She has recently developed bilateral ankle edema, particularly in the evenings.",
    "Mild shortness of breath is present with moderate physical exertion.",
    "These symptoms raise concerns for evolving heart failure.",
    "Her clinic vitals reveal a blood pressure of 168/98 mmHg while seated.",
    "Heart rate is 88 bpm, respiratory rate 18, and oxygen saturation is 96% on room air.",
    "Temperature is within normal limits.",
    "Cardiovascular exam reveals a displaced apical impulse at the 6th intercostal space.",
    "A soft systolic murmur (grade 2/6) is heard over the cardiac apex.",
    "No jugular venous distension is appreciated.",
    "Pulmonary examination shows clear lung fields bilaterally.",
    "Mild 1+ pitting edema is noted at both ankles.",
    "Fundoscopic exam shows arteriolar narrowing and scattered cotton wool spots.",
    "There is no evidence of papilledema.",
    "Recent labs show stable serum creatinine at 1.4 mg/dL.",
    "Her eGFR is estimated at 48 mL/min/1.73 m².",
    "Electrolyte levels remain within normal limits.",
    "Her lipid panel reveals LDL cholesterol at 130 mg/dL.",
    "HDL is 38 mg/dL and triglycerides are measured at 160 mg/dL.",
    "A 12-lead ECG reveals left ventricular hypertrophy with strain pattern.",
    "There are no arrhythmias or conduction defects observed on ECG.",
    "Echocardiography demonstrates concentric LV hypertrophy with LVEF of 60%.",
    "Mild left atrial enlargement is present on echocardiogram.",
    "Trace mitral regurgitation is also identified.",
    "She expresses challenges adhering to a low-sodium diet.",
    "Meal preparation is difficult due to joint pain and fatigue.",
    "She frequently consumes processed or ready-made meals.",
    "Her husband assists with shopping but she manages meal planning.",
    "Despite previous dietary counseling, her weight has remained stable.",
    "She reports feeling frustrated with her lack of progress.",
    "Her hydrochlorothiazide dose was reduced due to borderline hypokalemia.",
    "Lisinopril 10 mg daily was added to improve BP control and renal protection.",
    "She was instructed to record BP twice daily and maintain a log.",
    "Additional labs were ordered to reassess renal function and electrolytes.",
    "A follow-up lipid panel was also requested.",
    "She was counseled on symptoms of hypertensive urgency.",
    "Educational materials were provided about recognizing chest pain, confusion, or sudden weakness.",
    "She was encouraged to gradually increase activity within tolerance.",
    "Physical therapy referral was made to support her with joint-friendly exercises.",
    "Dietitian consultation was arranged to address nutritional gaps.",
    "The diet plan focuses on kidney-friendly and low-sodium strategies.",
    "Home health services were initiated to aid with BP monitoring.",
    "Medication adherence support and lifestyle reinforcement are included in home visits.",
    "Cardiology follow-up is scheduled within 4 weeks.",
    "Nephrology will reassess renal function and medication tolerance.",
    "She lives with her husband in a single-story home.",
    "Her husband offers daily assistance and emotional support.",
    "She is the primary caregiver for her elderly mother-in-law with dementia.",
    "Caregiver duties contribute to emotional stress and fatigue.",
    "She reports occasional insomnia and feelings of anxiety.",
    "There is no history of tobacco or alcohol use.",
    "She denies symptoms of depression.",
    "Social work referral was made to explore caregiver resources.",
    "Multidisciplinary care was emphasized to support long-term BP management.",
    "She was advised to follow up within 6 weeks or sooner if symptoms worsen.",
    "She agreed to maintain a BP log and bring it to her next visit.",
    "Nutritional handouts and physical therapy instructions were given in print.",
    "She was educated on avoiding NSAIDs due to kidney function.",
    "She occasionally used over-the-counter pain relievers for joint discomfort.",
    "Acetaminophen was recommended instead of ibuprofen.",
    "Fluid status will be monitored to avoid volume overload.",
    "She was advised to elevate legs periodically throughout the day.",
    "Compression stockings were discussed for managing ankle edema.",
    "Psychosocial support and stress management were reviewed during the visit.",
    "Behavioral therapy referral was considered if insomnia persists.",
    "Cognitive function was grossly intact during the visit.",
    "There was no evidence of delirium or memory impairment.",
    "Speech, gait, and motor function appeared normal during physical exam.",
    "Her last eye exam was over a year ago; ophthalmology referral was placed.",
    "Diabetic foot screening was normal, with intact sensation and no ulceration.",
    "She is due for her pneumococcal and influenza vaccinations.",
    "Vaccination updates were administered during this visit.",
    "Her BMI is 31, placing her in the obese category.",
    "Weight reduction strategies were discussed, including dietary adjustments.",
    "She agreed to log meals and reduce sugary beverages.",
    "She was encouraged to attend group diabetes education sessions.",
    "A glucose meter was prescribed for home use.",
    "Target fasting blood glucose levels were reviewed with her.",
    "She demonstrated appropriate technique for fingerstick glucose testing.",
    "Pharmacist consultation was arranged to review medication regimen.",
    "Her medication list was reconciled and updated in the EHR.",
    "There are no current drug allergies reported.",
    "Her creatinine will be monitored every 3 months going forward.",
    "UACR (urine albumin-to-creatinine ratio) was ordered to assess proteinuria.",
    "Recent lab results were reviewed and explained in detail.",
    "Patient verbalized understanding of the treatment plan.",
    "Emergency contact information was verified and updated.",
    "Instructions were provided on when to seek urgent care.",
    "She was advised to avoid high-potassium foods temporarily.",
    "A potassium supplement was not needed at this time.",
    "She requested information on local support groups for caregivers.",
    "Social worker will follow up within 2 weeks.",
    "Progress notes and care plan summary were printed for her records.",
    "Next appointment was scheduled before she left the clinic.",
    "Clinic staff reviewed transportation resources if needed.",
    "She prefers morning appointments due to caregiver duties in the afternoon.",
    "Her husband accompanied her to the visit and asked questions about medications.",
    "Together, they expressed appreciation for coordinated care.",
    "She was reminded of the importance of medication timing consistency.",
    "BP monitor calibration was reviewed and verified as accurate.",
    "Telehealth follow-up was discussed as an option for future visits.",
    "Patient was comfortable with the use of video visits.",
    "Portal access instructions were provided to view lab results online.",
    "Patient's adherence barriers were reviewed in depth.",
    "She reported no issues obtaining her prescriptions.",
    "Medication copays are affordable with current insurance.",
    "She uses a pill organizer to stay consistent with medications.",
    "No adverse effects were noted with lisinopril initiation.",
    "She will follow up earlier if she experiences cough or dizziness.",
    "Lab orders were sent electronically to her preferred lab facility.",
    "Her primary care physician was updated with a detailed note.",
    "All referrals were placed and communicated via the EHR.",
    "She verbalized a commitment to improve her dietary habits.",
    "Caregiver support remains a critical concern in her daily life.",
    "A family meeting was suggested to discuss shared caregiving responsibilities.",
    "Advance directives were briefly discussed and documented in the chart.",
    "She has not completed a living will but expressed interest in doing so.",
    "Goals of care conversation was scheduled for her next visit.",
    "The care team concluded the visit with a review of next steps.",
    "Mrs. M was thanked for her active participation and engagement in her care."
  ]
,
"entities" : [
    ["68-year-old woman", "hypertension", "type 2 diabetes mellitus", "stage 3 chronic kidney disease"],
    ["internal medicine clinic", "occipital headaches", "six weeks"],
    ["headaches", "dizziness", "blurred vision", "morning hours"],
    ["increased symptom frequency", "intensity"],
    ["home blood pressure monitoring", "elevated readings", "160/95 mmHg"],
    ["elevated blood pressure", "early mornings"],
    ["amlodipine 10 mg daily", "hydrochlorothiazide 25 mg daily", "antihypertensive regimen"],
    ["medication adherence", "poorly controlled blood pressure"],
    ["hyperlipidemia", "atorvastatin"],
    ["HbA1c 7.8%", "suboptimal glycemic control"],
    ["osteoarthritis", "knees", "hips"],
    ["limited mobility", "daily activity"],
    ["<1,000 steps per day", "wearable tracker"],
    ["bilateral ankle edema", "evenings"],
    ["mild shortness of breath", "moderate physical exertion"],
    ["concerns for heart failure"],
    ["blood pressure 168/98 mmHg", "seated vitals"],
    ["heart rate 88 bpm", "respiratory rate 18", "oxygen saturation 96%"],
    ["temperature", "normal limits"],
    ["displaced apical impulse", "6th intercostal space"],
    ["soft systolic murmur", "grade 2/6", "cardiac apex"],
    ["no jugular venous distension"],
    ["clear lung fields", "pulmonary examination"],
    ["1+ pitting edema", "both ankles"],
    ["arteriolar narrowing", "cotton wool spots", "fundoscopic exam"],
    ["no papilledema"],
    ["serum creatinine 1.4 mg/dL", "stable renal function"],
    ["eGFR 48 mL/min/1.73 m²"],
    ["electrolyte levels", "normal limits"],
    ["LDL cholesterol 130 mg/dL", "lipid panel"],
    ["HDL 38 mg/dL", "triglycerides 160 mg/dL"],
    ["left ventricular hypertrophy", "strain pattern", "ECG"],
    ["no arrhythmias", "no conduction defects"],
    ["echocardiography", "concentric LV hypertrophy", "LVEF 60%"],
    ["mild left atrial enlargement", "echocardiogram"],
    ["trace mitral regurgitation"],
    ["low-sodium diet", "adherence challenges"],
    ["joint pain", "fatigue"],
    ["processed meals", "ready-made meals"],
    ["husband assistance", "meal planning"],
    ["dietary counseling", "weight stable"],
    ["frustration", "lack of progress"],
    ["hydrochlorothiazide dose reduction", "borderline hypokalemia"],
    ["lisinopril 10 mg daily", "BP control", "renal protection"],
    ["BP recording", "twice daily", "log maintenance"],
    ["labs", "renal function", "electrolytes"],
    ["lipid panel", "follow-up"],
    ["hypertensive urgency", "counseling"],
    ["chest pain", "confusion", "sudden weakness", "educational materials"],
    ["activity increase", "tolerance"],
    ["physical therapy referral", "joint-friendly exercises"],
    ["dietitian consultation", "nutritional gaps"],
    ["kidney-friendly diet", "low-sodium strategies"],
    ["home health services", "BP monitoring"],
    ["medication adherence support", "lifestyle reinforcement"],
    ["cardiology follow-up", "4 weeks"],
    ["nephrology", "renal function", "medication tolerance"],
    ["single-story home", "husband support"],
    ["daily assistance", "emotional support"],
    ["primary caregiver", "elderly mother-in-law", "dementia"],
    ["emotional stress", "fatigue"],
    ["insomnia", "feelings of anxiety"],
    ["no tobacco use", "no alcohol use"],
    ["no depression"],
    ["social work referral", "caregiver resources"],
    ["multidisciplinary care", "BP management"],
    ["follow-up", "6 weeks", "symptom worsening"],
    ["BP log", "visit agreement"],
    ["nutritional handouts", "physical therapy instructions"],
    ["NSAID avoidance", "kidney function"],
    ["OTC pain relievers", "joint discomfort"],
    ["acetaminophen", "ibuprofen substitution"],
    ["fluid status", "volume overload"],
    ["leg elevation"],
    ["compression stockings", "ankle edema"],
    ["psychosocial support", "stress management"],
    ["behavioral therapy referral", "insomnia"],
    ["cognitive function", "grossly intact"],
    ["no delirium", "no memory impairment"],
    ["speech", "gait", "motor function", "normal exam"],
    ["eye exam", "overdue", "ophthalmology referral"],
    ["diabetic foot screening", "normal", "no ulceration"],
    ["pneumococcal vaccination", "influenza vaccination"],
    ["vaccinations", "administered"],
    ["BMI 31", "obese category"],
    ["weight reduction", "dietary adjustments"],
    ["meal log", "sugary beverage reduction"],
    ["group diabetes education sessions"],
    ["glucose meter", "home use"],
    ["target fasting blood glucose levels"],
    ["fingerstick glucose testing", "appropriate technique"],
    ["pharmacist consultation", "medication review"],
    ["medication list", "EHR", "reconciled"],
    ["no drug allergies"],
    ["creatinine", "monitoring every 3 months"],
    ["UACR", "proteinuria assessment"],
    ["lab results", "reviewed", "explained"],
    ["treatment plan", "understanding verbalized"],
    ["emergency contact information", "updated"],
    ["urgent care instructions"],
    ["high-potassium foods", "dietary avoidance"],
    ["potassium supplement", "not needed"],
    ["support groups", "caregivers"],
    ["social worker", "2-week follow-up"],
    ["progress notes", "care plan summary", "printed"],
    ["next appointment", "scheduled"],
    ["transportation resources"],
    ["morning appointments", "caregiver duties"],
    ["husband", "medication questions"],
    ["coordinated care", "appreciation"],
    ["medication timing", "consistency"],
    ["BP monitor", "calibration verified"],
    ["telehealth follow-up", "future visits"],
    ["video visits", "comfort"],
    ["portal access", "lab results online"],
    ["adherence barriers", "reviewed"],
    ["prescriptions", "no issues obtaining"],
    ["medication copays", "insurance coverage"],
    ["pill organizer", "medication consistency"],
    ["lisinopril", "no adverse effects"],
    ["cough", "dizziness", "monitoring advised"],
    ["lab orders", "electronic submission"],
    ["primary care physician", "updated note"],
    ["referrals", "communicated via EHR"],
    ["dietary habits", "improvement commitment"],
    ["caregiver support", "critical concern"],
    ["family meeting", "caregiving responsibilities"],
    ["advance directives", "discussion", "chart documentation"],
    ["living will", "not completed", "interest expressed"],
    ["goals of care", "next visit plan"],
    ["care team", "review", "next steps"],
    ["patient engagement", "acknowledged"]
]

}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist":0.25,  "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
)
print("✅ Model ready for querying.")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics = [
  { "topic_id": "T0", "entities": ["hypertension", "headaches", "dizziness", "blurred vision", "poorly controlled blood pressure", "medication adherence"] },
  { "topic_id": "T1", "entities": ["amlodipine", "hydrochlorothiazide", "lisinopril", "medication regimen", "potassium", "borderline hypokalemia"] },
  { "topic_id": "T2", "entities": ["type 2 diabetes mellitus", "chronic kidney disease", "hyperlipidemia", "osteoarthritis", "comorbidities"] },
  { "topic_id": "T3", "entities": ["ankle edema", "shortness of breath", "fatigue", "exercise limitation", "limited mobility", "fewer than 1,000 steps per day"] },
  { "topic_id": "T4", "entities": ["blood pressure 168/98 mmHg", "heart rate", "systolic murmur", "displaced apical impulse", "fundoscopic changes", "retinopathy", "edema"] },
  { "topic_id": "T5", "entities": ["creatinine", "eGFR", "HbA1c", "lipid panel", "LDL", "HDL", "triglycerides", "electrolytes"] },
  { "topic_id": "T6", "entities": ["ECG", "left ventricular hypertrophy", "strain pattern", "echocardiography", "LVEF", "mitral regurgitation", "left atrial enlargement"] },
  { "topic_id": "T7", "entities": ["low-sodium diet", "processed meals", "ready-made meals", "dietary noncompliance", "weight stable", "obesity", "BMI 31", "weight loss"] },
  { "topic_id": "T8", "entities": ["dietitian", "physical therapy", "cardiology follow-up", "nephrology follow-up", "ophthalmology referral", "social work referral"] },
  { "topic_id": "T9", "entities": ["married", "lives with husband", "caregiver stress", "caregiving responsibilities", "anxiety", "insomnia", "emotional stress", "no tobacco use", "no alcohol use"] },
  { "topic_id": "T10", "entities": ["patient education", "hypertension urgency signs", "chest pain", "confusion", "weakness", "treatment plan understanding", "symptom awareness"] },
  { "topic_id": "T11", "entities": ["follow-up", "BP monitoring", "home health", "medication reminders", "lab monitoring", "pill organizer", "telehealth visits", "portal access"] },
  { "topic_id": "T12", "entities": ["multidisciplinary care", "nutritional support", "psychosocial support", "pharmacist consultation", "care coordination", "support groups"] }
]


# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
print("\n=== 📊 Topic Matching Summary ===")
for model_id, gt_id, score in matches:
    print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=3, top_k_sents=3)

    threshold = 0.2  # Set your similarity threshold here

    # Find the max topic score among the results (0 if no results)
    max_topic_score = max((res["topic_score"] for res in results), default=0)

    if max_topic_score < threshold or not results:
        print("Sorry, the query is irrelevant.")
    else:
        print(f"\n🔎 Top results for: '{query}'")
        for res in results:
            print(f"🧠 Topic ID: {res['topic_id']} (Score: {res['topic_score']:.4f})")
            print(f"🔗 Related Entities: {', '.join(res['entities'])}")
            for sent, score in res["sentences"]:
                print(f"✓ [{score:.4f}] {sent}")



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/433M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/412 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.5878
🌈 Topic Diversity: 0.4284
📐 Silhouette Score: 0.4214

=== 📊 Topic Matching Summary ===
🔗 Model Topic 10 ↔ Ground Truth T2 — Jaccard: 0.11
🔗 Model Topic 19 ↔ Ground Truth T0 — Jaccard: 0.25
🔗 Model Topic 53 ↔ Ground Truth T0 — Jaccard: 0.14
🔗 Model Topic 0 ↔ Ground Truth T1 — Jaccard: 0.07
🔗 Model Topic 75 ↔ Ground Truth T0 — Jaccard: 0.33
🔗 Model Topic 72 ↔ Ground Truth T2 — Jaccard: 0.17
🔗 Model Topic 68 ↔ Ground Truth T4 — Jaccard: 0.12
🔗 Model Topic 66 ↔ Ground Truth T2 — Jaccard: 0.17
🔗 Model Topic 13 ↔ Ground Truth T3 — Jaccard: 0.22
🔗 Model Topic 3 ↔ Ground Truth T10 — Jaccard: 0.18
🔗 Model Topic -1 ↔ Ground Truth T5 — Jaccard: 0.07
🔗 Model Topic 37 ↔ Ground Truth T4 — Jaccard: 0.11
🔗 Model Topic 8 ↔ Ground Truth T11 — Jaccard: 0.08
🔗 Model Topic 41 ↔ Ground Truth T3 — Jaccard: 0.12
🔗 Model Topic 69 ↔ Ground Truth T5 — Jaccard: 0.11
🔗 Model Topic 44 ↔ Ground Truth T6 — Jaccard: 0.25
🔗 Model