In [None]:
%pip install chromadb sentence-transformers google-generativeai

In [None]:
API_KEY = "AIzaSyApj19v7hlquafvI6nJ7qK7gkVrJR1siQ4"

In [None]:
import os
import glob
import json
import hashlib
from typing import List, Dict, Any, Optional
import datetime as dt
import chromadb
from chromadb.utils import embedding_functions
from sentence_transformers import CrossEncoder
import google.generativeai as genai

In [None]:

# CONFIG


DATA_DIR = "data"                     # .txt policy files
CHROMA_PATH = "chroma_store"          # local Chroma directory

# Cache files
METADATA_CACHE_PATH = "metadata_cache.json"
ANSWER_CACHE_PATH   = "answer_cache.json"

COLLECTION_NAME = "nebula_policies"

EMBED_MODEL   = "all-MiniLM-L6-v2"
RERANK_MODEL  = "cross-encoder/ms-marco-MiniLM-L-6-v2"
GEMINI_MODEL  = "gemini-2.5-flash"



if not API_KEY:
    raise RuntimeError(" Please set GOOGLE_API_KEY before running this script.")

genai.configure(api_key=API_KEY)



#  SMALL UTILS: LOAD/SAVE JSON CACHES


def load_json(path: str):
    if os.path.exists(path):
        try:
            with open(path, "r", encoding="utf-8") as f:
                return json.load(f)
        except Exception:
            return {}
    return {}

def save_json(path: str, data: dict):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2)


# In-memory caches (backed by JSON files)
metadata_cache = load_json(METADATA_CACHE_PATH)  # key: filename
answer_cache   = load_json(ANSWER_CACHE_PATH)    # key: md5(query)



#  BASIC HELPERS


def clean_metadata(meta: dict) -> dict:
    return {k: ("" if v is None else v) for k, v in meta.items()}


def parse_date(d: str) -> Optional[dt.date]:
    if not d:
        return None
    d = d.strip()
    for fmt in ("%Y-%m-%d", "%d-%m-%Y", "%Y/%m/%d", "%d/%m/%Y"):
        try:
            return dt.datetime.strptime(d, fmt).date()
        except Exception:
            continue
    try:
        return dt.date.fromisoformat(d)
    except Exception:
        return None


def parse_version(v: str) -> float:
    if not v:
        return 0.0
    v = v.lower().strip()
    if v.startswith("v"):
        v = v[1:]
    try:
        return float(v)
    except Exception:
        return 0.0



#  METADATA EXTRACTION (LLM) WITH CACHE


def extract_metadata_llm(text: str, filename: str) -> Dict[str, Any]:
    """
    Runs ONCE per doc thanks to metadata_cache.
    """

    if filename in metadata_cache:
        return metadata_cache[filename]

    prompt = f"""
You are an HR policy metadata extractor.

Return ONLY valid JSON (no explanations).

Extract:
- audience_scope: one of
  ["interns","full_time_employees","managers","contractors","company_wide","all_employees","other"]
- policy_type: short snake_case like
  "remote_work", "vacation", "sick_leave", "benefits", "travel", "code_of_conduct", etc.
- effective_date: ISO "YYYY-MM-DD" if explicitly stated, else "".
- version: like "v1", "v2", "v2.1" if present, else "".
- specificity_level: one of ["role_specific","team_specific","company_wide","unclear"].
- override_notes: short string explaining if this document updates/overrides earlier policies (or "" if not clear).

Document filename: {filename}

Document content (trimmed):
\"\"\"{text[:4000]}\"\"\"
"""

    model = genai.GenerativeModel(GEMINI_MODEL)
    resp = model.generate_content(prompt)

    try:
        meta = json.loads(resp.text)
    except Exception:
        meta = {}

    final = {
        "audience_scope":   meta.get("audience_scope", "other"),
        "policy_type":      meta.get("policy_type", "uncategorized"),
        "effective_date":   meta.get("effective_date", ""),
        "version":          meta.get("version", ""),
        "specificity_level":meta.get("specificity_level", "unclear"),
        "override_notes":   meta.get("override_notes", ""),
    }

    metadata_cache[filename] = final
    save_json(METADATA_CACHE_PATH, metadata_cache)

    return final



# QUERY CONTEXT (user_role + policy_type)


def extract_query_context(query: str) -> Dict[str, Any]:
    prompt = f"""
You are a classifier for HR questions.

Return ONLY JSON (no explanation) with:
- user_role: one of ["intern","full_time_employee","manager","contractor","unknown"]
- policy_type: short snake_case label like "remote_work", "vacation", "sick_leave", etc.

User question:
\"\"\"{query}\"\"\"
"""

    model = genai.GenerativeModel(GEMINI_MODEL)
    resp = model.generate_content(prompt)

    try:
        ctx = json.loads(resp.text)
    except Exception:
        ctx = {}

    return {
        "user_role":  ctx.get("user_role", "unknown"),
        "policy_type":ctx.get("policy_type", "uncategorized"),
    }



#  INIT & INGEST (NO RE-INGEST IF ALREADY DONE)


def init_chroma():
    client = chromadb.PersistentClient(path=CHROMA_PATH)
    embed_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
        model_name=EMBED_MODEL
    )
    return client.get_or_create_collection(
        name=COLLECTION_NAME,
        embedding_function=embed_fn
    )


def ingest_all_txt(collection):
    """
    Only ingests if collection is empty.
    Saves LLM + Chroma cost on subsequent calls.
    """
    if collection.count() > 0:
        return

    paths = glob.glob(f"{DOC_DIR}/*.txt")
    if not paths:
        print("❌ No .txt files found.")
        return

    for path in paths:
        doc_id = os.path.basename(path)

        # Not strictly needed if count()==0, but safe if reused in future
        if collection.get(ids=[doc_id])["ids"]:
            continue

        with open(path, "r", encoding="utf-8") as f:
            text = f.read()

        metadata = extract_metadata_llm(text, doc_id)
        safe_meta = clean_metadata({"source": doc_id, **metadata})

        collection.add(
            ids=[doc_id],
            documents=[text],
            metadatas=[safe_meta]
        )

    print(" Ingestion complete with metadata.")


#  RETRIEVAL


def retrieve_candidates(collection, query: str, q_ctx: Dict[str, Any], top_n=50):

    where = {}
    policy_type = q_ctx.get("policy_type")
    if policy_type and policy_type != "uncategorized":
        where["policy_type"] = policy_type

    result = collection.query(
        query_texts=[query],
        n_results=top_n,
        where=where or None
    )

    docs = []
    if not result["ids"]:
        return docs

    for i, doc_id in enumerate(result["ids"][0]):
        docs.append({
            "id": doc_id,
            "text": result["documents"][0][i],
            "meta": result["metadatas"][0][i]
        })

    return docs



#  RERANK WITH CONFLICT LOGIC


reranker = None

def rerank_with_conflict_logic(query: str,
                               docs: List[Dict],
                               q_ctx: Dict[str, Any],
                               top_k=3) -> List[Dict[str, Any]]:

    global reranker
    if reranker is None:
        reranker = CrossEncoder(RERANK_MODEL)

    user_role = q_ctx.get("user_role", "unknown")

    pairs = [[query, d["text"]] for d in docs]
    base_scores = reranker.predict(pairs)

    scored_docs = []
    for d, s in zip(docs, base_scores):
        meta = d.get("meta", {})
        aud  = (meta.get("audience_scope") or "").lower()
        spec = (meta.get("specificity_level") or "").lower()
        eff  = parse_date(meta.get("effective_date", ""))
        ver  = parse_version(meta.get("version", ""))

        bonus = 0.0

        # Role-specific bonus
        if user_role != "unknown":
            if user_role in aud:  # e.g. intern in "interns"
                bonus += 2.0
            elif aud in ("all_employees", "company_wide"):
                bonus += 0.5

        # Specificity bonus
        if spec == "role_specific":
            bonus += 0.5
        elif spec == "company_wide":
            bonus += 0.2

        # Recency bonus
        if eff:
            years_since_2000 = (eff - dt.date(2000, 1, 1)).days / 365.0
            bonus += 0.02 * years_since_2000

        # Version bonus
        if ver > 0:
            bonus += 0.1 * ver

        total_score = float(s) + float(bonus)
        scored_docs.append({**d, "score": total_score})

    scored_docs.sort(key=lambda x: x["score"], reverse=True)
    return scored_docs[:top_k]



#   RELEVANT SEGMENT EXTRACTION


def extract_relevant_segment(query: str,
                             text: str,
                             max_chars: int = 1200,
                             context_margin: int = 400) -> str:
    
    """
    Heuristic, no extra LLM calls.
    Reduce tokens passed to Gemini while keeping relevant context.
    """
    text_str = text if isinstance(text, str) else str(text)
    q_lower = query.lower()

    idx = -1
    # try to find first occurrence of any query word
    for word in q_lower.split():
        w = word.strip()
        if not w:
            continue
        pos = text_str.lower().find(w)
        if pos != -1:
            idx = pos
            break

    if idx == -1:
        # fallback: just return the beginning
        return text_str[:max_chars]

    start = max(0, idx - context_margin)
    end   = min(len(text_str), start + max_chars)
    return text_str[start:end]


# FINAL REASONING OVER MULTIPLE DOCS 


def reason_over_docs(query: str,
                     docs: List[Dict[str, Any]],
                     q_ctx: Dict[str, Any]) -> str:
    """
    Uses relevant segments instead of whole documents to save tokens.
    """

    context_blocks = []
    for i, d in enumerate(docs, 1):
        m = d["meta"]
        segment = extract_relevant_segment(query, d["text"])

        context_blocks.append(
            f"[DOC {i}]\n"
            f"source: {m.get('source')}\n"
            f"audience_scope: {m.get('audience_scope')}\n"
            f"policy_type: {m.get('policy_type')}\n"
            f"effective_date: {m.get('effective_date')}\n"
            f"version: {m.get('version')}\n"
            f"specificity_level: {m.get('specificity_level')}\n"
            f"override_notes: {m.get('override_notes')}\n"
            f"--- RELEVANT EXCERPT ---\n"
            f"{segment}\n"
            f"--- END EXCERPT ---\n"
        )

    context_str = "\n\n".join(context_blocks)
    user_role = q_ctx.get("user_role", "unknown")

    prompt = f"""
You are an HR policy assistant.

User role (from query): {user_role}
User question:
\"\"\"{query}\"\"\"

You are given multiple policy document excerpts (not full text). They may conflict.

Excerpts:
{context_str}

INSTRUCTIONS (VERY IMPORTANT):
1. Use ONLY the provided excerpts. Do not invent rules.
2. If there is a document that is specific to the user's role (e.g. interns),
   that role-specific document OVERRIDES more general documents for that user.
3. If multiple documents apply to the same role, prefer:
   - the one with the most recent effective_date, or
   - the higher version (e.g. v2 > v1), or
   - explicit override_notes that say it updates/replaces earlier policies.
4. Your answer MUST:
   - Directly answer the user's question in 1–3 sentences.
   - Briefly explain why (mentioning the conflict resolution in 1–2 sentences).
   - Explicitly mention which document(s) provided the final ruling, by filename.

FORMAT:
Final Answer: <very short direct answer dont give reasoning>
Sources: <comma-separated list of source filenames>
"""

    model = genai.GenerativeModel(GEMINI_MODEL)
    resp = model.generate_content(prompt)
    return resp.text.strip()



# ANSWER CACHE


def cached_answer(query: str) -> Optional[str]:
    key = hashlib.md5(query.encode("utf-8")).hexdigest()
    return answer_cache.get(key)

def store_answer(query: str, answer: str):
    key = hashlib.md5(query.encode("utf-8")).hexdigest()
    answer_cache[key] = answer
    save_json(ANSWER_CACHE_PATH, answer_cache)





In [None]:
def answer(query: str) -> str:

    # Answer cache (no LLM cost if repeat)
    cached = cached_answer(query)
    if cached:
        return cached

    #  Init + ingest (ingest only first time)
    collection = init_chroma()
    ingest_all_txt(collection)

    #  Query understanding
    q_ctx = extract_query_context(query)

    # Retrieval
    candidates = retrieve_candidates(collection, query, q_ctx)
    if not candidates:
        final = "Final Answer: No matching documents found.\nSources: none"
        store_answer(query, final)
        return final

    # Rerank with conflict logic
    ranked = rerank_with_conflict_logic(query, candidates, q_ctx, top_k=3)

    # Final reasoning over multiple docs (with trimmed segments)
    final = reason_over_docs(query, ranked, q_ctx)

    # Store in cache
    store_answer(query, final)
    return final

In [None]:
# test
print(answer("I just joined as a new intern. Can I work from home?"))