In [None]:
# === Install dependencies
!pip install google-cloud-bigquery pandas scann hdbscan umap-learn nltk sentence-transformers bertopic --quiet
import nltk
nltk.download('punkt')



In [None]:
# === Imports ===
import os
import random
import numpy as np
import torch
import nltk
import logging
import pandas as pd
from collections import defaultdict
from nltk.tokenize import sent_tokenize
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from hdbscan import HDBSCAN
from umap import UMAP
import scann
from google.cloud import bigquery
import ast
import json

# === Set random seed for reproducibility ===
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

# === Setup logging ===
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# === Download tokenizer ===
nltk.download("punkt")


# === Function to read and normalize BigQuery data ===
def load_data_from_bigquery(project_id, dataset_id, table_id):
    client = bigquery.Client()
    full_table_path = f"{project_id}.{dataset_id}.{table_id}"

    query = f"""
    SELECT chunk, entities
    FROM `{full_table_path}`
    """
    df = client.query(query).to_dataframe()

    def parse_entities(val):
        if isinstance(val, list):
            return val
        elif isinstance(val, str):
            val = val.strip()
            try:
                # Try JSON first (e.g., '["a", "b"]')
                parsed = json.loads(val)
                if isinstance(parsed, list):
                    return parsed
            except json.JSONDecodeError:
                pass

            try:
                # Try Python literal (e.g., "['a', 'b']")
                parsed = ast.literal_eval(val)
                if isinstance(parsed, list):
                    return parsed
            except Exception:
                pass

            # Fallback: comma-separated string
            return [e.strip() for e in val.split(',') if e.strip()]

        return []

    df['entities'] = df['entities'].apply(parse_entities)

    # Final output format
    chunks = df['chunk'].tolist()
    manual_entities_per_chunk = df['entities'].tolist()
    return chunks, manual_entities_per_chunk


# === AllergyTopicSearcher class ===
class AllergyTopicSearcher:
    def __init__(self, chunks, manual_entities_per_chunk, model_name="emilyalsentzer/Bio_ClinicalBERT"):
        self.chunks = chunks
        self.manual_entities_per_chunk = manual_entities_per_chunk
        self.embedding_model_name = model_name

        self.embedding_model = None
        self.topic_model = None
        self.topic_metadata = []
        self.topic_embeddings = None
        self.searcher = None

        self._prepare()

    def _prepare(self):
        try:
            logger.info("Loading embedding model...")
            self.embedding_model = SentenceTransformer(self.embedding_model_name)

            entity_to_chunk = defaultdict(list)
            all_entities = []
            for idx, ents in enumerate(self.manual_entities_per_chunk):
                for ent in ents:
                    ent_lower = ent.lower()
                    all_entities.append(ent_lower)
                    entity_to_chunk[ent_lower].append(idx)

            unique_entities = sorted(set(all_entities))
            logger.info(f"Unique entities found: {len(unique_entities)}")

            entity_embeddings = self.embedding_model.encode(unique_entities, normalize_embeddings=True)

            umap_model = UMAP(n_neighbors=15, n_components=5, metric='cosine', random_state=SEED)
            hdbscan_model = HDBSCAN(min_cluster_size=2, min_samples=1, metric='euclidean', prediction_data=True)

            self.topic_model = BERTopic(
                embedding_model=self.embedding_model,
                umap_model=umap_model,
                hdbscan_model=hdbscan_model,
                representation_model=KeyBERTInspired(),
                calculate_probabilities=True,
                verbose=False
            )

            topics, _ = self.topic_model.fit_transform(unique_entities, embeddings=entity_embeddings)

            topic_to_entities = defaultdict(list)
            for ent, topic in zip(unique_entities, topics):
                topic_to_entities[topic].append(ent)

            topic_contexts = defaultdict(list)
            for topic, entities in topic_to_entities.items():
                for ent in entities:
                    for chunk_id in entity_to_chunk[ent]:
                        for sent in sent_tokenize(self.chunks[chunk_id]):
                            if ent in sent.lower():
                                topic_contexts[topic].append(sent)

            for topic in topic_contexts:
                topic_contexts[topic] = list(set(topic_contexts[topic]))

            topic_embeddings = []
            topic_metadata = []

            for topic_id, sentences in topic_contexts.items():
                if not sentences:
                    continue
                sent_embs = self.embedding_model.encode(sentences, normalize_embeddings=True)
                mean_emb = np.mean(sent_embs, axis=0)
                mean_emb /= np.linalg.norm(mean_emb) + 1e-10
                topic_embeddings.append(mean_emb)
                topic_metadata.append({
                    "topic_id": topic_id,
                    "entities": topic_to_entities[topic_id],
                    "sentences": sentences,
                    "sentence_embeddings": sent_embs
                })

            self.topic_embeddings = np.array(topic_embeddings)

            if len(self.topic_embeddings) == 0:
                raise ValueError("No topic embeddings generated.")

            num_clusters = min(len(self.topic_embeddings), 5)
            self.searcher = (
                scann.scann_ops_pybind.builder(self.topic_embeddings, 3, "dot_product")
                .tree(num_leaves=num_clusters, num_leaves_to_search=2, training_sample_size=len(self.topic_embeddings))
                .score_brute_force()
                .reorder(3)
                .build()
            )

            self.topic_metadata = topic_metadata
            logger.info("Preparation complete. Ready to search.")

        except Exception as e:
            logger.error(f"Error during preparation: {e}")
            raise

    def search(self, query, top_k_topics=1, top_k_sents=3):
        if self.searcher is None or not self.topic_metadata:
            raise RuntimeError("Searcher not initialized. Call _prepare() first.")

        print(f"\n🔎 Query: '{query}'\n{'=' * 60}")
        query_emb = self.embedding_model.encode([query], normalize_embeddings=True)[0]

        neighbors, scores = self.searcher.search(query_emb, final_num_neighbors=top_k_topics)

        for rank, (idx, score) in enumerate(zip(neighbors, scores), 1):
            meta = self.topic_metadata[idx]
            print(f"\nRank {rank} [Score: {score:.4f}]")
            print(f"  Topic ID   : {meta['topic_id']}")
            print(f"  Entities   : {meta['entities']}")

            sents = meta["sentences"]
            sent_embs = meta["sentence_embeddings"]
            sent_embs_norm = sent_embs / np.linalg.norm(sent_embs, axis=1, keepdims=True)
            sims = np.dot(sent_embs_norm, query_emb)
            top_indices = sims.argsort()[::-1][:top_k_sents]

            print(f"  Top {top_k_sents} Relevant Sentences:")
            for i in top_indices:
                print(f"    - {sents[i]} (score: {sims[i]:.4f})")


# === Entry point ===
if __name__ == "__main__":
    # 🔁 Replace with your actual values
    project_id = "your_project"
    dataset_id = "your_dataset"
    table_id = "chunk_entities_table"

    chunks, manual_entities_per_chunk = load_data_from_bigquery(project_id, dataset_id, table_id)

    searcher = AllergyTopicSearcher(chunks, manual_entities_per_chunk)

    # Test it with a query
    searcher.search("What are skin allergy symptoms?")


In [None]:
#-- Step 1: Create the table with ARRAY<STRING> for entities
CREATE TABLE `your_project.your_dataset.chunk_entities_table` (
  chunk STRING,
  entities ARRAY<STRING>
);

-- Step 2: Insert example chunks and associated entities
INSERT INTO `your_project.your_dataset.chunk_entities_table` (chunk, entities)
VALUES
(
  "Peanut allergy is one of the most common causes of severe allergic reactions. Symptoms can include hives, swelling, and anaphylaxis.",
  ["peanut allergy", "hives", "swelling", "anaphylaxis"]
),
(
  "Allergic rhinitis, commonly known as hay fever, is an allergic response to pollen, dust, or pet dander.",
  ["allergic rhinitis", "hay fever", "pollen", "dust", "pet dander"]
),
(
  "Anaphylaxis is a serious, potentially life-threatening allergic reaction that can occur rapidly.",
  ["anaphylaxis", "allergic reaction"]
),
(
  "Patients with food allergies, such as milk or eggs, need to be careful with their diet.",
  ["food allergies", "milk", "eggs"]
),
(
  "Skin reactions like urticaria (hives) and eczema are often signs of allergies.",
  ["urticaria", "hives", "eczema", "allergies"]
),
(
  "He walks in cold weather but has no allergy symptoms or reactions.",
  ["cold weather", "allergy symptoms", "reactions"]
);







#recomended bigquery schema:

| Column Name | Type                                        | Description                                              |
| ----------- | ------------------------------------------- | -------------------------------------------------------- |
| `chunk`     | `STRING`                                    | The full chunk of text (a paragraph or document snippet) |
| `entities`  | `ARRAY<STRING>` (preferred) **or** `STRING` | The list of entities found in the chunk                  |
