In [0]:
%run ./00_constants

In [0]:

# The * means: All parameters after * must be passed as keyword arguments.
def embed_texts(texts, *, single: bool = None):
    """
    Generate embeddings for a single string or a list of strings.

    Args:
        texts: str or list[str]
        single: Optional[bool]. If True, returns a single embedding.
                If False, returns a list of embeddings.
                If None, inferred from input type.

    Returns:
        If single=True or input is str -> list[float]
        If single=False or input is list[str] -> list[list[float]]
    """
    # Normalize input
    if isinstance(texts, str):
        inputs = [texts]
        inferred_single = True
    elif isinstance(texts, list):
        inputs = texts
        inferred_single = False
    else:
        raise TypeError("texts must be a string or list of strings")

    if single is None:
        single = inferred_single

    response = aoai.embeddings.create(
        model=EMBEDDING_DEPLOYMENT,
        input=inputs
    )

    embeddings = [d.embedding for d in response.data]

    if single:
        return embeddings[0]
    return embeddings

In [0]:
# cosine similarity function
import numpy as np

def cosine_similarity(a, b):
    a = np.array(a)
    b = np.array(b)
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

In [0]:
# ---------- Option A: brute-force cosine similarity ----------

def retrieve_top_k_optionA(query_embedding, k=5, limit=2000):

    # Load all embedding rows
    embeddings_df = spark.table(EMB_TABLE)

    rows = embeddings_df.select(
        "chunk_id",
        "doc_id",
        "title",
        "url",
        "chunk_index",
        "chunk_text",
        "category",
        "embedding"
    ).limit(limit).collect()

    scored = []

    for r in rows:
        score = cosine_similarity(query_embedding, r.embedding)
        scored.append({
            "chunk_id": r.chunk_id,
            "doc_id": r.doc_id,
            "title": r.title,
            "url": r.url,
            "chunk_index": int(r.chunk_index) if r.chunk_index is not None else None,
            "chunk_text": r.chunk_text,
            "category": r.category,
            "score": float(score),
        })

    scored.sort(key=lambda x: x["score"], reverse=True)
    return scored[:k]

# ---------- Option B: Databricks Vector Search ----------

def retrieve_top_k_optionB(query_embedding, k=5, filters=None):

    from databricks.vector_search.client import VectorSearchClient

    vsc = VectorSearchClient()

    VS_ENDPOINT = "vs_azure_compute"
    VS_INDEX_FULLNAME = f"{CATALOG}.{SCHEMA}.azure_compute_docs_vs_index"

    index = vsc.get_index(endpoint_name=VS_ENDPOINT, index_name=VS_INDEX_FULLNAME)

    vs_result = index.similarity_search(
        query_vector=query_embedding,
        columns=[
            "chunk_id",
            "doc_id",
            "title",
            "url",
            "chunk_index",
            "chunk_text",
            "category",
            # no store
        ],
        num_results=k,
        filters=filters
    )

    rows = vs_result["result"]["data_array"]

    normalized = []
    for r in rows:
        normalized.append({
            "chunk_id": r[0],
            "doc_id": r[1],
            "title": r[2],
            "url": r[3],
            "chunk_index": r[4],
            "chunk_text": r[5],
            "category": r[6],
            "score": None   # Vector Search does not always return score
        })

    return normalized


# ---------- Wrapper: choose A or B, same output ----------
# Both will return this structure:
# {
#   "chunk_id": ...,
#   "doc_id": ...,
#   "title": ...,
#   "url": ...,
#   "chunk_index": ...,
#   "chunk_text": ...,
#   "category": ...,
#   "score": ...
# }

def retrieve_top_k(query_embedding, k=5, option="A", **kwargs) -> list[dict]:
    """
    Unified retrieval entrypoint.
    - option="A": brute-force cosine similarity
    - option="B": vector search
    Always returns list[dict] with the same schema.
    """
    option = option.upper().strip()
    if option == "A":
        return retrieve_top_k_optionA(query_embedding, k=k, **kwargs)
    elif option == "B":
        return retrieve_top_k_optionB(query_embedding, k=k, **kwargs)
    else:
        raise ValueError("option must be 'A' or 'B'")

In [0]:
def ensure_rag_log_table(spark, table_name: str):
    """
    Ensure the RAG query log Delta table exists with the correct schema.
    Safe to call multiple times (idempotent).
    
    Parameters
    ----------
    spark : SparkSession
    table_name : str
        Fully qualified table name, e.g. "catalog.schema.rag_query_logs"
    """
    ddl = f"""
    CREATE TABLE IF NOT EXISTS {table_name} (
      query_id STRING,
      question STRING,
      top_k INT,
      retriever_type STRING,

      retrieved_chunks ARRAY<STRUCT<
        chunk_id: STRING,
        doc_id: STRING,
        title: STRING,
        url: STRING,
        chunk_index: INT,
        category: STRING,
        score: DOUBLE
      >>,

      prompt STRING,
      answer STRING,
      embedding_deployment STRING,
      chat_deployment STRING,
      created_at TIMESTAMP
    )
    USING DELTA
    """
    
    spark.sql(ddl)
    print(f"✅ RAG log table ensured: {table_name}")

def log_rag_event(event: dict):
    import uuid
    from datetime import datetime
    from pyspark.sql import Row
    # from pyspark.sql.types import *

    rag_log_schema = StructType([
        StructField("query_id", StringType(), True),
        StructField("question", StringType(), True),
        StructField("top_k", IntegerType(), True),
        StructField("retriever_type", StringType(), True),
        StructField("retrieved_chunks", ArrayType(
            StructType([
                StructField("chunk_id", StringType(), True),
                StructField("doc_id", StringType(), True),
                StructField("title", StringType(), True),
                StructField("url", StringType(), True),
                StructField("chunk_index", IntegerType(), True),
                StructField("category", StringType(), True),
                StructField("score", DoubleType(), True),
            ])
        ), True),
        StructField("prompt", StringType(), True),
        StructField("answer", StringType(), True),
        StructField("embedding_deployment", StringType(), True),
        StructField("chat_deployment", StringType(), True),
        StructField("created_at", TimestampType(), True),
    ])

    query_id = str(uuid.uuid4())

    # This produces a timezone-aware UTC timestamp.
    now = datetime.now(UTC)

    row_data = [{
        "query_id": query_id,
        "question": str(event["question"]),
        "top_k": int(event["top_k"]),
        "retriever_type": event["retriever_type"],
        "retrieved_chunks": event["retrieved_chunks"],
        "prompt": str(event["prompt"]),
        "answer": str(event["answer"]),
        "embedding_deployment": str(event["embedding_deployment"]),
        "chat_deployment": str(event["chat_deployment"]),
        "created_at": now
    }]

    df = spark.createDataFrame(row_data, schema=rag_log_schema)

    (
        df.write
        .format("delta")
        .mode("append")
        .saveAsTable("databricks_rag_demo.default.rag_query_logs")
    )

    return query_id

In [0]:
# Prompt assembly

def build_prompt(question, contexts):
    joined_context = "\n\n".join(contexts)
    return f"""
You are a helpful assistant answering questions about Azure Compute.

Use the following documentation excerpts to answer the question.

Context:
{joined_context}

Question:
{question}

Answer:
"""

In [0]:
##### LLM judge

###### Judge prompt builder
import json

def build_judge_prompt(question: str, answer: str, retrieved_chunks: list[dict]) -> str:
    # keep judge prompt compact: only short excerpts
    excerpts = []
    for i, c in enumerate(retrieved_chunks[:6], start=1):
        chunk_text = (c.get("chunk_text") or "")[:1200]  # cap to avoid giant prompts
        url = c.get("url")
        excerpts.append(f"[{i}] URL: {url}\nEXCERPT:\n{chunk_text}")

    sources_block = "\n\n---\n\n".join(excerpts)

    return f"""
You are evaluating a Retrieval-Augmented Generation (RAG) system for Azure Compute documentation.

Score the system on a 1 - 5 scale (integers only):
- retrieval_relevance: are the retrieved excerpts relevant to the question?
- answer_relevance: does the answer address the question?
- faithfulness: is the answer supported by the provided excerpts (no hallucination)?

Return ONLY valid JSON with keys:
retrieval_relevance, answer_relevance, faithfulness, notes

Question:
{question}

Answer:
{answer}

Retrieved sources:
{sources_block}
"""

###### Judge function (Azure OpenAI chat)

def judge_rag(question: str, answer: str, retrieved_chunks: list[dict]) -> dict:
    prompt = build_judge_prompt(question, answer, retrieved_chunks)

    resp = aoai.chat.completions.create(
        model=CHAT_DEPLOYMENT,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.0
    )

    text = resp.choices[0].message.content.strip()

    # parse JSON robustly
    try:
        data = json.loads(text)
    except json.JSONDecodeError:
        # fallback: try to extract JSON substring if model added extra text
        start = text.find("{")
        end = text.rfind("}")
        data = json.loads(text[start:end+1])

    # enforce integers 1..5
    def clamp_int(x):
        x = int(x)
        return max(1, min(5, x))

    return {
        "retrieval_relevance": clamp_int(data.get("retrieval_relevance", 3)),
        "answer_relevance": clamp_int(data.get("answer_relevance", 3)),
        "faithfulness": clamp_int(data.get("faithfulness", 3)),
        "notes": str(data.get("notes", "")).strip()[:2000]
    }


In [0]:

import uuid
from datetime import datetime, UTC
from pyspark.sql import Row
from pyspark.sql.types import *

def ensure_rag_eval_table(spark, table_name: str):
    """
    Ensure the RAG evaluation Delta table exists with the correct schema.
    Safe to call multiple times (idempotent).

    Parameters
    ----------
    spark : SparkSession
    table_name : str
        Fully qualified table name, e.g. "catalog.schema.rag_evaluations"
    """
    ddl = f"""
    CREATE TABLE IF NOT EXISTS {table_name} (
      evaluation_id STRING,
      query_id STRING,
      question STRING,
      answer STRING,
      retrieval_relevance INT,
      answer_relevance INT,
      faithfulness INT,
      evaluator STRING,
      notes STRING,
      created_at TIMESTAMP
    )
    USING DELTA
    """

    spark.sql(ddl)
    print(f"✅ RAG evaluation table ensured: {table_name}")

### Write evaluation results to Delta
def write_evaluation(query_id: str, question: str, answer: str, scores: dict, evaluator="llm_judge_v1"):

    evaluation_id = str(uuid.uuid4())
    now = datetime.now(UTC)

    rag_eval_schema = StructType([
        StructField("evaluation_id", StringType(), True),
        StructField("query_id", StringType(), True),
        StructField("question", StringType(), True),
        StructField("answer", StringType(), True),
        StructField("retrieval_relevance", IntegerType(), True),
        StructField("answer_relevance", IntegerType(), True),
        StructField("faithfulness", IntegerType(), True),
        StructField("evaluator", StringType(), True),
        StructField("notes", StringType(), True),
        StructField("created_at", TimestampType(), True),
    ])

    row = Row(
        evaluation_id=evaluation_id,
        query_id=query_id,
        question=str(question),
        answer=str(answer),
        retrieval_relevance=int(scores["retrieval_relevance"]),  # <-- FORCE
        answer_relevance=int(scores["answer_relevance"]),        # <-- FORCE
        faithfulness=int(scores["faithfulness"]),                # <-- FORCE
        evaluator=str(evaluator),
        notes=str(scores.get("notes", "")),
        created_at=now
    )
    (
        spark.createDataFrame([row], schema=rag_eval_schema)
        .write
        .format("delta")
        .mode("append")
        .saveAsTable(RAG_EVAL_TABLE)
    )
    return evaluation_id

In [0]:
from pyspark.sql import functions as F

# Enrich chunks with chunk_text
def enrich_chunks_with_text(retrieved_chunks_py: list[dict], chunks_df) -> list[dict]:
    """
    Given a list of chunk dicts (with chunk_id), attach chunk_text from chunks_df.

    Parameters
    ----------
    retrieved_chunks_py : list[dict]
        Each dict must contain "chunk_id"
    chunks_df : Spark DataFrame
        Must contain columns: chunk_id, chunk_text

    Returns
    -------
    list[dict]
        Same list, enriched with "chunk_text"
    """
    chunk_ids = [c["chunk_id"] for c in retrieved_chunks_py if c.get("chunk_id")]
    if not chunk_ids:
        return retrieved_chunks_py

    rows = (
        chunks_df
        .where(F.col("chunk_id").isin(chunk_ids))
        .select("chunk_id", "chunk_text")
        .collect()
    )

    text_map = {r["chunk_id"]: r["chunk_text"] for r in rows}

    for c in retrieved_chunks_py:
        cid = c.get("chunk_id")
        c["chunk_text"] = text_map.get(cid, "")

    return retrieved_chunks_py

# Convert logged row → Python dict chunks, Your log table stores retrieved_chunks as Spark structs. This helper converts them to normal Python dicts.
def normalize_logged_chunks(log_row) -> list[dict]:
    """
    Convert Spark struct array -> list of Python dicts.
    """
    out = []
    for c in log_row["retrieved_chunks"]:
        out.append({
            "chunk_id": c["chunk_id"],
            "doc_id": c["doc_id"],
            "title": c["title"],
            "url": c["url"],
            "chunk_index": c["chunk_index"],
            "category": c["category"],
            "score": c["score"],
            "chunk_text": ""  # will be enriched later
        })
    return out

# Run LLM evaluation on recent logs
def evaluate_recent_logs(
    spark,
    rag_log_table: str,
    chunks_table: str,
    n: int,
    judge_fn,
    write_eval_fn,
    evaluator_name="llm_judge_with_text"
):
    """
    Run evaluation on the most recent N RAG queries.

    Parameters
    ----------
    spark : SparkSession
    rag_log_table : str
        Fully qualified table name
    chunks_table : str
        Fully qualified table name
    n : int
        Number of recent logs to evaluate
    judge_fn : function
        judge_rag(question, answer, retrieved_chunks) -> scores
    write_eval_fn : function
        write_evaluation(query_id, question, answer, scores, evaluator=...)
    evaluator_name : str
    """
    from pyspark.sql import functions as F

    logs_df = spark.table(rag_log_table)
    chunks_df = spark.table(chunks_table).select("chunk_id", "chunk_text")

    recent_logs = (
        logs_df
        .orderBy(F.col("created_at").desc())
        .limit(n)
        .collect()
    )

    for r in recent_logs:
        query_id = r["query_id"]
        question = r["question"]
        answer = r["answer"]

        retrieved_chunks_py = normalize_logged_chunks(r)
        retrieved_chunks_py = enrich_chunks_with_text(retrieved_chunks_py, chunks_df)

        scores = judge_fn(question, answer, retrieved_chunks_py)
        eval_id = write_eval_fn(
            query_id=query_id,
            question=question,
            answer=answer,
            scores=scores,
            evaluator=evaluator_name
        )

        print("✅ Evaluated:", query_id, scores)

In [0]:
# The unified pipeline function: ask()
# embed → retrieve → prompt → Azure OpenAI → log → (optional) evaluate → return JSON.

def ask(
    question: str,
    *,
    k: int = 6,
    retriever: str = "A",     # "A" brute-force, "B" vector search
    do_eval: bool = True,     # run LLM judge + write to rag_evaluations
    filters: dict = None,     # optional metadata filters for Vector Search
    temperature: float = 0.2,
    verbose: bool = True
) -> dict:
    """
    End-to-end RAG call:
      - Embed
      - Retrieve
      - Prompt
      - Azure OpenAI
      - Log
      - Optional evaluation

    Returns:
      {
        query_id: str,
        question: str,
        answer: str,
        sources: [...],
        eval: {...} | None
      }
    """
    if not isinstance(question, str) or not question.strip():
        raise ValueError("question must be a non-empty string")

    retriever = retriever.upper().strip()
    if retriever not in ("A", "B"):
        raise ValueError("retriever must be 'A' or 'B'")

    if verbose:
        print(f"\n🔹 Question: {question}")
        print(f"🔹 Retriever: {retriever}")

    # 1) Embed
    q_emb = embed_texts(question)

    # 2) Retrieve
    if retriever == "B":
        chunks = retrieve_top_k(q_emb, option="B", k=k, filters=filters)
    else:
        chunks = retrieve_top_k(q_emb, option="A", k=k)

    _validate_chunks(chunks)

    # 3) Prompt
    chunks_text = [c["chunk_text"] for c in chunks]
    prompt = build_prompt(question, chunks_text)

    # 4) Call LLM
    resp = aoai.chat.completions.create(
        model=CHAT_DEPLOYMENT,
        messages=[{"role": "user", "content": prompt}],
        temperature=temperature
    )
    answer = resp.choices[0].message.content

    if verbose:
        print("\nAnswer:")
        print(answer)

    # 5) Log
    rag_event = {
        "question": question,
        "top_k": k,
        "retriever_type": retriever,
        "retrieved_chunks": chunks,
        "prompt": prompt,
        "answer": answer,
        "embedding_deployment": EMBEDDING_DEPLOYMENT,
        "chat_deployment": CHAT_DEPLOYMENT
    }
    query_id = log_rag_event(rag_event)

    if verbose:
        print("Logged query_id:", query_id)

    # 6) Optional evaluation
    eval_result = None
    if do_eval:
        eval_scores = judge_rag(question, answer, chunks)
        eval_id = write_evaluation(
            query_id=query_id,
            question=question,
            answer=answer,
            scores=eval_scores,
            evaluator="llm_judge_with_text"
        )
        eval_result = {
            "evaluation_id": eval_id,
            **eval_scores
        }

        if verbose:
            print("Eval:", eval_result)

    return {
        "query_id": query_id,
        "question": question,
        "answer": answer,
        "sources": _compact_sources(chunks),
        "eval": eval_result
    }

def batch_ask(
    questions: list[str],
    *,
    k: int = 6,
    retrievers=["A", "B"],
    do_eval: bool = True,
    filters: dict = None,
    temperature: float = 0.2,
    verbose: bool = True
):
    """
    Run multiple questions across multiple retrievers.

    Returns:
      list of results
    """
    results = []

    for q in questions:
        print("\n" + "=" * 80)
        print("Question:", q)
        print("=" * 80)

        for r in retrievers:
            print(f"\n--- Retriever {r} ---")

            res = ask(
                q,
                k=k,
                retriever=r,
                do_eval=do_eval,
                filters=filters,
                temperature=temperature,
                verbose=verbose
            )
            results.append(res)

    return results

def _validate_chunks(chunks: list[dict]):
    if not isinstance(chunks, list):
        raise TypeError("retrieve_top_k must return a list[dict]")

    if len(chunks) == 0:
        raise ValueError("No chunks were retrieved")

    required_keys = {
        "chunk_id",
        "doc_id",
        "title",
        "url",
        "chunk_index",
        "category",
        "score",
        "chunk_text",
    }

    for i, c in enumerate(chunks):
        if not isinstance(c, dict):
            raise TypeError(f"Chunk at index {i} is not a dict")

        missing = required_keys - set(c.keys())
        if missing:
            raise ValueError(f"Chunk at index {i} is missing keys: {missing}")

def _compact_sources(chunks: list[dict], max_chars: int = 200):
    """
    Compact retrieved chunks for UI / API output.
    """
    sources = []

    for c in chunks:
        text = c.get("chunk_text", "") or ""
        preview = text[:max_chars] + ("…" if len(text) > max_chars else "")

        sources.append({
            "chunk_id": c.get("chunk_id"),
            "doc_id": c.get("doc_id"),
            "title": c.get("title"),
            "url": c.get("url"),
            "chunk_index": c.get("chunk_index"),
            "category": c.get("category"),
            "score": c.get("score"),
            "preview": preview,
        })

    return sources