In [1]:
!pip install numpy sentence-transformers bertopic hdbscan nltk scann
import nltk
nltk.download('punkt')
import nltk
nltk.download('punkt_tab')
!pip install sentence-transformers bertopic hdbscan umap-learn scann nltk datasets
!pip install gensim

Collecting bertopic
  Downloading bertopic-0.17.3-py3-none-any.whl.metadata (24 kB)
Collecting scann
  Downloading scann-1.4.0-cp311-cp311-manylinux_2_27_x86_64.whl.metadata (5.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transform

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


Collecting gensim
  Downloading gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting numpy<2.0,>=1.18.5 (from gensim)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy<1.14.0,>=1.7.0 (from gensim)
  Downloading scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Downloading gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.7/26.7 MB[0m [31m59.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━

In [None]:
#Improved code tried on new dataset:

In [2]:
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict, Counter
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score, precision_recall_fscore_support
from sklearn.metrics.pairwise import cosine_similarity

# === SEED FIXING ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
nltk.download("punkt")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# === CLEANING & CONTEXT EXTRACTION (IMPROVED) ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text)
    return re.sub(r"\s+", " ", text).strip()


def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = " ".join(sentences[max(0, i - 1): i + 2]) if use_multi_sentence else sent.strip()
                    enriched = f"The concept '{ent_lower}' appears in the following context: {context}"
                    entity_context_pairs.append((ent_lower, enriched.strip()))
                    matched = True
                    break
            if not matched:
                fallback = f"The concept '{ent_lower}' appears in the following context: {chunk}"
                entity_context_pairs.append((ent_lower, fallback.strip()))
    return entity_context_pairs


# === TOPIC SEARCHER CLASS (WITH DEDUPLICATION, NOISE FILTERING) ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params,
                 model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(self.chunks, self.entities_per_chunk)
        contextual_texts = [ctx for _, ctx in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=False)

        umap_model = UMAP(**self.umap_params, random_state=SEED)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            if topic == -1:
                continue  # Skip noisy topics
            ent, ctx = entity_context_pairs[i]
            topic_to_contexts[topic].append(ctx)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            emb = topic_to_embeddings[topic_id]
            centroid = np.mean(emb, axis=0)
            centroid /= np.linalg.norm(centroid) + 1e-10
            topic_embeddings.append(centroid)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(emb)
            })

        # === OPTIONAL: Merge semantically similar topics (cosine sim > 0.95)
        deduped_metadata = []
        used = set()

        for i, emb_i in enumerate(topic_embeddings):
            if i in used:
                continue
            group = [i]
            sim_scores = cosine_similarity([emb_i], topic_embeddings)[0]
            for j in range(i + 1, len(sim_scores)):
                if sim_scores[j] > 0.95:
                    group.append(j)
                    used.add(j)

            merged = {
                "topic_id": i,
                "sentences": [],
                "entities": [],
                "sentence_embeddings": []
            }
            for g in group:
                merged["sentences"] += topic_metadata[g]["sentences"]
                merged["entities"] += topic_metadata[g]["entities"]
                merged["sentence_embeddings"] += list(topic_metadata[g]["sentence_embeddings"])

            merged["sentence_embeddings"] = np.array(merged["sentence_embeddings"])
            merged["entities"] = list(set(merged["entities"]))
            deduped_metadata.append(merged)

        self.topic_metadata = deduped_metadata
        self.topic_embeddings = np.array([
            np.mean(m["sentence_embeddings"], axis=0) /
            (np.linalg.norm(np.mean(m["sentence_embeddings"], axis=0)) + 1e-10)
            for m in deduped_metadata
        ])

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=10, num_leaves_to_search=5, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(5)
            .build()
        )

    import re

    def search(self, query, top_k_topics=3, top_k_sents=3):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        prefix_pattern = r"^the concept '.*?' appears in (the following )?context:\s*"

        for i, idx in enumerate(neighbors):
            meta = self.topic_metadata[idx]
            topic_score = float(scores[i])

            # Deduplicate sentences
            seen = set()
            cleaned_sentences = []
            cleaned_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
            # Apply regex to remove beginning prefix
                cleaned = re.sub(prefix_pattern, "", sent, flags=re.IGNORECASE).strip()

        # No duplicates
                if cleaned not in seen:
                    seen.add(cleaned)
                    cleaned_sentences.append(cleaned)
                    cleaned_embeddings.append(emb)

            if not cleaned_sentences:
                continue

            emb_array = np.array(cleaned_embeddings)
            sims = np.dot(emb_array / np.linalg.norm(emb_array, axis=1, keepdims=True), query_emb)
            top_ids = sims.argsort()[::-1][:top_k_sents]

            top_sents = [(cleaned_sentences[j], float(sims[j])) for j in top_ids]
            results.append({
            "topic_id": meta["topic_id"],
            "topic_score": topic_score,
            "entities": meta["entities"],
            "sentences": top_sents,
            })

        return results




# === METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for s in meta["sentences"]:
            tokens = clean_text(s).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(t) for t in texts]

    cm = CoherenceModel(
        topics=word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return cm.get_coherence()


def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topic_words = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    unique_words = set(word for topic in topic_words for word, _ in topic)
    return len(unique_words) / (len(topic_words) * topk)


def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        emb = meta["sentence_embeddings"]
        if len(emb) < 2:  # skip small clusters
            continue
        all_embeddings.extend(emb)
        all_labels.extend([meta["topic_id"]] * len(emb))

    if len(all_embeddings) < 3:
        return None

    all_embeddings = np.vstack(all_embeddings)
    return silhouette_score(all_embeddings, all_labels, metric="cosine")


# === DATASET & INITIALIZATION ===
allergy_dataset = {
  "chunks": [
  "Mr. James H., a 79-year-old male with a long-standing history of cardiovascular and metabolic diseases, was brought to the emergency department due to acute confusion and generalized weakness.",
  "According to his daughter, he had wandered outside disoriented and was unable to identify family members or recall events from the previous day.",
  "He has known medical conditions including hypertension, heart failure with reduced ejection fraction, insulin-dependent diabetes mellitus, stage 4 chronic kidney disease, and major depressive disorder.",
  "His medication regimen includes daily doses of lisinopril, furosemide, carvedilol, insulin glargine, sertraline, and donepezil.",
  "In the past 24 hours, the patient experienced decreased appetite, an episode of vomiting, and two instances of urinary incontinence.",
  "Vital signs upon arrival included a blood pressure of 98/56 mmHg, heart rate of 112 beats per minute (irregularly irregular), respiratory rate of 20, oxygen saturation of 93% on room air, and a temperature of 37.6°C.",
  "Physical examination revealed dry mucous membranes, poor skin turgor, moderate lower limb pitting edema, and delayed capillary refill.",
  "Auscultation of the lungs revealed bilateral basal crackles, and cardiac exam confirmed an irregularly irregular heartbeat without murmurs.",
  "Neurological examination showed fluctuating attention span but no signs of focal deficits or lateralizing neurological signs.",
  "Initial lab studies demonstrated an elevated blood glucose of 421 mg/dL, serum sodium of 129 mmol/L, potassium at 5.7 mmol/L, and creatinine at 2.9 mg/dL.",
  "Serum BUN was elevated at 59 mg/dL and the patient’s anion gap was calculated to be 19, consistent with an anion-gap metabolic acidosis.",
  "Urinalysis revealed glucosuria and ketonuria without signs of infection, and serum ketones were modestly elevated.",
  "His HbA1c on record from two months ago was 8.1%, confirming chronic poor glycemic control.",
  "An ECG showed atrial fibrillation with rapid ventricular response but no acute ischemic changes.",
  "Chest radiograph revealed cardiomegaly and pulmonary vascular congestion with mild bilateral pleural effusions.",
  "CT head without contrast was negative for acute infarct, hemorrhage, or mass effect, but showed chronic microvascular changes.",
  "Given the presentation, he was admitted to the medical ward for acute hyperosmolar hyperglycemic state (HHS) and acute on chronic kidney injury.",
  "A diagnosis of acute delirium, likely secondary to metabolic derangements, volume depletion, and possible infection, was made.",
  "He was started on intravenous normal saline, correctional insulin, and telemetry monitoring.",
  "Furosemide was temporarily held due to volume depletion, and electrolytes were repleted cautiously under nephrology guidance.",
  "Blood cultures, urine cultures, and chest x-ray were obtained to rule out infection as a potential delirium trigger.",
  "Empiric antibiotics (ceftriaxone and azithromycin) were initiated pending culture data due to concern for possible aspiration pneumonia.",
  "On day two, the patient’s mental status began to improve with the resolution of hyperglycemia and normalization of serum osmolarity.",
  "Repeat labs showed trending down of BUN and creatinine, with sodium rising to 134 and potassium corrected to 4.5 mmol/L.",
  "He remained in atrial fibrillation and required continuation of beta-blocker therapy to manage ventricular rate.",
  "Apixaban was continued upon nephrology clearance given acceptable bleeding risk and stable renal function.",
  "He was evaluated by geriatrics for worsening cognitive decline and safety evaluation related to home discharge.",
  "PT/OT performed a bedside mobility assessment showing weakness, unsteadiness, and need for moderate assistance with transfers.",
  "Case management consulted social work regarding home safety, fall prevention, and caregiver support.",
  "His hospital stay was complicated by mild hypoglycemia on hospital day 3, prompting insulin dose adjustments.",
  "Nutritional support was consulted to optimize diabetic-friendly, renal-adjusted diet appropriate for age and mobility.",
  "His depression management was reviewed with psychiatry, and sertraline was continued at 100 mg/day with no suggestion for dose change.",
  "A Montreal Cognitive Assessment (MoCA) was done revealing a score of 19/30, indicating significant mild cognitive impairment.",
  "Audiology was recommended due to hearing difficulty interfering with care discussions.",
  "Oral exam noted poor dentition; dental evaluation was recommended for follow-up to address suspected pain and poor appetite.",
  "After 6 days, the patient was clinically improved, mentally oriented, and ambulatory with the help of physical therapy.",
  "Cardiac and renal parameters stabilized sufficiently to permit safe discharge planning.",
  "The final hospital diagnosis included hyperosmolar hyperglycemic state, volume depletion, acute-on-chronic kidney injury, atrial fibrillation with RVR, and acute delirium.",
  "He was discharged on a simplified diabetic regimen including basal insulin and correctional sliding scale doses only.",
  "Apixaban, carvedilol, donepezil, and sertraline were continued with no changes.",
  "Discharge medication reconciliation included temporary hold of furosemide with plan for outpatient reassessment after fluid status recovery.",
  "Caregiver role was assumed by daughter who had durable power of attorney and assisted with all home-based needs.",
  "Written instructions and red flags for hyperglycemia, dizziness, and recurrent confusion were provided.",
  "A follow-up with his primary care physician, nephrologist, and endocrinologist were scheduled within one and two weeks respectively.",
  "Home health nursing was arranged to provide medication support and glucose monitoring.",
  "Nutritionist and physical therapy were ordered for continued improvement in diet and mobility.",
  "Advanced care planning was briefly discussed including code status, proxy, and end-of-life preferences.",
  "He is currently listed as full code but family is open to further discussion at next provider visit.",
  "Patient was grateful for hospital care and expressed motivation to remain active and well at home.",
  "The overall prognosis remains guarded due to progressive cognitive decline and limited renal reserve.",
  "Close monitoring for new signs of decompensation or medication nonadherence was advised.",
  "Pulmonology follow-up was discussed due to prior mild restrictive spirometry suggestive of early interstitial lung disease.",
  "Family history reveals mother died of complications from dementia and father from ischemic stroke.",
  "No reported use of tobacco, alcohol, or recreational drugs throughout his life.",
  "Lives in a single-story home with grab bars and minimal clutter, although risks for falls still persist.",
  "Wears eyeglasses but rarely uses his hearing aids, sometimes leading to miscommunication or withdrawal.",
  "History of previous admission 1 year ago for pneumonia requiring IV antibiotics and 6-day hospitalization.",
  "Documentation from that admission revealed transient delirium and impaired oral intake similar to current episode.",
  "Goals-of-care conversations were initiated during this admission but deferred for primary care setting follow-up.",
  "Social isolation remains a concern, especially since his wife passed away 3 years ago.",
  "Patient receives Meals on Wheels but misses many meal deliveries due to lack of reliable caregiver at times.",
  "Transportation to medical appointments is provided by his daughter, who balances full-time work responsibilities.",
  "No current enrollment in adult day health programs; options discussed with case management on discharge.",
  "Insurance covers home nursing and outpatient labs but does not cover custodial care.",
  "Patient was educated about Medicare Advantage benefits and encouraged to review covered services with the plan coordinator.",
  "He was also reminded of the importance of daily glucose checks and hydration in summer months.",
  "Foot exam demonstrated mild calluses and intact sensation; he denies new ulcers or foot injuries.",
  "Vaccination status confirmed: received influenza and COVID vaccines last fall, but is due for pneumococcal booster.",
  "Dentition issues may be contributing to decreased intake; dental clinic referral was sent through EHR.",
  "Assistive device for walking was provided (four-point cane) after physical therapy evaluation.",
  "Contact dermatitis on legs due to prolonged pressure and incontinence was treated with barrier cream.",
  "Skin care and bathing guidance were reviewed with family nursing staff prior to discharge.",
  "Patient verbalized understanding of all discharge instructions with support from daughter.",
  "Hospital team closed chart after discussing active problems list, response to therapy, and continued plan.",
  "Patient left the hospital in a wheelchair, accompanied by family, and appeared in good spirits.",
  "The full discharge plan was documented and faxed to his primary provider for continuity of care.",
  "Medication reconciliation showed no potential drug interactions or allergy mismatches.",
  "He was warned against use of NSAIDs due to underlying CKD and risk of acute worsening.",
  "Hydration goals of at least 1.5 liters per day were set; urination logs and symptom review were encouraged.",
  "Emergency instructions included what to do in case of unresponsiveness, low blood glucose, sudden confusion, or chest pain.",
  "Digital blood glucose monitor was reviewed at bedside; daughter demonstrated appropriate calibration and use.",
  "All prescriptions were sent electronically to their local pharmacy located eight blocks from their home.",
  "Patient prefers morning appointments due to increased alertness and energy early in the day.",
  "A follow-up MoCA test was recommended in 3–6 months to assess cognitive trajectory.",
  "Updated advance directives were placed in the chart and a copy was given to the daughter.",
  "Fall prevention strategies were emphasized including appropriate lighting, footwear, and scheduled ambulation.",
  "Use of automatic pill organizers was encouraged to improve adherence across complex medication schedules.",
  "Daily weights will be tracked at home to monitor for unexpected fluid retention or heart failure.",
  "Serum creatinine will be rechecked in one week given borderline rise during admission.",
  "A nephrology note was sent to alert about potential need for long-term planning if GFR continues to decline.",
  "Patient qualifies for shared savings Medicare model and was assigned a care coordinator temporarily.",
  "Patient support group information was handed out, including resources for caregivers.",
  "He is open to exploring telehealth check-ins for medication titration and early symptom triage.",
  "Daughter confirmed she has portal access to review labs and visit summaries on his behalf.",
  "Patient and daughter expressed appreciation for the hospital care coordination team.",
  "Case closed with summary of diagnosis, medications, specialists involved, and plan for 30-day transitional care.",
  "Status post discharge: stable, safe for home, alert and oriented with supervision."
]

,
"entities":[
  ["confusion", "weakness", "cardiovascular", "metabolic"],
  ["disorientation", "memory"],
  ["hypertension", "failure", "diabetes", "kidney", "depression"],
  ["lisinopril", "furosemide", "carvedilol", "insulin", "sertraline", "donepezil"],
  ["appetite", "vomiting", "incontinence"],
  ["pressure", "rate", "rhythm", "respiration", "saturation", "temperature"],
  ["mucosa", "turgor", "edema", "refill"],
  ["crackles", "heartbeat", "murmurs"],
  ["attention", "deficits"],
  ["glucose", "sodium", "potassium", "creatinine"],
  ["bun", "acidosis"],
  ["glucosuria", "ketonuria", "ketones"],
  ["hba1c", "control"],
  ["ecg", "fibrillation", "response", "ischemia"],
  ["cardiomegaly", "congestion", "effusions"],
  ["infarct", "hemorrhage", "microvascular"],
  ["hyperglycemia", "injury"],
  ["delirium", "derangements", "infection"],
  ["saline", "insulin", "telemetry"],
  ["furosemide", "depletion", "electrolytes"],
  ["cultures", "infection"],
  ["antibiotics", "ceftriaxone", "azithromycin", "pneumonia"],
  ["status", "hyperglycemia", "osmolarity"],
  ["bun", "creatinine", "sodium", "potassium"],
  ["fibrillation", "rate", "blocker"],
  ["apixaban", "function", "bleeding"],
  ["geriatrics", "cognition"],
  ["pt", "ot", "mobility", "weakness", "transfers"],
  ["safety", "falls"],
  ["hypoglycemia", "insulin"],
  ["nutrition", "diet"],
  ["depression", "psychiatry", "sertraline"],
  ["moca", "impairment"],
  ["audiology", "hearing"],
  ["dentition", "pain"],
  ["therapy", "ambulation"],
  ["parameters"],
  ["hyperglycemia", "depletion", "injury", "fibrillation", "delirium"],
  ["regimen", "insulin"],
  ["apixaban", "carvedilol", "donepezil", "sertraline"],
  ["reconciliation", "furosemide"],
  ["power"],
  ["hyperglycemia", "dizziness", "confusion"],
  ["nephrologist", "endocrinologist"],
  ["nursing", "glucose"],
  ["nutritionist", "therapy"],
  ["planning", "status", "proxy"],
  ["code"],
  ["prognosis", "cognition", "reserve"],
  ["monitoring", "decompensation", "adherence"],
  ["pulmonology", "spirometry", "disease"],
  ["dementia", "stroke"],
  ["tobacco", "alcohol", "drugs"],
  ["falls"],
  ["hearing"],
  ["pneumonia", "antibiotics"],
  ["delirium"],
  ["conversations"],
  ["isolation"],
  ["meals"],
  ["transportation"],
  ["enrollment"],
  ["insurance", "nursing", "labs"],
  ["medicare"],
  ["glucose", "hydration"],
  ["exam", "calluses", "ulcers"],
  ["vaccination", "influenza", "covid", "booster"],
  ["dentition", "referral"],
  ["cane"],
  ["dermatitis", "cream"],
  ["skin"],
  ["instructions"],
  ["problems", "therapy", "plan"],
  ["wheelchair"],
  ["continuity"],
  ["reconciliation", "interactions", "allergies"],
  ["nsaids"],
  ["hydration", "urination", "symptoms"],
  ["instructions", "glucose", "confusion", "pain"],
  ["monitor", "calibration"],
  ["prescriptions", "pharmacy"],
  ["appointments", "alertness", "energy"],
  ["moca"],
  ["directives"],
  ["prevention", "lighting", "footwear", "ambulation"],
  ["organizer", "adherence"],
  ["weight", "retention", "failure"],
  ["creatinine"],
  ["nephrology", "gfr"],
  ["medicare", "coordinator"],
  ["group", "caregivers"],
  ["telehealth", "titration", "triage"],
  ["portal", "labs", "summaries"],
  ["coordination"],
  ["diagnosis", "medications", "specialists", "care"],
  ["discharge", "supervision"]
]



}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist":0.25,  "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
)
print("✅ Model ready for querying.")
# print("\n=== 🧠 Generated Topics and Entities ===")
# for meta in searcher.topic_metadata:
#     topic_id = meta["topic_id"]
#     entities = ", ".join(meta["entities"])
#     print(f"🔹 Topic ID: {topic_id} — Entities: {entities}")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics = [
  {
    "topic_id": "T0",
    "entities": [
      "confusion", "disorientation", "delirium", "weakness", "memory", "attention"
    ]
  },
  {
    "topic_id": "T1",
    "entities": [
      "hypertension", "diabetes", "kidney", "failure", "depression", "dementia"
    ]
  },
  {
    "topic_id": "T2",
    "entities": [
      "cardiovascular", "fibrillation", "carvedilol", "apixaban", "ecg", "rate", "rhythm", "ischemia"
    ]
  },
  {
    "topic_id": "T3",
    "entities": [
      "glucose", "insulin", "hba1c", "hyperglycemia", "ketones", "glucosuria", "ketonuria"
    ]
  },
  {
    "topic_id": "T4",
    "entities": [
      "creatinine", "bun", "potassium", "sodium", "hydration", "urination", "nsaids", "furosemide"
    ]
  },
  {
    "topic_id": "T5",
    "entities": [
      "sertraline", "psychiatry", "moca", "cognition", "stroke", "donepezil"
    ]
  },
  {
    "topic_id": "T6",
    "entities": [
      "infection", "cultures", "pneumonia", "ceftriaxone", "azithromycin", "delirium"
    ]
  },
  {
    "topic_id": "T7",
    "entities": [
      "weakness", "mobility", "falls", "therapy", "ambulation", "transfers", "edema"
    ]
  },
  {
    "topic_id": "T8",
    "entities": [
      "pressure", "temperature", "saturation", "rate", "mucosa", "turgor", "refill"
    ]
  },
  {
    "topic_id": "T9",
    "entities": [
      "nutrition", "appetite", "vomiting", "diet", "dentition", "pain"
    ]
  },
  {
    "topic_id": "T10",
    "entities": [
      "respiration", "spirometry", "saturation", "congestion", "crackles"
    ]
  },
  {
    "topic_id": "T11",
    "entities": [
      "medications", "reconciliation", "organizer", "interactions", "allergies"
    ]
  },
  {
    "topic_id": "T12",
    "entities": [
      "glucose", "creatinine", "bun", "sodium", "potassium", "ecg", "infarct", "hemorrhage", "microvascular"
    ]
  },
  {
    "topic_id": "T13",
    "entities": [
      "diagnosis", "specialists", "prognosis", "supervision", "summary", "discharge"
    ]
  }
]



# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
#print("\n=== 📊 Topic Matching Summary ===")
#for model_id, gt_id, score in matches:
    #print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=3, top_k_sents=3)

    threshold = 0.2  # Set your similarity threshold here

    # Find the max topic score among the results (0 if no results)
    max_topic_score = max((res["topic_score"] for res in results), default=0)

    if max_topic_score < threshold or not results:
        print("Sorry, the query is irrelevant.")
    else:
        print(f"\n🔎 Top results for: '{query}'")
        for res in results:
            print(f"🧠 Topic ID: {res['topic_id']} (Score: {res['topic_score']:.4f})")
            print(f"🔗 Related Entities: {', '.join(res['entities'])}")
            for sent, score in res["sentences"]:
                print(f"✓ [{score:.4f}] {sent}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.5826
🌈 Topic Diversity: 0.4672
📐 Silhouette Score: 0.5386

🧮 Average Jaccard Similarity: 0.2279
📈 Ground Truth Coverage: 14/14 (100.0%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 0.5664
🧲 Recall:    0.9878
🏅 F1 Score:  0.7200

=== Allergy Topic Search ===

Enter a query (or type 'exit' to quit): # What is the patient’s main medical history?

🔎 Top results for: '# What is the patient’s main medical history?'
🧠 Topic ID: 29 (Score: 0.3305)
🔗 Related Entities: sertraline, depression, psychiatry
✓ [0.3127] his depression management was reviewed with psychiatry and sertraline was continued at 100 mgday with no suggestion for dose change
🧠 Topic ID: 46 (Score: 0.3276)
🔗 Related Entities: conversations, interactions
✓ [0.3114] the full discharge plan was documented and faxed to his primary provider for continuity of care
✓ [0.2733] documentation from that admission revealed transient delirium and impair

In [None]:
# What is the patient’s main medical history?
# What symptoms or problems did the patient present with initially?
# What is the patient’s current medication list?
# What did the physical examination reveal on admission?
# What dietary advice or nutrition support does the patient require?
# Has the patient had any recent imaging or scans? What were the results?

In [3]:
#Dataset 2:



# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict, Counter
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score, precision_recall_fscore_support
from sklearn.metrics.pairwise import cosine_similarity

# === SEED FIXING ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
nltk.download("punkt")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# === CLEANING & CONTEXT EXTRACTION (IMPROVED) ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text)
    return re.sub(r"\s+", " ", text).strip()


def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = " ".join(sentences[max(0, i - 1): i + 2]) if use_multi_sentence else sent.strip()
                    enriched = f"The concept '{ent_lower}' appears in the following context: {context}"
                    entity_context_pairs.append((ent_lower, enriched.strip()))
                    matched = True
                    break
            if not matched:
                fallback = f"The concept '{ent_lower}' appears in the following context: {chunk}"
                entity_context_pairs.append((ent_lower, fallback.strip()))
    return entity_context_pairs


# === TOPIC SEARCHER CLASS (WITH DEDUPLICATION, NOISE FILTERING) ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params,
                 model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(self.chunks, self.entities_per_chunk)
        contextual_texts = [ctx for _, ctx in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=False)

        umap_model = UMAP(**self.umap_params, random_state=SEED)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            if topic == -1:
                continue  # Skip noisy topics
            ent, ctx = entity_context_pairs[i]
            topic_to_contexts[topic].append(ctx)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            emb = topic_to_embeddings[topic_id]
            centroid = np.mean(emb, axis=0)
            centroid /= np.linalg.norm(centroid) + 1e-10
            topic_embeddings.append(centroid)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(emb)
            })

        # === OPTIONAL: Merge semantically similar topics (cosine sim > 0.95)
        deduped_metadata = []
        used = set()

        for i, emb_i in enumerate(topic_embeddings):
            if i in used:
                continue
            group = [i]
            sim_scores = cosine_similarity([emb_i], topic_embeddings)[0]
            for j in range(i + 1, len(sim_scores)):
                if sim_scores[j] > 0.95:
                    group.append(j)
                    used.add(j)

            merged = {
                "topic_id": i,
                "sentences": [],
                "entities": [],
                "sentence_embeddings": []
            }
            for g in group:
                merged["sentences"] += topic_metadata[g]["sentences"]
                merged["entities"] += topic_metadata[g]["entities"]
                merged["sentence_embeddings"] += list(topic_metadata[g]["sentence_embeddings"])

            merged["sentence_embeddings"] = np.array(merged["sentence_embeddings"])
            merged["entities"] = list(set(merged["entities"]))
            deduped_metadata.append(merged)

        self.topic_metadata = deduped_metadata
        self.topic_embeddings = np.array([
            np.mean(m["sentence_embeddings"], axis=0) /
            (np.linalg.norm(np.mean(m["sentence_embeddings"], axis=0)) + 1e-10)
            for m in deduped_metadata
        ])

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=10, num_leaves_to_search=5, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(5)
            .build()
        )

    import re

    def search(self, query, top_k_topics=3, top_k_sents=3):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        prefix_pattern = r"^the concept '.*?' appears in (the following )?context:\s*"

        for i, idx in enumerate(neighbors):
            meta = self.topic_metadata[idx]
            topic_score = float(scores[i])

            # Deduplicate sentences
            seen = set()
            cleaned_sentences = []
            cleaned_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
            # Apply regex to remove beginning prefix
                cleaned = re.sub(prefix_pattern, "", sent, flags=re.IGNORECASE).strip()

        # No duplicates
                if cleaned not in seen:
                    seen.add(cleaned)
                    cleaned_sentences.append(cleaned)
                    cleaned_embeddings.append(emb)

            if not cleaned_sentences:
                continue

            emb_array = np.array(cleaned_embeddings)
            sims = np.dot(emb_array / np.linalg.norm(emb_array, axis=1, keepdims=True), query_emb)
            top_ids = sims.argsort()[::-1][:top_k_sents]

            top_sents = [(cleaned_sentences[j], float(sims[j])) for j in top_ids]
            results.append({
            "topic_id": meta["topic_id"],
            "topic_score": topic_score,
            "entities": meta["entities"],
            "sentences": top_sents,
            })

        return results




# === METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for s in meta["sentences"]:
            tokens = clean_text(s).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(t) for t in texts]

    cm = CoherenceModel(
        topics=word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return cm.get_coherence()


def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topic_words = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    unique_words = set(word for topic in topic_words for word, _ in topic)
    return len(unique_words) / (len(topic_words) * topk)


def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        emb = meta["sentence_embeddings"]
        if len(emb) < 2:  # skip small clusters
            continue
        all_embeddings.extend(emb)
        all_labels.extend([meta["topic_id"]] * len(emb))

    if len(all_embeddings) < 3:
        return None

    all_embeddings = np.vstack(all_embeddings)
    return silhouette_score(all_embeddings, all_labels, metric="cosine")


# === DATASET & INITIALIZATION ===
allergy_dataset = {
  "chunks": [
    "Mrs. M is a 68-year-old female with a complex medical history including hypertension, type 2 diabetes mellitus, and stage 3 chronic kidney disease.",
    "She presents to the internal medicine clinic with ongoing complaints of dull, persistent occipital headaches for the past six weeks.",
    "The headaches are associated with episodes of dizziness and intermittent blurred vision, particularly in the morning hours.",
    "She reports that these symptoms have gradually increased in frequency and intensity.",
    "At-home blood pressure monitoring consistently shows elevated readings above 160/95 mmHg.",
    "These elevated blood pressure values are most pronounced during early mornings.",
    "Her current antihypertensive regimen includes amlodipine 10 mg daily and hydrochlorothiazide 25 mg daily.",
    "Despite adherence, her blood pressure remains poorly controlled.",
    "She also has a history of hyperlipidemia managed with atorvastatin.",
    "Her most recent HbA1c is 7.8%, indicating suboptimal glycemic control.",
    "She suffers from osteoarthritis, predominantly affecting her knees and hips.",
    "Osteoarthritis significantly limits her mobility and daily activity.",
    "She averages fewer than 1,000 steps per day, as recorded by her wearable tracker.",
    "She has recently developed bilateral ankle edema, particularly in the evenings.",
    "Mild shortness of breath is present with moderate physical exertion.",
    "These symptoms raise concerns for evolving heart failure.",
    "Her clinic vitals reveal a blood pressure of 168/98 mmHg while seated.",
    "Heart rate is 88 bpm, respiratory rate 18, and oxygen saturation is 96% on room air.",
    "Temperature is within normal limits.",
    "Cardiovascular exam reveals a displaced apical impulse at the 6th intercostal space.",
    "A soft systolic murmur (grade 2/6) is heard over the cardiac apex.",
    "No jugular venous distension is appreciated.",
    "Pulmonary examination shows clear lung fields bilaterally.",
    "Mild 1+ pitting edema is noted at both ankles.",
    "Fundoscopic exam shows arteriolar narrowing and scattered cotton wool spots.",
    "There is no evidence of papilledema.",
    "Recent labs show stable serum creatinine at 1.4 mg/dL.",
    "Her eGFR is estimated at 48 mL/min/1.73 m².",
    "Electrolyte levels remain within normal limits.",
    "Her lipid panel reveals LDL cholesterol at 130 mg/dL.",
    "HDL is 38 mg/dL and triglycerides are measured at 160 mg/dL.",
    "A 12-lead ECG reveals left ventricular hypertrophy with strain pattern.",
    "There are no arrhythmias or conduction defects observed on ECG.",
    "Echocardiography demonstrates concentric LV hypertrophy with LVEF of 60%.",
    "Mild left atrial enlargement is present on echocardiogram.",
    "Trace mitral regurgitation is also identified.",
    "She expresses challenges adhering to a low-sodium diet.",
    "Meal preparation is difficult due to joint pain and fatigue.",
    "She frequently consumes processed or ready-made meals.",
    "Her husband assists with shopping but she manages meal planning.",
    "Despite previous dietary counseling, her weight has remained stable.",
    "She reports feeling frustrated with her lack of progress.",
    "Her hydrochlorothiazide dose was reduced due to borderline hypokalemia.",
    "Lisinopril 10 mg daily was added to improve BP control and renal protection.",
    "She was instructed to record BP twice daily and maintain a log.",
    "Additional labs were ordered to reassess renal function and electrolytes.",
    "A follow-up lipid panel was also requested.",
    "She was counseled on symptoms of hypertensive urgency.",
    "Educational materials were provided about recognizing chest pain, confusion, or sudden weakness.",
    "She was encouraged to gradually increase activity within tolerance.",
    "Physical therapy referral was made to support her with joint-friendly exercises.",
    "Dietitian consultation was arranged to address nutritional gaps.",
    "The diet plan focuses on kidney-friendly and low-sodium strategies.",
    "Home health services were initiated to aid with BP monitoring.",
    "Medication adherence support and lifestyle reinforcement are included in home visits.",
    "Cardiology follow-up is scheduled within 4 weeks.",
    "Nephrology will reassess renal function and medication tolerance.",
    "She lives with her husband in a single-story home.",
    "Her husband offers daily assistance and emotional support.",
    "She is the primary caregiver for her elderly mother-in-law with dementia.",
    "Caregiver duties contribute to emotional stress and fatigue.",
    "She reports occasional insomnia and feelings of anxiety.",
    "There is no history of tobacco or alcohol use.",
    "She denies symptoms of depression.",
    "Social work referral was made to explore caregiver resources.",
    "Multidisciplinary care was emphasized to support long-term BP management.",
    "She was advised to follow up within 6 weeks or sooner if symptoms worsen.",
    "She agreed to maintain a BP log and bring it to her next visit.",
    "Nutritional handouts and physical therapy instructions were given in print.",
    "She was educated on avoiding NSAIDs due to kidney function.",
    "She occasionally used over-the-counter pain relievers for joint discomfort.",
    "Acetaminophen was recommended instead of ibuprofen.",
    "Fluid status will be monitored to avoid volume overload.",
    "She was advised to elevate legs periodically throughout the day.",
    "Compression stockings were discussed for managing ankle edema.",
    "Psychosocial support and stress management were reviewed during the visit.",
    "Behavioral therapy referral was considered if insomnia persists.",
    "Cognitive function was grossly intact during the visit.",
    "There was no evidence of delirium or memory impairment.",
    "Speech, gait, and motor function appeared normal during physical exam.",
    "Her last eye exam was over a year ago; ophthalmology referral was placed.",
    "Diabetic foot screening was normal, with intact sensation and no ulceration.",
    "She is due for her pneumococcal and influenza vaccinations.",
    "Vaccination updates were administered during this visit.",
    "Her BMI is 31, placing her in the obese category.",
    "Weight reduction strategies were discussed, including dietary adjustments.",
    "She agreed to log meals and reduce sugary beverages.",
    "She was encouraged to attend group diabetes education sessions.",
    "A glucose meter was prescribed for home use.",
    "Target fasting blood glucose levels were reviewed with her.",
    "She demonstrated appropriate technique for fingerstick glucose testing.",
    "Pharmacist consultation was arranged to review medication regimen.",
    "Her medication list was reconciled and updated in the EHR.",
    "There are no current drug allergies reported.",
    "Her creatinine will be monitored every 3 months going forward.",
    "UACR (urine albumin-to-creatinine ratio) was ordered to assess proteinuria.",
    "Recent lab results were reviewed and explained in detail.",
    "Patient verbalized understanding of the treatment plan.",
    "Emergency contact information was verified and updated.",
    "Instructions were provided on when to seek urgent care.",
    "She was advised to avoid high-potassium foods temporarily.",
    "A potassium supplement was not needed at this time.",
    "She requested information on local support groups for caregivers.",
    "Social worker will follow up within 2 weeks.",
    "Progress notes and care plan summary were printed for her records.",
    "Next appointment was scheduled before she left the clinic.",
    "Clinic staff reviewed transportation resources if needed.",
    "She prefers morning appointments due to caregiver duties in the afternoon.",
    "Her husband accompanied her to the visit and asked questions about medications.",
    "Together, they expressed appreciation for coordinated care.",
    "She was reminded of the importance of medication timing consistency.",
    "BP monitor calibration was reviewed and verified as accurate.",
    "Telehealth follow-up was discussed as an option for future visits.",
    "Patient was comfortable with the use of video visits.",
    "Portal access instructions were provided to view lab results online.",
    "Patient's adherence barriers were reviewed in depth.",
    "She reported no issues obtaining her prescriptions.",
    "Medication copays are affordable with current insurance.",
    "She uses a pill organizer to stay consistent with medications.",
    "No adverse effects were noted with lisinopril initiation.",
    "She will follow up earlier if she experiences cough or dizziness.",
    "Lab orders were sent electronically to her preferred lab facility.",
    "Her primary care physician was updated with a detailed note.",
    "All referrals were placed and communicated via the EHR.",
    "She verbalized a commitment to improve her dietary habits.",
    "Caregiver support remains a critical concern in her daily life.",
    "A family meeting was suggested to discuss shared caregiving responsibilities.",
    "Advance directives were briefly discussed and documented in the chart.",
    "She has not completed a living will but expressed interest in doing so.",
    "Goals of care conversation was scheduled for her next visit.",
    "The care team concluded the visit with a review of next steps.",
    "Mrs. M was thanked for her active participation and engagement in her care."
  ]
,
"entities": [
  ["Mrs. M", "68-year-old", "female", "medical history", "hypertension", "type 2 diabetes mellitus", "stage 3 chronic kidney disease"],
  ["internal medicine clinic", "ongoing complaints", "occipital headaches", "six weeks"],
  ["headaches", "dizziness", "intermittent blurred vision", "morning hours"],
  ["symptoms", "increased frequency", "increased intensity"],
  ["at-home blood pressure monitoring", "elevated readings", "160/95 mmHg"],
  ["elevated blood pressure values", "early mornings"],
  ["amlodipine 10 mg daily", "hydrochlorothiazide 25 mg daily", "antihypertensive regimen"],
  ["medication adherence", "poorly controlled blood pressure"],
  ["hyperlipidemia", "atorvastatin"],
  ["HbA1c", "7.8%", "suboptimal glycemic control"],
  ["osteoarthritis", "knees", "hips"],
  ["osteoarthritis", "mobility", "daily activity"],
  ["fewer than 1,000 steps per day", "wearable tracker"],
  ["bilateral ankle edema", "evenings"],
  ["mild shortness of breath", "moderate physical exertion"],
  ["symptoms", "evolving heart failure"],
  ["clinic vitals", "blood pressure 168/98 mmHg", "seated"],
  ["heart rate 88 bpm", "respiratory rate 18", "oxygen saturation 96% on room air"],
  ["temperature", "normal limits"],
  ["cardiovascular exam", "displaced apical impulse", "6th intercostal space"],
  ["soft systolic murmur", "grade 2/6", "cardiac apex"],
  ["jugular venous distension", "not appreciated"],
  ["pulmonary examination", "clear lung fields bilaterally"],
  ["1+ pitting edema", "both ankles"],
  ["fundoscopic exam", "arteriolar narrowing", "cotton wool spots"],
  ["no evidence of papilledema"],
  ["recent labs", "serum creatinine", "1.4 mg/dL"],
  ["eGFR", "estimated at 48 mL/min/1.73 m²"],
  ["electrolyte levels", "within normal limits"],
  ["lipid panel", "LDL cholesterol", "130 mg/dL"],
  ["HDL", "38 mg/dL", "triglycerides", "160 mg/dL"],
  ["12-lead ECG", "left ventricular hypertrophy", "strain pattern"],
  ["no arrhythmias", "no conduction defects"],
  ["echocardiography", "concentric LV hypertrophy", "LVEF 60%"],
  ["left atrial enlargement", "echocardiogram"],
  ["trace mitral regurgitation"],
  ["low-sodium diet", "adherence challenges"],
  ["meal preparation", "joint pain", "fatigue"],
  ["processed meals", "ready-made meals"],
  ["husband", "shopping", "meal planning"],
  ["dietary counseling", "weight", "remained stable"],
  ["lack of progress", "frustration"],
  ["hydrochlorothiazide", "dose reduction", "borderline hypokalemia"],
  ["lisinopril 10 mg daily", "BP control", "renal protection"],
  ["BP log", "twice daily recording"],
  ["additional labs", "renal function", "electrolytes"],
  ["follow-up lipid panel"],
  ["counseled", "hypertensive urgency"],
  ["educational materials", "chest pain", "confusion", "sudden weakness"],
  ["gradual activity increase", "tolerance"],
  ["physical therapy referral", "joint-friendly exercises"],
  ["dietitian consultation", "nutritional gaps"],
  ["kidney-friendly diet plan", "low-sodium strategies"],
  ["home health services", "blood pressure monitoring"],
  ["medication adherence support", "lifestyle reinforcement", "home visits"],
  ["cardiology follow-up", "4 weeks"],
  ["nephrology", "reassess renal function", "medication tolerance"],
  ["husband", "single-story home"],
  ["daily assistance", "emotional support"],
  ["primary caregiver", "elderly mother-in-law", "dementia"],
  ["caregiver duties", "emotional stress", "fatigue"],
  ["occasional insomnia", "feelings of anxiety"],
  ["no tobacco use", "no alcohol use"],
  ["denies depression"],
  ["social work referral", "caregiver resources"],
  ["multidisciplinary care", "long-term BP management"],
  ["6 weeks", "follow up", "symptoms"],
  ["BP log", "next visit"],
  ["nutritional handouts", "physical therapy instructions"],
  ["avoiding NSAIDs", "kidney function"],
  ["over-the-counter pain relievers", "joint discomfort"],
  ["acetaminophen", "ibuprofen substitution"],
  ["fluid status", "volume overload"],
  ["elevate legs", "periodically"],
  ["compression stockings", "ankle edema"],
  ["psychosocial support", "stress management"],
  ["behavioral therapy", "insomnia"],
  ["cognitive function", "grossly intact"],
  ["no delirium", "no memory impairment"],
  ["speech", "gait", "motor function", "physical exam"],
  ["eye exam", "over a year ago", "ophthalmology referral"],
  ["diabetic foot screening", "intact sensation", "no ulceration"],
  ["pneumococcal vaccination", "influenza vaccination"],
  ["vaccination updates"],
  ["BMI", "31", "obese category"],
  ["weight reduction strategies", "dietary adjustments"],
  ["log meals", "reduce sugary beverages"],
  ["group diabetes education sessions"],
  ["glucose meter", "home use"],
  ["target fasting blood glucose levels"],
  ["fingerstick glucose testing", "technique"],
  ["pharmacist consultation", "medication regimen"],
  ["medication list", "reconciled", "EHR"],
  ["no drug allergies"],
  ["creatinine", "monitored", "every 3 months"],
  ["UACR", "urine albumin-to-creatinine ratio", "proteinuria"],
  ["lab results", "reviewed", "explained"],
  ["treatment plan", "understanding", "verbalized"],
  ["emergency contact", "updated"],
  ["urgent care", "instructions"],
  ["high-potassium foods", "avoidance"],
  ["potassium supplement", "not needed"],
  ["local support groups", "caregivers"],
  ["social worker", "follow-up", "2 weeks"],
  ["progress notes", "care plan summary", "printed"],
  ["next appointment", "scheduled"],
  ["transportation resources"],
  ["morning appointments", "caregiver duties"],
  ["husband", "medication questions"],
  ["coordinated care", "appreciation"],
  ["medication timing", "consistency"],
  ["BP monitor calibration", "verified"],
  ["telehealth follow-up", "future visits"],
  ["video visits", "comfortable"],
  ["portal access", "lab results"],
  ["adherence barriers", "reviewed"],
  ["prescriptions", "no issues"],
  ["medication copays", "insurance"],
  ["pill organizer", "medication consistency"],
  ["lisinopril", "no adverse effects"],
  ["cough", "dizziness", "monitoring"],
  ["lab orders", "electronic", "preferred lab"],
  ["primary care physician", "updated note"],
  ["referrals", "EHR"],
  ["dietary habits", "improvement commitment"],
  ["caregiver support", "daily life"],
  ["family meeting", "caregiving responsibilities"],
  ["advance directives", "discussion", "chart documentation"],
  ["living will", "interest expressed"],
  ["goals of care", "next visit"],
  ["care team", "review", "next steps"],
  ["active participation", "engagement", "thanked"]
]


}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist":0.25,  "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
)
print("✅ Model ready for querying.")
# print("\n=== 🧠 Generated Topics and Entities ===")
# for meta in searcher.topic_metadata:
#     topic_id = meta["topic_id"]
#     entities = ", ".join(meta["entities"])
#     print(f"🔹 Topic ID: {topic_id} — Entities: {entities}")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics = [
  {
    "topic_id": "T0",
    "entities": [
      "occipital headaches", "dizziness", "intermittent blurred vision",
      "morning hours", "persistent symptoms", "symptom worsening"
    ]
  },
  {
    "topic_id": "T1",
    "entities": [
      "hypertension", "poorly controlled blood pressure", "blood pressure 168/98 mmHg",
      "elevated readings", "home BP monitoring", "BP log", "hypertension urgency signs"
    ]
  },
  {
    "topic_id": "T2",
    "entities": [
      "amlodipine", "hydrochlorothiazide", "lisinopril",
      "medication adherence", "medication regimen", "pill organizer",
      "borderline hypokalemia", "potassium", "medication reminders"
    ]
  },
  {
    "topic_id": "T3",
    "entities": [
      "type 2 diabetes mellitus", "HbA1c", "glucose meter", "target fasting blood glucose levels",
      "suboptimal glycemic control", "dietary habits", "group diabetes education", "fingerstick glucose testing"
    ]
  },
  {
    "topic_id": "T4",
    "entities": [
      "chronic kidney disease", "creatinine", "eGFR",
      "proteinuria", "electrolytes", "renal function", "UACR",
      "avoid NSAIDs", "monitoring every 3 months"
    ]
  },
  {
    "topic_id": "T5",
    "entities": [
      "hyperlipidemia", "LDL", "HDL", "triglycerides",
      "lipid panel", "atorvastatin", "lab monitoring", "follow-up lipid panel"
    ]
  },
  {
    "topic_id": "T6",
    "entities": [
      "shortness of breath", "exercise limitation", "ankle edema",
      "1+ pitting edema", "volume overload", "fatigue"
    ]
  },
  {
    "topic_id": "T7",
    "entities": [
      "ECG", "left ventricular hypertrophy", "strain pattern",
      "echocardiography", "concentric LV hypertrophy",
      "LVEF", "mitral regurgitation", "left atrial enlargement"
    ]
  },
  {
    "topic_id": "T8",
    "entities": [
      "fundoscopic changes", "arteriolar narrowing", "cotton wool spots",
      "retinopathy", "ophthalmology referral", "no papilledema"
    ]
  },
  {
    "topic_id": "T9",
    "entities": [
      "low-sodium diet", "processed meals", "ready-made meals",
      "meal prep difficulty", "weight stable", "obesity",
      "BMI 31", "weight reduction", "dietitian consultation"
    ]
  },
  {
    "topic_id": "T10",
    "entities": [
      "osteoarthritis", "joint pain", "limited mobility",
      "knees", "hips", "physical therapy", "joint-friendly exercises",
      "fewer than 1,000 steps per day"
    ]
  },
  {
    "topic_id": "T11",
    "entities": [
      "caregiver stress", "caregiving responsibilities", "lives with husband",
      "emotional stress", "insomnia", "feelings of anxiety",
      "caregiver duties", "support groups", "social work referral"
    ]
  },
  {
    "topic_id": "T12",
    "entities": [
      "multidisciplinary care", "cardiology follow-up", "nephrology follow-up",
      "physical therapy referral", "ophthalmology referral", "dietitian consultation",
      "pharmacist consultation", "social work referral", "telehealth visits",
      "home health", "coordinated care"
    ]
  },
  {
    "topic_id": "T13",
    "entities": [
      "education", "hypertension urgency signs", "chest pain", "confusion",
      "symptom awareness", "treatment plan understanding", "patient engagement",
      "emergency instructions", "next steps", "follow-up plan"
    ]
  },
  {
    "topic_id": "T14",
    "entities": [
      "portal access", "electronic lab orders", "updated medication list",
      "EHR documentation", "primary care physician", "BP monitor calibration"
    ]
  }
]


# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
# print("\n=== 📊 Topic Matching Summary ===")
# for model_id, gt_id, score in matches:
#     print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=3, top_k_sents=3)

    threshold = 0.2  # Set your similarity threshold here

    # Find the max topic score among the results (0 if no results)
    max_topic_score = max((res["topic_score"] for res in results), default=0)

    if max_topic_score < threshold or not results:
        print("Sorry, the query is irrelevant.")
    else:
        print(f"\n🔎 Top results for: '{query}'")
        for res in results:
            print(f"🧠 Topic ID: {res['topic_id']} (Score: {res['topic_score']:.4f})")
            print(f"🔗 Related Entities: {', '.join(res['entities'])}")
            for sent, score in res["sentences"]:
                print(f"✓ [{score:.4f}] {sent}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.5728
🌈 Topic Diversity: 0.3620
📐 Silhouette Score: 0.5220

🧮 Average Jaccard Similarity: 0.1391
📈 Ground Truth Coverage: 15/15 (100.0%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 0.4224
🧲 Recall:    0.5862
🏅 F1 Score:  0.4910

=== Allergy Topic Search ===

Enter a query (or type 'exit' to quit): 'what the symptoms of patient?'

🔎 Top results for: ''what the symptoms of patient?''
🧠 Topic ID: 42 (Score: 0.4749)
🔗 Related Entities: dizziness, 6 weeks, hypertensive urgency, counseled, follow up, symptoms, monitoring, cough
✓ [0.4574] she will follow up earlier if she experiences cough or dizziness
✓ [0.3649] she was counseled on symptoms of hypertensive urgency
✓ [0.2751] she was advised to follow up within 6 weeks or sooner if symptoms worsen
🧠 Topic ID: 17 (Score: 0.4453)
🔗 Related Entities: symptoms, confusion, evolving heart failure
✓ [0.4847] these symptoms raise concerns for evolving heart fai

In [None]:
#Medium size dataset(10000 characters,70 chunks)

In [6]:
#Data 1
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict, Counter
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score, precision_recall_fscore_support
from sklearn.metrics.pairwise import cosine_similarity

# === SEED FIXING ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
nltk.download("punkt")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# === CLEANING & CONTEXT EXTRACTION (IMPROVED) ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text)
    return re.sub(r"\s+", " ", text).strip()


def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = " ".join(sentences[max(0, i - 1): i + 2]) if use_multi_sentence else sent.strip()
                    enriched = f"The concept '{ent_lower}' appears in the following context: {context}"
                    entity_context_pairs.append((ent_lower, enriched.strip()))
                    matched = True
                    break
            if not matched:
                fallback = f"The concept '{ent_lower}' appears in the following context: {chunk}"
                entity_context_pairs.append((ent_lower, fallback.strip()))
    return entity_context_pairs


# === TOPIC SEARCHER CLASS (WITH DEDUPLICATION, NOISE FILTERING) ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params,
                 model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(self.chunks, self.entities_per_chunk)
        contextual_texts = [ctx for _, ctx in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=False)

        umap_model = UMAP(**self.umap_params, random_state=SEED)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            if topic == -1:
                continue  # Skip noisy topics
            ent, ctx = entity_context_pairs[i]
            topic_to_contexts[topic].append(ctx)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            emb = topic_to_embeddings[topic_id]
            centroid = np.mean(emb, axis=0)
            centroid /= np.linalg.norm(centroid) + 1e-10
            topic_embeddings.append(centroid)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(emb)
            })

        # === OPTIONAL: Merge semantically similar topics (cosine sim > 0.95)
        deduped_metadata = []
        used = set()

        for i, emb_i in enumerate(topic_embeddings):
            if i in used:
                continue
            group = [i]
            sim_scores = cosine_similarity([emb_i], topic_embeddings)[0]
            for j in range(i + 1, len(sim_scores)):
                if sim_scores[j] > 0.95:
                    group.append(j)
                    used.add(j)

            merged = {
                "topic_id": i,
                "sentences": [],
                "entities": [],
                "sentence_embeddings": []
            }
            for g in group:
                merged["sentences"] += topic_metadata[g]["sentences"]
                merged["entities"] += topic_metadata[g]["entities"]
                merged["sentence_embeddings"] += list(topic_metadata[g]["sentence_embeddings"])

            merged["sentence_embeddings"] = np.array(merged["sentence_embeddings"])
            merged["entities"] = list(set(merged["entities"]))
            deduped_metadata.append(merged)

        self.topic_metadata = deduped_metadata
        self.topic_embeddings = np.array([
            np.mean(m["sentence_embeddings"], axis=0) /
            (np.linalg.norm(np.mean(m["sentence_embeddings"], axis=0)) + 1e-10)
            for m in deduped_metadata
        ])

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=10, num_leaves_to_search=5, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(5)
            .build()
        )

    import re

    def search(self, query, top_k_topics=3, top_k_sents=3):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        prefix_pattern = r"^the concept '.*?' appears in (the following )?context:\s*"

        for i, idx in enumerate(neighbors):
            meta = self.topic_metadata[idx]
            topic_score = float(scores[i])

            # Deduplicate sentences
            seen = set()
            cleaned_sentences = []
            cleaned_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
            # Apply regex to remove beginning prefix
                cleaned = re.sub(prefix_pattern, "", sent, flags=re.IGNORECASE).strip()

        # No duplicates
                if cleaned not in seen:
                    seen.add(cleaned)
                    cleaned_sentences.append(cleaned)
                    cleaned_embeddings.append(emb)

            if not cleaned_sentences:
                continue

            emb_array = np.array(cleaned_embeddings)
            sims = np.dot(emb_array / np.linalg.norm(emb_array, axis=1, keepdims=True), query_emb)
            top_ids = sims.argsort()[::-1][:top_k_sents]

            top_sents = [(cleaned_sentences[j], float(sims[j])) for j in top_ids]
            results.append({
            "topic_id": meta["topic_id"],
            "topic_score": topic_score,
            "entities": meta["entities"],
            "sentences": top_sents,
            })

        return results




# === METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for s in meta["sentences"]:
            tokens = clean_text(s).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(t) for t in texts]

    cm = CoherenceModel(
        topics=word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return cm.get_coherence()


def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topic_words = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    unique_words = set(word for topic in topic_words for word, _ in topic)
    return len(unique_words) / (len(topic_words) * topk)


def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        emb = meta["sentence_embeddings"]
        if len(emb) < 2:  # skip small clusters
            continue
        all_embeddings.extend(emb)
        all_labels.extend([meta["topic_id"]] * len(emb))

    if len(all_embeddings) < 3:
        return None

    all_embeddings = np.vstack(all_embeddings)
    return silhouette_score(all_embeddings, all_labels, metric="cosine")


# === DATASET & INITIALIZATION ===
allergy_dataset = {
  "chunks": [
  "A 45-year-old male construction worker presented to the infectious disease clinic with intermittent fever over the past month.",
  "He also reported night sweats, generalized fatigue, and a 10 kg weight loss during this period.",
  "He had recently returned from a prolonged work assignment in Southeast Asia.",
  "The patient experienced a persistent dry cough and occasional shortness of breath.",
  "He denied hemoptysis but admitted to occasional chest discomfort without radiation.",
  "Past medical history was significant for well-controlled asthma from childhood.",
  "He was a current smoker with a 15-pack-year history and consumed alcohol socially.",
  "On initial assessment, his temperature was 38.3°C, heart rate 100 bpm, and respiratory rate 20 per minute.",
  "Physical examination revealed cervical and axillary lymphadenopathy.",
  "Lung auscultation identified scattered crackles bilaterally without wheezes.",
  "Skin inspection showed multiple painless, erythematous nodules on the anterior forearms.",
  "No oral thrush or mucosal lesions were observed.",
  "There was moderate hepatomegaly but no splenomegaly.",
  "Neurologic examination was unremarkable without focal deficits.",
  "No signs of meningismus or altered mental status were noted.",
  "Laboratory studies revealed normocytic anemia with hemoglobin of 10.5 g/dL.",
  "White blood cell count was within normal limits but with mild lymphopenia.",
  "Platelets were mildly reduced at 130,000/μL.",
  "Erythrocyte sedimentation rate was elevated at 65 mm/hr.",
  "C-reactive protein was elevated at 45 mg/L.",
  "Liver function tests showed mild elevation of alkaline phosphatase and transaminases.",
  "Renal function tests were within normal limits.",
  "HIV serology was negative.",
  "Tuberculin skin test was positive with 18 mm induration.",
  "Chest X-ray revealed bilateral patchy infiltrates and hilar adenopathy.",
  "High-resolution CT scan showed multiple micronodular opacities and mediastinal lymphadenopathy.",
  "Bronchoscopy was performed with bronchial washings sent for microbiological analysis.",
  "Sputum smears were negative for acid-fast bacilli on initial evaluation.",
  "PCR assays from bronchial washings indicated presence of Mycobacterium tuberculosis complex DNA.",
  "Blood cultures were negative for bacterial growth.",
  "Skin biopsy of the nodules showed granulomatous inflammation with caseating necrosis.",
  "Staining for acid-fast bacilli revealed rare organisms within the granulomas.",
  "Cultures from the biopsy later confirmed Mycobacterium tuberculosis growth.",
  "An interferon-gamma release assay was positive.",
  "The patient was diagnosed with disseminated tuberculosis involving lungs, lymph nodes, liver, and skin.",
  "He was initiated on a four-drug anti-tuberculous therapy regimen including isoniazid, rifampin, pyrazinamide, and ethambutol.",
  "Vitamin B6 supplementation was started to prevent neuropathy.",
  "Baseline visual acuity and color vision testing were documented before initiating therapy.",
  "The patient was counseled on medication adherence and potential side effects.",
  "Regular monitoring plans for liver function tests and drug toxicity were scheduled.",
  "Symptomatically, supplemental oxygen was provided intermittently during exacerbations of dyspnea.",
  "Physiotherapy was initiated to improve pulmonary function and overall stamina.",
  "The patient was advised to cease smoking and offered support for cessation.",
  "Hepatomegaly resolved after two months of therapy.",
  "Skin lesions gradually regressed over the treatment course.",
  "Repeat sputum samples became negative for Mycobacterium tuberculosis after six weeks.",
  "He experienced transient hyperuricemia secondary to pyrazinamide which was managed conservatively.",
  "No episodes of drug-induced hepatitis occurred during the therapy.",
  "His anemia improved with nutritional supplementation and resolution of chronic inflammation.",
  "A multidisciplinary team including infectious disease specialists, pulmonologists, dermatologists, and dieticians managed his care.",
  "At three months, the patient showed substantial weight gain and resolution of systemic symptoms.",
  "Chest imaging demonstrated significant reduction in infiltrates and lymphadenopathy.",
  "He was transitioned to a continuation phase of therapy with isoniazid and rifampin planned for six additional months.",
  "Liver function tests remained stable throughout treatment.",
  "Psychological support was offered due to anxiety related to the diagnosis and isolation precautions.",
  "Family screening and contact tracing were initiated to identify and manage exposed individuals.",
  "The patient was educated about infection control measures to prevent disease spread.",
  "A structured follow-up schedule was established including monthly clinical and laboratory assessments.",
  "Compliance with therapy was monitored via pill counts and pharmacy records.",
  "The patient maintained good adherence and reported improved quality of life.",
  "Potential drug interactions with his asthma medications were reviewed and managed.",
  "No adjustments were needed as inhaled corticosteroids and bronchodilators were compatible.",
  "He underwent vaccination review and received updated influenza and pneumococcal vaccines after starting treatment.",
  "After nine months, the patient was declared clinically cured with no relapse signs.",
  "Final imaging showed complete resolution of pulmonary and lymph node involvement.",
  "Skin lesions healed with residual scarring but no active disease.",
  "His hematologic and inflammatory markers normalized.",
  "Regular long-term follow-up was recommended focusing on pulmonary function and prevention of reactivation.",
  "He remained asymptomatic with normal exercise tolerance at one year.",
  "The case highlights the complexity of disseminated tuberculosis with multi-organ involvement and the importance of early diagnosis and multidisciplinary management."
]

,
  "entities":[
  ["male", "worker", "fever"],
  ["night", "sweats", "fatigue", "weight"],
  ["assignment", "asia"],
  ["cough", "shortness", "breath"],
  ["hemoptysis", "chest", "discomfort"],
  ["asthma", "childhood"],
  ["smoking", "history", "alcohol"],
  ["temperature", "heart", "rate", "respiratory", "rate"],
  ["lymphadenopathy", "cervical", "axillary"],
  ["lung", "crackles"],
  ["skin", "nodules", "forearm", "erythematous"],
  ["oral", "thrush", "mucosal", "lesions"],
  ["hepatomegaly", "splenomegaly"],
  ["neurologic", "deficits"],
  ["meningismus", "status"],
  ["anemia", "hemoglobin", "g/dl"],
  ["white", "cells", "lymphopenia"],
  ["platelets"],
  ["esr"],
  ["crp"],
  ["liver", "alkaline", "phosphatase", "transaminases"],
  ["renal", "function"],
  ["hiv", "serology"],
  ["tuberculin", "test", "induration"],
  ["chest", "x-ray", "infiltrates", "adenopathy"],
  ["ct", "micronodular", "opacities", "lymphadenopathy"],
  ["bronchoscopy", "washings", "microbiological"],
  ["sputum", "acid-fast", "bacilli"],
  ["pcr", "mycobacterium", "tuberculosis", "dna"],
  ["blood", "cultures"],
  ["biopsy", "granulomatous", "inflammation", "necrosis"],
  ["staining", "acid-fast", "bacilli"],
  ["culture", "mycobacterium", "tuberculosis"],
  ["interferon-gamma", "release", "assay"],
  ["therapy", "tuberculosis", "lungs", "lymph", "nodes", "liver", "skin"],
  ["isoniazid", "rifampin", "pyrazinamide", "ethambutol"],
  ["vitamin", "supplementation", "neuropathy"],
  ["visual", "acuity", "color", "vision", "testing"],
  ["counseling", "medication", "adherence", "side", "effects"],
  ["monitoring", "liver", "toxicity"],
  ["oxygen", "supplemental"],
  ["physiotherapy", "pulmonary", "function", "stamina"],
  ["smoking", "cessation", "support"],
  ["hepatomegaly", "resolution"],
  ["skin", "lesions", "regression"],
  ["sputum", "negative", "mycobacterium", "tuberculosis"],
  ["hyperuricemia", "pyrazinamide"],
  ["hepatitis"],
  ["anemia", "nutrition", "inflammation"],
  ["team", "infectious", "disease", "pulmonology", "dermatology", "dietician"],
  ["weight", "gain", "symptoms", "resolution"],
  ["imaging", "reduction", "infiltrates", "lymphadenopathy"],
  ["therapy", "continuation", "isoniazid", "rifampin"],
  ["liver", "function", "stability"],
  ["psychological", "support", "anxiety", "isolation"],
  ["family", "screening", "contact", "tracing"],
  ["education", "infection", "control"],
  ["follow-up", "clinical", "laboratory"],
  ["compliance", "pill", "counts", "pharmacy"],
  ["adherence", "quality", "life"],
  ["drug", "interaction", "asthma", "medication"],
  ["inhaled", "corticosteroid", "bronchodilators"],
  ["vaccination", "influenza", "pneumococcal"],
  ["clinical", "cure", "relapse"],
  ["imaging", "resolution", "pulmonary", "lymph", "nodes"],
  ["skin", "scarring"],
  ["hematologic", "inflammatory", "normalization"],
  ["follow-up", "pulmonary", "function", "reactivation"],
  ["exercise", "tolerance"],
  ["disseminated", "tuberculosis", "multi-organ", "management"]
]




}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist":0.25,  "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
)
print("✅ Model ready for querying.")
# print("\n=== 🧠 Generated Topics and Entities ===")
# for meta in searcher.topic_metadata:
#     topic_id = meta["topic_id"]
#     entities = ", ".join(meta["entities"])
#     print(f"🔹 Topic ID: {topic_id} — Entities: {entities}")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics =[
  {
    "topic_id": "T0",
    "entities": [
      "fever", "night", "sweats", "fatigue", "weight", "cough", "dyspnea",
      "chest", "discomfort", "smoking", "alcohol", "asthma"
    ]
  },
  {
    "topic_id": "T1",
    "entities": [
      "lymphadenopathy", "hepatomegaly", "splenomegaly", "skin", "nodules", "erythematous"
    ]
  },
  {
    "topic_id": "T2",
    "entities": [
      "anemia", "hemoglobin", "platelets", "white", "cells", "lymphopenia",
      "esr", "crp", "alkaline", "phosphatase", "transaminases", "renal", "function"
    ]
  },
  {
    "topic_id": "T3",
    "entities": [
      "hiv", "tuberculin", "test", "induration", "interferon-gamma", "release", "assay"
    ]
  },
  {
    "topic_id": "T4",
    "entities": [
      "chest", "x-ray", "infiltrates", "adenopathy", "ct", "micronodular", "opacities", "lymphadenopathy"
    ]
  },
  {
    "topic_id": "T5",
    "entities": [
      "bronchoscopy", "sputum", "acid-fast", "bacilli", "pcr", "mycobacterium", "tuberculosis", "dna",
      "blood", "cultures"
    ]
  },
  {
    "topic_id": "T6",
    "entities": [
      "biopsy", "granulomatous", "inflammation", "necrosis", "staining", "acid-fast", "bacilli", "culture"
    ]
  },
  {
    "topic_id": "T7",
    "entities": [
      "therapy", "isoniazid", "rifampin", "pyrazinamide", "ethambutol", "vitamin", "supplementation", "neuropathy"
    ]
  },
  {
    "topic_id": "T8",
    "entities": [
      "visual", "acuity", "color", "vision", "testing", "medication", "adherence", "side", "effects", "liver", "toxicity", "monitoring"
    ]
  },
  {
    "topic_id": "T9",
    "entities": [
      "oxygen", "supplemental", "physiotherapy", "pulmonary", "function", "stamina", "smoking", "cessation", "support"
    ]
  },
  {
    "topic_id": "T10",
    "entities": [
      "sputum", "negative", "mycobacterium", "tuberculosis", "hyperuricemia", "pyrazinamide", "hepatitis"
    ]
  },
  {
    "topic_id": "T11",
    "entities": [
      "anemia", "nutrition", "inflammation"
    ]
  },
  {
    "topic_id": "T12",
    "entities": [
      "team", "infectious", "disease", "pulmonology", "dermatology", "dietician"
    ]
  },
  {
    "topic_id": "T13",
    "entities": [
      "weight", "gain", "symptoms", "resolution", "imaging", "reduction", "infiltrates", "lymphadenopathy"
    ]
  },
  {
    "topic_id": "T14",
    "entities": [
      "therapy", "continuation", "liver", "function", "stability"
    ]
  },
  {
    "topic_id": "T15",
    "entities": [
      "psychological", "support", "anxiety", "isolation"
    ]
  },
  {
    "topic_id": "T16",
    "entities": [
      "family", "screening", "contact", "tracing", "education", "infection", "control"
    ]
  },
  {
    "topic_id": "T17",
    "entities": [
      "follow-up", "clinical", "laboratory", "compliance", "pill", "counts", "pharmacy",
      "adherence", "quality", "life"
    ]
  },
  {
    "topic_id": "T18",
    "entities": [
      "drug", "interaction", "asthma", "medication", "inhaled", "corticosteroid", "bronchodilators"
    ]
  },
  {
    "topic_id": "T19",
    "entities": [
      "vaccination", "influenza", "pneumococcal"
    ]
  },
  {
    "topic_id": "T20",
    "entities": [
      "clinical", "cure", "relapse", "imaging", "resolution", "pulmonary", "lymph", "nodes"
    ]
  },
  {
    "topic_id": "T21",
    "entities": [
      "skin", "scarring"
    ]
  },
  {
    "topic_id": "T22",
    "entities": [
      "hematologic", "inflammatory", "normalization"
    ]
  },
  {
    "topic_id": "T23",
    "entities": [
      "follow-up", "pulmonary", "function", "reactivation"
    ]
  },
  {
    "topic_id": "T24",
    "entities": [
      "exercise", "tolerance"
    ]
  },
  {
    "topic_id": "T25",
    "entities": [
      "disseminated", "tuberculosis", "multi-organ", "management"
    ]
  }
]





# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
#print("\n=== 📊 Topic Matching Summary ===")
#for model_id, gt_id, score in matches:
    #print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=3, top_k_sents=3)

    threshold = 0.2  # Set your similarity threshold here

    # Find the max topic score among the results (0 if no results)
    max_topic_score = max((res["topic_score"] for res in results), default=0)

    if max_topic_score < threshold or not results:
        print("Sorry, the query is irrelevant.")
    else:
        print(f"\n🔎 Top results for: '{query}'")
        for res in results:
            print(f"🧠 Topic ID: {res['topic_id']} (Score: {res['topic_score']:.4f})")
            print(f"🔗 Related Entities: {', '.join(res['entities'])}")
            for sent, score in res["sentences"]:
                print(f"✓ [{score:.4f}] {sent}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.6685
🌈 Topic Diversity: 0.4030
📐 Silhouette Score: 0.6300

🧮 Average Jaccard Similarity: 0.3985
📈 Ground Truth Coverage: 26/26 (100.0%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 0.8765
🧲 Recall:    0.9793
🏅 F1 Score:  0.9251

=== Allergy Topic Search ===

Enter a query (or type 'exit' to quit): What abnormalities are present in the complete blood count and inflammatory markers?

🔎 Top results for: 'What abnormalities are present in the complete blood count and inflammatory markers?'
🧠 Topic ID: 63 (Score: 0.4633)
🔗 Related Entities: normalization, hematologic, inflammatory
✓ [0.4231] his hematologic and inflammatory markers normalized
🧠 Topic ID: 27 (Score: 0.4276)
🔗 Related Entities: biopsy, necrosis, granulomatous, inflammation
✓ [0.3975] skin biopsy of the nodules showed granulomatous inflammation with caseating necrosis
🧠 Topic ID: 17 (Score: 0.4235)
🔗 Related Entities: opacities, micronodul

In [7]:
#Data 2
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict, Counter
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score, precision_recall_fscore_support
from sklearn.metrics.pairwise import cosine_similarity

# === SEED FIXING ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
nltk.download("punkt")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# === CLEANING & CONTEXT EXTRACTION (IMPROVED) ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text)
    return re.sub(r"\s+", " ", text).strip()


def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = " ".join(sentences[max(0, i - 1): i + 2]) if use_multi_sentence else sent.strip()
                    enriched = f"The concept '{ent_lower}' appears in the following context: {context}"
                    entity_context_pairs.append((ent_lower, enriched.strip()))
                    matched = True
                    break
            if not matched:
                fallback = f"The concept '{ent_lower}' appears in the following context: {chunk}"
                entity_context_pairs.append((ent_lower, fallback.strip()))
    return entity_context_pairs


# === TOPIC SEARCHER CLASS (WITH DEDUPLICATION, NOISE FILTERING) ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params,
                 model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(self.chunks, self.entities_per_chunk)
        contextual_texts = [ctx for _, ctx in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=False)

        umap_model = UMAP(**self.umap_params, random_state=SEED)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            if topic == -1:
                continue  # Skip noisy topics
            ent, ctx = entity_context_pairs[i]
            topic_to_contexts[topic].append(ctx)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            emb = topic_to_embeddings[topic_id]
            centroid = np.mean(emb, axis=0)
            centroid /= np.linalg.norm(centroid) + 1e-10
            topic_embeddings.append(centroid)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(emb)
            })

        # === OPTIONAL: Merge semantically similar topics (cosine sim > 0.95)
        deduped_metadata = []
        used = set()

        for i, emb_i in enumerate(topic_embeddings):
            if i in used:
                continue
            group = [i]
            sim_scores = cosine_similarity([emb_i], topic_embeddings)[0]
            for j in range(i + 1, len(sim_scores)):
                if sim_scores[j] > 0.95:
                    group.append(j)
                    used.add(j)

            merged = {
                "topic_id": i,
                "sentences": [],
                "entities": [],
                "sentence_embeddings": []
            }
            for g in group:
                merged["sentences"] += topic_metadata[g]["sentences"]
                merged["entities"] += topic_metadata[g]["entities"]
                merged["sentence_embeddings"] += list(topic_metadata[g]["sentence_embeddings"])

            merged["sentence_embeddings"] = np.array(merged["sentence_embeddings"])
            merged["entities"] = list(set(merged["entities"]))
            deduped_metadata.append(merged)

        self.topic_metadata = deduped_metadata
        self.topic_embeddings = np.array([
            np.mean(m["sentence_embeddings"], axis=0) /
            (np.linalg.norm(np.mean(m["sentence_embeddings"], axis=0)) + 1e-10)
            for m in deduped_metadata
        ])

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=10, num_leaves_to_search=5, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(5)
            .build()
        )

    import re

    def search(self, query, top_k_topics=3, top_k_sents=3):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        prefix_pattern = r"^the concept '.*?' appears in (the following )?context:\s*"

        for i, idx in enumerate(neighbors):
            meta = self.topic_metadata[idx]
            topic_score = float(scores[i])

            # Deduplicate sentences
            seen = set()
            cleaned_sentences = []
            cleaned_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
            # Apply regex to remove beginning prefix
                cleaned = re.sub(prefix_pattern, "", sent, flags=re.IGNORECASE).strip()

        # No duplicates
                if cleaned not in seen:
                    seen.add(cleaned)
                    cleaned_sentences.append(cleaned)
                    cleaned_embeddings.append(emb)

            if not cleaned_sentences:
                continue

            emb_array = np.array(cleaned_embeddings)
            sims = np.dot(emb_array / np.linalg.norm(emb_array, axis=1, keepdims=True), query_emb)
            top_ids = sims.argsort()[::-1][:top_k_sents]

            top_sents = [(cleaned_sentences[j], float(sims[j])) for j in top_ids]
            results.append({
            "topic_id": meta["topic_id"],
            "topic_score": topic_score,
            "entities": meta["entities"],
            "sentences": top_sents,
            })

        return results




# === METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for s in meta["sentences"]:
            tokens = clean_text(s).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(t) for t in texts]

    cm = CoherenceModel(
        topics=word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return cm.get_coherence()


def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topic_words = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    unique_words = set(word for topic in topic_words for word, _ in topic)
    return len(unique_words) / (len(topic_words) * topk)


def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        emb = meta["sentence_embeddings"]
        if len(emb) < 2:  # skip small clusters
            continue
        all_embeddings.extend(emb)
        all_labels.extend([meta["topic_id"]] * len(emb))

    if len(all_embeddings) < 3:
        return None

    all_embeddings = np.vstack(all_embeddings)
    return silhouette_score(all_embeddings, all_labels, metric="cosine")


# === DATASET & INITIALIZATION ===
allergy_dataset = {
  "chunks": [
  "A 58-year-old female with a history of invasive ductal carcinoma of the breast presented with unintentional weight loss and malaise over two months.",
  "She reported new onset of bone pain localized to the lower back and hips, progressively worsening in intensity.",
  "There was a history of intermittent fever and night sweats but no reported cough or hemoptysis.",
  "Her past medical history includes Stage II breast cancer treated with lumpectomy and adjuvant chemotherapy three years ago.",
  "She completed radiation therapy 18 months prior and was on hormonal therapy with tamoxifen.",
  "Family history revealed a mother diagnosed with ovarian cancer at age 62.",
  "Physical examination showed pallor, mild lymphadenopathy in the supraclavicular region, and tenderness over the lumbar spine.",
  "Vital signs were stable with a low-grade fever of 37.8°C documented.",
  "Neurological examination was non-focal with intact motor and sensory function.",
  "Skin inspection revealed scattered petechiae over the lower extremities.",
  "Initial laboratory tests showed anemia with a hemoglobin of 9.6 g/dL and elevated alkaline phosphatase.",
  "Serum calcium was elevated at 11.2 mg/dL with normal renal function.",
  "Liver enzymes revealed mildly elevated transaminases.",
  "Tumor markers demonstrated elevated CA 15-3 and carcinoembryonic antigen (CEA).",
  "Bone marrow biopsy was scheduled given pancytopenia and bone pain.",
  "Plain radiographs of the lumbar spine displayed lytic lesions involving vertebral bodies L3 to L5.",
  "A whole-body bone scan indicated multiple areas of increased uptake consistent with skeletal metastases.",
  "Contrast-enhanced CT scan of the chest, abdomen, and pelvis showed multiple hypodense liver lesions.",
  "No pulmonary nodules or mediastinal lymphadenopathy were identified.",
  "Ultrasound-guided liver biopsy confirmed metastatic adenocarcinoma consistent with a breast primary.",
  "Immunohistochemistry revealed estrogen and progesterone receptor positivity with HER2 negativity.",
  "The oncology team initiated treatment with a combination of bisphosphonates and aromatase inhibitors.",
  "Supportive care included analgesics for bone pain and intravenous hydration for hypercalcemia.",
  "She was referred to pain management and physiotherapy services early in the treatment course.",
  "The patient was counseled on prognosis, treatment goals, and potential side effects, emphasizing quality of life.",
  "Follow-up imaging after three cycles of systemic therapy showed partial response with reduction in liver lesions.",
  "Her hemoglobin levels improved gradually, and bone pain was better controlled after bisphosphonate therapy.",
  "Repeat tumor markers demonstrated a downward trend correlating with clinical improvement.",
  "She experienced mild nausea and fatigue as common adverse effects of aromatase inhibitors.",
  "Endocrinology was consulted to monitor bone mineral density and calcium homeostasis.",
  "Psychological support was provided as the patient reported anxiety and mood changes linked to diagnosis.",
  "Genetic counseling was offered due to family history and implications for targeted therapies and relatives.",
  "The patient was screened for vitamin D deficiency and started supplementation accordingly.",
  "Regular liver function tests and hematologic parameters were closely monitored throughout treatment.",
  "No evidence of brain metastases was found on screening MRI performed due to headaches.",
  "The multidisciplinary tumor board reviewed her case regularly to optimize management plans.",
  "Physical rehabilitation focused on maintaining mobility and preventing skeletal complications.",
  "She was advised on diet modifications to address nutritional deficits and maintain strength.",
  "Symptoms of potential pathological fractures were carefully monitored, with orthopedics in consultation.",
  "The patient's preferences strongly influenced decision-making, with goals centered on symptom control.",
  "Anemia was managed conservatively with erythropoiesis-stimulating agents as hemoglobin fell below threshold.",
  "The endocrinology team adjusted hormonal therapy due to emerging side effects including arthralgia.",
  "Bone biopsies excluded secondary malignancies and confirmed disease progression status.",
  "A plan for cyclic imaging surveillance was established to assess disease stability.",
  "The patient was enrolled in a clinical trial investigating novel agents for metastatic breast cancer.",
  "Treatment adherence was supported via nursing follow-ups and patient education materials.",
  "The role of palliative care was discussed early with emphasis on holistic support.",
  "Complications such as thrombocytopenia developed transiently, managed with dose adjustments.",
  "Cardiotoxicity surveillance was instituted due to prior chemotherapy exposure, with echocardiograms scheduled.",
  "The patient reported improved mood and coping strategies following psychiatric intervention.",
  "End-of-life care planning was briefly addressed with patient and family input incorporated.",
  "The patient participated actively in support groups facilitated by the oncology center.",
  "Pain control was optimized using a multimodal analgesic regimen including opioids and NSAIDs.",
  "Bone-modifying agents helped reduce skeletal related events and improve quality of life.",
  "Serum biomarkers were periodically assessed as part of response monitoring.",
  "The impact of hormonal therapy on bone health was managed with calcium and vitamin D supplements.",
  "Anemia-related fatigue was addressed with physical activity modulation and energy conservation techniques.",
  "The patient was closely observed for signs of disease progression or treatment resistance.",
  "Interdisciplinary communication between oncology, nursing, rehabilitation, and social workers was continuous.",
  "Access to financial and social support services was coordinated to assist with treatment costs.",
  "The patient’s diary was used to track symptoms and side effects to inform ongoing care.",
  "Reassessment of goals of care was planned regularly to align with changing clinical status.",
  "International guidelines for metastatic breast cancer management were applied in treatment decisions.",
  "The case exemplified challenges in balancing disease control with minimizing treatment toxicity.",
  "The patient was encouraged to engage in meaningful activities and social interactions.",
  "Long-term follow-up focused on symptom management, surveillance, and survivorship issues.",
  "The clinical summary contributed to educational efforts for trainees in complex cancer care.",
  "Transitions between inpatient, outpatient, and supportive community care were smoothly implemented.",
  "Family meetings facilitated shared decision-making and education about prognosis and care options.",
  "The clinical documentation included detailed medication reconciliation and allergy assessments.",
  "Data captured support future research and quality improvement initiatives in metastatic breast cancer.",
  "Final diagnoses included metastatic estrogen receptor positive breast cancer with skeletal and hepatic involvement.",
  "The patient’s overall prognosis was communicated with sensitivity and clarity aligning care expectations."
]


,
  "entities":[
  ["female", "carcinoma", "breast", "weight", "malaise"],
  ["bone", "pain", "back", "hips"],
  ["fever", "sweats", "cough", "hemoptysis"],
  ["cancer", "lumpectomy", "chemotherapy"],
  ["radiation", "therapy", "tamoxifen"],
  ["family", "ovarian", "cancer"],
  ["pallor", "lymphadenopathy", "supraclavicular", "tenderness", "lumbar", "spine"],
  ["vital", "fever"],
  ["neurological", "motor", "sensory"],
  ["skin", "petechiae", "extremities"],
  ["anemia", "hemoglobin", "alkaline", "phosphatase"],
  ["calcium", "renal", "function"],
  ["liver", "enzymes", "transaminases"],
  ["tumor", "markers", "ca15-3", "cea"],
  ["bone", "marrow", "biopsy", "pancytopenia"],
  ["radiographs", "vertebral", "lytic", "lesions"],
  ["bone", "scan", "metastases"],
  ["ct", "chest", "abdomen", "pelvis", "lesions"],
  ["pulmonary", "nodules", "lymphadenopathy"],
  ["biopsy", "adenocarcinoma", "breast", "primary"],
  ["immunohistochemistry", "estrogen", "progesterone", "her2"],
  ["bisphosphonates", "aromatase", "inhibitors"],
  ["analgesics", "hypercalcemia", "hydration"],
  ["pain", "management", "physiotherapy"],
  ["counseling", "prognosis", "quality", "life"],
  ["imaging", "therapy", "response"],
  ["hemoglobin", "bisphosphonate", "therapy"],
  ["tumor", "markers", "trend"],
  ["nausea", "fatigue", "aromatase"],
  ["endocrinology", "bone", "density", "calcium"],
  ["psychological", "anxiety", "mood"],
  ["genetic", "counseling", "therapy"],
  ["vitamin", "d", "supplementation"],
  ["liver", "function", "hematologic"],
  ["brain", "metastases", "mri", "headaches"],
  ["tumor", "board", "management"],
  ["rehabilitation", "mobility", "complications"],
  ["diet", "nutritional", "deficits"],
  ["fractures", "orthopedics"],
  ["patient", "goals", "symptom", "control"],
  ["anemia", "erythropoiesis"],
  ["side-effects", "arthralgia", "hormonal"],
  ["bone", "biopsies", "progression"],
  ["imaging", "surveillance"],
  ["clinical", "trial", "breast", "cancer"],
  ["adherence", "nursing", "education"],
  ["palliative", "care", "support"],
  ["thrombocytopenia", "dose", "adjustments"],
  ["cardiotoxicity", "echocardiogram", "chemotherapy"],
  ["psychiatric", "intervention", "mood"],
  ["care", "planning", "family"],
  ["support", "groups", "oncology"],
  ["analgesic", "opioids", "nsaids"],
  ["bone-modifying", "skeletal", "quality"],
  ["biomarkers", "monitoring"],
  ["hormonal", "therapy", "calcium", "vitamin"],
  ["fatigue", "activity", "conservation"],
  ["disease", "progression", "resistance"],
  ["communication", "oncology", "rehabilitation", "social"],
  ["financial", "support", "treatment"],
  ["symptoms", "side-effects", "diary"],
  ["goals", "care", "status"],
  ["guidelines", "management"],
  ["balance", "control", "toxicity"],
  ["activities", "social", "interaction"],
  ["survivorship", "monitoring"],
  ["education", "trainees", "cancer"],
  ["care", "transitions", "community"],
  ["decision-making", "family", "prognosis"],
  ["documentation", "medication", "allergy"],
  ["research", "quality", "improvement"],
  ["diagnosis", "breast", "cancer", "metastatic", "skeletal", "hepatic"],
  ["prognosis", "communication", "care"]
]




}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist":0.25,  "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
)
print("✅ Model ready for querying.")
# print("\n=== 🧠 Generated Topics and Entities ===")
# for meta in searcher.topic_metadata:
#     topic_id = meta["topic_id"]
#     entities = ", ".join(meta["entities"])
#     print(f"🔹 Topic ID: {topic_id} — Entities: {entities}")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics =[
  {
    "topic_id": "T0",
    "entities": [
      "breast", "carcinoma", "adenocarcinoma", "metastatic", "primary",
      "estrogen", "progesterone", "her2", "tumor", "markers", "ca15-3", "cea",
      "immunohistochemistry", "biopsy", "bone", "marrow", "lesions", "skeletal",
      "hepatic", "liver", "metastases", "imaging", "ct", "radiographs", "bone", "scan",
      "mri", "brain", "metastases"
    ]
  },
  {
    "topic_id": "T1",
    "entities": [
      "weight", "loss", "malaise", "fatigue",
      "fever", "sweats", "night", "anemia", "hemoglobin", "pancytopenia",
      "thrombocytopenia", "platelets", "calcium", "hypercalcemia",
      "alkaline", "phosphatase", "liver", "enzymes", "transaminases",
      "renal", "function", "hematologic", "inflammatory", "markers", "crp", "esr"
    ]
  },
  {
    "topic_id": "T2",
    "entities": [
      "pain", "back", "hips", "bone", "fractures",
      "analgesics", "opioids", "nsaids", "bisphosphonates",
      "bone-modifying", "skeletal", "therapy", "physiotherapy", "pain", "management"
    ]
  },
  {
    "topic_id": "T3",
    "entities": [
      "chemotherapy", "radiation", "therapy", "tamoxifen",
      "aromatase", "inhibitors", "hormonal", "therapy", "side-effects", "toxicity",
      "cardiotoxicity", "echocardiogram", "dose", "adjustments"
    ]
  },
  {
    "topic_id": "T4",
    "entities": [
      "family", "genetic", "counseling", "history", "ovarian", "cancer"
    ]
  },
  {
    "topic_id": "T5",
    "entities": [
      "vital", "signs", "fever", "pallor", "lymphadenopathy", "supraclavicular",
      "neurological", "examination", "motor", "sensory", "non-focal", "skin", "petechiae"
    ]
  },
  {
    "topic_id": "T6",
    "entities": [
      "patient", "goals", "care", "prognosis", "quality", "life", "psychological",
      "support", "anxiety", "mood", "psychiatric", "intervention"
    ]
  },
  {
    "topic_id": "T7",
    "entities": [
      "support", "groups", "oncology", "rehabilitation", "mobility", "social", "interaction"
    ]
  },
  {
    "topic_id": "T8",
    "entities": [
      "nutrition", "diet", "supplementation", "vitamin", "d", "calcium",
      "hydration", "nutrition", "therapy"
    ]
  },
  {
    "topic_id": "T9",
    "entities": [
      "clinical", "trial", "enrollment", "research", "quality", "improvement"
    ]
  },
  {
    "topic_id": "T10",
    "entities": [
      "management", "team", "oncology", "nursing", "endocrinology", "orthopedics",
      "pain", "management", "physiotherapy", "multidisciplinary", "tumor", "board"
    ]
  },
  {
    "topic_id": "T11",
    "entities": [
      "medication", "adherence", "education", "counseling", "follow-up",
      "monitoring", "labs", "surveillance", "compliance"
    ]
  },
  {
    "topic_id": "T12",
    "entities": [
      "support", "financial", "social", "services", "family", "decision-making", "communication"
    ]
  }
]


# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
#print("\n=== 📊 Topic Matching Summary ===")
#for model_id, gt_id, score in matches:
    #print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=3, top_k_sents=3)

    threshold = 0.2  # Set your similarity threshold here

    # Find the max topic score among the results (0 if no results)
    max_topic_score = max((res["topic_score"] for res in results), default=0)

    if max_topic_score < threshold or not results:
        print("Sorry, the query is irrelevant.")
    else:
        print(f"\n🔎 Top results for: '{query}'")
        for res in results:
            print(f"🧠 Topic ID: {res['topic_id']} (Score: {res['topic_score']:.4f})")
            print(f"🔗 Related Entities: {', '.join(res['entities'])}")
            for sent, score in res["sentences"]:
                print(f"✓ [{score:.4f}] {sent}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.7093
🌈 Topic Diversity: 0.4000
📐 Silhouette Score: 0.7744

🧮 Average Jaccard Similarity: 0.1628
📈 Ground Truth Coverage: 13/13 (100.0%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 0.7059
🧲 Recall:    0.8571
🏅 F1 Score:  0.7742

=== Allergy Topic Search ===

Enter a query (or type 'exit' to quit): How do you diagnose and treat community-acquired pneumonia?

🔎 Top results for: 'How do you diagnose and treat community-acquired pneumonia?'
🧠 Topic ID: 43 (Score: 0.3604)
🔗 Related Entities: goals, control, symptom, patient
✓ [0.3555] the patients preferences strongly influenced decisionmaking with goals centered on symptom control
🧠 Topic ID: 50 (Score: 0.3368)
🔗 Related Entities: care, support, palliative
✓ [0.3494] the role of palliative care was discussed early with emphasis on holistic support
🧠 Topic ID: 67 (Score: 0.3360)
🔗 Related Entities: control, toxicity, balance
✓ [0.2677] the case exemplif

In [None]:
#Small size dataset:(5000 characters(15-20chunks))

In [None]:
#data 1:

In [1]:
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict, Counter
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score, precision_recall_fscore_support
from sklearn.metrics.pairwise import cosine_similarity

# === SEED FIXING ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
nltk.download("punkt")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# === CLEANING & CONTEXT EXTRACTION (IMPROVED) ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text)
    return re.sub(r"\s+", " ", text).strip()


def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = " ".join(sentences[max(0, i - 1): i + 2]) if use_multi_sentence else sent.strip()
                    enriched = f"The concept '{ent_lower}' appears in the following context: {context}"
                    entity_context_pairs.append((ent_lower, enriched.strip()))
                    matched = True
                    break
            if not matched:
                fallback = f"The concept '{ent_lower}' appears in the following context: {chunk}"
                entity_context_pairs.append((ent_lower, fallback.strip()))
    return entity_context_pairs


# === TOPIC SEARCHER CLASS (WITH DEDUPLICATION, NOISE FILTERING) ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params,
                 model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(self.chunks, self.entities_per_chunk)
        contextual_texts = [ctx for _, ctx in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=False)

        umap_model = UMAP(**self.umap_params, random_state=SEED)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            if topic == -1:
                continue  # Skip noisy topics
            ent, ctx = entity_context_pairs[i]
            topic_to_contexts[topic].append(ctx)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            emb = topic_to_embeddings[topic_id]
            centroid = np.mean(emb, axis=0)
            centroid /= np.linalg.norm(centroid) + 1e-10
            topic_embeddings.append(centroid)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(emb)
            })

        # === OPTIONAL: Merge semantically similar topics (cosine sim > 0.95)
        deduped_metadata = []
        used = set()

        for i, emb_i in enumerate(topic_embeddings):
            if i in used:
                continue
            group = [i]
            sim_scores = cosine_similarity([emb_i], topic_embeddings)[0]
            for j in range(i + 1, len(sim_scores)):
                if sim_scores[j] > 0.95:
                    group.append(j)
                    used.add(j)

            merged = {
                "topic_id": i,
                "sentences": [],
                "entities": [],
                "sentence_embeddings": []
            }
            for g in group:
                merged["sentences"] += topic_metadata[g]["sentences"]
                merged["entities"] += topic_metadata[g]["entities"]
                merged["sentence_embeddings"] += list(topic_metadata[g]["sentence_embeddings"])

            merged["sentence_embeddings"] = np.array(merged["sentence_embeddings"])
            merged["entities"] = list(set(merged["entities"]))
            deduped_metadata.append(merged)

        self.topic_metadata = deduped_metadata
        self.topic_embeddings = np.array([
            np.mean(m["sentence_embeddings"], axis=0) /
            (np.linalg.norm(np.mean(m["sentence_embeddings"], axis=0)) + 1e-10)
            for m in deduped_metadata
        ])

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=10, num_leaves_to_search=5, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(5)
            .build()
        )

    import re

    def search(self, query, top_k_topics=3, top_k_sents=3):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        prefix_pattern = r"^the concept '.*?' appears in (the following )?context:\s*"

        for i, idx in enumerate(neighbors):
            meta = self.topic_metadata[idx]
            topic_score = float(scores[i])

            # Deduplicate sentences
            seen = set()
            cleaned_sentences = []
            cleaned_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
            # Apply regex to remove beginning prefix
                cleaned = re.sub(prefix_pattern, "", sent, flags=re.IGNORECASE).strip()

        # No duplicates
                if cleaned not in seen:
                    seen.add(cleaned)
                    cleaned_sentences.append(cleaned)
                    cleaned_embeddings.append(emb)

            if not cleaned_sentences:
                continue

            emb_array = np.array(cleaned_embeddings)
            sims = np.dot(emb_array / np.linalg.norm(emb_array, axis=1, keepdims=True), query_emb)
            top_ids = sims.argsort()[::-1][:top_k_sents]

            top_sents = [(cleaned_sentences[j], float(sims[j])) for j in top_ids]
            results.append({
            "topic_id": meta["topic_id"],
            "topic_score": topic_score,
            "entities": meta["entities"],
            "sentences": top_sents,
            })

        return results




# === METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for s in meta["sentences"]:
            tokens = clean_text(s).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(t) for t in texts]

    cm = CoherenceModel(
        topics=word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return cm.get_coherence()


def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topic_words = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    unique_words = set(word for topic in topic_words for word, _ in topic)
    return len(unique_words) / (len(topic_words) * topk)


def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        emb = meta["sentence_embeddings"]
        if len(emb) < 2:  # skip small clusters
            continue
        all_embeddings.extend(emb)
        all_labels.extend([meta["topic_id"]] * len(emb))

    if len(all_embeddings) < 3:
        return None

    all_embeddings = np.vstack(all_embeddings)
    return silhouette_score(all_embeddings, all_labels, metric="cosine")


# === DATASET & INITIALIZATION ===
allergy_dataset = {
  "chunks": [
  "A 75-year-old male with ischemic cardiomyopathy (LVEF 28%) and persistent atrial fibrillation on warfarin presented with worsening dyspnea and orthopnea over 2 weeks.",
  "He has a history of COPD, stage 4 chronic kidney disease, type 2 diabetes mellitus (HbA1c 9.1%), and prior CABG in 2008.",
  "On admission, his vitals were BP 155/95, HR 112 irregular, RR 28, and SpO₂ 86% on room air, improving with oxygen via nasal cannula.",
  "Physical exam revealed elevated JVP, bibasilar crackles, 3+ pitting edema, and an S3 gallop.",
  "Labs showed Cr 2.8, eGFR 25, BNP 2,400 pg/mL, and INR 2.3. ABG showed pH 7.36, PaCO₂ 50, PaO₂ 70. CXR indicated cardiomegaly and pleural effusions.",
  "ECG confirmed atrial fibrillation with QTc 480 ms. Pacemaker function was intact.",
  "He was treated with IV furosemide, oxygen therapy, and continued on carvedilol and lisinopril. Metformin was held due to renal impairment.",
  "Diabetes was managed with basal-bolus insulin and sliding scale. A low-sodium renal diet was started.",
  "Education included fluid restriction, daily weight tracking, and signs of heart failure. Warfarin was continued with INR monitoring.",
  "He showed improvement with a 3 kg net negative fluid balance and SpO₂ rising to 94%.",
  "An echocardiogram showed LVEF 28%, moderate mitral regurgitation, and left atrial enlargement.",
  "Discharge medications included furosemide 80 mg, carvedilol 12.5 mg BID, lisinopril 5 mg, spironolactone 25 mg, basal insulin, and warfarin.",
  "Follow-up appointments were scheduled with cardiology, nephrology, and the diabetes clinic. Anticoagulation clinic and home health visits arranged.",
  "Patient is a widower living alone with limited mobility and uses a walker. Transportation issues were noted.",
  "Social work coordinated support services including meal delivery, medication management, and transport assistance."
]
,
  "entities": [
  ["male", "cardiomyopathy", "lvef", "fibrillation", "warfarin", "dyspnea", "orthopnea"],
  ["copd", "kidney", "disease", "diabetes", "hba1c", "cabg"],
  ["bp", "hr", "irregularity", "rr", "spo2", "oxygen", "nasal", "cannula"],
  ["jvp", "crackles", "edema", "gallop"],
  ["cr", "egfr", "bnp", "inr", "abg", "ph", "paco2", "pao2", "cxr", "cardiomegaly", "effusions"],
  ["ecg", "fibrillation", "qtc", "pacemaker"],
  ["furosemide", "oxygen", "therapy", "carvedilol", "lisinopril", "metformin", "renal", "impairment"],
  ["diabetes", "insulin", "scale", "diet"],
  ["fluid", "weight", "tracking", "signs", "failure", "warfarin", "inr", "monitoring"],
  ["improvement", "fluid", "balance", "spo2"],
  ["echocardiogram", "lvef", "regurgitation", "atrium", "enlargement"],
  ["furosemide", "carvedilol", "lisinopril", "spironolactone", "insulin", "warfarin"],
  ["follow-up", "cardiology", "nephrology", "clinic", "anticoagulation", "health", "visits"],
  ["widower", "mobility", "walker"],
  ["social", "support", "meal", "delivery", "medication", "management", "transport"]
]



}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist":0.25,  "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
)
print("✅ Model ready for querying.")
# print("\n=== 🧠 Generated Topics and Entities ===")
# for meta in searcher.topic_metadata:
#     topic_id = meta["topic_id"]
#     entities = ", ".join(meta["entities"])
#     print(f"🔹 Topic ID: {topic_id} — Entities: {entities}")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics =[
  {
    "topic_id": "T0",
    "entities": [
      "cardiomyopathy", "lvef", "fibrillation", "warfarin", "dyspnea", "orthopnea",
      "ecg", "qtc", "pacemaker", "regurgitation", "atrium", "enlargement",
      "echocardiogram", "carvedilol", "lisinopril", "spironolactone"
    ]
  },
  {
    "topic_id": "T1",
    "entities": [
      "copd", "kidney", "disease", "egfr", "renal", "impairment", "cr"
    ]
  },
  {
    "topic_id": "T2",
    "entities": [
      "diabetes", "hba1c", "insulin", "scale", "metformin", "diet"
    ]
  },
  {
    "topic_id": "T3",
    "entities": [
      "bp", "hr", "rr", "spo2", "oxygen", "nasal", "cannula"
    ]
  },
  {
    "topic_id": "T4",
    "entities": [
      "jvp", "crackles", "edema", "gallop", "cardiomegaly", "effusions",
      "bnp"
    ]
  },
  {
    "topic_id": "T5",
    "entities": [
      "abg", "ph", "paco2", "pao2"
    ]
  },
  {
    "topic_id": "T6",
    "entities": [
      "furosemide", "oxygen", "therapy"
    ]
  },
  {
    "topic_id": "T7",
    "entities": [
      "fluid", "weight", "balance", "tracking"
    ]
  },
  {
    "topic_id": "T8",
    "entities": [
      "follow-up", "cardiology", "nephrology", "clinic", "anticoagulation", "health", "visits"
    ]
  },
  {
    "topic_id": "T9",
    "entities": [
      "mobility", "walker"
    ]
  },
  {
    "topic_id": "T10",
    "entities": [
      "meal", "delivery", "medication", "management", "transport"
    ]
  }
]





# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
#print("\n=== 📊 Topic Matching Summary ===")
#for model_id, gt_id, score in matches:
    #print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=3, top_k_sents=3)

    threshold = 0.2  # Set your similarity threshold here

    # Find the max topic score among the results (0 if no results)
    max_topic_score = max((res["topic_score"] for res in results), default=0)

    if max_topic_score < threshold or not results:
        print("Sorry, the query is irrelevant.")
    else:
        print(f"\n🔎 Top results for: '{query}'")
        for res in results:
            print(f"🧠 Topic ID: {res['topic_id']} (Score: {res['topic_score']:.4f})")
            print(f"🔗 Related Entities: {', '.join(res['entities'])}")
            for sent, score in res["sentences"]:
                print(f"✓ [{score:.4f}] {sent}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/433M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/412 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.7297
🌈 Topic Diversity: 0.5375
📐 Silhouette Score: 0.7615

🧮 Average Jaccard Similarity: 0.4470
📈 Ground Truth Coverage: 11/11 (100.0%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 0.8442
🧲 Recall:    0.9701
🏅 F1 Score:  0.9028

=== Allergy Topic Search ===

Enter a query (or type 'exit' to quit): What are the main presenting symptoms?

🔎 Top results for: 'What are the main presenting symptoms?'
🧠 Topic ID: 5 (Score: 0.3312)
🔗 Related Entities: jvp, crackles, gallop, edema
✓ [0.2830] physical exam revealed elevated jvp bibasilar crackles 3 pitting edema and an s3 gallop
🧠 Topic ID: 14 (Score: 0.2294)
🔗 Related Entities: lvef, atrium, echocardiogram, enlargement, regurgitation
✓ [0.2386] an echocardiogram showed lvef 28 moderate mitral regurgitation and left atrial enlargement
🧠 Topic ID: 16 (Score: 0.2134)
🔗 Related Entities: nephrology, clinic, cardiology, anticoagulation, health, follow-up, visit

In [2]:
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict, Counter
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score, precision_recall_fscore_support
from sklearn.metrics.pairwise import cosine_similarity

# === SEED FIXING ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
nltk.download("punkt")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# === CLEANING & CONTEXT EXTRACTION (IMPROVED) ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text)
    return re.sub(r"\s+", " ", text).strip()


def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = " ".join(sentences[max(0, i - 1): i + 2]) if use_multi_sentence else sent.strip()
                    enriched = f"The concept '{ent_lower}' appears in the following context: {context}"
                    entity_context_pairs.append((ent_lower, enriched.strip()))
                    matched = True
                    break
            if not matched:
                fallback = f"The concept '{ent_lower}' appears in the following context: {chunk}"
                entity_context_pairs.append((ent_lower, fallback.strip()))
    return entity_context_pairs


# === TOPIC SEARCHER CLASS (WITH DEDUPLICATION, NOISE FILTERING) ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params,
                 model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(self.chunks, self.entities_per_chunk)
        contextual_texts = [ctx for _, ctx in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=False)

        umap_model = UMAP(**self.umap_params, random_state=SEED)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            if topic == -1:
                continue  # Skip noisy topics
            ent, ctx = entity_context_pairs[i]
            topic_to_contexts[topic].append(ctx)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            emb = topic_to_embeddings[topic_id]
            centroid = np.mean(emb, axis=0)
            centroid /= np.linalg.norm(centroid) + 1e-10
            topic_embeddings.append(centroid)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(emb)
            })

        # === OPTIONAL: Merge semantically similar topics (cosine sim > 0.95)
        deduped_metadata = []
        used = set()

        for i, emb_i in enumerate(topic_embeddings):
            if i in used:
                continue
            group = [i]
            sim_scores = cosine_similarity([emb_i], topic_embeddings)[0]
            for j in range(i + 1, len(sim_scores)):
                if sim_scores[j] > 0.95:
                    group.append(j)
                    used.add(j)

            merged = {
                "topic_id": i,
                "sentences": [],
                "entities": [],
                "sentence_embeddings": []
            }
            for g in group:
                merged["sentences"] += topic_metadata[g]["sentences"]
                merged["entities"] += topic_metadata[g]["entities"]
                merged["sentence_embeddings"] += list(topic_metadata[g]["sentence_embeddings"])

            merged["sentence_embeddings"] = np.array(merged["sentence_embeddings"])
            merged["entities"] = list(set(merged["entities"]))
            deduped_metadata.append(merged)

        self.topic_metadata = deduped_metadata
        self.topic_embeddings = np.array([
            np.mean(m["sentence_embeddings"], axis=0) /
            (np.linalg.norm(np.mean(m["sentence_embeddings"], axis=0)) + 1e-10)
            for m in deduped_metadata
        ])

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=10, num_leaves_to_search=5, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(5)
            .build()
        )

    import re

    def search(self, query, top_k_topics=3, top_k_sents=3):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        prefix_pattern = r"^the concept '.*?' appears in (the following )?context:\s*"

        for i, idx in enumerate(neighbors):
            meta = self.topic_metadata[idx]
            topic_score = float(scores[i])

            # Deduplicate sentences
            seen = set()
            cleaned_sentences = []
            cleaned_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
            # Apply regex to remove beginning prefix
                cleaned = re.sub(prefix_pattern, "", sent, flags=re.IGNORECASE).strip()

        # No duplicates
                if cleaned not in seen:
                    seen.add(cleaned)
                    cleaned_sentences.append(cleaned)
                    cleaned_embeddings.append(emb)

            if not cleaned_sentences:
                continue

            emb_array = np.array(cleaned_embeddings)
            sims = np.dot(emb_array / np.linalg.norm(emb_array, axis=1, keepdims=True), query_emb)
            top_ids = sims.argsort()[::-1][:top_k_sents]

            top_sents = [(cleaned_sentences[j], float(sims[j])) for j in top_ids]
            results.append({
            "topic_id": meta["topic_id"],
            "topic_score": topic_score,
            "entities": meta["entities"],
            "sentences": top_sents,
            })

        return results




# === METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for s in meta["sentences"]:
            tokens = clean_text(s).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(t) for t in texts]

    cm = CoherenceModel(
        topics=word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return cm.get_coherence()


def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topic_words = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    unique_words = set(word for topic in topic_words for word, _ in topic)
    return len(unique_words) / (len(topic_words) * topk)


def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        emb = meta["sentence_embeddings"]
        if len(emb) < 2:  # skip small clusters
            continue
        all_embeddings.extend(emb)
        all_labels.extend([meta["topic_id"]] * len(emb))

    if len(all_embeddings) < 3:
        return None

    all_embeddings = np.vstack(all_embeddings)
    return silhouette_score(all_embeddings, all_labels, metric="cosine")


# === DATASET & INITIALIZATION ===
allergy_dataset = {
  "chunks": [
  "A 68-year-old female with a history of rheumatoid arthritis and congestive heart failure presented with progressive fatigue and swelling in her lower extremities over one month.",
  "She reports worsening shortness of breath on exertion and occasional palpitations but denies chest pain or syncope.",
  "Past medical records show diagnosis of hypertension, type 2 diabetes mellitus, and chronic kidney disease stage 3.",
  "Home medications include methotrexate, lisinopril, metformin, furosemide, and low-dose aspirin.",
  "On admission, vital signs were notable for BP 145/90 mmHg, HR 90 bpm and irregular, RR 20 per minute, and oxygen saturation 94% on room air.",
  "Physical examination revealed jugular venous distension, bilateral pitting edema to mid-shin, and crackles over the lower lung fields.",
  "Cardiac auscultation noted an S3 gallop with no murmurs. Neurological exam was unremarkable without focal deficits.",
  "Laboratory investigations showed elevated B-type natriuretic peptide (BNP) at 1200 pg/mL and serum creatinine of 1.8 mg/dL, indicating moderate renal impairment.",
  "Electrolytes revealed mild hyponatremia and potassium within normal limits. HbA1c was 7.8%.",
  "Chest X-ray demonstrated cardiomegaly and moderate bilateral pleural effusions.",
  "An echocardiogram revealed a reduced left ventricular ejection fraction of 35%, moderate mitral regurgitation, and left atrial enlargement.",
  "ECG confirmed atrial fibrillation with controlled ventricular response but no acute ischemic changes.",
  "She was started on rate control therapy with carvedilol and her diuretic dose was adjusted to furosemide 40 mg twice daily.",
  "Methotrexate was held temporarily due to rising creatinine and concerns for nephrotoxicity.",
  "Diabetes management was optimized with basal insulin added to metformin for better glycemic control.",
  "She received education on sodium and fluid restriction, daily weight monitoring, and recognizing signs of volume overload.",
  "Over the hospitalization course, her peripheral edema improved significantly and oxygen requirement decreased.",
  "Renal function stabilized, and repeat labs showed improved electrolyte balance and decreased BNP levels.",
  "Coordinated care included nephrology consultation, cardiology follow-up, and physical therapy assessment for mobility.",
  "Discharge planning involved initiating anticoagulation with apixaban due to atrial fibrillation and arranging home health nursing visits."
]

,
  "entities":[
  ["female", "arthritis", "heart", "failure", "fatigue", "swelling", "extremities"],
  ["dyspnea", "palpitations", "chest", "pain", "syncope"],
  ["hypertension", "diabetes", "kidney", "disease"],
  ["methotrexate", "lisinopril", "metformin", "furosemide", "aspirin"],
  ["bp", "hr", "irregular", "rr", "oxygen", "saturation"],
  ["jugular", "venous", "edema", "crackles", "lungs"],
  ["cardiac", "auscultation", "gallop", "murmurs", "neurological", "deficits"],
  ["bnp", "creatinine", "renal", "impairment"],
  ["electrolytes", "hyponatremia", "potassium", "hba1c"],
  ["chest", "xray", "cardiomegaly", "effusions"],
  ["echocardiogram", "ventricular", "ejection", "regurgitation", "atrium", "enlargement"],
  ["ecg", "fibrillation", "ventricular", "response", "ischemic"],
  ["carvedilol", "furosemide", "diuretic"],
  ["methotrexate", "creatinine", "nephrotoxicity"],
  ["insulin", "metformin", "glycemic", "control"],
  ["education", "sodium", "fluid", "restriction", "weight", "overload"],
  ["edema", "oxygen", "requirement"],
  ["renal", "function", "electrolytes", "bnp"],
  ["nephrology", "cardiology", "therapy", "mobility"],
  ["anticoagulation", "apixaban", "atrial", "fibrillation", "nursing"]
]




}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist":0.25,  "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
)
print("✅ Model ready for querying.")
# print("\n=== 🧠 Generated Topics and Entities ===")
# for meta in searcher.topic_metadata:
#     topic_id = meta["topic_id"]
#     entities = ", ".join(meta["entities"])
#     print(f"🔹 Topic ID: {topic_id} — Entities: {entities}")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics =[
  {
    "topic_id": "T0",
    "entities": [
      "heart", "failure", "fibrillation", "carvedilol", "apixaban",
      "ventricular", "ejection", "regurgitation", "atrium", "enlargement",
      "ecg", "echocardiogram"
    ]
  },
  {
    "topic_id": "T1",
    "entities": [
      "arthritis", "methotrexate", "nephrotoxicity"
    ]
  },
  {
    "topic_id": "T2",
    "entities": [
      "hypertension", "diabetes", "kidney", "renal", "creatinine",
      "impairment", "nephrology"
    ]
  },
  {
    "topic_id": "T3",
    "entities": [
      "bp", "hr", "rr", "oxygen", "saturation", "edema", "jugular",
      "venous", "crackles", "lungs", "cardiomegaly", "effusions"
    ]
  },
  {
    "topic_id": "T4",
    "entities": [
      "palpitations", "pain", "syncope", "ischemic"
    ]
  },
  {
    "topic_id": "T5",
    "entities": [
      "aspirin", "lisinopril", "furosemide", "diuretic", "insulin", "metformin"
    ]
  },
  {
    "topic_id": "T6",
    "entities": [
      "electrolytes", "hyponatremia", "potassium", "bnp", "hba1c"
    ]
  },
  {
    "topic_id": "T7",
    "entities": [
      "glycemic", "control", "education", "sodium", "fluid", "restriction",
      "weight", "overload"
    ]
  },
  {
    "topic_id": "T8",
    "entities": [
      "oxygen", "requirement", "function"
    ]
  },
  {
    "topic_id": "T9",
    "entities": [
      "cardiology", "therapy", "mobility", "nursing", "visits", "follow-up"
    ]
  }
]






# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
#print("\n=== 📊 Topic Matching Summary ===")
#for model_id, gt_id, score in matches:
    #print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=3, top_k_sents=3)

    threshold = 0.2  # Set your similarity threshold here

    # Find the max topic score among the results (0 if no results)
    max_topic_score = max((res["topic_score"] for res in results), default=0)

    if max_topic_score < threshold or not results:
        print("Sorry, the query is irrelevant.")
    else:
        print(f"\n🔎 Top results for: '{query}'")
        for res in results:
            print(f"🧠 Topic ID: {res['topic_id']} (Score: {res['topic_score']:.4f})")
            print(f"🔗 Related Entities: {', '.join(res['entities'])}")
            for sent, score in res["sentences"]:
                print(f"✓ [{score:.4f}] {sent}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.7344
🌈 Topic Diversity: 0.5136
📐 Silhouette Score: 0.7405

🧮 Average Jaccard Similarity: 0.3682
📈 Ground Truth Coverage: 10/10 (100.0%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 0.8356
🧲 Recall:    0.9385
🏅 F1 Score:  0.8841

=== Allergy Topic Search ===

Enter a query (or type 'exit' to quit): What are the patient’s chronic medical conditions and comorbidities?

🔎 Top results for: 'What are the patient’s chronic medical conditions and comorbidities?'
🧠 Topic ID: 24 (Score: 0.4931)
🔗 Related Entities: cardiology, therapy, mobility, nephrology
✓ [0.4404] coordinated care included nephrology consultation cardiology followup and physical therapy assessment for mobility
🧠 Topic ID: 3 (Score: 0.4429)
🔗 Related Entities: kidney, diabetes, disease, hypertension
✓ [0.4264] past medical records show diagnosis of hypertension type 2 diabetes mellitus and chronic kidney disease stage 3
🧠 Topic ID: 0 (Score