In [1]:
!pip install numpy sentence-transformers bertopic hdbscan nltk scann
import nltk
nltk.download('punkt')
import nltk
nltk.download('punkt_tab')
!pip install sentence-transformers bertopic hdbscan umap-learn scann nltk datasets
!pip install gensim

Collecting bertopic
  Downloading bertopic-0.17.3-py3-none-any.whl.metadata (24 kB)
Collecting scann
  Downloading scann-1.4.0-cp311-cp311-manylinux_2_27_x86_64.whl.metadata (5.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transform

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


Collecting gensim
  Downloading gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting numpy<2.0,>=1.18.5 (from gensim)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy<1.14.0,>=1.7.0 (from gensim)
  Downloading scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
Downloading gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.7/26.7 MB[0m [31m63.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━

In [2]:
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict, Counter
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score, precision_recall_fscore_support

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
nltk.download("punkt")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# === CLEANING & CONTEXT EXTRACTION ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    return re.sub(r'\s+', ' ', text).strip()

def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = (
                        " ".join(sentences[max(0, i - 1): i + 2])
                        if use_multi_sentence else sent.strip()
                    )
                    entity_context_pairs.append((ent_lower, context.strip()))
                    matched = True
                    break
            if not matched:
                entity_context_pairs.append((ent_lower, chunk))
    return entity_context_pairs

# === TOPIC SEARCHER CLASS ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params,
                 model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(
            self.chunks,
            self.entities_per_chunk,
            use_multi_sentence=True
        )

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        contextual_texts = [context for _, context in entity_context_pairs]
        # UPDATED: use full 768-dim embeddings for search
        contextual_embeddings = self.embedding_model.encode(
            contextual_texts, normalize_embeddings=True
        )

        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            if len(sent_embs) == 0:
                continue

            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], float(sims[i])) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": list(meta["entities"]),  # <- ensure it's a list in case it's a set
                "sentences": top_sents,
            })

        return results

# === EVALUATION METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for sent in meta["sentences"]:
            tokens = clean_text(sent).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return coherence_model.get_coherence()

def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    all_words = [word for topic in topics for word, _ in topic]
    return len(set(all_words)) / (len(topics) * topk)

def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        embeddings = meta["sentence_embeddings"]
        labels = [meta["topic_id"]] * len(embeddings)
        all_embeddings.append(embeddings)
        all_labels.extend(labels)

    if len(all_embeddings) == 0:
        return None

    all_embeddings = np.vstack(all_embeddings)
    n_samples = all_embeddings.shape[0]
    n_labels = len(set(all_labels))

    if n_labels < 2 or n_labels > n_samples - 1:
        return None

    return silhouette_score(all_embeddings, all_labels, metric="cosine")

# === DATASET & INITIALIZATION ===
allergy_dataset = {
   "chunks": [
  "A 75-year-old male with ischemic cardiomyopathy (LVEF 28%) and persistent atrial fibrillation on warfarin presented with worsening dyspnea and orthopnea over 2 weeks.",
  "He has a history of COPD, stage 4 chronic kidney disease, type 2 diabetes mellitus (HbA1c 9.1%), and prior CABG in 2008.",
  "On admission, his vitals were BP 155/95, HR 112 irregular, RR 28, and SpO₂ 86% on room air, improving with oxygen via nasal cannula.",
  "Physical exam revealed elevated JVP, bibasilar crackles, 3+ pitting edema, and an S3 gallop.",
  "Labs showed Cr 2.8, eGFR 25, BNP 2,400 pg/mL, and INR 2.3. ABG showed pH 7.36, PaCO₂ 50, PaO₂ 70. CXR indicated cardiomegaly and pleural effusions.",
  "ECG confirmed atrial fibrillation with QTc 480 ms. Pacemaker function was intact.",
  "He was treated with IV furosemide, oxygen therapy, and continued on carvedilol and lisinopril. Metformin was held due to renal impairment.",
  "Diabetes was managed with basal-bolus insulin and sliding scale. A low-sodium renal diet was started.",
  "Education included fluid restriction, daily weight tracking, and signs of heart failure. Warfarin was continued with INR monitoring.",
  "He showed improvement with a 3 kg net negative fluid balance and SpO₂ rising to 94%.",
  "An echocardiogram showed LVEF 28%, moderate mitral regurgitation, and left atrial enlargement.",
  "Discharge medications included furosemide 80 mg, carvedilol 12.5 mg BID, lisinopril 5 mg, spironolactone 25 mg, basal insulin, and warfarin.",
  "Follow-up appointments were scheduled with cardiology, nephrology, and the diabetes clinic. Anticoagulation clinic and home health visits arranged.",
  "Patient is a widower living alone with limited mobility and uses a walker. Transportation issues were noted.",
  "Social work coordinated support services including meal delivery, medication management, and transport assistance."
]
,
  "entities": [
  [
    "ischemic cardiomyopathy",
    "LVEF 28%",
    "atrial fibrillation",
    "warfarin",
    "dyspnea",
    "orthopnea",
    "2 weeks"
  ],
  [
    "chronic obstructive pulmonary disease",
    "stage 4 chronic kidney disease",
    "type 2 diabetes mellitus",
    "HbA1c 9.1%",
    "CABG 2008"
  ],
  [
    "BP 155/95",
    "HR 112",
    "RR 28",
    "SpO₂ 86%",
    "room air",
    "oxygen therapy",
    "nasal cannula"
  ],
  [
    "jugular venous pressure",
    "bibasilar crackles",
    "3+ pitting edema",
    "S3 gallop"
  ],
  [
    "Cr 2.8",
    "eGFR 25",
    "BNP 2,400 pg/mL",
    "INR 2.3",
    "ABG pH 7.36",
    "PaCO₂ 50",
    "PaO₂ 70",
    "CXR",
    "cardiomegaly",
    "pleural effusions"
  ],
  [
    "ECG",
    "atrial fibrillation",
    "QTc 480 ms",
    "pacemaker function"
  ],
  [
    "IV furosemide",
    "oxygen therapy",
    "carvedilol",
    "lisinopril",
    "metformin",
    "renal impairment"
  ],
  [
    "basal-bolus insulin",
    "sliding scale insulin",
    "low-sodium renal diet"
  ],
  [
    "fluid restriction",
    "daily weight tracking",
    "heart failure education",
    "warfarin",
    "INR monitoring"
  ],
  [
    "3 kg fluid loss",
    "SpO₂ 94%"
  ],
  [
    "echocardiogram",
    "LVEF 28%",
    "moderate mitral regurgitation",
    "left atrial enlargement"
  ],
  [
    "furosemide 80 mg",
    "carvedilol 12.5 mg BID",
    "lisinopril 5 mg",
    "spironolactone 25 mg",
    "basal insulin",
    "warfarin"
  ],
  [
    "cardiology follow-up",
    "nephrology follow-up",
    "diabetes clinic follow-up",
    "anticoagulation clinic",
    "home health nursing"
  ],
  [
    "widowed",
    "lives alone",
    "limited mobility",
    "walker use",
    "transportation issues"
  ],
  [
    "social work",
    "meal delivery",
    "medication management",
    "transport assistance"
  ]
]

}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist": 0.1, "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

print("Preparing Allergy Topic Searcher...")
searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
)
print("✅ Model ready for querying.")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics = [
  { "topic_id": "T1", "entities": ["ischemic cardiomyopathy", "LVEF 28%", "moderate mitral regurgitation", "mild tricuspid regurgitation", "echocardiogram"] },
  { "topic_id": "T2", "entities": ["atrial fibrillation", "warfarin", "INR 2.3", "pacemaker function", "ECG", "QTc 480 ms"] },
  { "topic_id": "T3", "entities": ["chronic obstructive pulmonary disease", "dyspnea", "orthopnea", "two-pillow PND", "SpO₂ 86%", "nasal cannula"] },
  { "topic_id": "T4", "entities": ["ischemic stroke 2012", "CABG 2008"] },
  { "topic_id": "T5", "entities": ["type 2 diabetes mellitus", "HbA1c 9.1%", "metformin", "basal-bolus insulin", "sliding scale insulin", "diabetes clinic follow-up"] },
  { "topic_id": "T6", "entities": ["stage 4 chronic kidney disease", "Cr 2.8", "eGFR 25", "nephrology follow-up"] },
  { "topic_id": "T7", "entities": ["hypertension", "hyperlipidemia", "atorvastatin", "lisinopril"] },
  { "topic_id": "T8", "entities": ["osteoarthritis"] },
  { "topic_id": "T9", "entities": ["BP 155/95", "HR 112", "RR 28", "temperature 37.3 °C", "SpO₂ 86%", "SpO₂ 93%"] },
  { "topic_id": "T10", "entities": ["jugular venous pressure", "bibasilar crackles", "scattered wheezes", "3+ pitting edema", "S3 gallop", "displaced PMI", "dullness to percussion"] },
  { "topic_id": "T11", "entities": ["WBC 10.2", "Hb 12.0", "platelets 210", "Na 138", "K 4.5", "BUN 42", "BNP 2,400 pg/mL", "INR 2.3", "TSH normal", "troponin I negative"] },
  { "topic_id": "T12", "entities": ["ABG pH 7.36", "PaCO₂ 50", "PaO₂ 70"] },
  { "topic_id": "T13", "entities": ["CXR", "cardiomegaly", "interstitial edema", "small bilateral pleural effusions"] },
  { "topic_id": "T14", "entities": ["IV furosemide", "oxygen via nasal cannula"] },
  { "topic_id": "T15", "entities": ["carvedilol", "lisinopril", "spironolactone", "furosemide 80 mg"] },
  { "topic_id": "T16", "entities": ["warfarin", "INR monitoring", "anticoagulation clinic"] },
  { "topic_id": "T17", "entities": ["oxygen therapy", "nasal cannula", "SpO₂ 94%"] },
  { "topic_id": "T18", "entities": ["3 kg net negative fluid balance", "daily weight tracking", "fluid restriction"] },
  { "topic_id": "T19", "entities": ["pulmonary rehab evaluation", "physical therapy evaluation", "6-minute walk test"] },
  { "topic_id": "T20", "entities": ["heart failure red flags", "dietician visit", "home health nursing"] },
  { "topic_id": "T21", "entities": ["influenza vaccination", "pneumococcal vaccination", "herpes zoster vaccination"] },
  { "topic_id": "T22", "entities": ["cardiology follow-up", "nephrology follow-up", "diabetes clinic follow-up"] },
  { "topic_id": "T23", "entities": ["widowed", "lives alone", "former smoker", "limited mobility", "walker use", "transportation issues"] },
  { "topic_id": "T24", "entities": ["social work", "meal delivery", "medication management", "appointment coordination"] }
]


# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
print("\n=== 📊 Topic Matching Summary ===")
for model_id, gt_id, score in matches:
    print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=1, top_k_sents=3)
    print(f"\n🔎 Top results for: '{query}'")
    for res in results:
        print(f"🧠 Topic ID: {res['topic_id']}")
        print(f"🔗 Related Entities: {', '.join(res['entities'])}")
        for sent, _ in res["sentences"]:
            print(f"✓ {sent}")



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Preparing Allergy Topic Searcher...
✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.7842
🌈 Topic Diversity: 0.5619
📐 Silhouette Score: 0.3062

=== 📊 Topic Matching Summary ===
🔗 Model Topic -1 ↔ Ground Truth T14 — Jaccard: 0.20
🔗 Model Topic 12 ↔ Ground Truth T1 — Jaccard: 0.17
🔗 Model Topic 17 ↔ Ground Truth T16 — Jaccard: 0.25
🔗 Model Topic 5 ↔ Ground Truth T1 — Jaccard: 0.29
🔗 Model Topic 11 ↔ Ground Truth T4 — Jaccard: 0.25
🔗 Model Topic 15 ↔ Ground Truth T6 — Jaccard: 0.20
🔗 Model Topic 1 ↔ Ground Truth T9 — Jaccard: 0.36
🔗 Model Topic 3 ↔ Ground Truth T10 — Jaccard: 0.50
🔗 Model Topic 0 ↔ Ground Truth T12 — Jaccard: 0.30
🔗 Model Topic 13 ↔ Ground Truth T2 — Jaccard: 0.33
🔗 Model Topic 19 ↔ Ground Truth T2 — Jaccard: 0.33
🔗 Model Topic 7 ↔ Ground Truth T17 — Jaccard: 0.20
🔗 Model Topic 4 ↔ Ground Truth T15 — Jaccard: 0.29
🔗 Model Topic 9 ↔ Ground Truth T16 — Jaccard: 0.20
🔗 Model Topic 14 ↔ Ground Truth T16 — Jaccard: 0.25
🔗 Model Topic 6 ↔ Gr

In [4]:
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict, Counter
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score, precision_recall_fscore_support

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
nltk.download("punkt")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# === CLEANING & CONTEXT EXTRACTION ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    return re.sub(r'\s+', ' ', text).strip()

def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = (
                        " ".join(sentences[max(0, i - 1): i + 2])
                        if use_multi_sentence else sent.strip()
                    )
                    entity_context_pairs.append((ent_lower, context.strip()))
                    matched = True
                    break
            if not matched:
                entity_context_pairs.append((ent_lower, chunk))
    return entity_context_pairs

# === TOPIC SEARCHER CLASS ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params,
                 model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(
            self.chunks,
            self.entities_per_chunk,
            use_multi_sentence=True
        )

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        contextual_texts = [context for _, context in entity_context_pairs]
        # UPDATED: use full 768-dim embeddings for search
        contextual_embeddings = self.embedding_model.encode(
            contextual_texts, normalize_embeddings=True
        )

        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            if len(sent_embs) == 0:
                continue

            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], float(sims[i])) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": list(meta["entities"]),  # <- ensure it's a list in case it's a set
                "sentences": top_sents,
            })

        return results

# === EVALUATION METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for sent in meta["sentences"]:
            tokens = clean_text(sent).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return coherence_model.get_coherence()

def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    all_words = [word for topic in topics for word, _ in topic]
    return len(set(all_words)) / (len(topics) * topk)

def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        embeddings = meta["sentence_embeddings"]
        labels = [meta["topic_id"]] * len(embeddings)
        all_embeddings.append(embeddings)
        all_labels.extend(labels)

    if len(all_embeddings) == 0:
        return None

    all_embeddings = np.vstack(all_embeddings)
    n_samples = all_embeddings.shape[0]
    n_labels = len(set(all_labels))

    if n_labels < 2 or n_labels > n_samples - 1:
        return None

    return silhouette_score(all_embeddings, all_labels, metric="cosine")

# === DATASET & INITIALIZATION ===
allergy_dataset = {
   "chunks": [
  "A 68-year-old female with a history of chronic obstructive pulmonary disease (COPD) presented with worsening shortness of breath, increased sputum production, and wheezing over the past week.",
  "She reports using her rescue inhaler more frequently and having difficulty climbing stairs due to breathlessness.",
  "Home medications include tiotropium inhaler, albuterol as needed, and a prednisone taper started 3 days ago by her primary care provider.",
  "She denies chest pain or fever but has noticed fatigue and reduced exercise tolerance.",
  "She was started on nebulized bronchodilators and continued her inhaled therapies during the visit.",
  "Pulmonary rehabilitation referral was placed, and smoking cessation counseling was reinforced.",
  "Discharge plan included follow-up with her pulmonologist in one week and continuation of home medications.",
  "She lives with her daughter, is a former smoker, and uses oxygen at home as needed."
]

,
  "entities": [
  ["chronic obstructive pulmonary disease", "shortness of breath", "sputum production", "wheezing", "past week"],
  ["rescue inhaler", "breathlessness", "difficulty climbing stairs"],
  ["tiotropium inhaler", "albuterol", "prednisone taper", "primary care provider"],
  ["chest pain (denied)", "fever (denied)", "fatigue", "reduced exercise tolerance"],
  ["nebulized bronchodilators", "inhaled therapies"],
  ["pulmonary rehabilitation", "smoking cessation counseling"],
  ["pulmonologist follow-up", "home medications"],
  ["lives with daughter", "former smoker", "home oxygen"]
]

}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist": 0.1, "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

print("Preparing Allergy Topic Searcher...")
searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
)
print("✅ Model ready for querying.")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics=[
  { "topic_id": "T1", "entities": ["chronic obstructive pulmonary disease", "shortness of breath", "wheezing", "sputum production", "breathlessness"] },
  { "topic_id": "T2", "entities": ["tiotropium inhaler", "albuterol", "prednisone taper", "nebulized bronchodilators", "inhaled therapies", "rescue inhaler"] },
  { "topic_id": "T3", "entities": ["fatigue", "reduced exercise tolerance", "difficulty climbing stairs"] },
  { "topic_id": "T4", "entities": ["pulmonary rehabilitation", "smoking cessation counseling"] },
  { "topic_id": "T5", "entities": ["pulmonologist follow-up", "home medications", "primary care provider"] },
  { "topic_id": "T6", "entities": ["former smoker", "home oxygen", "lives with daughter"] }
]

# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
print("\n=== 📊 Topic Matching Summary ===")
for model_id, gt_id, score in matches:
    print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=1, top_k_sents=3)
    print(f"\n🔎 Top results for: '{query}'")
    for res in results:
        print(f"🧠 Topic ID: {res['topic_id']}")
        print(f"🔗 Related Entities: {', '.join(res['entities'])}")
        for sent, _ in res["sentences"]:
            print(f"✓ {sent}")



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Preparing Allergy Topic Searcher...
✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.5139
🌈 Topic Diversity: 0.8500
📐 Silhouette Score: 0.4215

=== 📊 Topic Matching Summary ===
🔗 Model Topic 1 ↔ Ground Truth T1 — Jaccard: 0.42
🔗 Model Topic 0 ↔ Ground Truth T2 — Jaccard: 0.36

🧮 Average Jaccard Similarity: 0.3869
📈 Ground Truth Coverage: 2/6 (33.3%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 0.4400
🧲 Recall:    1.0000
🏅 F1 Score:  0.6111

=== Allergy Topic Search ===

🔎 Top results for: 'What symptoms did the patient present with?'
🧠 Topic ID: 1
🔗 Related Entities: past week, shortness of breath, difficulty climbing stairs, rescue inhaler, breathlessness, wheezing, sputum production, chest pain (denied), chronic obstructive pulmonary disease, reduced exercise tolerance, fatigue, fever (denied)
✓ she denies chest pain or fever but has noticed fatigue and reduced exercise tolerance
✓ a 68yearold female with a history of chronic obstructive pulmon

In [16]:

# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict, Counter
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score, precision_recall_fscore_support

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
nltk.download("punkt")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# === CLEANING & CONTEXT EXTRACTION ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    return re.sub(r'\s+', ' ', text).strip()

def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = (
                        " ".join(sentences[max(0, i - 1): i + 2])
                        if use_multi_sentence else sent.strip()
                    )
                    entity_context_pairs.append((ent_lower, context.strip()))
                    matched = True
                    break
            if not matched:
                entity_context_pairs.append((ent_lower, chunk))
    return entity_context_pairs

# === TOPIC SEARCHER CLASS ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params,
                 model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(
            self.chunks,
            self.entities_per_chunk,
            use_multi_sentence=True
        )

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        contextual_texts = [context for _, context in entity_context_pairs]
        # UPDATED: use full 768-dim embeddings for search
        contextual_embeddings = self.embedding_model.encode(
            contextual_texts, normalize_embeddings=True
        )

        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            if len(sent_embs) == 0:
                continue

            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], float(sims[i])) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": list(meta["entities"]),  # <- ensure it's a list in case it's a set
                "sentences": top_sents,
            })

        return results

# === EVALUATION METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for sent in meta["sentences"]:
            tokens = clean_text(sent).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return coherence_model.get_coherence()

def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    all_words = [word for topic in topics for word, _ in topic]
    return len(set(all_words)) / (len(topics) * topk)

def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        embeddings = meta["sentence_embeddings"]
        labels = [meta["topic_id"]] * len(embeddings)
        all_embeddings.append(embeddings)
        all_labels.extend(labels)

    if len(all_embeddings) == 0:
        return None

    all_embeddings = np.vstack(all_embeddings)
    n_samples = all_embeddings.shape[0]
    n_labels = len(set(all_labels))

    if n_labels < 2 or n_labels > n_samples - 1:
        return None

    return silhouette_score(all_embeddings, all_labels, metric="cosine")

# === DATASET & INITIALIZATION ===
allergy_dataset = {
  "chunks": [
  "A 68-year-old woman with a 15-year history of hypertension and type 2 diabetes mellitus presents complaining of persistent, dull headaches localized to the occipital region, occasional dizziness, and intermittent blurred vision over the last 6 weeks.",
  "She has struggled with blood pressure control despite taking amlodipine 10 mg daily and hydrochlorothiazide 25 mg daily for several years. Home BP measurements often exceed 160/95 mmHg, particularly in the mornings.",
  "Her past medical history also includes stage 3 chronic kidney disease with baseline creatinine around 1.4 mg/dL, hyperlipidemia treated with atorvastatin, and osteoarthritis limiting her mobility.",
  "She reports new symptoms of swelling in both ankles and occasional shortness of breath during moderate exertion, which she attributes to her limited physical activity due to joint pain.",
  "On examination, her blood pressure measured in clinic is 168/98 mmHg sitting, with a heart rate of 88 beats per minute, respiratory rate 18, temperature 36.8 °C, and oxygen saturation 96% on room air.",
  "Cardiovascular examination reveals a displaced apical impulse palpable in the 6th intercostal space anterior axillary line, a grade 2/6 systolic murmur best heard at the cardiac apex, and no jugular venous distension.",
  "Fundoscopic exam shows mild hypertensive changes, including arteriolar narrowing and scattered cotton wool spots, but no papilledema.",
  "Respiratory exam is clear with normal breath sounds bilaterally; there is mild 1+ pitting edema in the ankles.",
  "Laboratory studies from her recent outpatient evaluation indicate stable renal function with creatinine 1.4 mg/dL and estimated glomerular filtration rate (eGFR) of 48 mL/min/1.73 m², normal electrolytes, and HbA1c of 7.8%.",
  "Her lipid panel reveals LDL cholesterol of 130 mg/dL despite ongoing statin therapy, with triglycerides at 160 mg/dL and HDL at 38 mg/dL.",
  "A 12-lead ECG demonstrates left ventricular hypertrophy with strain pattern but no arrhythmias or conduction abnormalities.",
  "Echocardiography confirms left ventricular hypertrophy with preserved left ventricular ejection fraction of 60%, mild left atrial enlargement, and no valvular abnormalities aside from mild mitral regurgitation.",
  "She admits difficulty adhering to a low-sodium diet because of a preference for processed foods and challenges preparing meals at home.",
  "Her physical activity is limited by osteoarthritis pain affecting knees and hips, and she rarely exceeds 1000 steps per day according to her activity tracker.",
  "Medications were adjusted to include lisinopril 10 mg daily to provide better blood pressure control and nephroprotection; hydrochlorothiazide dose was decreased due to borderline low potassium levels.",
  "She was counseled extensively on lifestyle modifications, including sodium restriction, weight reduction, and gradual increase in physical activity tailored to her joint limitations.",
  "Referrals were made to a dietitian for nutritional counseling focusing on kidney-friendly, low-sodium meals, and to physical therapy for osteoarthritis management and tailored exercise program.",
  "A cardiology appointment was scheduled within 4 weeks for blood pressure reassessment and echocardiographic follow-up, and nephrology follow-up planned to monitor kidney function closely.",
  "Home health services were arranged to assist with blood pressure monitoring, medication reminders, and reinforcement of dietary adherence.",
  "Social history reveals she lives with her husband in a single-story home, has a strong family support network, but experiences caregiver stress related to her mother-in-law’s dementia care.",
  "She denies tobacco use and alcohol consumption but expresses feelings of anxiety and occasional insomnia due to stress.",
  "Patient education was provided on the importance of medication adherence, recognizing signs of hypertensive crisis, and understanding potential complications such as stroke and heart failure.",
  "Follow-up labs including renal function, electrolytes, and lipid profile were ordered to be done prior to her next clinic visit.",
  "She was encouraged to maintain a daily blood pressure log and report any symptoms such as chest pain, worsening shortness of breath, or neurological changes promptly.",
  "The care team emphasized the importance of a multidisciplinary approach, including medical, nutritional, and psychosocial support to optimize her hypertension management."
]
,
  "entities":[
  ["68-year-old woman", "hypertension", "type 2 diabetes mellitus", "headaches", "dizziness", "blurred vision", "6 weeks"],
  ["poorly controlled BP", "amlodipine 10 mg", "hydrochlorothiazide 25 mg", "home BP readings", "160/95 mmHg", "mornings"],
  ["stage 3 chronic kidney disease", "creatinine 1.4 mg/dL", "hyperlipidemia", "atorvastatin", "osteoarthritis", "limited mobility"],
  ["bilateral ankle swelling", "shortness of breath", "limited physical activity", "joint pain"],
  ["clinic BP 168/98 mmHg", "HR 88 bpm", "RR 18", "Temp 36.8°C", "SpO2 96%"],
  ["displaced apical impulse", "grade 2/6 systolic murmur", "no JVD"],
  ["hypertensive retinopathy", "arteriolar narrowing", "cotton wool spots", "no papilledema"],
  ["clear lungs", "1+ pitting edema ankles"],
  ["renal function stable", "creatinine 1.4 mg/dL", "eGFR 48 mL/min/1.73 m²", "normal electrolytes", "HbA1c 7.8%"],
  ["LDL 130 mg/dL", "statin therapy", "triglycerides 160 mg/dL", "HDL 38 mg/dL"],
  ["ECG", "left ventricular hypertrophy", "strain pattern", "no arrhythmias"],
  ["echocardiogram", "LV hypertrophy", "EF 60%", "left atrial enlargement", "mild mitral regurgitation"],
  ["dietary noncompliance", "processed foods", "meal prep challenges"],
  ["limited physical activity", "osteoarthritis pain", "knee and hip involvement", "1000 steps/day"],
  ["lisinopril 10 mg", "hydrochlorothiazide dose decreased", "borderline low potassium"],
  ["lifestyle counseling", "sodium restriction", "weight loss", "exercise modification"],
  ["dietitian referral", "nutritional counseling", "kidney-friendly diet", "physical therapy referral", "tailored exercise program"],
  ["cardiology follow-up", "nephrology follow-up"],
  ["home health services", "BP monitoring", "medication reminders", "diet adherence support"],
  ["lives with husband", "family support", "caregiver stress", "mother-in-law dementia"],
  ["no tobacco", "no alcohol", "anxiety", "insomnia", "stress"],
  ["patient education", "medication adherence", "hypertensive crisis signs", "complications education"],
  ["follow-up labs", "renal function", "electrolytes", "lipid profile"],
  ["daily BP log", "symptom monitoring", "chest pain", "neurological symptoms"],
  ["multidisciplinary care", "medical", "nutritional", "psychosocial support"]
]


}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist": 0.25, "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
)
print("✅ Model ready for querying.")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics=[
  { "topic_id": "T1", "entities": ["hypertension", "headaches", "dizziness", "blurred vision", "poorly controlled BP", "medication adherence"] },
  { "topic_id": "T2", "entities": ["medications", "amlodipine", "hydrochlorothiazide", "lisinopril", "potassium"] },
  { "topic_id": "T3", "entities": ["comorbidities", "type 2 diabetes mellitus", "chronic kidney disease", "hyperlipidemia", "osteoarthritis"] },
  { "topic_id": "T4", "entities": ["symptoms", "ankle swelling", "shortness of breath", "limited physical activity", "fatigue"] },
  { "topic_id": "T5", "entities": ["physical exam", "BP 168/98", "heart rate", "murmur", "apical impulse", "retinopathy", "edema"] },
  { "topic_id": "T6", "entities": ["labs", "creatinine", "eGFR", "HbA1c", "lipid panel", "electrolytes"] },
  { "topic_id": "T7", "entities": ["imaging", "ECG", "LV hypertrophy", "strain pattern", "echocardiogram", "mitral regurgitation"] },
  { "topic_id": "T8", "entities": ["lifestyle", "dietary noncompliance", "processed foods", "exercise limitation", "weight loss"] },
  { "topic_id": "T9", "entities": ["referrals", "dietitian", "physical therapy", "cardiology", "nephrology"] },
  { "topic_id": "T10", "entities": ["social history", "married", "lives with spouse", "caregiver stress", "anxiety", "insomnia"] },
  { "topic_id": "T11", "entities": ["patient education", "hypertension complications", "medication adherence", "lifestyle modifications", "crisis signs"] },
  { "topic_id": "T12", "entities": ["follow-up", "home health", "blood pressure monitoring", "medication reminders", "labs"] },
  { "topic_id": "T13", "entities": ["multidisciplinary care", "medical", "nutritional", "psychosocial support"] }
]


# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
print("\n=== 📊 Topic Matching Summary ===")
for model_id, gt_id, score in matches:
    print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=1, top_k_sents=3)
    print(f"\n🔎 Top results for: '{query}'")
    for res in results:
        print(f"🧠 Topic ID: {res['topic_id']}")
        print(f"🔗 Related Entities: {', '.join(res['entities'])}")
        for sent, _ in res["sentences"]:
            print(f"✓ {sent}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.7728
🌈 Topic Diversity: 0.6667
📐 Silhouette Score: 0.5705

=== 📊 Topic Matching Summary ===
🔗 Model Topic 2 ↔ Ground Truth T1 — Jaccard: 0.44
🔗 Model Topic 9 ↔ Ground Truth T1 — Jaccard: 0.11
🔗 Model Topic 18 ↔ Ground Truth T3 — Jaccard: 0.14
🔗 Model Topic 21 ↔ Ground Truth T3 — Jaccard: 0.17
🔗 Model Topic 1 ↔ Ground Truth T4 — Jaccard: 0.17
🔗 Model Topic 12 ↔ Ground Truth T7 — Jaccard: 0.25
🔗 Model Topic 11 ↔ Ground Truth T7 — Jaccard: 0.25
🔗 Model Topic 0 ↔ Ground Truth T8 — Jaccard: 0.21
🔗 Model Topic 10 ↔ Ground Truth T11 — Jaccard: 0.29
🔗 Model Topic 16 ↔ Ground Truth T12 — Jaccard: 0.14
🔗 Model Topic 5 ↔ Ground Truth T13 — Jaccard: 0.12
🔗 Model Topic 17 ↔ Ground Truth T10 — Jaccard: 0.12
🔗 Model Topic 22 ↔ Ground Truth T10 — Jaccard: 0.14
🔗 Model Topic 19 ↔ Ground Truth T6 — Jaccard: 0.12
🔗 Model Topic 20 ↔ Ground Truth T13 — Jaccard: 0.75

🧮 Average Jaccard Similarity: 0.2295
📈 Ground Truth Cov

In [13]:
import re
import numpy as np
from itertools import product
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.metrics import silhouette_score
from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
import nltk
from nltk.tokenize import sent_tokenize

nltk.download("punkt")

# === Helper Functions ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text)
    return re.sub(r"\s+", " ", text).strip()

def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent:
                    context = (
                        " ".join(sentences[max(0, i - 1): i + 2])
                        if use_multi_sentence else sent.strip()
                    )
                    entity_context_pairs.append((ent_lower, context.strip()))
                    matched = True
                    break
            if not matched:
                entity_context_pairs.append((ent_lower, chunk))
    return entity_context_pairs

def compute_coherence_and_diversity(topic_model, texts, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    tokenized_texts = [clean_text(doc).split() for doc in texts]
    dictionary = Dictionary(tokenized_texts)
    corpus = [dictionary.doc2bow(text) for text in tokenized_texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=tokenized_texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    coherence = coherence_model.get_coherence()

    all_words = [word for topic in topic_word_lists for word in topic]
    diversity = len(set(all_words)) / (len(topic_word_lists) * topk)
    return coherence, diversity

# === Your Dataset ===
chunks = [
  "A 68-year-old woman with a 15-year history of hypertension and type 2 diabetes mellitus presents complaining of persistent, dull headaches localized to the occipital region, occasional dizziness, and intermittent blurred vision over the last 6 weeks.",
  "She has struggled with blood pressure control despite taking amlodipine 10 mg daily and hydrochlorothiazide 25 mg daily for several years. Home BP measurements often exceed 160/95 mmHg, particularly in the mornings.",
  "Her past medical history also includes stage 3 chronic kidney disease with baseline creatinine around 1.4 mg/dL, hyperlipidemia treated with atorvastatin, and osteoarthritis limiting her mobility.",
  "She reports new symptoms of swelling in both ankles and occasional shortness of breath during moderate exertion, which she attributes to her limited physical activity due to joint pain.",
  "On examination, her blood pressure measured in clinic is 168/98 mmHg sitting, with a heart rate of 88 beats per minute, respiratory rate 18, temperature 36.8 °C, and oxygen saturation 96% on room air.",
  "Cardiovascular examination reveals a displaced apical impulse palpable in the 6th intercostal space anterior axillary line, a grade 2/6 systolic murmur best heard at the cardiac apex, and no jugular venous distension.",
  "Fundoscopic exam shows mild hypertensive changes, including arteriolar narrowing and scattered cotton wool spots, but no papilledema.",
  "Respiratory exam is clear with normal breath sounds bilaterally; there is mild 1+ pitting edema in the ankles.",
  "Laboratory studies from her recent outpatient evaluation indicate stable renal function with creatinine 1.4 mg/dL and estimated glomerular filtration rate (eGFR) of 48 mL/min/1.73 m², normal electrolytes, and HbA1c of 7.8%.",
  "Her lipid panel reveals LDL cholesterol of 130 mg/dL despite ongoing statin therapy, with triglycerides at 160 mg/dL and HDL at 38 mg/dL.",
  "A 12-lead ECG demonstrates left ventricular hypertrophy with strain pattern but no arrhythmias or conduction abnormalities.",
  "Echocardiography confirms left ventricular hypertrophy with preserved left ventricular ejection fraction of 60%, mild left atrial enlargement, and no valvular abnormalities aside from mild mitral regurgitation.",
  "She admits difficulty adhering to a low-sodium diet because of a preference for processed foods and challenges preparing meals at home.",
  "Her physical activity is limited by osteoarthritis pain affecting knees and hips, and she rarely exceeds 1000 steps per day according to her activity tracker.",
  "Medications were adjusted to include lisinopril 10 mg daily to provide better blood pressure control and nephroprotection; hydrochlorothiazide dose was decreased due to borderline low potassium levels.",
  "She was counseled extensively on lifestyle modifications, including sodium restriction, weight reduction, and gradual increase in physical activity tailored to her joint limitations.",
  "Referrals were made to a dietitian for nutritional counseling focusing on kidney-friendly, low-sodium meals, and to physical therapy for osteoarthritis management and tailored exercise program.",
  "A cardiology appointment was scheduled within 4 weeks for blood pressure reassessment and echocardiographic follow-up, and nephrology follow-up planned to monitor kidney function closely.",
  "Home health services were arranged to assist with blood pressure monitoring, medication reminders, and reinforcement of dietary adherence.",
  "Social history reveals she lives with her husband in a single-story home, has a strong family support network, but experiences caregiver stress related to her mother-in-law’s dementia care.",
  "She denies tobacco use and alcohol consumption but expresses feelings of anxiety and occasional insomnia due to stress.",
  "Patient education was provided on the importance of medication adherence, recognizing signs of hypertensive crisis, and understanding potential complications such as stroke and heart failure.",
  "Follow-up labs including renal function, electrolytes, and lipid profile were ordered to be done prior to her next clinic visit.",
  "She was encouraged to maintain a daily blood pressure log and report any symptoms such as chest pain, worsening shortness of breath, or neurological changes promptly.",
  "The care team emphasized the importance of a multidisciplinary approach, including medical, nutritional, and psychosocial support to optimize her hypertension management."
]
entities = [
  ["68-year-old woman", "hypertension", "type 2 diabetes mellitus", "headaches", "dizziness", "blurred vision", "6 weeks"],
  ["poorly controlled BP", "amlodipine 10 mg", "hydrochlorothiazide 25 mg", "home BP readings", "160/95 mmHg", "mornings"],
  ["stage 3 chronic kidney disease", "creatinine 1.4 mg/dL", "hyperlipidemia", "atorvastatin", "osteoarthritis", "limited mobility"],
  ["bilateral ankle swelling", "shortness of breath", "limited physical activity", "joint pain"],
  ["clinic BP 168/98 mmHg", "HR 88 bpm", "RR 18", "Temp 36.8°C", "SpO2 96%"],
  ["displaced apical impulse", "grade 2/6 systolic murmur", "no JVD"],
  ["hypertensive retinopathy", "arteriolar narrowing", "cotton wool spots", "no papilledema"],
  ["clear lungs", "1+ pitting edema ankles"],
  ["renal function stable", "creatinine 1.4 mg/dL", "eGFR 48 mL/min/1.73 m²", "normal electrolytes", "HbA1c 7.8%"],
  ["LDL 130 mg/dL", "statin therapy", "triglycerides 160 mg/dL", "HDL 38 mg/dL"],
  ["ECG", "left ventricular hypertrophy", "strain pattern", "no arrhythmias"],
  ["echocardiogram", "LV hypertrophy", "EF 60%", "left atrial enlargement", "mild mitral regurgitation"],
  ["dietary noncompliance", "processed foods", "meal prep challenges"],
  ["limited physical activity", "osteoarthritis pain", "knee and hip involvement", "1000 steps/day"],
  ["lisinopril 10 mg", "hydrochlorothiazide dose decreased", "borderline low potassium"],
  ["lifestyle counseling", "sodium restriction", "weight loss", "exercise modification"],
  ["dietitian referral", "nutritional counseling", "kidney-friendly diet", "physical therapy referral", "tailored exercise program"],
  ["cardiology follow-up", "nephrology follow-up"],
  ["home health services", "BP monitoring", "medication reminders", "diet adherence support"],
  ["lives with husband", "family support", "caregiver stress", "mother-in-law dementia"],
  ["no tobacco", "no alcohol", "anxiety", "insomnia", "stress"],
  ["patient education", "medication adherence", "hypertensive crisis signs", "complications education"],
  ["follow-up labs", "renal function", "electrolytes", "lipid profile"],
  ["daily BP log", "symptom monitoring", "chest pain", "neurological symptoms"],
  ["multidisciplinary care", "medical", "nutritional", "psychosocial support"]
]

# === Embedding & Context Preparation ===
embedding_model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb")
entity_context_pairs = extract_entity_contexts(chunks, entities)
texts = [context for _, context in entity_context_pairs]
embeddings = embedding_model.encode(texts, normalize_embeddings=True)

# === Parameter Grid ===
umap_grid = [{"n_neighbors": n, "n_components": 5, "min_dist": d}
             for n, d in product([5, 10], [0.1, 0.25])]
hdbscan_grid = [{"min_cluster_size": c, "min_samples": s}
                for c, s in product([2, 5], [1, 5])]

best_score = -1
best_model = None
best_params = {}
best_coherence = 0
best_diversity = 0

# === Grid Search Loop ===
for umap_params in umap_grid:
    for hdbscan_params in hdbscan_grid:
        try:
            umap_model = UMAP(metric="cosine", random_state=42, **umap_params)
            hdbscan_model = HDBSCAN(metric="euclidean", prediction_data=True, **hdbscan_params)

            topic_model = BERTopic(
                embedding_model=embedding_model,
                umap_model=umap_model,
                hdbscan_model=hdbscan_model,
                representation_model=KeyBERTInspired(),
                calculate_probabilities=False,
                verbose=False,
            )

            topics, _ = topic_model.fit_transform(texts, embeddings)
            labels = np.array(topics)
            valid_idx = labels != -1
            if sum(valid_idx) < 2:
                continue

            sil_score = silhouette_score(np.array(embeddings)[valid_idx], labels[valid_idx], metric="cosine")
            topic_metadata = []
            for tid in set(topics):
                if tid == -1:
                    continue
                inds = [i for i, t in enumerate(topics) if t == tid]
                topic_metadata.append({
                    "topic_id": tid,
                    "sentences": [texts[i] for i in inds]
                })

            coherence, diversity = compute_coherence_and_diversity(topic_model, texts, topic_metadata)

            if sil_score > best_score:
                best_score = sil_score
                best_model = topic_model
                best_params = {"umap": umap_params, "hdbscan": hdbscan_params}
                best_coherence = coherence
                best_diversity = diversity

        except Exception as e:
            continue

# === Print Best Model Results ===
print("\n📊 Best Hyperparameter Combination Found:")
print(f"UMAP Params: {best_params['umap']}")
print(f"HDBSCAN Params: {best_params['hdbscan']}")
print(f"✅ Silhouette Score: {best_score:.4f}")
print(f"🧠 Coherence Score (c_v): {best_coherence:.4f}")
print(f"🌈 Topic Diversity: {best_diversity:.4f}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!



📊 Best Hyperparameter Combination Found:
UMAP Params: {'n_neighbors': 5, 'n_components': 5, 'min_dist': 0.25}
HDBSCAN Params: {'min_cluster_size': 2, 'min_samples': 1}
✅ Silhouette Score: 0.6048
🧠 Coherence Score (c_v): 0.7722
🌈 Topic Diversity: 0.6957


In [None]:
#top 3-topics-score
#dynamic thrsholding on the output
#3x of existing data(10000 characters)- fie tune parameters
#entity occuring at multi places