In [1]:
"""
AIDP PySpark â€“ OCI Bucket Document Q&A Pipeline (SAFE VERSION)
--------------------------------------------------------------
Fixed: compute CHUNK_EMB, keep real path, safe prompt length, enable prompt_truncation, minor cleanups.
"""

from typing import List, Tuple
import math
import json
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, lower, regexp_extract, monotonically_increasing_id, explode, udf, expr
)
from pyspark.sql.types import StringType, ArrayType

# -----------------------------
# User Configuration (explicit)
# -----------------------------
BUCKET = "test_doc"
NAMESPACE = "idseylbmv0mm"
PREFIX = "documents/"
BASE_URI = f"oci://{BUCKET}@{NAMESPACE}/{PREFIX}"

GENERATION_MODEL = "default.oci_ai_models.xai.grok-4"  # use your deployed name
EMBEDDING_MODEL = "default.oci_ai_models.xai.grok-4"

# Chunking / retrieval
CHUNK_SIZE = 1200            # slightly smaller to reduce prompt size pressure
CHUNK_OVERLAP = 200
TOP_K = 8
MAX_FILE_CHARS = 200_000

# Prompt safety budget (characters, not tokens, but good enough)
MAX_PROMPT_CHARS = 10_000

QUESTION = "Give all words similar to anaconda"

# -----------------------------
# Spark Session
# -----------------------------
spark = SparkSession.builder.appName("oci-doc-qa").getOrCreate()

# -----------------------------
# Read files from OCI
# -----------------------------
SUPPORTED_TEXT_EXT = {".txt", ".md"}

binary_df = (
    spark.read.format("binaryFile")
    .option("recursiveFileLookup", "true")
    .load(BASE_URI)
    .withColumn("ext", lower(regexp_extract(col("path"), r"(\.[^./\\]+)$", 1)))
    .filter(col("ext").isin(*list(SUPPORTED_TEXT_EXT)))
)

print(f"[INFO] Found {binary_df.count()} supported files")

# -----------------------------
# Parse binary files to text
# -----------------------------
def _safe_truncate(text: str) -> str:
    if not text:
        return ""
    if len(text) > MAX_FILE_CHARS:
        return text[:MAX_FILE_CHARS] + f"\n...[TRUNCATED {len(text)-MAX_FILE_CHARS} chars]"
    return text

def _bytes_to_text(path: str, ext: str, content: bytes) -> str:
    try:
        return _safe_truncate(content.decode("utf-8", errors="ignore"))
    except Exception as e:
        return f"[PARSE_ERROR] {path}: {e}"

@udf(returnType=StringType())
def parse_binary_to_text(path: str, ext: str, content: bytes) -> str:
    try:
        if content is None:
            return ""
        return _bytes_to_text(path, ext, content)
    except Exception as e:
        return f"[PARSE_ERROR] {path}: {e}"

text_df = (
    binary_df
    .withColumn("content", parse_binary_to_text(col("path"), col("ext"), col("content")))
    .select(col("path"), col("content"))
)

text_df.show(truncate=80)

# -----------------------------
# Chunking
# -----------------------------
def split_into_chunks(text: str, chunk_size: int, overlap: int):
    chunks = []
    if not text:
        return chunks
    start, n = 0, len(text)
    step = max(1, chunk_size - overlap)
    while start < n:
        end = min(start + chunk_size, n)
        chunks.append(text[start:end])
        start += step
    return chunks

@udf(returnType=ArrayType(StringType()))
def chunk_udf(content: str):
    if content is None:
        return []
    try:
        return split_into_chunks(str(content), CHUNK_SIZE, CHUNK_OVERLAP)
    except Exception as e:
        return [f"UDF ERROR: {str(e)}"]

chunks_df = (
    text_df
    .withColumn("chunks", chunk_udf(col("content")))
    .withColumn("chunk", explode(col("chunks")))
    .withColumn("chunk_id", monotonically_increasing_id())   # unique id
    # IMPORTANT: keep the real path; do NOT overwrite it
)

print("[INFO] Total chunks:", chunks_df.count())

# Pull a manageable slice to avoid huge local memory during testing.
# Remove `.limit(200)` if you want everything.
chunk_rows = chunks_df.select("chunk_id", "path", "chunk").limit(200).collect()
print("[INFO] Collected chunk rows:", len(chunk_rows))

CHUNK_IDS  = [r["chunk_id"] for r in chunk_rows]
CHUNK_PATHS = [r["path"] for r in chunk_rows]
CHUNK_TEXTS = [r["chunk"] for r in chunk_rows]

# -----------------------------
# Embeddings
# -----------------------------
def ensure_vector(vec) -> List[float]:
    """Make sure embeddings are always a list of floats."""
    if not vec:
        return []
    if isinstance(vec, str):
        try:
            vec = json.loads(vec)
        except Exception as e:
            print("[ERROR] Could not parse embedding:", e, str(vec)[:80])
            return []
    # Handle potential dict format {"embedding":[...]}
    if isinstance(vec, dict) and "embedding" in vec:
        vec = vec["embedding"]
    return [float(x) for x in vec]

def embed_texts(texts: List[str]) -> List[List[float]]:
    if not texts:
        return []
    print(f"[DEBUG] Sending {len(texts)} texts to {EMBEDDING_MODEL}")
    df = spark.createDataFrame([(t,) for t in texts], ["text"])
    embed_df = df.select(expr(f"query_model('{EMBEDDING_MODEL}', text) as embedding"))
    result = embed_df.collect()
    vectors = []
    for row in result:
        emb = row["embedding"]
        vectors.append(ensure_vector(emb))
    print(f"[DEBUG] Got {len(vectors)} embeddings")
    return vectors

print("[INFO] Embedding all chunks...")
CHUNK_EMB = embed_texts(CHUNK_TEXTS)
print(f"[INFO] Got {len(CHUNK_EMB)} chunk embeddings")

# -----------------------------
# Retrieval + Generation
# -----------------------------
def cosine_sim(a: List[float], b: List[float]) -> float:
    if not a or not b:
        return 0.0
    m = min(len(a), len(b))
    dot = sum(float(a[i]) * float(b[i]) for i in range(m))
    na = math.sqrt(sum(float(a[i])**2 for i in range(m)))
    nb = math.sqrt(sum(float(b[i])**2 for i in range(m)))
    return dot / (na*nb) if na and nb else 0.0

def retrieve(question: str, k: int = TOP_K) -> List[Tuple[str, str, float]]:
    q_embs = embed_texts([question])
    if not q_embs:
        return []
    q_emb = q_embs[0]
    scored = [
        (path, text, cosine_sim(q_emb, ensure_vector(vec)))
        for path, text, vec in zip(CHUNK_PATHS, CHUNK_TEXTS, CHUNK_EMB)
    ]
    scored.sort(key=lambda x: x[2], reverse=True)
    return scored[:k]

def _build_context_under_budget(scored_ctx: List[Tuple[str, str, float]], max_chars: int) -> str:
    """
    Concatenate [Source i] blocks until we hit max_chars.
    """
    blocks = []
    total = 0
    for i, (path, text, score) in enumerate(scored_ctx, 1):
        block = f"[Source {i}] {path}\n{text}"
        # If single block is huge, hard-trim it.
        if len(block) > max_chars // max(1, len(scored_ctx)):
            block = block[: max_chars // max(1, len(scored_ctx))] + "\n...[TRUNCATED]\n"
        if total + len(block) > max_chars:
            break
        blocks.append(block)
        total += len(block)
    return "\n\n".join(blocks)

def _make_prompt(context_text: str, question: str) -> str:
    return f"""
You are a helpful enterprise assistant. Answer the question strictly using the CONTEXT.
If the answer is not in the context, say you don't know. Provide concise, factual answers.

CONTEXT:
{context_text}

QUESTION:
{question}

Answer with bullet points where useful and include a short citations list as [Source N].
""".strip()

def answer_question(question: str) -> Tuple[str, List[str]]:
    top = retrieve(question, TOP_K)
    if not top:
        return "No relevant content found in bucket.", []

    # Build safe context under character budget
    context_text = _build_context_under_budget(top, MAX_PROMPT_CHARS)
    prompt = _make_prompt(context_text, question)

    # Escape single quotes for Spark SQL string
    escaped_prompt = prompt.replace("'", "''")

    # Use prompt_truncation='AUTO' so the service trims if still too big
    df = spark.range(1).select(
        expr(
            f"""query_model('{GENERATION_MODEL}', '{escaped_prompt}', map('prompt_truncation','AUTO')) as response"""
        )
    )

    resp = df.collect()[0]["response"]
    used_paths = [path for (path, _text, _score) in top]
    return (resp or "(No response text)"), used_paths

# -----------------------------
# Public API
# -----------------------------
def ask(question: str):
    if not CHUNK_TEXTS:
        print("[ERROR] No chunks found!")
        return
    if not CHUNK_EMB:
        print("[ERROR] No embeddings found!")
        return

    ans, sources = answer_question(question)
    print("\n===== ANSWER =====\n")
    print(ans)
    print("\n===== SOURCES =====\n")
    for s in sorted(set(sources)):
        print(s)

if __name__ == "__main__":
    ask(QUESTION)


[INFO] Found 1 supported files


+-----------------------------------------------+--------------------------------------------------------------------------------+
|                                           path|                                                                         content|
+-----------------------------------------------+--------------------------------------------------------------------------------+
|oci://test_doc@idseylbmv0mm/documents/words.txt|2\n1080\n&c\n10-point\n10th\n11-point\n12-point\n16-point\n18-point\n1st\n2,4...|
+-----------------------------------------------+--------------------------------------------------------------------------------+



[INFO] Total chunks: 201


[INFO] Collected chunk rows: 200
[INFO] Embedding all chunks...
[DEBUG] Sending 200 texts to hive.oci_ai_models.cohere.embed-english-v3.0


[DEBUG] Got 200 embeddings
[INFO] Got 200 chunk embeddings
[DEBUG] Sending 1 texts to hive.oci_ai_models.cohere.embed-english-v3.0


[DEBUG] Got 1 embeddings



===== ANSWER =====

- anacahuita
- anacahuite
- anacalypsis
- an

===== SOURCES =====

oci://test_doc@idseylbmv0mm/documents/words.txt
