In [None]:
!pip install numpy sentence-transformers bertopic hdbscan nltk scann
import nltk
nltk.download('punkt')
import nltk
nltk.download('punkt_tab')
!pip install sentence-transformers bertopic hdbscan umap-learn scann nltk datasets
!pip install gensim

Collecting bertopic
  Downloading bertopic-0.17.3-py3-none-any.whl.metadata (24 kB)
Collecting scann
  Downloading scann-1.4.0-cp311-cp311-manylinux_2_27_x86_64.whl.metadata (5.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transform

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


Collecting gensim
  Downloading gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting numpy<2.0,>=1.18.5 (from gensim)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy<1.14.0,>=1.7.0 (from gensim)
  Downloading scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Downloading gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.7/26.7 MB[0m [31m64.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━

In [None]:
# === IMPORTS ===
import os
import random
import numpy as np
import torch
import nltk
import logging

from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score

# === ENVIRONMENT SETUP ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
nltk.download("punkt")

# === CLASS FOR TOPIC SEARCHER ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params, model_name="all-MiniLM-L6-v2"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = []

        for idx, ents in enumerate(self.entities_per_chunk):
            chunk = self.chunks[idx].lower()
            sentences = sent_tokenize(chunk)
            for ent in ents:
                ent_lower = ent.lower()
                for sent in sentences:
                    if ent_lower in sent:
                        entity_context_pairs.append((ent_lower, sent.strip()))
                        break

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        # Embed entity-contexts
        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        # Topic Modeling
        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        # Metadata aggregation
        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        # === Print topics and their entities ===
        print("\n=== Topics and Associated Entities ===")
        for meta in self.topic_metadata:
            print(f"Topic ID: {meta['topic_id']}, Entities: {', '.join(meta['entities'])}")

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], sims[i]) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": meta["entities"],
                "sentences": top_sents,
            })

        return results

# === COHERENCE SCORE ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        sentences = meta["sentences"]
        for sent in sentences:
            tokens = [word for word in sent.lower().split()]
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    coherence_score = coherence_model.get_coherence()
    return coherence_score

# === TOPIC DIVERSITY ===
def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    all_words = []
    for topic in topics:
        all_words.extend([word for word, _ in topic])
    unique_words = set(all_words)
    diversity = len(unique_words) / (len(topics) * topk)
    return diversity

# === SILHOUETTE SCORE (sentence embeddings, avoid errors) ===
def compute_silhouette_score(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        embeddings = meta["sentence_embeddings"]  # (num_sentences, emb_dim)
        labels = [meta["topic_id"]] * len(embeddings)
        all_embeddings.append(embeddings)
        all_labels.extend(labels)

    if len(all_embeddings) == 0:
        return None

    all_embeddings = np.vstack(all_embeddings)
    n_samples = all_embeddings.shape[0]
    n_labels = len(set(all_labels))

    if n_labels < 2 or n_labels > n_samples - 1:
        return None  # silhouette score constraints

    return silhouette_score(all_embeddings, all_labels, metric="cosine")

# === DATASET ===
allergy_dataset = {
    "name": "Allergy Dataset",
    "chunks": [
        "Patient has peanut allergy causing hives and swelling. Anaphylaxis noted once during a reaction.",
        "Allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.",
        "Severe anaphylaxis symptoms require immediate treatment with epinephrine.",
        "Food allergies to milk and eggs can cause skin reactions like urticaria and eczema.",
        "Cold weather does not cause allergy symptoms in this patient."
    ],
    "entities": [
        ["peanut allergy", "hives", "swelling", "anaphylaxis"],
        ["allergic rhinitis", "hay fever", "pollen", "dust", "pet dander"],
        ["anaphylaxis", "epinephrine", "treatment"],
        ["food allergies", "milk", "eggs", "urticaria", "eczema"],
        ["cold weather", "allergy symptoms"]
    ]
}

# === BEST PARAMETERS ===
best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist": 0.1, "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

# === INITIALIZE SEARCHER ===
print("Preparing Allergy Topic Searcher...")
searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
)
print("✅ Model ready for querying.")

# === EVALUATION METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=10)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f} (Higher is better, usually >0.4 is decent)")
print(f"🌈 Topic Diversity: {diversity:.4f} (Closer to 1 means more unique topics)")
if sil_score is not None:
    print(f"📐 Silhouette Score (cosine): {sil_score:.4f} (Closer to 1 means better cluster separation)")
else:
    print("📐 Silhouette Score: Not applicable (need at least 2 topics and enough samples).")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=1, top_k_sents=3)
    print(f"\n🔎 Top results for: '{query}'")
    for res in results:
        print(f"🧠 Topic ID: {res['topic_id']}")
        print(f"🔗 Related Entities: {', '.join(res['entities'])}")
        for sent, _ in res["sentences"]:
            print(f"✓ {sent}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Preparing Allergy Topic Searcher...


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]


=== Topics and Associated Entities ===
Topic ID: 3, Entities: swelling, peanut allergy, hives
Topic ID: 2, Entities: epinephrine, treatment, anaphylaxis
Topic ID: 0, Entities: pollen, allergic rhinitis, dust, hay fever, allergy symptoms, pet dander, cold weather
Topic ID: 1, Entities: food allergies, eggs, urticaria, milk, eczema
✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.6967 (Higher is better, usually >0.4 is decent)
🌈 Topic Diversity: 0.8750 (Closer to 1 means more unique topics)
📐 Silhouette Score (cosine): 0.7373 (Closer to 1 means better cluster separation)

=== Allergy Topic Search ===

Enter a query (or type 'exit' to quit): allergy

🔎 Top results for: 'allergy'
🧠 Topic ID: 0
🔗 Related Entities: pollen, allergic rhinitis, dust, hay fever, allergy symptoms, pet dander, cold weather
✓ cold weather does not cause allergy symptoms in this patient.
✓ allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.

En

In [None]:
#code after improved topic modelling metric

In [None]:
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
nltk.download("punkt")

# === CLEANING & CONTEXT EXTRACTION ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent.lower():
                    context = (
                        " ".join(sentences[max(0, i - 1): i + 2])
                        if use_multi_sentence else sent.strip()
                    )
                    entity_context_pairs.append((ent_lower, context.strip()))
                    matched = True
                    break
            if not matched:
                entity_context_pairs.append((ent_lower, chunk))
    return entity_context_pairs

# === TOPIC SEARCHER CLASS ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params, model_name="all-mpnet-base-v2"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(
            self.chunks,
            self.entities_per_chunk,
            use_multi_sentence=True
        )

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        print("\n=== Topics and Associated Entities ===")
        for meta in self.topic_metadata:
            print(f"Topic ID: {meta['topic_id']}, Entities: {', '.join(meta['entities'])}")

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], sims[i]) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": meta["entities"],
                "sentences": top_sents,
            })

        return results

# === EVALUATION METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for sent in meta["sentences"]:
            tokens = clean_text(sent).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return coherence_model.get_coherence()

def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    all_words = [word for topic in topics for word, _ in topic]
    return len(set(all_words)) / (len(topics) * topk)

def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        embeddings = meta["sentence_embeddings"]
        labels = [meta["topic_id"]] * len(embeddings)
        all_embeddings.append(embeddings)
        all_labels.extend(labels)

    if len(all_embeddings) == 0:
        return None

    all_embeddings = np.vstack(all_embeddings)
    n_samples = all_embeddings.shape[0]
    n_labels = len(set(all_labels))

    if n_labels < 2 or n_labels > n_samples - 1:
        return None

    return silhouette_score(all_embeddings, all_labels, metric="cosine")

# === DATASET & INITIALIZATION ===
allergy_dataset = {
    "chunks": [
        "Patient has peanut allergy causing hives and swelling. Anaphylaxis noted once during a reaction.",
        "Allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.",
        "Severe anaphylaxis symptoms require immediate treatment with epinephrine.",
        "Food allergies to milk and eggs can cause skin reactions like urticaria and eczema.",
        "Cold weather does not cause allergy symptoms in this patient."
    ],
    "entities": [
        ["peanut allergy", "hives", "swelling", "anaphylaxis"],
        ["allergic rhinitis", "hay fever", "pollen", "dust", "pet dander"],
        ["anaphylaxis", "epinephrine", "treatment"],
        ["food allergies", "milk", "eggs", "urticaria", "eczema"],
        ["cold weather", "allergy symptoms"]
    ]
}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist": 0.1, "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

print("Preparing Allergy Topic Searcher...")
searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="all-mpnet-base-v2"
)
print("✅ Model ready for querying.")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics = [
    {"topic_id": "T1", "entities": ["peanut allergy", "allergic rhinitis", "hay fever", "food allergies"]},
    {"topic_id": "T2", "entities": ["peanut", "pollen", "dust", "pet dander", "milk", "eggs"]},
    {"topic_id": "T3", "entities": ["hives", "swelling", "anaphylaxis", "urticaria", "eczema", "allergy symptoms"]},
    {"topic_id": "T4", "entities": ["epinephrine", "treatment"]},
    {"topic_id": "T5", "entities": ["cold weather"]}
]

# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
print("\n=== 📊 Topic Matching Summary ===")
for model_id, gt_id, score in matches:
    print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=1, top_k_sents=3)
    print(f"\n🔎 Top results for: '{query}'")
    for res in results:
        print(f"🧠 Topic ID: {res['topic_id']}")
        print(f"🔗 Related Entities: {', '.join(res['entities'])}")
        for sent, _ in res["sentences"]:
            print(f"✓ {sent}")



Preparing Allergy Topic Searcher...


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]


=== Topics and Associated Entities ===
Topic ID: 0, Entities: peanut allergy, swelling, anaphylaxis, hives, allergy symptoms, cold weather
Topic ID: 1, Entities: pollen, allergic rhinitis, dust, hay fever, pet dander
Topic ID: 3, Entities: epinephrine, treatment, anaphylaxis
Topic ID: 2, Entities: food allergies, eggs, urticaria, milk, eczema
✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.8963
🌈 Topic Diversity: 0.8250
📐 Silhouette Score: 0.7315

=== 📊 Topic Matching Summary ===
🔗 Model Topic 0 ↔ Ground Truth T3 — Jaccard: 0.50
🔗 Model Topic 1 ↔ Ground Truth T2 — Jaccard: 0.38
🔗 Model Topic 3 ↔ Ground Truth T4 — Jaccard: 0.67
🔗 Model Topic 2 ↔ Ground Truth T2 — Jaccard: 0.22

🧮 Average Jaccard Similarity: 0.4410
📈 Ground Truth Coverage: 3/5 (60.0%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 0.7222
🧲 Recall:    0.9286
🏅 F1 Score:  0.8125

=== Allergy Topic Search ===

Enter a query (or type 'exit' to quit): allergy

🔎 Top results for: 'allerg

In [None]:
#Experiemnt on dataset 2

In [None]:
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
nltk.download("punkt")

# === CLEANING & CONTEXT EXTRACTION ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent.lower():
                    context = (
                        " ".join(sentences[max(0, i - 1): i + 2])
                        if use_multi_sentence else sent.strip()
                    )
                    entity_context_pairs.append((ent_lower, context.strip()))
                    matched = True
                    break
            if not matched:
                entity_context_pairs.append((ent_lower, chunk))
    return entity_context_pairs

# === TOPIC SEARCHER CLASS ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params, model_name="all-mpnet-base-v2"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(
            self.chunks,
            self.entities_per_chunk,
            use_multi_sentence=True
        )

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        print("\n=== Topics and Associated Entities ===")
        for meta in self.topic_metadata:
            print(f"Topic ID: {meta['topic_id']}, Entities: {', '.join(meta['entities'])}")

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], sims[i]) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": meta["entities"],
                "sentences": top_sents,
            })

        return results

# === EVALUATION METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for sent in meta["sentences"]:
            tokens = clean_text(sent).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return coherence_model.get_coherence()

def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    all_words = [word for topic in topics for word, _ in topic]
    return len(set(all_words)) / (len(topics) * topk)

def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        embeddings = meta["sentence_embeddings"]
        labels = [meta["topic_id"]] * len(embeddings)
        all_embeddings.append(embeddings)
        all_labels.extend(labels)

    if len(all_embeddings) == 0:
        return None

    all_embeddings = np.vstack(all_embeddings)
    n_samples = all_embeddings.shape[0]
    n_labels = len(set(all_labels))

    if n_labels < 2 or n_labels > n_samples - 1:
        return None

    return silhouette_score(all_embeddings, all_labels, metric="cosine")

# === DATASET & INITIALIZATION ===
allergy_dataset = {
   "chunks": [
            "Atopic dermatitis is a chronic skin condition characterized by itchy and inflamed skin.",
            "Patients with atopic dermatitis often have dry, scaly patches and may experience infections.",
            "Treatment includes moisturizers, corticosteroids, and avoiding irritants.",
            "Severe cases may require systemic immunosuppressants.",
            "Triggers include allergens such as dust mites, pet dander, and pollen."
        ],
  "entities": [
            ["atopic dermatitis", "skin condition", "itchy", "inflamed skin"],
            ["dry", "scaly patches", "infections"],
            ["treatment", "moisturizers", "corticosteroids", "irritants"],
            ["severe cases", "systemic immunosuppressants"],
            ["triggers", "allergens", "dust mites", "pet dander", "pollen"]
        ]
}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist": 0.1, "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

print("Preparing Allergy Topic Searcher...")
searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="all-mpnet-base-v2"
)
print("✅ Model ready for querying.")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics = [
    {"topic_id": "T1", "entities": ["atopic dermatitis", "skin condition", "itchy", "inflamed skin"]},
    {"topic_id": "T2", "entities": ["dry", "scaly patches", "infections"]},
    {"topic_id": "T3", "entities": ["treatment", "moisturizers", "corticosteroids", "irritants"]},
    {"topic_id": "T4", "entities": ["severe cases", "systemic immunosuppressants"]},
    {"topic_id": "T5", "entities": ["triggers", "allergens", "dust mites", "pet dander", "pollen"]}
]

# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
print("\n=== 📊 Topic Matching Summary ===")
for model_id, gt_id, score in matches:
    print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=1, top_k_sents=3)
    print(f"\n🔎 Top results for: '{query}'")
    for res in results:
        print(f"🧠 Topic ID: {res['topic_id']}")
        print(f"🔗 Related Entities: {', '.join(res['entities'])}")
        for sent, _ in res["sentences"]:
            print(f"✓ {sent}")



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Preparing Allergy Topic Searcher...

=== Topics and Associated Entities ===
Topic ID: 1, Entities: skin condition, itchy, inflamed skin, atopic dermatitis
Topic ID: 3, Entities: infections, scaly patches, dry
Topic ID: 2, Entities: treatment, moisturizers, corticosteroids, irritants
Topic ID: 4, Entities: severe cases, systemic immunosuppressants
Topic ID: 0, Entities: pollen, triggers, allergens, dust mites, pet dander
✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.9420
🌈 Topic Diversity: 0.7800
📐 Silhouette Score: 0.8139

=== 📊 Topic Matching Summary ===
🔗 Model Topic 1 ↔ Ground Truth T1 — Jaccard: 1.00
🔗 Model Topic 3 ↔ Ground Truth T2 — Jaccard: 1.00
🔗 Model Topic 2 ↔ Ground Truth T3 — Jaccard: 1.00
🔗 Model Topic 4 ↔ Ground Truth T4 — Jaccard: 1.00
🔗 Model Topic 0 ↔ Ground Truth T5 — Jaccard: 1.00

🧮 Average Jaccard Similarity: 1.0000
📈 Ground Truth Coverage: 5/5 (100.0%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 1.0000
🧲 Recall:    1.00

In [None]:
#Experiemnt on dataset 3

In [None]:
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
nltk.download("punkt")

# === CLEANING & CONTEXT EXTRACTION ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent.lower():
                    context = (
                        " ".join(sentences[max(0, i - 1): i + 2])
                        if use_multi_sentence else sent.strip()
                    )
                    entity_context_pairs.append((ent_lower, context.strip()))
                    matched = True
                    break
            if not matched:
                entity_context_pairs.append((ent_lower, chunk))
    return entity_context_pairs

# === TOPIC SEARCHER CLASS ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params, model_name="all-mpnet-base-v2"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(
            self.chunks,
            self.entities_per_chunk,
            use_multi_sentence=True
        )

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        print("\n=== Topics and Associated Entities ===")
        for meta in self.topic_metadata:
            print(f"Topic ID: {meta['topic_id']}, Entities: {', '.join(meta['entities'])}")

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], sims[i]) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": meta["entities"],
                "sentences": top_sents,
            })

        return results

# === EVALUATION METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for sent in meta["sentences"]:
            tokens = clean_text(sent).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return coherence_model.get_coherence()

def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    all_words = [word for topic in topics for word, _ in topic]
    return len(set(all_words)) / (len(topics) * topk)

def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        embeddings = meta["sentence_embeddings"]
        labels = [meta["topic_id"]] * len(embeddings)
        all_embeddings.append(embeddings)
        all_labels.extend(labels)

    if len(all_embeddings) == 0:
        return None

    all_embeddings = np.vstack(all_embeddings)
    n_samples = all_embeddings.shape[0]
    n_labels = len(set(all_labels))

    if n_labels < 2 or n_labels > n_samples - 1:
        return None

    return silhouette_score(all_embeddings, all_labels, metric="cosine")

# === DATASET & INITIALIZATION ===
allergy_dataset = {
   "chunks": [
            "Food allergies can cause a range of reactions from mild hives to severe anaphylaxis.",
            "Common food allergens include peanuts, tree nuts, milk, eggs, wheat, soy, fish, and shellfish.",
            "Anaphylaxis requires immediate treatment with epinephrine.",
            "Symptoms can include swelling, difficulty breathing, and rash.",
            "Avoidance of allergens is key to management."
        ],
    "entities": [
            ["food allergies", "reactions", "hives", "anaphylaxis"],
            ["common food allergens", "peanuts", "tree nuts", "milk", "eggs", "wheat", "soy", "fish", "shellfish"],
            ["anaphylaxis", "treatment", "epinephrine"],
            ["symptoms", "swelling", "difficulty breathing", "rash"],
            ["avoidance", "management"]

        ]
}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist": 0.1, "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

print("Preparing Allergy Topic Searcher...")
searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="all-mpnet-base-v2"
)
print("✅ Model ready for querying.")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics = [
    {"topic_id": "T1", "entities": ["food allergies", "reactions", "hives", "anaphylaxis"]},
    {"topic_id": "T2", "entities": ["common food allergens", "peanuts", "tree nuts", "milk", "eggs", "wheat", "soy", "fish", "shellfish"]},
    {"topic_id": "T3", "entities": ["anaphylaxis", "treatment", "epinephrine"]},
    {"topic_id": "T4", "entities": ["symptoms", "swelling", "difficulty breathing", "rash"]},
    {"topic_id": "T5", "entities": ["avoidance", "management"]}
]

# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
print("\n=== 📊 Topic Matching Summary ===")
for model_id, gt_id, score in matches:
    print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=1, top_k_sents=3)
    print(f"\n🔎 Top results for: '{query}'")
    for res in results:
        print(f"🧠 Topic ID: {res['topic_id']}")
        print(f"🔗 Related Entities: {', '.join(res['entities'])}")
        for sent, _ in res["sentences"]:
            print(f"✓ {sent}")



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Preparing Allergy Topic Searcher...

=== Topics and Associated Entities ===
Topic ID: 1, Entities: reactions, food allergies, anaphylaxis, hives
Topic ID: 0, Entities: peanuts, common food allergens, eggs, wheat, shellfish, milk, tree nuts, fish, soy
Topic ID: 3, Entities: treatment, anaphylaxis, epinephrine
Topic ID: 2, Entities: swelling, difficulty breathing, symptoms, rash
Topic ID: 4, Entities: avoidance, management
✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.9326
🌈 Topic Diversity: 0.5200
📐 Silhouette Score: 0.8557

=== 📊 Topic Matching Summary ===
🔗 Model Topic 1 ↔ Ground Truth T1 — Jaccard: 1.00
🔗 Model Topic 0 ↔ Ground Truth T2 — Jaccard: 1.00
🔗 Model Topic 3 ↔ Ground Truth T3 — Jaccard: 1.00
🔗 Model Topic 2 ↔ Ground Truth T4 — Jaccard: 1.00
🔗 Model Topic 4 ↔ Ground Truth T5 — Jaccard: 1.00

🧮 Average Jaccard Similarity: 1.0000
📈 Ground Truth Coverage: 5/5 (100.0%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 1.0000
🧲 Recall:    1.0

In [None]:
#Experiemnt on dataset 4

In [None]:
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
nltk.download("punkt")

# === CLEANING & CONTEXT EXTRACTION ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent.lower():
                    context = (
                        " ".join(sentences[max(0, i - 1): i + 2])
                        if use_multi_sentence else sent.strip()
                    )
                    entity_context_pairs.append((ent_lower, context.strip()))
                    matched = True
                    break
            if not matched:
                entity_context_pairs.append((ent_lower, chunk))
    return entity_context_pairs

# === TOPIC SEARCHER CLASS ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params, model_name="all-mpnet-base-v2"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(
            self.chunks,
            self.entities_per_chunk,
            use_multi_sentence=True
        )

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        print("\n=== Topics and Associated Entities ===")
        for meta in self.topic_metadata:
            print(f"Topic ID: {meta['topic_id']}, Entities: {', '.join(meta['entities'])}")

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], sims[i]) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": meta["entities"],
                "sentences": top_sents,
            })

        return results

# === EVALUATION METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for sent in meta["sentences"]:
            tokens = clean_text(sent).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return coherence_model.get_coherence()

def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    all_words = [word for topic in topics for word, _ in topic]
    return len(set(all_words)) / (len(topics) * topk)

def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        embeddings = meta["sentence_embeddings"]
        labels = [meta["topic_id"]] * len(embeddings)
        all_embeddings.append(embeddings)
        all_labels.extend(labels)

    if len(all_embeddings) == 0:
        return None

    all_embeddings = np.vstack(all_embeddings)
    n_samples = all_embeddings.shape[0]
    n_labels = len(set(all_labels))

    if n_labels < 2 or n_labels > n_samples - 1:
        return None

    return silhouette_score(all_embeddings, all_labels, metric="cosine")

# === DATASET & INITIALIZATION ===
allergy_dataset = {
   "chunks": [
            "Allergic rhinitis is caused by an allergic response to airborne particles like pollen, dust, and pet dander.",
            "Symptoms include sneezing, nasal congestion, runny nose, and itchy eyes.",
            "Treatment options include antihistamines, nasal corticosteroids, and avoiding triggers.",
            "Common triggers include pollen, dust mites, mold, and pet dander.",
            "Seasonal allergic rhinitis is often worse during pollen season."
        ],
        "entities": [
            ["allergic rhinitis", "allergic response", "airborne particles", "pollen", "dust", "pet dander"],
            ["symptoms", "sneezing", "nasal congestion", "runny nose", "itchy eyes"],
            ["treatment", "antihistamines", "nasal corticosteroids", "avoiding triggers"],
            ["common triggers", "pollen", "dust mites", "mold", "pet dander"],
            ["seasonal allergic rhinitis", "pollen season"]
        ]

}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist": 0.1, "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

print("Preparing Allergy Topic Searcher...")
searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="all-mpnet-base-v2"
)
print("✅ Model ready for querying.")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics = [
    {"topic_id": "T1", "entities": ["allergic rhinitis", "allergic response", "airborne particles", "pollen", "dust", "pet dander"]},
    {"topic_id": "T2", "entities": ["symptoms", "sneezing", "nasal congestion", "runny nose", "itchy eyes"]},
    {"topic_id": "T3", "entities": ["treatment", "antihistamines", "nasal corticosteroids", "avoiding triggers"]},
    {"topic_id": "T4", "entities": ["common triggers", "pollen", "dust mites", "mold", "pet dander"]},
    {"topic_id": "T5", "entities": ["seasonal allergic rhinitis", "pollen season"]}
]

# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
print("\n=== 📊 Topic Matching Summary ===")
for model_id, gt_id, score in matches:
    print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=1, top_k_sents=3)
    print(f"\n🔎 Top results for: '{query}'")
    for res in results:
        print(f"🧠 Topic ID: {res['topic_id']}")
        print(f"🔗 Related Entities: {', '.join(res['entities'])}")
        for sent, _ in res["sentences"]:
            print(f"✓ {sent}")



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Preparing Allergy Topic Searcher...

=== Topics and Associated Entities ===
Topic ID: 0, Entities: pollen, allergic rhinitis, seasonal allergic rhinitis, dust, pollen season, allergic response, airborne particles, pet dander
Topic ID: 1, Entities: runny nose, itchy eyes, sneezing, symptoms, nasal congestion
Topic ID: 3, Entities: treatment, nasal corticosteroids, avoiding triggers, antihistamines
Topic ID: 2, Entities: pollen, common triggers, dust mites, mold, pet dander
✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.8624
🌈 Topic Diversity: 0.7250
📐 Silhouette Score: 0.6324

=== 📊 Topic Matching Summary ===
🔗 Model Topic 0 ↔ Ground Truth T1 — Jaccard: 0.75
🔗 Model Topic 1 ↔ Ground Truth T2 — Jaccard: 1.00
🔗 Model Topic 3 ↔ Ground Truth T3 — Jaccard: 1.00
🔗 Model Topic 2 ↔ Ground Truth T4 — Jaccard: 1.00

🧮 Average Jaccard Similarity: 0.9375
📈 Ground Truth Coverage: 4/5 (80.0%)

=== 🧠 Entity-Level Evaluation ===
🎯 Precision: 0.9000
🧲 Recall:    1.

In [None]:
#Experiemnt on dataset 5

In [None]:
# === IMPORTS & SETUP ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import re

from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora import Dictionary
from sklearn.metrics import silhouette_score

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
nltk.download("punkt")

# === CLEANING & CONTEXT EXTRACTION ===
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def extract_entity_contexts(chunks, entities_per_chunk, use_multi_sentence=True):
    entity_context_pairs = []
    for idx, ents in enumerate(entities_per_chunk):
        chunk = clean_text(chunks[idx])
        sentences = sent_tokenize(chunk)
        for ent in ents:
            ent_lower = ent.lower()
            matched = False
            for i, sent in enumerate(sentences):
                if ent_lower in sent.lower():
                    context = (
                        " ".join(sentences[max(0, i - 1): i + 2])
                        if use_multi_sentence else sent.strip()
                    )
                    entity_context_pairs.append((ent_lower, context.strip()))
                    matched = True
                    break
            if not matched:
                entity_context_pairs.append((ent_lower, chunk))
    return entity_context_pairs

# === TOPIC SEARCHER CLASS ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params, model_name="all-mpnet-base-v2"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = extract_entity_contexts(
            self.chunks,
            self.entities_per_chunk,
            use_multi_sentence=True
        )

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        print("\n=== Topics and Associated Entities ===")
        for meta in self.topic_metadata:
            print(f"Topic ID: {meta['topic_id']}, Entities: {', '.join(meta['entities'])}")

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], sims[i]) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": meta["entities"],
                "sentences": top_sents,
            })

        return results

# === EVALUATION METRICS ===
def compute_bertopic_coherence(topic_model, topic_metadata, topk=15):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    topic_word_lists = [[word for word, _ in topic] for topic in topics]

    texts = []
    for meta in topic_metadata:
        for sent in meta["sentences"]:
            tokens = clean_text(sent).split()
            texts.append(tokens)

    dictionary = Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]

    coherence_model = CoherenceModel(
        topics=topic_word_lists,
        texts=texts,
        dictionary=dictionary,
        coherence="c_v"
    )
    return coherence_model.get_coherence()

def compute_topic_diversity(topic_model, topic_metadata, topk=10):
    topics = [topic_model.get_topic(meta["topic_id"])[:topk] for meta in topic_metadata]
    all_words = [word for topic in topics for word, _ in topic]
    return len(set(all_words)) / (len(topics) * topk)

def compute_silhouette_score_custom(topic_metadata):
    all_embeddings = []
    all_labels = []

    for meta in topic_metadata:
        embeddings = meta["sentence_embeddings"]
        labels = [meta["topic_id"]] * len(embeddings)
        all_embeddings.append(embeddings)
        all_labels.extend(labels)

    if len(all_embeddings) == 0:
        return None

    all_embeddings = np.vstack(all_embeddings)
    n_samples = all_embeddings.shape[0]
    n_labels = len(set(all_labels))

    if n_labels < 2 or n_labels > n_samples - 1:
        return None

    return silhouette_score(all_embeddings, all_labels, metric="cosine")

# === DATASET & INITIALIZATION ===
allergy_dataset = {
   "chunks": [
        "The patient reports frequent migraines, especially after long periods of screen exposure.",
        "A family history of migraines is noted, particularly on the maternal side.",
        "Neurological imaging (MRI) showed no abnormalities.",
        "The patient consumes high amounts of caffeine and has irregular sleep patterns.",
        "Preventive strategies include regular sleep hygiene, reduced caffeine, and stress management.",
        "The patient has tried multiple over-the-counter pain relievers with limited success.",
        "Genetic testing revealed a polymorphism in the CACNA1A gene, associated with familial hemiplegic migraine."
    ],
    "entities": [
        ["migraines", "screen exposure", "frequent headaches"],
        ["family history", "maternal side", "migraines"],
        ["neurological imaging", "MRI", "no abnormalities"],
        ["caffeine", "irregular sleep", "sleep patterns"],
        ["preventive strategies", "sleep hygiene", "stress management", "reduced caffeine"],
        ["pain relievers", "limited success", "over-the-counter medications"],
        ["genetic testing", "CACNA1A", "familial hemiplegic migraine", "polymorphism"]
    ]
}

best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist": 0.1, "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}

print("Preparing Allergy Topic Searcher...")
searcher = AllergyTopicSearcher(
    chunks=allergy_dataset["chunks"],
    entities_per_chunk=allergy_dataset["entities"],
    umap_params=best_umap,
    hdbscan_params=best_hdbscan,
    model_name="all-mpnet-base-v2"
)
print("✅ Model ready for querying.")

# === METRICS ===
coherence = compute_bertopic_coherence(searcher.topic_model, searcher.topic_metadata, topk=15)
diversity = compute_topic_diversity(searcher.topic_model, searcher.topic_metadata, topk=10)
sil_score = compute_silhouette_score_custom(searcher.topic_metadata)

print("\n=== Topic Quality Metrics ===")
print(f"🧪 Coherence Score (c_v): {coherence:.4f}")
print(f"🌈 Topic Diversity: {diversity:.4f}")
if sil_score is not None:
    print(f"📐 Silhouette Score: {sil_score:.4f}")
else:
    print("📐 Silhouette Score: Not applicable.")

# === GROUND TRUTH TOPICS ===
ground_truth_topics = [
    {"topic_id": "T1", "entities": ["migraines", "screen exposure", "frequent headaches"]},
    {"topic_id": "T2", "entities": ["family history", "maternal side", "migraines"]},
    {"topic_id": "T3", "entities": ["neurological imaging", "MRI", "no abnormalities"]},
    {"topic_id": "T4", "entities": ["caffeine", "irregular sleep", "sleep patterns"]},
    {"topic_id": "T5", "entities": ["preventive strategies", "sleep hygiene", "stress management", "reduced caffeine"]},
    {"topic_id": "T6", "entities": ["pain relievers", "limited success", "over-the-counter medications"]},
    {"topic_id": "T7", "entities": ["genetic testing", "CACNA1A", "familial hemiplegic migraine", "polymorphism"]}
]

# === EVALUATION CODE ===
from collections import Counter
from sklearn.metrics import precision_recall_fscore_support

def normalize(entities):
    return [e.lower().strip() for e in entities]

def jaccard_similarity(set1, set2):
    set1, set2 = set(set1), set(set2)
    return len(set1 & set2) / len(set1 | set2) if set1 | set2 else 0.0

# Prepare model topics
model_topics = [
    {"topic_id": meta["topic_id"], "entities": normalize(meta["entities"])}
    for meta in searcher.topic_metadata
]

# Matching model topics to ground truth
matched_gt_ids = set()
matches = []
all_model_entities = []
all_gt_entities = []

for mt in model_topics:
    best_score = 0
    best_gt = None
    for gt in ground_truth_topics:
        score = jaccard_similarity(mt["entities"], normalize(gt["entities"]))
        if score > best_score:
            best_score = score
            best_gt = gt
    if best_gt:
        matches.append((mt["topic_id"], best_gt["topic_id"], best_score))
        matched_gt_ids.add(best_gt["topic_id"])

        # Collect entities for entity-level precision/recall
        all_model_entities.extend(mt["entities"])
        all_gt_entities.extend(normalize(best_gt["entities"]))

# Entity-level metrics
model_entity_counter = Counter(all_model_entities)
gt_entity_counter = Counter(all_gt_entities)

unique_entities = list(set(list(model_entity_counter.keys()) + list(gt_entity_counter.keys())))
y_true = [gt_entity_counter[e] > 0 for e in unique_entities]
y_pred = [model_entity_counter[e] > 0 for e in unique_entities]

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

# Print evaluation
print("\n=== 📊 Topic Matching Summary ===")
for model_id, gt_id, score in matches:
    print(f"🔗 Model Topic {model_id} ↔ Ground Truth {gt_id} — Jaccard: {score:.2f}")

print(f"\n🧮 Average Jaccard Similarity: {sum(score for _, _, score in matches) / len(matches):.4f}")
print(f"📈 Ground Truth Coverage: {len(matched_gt_ids)}/{len(ground_truth_topics)} "
      f"({(len(matched_gt_ids)/len(ground_truth_topics))*100:.1f}%)")

print("\n=== 🧠 Entity-Level Evaluation ===")
print(f"🎯 Precision: {precision:.4f}")
print(f"🧲 Recall:    {recall:.4f}")
print(f"🏅 F1 Score:  {f1:.4f}")

# === QUERY LOOP ===
print("\n=== Allergy Topic Search ===")
while True:
    query = input("\nEnter a query (or type 'exit' to quit): ").strip()
    if query.lower() in {"exit", "quit"}:
        print("Goodbye!")
        break

    results = searcher.search(query, top_k_topics=1, top_k_sents=3)
    print(f"\n🔎 Top results for: '{query}'")
    for res in results:
        print(f"🧠 Topic ID: {res['topic_id']}")
        print(f"🔗 Related Entities: {', '.join(res['entities'])}")
        for sent, _ in res["sentences"]:
            print(f"✓ {sent}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Preparing Allergy Topic Searcher...

=== Topics and Associated Entities ===
Topic ID: 2, Entities: migraines, screen exposure, frequent headaches
Topic ID: 3, Entities: migraines, family history, maternal side
Topic ID: 4, Entities: neurological imaging, no abnormalities, mri
Topic ID: 5, Entities: irregular sleep, caffeine, sleep patterns
Topic ID: 0, Entities: reduced caffeine, stress management, sleep hygiene, preventive strategies
Topic ID: 6, Entities: limited success, pain relievers, over-the-counter medications
Topic ID: 1, Entities: polymorphism, cacna1a, familial hemiplegic migraine, genetic testing
✅ Model ready for querying.

=== Topic Quality Metrics ===
🧪 Coherence Score (c_v): 0.9453
🌈 Topic Diversity: 0.7571
📐 Silhouette Score: 0.8630

=== 📊 Topic Matching Summary ===
🔗 Model Topic 2 ↔ Ground Truth T1 — Jaccard: 1.00
🔗 Model Topic 3 ↔ Ground Truth T2 — Jaccard: 1.00
🔗 Model Topic 4 ↔ Ground Truth T3 — Jaccard: 1.00
🔗 Model Topic 5 ↔ Ground Truth T4 — Jaccard: 1.00
🔗 Mode

# Reference notes for metrics used on topic correctness

In [None]:
#Topic quality or correctness evalauting  metrics

| Metric     | What It Tells You                | Your Result | Interpretation                   |
| ---------- | -------------------------------- | ----------- | -------------------------------- |
| Coherence  | Topic semantic quality           | 0.8963      | Very coherent, meaningful topics |
| Diversity  | How distinct topics are          | 0.8250      | High diversity, little overlap   |
| Silhouette | Cluster separation & compactness | 0.7315      | Well-separated, tight clusters   |


**Coherence Score** (c_v): 0.8963

What it measures: How semantically coherent or meaningful the words within each topic are when considered together.

How it works: It compares how often top words in a topic appear together in the actual data (using word co-occurrence and semantic similarity).

Interpretation:

Values range roughly from 0 to 1 (sometimes slightly above 1 in some implementations).

Higher values (closer to 1) mean the topic words are more related and make more sense together.

A score of 0.8963 is quite high, indicating your topics are well-defined and meaningful.

🌈 **Topic Diversity** : 0.8250

What it measures: How diverse or distinct the topics are from each other based on their top words.

How it works: It's the proportion of unique top words across all topics compared to the total number of top words considered.

Interpretation:

Values range from 0 to 1.

Closer to 1 means topics are very different from each other (good diversity).

Lower means topics overlap a lot, sharing many words.

0.8250 means your topics cover a broad range of concepts with relatively little overlap.

📐**Silhouette Score**: 0.7315

What it measures: How well-separated the clusters/topics are based on the embeddings of their sentences/documents.

How it works: It compares the average distance between points in the same cluster to the distance between points in different clusters.

Interpretation:

Ranges from -1 to +1.

Closer to +1 means clusters are well-separated and compact.

Around 0 means clusters are overlapping.

Negative means points may be assigned to the wrong cluster.

0.7315 is a strong positive value, indicating your topic clusters are clearly separated in embedding space.



In [None]:
#ground truth evaluation metrics explanation

1. Average Jaccard Similarity
✅ What is it?
This tells us how similar each model-generated topic is to a ground truth topic based on the overlap of entities.

📐 Formula:
For two sets of entities
𝐴
A and
𝐵
B:

Jaccard
(
𝐴
,
𝐵
)
=
∣
𝐴
∩
𝐵
∣
∣
𝐴
∪
𝐵
∣
Jaccard(A,B)=
∣A∪B∣
∣A∩B∣
​

∣
𝐴
∩
𝐵
∣
∣A∩B∣: number of shared entities

∣
𝐴
∪
𝐵
∣
∣A∪B∣: total unique entities in both

📊 How it’s used:
For each model topic, we compute the Jaccard score with every ground truth topic and select the best match. Then we take the average of all best-match scores.

📈 Why it matters:
It gives a quantitative view of how well each generated topic matches a real one. High Jaccard → good topic separation and entity grouping.

📈 2. Coverage of Ground Truth Topics
✅ What is it?
This measures how many of the ground truth topics were matched by at least one model topic.

📐 Formula:
Coverage
=
# ground truth topics matched
total ground truth topics
×
100
%
Coverage=
total ground truth topics
# ground truth topics matched
​
 ×100%
📊 How it’s used:
If the model only clusters around a few topics, this score will be low. A good model should cover most or all ground truth topics.

📈 Why it matters:
Ensures the model isn’t ignoring certain areas or clustering everything into too few topics.

🧠 3. Entity-level Precision, Recall, and F1-score
✅ What is it?
This treats entity extraction like a classification problem:

Does the model include the right entities across topics?

📐 Definitions:
Precision = % of model entities that are correct

Recall = % of ground truth entities that were found

F1 Score = harmonic mean of precision and recall

Precision
=
TP
TP + FP
,
Recall
=
TP
TP + FN
,
F1
=
2
⋅
Precision
⋅
Recall
Precision + Recall
Precision=
TP + FP
TP
​
 ,Recall=
TP + FN
TP
​
 ,F1=
Precision + Recall
2⋅Precision⋅Recall
​

Where:

TP = entity exists in both model and ground truth

FP = entity found by model but not in ground truth

FN = ground truth entity missed by the model

📊 How it’s used:
We build a list of all entities found by the model vs ground truth, and compute scores over this list.

📈 Why it matters:
Even if topics are not perfectly matched, accurate entities still matter — e.g., for medical use cases like yours.

