In [2]:
!pip install numpy sentence-transformers bertopic hdbscan nltk scann
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [7]:
# === Shared Imports & Setup ===
import os
import random
import numpy as np
import torch
import nltk
import logging
from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

# Reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

# Download punkt tokenizer
nltk.download("punkt")

# === Code 1 ===
class AllergyTopicSearcherModel1:
    def __init__(self, chunks, manual_entities_per_chunk, model_name="emilyalsentzer/Bio_ClinicalBERT"):
        self.chunks = chunks
        self.manual_entities_per_chunk = manual_entities_per_chunk
        self.embedding_model_name = model_name

        self.embedding_model = None
        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        self.embedding_model = SentenceTransformer(self.embedding_model_name)

        entity_to_chunk = defaultdict(list)
        all_entities = []
        for idx, ents in enumerate(self.manual_entities_per_chunk):
            for ent in ents:
                ent_lower = ent.lower()
                all_entities.append(ent_lower)
                entity_to_chunk[ent_lower].append(idx)

        unique_entities = sorted(set(all_entities))
        entity_embeddings = self.embedding_model.encode(unique_entities, normalize_embeddings=True)

        umap_model = UMAP(n_neighbors=15, n_components=5, metric='cosine', random_state=SEED)
        hdbscan_model = HDBSCAN(min_cluster_size=2, min_samples=1, metric='euclidean',
                                prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False
        )

        topics, _ = self.topic_model.fit_transform(unique_entities, embeddings=entity_embeddings)

        topic_to_entities = defaultdict(list)
        for ent, topic in zip(unique_entities, topics):
            topic_to_entities[topic].append(ent)

        topic_contexts = defaultdict(list)
        for topic, entities in topic_to_entities.items():
            for ent in entities:
                for chunk_id in entity_to_chunk[ent]:
                    for sent in sent_tokenize(self.chunks[chunk_id]):
                        if ent in sent.lower():
                            topic_contexts[topic].append(sent)

        for topic in topic_contexts:
            topic_contexts[topic] = list(set(topic_contexts[topic]))

        topic_embeddings = []
        topic_metadata = []

        for topic_id, sentences in topic_contexts.items():
            if not sentences:
                continue
            sent_embs = self.embedding_model.encode(sentences, normalize_embeddings=True)
            mean_emb = np.mean(sent_embs, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": topic_to_entities[topic_id],
                "sentences": sentences,
                "sentence_embeddings": sent_embs
            })

        self.topic_embeddings = np.array(topic_embeddings)

        num_clusters = min(len(self.topic_embeddings), 5)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

        self.topic_metadata = topic_metadata

    def search(self, query, top_k_topics=1, top_k_sents=2):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        output = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            sents = meta["sentences"]
            sent_embs = meta["sentence_embeddings"]
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            output += [sents[i] for i in top_indices]
        return output


# === Code 2 ===
class AllergyTopicSearcherModel2:
    def __init__(self, chunks, manual_entities_per_chunk, model_name="emilyalsentzer/Bio_ClinicalBERT"):
        self.chunks = chunks
        self.manual_entities_per_chunk = manual_entities_per_chunk
        self.embedding_model_name = model_name

        self.embedding_model = None
        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        self.embedding_model = SentenceTransformer(self.embedding_model_name)

        entity_context_pairs = []
        entity_to_chunk = defaultdict(list)

        for idx, ents in enumerate(self.manual_entities_per_chunk):
            chunk = self.chunks[idx].lower()
            sentences = sent_tokenize(chunk)
            for ent in ents:
                ent_lower = ent.lower()
                for sent in sentences:
                    if ent_lower in sent:
                        entity_context_pairs.append((ent_lower, sent.strip()))
                        entity_to_chunk[ent_lower].append(idx)
                        break

        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        umap_model = UMAP(n_neighbors=15, n_components=5, metric='cosine', random_state=SEED)
        hdbscan_model = HDBSCAN(min_cluster_size=2, min_samples=1, metric='euclidean', prediction_data=True)

        self.topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False
        )

        topics, _ = self.topic_model.fit_transform(contextual_texts, embeddings=contextual_embeddings)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10

            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(topic_to_embeddings[topic_id])
            })

        self.topic_embeddings = np.array(topic_embeddings)

        num_clusters = min(len(self.topic_embeddings), 5)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

        self.topic_metadata = topic_metadata

    def search(self, query, top_k_topics=1, top_k_sents=2):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        output = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            sents = []
            embs = []
            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    sents.append(sent)
                    embs.append(emb)
            sent_embs = np.array(embs)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            output += [sents[i] for i in top_indices]
        return output


# === Evaluation ===
chunks = [
    "Peanut allergy is one of the most common causes of severe allergic reactions. Symptoms can include hives, swelling, and anaphylaxis.",
    "Allergic rhinitis, commonly known as hay fever, is an allergic response to pollen, dust, or pet dander.",
    "Anaphylaxis is a serious, potentially life-threatening allergic reaction that can occur rapidly.",
    "Patients with food allergies, such as milk or eggs, need to be careful with their diet.",
    "Skin reactions like urticaria (hives) and eczema are often signs of allergies.",
    "He walks in cold weather but has no allergy symptoms or reactions."
]

manual_entities_per_chunk = [
    ["peanut allergy", "hives", "swelling", "anaphylaxis"],
    ["allergic rhinitis", "hay fever", "pollen", "dust", "pet dander"],
    ["anaphylaxis", "allergic reaction"],
    ["food allergies", "milk", "eggs"],
    ["urticaria", "hives", "eczema", "allergies"],
    ["cold weather", "allergy symptoms", "reactions"]
]

queries = [
    "peanut allergy",
    "symptoms of anaphylaxis",
    "hay fever",
    "eczema treatment",
    "allergic reaction to milk",
    "signs of food allergy",
    "urticaria causes",
    "pet dander allergies",
    "cold weather allergy",
    "hives and swelling"
]

print("🔧 Initializing Model 1...")
m1 = AllergyTopicSearcherModel1(chunks, manual_entities_per_chunk)

print("🔧 Initializing Model 2...")
m2 = AllergyTopicSearcherModel2(chunks, manual_entities_per_chunk)

print("\n\n🔍 Starting Evaluation")
for q in queries:
    res1 = m1.search(q, top_k_sents=2)
    res2 = m2.search(q, top_k_sents=2)

    print(f"\n\n🔎 Query: {q}")
    print("-" * 90)
    print("📘 Model 1:")
    for i, r in enumerate(res1, 1):
        print(f"  {i}. {r}")
    print("📙 Model 2:")
    for i, r in enumerate(res2, 1):
        print(f"  {i}. {r}")
    print("=" * 90)


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


🔧 Initializing Model 1...




🔧 Initializing Model 2...


🔍 Starting Evaluation


🔎 Query: peanut allergy
------------------------------------------------------------------------------------------
📘 Model 1:
  1. He walks in cold weather but has no allergy symptoms or reactions.
  2. Peanut allergy is one of the most common causes of severe allergic reactions.
📙 Model 2:
  1. he walks in cold weather but has no allergy symptoms or reactions.


🔎 Query: symptoms of anaphylaxis
------------------------------------------------------------------------------------------
📘 Model 1:
  1. Symptoms can include hives, swelling, and anaphylaxis.
  2. Anaphylaxis is a serious, potentially life-threatening allergic reaction that can occur rapidly.
📙 Model 2:
  1. symptoms can include hives, swelling, and anaphylaxis.


🔎 Query: hay fever
------------------------------------------------------------------------------------------
📘 Model 1:
  1. He walks in cold weather but has no allergy symptoms or reactions.
  2. Symptoms can i

In [8]:
# === Shared Imports & Setup ===
import os
import random
import numpy as np
import torch
import nltk
import logging
from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

# Reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

# Download punkt tokenizer
nltk.download("punkt")


# === Data ===
chunks = [
    "Peanut allergy is one of the most common causes of severe allergic reactions. Symptoms can include hives, swelling, and anaphylaxis.",
    "Allergic rhinitis, commonly known as hay fever, is an allergic response to pollen, dust, or pet dander.",
    "Anaphylaxis is a serious, potentially life-threatening allergic reaction that can occur rapidly.",
    "Patients with food allergies, such as milk or eggs, need to be careful with their diet.",
    "Skin reactions like urticaria (hives) and eczema are often signs of allergies.",
    "He walks in cold weather but has no allergy symptoms or reactions."
]

manual_entities_per_chunk = [
    ["peanut allergy", "hives", "swelling", "anaphylaxis"],
    ["allergic rhinitis", "hay fever", "pollen", "dust", "pet dander"],
    ["anaphylaxis", "allergic reaction"],
    ["food allergies", "milk", "eggs"],
    ["urticaria", "hives", "eczema", "allergies"],
    ["cold weather", "allergy symptoms", "reactions"]
]

queries = [
    "peanut allergy",
    "symptoms of anaphylaxis",
    "hay fever",
    "eczema treatment",
    "allergic reaction to milk",
    "signs of food allergy",
    "urticaria causes",
    "pet dander allergies",
    "cold weather allergy",
    "hives and swelling"
]

# Ground truth answers (2 outputs each query)
ground_truth = {
    "peanut allergy": [
        "Peanut allergy is one of the most common causes of severe allergic reactions. Symptoms can include hives, swelling, and anaphylaxis.",
        "Anaphylaxis is a serious, potentially life-threatening allergic reaction that can occur rapidly."
    ],
    "symptoms of anaphylaxis": [
        "Anaphylaxis is a serious, potentially life-threatening allergic reaction that can occur rapidly.",
        "Peanut allergy is one of the most common causes of severe allergic reactions. Symptoms can include hives, swelling, and anaphylaxis."
    ],
    "hay fever": [
        "Allergic rhinitis, commonly known as hay fever, is an allergic response to pollen, dust, or pet dander.",
        "Skin reactions like urticaria (hives) and eczema are often signs of allergies."
    ],
    "eczema treatment": [
        "Skin reactions like urticaria (hives) and eczema are often signs of allergies.",
        "Patients with food allergies, such as milk or eggs, need to be careful with their diet."
    ],
    "allergic reaction to milk": [
        "Patients with food allergies, such as milk or eggs, need to be careful with their diet.",
        "Peanut allergy is one of the most common causes of severe allergic reactions. Symptoms can include hives, swelling, and anaphylaxis."
    ],
    "signs of food allergy": [
        "Patients with food allergies, such as milk or eggs, need to be careful with their diet.",
        "Skin reactions like urticaria (hives) and eczema are often signs of allergies."
    ],
    "urticaria causes": [
        "Skin reactions like urticaria (hives) and eczema are often signs of allergies.",
        "Peanut allergy is one of the most common causes of severe allergic reactions. Symptoms can include hives, swelling, and anaphylaxis."
    ],
    "pet dander allergies": [
        "Allergic rhinitis, commonly known as hay fever, is an allergic response to pollen, dust, or pet dander.",
        "Skin reactions like urticaria (hives) and eczema are often signs of allergies."
    ],
    "cold weather allergy": [
        "He walks in cold weather but has no allergy symptoms or reactions.",
        "Skin reactions like urticaria (hives) and eczema are often signs of allergies."
    ],
    "hives and swelling": [
        "Peanut allergy is one of the most common causes of severe allergic reactions. Symptoms can include hives, swelling, and anaphylaxis.",
        "Skin reactions like urticaria (hives) and eczema are often signs of allergies."
    ]
}

# === Helper functions to compute metrics ===
def normalize_text(text):
    return text.lower().strip()

def precision_recall_f1(preds, truths):
    preds_norm = set(normalize_text(p) for p in preds)
    truths_norm = set(normalize_text(t) for t in truths)
    tp = len(preds_norm.intersection(truths_norm))
    precision = tp / len(preds_norm) if preds_norm else 0.0
    recall = tp / len(truths_norm) if truths_norm else 0.0
    if precision + recall == 0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)
    return precision, recall, f1

# === Initialize Models ===
print("🔧 Initializing Model 1...")
m1 = AllergyTopicSearcherModel1(chunks, manual_entities_per_chunk)

print("🔧 Initializing Model 2...")
m2 = AllergyTopicSearcherModel2(chunks, manual_entities_per_chunk)

# === Evaluation ===
print("\n\n🔍 Starting Evaluation\n")

metrics_m1 = []
metrics_m2 = []

for q in queries:
    gt_answers = ground_truth[q]
    res1 = m1.search(q, top_k_sents=2)
    res2 = m2.search(q, top_k_sents=2)

    prec1, rec1, f1_1 = precision_recall_f1(res1, gt_answers)
    prec2, rec2, f1_2 = precision_recall_f1(res2, gt_answers)

    metrics_m1.append((prec1, rec1, f1_1))
    metrics_m2.append((prec2, rec2, f1_2))

    print(f"🔎 Query: {q}")
    print("-" * 90)
    print("📘 Model 1 Results:")
    for i, r in enumerate(res1, 1):
        print(f"  {i}. {r}")
    print(f"  Precision: {prec1:.3f} | Recall: {rec1:.3f} | F1: {f1_1:.3f}")
    print("📙 Model 2 Results:")
    for i, r in enumerate(res2, 1):
        print(f"  {i}. {r}")
    print(f"  Precision: {prec2:.3f} | Recall: {rec2:.3f} | F1: {f1_2:.3f}")
    print("=" * 90)

# Aggregate overall metrics
def aggregate_metrics(metrics):
    precs, recs, f1s = zip(*metrics)
    return np.mean(precs), np.mean(recs), np.mean(f1s)

avg_prec_m1, avg_rec_m1, avg_f1_m1 = aggregate_metrics(metrics_m1)
avg_prec_m2, avg_rec_m2, avg_f1_m2 = aggregate_metrics(metrics_m2)

print("\n\n=== Overall Evaluation ===")
print(f"Model 1 - Precision@2: {avg_prec_m1:.3f}, Recall@2: {avg_rec_m1:.3f}, F1@2: {avg_f1_m1:.3f}")
print(f"Model 2 - Precision@2: {avg_prec_m2:.3f}, Recall@2: {avg_rec_m2:.3f}, F1@2: {avg_f1_m2:.3f}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


🔧 Initializing Model 1...




🔧 Initializing Model 2...


🔍 Starting Evaluation

🔎 Query: peanut allergy
------------------------------------------------------------------------------------------
📘 Model 1 Results:
  1. He walks in cold weather but has no allergy symptoms or reactions.
  2. Peanut allergy is one of the most common causes of severe allergic reactions.
  Precision: 0.000 | Recall: 0.000 | F1: 0.000
📙 Model 2 Results:
  1. he walks in cold weather but has no allergy symptoms or reactions.
  Precision: 0.000 | Recall: 0.000 | F1: 0.000
🔎 Query: symptoms of anaphylaxis
------------------------------------------------------------------------------------------
📘 Model 1 Results:
  1. Symptoms can include hives, swelling, and anaphylaxis.
  2. Anaphylaxis is a serious, potentially life-threatening allergic reaction that can occur rapidly.
  Precision: 0.500 | Recall: 0.500 | F1: 0.500
📙 Model 2 Results:
  1. symptoms can include hives, swelling, and anaphylaxis.
  Precision: 0.000 | Recall: 0.000 | F1: 0.0

In [13]:
# === Shared Imports & Setup ===
import os
import random
import numpy as np
import torch
import nltk
import logging
from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

# Reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

nltk.download("punkt")
logging.basicConfig(level=logging.INFO)

# === Model 1 ===
class AllergyTopicSearcherModel1:
    def __init__(self, chunks, manual_entities_per_chunk, model_name="all-MiniLM-L6-v2"):
        self.chunks = chunks
        self.manual_entities_per_chunk = manual_entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)
        self._prepare()

    def _prepare(self):
        entity_to_chunk = defaultdict(list)
        all_entities = []
        for idx, ents in enumerate(self.manual_entities_per_chunk):
            for ent in ents:
                e = ent.lower()
                all_entities.append(e)
                entity_to_chunk[e].append(idx)

        unique_entities = sorted(set(all_entities))
        entity_embeddings = self.embedding_model.encode(unique_entities, normalize_embeddings=True)

        umap_model = UMAP(n_neighbors=15, n_components=5, metric='cosine', random_state=SEED)
        hdbscan_model = HDBSCAN(min_cluster_size=2, min_samples=1, metric='euclidean', prediction_data=True)

        topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False
        )
        topics, _ = topic_model.fit_transform(unique_entities, embeddings=entity_embeddings)

        topic_to_entities = defaultdict(list)
        for ent, t in zip(unique_entities, topics):
            topic_to_entities[t].append(ent)

        topic_metadata = []
        topic_embeddings = []
        for t, ents in topic_to_entities.items():
            sents = []
            for ent in ents:
                for idx in entity_to_chunk[ent]:
                    for sent in sent_tokenize(self.chunks[idx].lower()):
                        if ent in sent:
                            sents.append(sent)
            sents = list(set(sents))
            if not sents:
                continue
            sent_embs = self.embedding_model.encode(sents, normalize_embeddings=True)
            mean_emb = np.mean(sent_embs, axis=0)
            mean_emb /= np.linalg.norm(mean_emb)

            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": t,
                "entities": ents,
                "sentences": sents,
                "sentence_embeddings": sent_embs
            })

        self.topic_embeddings = np.vstack(topic_embeddings)
        self.topic_metadata = topic_metadata
        self.searcher = scann.scann_ops_pybind.builder(
            self.topic_embeddings, 3, "dot_product"
        ).tree(num_leaves=min(5, len(self.topic_embeddings)), num_leaves_to_search=2)\
         .score_brute_force().reorder(3).build()

    def search(self, query, top_k_topics=1, top_k_sents=2):
        q_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, _ = self.searcher.search(q_emb, final_num_neighbors=top_k_topics)
        output = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            sent_embs = meta["sentence_embeddings"]
            sims = np.dot(sent_embs, q_emb)
            for i in sims.argsort()[::-1][:top_k_sents]:
                output.append(meta["sentences"][i])
        return output

# === Model 2 ===
class AllergyTopicSearcherModel2:
    def __init__(self, chunks, manual_entities_per_chunk, model_name="all-MiniLM-L6-v2"):
        self.chunks = chunks
        self.manual_entities_per_chunk = manual_entities_per_chunk
        self.embedding_model = SentenceTransformer(model_name)
        self._prepare()

    def _prepare(self):
        pairs = []
        for idx, ents in enumerate(self.manual_entities_per_chunk):
            for ent in ents:
                for sent in sent_tokenize(self.chunks[idx].lower()):
                    if ent.lower() in sent:
                        pairs.append((ent.lower(), sent.strip()))
                        break

        texts = [f"{ent}: {sent}" for ent, sent in pairs]
        embeddings = self.embedding_model.encode(texts, normalize_embeddings=True)

        umap_model = UMAP(n_neighbors=15, n_components=5, metric='cosine', random_state=SEED)
        hdbscan_model = HDBSCAN(min_cluster_size=2, min_samples=1, metric='euclidean', prediction_data=True)

        topic_model = BERTopic(
            embedding_model=self.embedding_model,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            representation_model=KeyBERTInspired(),
            calculate_probabilities=True,
            verbose=False
        )
        topics, _ = topic_model.fit_transform(texts, embeddings=embeddings)

        topic_metadata = defaultdict(lambda: {"ents": [], "sents": [], "embs": []})
        for i, t in enumerate(topics):
            ent, sent = pairs[i]
            topic_metadata[t]["ents"].append(ent)
            topic_metadata[t]["sents"].append(sent)
            topic_metadata[t]["embs"].append(embeddings[i])

        self.topic_metadata = []
        topic_embeddings = []
        for t, d in topic_metadata.items():
            embs = np.vstack(d["embs"])
            mean_emb = np.mean(embs, axis=0)
            mean_emb /= np.linalg.norm(mean_emb)
            topic_embeddings.append(mean_emb)
            self.topic_metadata.append({
                "topic_id": t,
                "entities": list(set(d["ents"])),
                "sentences": d["sents"],
                "sentence_embeddings": embs
            })

        self.topic_embeddings = np.vstack(topic_embeddings)
        self.searcher = scann.scann_ops_pybind.builder(
            self.topic_embeddings, 3, "dot_product"
        ).tree(num_leaves=min(5, len(self.topic_embeddings)), num_leaves_to_search=2)\
         .score_brute_force().reorder(3).build()

    def search(self, query, top_k_topics=1, top_k_sents=2):
        q_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, _ = self.searcher.search(q_emb, final_num_neighbors=top_k_topics)
        output = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            sent_embs = meta["sentence_embeddings"]
            sims = np.dot(sent_embs, q_emb)
            for i in sims.argsort()[::-1][:top_k_sents]:
                output.append(meta["sentences"][i])
        return output

# === New Example Data ===
chunks = [
    "60-year-old female with chronic kidney disease stage 3, presenting with fatigue and ankle swelling.",
    "Medical history includes hypertension, hyperlipidemia, and gout. Medications: ACE inhibitor, statin.",
    "Lab results: serum creatinine 2.1 mg/dL, eGFR 38 mL/min/1.73m2, uric acid 9.0 mg/dL.",
    "Patient reports nocturia and decreased appetite. No chest pain or dyspnea.",
    "Exam shows 2+ pitting edema in lower extremities. Blood pressure 145/90 mmHg.",
    "Lifestyle: overweight, diet high in sodium, minimal exercise."
]

manual_entities_per_chunk = [
    ["kidney disease", "fatigue", "ankle swelling"],
    ["hypertension", "hyperlipidemia", "gout", "ACE inhibitor", "statin"],
    ["serum creatinine", "eGFR", "uric acid"],
    ["nocturia", "decreased appetite"],
    ["pitting edema", "blood pressure"],
    ["overweight", "diet", "salt", "exercise"]
]

queries = [
    "kidney function lab results",
    "edema assessment",
    "gout medication",
    "blood pressure management",
    "patient lifestyle",
]

# === Evaluation ===
print("🔧 Initializing Model 1...")
m1 = AllergyTopicSearcherModel1(chunks, manual_entities_per_chunk)
print("🔧 Initializing Model 2...")
m2 = AllergyTopicSearcherModel2(chunks, manual_entities_per_chunk)

print("\n🔍 Running Queries and Results")
for q in queries:
    res1 = m1.search(q, top_k_sents=2)
    res2 = m2.search(q, top_k_sents=2)
    print(f"\nQuery: {q}")
    print("Model 1:", res1)
    print("Model 2:", res2)


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


🔧 Initializing Model 1...
🔧 Initializing Model 2...

🔍 Running Queries and Results

Query: kidney function lab results
Model 1: ['lab results: serum creatinine 2.1 mg/dl, egfr 38 ml/min/1.73m2, uric acid 9.0 mg/dl.', 'patient reports nocturia and decreased appetite.']
Model 2: ['lab results: serum creatinine 2.1 mg/dl, egfr 38 ml/min/1.73m2, uric acid 9.0 mg/dl.', 'lab results: serum creatinine 2.1 mg/dl, egfr 38 ml/min/1.73m2, uric acid 9.0 mg/dl.']

Query: edema assessment
Model 1: ['exam shows 2+ pitting edema in lower extremities.', 'lab results: serum creatinine 2.1 mg/dl, egfr 38 ml/min/1.73m2, uric acid 9.0 mg/dl.']
Model 2: ['exam shows 2+ pitting edema in lower extremities.', 'lab results: serum creatinine 2.1 mg/dl, egfr 38 ml/min/1.73m2, uric acid 9.0 mg/dl.']

Query: gout medication
Model 1: ['medical history includes hypertension, hyperlipidemia, and gout.']
Model 2: ['medical history includes hypertension, hyperlipidemia, and gout.', 'medical history includes hypertension

In [14]:
# === Shared Imports & Setup ===
import os
import random
import numpy as np
import torch
import nltk
import logging
from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann

# Reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

# Download punkt tokenizer
nltk.download("punkt")


# === Data ===
chunks = [
    "60-year-old female with chronic kidney disease stage 3, presenting with fatigue and ankle swelling.",
    "Medical history includes hypertension, hyperlipidemia, and gout. Medications: ACE inhibitor, statin.",
    "Lab results: serum creatinine 2.1 mg/dL, eGFR 38 mL/min/1.73m2, uric acid 9.0 mg/dL.",
    "Patient reports nocturia and decreased appetite. No chest pain or dyspnea.",
    "Exam shows 2+ pitting edema in lower extremities. Blood pressure 145/90 mmHg.",
    "Lifestyle: overweight, diet high in sodium, minimal exercise."
]

queries = [
    "kidney function lab results",
    "edema assessment",
    "gout medication",
    "blood pressure management",
    "patient lifestyle",
]

# If your models rely on manual entities per chunk, here is an example:
manual_entities_per_chunk = [
    ["chronic kidney disease", "fatigue", "ankle swelling"],
    ["hypertension", "hyperlipidemia", "gout", "ACE inhibitor", "statin"],
    ["serum creatinine", "eGFR", "uric acid"],
    ["nocturia", "decreased appetite", "chest pain", "dyspnea"],
    ["pitting edema", "blood pressure"],
    ["overweight", "diet high in sodium", "minimal exercise"]
]

# Ground truth answers for evaluation (2 outputs per query)
ground_truth = {
    "kidney function lab results": [
        "Lab results: serum creatinine 2.1 mg/dL, eGFR 38 mL/min/1.73m2, uric acid 9.0 mg/dL.",
        "60-year-old female with chronic kidney disease stage 3, presenting with fatigue and ankle swelling."
    ],
    "edema assessment": [
        "Exam shows 2+ pitting edema in lower extremities. Blood pressure 145/90 mmHg.",
        "60-year-old female with chronic kidney disease stage 3, presenting with fatigue and ankle swelling."
    ],
    "gout medication": [
        "Medical history includes hypertension, hyperlipidemia, and gout. Medications: ACE inhibitor, statin.",
        "Lab results: serum creatinine 2.1 mg/dL, eGFR 38 mL/min/1.73m2, uric acid 9.0 mg/dL."
    ],
    "blood pressure management": [
        "Exam shows 2+ pitting edema in lower extremities. Blood pressure 145/90 mmHg.",
        "Medical history includes hypertension, hyperlipidemia, and gout. Medications: ACE inhibitor, statin."
    ],
    "patient lifestyle": [
        "Lifestyle: overweight, diet high in sodium, minimal exercise.",
        "Patient reports nocturia and decreased appetite. No chest pain or dyspnea."
    ]
}

# === Helper functions to compute metrics ===
def normalize_text(text):
    return text.lower().strip()

def precision_recall_f1(preds, truths):
    preds_norm = set(normalize_text(p) for p in preds)
    truths_norm = set(normalize_text(t) for t in truths)
    tp = len(preds_norm.intersection(truths_norm))
    precision = tp / len(preds_norm) if preds_norm else 0.0
    recall = tp / len(truths_norm) if truths_norm else 0.0
    if precision + recall == 0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)
    return precision, recall, f1

# === Initialize Models ===
print("🔧 Initializing Model 1...")
m1 = AllergyTopicSearcherModel1(chunks, manual_entities_per_chunk)

print("🔧 Initializing Model 2...")
m2 = AllergyTopicSearcherModel2(chunks, manual_entities_per_chunk)

# === Evaluation ===
print("\n\n🔍 Starting Evaluation\n")

metrics_m1 = []
metrics_m2 = []

for q in queries:
    gt_answers = ground_truth[q]
    res1 = m1.search(q, top_k_sents=2)
    res2 = m2.search(q, top_k_sents=2)

    prec1, rec1, f1_1 = precision_recall_f1(res1, gt_answers)
    prec2, rec2, f1_2 = precision_recall_f1(res2, gt_answers)

    metrics_m1.append((prec1, rec1, f1_1))
    metrics_m2.append((prec2, rec2, f1_2))

    print(f"🔎 Query: {q}")
    print("-" * 90)
    print("📘 Model 1 Results:")
    for i, r in enumerate(res1, 1):
        print(f"  {i}. {r}")
    print(f"  Precision: {prec1:.3f} | Recall: {rec1:.3f} | F1: {f1_1:.3f}")
    print("📙 Model 2 Results:")
    for i, r in enumerate(res2, 1):
        print(f"  {i}. {r}")
    print(f"  Precision: {prec2:.3f} | Recall: {rec2:.3f} | F1: {f1_2:.3f}")
    print("=" * 90)

# Aggregate overall metrics
def aggregate_metrics(metrics):
    precs, recs, f1s = zip(*metrics)
    return np.mean(precs), np.mean(recs), np.mean(f1s)

avg_prec_m1, avg_rec_m1, avg_f1_m1 = aggregate_metrics(metrics_m1)
avg_prec_m2, avg_rec_m2, avg_f1_m2 = aggregate_metrics(metrics_m2)

print("\n\n=== Overall Evaluation ===")
print(f"Model 1 - Precision@2: {avg_prec_m1:.3f}, Recall@2: {avg_rec_m1:.3f}, F1@2: {avg_f1_m1:.3f}")
print(f"Model 2 - Precision@2: {avg_prec_m2:.3f}, Recall@2: {avg_rec_m2:.3f}, F1@2: {avg_f1_m2:.3f}")


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


🔧 Initializing Model 1...
🔧 Initializing Model 2...


🔍 Starting Evaluation

🔎 Query: kidney function lab results
------------------------------------------------------------------------------------------
📘 Model 1 Results:
  1. lab results: serum creatinine 2.1 mg/dl, egfr 38 ml/min/1.73m2, uric acid 9.0 mg/dl.
  2. patient reports nocturia and decreased appetite.
  Precision: 0.500 | Recall: 0.500 | F1: 0.500
📙 Model 2 Results:
  1. lab results: serum creatinine 2.1 mg/dl, egfr 38 ml/min/1.73m2, uric acid 9.0 mg/dl.
  2. lab results: serum creatinine 2.1 mg/dl, egfr 38 ml/min/1.73m2, uric acid 9.0 mg/dl.
  Precision: 1.000 | Recall: 0.500 | F1: 0.667
🔎 Query: edema assessment
------------------------------------------------------------------------------------------
📘 Model 1 Results:
  1. exam shows 2+ pitting edema in lower extremities.
  2. medical history includes hypertension, hyperlipidemia, and gout.
  Precision: 0.000 | Recall: 0.000 | F1: 0.000
📙 Model 2 Results:
  1. exam sh

| Method                    | Type                  | Embeddings | Suitable for Short Texts | Notes                            |
| ------------------------- | --------------------- | ---------- | ------------------------ | -------------------------------- |
| LDA                       | Probabilistic         | No         | No                       | Classic baseline                 |
| NMF                       | Matrix Factorization  | No         | No                       | Fast, interpretable              |
| Neural Topic Models (VAE) | Neural Probabilistic  | Optional   | Yes                      | Powerful but complex             |
| Top2Vec                   | Embedding Clustering  | Yes        | Yes                      | No predefined topic number       |
| GSDMM                     | Probabilistic         | No         | Yes                      | Good for short texts             |
| CTM                       | Contextualized Neural | Yes (BERT) | Yes                      | State-of-art for semantic topics |
| LDA2Vec                   | Hybrid                | Yes        | No                       | Combines LDA & embeddings        |


In [12]:
#CTM method

# Install or upgrade cython, which gensim sometimes needs
!pip install --upgrade cython

# Now install gensim (specific stable version recommended)
!pip install gensim

# Finally install contextualized-topic-models (which depends on gensim)
!pip install contextualized-topic-models


Collecting contextualized-topic-models
  Using cached contextualized_topic_models-2.5.0-py2.py3-none-any.whl.metadata (24 kB)
Collecting gensim==4.2.0 (from contextualized-topic-models)
  Using cached gensim-4.2.0.tar.gz (23.2 MB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ipywidgets==7.5.1 (from contextualized-topic-models)
  Using cached ipywidgets-7.5.1-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting ipython==8.10.0 (from contextualized-topic-models)
  Using cached ipython-8.10.0-py3-none-any.whl.metadata (5.7 kB)
Collecting jedi>=0.16 (from ipython==8.10.0->contextualized-topic-models)
  Using cached jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting stack-data (from ipython==8.10.0->contextualized-topic-models)
  Using cached stack_data-0.6.3-py3-none-any.whl.metadata (18 kB)
Collecting widgetsnbextension~=3.5.0 (from ipywidgets==7.5.1->contextualized-topic-models)
  Using cached widgetsnbextension-3.5.2-py2.py3-none-any.whl.metadata (1.3 kB)
Colle

In [3]:
!apt-get install -y build-essential python3-dev
!pip install --upgrade pip setuptools wheel cython


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
build-essential is already the newest version (12.9ubuntu3).
python3-dev is already the newest version (3.10.6-1~22.04.1).
python3-dev set to manually installed.
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.


In [5]:
import gensim
print(gensim.__version__)


ImportError: cannot import name 'triu' from 'scipy.linalg' (/usr/local/lib/python3.11/dist-packages/scipy/linalg/__init__.py)

In [1]:
# === Shared Imports & Setup ===
import os
import random
import numpy as np
import torch
import nltk
import logging
from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from contextualized_topic_models.models.ctm import CombinedTM
from contextualized_topic_models.datasets.dataset import CTMDataset
from contextualized_topic_models.utils.data_preparation import TextHandler
import scann

# Reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

# Download punkt tokenizer
nltk.download("punkt")


# === AllergyTopicSearcherModel1 with CTM instead of BERTopic ===
class AllergyTopicSearcherModel1:
    def __init__(self, chunks, manual_entities_per_chunk, model_name="emilyalsentzer/Bio_ClinicalBERT", num_topics=6):
        self.chunks = chunks
        self.manual_entities_per_chunk = manual_entities_per_chunk
        self.embedding_model_name = model_name
        self.num_topics = num_topics

        self.embedding_model = None
        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        # SentenceTransformer model
        self.embedding_model = SentenceTransformer(self.embedding_model_name)

        entity_to_chunk = defaultdict(list)
        all_entities = []
        for idx, ents in enumerate(self.manual_entities_per_chunk):
            for ent in ents:
                ent_lower = ent.lower()
                all_entities.append(ent_lower)
                entity_to_chunk[ent_lower].append(idx)

        unique_entities = sorted(set(all_entities))
        entity_embeddings = self.embedding_model.encode(unique_entities, normalize_embeddings=True)

        # Prepare tokenized corpus for CTM
        text_handler = TextHandler()
        tokenized_entities = [text_handler.tokenize(ent) for ent in unique_entities]

        # Create CTM Dataset: bag of words + contextual embeddings
        ctm_dataset = CTMDataset(tokenized_entities, entity_embeddings)

        # Initialize CTM model
        self.topic_model = CombinedTM(
            bow_size=ctm_dataset.get_vocab_size(),
            contextual_size=entity_embeddings.shape[1],
            n_components=self.num_topics,
            num_epochs=300,
            train_embeddings=False,
            seed=SEED
        )

        # Train CTM on entities + embeddings
        self.topic_model.fit(ctm_dataset)

        # Get topic assignments (theta)
        topics_prob = self.topic_model.get_thetas(ctm_dataset)  # shape (num_entities, num_topics)
        topics = np.argmax(topics_prob, axis=1)

        # Organize entities per topic
        topic_to_entities = defaultdict(list)
        for ent, topic in zip(unique_entities, topics):
            topic_to_entities[topic].append(ent)

        # Collect sentences per topic from chunks based on entities
        topic_contexts = defaultdict(list)
        for topic, entities in topic_to_entities.items():
            for ent in entities:
                for chunk_id in entity_to_chunk[ent]:
                    for sent in sent_tokenize(self.chunks[chunk_id]):
                        if ent in sent.lower():
                            topic_contexts[topic].append(sent)

        # Remove duplicates in sentences
        for topic in topic_contexts:
            topic_contexts[topic] = list(set(topic_contexts[topic]))

        # Compute topic embeddings by averaging sentence embeddings
        topic_embeddings = []
        topic_metadata = []
        for topic_id, sentences in topic_contexts.items():
            if not sentences:
                continue
            sent_embs = self.embedding_model.encode(sentences, normalize_embeddings=True)
            mean_emb = np.mean(sent_embs, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": topic_to_entities[topic_id],
                "sentences": sentences,
                "sentence_embeddings": sent_embs
            })

        self.topic_embeddings = np.array(topic_embeddings)

        num_clusters = min(len(self.topic_embeddings), 5)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

        self.topic_metadata = topic_metadata

    def search(self, query, top_k_topics=1, top_k_sents=2):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        output = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            sents = meta["sentences"]
            sent_embs = meta["sentence_embeddings"]
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            output += [sents[i] for i in top_indices]
        return output


# === AllergyTopicSearcherModel2 with CTM instead of BERTopic ===
class AllergyTopicSearcherModel2:
    def __init__(self, chunks, manual_entities_per_chunk, model_name="emilyalsentzer/Bio_ClinicalBERT", num_topics=6):
        self.chunks = chunks
        self.manual_entities_per_chunk = manual_entities_per_chunk
        self.embedding_model_name = model_name
        self.num_topics = num_topics

        self.embedding_model = None
        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        self.embedding_model = SentenceTransformer(self.embedding_model_name)

        entity_context_pairs = []
        entity_to_chunk = defaultdict(list)

        for idx, ents in enumerate(self.manual_entities_per_chunk):
            chunk = self.chunks[idx].lower()
            sentences = sent_tokenize(chunk)
            for ent in ents:
                ent_lower = ent.lower()
                for sent in sentences:
                    if ent_lower in sent:
                        entity_context_pairs.append((ent_lower, sent.strip()))
                        entity_to_chunk[ent_lower].append(idx)
                        break

        contextual_texts = [f"{ent}: {context}" for ent, context in entity_context_pairs]
        contextual_embeddings = self.embedding_model.encode(contextual_texts, normalize_embeddings=True)

        # Tokenize texts for CTM
        text_handler = TextHandler()
        tokenized_contexts = [text_handler.tokenize(text) for text in contextual_texts]

        ctm_dataset = CTMDataset(tokenized_contexts, contextual_embeddings)

        # Train CTM
        self.topic_model = CombinedTM(
            bow_size=ctm_dataset.get_vocab_size(),
            contextual_size=contextual_embeddings.shape[1],
            n_components=self.num_topics,
            num_epochs=300,
            train_embeddings=False,
            seed=SEED
        )

        self.topic_model.fit(ctm_dataset)

        # Get topic assignments
        topics_prob = self.topic_model.get_thetas(ctm_dataset)
        topics = np.argmax(topics_prob, axis=1)

        topic_to_contexts = defaultdict(list)
        topic_to_entities = defaultdict(set)
        topic_to_embeddings = defaultdict(list)

        for i, topic in enumerate(topics):
            ent, context = entity_context_pairs[i]
            topic_to_contexts[topic].append(context)
            topic_to_entities[topic].add(ent)
            topic_to_embeddings[topic].append(contextual_embeddings[i])

        topic_embeddings = []
        topic_metadata = []

        for topic_id in topic_to_contexts:
            embeddings = topic_to_embeddings[topic_id]
            mean_emb = np.mean(embeddings, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10

            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": list(topic_to_entities[topic_id]),
                "sentences": topic_to_contexts[topic_id],
                "sentence_embeddings": np.array(topic_to_embeddings[topic_id])
            })

        self.topic_embeddings = np.array(topic_embeddings)

        num_clusters = min(len(self.topic_embeddings), 5)
        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
            .score_brute_force()
            .reorder(3)
            .build()
        )

        self.topic_metadata = topic_metadata

    def search(self, query, top_k_topics=1, top_k_sents=2):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        output = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            seen = set()
            sents = []
            embs = []
            for sent, emb in zip(meta["sentences"], meta["sentence_embeddings"]):
                if sent not in seen:
                    seen.add(sent)
                    sents.append(sent)
                    embs.append(emb)
            sent_embs = np.array(embs)
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            output += [sents[i] for i in top_indices]
        return output


# === Evaluation ===
chunks = [
    "Peanut allergy is one of the most common causes of severe allergic reactions. Symptoms can include hives, swelling, and anaphylaxis.",
    "Allergic rhinitis, commonly known as hay fever, is an allergic response to pollen, dust, or pet dander.",
    "Anaphylaxis is a serious, potentially life-threatening allergic reaction that can occur rapidly.",
    "Patients with food allergies, such as milk or eggs, need to be careful with their diet.",
    "Skin reactions like urticaria (hives) and eczema are often signs of allergies.",
    "He walks in cold weather but has no allergy symptoms or reactions."
]

manual_entities_per_chunk = [
    ["peanut allergy", "hives", "swelling", "anaphylaxis"],
    ["allergic rhinitis", "hay fever", "pollen", "dust", "pet dander"],
    ["anaphylaxis", "allergic reaction"],
    ["food allergies", "milk", "eggs"],
    ["urticaria", "hives", "eczema", "allergies"],
    ["cold weather", "allergy symptoms", "reactions"]
]

queries = [
    "peanut allergy",
    "symptoms of anaphylaxis",
    "hay fever",
    "eczema treatment",
    "allergic reaction to milk",
    "signs of food allergy",
    "urticaria causes",
    "pet dander allergies",
    "cold weather allergy",
    "hives and swelling"
]

print("🔧 Initializing Model 1 (CTM)...")
m1 = AllergyTopicSearcherModel1(chunks, manual_entities_per_chunk)

print("🔧 Initializing Model 2 (CTM)...")
m2 = AllergyTopicSearcherModel2(chunks, manual_entities_per_chunk)

print("\n\n🔍 Starting Evaluation")
for q in queries:
    res1 = m1.search(q, top_k_sents=2)
    res2 = m2.search(q, top_k_sents=2)

    print(f"\n\n🔎 Query: {q}")
    print("-" * 90)
    print("📘 Model 1:")
    for i, r in enumerate(res1, 1):
        print(f"  {i}. {r}")
    print("📙 Model 2:")
    for i, r in enumerate(res2, 1):
        print(f"  {i}. {r}")
    print("=" * 90)


ModuleNotFoundError: No module named 'contextualized_topic_models'

In [7]:
#topic2vec
!pip install top2vec[sentence_encoders]


Collecting top2vec[sentence_encoders]
  Downloading top2vec-1.0.36-py3-none-any.whl.metadata (22 kB)
Downloading top2vec-1.0.36-py3-none-any.whl (33 kB)
Installing collected packages: top2vec
Successfully installed top2vec-1.0.36


In [1]:
!pip install --upgrade scipy
!pip install --upgrade gensim
!pip install --upgrade top2vec[sentence_encoders]

Collecting gensim
  Using cached gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting numpy<2.0,>=1.18.5 (from gensim)
  Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Collecting scipy<1.14.0,>=1.7.0 (from gensim)
  Using cached scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Using cached gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7 MB)
Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m92.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.6/38.6 MB[0m [31m60.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy, scipy, g



In [1]:
from top2vec import Top2Vec



stdout:



stderr:

Traceback (most recent call last):
  File "<string>", line 4, in <module>
  File "/usr/local/lib/python3.11/dist-packages/numba_cuda/numba/cuda/cudadrv/driver.py", line 314, in __getattr__
    raise CudaSupportError("Error at driver init: \n%s:" %
numba.cuda.cudadrv.error.CudaSupportError: Error at driver init: 

CUDA driver library cannot be found.
If you are sure that a CUDA driver is installed,
try setting environment variable NUMBA_CUDA_DRIVER
with the file path of the CUDA driver shared library.
:


Not patching Numba


In [10]:
# === Full Robust Top2Vec Code with Two Models and Evaluation ===

import os
import random
import numpy as np
import torch
import nltk
import logging
from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from top2vec import Top2Vec
import scann

# Setup logging
logging.basicConfig(level=logging.INFO)

# Reproducibility
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Download punkt tokenizer
nltk.download("punkt")

# === Model 1 ===
class AllergyTopicSearcherTop2VecModel1:
    def __init__(self, chunks, manual_entities_per_chunk, embedding_model_name="all-MiniLM-L6-v2"):
        self.chunks = chunks
        self.manual_entities_per_chunk = manual_entities_per_chunk
        self.embedding_model = SentenceTransformer(embedding_model_name)
        self.topic_embeddings = None
        self.topic_metadata = []
        self.searcher = None
        self._prepare()

    def _prepare(self):
        entity_to_chunk = defaultdict(list)
        all_entities = []
        for idx, ents in enumerate(self.manual_entities_per_chunk):
            for ent in ents:
                ent_lower = ent.lower()
                all_entities.append(ent_lower)
                entity_to_chunk[ent_lower].append(idx)
        unique_entities = sorted(set(all_entities))

        if not unique_entities:
            raise ValueError("No unique entities found. Check input data.")

        topic_model = Top2Vec(
            documents=unique_entities,
            embedding_model="universal-sentence-encoder",
            min_count=1,
            verbose=False
        )

        topic_words, _, topic_nums = topic_model.get_topics()

        topic_to_entities = defaultdict(list)
        for i, topic in enumerate(topic_nums):
            topic_to_entities[topic].append(unique_entities[i])

        topic_contexts = defaultdict(list)
        for topic, entities in topic_to_entities.items():
            for ent in entities:
                for chunk_id in entity_to_chunk[ent]:
                    for sent in sent_tokenize(self.chunks[chunk_id]):
                        if ent in sent.lower():
                            topic_contexts[topic].append(sent)

        topic_embeddings = []
        topic_metadata = []

        for topic_id, sents in topic_contexts.items():
            sents = list(set(sents))
            if not sents:
                continue
            sent_embs = self.embedding_model.encode(sents, normalize_embeddings=True)
            mean_emb = np.mean(sent_embs, axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": topic_id,
                "entities": topic_to_entities[topic_id],
                "sentences": sents,
                "sentence_embeddings": sent_embs
            })

        if not topic_embeddings:
            raise ValueError("No topic embeddings found in Model 1.")

        self.topic_embeddings = np.vstack(topic_embeddings)
        self.topic_metadata = topic_metadata

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=min(5, len(self.topic_embeddings)), num_leaves_to_search=2)
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=2):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, _ = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        output = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            sent_embs = meta["sentence_embeddings"]
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            output += [meta["sentences"][i] for i in top_indices]
        return output


# === Model 2 ===
class AllergyTopicSearcherTop2VecModel2:
    def __init__(self, chunks, manual_entities_per_chunk, embedding_model_name="all-MiniLM-L6-v2"):
        self.chunks = chunks
        self.manual_entities_per_chunk = manual_entities_per_chunk
        self.embedding_model = SentenceTransformer(embedding_model_name)
        self.topic_embeddings = None
        self.topic_metadata = []
        self.searcher = None
        self._prepare()

    def _prepare(self):
        entity_context_pairs = []
        for idx, ents in enumerate(self.manual_entities_per_chunk):
            chunk = self.chunks[idx].lower()
            for sent in sent_tokenize(chunk):
                for ent in ents:
                    ent_lower = ent.lower()
                    if ent_lower in sent:
                        entity_context_pairs.append((ent_lower, sent.strip()))
                        break

        contextual_texts = [f"{ent}: {ctx}" for ent, ctx in entity_context_pairs]

        if not contextual_texts:
            raise ValueError("No contextual texts found. Check entity-chunk pairings.")

        topic_model = Top2Vec(
            documents=contextual_texts,
            embedding_model="universal-sentence-encoder",
            min_count=1,
            verbose=False
        )

        topic_words, _, topic_nums = topic_model.get_topics()

        topic_to_data = defaultdict(lambda: {"contexts": [], "ents": [], "embs": []})
        for i, topic in enumerate(topic_nums):
            ent, ctx = entity_context_pairs[i]
            topic_to_data[topic]["contexts"].append(ctx)
            topic_to_data[topic]["ents"].append(ent)
            topic_to_data[topic]["embs"].append(self.embedding_model.encode(f"{ent}: {ctx}", normalize_embeddings=True))

        topic_embeddings = []
        topic_metadata = []

        for tid, data in topic_to_data.items():
            if not data["embs"]:
                continue
            mean_emb = np.mean(data["embs"], axis=0)
            mean_emb /= np.linalg.norm(mean_emb) + 1e-10
            topic_embeddings.append(mean_emb)
            topic_metadata.append({
                "topic_id": tid,
                "entities": list(set(data["ents"])),
                "sentences": list(set(data["contexts"])),
                "sentence_embeddings": np.array(data["embs"])
            })

        if not topic_embeddings:
            raise ValueError("No topic embeddings found in Model 2.")

        self.topic_embeddings = np.vstack(topic_embeddings)
        self.topic_metadata = topic_metadata

        self.searcher = (
            scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
            .tree(num_leaves=min(5, len(self.topic_embeddings)), num_leaves_to_search=2)
            .score_brute_force()
            .reorder(3)
            .build()
        )

    def search(self, query, top_k_topics=1, top_k_sents=2):
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]
        neighbors, _ = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        output = []
        for idx in neighbors:
            meta = self.topic_metadata[idx]
            sent_embs = meta["sentence_embeddings"]
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]
            output += [meta["sentences"][i] for i in top_indices]
        return output


# === Evaluation ===
chunks = [
    "Peanut allergy is one of the most common causes of severe allergic reactions. Symptoms can include hives, swelling, and anaphylaxis.",
    "Allergic rhinitis, commonly known as hay fever, is an allergic response to pollen, dust, or pet dander.",
    "Anaphylaxis is a serious, potentially life-threatening allergic reaction that can occur rapidly.",
    "Patients with food allergies, such as milk or eggs, need to be careful with their diet.",
    "Skin reactions like urticaria (hives) and eczema are often signs of allergies.",
    "He walks in cold weather but has no allergy symptoms or reactions."
]

manual_entities_per_chunk = [
    ["peanut allergy", "hives", "swelling", "anaphylaxis"],
    ["allergic rhinitis", "hay fever", "pollen", "dust", "pet dander"],
    ["anaphylaxis", "allergic reaction"],
    ["food allergies", "milk", "eggs"],
    ["urticaria", "hives", "eczema", "allergies"],
    ["cold weather", "allergy symptoms", "reactions"]
]

queries = [
    "peanut allergy",
    "symptoms of anaphylaxis",
    "hay fever",
    "eczema treatment",
    "allergic reaction to milk",
    "signs of food allergy",
    "urticaria causes",
    "pet dander allergies",
    "cold weather allergy",
    "hives and swelling"
]

print("🔧 Initializing Model 1...")
m1 = AllergyTopicSearcherTop2VecModel1(chunks, manual_entities_per_chunk)

print("🔧 Initializing Model 2...")
m2 = AllergyTopicSearcherTop2VecModel2(chunks, manual_entities_per_chunk)

print("\\n🔍 Evaluating Queries:")
for q in queries:
    print(f"\\n🧪 Query: {q}")
    print("Model 1:", m1.search(q))
    print("Model 2:", m2.search(q))


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


🔧 Initializing Model 1...


2025-07-06 15:10:23,258 - top2vec - INFO - Downloading universal-sentence-encoder model
INFO:top2vec:Downloading universal-sentence-encoder model


ValueError: need at least one array to concatenate