In [None]:
# Google Cloud Document AI client
!pip install google-cloud-documentai

# Google Cloud BigQuery (for reading CUI vectors)
!pip install pandas-gbq google-cloud-bigquery

# Numpy
!pip install numpy

# Google GenAI SDK (for Vertex AI embeddings)
!pip install google-genai

# Optional but useful: pandas (for dataframe handling)
!pip install pandas


In [None]:
#Updated code 1 woth 3 cui embedding: from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Any
import numpy as np
import re
from datetime import datetime
import os
import pandas_gbq

# ----------------------------
# CONFIG
# ----------------------------
PROJECT_ID = "YOUR_GCP_PROJECT"
LOCATION = "global"
BQ_CUI_TABLE = "your_project.your_dataset.cui_embeddings"  # columns: cui STRING, embedding ARRAY<FLOAT64>

os.environ["GOOGLE_CLOUD_PROJECT"] = PROJECT_ID
os.environ["GOOGLE_CLOUD_LOCATION"] = LOCATION
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True"

# ----------------------------
# Vertex AI Client
# ----------------------------
from google import genai
from google.genai.types import EmbedContentConfig
_client = genai.Client()

def gemini_embed(texts: List[str], *, task_type: str = "RETRIEVAL_DOCUMENT") -> np.ndarray:
    """Batch-embed text with Vertex AI (Gemini) embeddings."""
    if not texts:
        return np.zeros((0, 0), dtype=np.float32)
    cfg = EmbedContentConfig(task_type=task_type)
    resp = _client.models.embed_content(
        model="gemini-embedding-001",
        contents=texts,
        config=cfg,
    )
    return np.array([e.values for e in resp.embeddings], dtype=np.float32)

# ----------------------------
# BigQuery: load CUI embeddings
# ----------------------------
def load_cui_vectors_bq(table_fqn: str = BQ_CUI_TABLE) -> Dict[str, np.ndarray]:
    df = pandas_gbq.read_gbq(f"SELECT cui, embedding FROM `{table_fqn}`", project_id=PROJECT_ID)
    out: Dict[str, np.ndarray] = {}
    for _, row in df.iterrows():
        out[row["cui"]] = np.array(row["embedding"], dtype=np.float32)
    return out

CUI_VECTORS = load_cui_vectors_bq(BQ_CUI_TABLE)

# ----------------------------
# Dummy UMLS linker — replace with your logic
# ----------------------------
def umls_link(text: str) -> List[Dict[str, Any]]:
    """
    Return e.g. [{"cui": "C0001449", "score": 0.97},
                  {"cui": "C0002336", "score": 0.88},
                  {"cui": "C0005555", "score": 0.77}]
    """
    return []

# ----------------------------
# Date handling
# ----------------------------
DATE_REGEX = re.compile(r"\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})(?:[ ,;]+(\d{1,2}[:.]\d{2}))?\b")

def norm_date(s: str) -> Optional[str]:
    s = s.strip().replace('.', ':').replace(',', ' ')
    try:
        for fmt in ("%m/%d/%Y", "%m/%d/%y", "%m-%d-%Y", "%m-%d-%y"):
            try:
                return datetime.strptime(s.split()[0], fmt).strftime("%Y-%m-%d")
            except Exception:
                pass
        m = DATE_REGEX.search(s)
        if m:
            raw = m.group(1)
            for fmt in ("%m/%d/%Y", "%m/%d/%y", "%m-%d-%Y", "%m-%d-%y"):
                try:
                    return datetime.strptime(raw, fmt).strftime("%Y-%m-%d")
                except Exception:
                    pass
    except Exception:
        pass
    return None

# ----------------------------
# Entity class (UPDATED)
# ----------------------------
@dataclass
class Entity:
    entity_id: str
    page: int
    text: str
    kind: str                          # "date" | "kv" | "section" | "other"
    section_title: Optional[str]
    value_norm_date: Optional[str]

    # NEW: store multiple CUIs + embeddings
    cui_candidates: List[str] = field(default_factory=list)
    cui_scores: List[float] = field(default_factory=list)
    embeddings: List[np.ndarray] = field(default_factory=list)

    # Backward compatibility: best single values
    cui: Optional[str] = None
    cui_score: Optional[float] = None
    emb: Optional[np.ndarray] = None

# ----------------------------
# Parse DocAI entities
# ----------------------------
def parse_docai_entities(docai: Dict) -> List[Entity]:
    entities: List[Entity] = []
    pages = docai.get("document", {}).get("pages", [])
    for p_idx, page in enumerate(pages):
        # KVs
        for ff in page.get("formFields", []) or []:
            name = ff.get("fieldName", {}).get("textAnchor", {}).get("content", "") or ff.get("fieldName", {}).get("content", "")
            value = ff.get("fieldValue", {}).get("textAnchor", {}).get("content", "") or ff.get("fieldValue", {}).get("content", "")
            kv_text = f"{name.strip()}: {value.strip()}".strip(": ")
            entities.append(Entity(
                entity_id=f"kv:{p_idx}:{len(entities)}",
                page=p_idx, text=kv_text, kind="kv",
                section_title=None, value_norm_date=norm_date(kv_text)
            ))
            if DATE_REGEX.search(kv_text or ""):
                entities.append(Entity(
                    entity_id=f"date:{p_idx}:{len(entities)}",
                    page=p_idx, text=kv_text, kind="date",
                    section_title=None, value_norm_date=norm_date(kv_text)
                ))
        # Headings + paragraphs
        heading = None
        for line in page.get("lines", []) or []:
            t = line.get("layout", {}).get("textAnchor", {}).get("content", "") or line.get("layout", {}).get("content", "")
            if not t:
                continue
            if t.endswith(":") or t.isupper():
                heading = t.strip().strip(":")
            else:
                entities.append(Entity(
                    entity_id=f"sec:{p_idx}:{len(entities)}",
                    page=p_idx,
                    text=f"{(heading or '').upper()}: {t}",
                    kind="section",
                    section_title=(heading or "").upper(),
                    value_norm_date=norm_date(t)
                ))
                for m in DATE_REGEX.finditer(t or ""):
                    entities.append(Entity(
                        entity_id=f"datei:{p_idx}:{len(entities)}",
                        page=p_idx, text=t, kind="date",
                        section_title=(heading or "").upper(),
                        value_norm_date=norm_date(m.group(0))
                    ))
    return entities

# ----------------------------
# Attach CUIs and embeddings (UPDATED)
# ----------------------------
def attach_cuis_and_embeddings_local(entities: List[Entity], topk: int = 3) -> None:
    """
    Attach top-k CUIs and their embeddings (if available) to entities.
    Fallback to text embedding where needed.
    """
    to_embed_texts = []
    text_entity_refs = []

    for e in entities:
        cand = umls_link(e.text)

        if cand:
            # pick top-k best matches
            top_matches = sorted(cand, key=lambda x: x.get("score", 0.0), reverse=True)[:topk]

            e.cui_candidates = [m["cui"] for m in top_matches]
            e.cui_scores = [m.get("score", 0.0) for m in top_matches]

            # Add embeddings if available
            for cui, score in zip(e.cui_candidates, e.cui_scores):
                if cui in CUI_VECTORS:
                    e.embeddings.append(CUI_VECTORS[cui])

            # assign best candidate for backward use
            if e.cui_candidates:
                e.cui = e.cui_candidates[0]
                e.cui_score = e.cui_scores[0]
                if e.embeddings:
                    e.emb = e.embeddings[0]

        # If no CUI match or no precomputed embedding, fall back to text
        if not e.embeddings:
            to_embed_texts.append(f"[ENTITY:{e.kind}] {e.text[:500]}")
            text_entity_refs.append(e)

    if to_embed_texts:
        embs = gemini_embed(to_embed_texts)
        for ent, vec in zip(text_entity_refs, embs):
            ent.embeddings.append(vec)
            ent.emb = vec  # fallback single emb

# ----------------------------
# Cosine similarity + graph building
# ----------------------------
def _cos(a: np.ndarray, b: np.ndarray) -> float:
    da = float(np.linalg.norm(a) + 1e-8)
    db = float(np.linalg.norm(b) + 1e-8)
    return float(np.dot(a, b) / (da * db))

def build_edges(entities: List[Entity], sim_threshold: float=0.65) -> Dict[str, List[Tuple[str, float]]]:
    ids = [e.entity_id for e in entities if e.emb is not None]
    by_id = {e.entity_id: e for e in entities}
    adj: Dict[str, List[Tuple[str, float]]] = {i: [] for i in ids}
    for i in range(len(ids)):
        ei = by_id[ids[i]]
        for j in range(i+1, len(ids)):
            ej = by_id[ids[j]]
            s = _cos(ei.emb, ej.emb)
            if s >= sim_threshold:
                adj[ei.entity_id].append((ej.entity_id, s))
                adj[ej.entity_id].append((ei.entity_id, s))
    return adj

# ----------------------------
# Query utilities
# ----------------------------
def _embed_query_vector(query_text: str) -> np.ndarray:
    cand = umls_link(query_text)
    if cand:
        best = max(cand, key=lambda x: x.get("score", 0.0))
        cui = best["cui"]
        if cui in CUI_VECTORS:
            return CUI_VECTORS[cui]
    qvec = gemini_embed([f"[QUERY] {query_text[:800]}"])
    return qvec[0]

def _is_date_like(q: str) -> Optional[str]:
    return norm_date(q)

def _nearest_neighbors(vec: np.ndarray, entities: List[Entity], topk: int = 10) -> List[Tuple[Entity, float]]:
    cands = [(e, _cos(vec, e.emb)) for e in entities if e.emb is not None]
    cands.sort(key=lambda x: -x[1])
    return cands[:topk]

def build_model_local(docai: Dict, min_edge: float = 0.60):
    ents = parse_docai_entities(docai)
    attach_cuis_and_embeddings_local(ents)
    adj = build_edges(ents, sim_threshold=min_edge)
    by_id = {e.entity_id: e for e in ents}
    return ents, adj, by_id

def query_any_local(docai: Dict, query_text: str,
                    min_direct: float = 0.60, min_indirect: float = 0.45, decay: float = 0.8,
                    topk: int = 8):
    ents, adj, by_id = build_model_local(docai, min_edge=min_direct)

    q_iso = _is_date_like(query_text)
    q_vec = _embed_query_vector(query_text)

    nn = _nearest_neighbors(q_vec, ents, topk=topk)

    date_anchors = []
    if q_iso:
        date_anchors = [e for e in ents if e.kind == "date" and e.value_norm_date == q_iso]
        for e in date_anchors:
            nn.insert(0, (e, 1.0))
        seen = set()
        nn = [(e, s) for e, s in nn if not (e.entity_id in seen or seen.add(e.entity_id))]

    direct = []
    kept = set()
    for e, score in nn:
        if score < min_direct:
            continue
        direct.append({
            "entity_id": e.entity_id,
            "kind": e.kind,
            "section_title": e.section_title,
            "cui": e.cui,
            "cui_candidates": e.cui_candidates,
            "cui_scores": e.cui_scores,
            "score": round(float(score), 3),
            "evidence": e.text[:240]
        })
        kept.add(e.entity_id)
        if len(direct) >= topk:
            break

    indirect = []
    start_nodes = date_anchors if date_anchors else [by_id[d["entity_id"]] for d in direct[:3]]
    for src in start_nodes:
        for mid, s1 in adj.get(src.entity_id, []):
            if s1 < min_indirect:
                continue
            for tgt, s2 in adj.get(mid, []):
                if tgt in kept:
                    continue
                ind_score = s1 * s2 * decay
                if ind_score >= min_indirect:
                    tEnt = by_id[tgt]
                    indirect.append({
                        "target_entity_id": tgt,
                        "kind": tEnt.kind,
                        "section_title": tEnt.section_title,
                        "score": round(float(ind_score), 3),
                        "path": [
                            {"from": src.entity_id, "to": mid, "score": round(float(s1), 3)},
                            {"from": mid, "to": tgt, "score": round(float(s2), 3)}
                        ],
                        "evidence": tEnt.text[:240]
                    })
                    kept.add(tgt)

    direct.sort(key=lambda x: -x["score"])
    indirect.sort(key=lambda x: -x["score"])
    return {
        "query": query_text,
        "query_iso_date": q_iso,
        "direct": direct,
        "indirect": indirect
    }
