In [1]:
!pip install numpy sentence-transformers bertopic hdbscan nltk scann
import nltk
nltk.download('punkt')
import nltk
nltk.download('punkt_tab')
!pip install sentence-transformers bertopic hdbscan umap-learn scann nltk datasets


Collecting bertopic
  Downloading bertopic-0.17.0-py3-none-any.whl.metadata (23 kB)
Collecting scann
  Downloading scann-1.4.0-cp311-cp311-manylinux_2_27_x86_64.whl.metadata (5.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-transform

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.




In [None]:
#One Allergic data

In [2]:
# === IMPORTS ===
import os
import random
import numpy as np
import torch
import nltk
import logging

from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

# === ENVIRONMENT SETUP ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
nltk.download("punkt")


# === CLASS FOR TOPIC SEARCHER ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params, model_name="all-MiniLM-L6-v2"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = []

        for idx, ents in enumerate(self.entities_per_chunk):
            chunk = self.chunks[idx].lower()
            sentences = sent_tokenize(chunk)
            for ent in ents:
                ent_lower = ent.lower()
                for sent in sentences:
                    if ent_lower in sent:
                        entity_context_pairs.append((ent_lower, sent.strip()))
                        break

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        # Embed entity-contexts
        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        # Topic Modeling
        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        # Metadata aggregation
        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=3):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx, score in zip(neighbors, scores):
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], sims[i]) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": meta["entities"],
                "sentences": top_sents,
                "score": score,
            })

        return results


# === EVALUATION ===
def evaluate(searcher, ground_truth, k=2):
    all_recalls, all_precisions, all_f1s = [], [], []
    better_queries = []

    for query, gt_texts in ground_truth.items():
        results = searcher.search(query, top_k_topics=1, top_k_sents=k)
        retrieved_sents = set()
        for r in results:
            retrieved_sents.update([s for s, _ in r["sentences"]])

        gt_set = set(gt_texts)
        tp = len(retrieved_sents & gt_set)
        fp = len(retrieved_sents) - tp
        fn = len(gt_set) - tp

        recall = tp / (tp + fn) if (tp + fn) else 0.0
        precision = tp / (tp + fp) if (tp + fp) else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0

        all_recalls.append(recall)
        all_precisions.append(precision)
        all_f1s.append(f1)

        if f1 > 0:
            better_queries.append((query, recall, precision, f1, retrieved_sents, gt_set))

    return {
        "avg_recall": np.mean(all_recalls),
        "avg_precision": np.mean(all_precisions),
        "avg_f1": np.mean(all_f1s),
        "better_queries": better_queries
    }


# === DATA: CLINICAL SUMMARIES, ENTITIES, GROUND TRUTH ===
chunks = [
    "Patient has peanut allergy causing hives and swelling. Anaphylaxis noted once during a reaction.",
    "Allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.",
    "Severe anaphylaxis symptoms require immediate treatment with epinephrine.",
    "Food allergies to milk and eggs can cause skin reactions like urticaria and eczema.",
    "Cold weather does not cause allergy symptoms in this patient."
]

entities = [
    ["peanut allergy", "hives", "swelling", "anaphylaxis"],
    ["allergic rhinitis", "hay fever", "pollen", "dust", "pet dander"],
    ["anaphylaxis", "epinephrine", "treatment"],
    ["food allergies", "milk", "eggs", "urticaria", "eczema"],
    ["cold weather", "allergy symptoms"]
]

ground_truth = {
    "peanut allergy": [
        "patient has peanut allergy causing hives and swelling. anaphylaxis noted once during a reaction.",
        "severe anaphylaxis symptoms require immediate treatment with epinephrine."
    ],
    "symptoms of anaphylaxis": [
        "severe anaphylaxis symptoms require immediate treatment with epinephrine.",
        "patient has peanut allergy causing hives and swelling. anaphylaxis noted once during a reaction."
    ],
    "hay fever": [
        "allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.",
        "food allergies to milk and eggs can cause skin reactions like urticaria and eczema."
    ],
    "eczema treatment": [
        "food allergies to milk and eggs can cause skin reactions like urticaria and eczema.",
        "severe anaphylaxis symptoms require immediate treatment with epinephrine."
    ],
    "cold weather allergy": [
        "cold weather does not cause allergy symptoms in this patient.",
        "food allergies to milk and eggs can cause skin reactions like urticaria and eczema."
    ]
}


# === GRID SEARCH CONFIGS ===
umap_grid = [
    {'n_neighbors': 15, 'n_components': 5, 'min_dist': 0.0, 'metric': 'cosine', 'random_state': SEED},
    {'n_neighbors': 10, 'n_components': 4, 'min_dist': 0.1, 'metric': 'cosine', 'random_state': SEED}
]

hdbscan_grid = [
    {'min_cluster_size': 2, 'min_samples': 1, 'metric': 'euclidean', 'prediction_data': True},
    {'min_cluster_size': 3, 'min_samples': 1, 'metric': 'euclidean', 'prediction_data': True}
]

results = []
for u in umap_grid:
    for h in hdbscan_grid:
        print(f"\n🔍 Trying UMAP={u} | HDBSCAN={h}")
        try:
            searcher = AllergyTopicSearcher(chunks, entities, u, h)
            metrics = evaluate(searcher, ground_truth)
            results.append((metrics, u, h))
        except Exception as e:
            print(f"❌ Failed config: {e}")

# === Best Result ===
best = sorted(results, key=lambda x: x[0]["avg_f1"], reverse=True)[0]
metrics, best_umap, best_hdbscan = best

print("\n🏆 Best Configuration:")
print(f"UMAP: {best_umap}")
print(f"HDBSCAN: {best_hdbscan}")
print(f"Avg Recall: {metrics['avg_recall']:.2f} | Precision: {metrics['avg_precision']:.2f} | F1: {metrics['avg_f1']:.2f}")

print("\n🎯 Best 2 performing queries and their results:")
for q, recall, precision, f1, retrieved, gt in metrics["better_queries"][:2]:
    print(f"\n🔹 Query: '{q}'")
    print(f"Recall: {recall:.2f} | Precision: {precision:.2f} | F1: {f1:.2f}")
    print("Expected:")
    for s in gt:
        print(f"  ✓ {s}")
    print("Retrieved:")
    for s in retrieved:
        print(f"  → {s}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!



🔍 Trying UMAP={'n_neighbors': 15, 'n_components': 5, 'min_dist': 0.0, 'metric': 'cosine', 'random_state': 42} | HDBSCAN={'min_cluster_size': 2, 'min_samples': 1, 'metric': 'euclidean', 'prediction_data': True}


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]


🔍 Trying UMAP={'n_neighbors': 15, 'n_components': 5, 'min_dist': 0.0, 'metric': 'cosine', 'random_state': 42} | HDBSCAN={'min_cluster_size': 3, 'min_samples': 1, 'metric': 'euclidean', 'prediction_data': True}

🔍 Trying UMAP={'n_neighbors': 10, 'n_components': 4, 'min_dist': 0.1, 'metric': 'cosine', 'random_state': 42} | HDBSCAN={'min_cluster_size': 2, 'min_samples': 1, 'metric': 'euclidean', 'prediction_data': True}

🔍 Trying UMAP={'n_neighbors': 10, 'n_components': 4, 'min_dist': 0.1, 'metric': 'cosine', 'random_state': 42} | HDBSCAN={'min_cluster_size': 3, 'min_samples': 1, 'metric': 'euclidean', 'prediction_data': True}

🏆 Best Configuration:
UMAP: {'n_neighbors': 15, 'n_components': 5, 'min_dist': 0.0, 'metric': 'cosine', 'random_state': 42}
HDBSCAN: {'min_cluster_size': 2, 'min_samples': 1, 'metric': 'euclidean', 'prediction_data': True}
Avg Recall: 0.40 | Precision: 0.70 | F1: 0.50

🎯 Best 2 performing queries and their results:

🔹 Query: 'symptoms of anaphylaxis'
Recall: 0.50 

In [None]:
#Multi dataset with 2 top outputs in ground truth

In [3]:
# === IMPORTS ===
import os
import random
import numpy as np
import torch
import nltk
import logging

from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

# === ENVIRONMENT SETUP ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
nltk.download("punkt")


# === CLASS FOR TOPIC SEARCHER ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params, model_name="all-MiniLM-L6-v2"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = []

        for idx, ents in enumerate(self.entities_per_chunk):
            chunk = self.chunks[idx].lower()
            sentences = sent_tokenize(chunk)
            for ent in ents:
                ent_lower = ent.lower()
                for sent in sentences:
                    if ent_lower in sent:
                        entity_context_pairs.append((ent_lower, sent.strip()))
                        break

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        # Embed entity-contexts
        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        # Topic Modeling
        umap_model = UMAP(**self.umap_params)
        hdbscan_model = HDBSCAN(**self.hdbscan_params)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        # Metadata aggregation
        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=3):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx, score in zip(neighbors, scores):
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], sims[i]) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": meta["entities"],
                "sentences": top_sents,
                "score": score,
            })

        return results


# === EVALUATION FUNCTION ===
def evaluate(searcher, ground_truth, k=2):
    all_recalls, all_precisions, all_f1s = [], [], []
    better_queries = []

    for query, gt_texts in ground_truth.items():
        results = searcher.search(query, top_k_topics=1, top_k_sents=k)
        retrieved_sents = set()
        for r in results:
            retrieved_sents.update([s for s, _ in r["sentences"]])

        gt_set = set(gt_texts)
        tp = len(retrieved_sents & gt_set)
        fp = len(retrieved_sents) - tp
        fn = len(gt_set) - tp

        recall = tp / (tp + fn) if (tp + fn) else 0.0
        precision = tp / (tp + fp) if (tp + fp) else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0

        all_recalls.append(recall)
        all_precisions.append(precision)
        all_f1s.append(f1)

        if f1 > 0:
            better_queries.append((query, recall, precision, f1, retrieved_sents, gt_set))

    return {
        "avg_recall": np.mean(all_recalls),
        "avg_precision": np.mean(all_precisions),
        "avg_f1": np.mean(all_f1s),
        "better_queries": better_queries
    }


# === ORIGINAL DATASET ===
original_chunks = [
    "Patient has peanut allergy causing hives and swelling. Anaphylaxis noted once during a reaction.",
    "Allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.",
    "Severe anaphylaxis symptoms require immediate treatment with epinephrine.",
    "Food allergies to milk and eggs can cause skin reactions like urticaria and eczema.",
    "Cold weather does not cause allergy symptoms in this patient."
]

original_entities = [
    ["peanut allergy", "hives", "swelling", "anaphylaxis"],
    ["allergic rhinitis", "hay fever", "pollen", "dust", "pet dander"],
    ["anaphylaxis", "epinephrine", "treatment"],
    ["food allergies", "milk", "eggs", "urticaria", "eczema"],
    ["cold weather", "allergy symptoms"]
]

original_ground_truth = {
    "peanut allergy": [
        "patient has peanut allergy causing hives and swelling. anaphylaxis noted once during a reaction.",
        "severe anaphylaxis symptoms require immediate treatment with epinephrine."
    ],
    "symptoms of anaphylaxis": [
        "severe anaphylaxis symptoms require immediate treatment with epinephrine.",
        "patient has peanut allergy causing hives and swelling. anaphylaxis noted once during a reaction."
    ],
    "hay fever": [
        "allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.",
        "food allergies to milk and eggs can cause skin reactions like urticaria and eczema."
    ],
    "eczema treatment": [
        "food allergies to milk and eggs can cause skin reactions like urticaria and eczema.",
        "severe anaphylaxis symptoms require immediate treatment with epinephrine."
    ],
    "cold weather allergy": [
        "cold weather does not cause allergy symptoms in this patient.",
        "food allergies to milk and eggs can cause skin reactions like urticaria and eczema."
    ]
}

# === ADDITIONAL DATASET 1: ASTHMA-RELATED NOTES ===
asthma_chunks = [
    "The patient has a long history of asthma triggered by dust and pollen exposure.",
    "During cold weather, patient experiences wheezing and shortness of breath.",
    "Inhalers like salbutamol are prescribed to manage asthma symptoms.",
]

asthma_entities = [
    ["asthma", "dust", "pollen"],
    ["cold weather", "wheezing", "shortness of breath"],
    ["inhalers", "salbutamol", "asthma symptoms"],
]

asthma_ground_truth = {
    "asthma management": [
        "inhalers like salbutamol are prescribed to manage asthma symptoms.",
        "the patient has a long history of asthma triggered by dust and pollen exposure."
    ],
    "wheezing and cold weather": [
        "during cold weather, patient experiences wheezing and shortness of breath."
    ]
}

# === ADDITIONAL DATASET 2: DRUG ALLERGY NOTES ===
drug_allergy_chunks = [
    "Patient experienced rash and difficulty breathing after taking amoxicillin.",
    "Anaphylaxis suspected due to penicillin exposure. Epinephrine administered.",
    "Patient advised to avoid beta-lactam antibiotics.",
]

drug_allergy_entities = [
    ["rash", "difficulty breathing", "amoxicillin"],
    ["anaphylaxis", "penicillin", "epinephrine"],
    ["beta-lactam antibiotics"],
]

drug_allergy_ground_truth = {
    "amoxicillin allergy": [
        "patient experienced rash and difficulty breathing after taking amoxicillin.",
        "anaphylaxis suspected due to penicillin exposure. epinephrine administered."
    ],
    "epinephrine use": [
        "anaphylaxis suspected due to penicillin exposure. epinephrine administered."
    ],
    "drug allergy avoidance": [
        "patient advised to avoid beta-lactam antibiotics."
    ]
}

# === BEST PARAMETERS FROM YOUR GRID SEARCH ===
best_umap = {'n_neighbors': 15, 'n_components': 5, 'min_dist': 0.0, 'metric': 'cosine', 'random_state': SEED}
best_hdbscan = {'min_cluster_size': 2, 'min_samples': 1, 'metric': 'euclidean', 'prediction_data': True}

# === DATASETS LIST FOR EVALUATION ===
datasets = [
    (original_chunks, original_entities, original_ground_truth, "Original Allergy Dataset"),
    (asthma_chunks, asthma_entities, asthma_ground_truth, "Asthma-Related Notes"),
    (drug_allergy_chunks, drug_allergy_entities, drug_allergy_ground_truth, "Drug Allergy Notes"),
]

# === EVALUATION LOOP ===
evaluation_results = []

for chunks, entities, gt, name in datasets:
    print(f"\n=== Evaluating: {name} ===")
    try:
        searcher = AllergyTopicSearcher(chunks, entities, best_umap, best_hdbscan)
        metrics = evaluate(searcher, gt, k=2)

        print(f"Avg Recall: {metrics['avg_recall']:.2f} | Avg Precision: {metrics['avg_precision']:.2f} | F1: {metrics['avg_f1']:.2f}")

        if metrics["avg_f1"] > 0.0:
            top_queries = sorted(metrics["better_queries"], key=lambda x: x[3], reverse=True)[:2]
            for query, recall, precision, f1, retrieved, gt_set in top_queries:
                print(f"\n🔹 Query: '{query}'")
                print(f"Recall: {recall:.2f} | Precision: {precision:.2f} | F1: {f1:.2f}")
                print("Expected:")
                for s in gt_set:
                    print(f"  ✓ {s}")
                print("Retrieved:")
                for s in retrieved:
                    print(f"  → {s}")
    except Exception as e:
        print(f"Error evaluating {name}: {e}")



=== Evaluating: Original Allergy Dataset ===


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Avg Recall: 0.40 | Avg Precision: 0.70 | F1: 0.50

🔹 Query: 'hay fever'
Recall: 0.50 | Precision: 1.00 | F1: 0.67
Expected:
  ✓ allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.
  ✓ food allergies to milk and eggs can cause skin reactions like urticaria and eczema.
Retrieved:
  → allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.

🔹 Query: 'eczema treatment'
Recall: 0.50 | Precision: 1.00 | F1: 0.67
Expected:
  ✓ severe anaphylaxis symptoms require immediate treatment with epinephrine.
  ✓ food allergies to milk and eggs can cause skin reactions like urticaria and eczema.
Retrieved:
  → food allergies to milk and eggs can cause skin reactions like urticaria and eczema.

=== Evaluating: Asthma-Related Notes ===
Avg Recall: 0.75 | Avg Precision: 1.00 | F1: 0.83

🔹 Query: 'wheezing and cold weather'
Recall: 1.00 | Precision: 1.00 | F1: 1.00
Expected:
  ✓ during cold weather, patient experiences wheezing and shortness

In [None]:
#Multi dataset with 1 top outputs in ground truth

In [4]:
# === IMPORTS ===
import os
import random
import numpy as np
import torch
import nltk
import logging

from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

# === ENVIRONMENT SETUP ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
nltk.download("punkt")


# === CLASS FOR TOPIC SEARCHER ===
class AllergyTopicSearcher:
    def __init__(self, chunks, entities_per_chunk, umap_params, hdbscan_params, model_name="all-MiniLM-L6-v2"):
        self.chunks = chunks
        self.entities_per_chunk = entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)

        self.umap_params = umap_params
        self.hdbscan_params = hdbscan_params

        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        entity_context_pairs = []

        for idx, ents in enumerate(self.entities_per_chunk):
            chunk = self.chunks[idx].lower()
            sentences = sent_tokenize(chunk)
            for ent in ents:
                ent_lower = ent.lower()
                for sent in sentences:
                    if ent_lower in sent:
                        entity_context_pairs.append((ent_lower, sent.strip()))
                        break

        if not entity_context_pairs:
            raise ValueError("No entity-context pairs extracted!")

        # Embed entity-contexts
        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        # Topic Modeling
        umap_model = UMAP(**self.umap_params)
        # IMPORTANT: remove 'prediction_data' from hdbscan_params and add it explicitly here
        hdbscan_model = HDBSCAN(**self.hdbscan_params, prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False,
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        # Metadata aggregation
        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(embeddings)
            })

        self.topic_embeddings = np.array(topic_embeddings)
        self.topic_metadata = topic_metadata

        if len(self.topic_embeddings) < 1:
            raise RuntimeError("No topic embeddings to index.")

        num_clusters = min(len(self.topic_embeddings), 3)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=1):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        results = []
        for idx, score in zip(neighbors, scores):
            meta = self.topic_metadata[idx]
            seen = set()
            unique_sentences = []
            unique_embeddings = []

            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    unique_sentences.append(sent)
                    unique_embeddings.append(emb)

            sent_embs = np.array(unique_embeddings)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            top_sents = [(unique_sentences[i], sims[i]) for i in top_indices]

            results.append({
                "topic_id": meta["topic_id"],
                "entities": meta["entities"],
                "sentences": top_sents,
                "score": score,
            })

        return results


# === EVALUATION FUNCTION ===
def evaluate(searcher, ground_truth, k=1):
    all_recalls, all_precisions, all_f1s = [], [], []
    better_queries = []

    for query, gt_texts in ground_truth.items():
        results = searcher.search(query, top_k_topics=1, top_k_sents=k)
        retrieved_sents = set()
        for r in results:
            retrieved_sents.update([s for s, _ in r["sentences"]])

        gt_set = set(gt_texts)
        tp = len(retrieved_sents & gt_set)
        fp = len(retrieved_sents) - tp
        fn = len(gt_set) - tp

        recall = tp / (tp + fn) if (tp + fn) else 0.0
        precision = tp / (tp + fp) if (tp + fp) else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0

        all_recalls.append(recall)
        all_precisions.append(precision)
        all_f1s.append(f1)

        if f1 > 0:
            better_queries.append((query, recall, precision, f1, retrieved_sents, gt_set))

    return {
        "avg_recall": np.mean(all_recalls),
        "avg_precision": np.mean(all_precisions),
        "avg_f1": np.mean(all_f1s),
        "better_queries": better_queries
    }


# === DATASETS ===
datasets = [
    {
        "name": "Allergy Dataset",
        "chunks": [
            "Patient has peanut allergy causing hives and swelling. Anaphylaxis noted once during a reaction.",
            "Allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.",
            "Severe anaphylaxis symptoms require immediate treatment with epinephrine.",
            "Food allergies to milk and eggs can cause skin reactions like urticaria and eczema.",
            "Cold weather does not cause allergy symptoms in this patient."
        ],
        "entities": [
            ["peanut allergy", "hives", "swelling", "anaphylaxis"],
            ["allergic rhinitis", "hay fever", "pollen", "dust", "pet dander"],
            ["anaphylaxis", "epinephrine", "treatment"],
            ["food allergies", "milk", "eggs", "urticaria", "eczema"],
            ["cold weather", "allergy symptoms"]
        ],
        "ground_truth": {
            "peanut allergy": [
                "patient has peanut allergy causing hives and swelling. anaphylaxis noted once during a reaction."
            ],
            "symptoms of anaphylaxis": [
                "severe anaphylaxis symptoms require immediate treatment with epinephrine."
            ],
            "hay fever": [
                "allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander."
            ],
            "eczema treatment": [
                "food allergies to milk and eggs can cause skin reactions like urticaria and eczema."
            ],
            "cold weather allergy": [
                "cold weather does not cause allergy symptoms in this patient."
            ]
        }
    },
    {
        "name": "Atopic Dermatitis",
        "chunks": [
            "Atopic dermatitis is a chronic skin condition characterized by itchy and inflamed skin.",
            "Patients with atopic dermatitis often have dry, scaly patches and may experience infections.",
            "Treatment includes moisturizers, corticosteroids, and avoiding irritants.",
            "Severe cases may require systemic immunosuppressants.",
            "Triggers include allergens such as dust mites, pet dander, and pollen."
        ],
        "entities": [
            ["atopic dermatitis", "skin condition", "itchy", "inflamed skin"],
            ["dry", "scaly patches", "infections"],
            ["treatment", "moisturizers", "corticosteroids", "irritants"],
            ["severe cases", "systemic immunosuppressants"],
            ["triggers", "allergens", "dust mites", "pet dander", "pollen"]
        ],
        "ground_truth": {
            "treatment for atopic dermatitis": [
                "treatment includes moisturizers, corticosteroids, and avoiding irritants."
            ],
            "symptoms of atopic dermatitis": [
                "atopic dermatitis is a chronic skin condition characterized by itchy and inflamed skin.",
                "patients with atopic dermatitis often have dry, scaly patches and may experience infections."
            ],
            "severe atopic dermatitis": [
                "severe cases may require systemic immunosuppressants."
            ],
            "common triggers": [
                "triggers include allergens such as dust mites, pet dander, and pollen."
            ]
        }
    },
    {
        "name": "Food Allergy Reactions",
        "chunks": [
            "Food allergies can cause a range of reactions from mild hives to severe anaphylaxis.",
            "Common food allergens include peanuts, tree nuts, milk, eggs, wheat, soy, fish, and shellfish.",
            "Anaphylaxis requires immediate treatment with epinephrine.",
            "Symptoms can include swelling, difficulty breathing, and rash.",
            "Avoidance of allergens is key to management."
        ],
        "entities": [
            ["food allergies", "reactions", "hives", "anaphylaxis"],
            ["common food allergens", "peanuts", "tree nuts", "milk", "eggs", "wheat", "soy", "fish", "shellfish"],
            ["anaphylaxis", "treatment", "epinephrine"],
            ["symptoms", "swelling", "difficulty breathing", "rash"],
            ["avoidance", "management"]
        ],
        "ground_truth": {
            "anaphylaxis treatment": [
                "anaphylaxis requires immediate treatment with epinephrine."
            ],
            "common food allergens": [
                "common food allergens include peanuts, tree nuts, milk, eggs, wheat, soy, fish, and shellfish."
            ],
            "symptoms of food allergy": [
                "symptoms can include swelling, difficulty breathing, and rash."
            ],
            "management of food allergies": [
                "avoidance of allergens is key to management."
            ]
        }
    },

    {
        "name": "Allergic Rhinitis",
        "chunks": [
            "Allergic rhinitis is caused by an allergic response to airborne particles like pollen, dust, and pet dander.",
            "Symptoms include sneezing, nasal congestion, runny nose, and itchy eyes.",
            "Treatment options include antihistamines, nasal corticosteroids, and avoiding triggers.",
            "Common triggers include pollen, dust mites, mold, and pet dander.",
            "Seasonal allergic rhinitis is often worse during pollen season."
        ],
        "entities": [
            ["allergic rhinitis", "allergic response", "airborne particles", "pollen", "dust", "pet dander"],
            ["symptoms", "sneezing", "nasal congestion", "runny nose", "itchy eyes"],
            ["treatment", "antihistamines", "nasal corticosteroids", "avoiding triggers"],
            ["common triggers", "pollen", "dust mites", "mold", "pet dander"],
            ["seasonal allergic rhinitis", "pollen season"]
        ],
        "ground_truth": {
            "symptoms of allergic rhinitis": [
                "symptoms include sneezing, nasal congestion, runny nose, and itchy eyes."
            ],
            "treatment options": [
                "treatment options include antihistamines, nasal corticosteroids, and avoiding triggers."
            ],
            "common triggers": [
                "common triggers include pollen, dust mites, mold, and pet dander."
            ],
            "seasonal allergic rhinitis": [
                "seasonal allergic rhinitis is often worse during pollen season."
            ]
        }
    }
]

# === BEST PARAMETERS ===
best_umap = {"n_neighbors": 5, "n_components": 5, "min_dist": 0.1, "metric": "cosine"}
best_hdbscan = {"min_cluster_size": 2, "min_samples": 1, "metric": "euclidean"}  # prediction_data removed here

# === RUN EVALUATION ===
for dataset in datasets:
    print(f"\n=== Evaluating Dataset: {dataset['name']} ===")
    try:
        searcher = AllergyTopicSearcher(
            chunks=dataset["chunks"],
            entities_per_chunk=dataset["entities"],
            umap_params=best_umap,
            hdbscan_params=best_hdbscan,
        )
    except Exception as e:
        print(f"❌ Failed to create searcher for {dataset['name']}: {e}")
        continue

    eval_res = evaluate(searcher, dataset["ground_truth"], k=2)

    print(f"Avg Recall: {eval_res['avg_recall']:.2f} | Avg Precision: {eval_res['avg_precision']:.2f} | F1: {eval_res['avg_f1']:.2f}")

    # Show top 2 queries with results
    for q, r, p, f1, retrieved, gt in sorted(eval_res["better_queries"], key=lambda x: x[3], reverse=True)[:2]:
        print(f"\nQuery: '{q}'")
        print(f"Recall: {r:.2f} | Precision: {p:.2f} | F1: {f1:.2f}")
        print("Expected:\n ", "\n  ✓ ".join(gt))
        print("Retrieved:\n ", "\n  ✓ ".join(retrieved))





=== Evaluating Dataset: Allergy Dataset ===


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Avg Recall: 0.60 | Avg Precision: 0.30 | F1: 0.40

Query: 'symptoms of anaphylaxis'
Recall: 1.00 | Precision: 0.50 | F1: 0.67
Expected:
  severe anaphylaxis symptoms require immediate treatment with epinephrine.
Retrieved:
  severe anaphylaxis symptoms require immediate treatment with epinephrine.
  ✓ anaphylaxis noted once during a reaction.

Query: 'hay fever'
Recall: 1.00 | Precision: 0.50 | F1: 0.67
Expected:
  allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.
Retrieved:
  allergic rhinitis, or hay fever, results from exposure to pollen, dust, or pet dander.
  ✓ cold weather does not cause allergy symptoms in this patient.

=== Evaluating Dataset: Atopic Dermatitis ===
Avg Recall: 0.38 | Avg Precision: 0.50 | F1: 0.42

Query: 'common triggers'
Recall: 1.00 | Precision: 1.00 | F1: 1.00
Expected:
  triggers include allergens such as dust mites, pet dander, and pollen.
Retrieved:
  triggers include allergens such as dust mites, pet dander, and poll