In [1]:
import json
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# -------------------------------
# 1. Load FAISS + Metadata
# -------------------------------
index = faiss.read_index("../data/rag_chunks/faiss_index.idx")

with open("../data/rag_chunks/metadata.json", "r", encoding="utf-8") as f:
    chunks = json.load(f)

print(f"📂 Loaded {len(chunks)} metadata chunks")

📂 Loaded 837 metadata chunks


In [3]:
# -------------------------------
# 2. Embedding Model (same as before)
# -------------------------------
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [8]:
# -------------------------------
# 4. Retrieval Function
# -------------------------------
def search(query, top_k=5):
    query_vec = embedder.encode([query]).astype("float32")
    D, I = index.search(query_vec, top_k)
    return [chunks[idx]["text"] for idx in I[0] if idx != -1]

In [31]:
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline

# ---- Model
model_name = "google/flan-t5-base"   # try flan-t5-large if CPU allows
tokenizer = AutoTokenizer.from_pretrained(model_name)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="cpu")
llm = pipeline("text2text-generation", model=llm_model, tokenizer=tokenizer)

STOPWORDS = set("""
a an and are as at be by for from has have in is it of on or that the to was were will with
fy q1 q2 q3 q4 year quarter
""".split())

def count_tokens(txt: str) -> int:
    return len(tokenizer.encode(txt, add_special_tokens=True))

def simple_overlap_score(q: str, p: str) -> int:
    # naive keyword overlap to push relevant passages up
    tok = lambda s: [w for w in re.findall(r"[A-Za-z0-9,.-]+", s.lower()) if w not in STOPWORDS]
    qset = set(tok(q))
    return sum(1 for w in tok(p) if w in qset)

def search_with_text(question: str, top_k: int = 12):
    """
    Uses your FAISS index + embedder defined earlier in the notebook.
    Returns (passages, distances).
    """
    qv = embedder.encode([question]).astype("float32")
    D, I = index.search(qv, top_k)
    hits = []
    for rank, idx in enumerate(I[0]):
        if idx == -1: 
            continue
        hits.append((chunks[idx]["text"], float(D[0][rank])))
    return hits  # list of (text, distance)

def build_prompt(question: str, passages: list[str], token_budget: int):
    system = (
        "You are a financial analyst assistant.\n"
        "Use ONLY the provided context to answer. If the answer is not present, say you don't know.\n"
        "Cite the period identifiers (e.g., Q4 FY24, FY23) when relevant.\n\n"
    )
    prefix = system + "Context:\n"
    suffix = f"\n\nQuestion: {question}\nAnswer in 1–2 complete sentences:\n"

    selected = []
    for p in passages:
        trial = prefix + "\n".join(selected + [p]) + suffix
        if count_tokens(trial) > token_budget:
            continue
        selected.append(p)

    # if nothing fit (very long first passage), hard-truncate first
    if not selected and passages:
        p = passages[0]
        while p and count_tokens(prefix + p + suffix) > token_budget:
            p = p[: max(32, len(p)//2)]
        if p: selected = [p]

    return prefix + "\n".join(f"- {s}" for s in selected) + suffix

def rag_query(question: str, top_k_retrieval: int = 12, max_ctx_passages: int = 6,
              input_token_budget: int = 480, max_new_tokens: int = 96):
    # 1) Retrieve
    hits = search_with_text(question, top_k=top_k_retrieval)
    passages = [t for (t, _) in hits]

    # 2) Re-rank by keyword overlap (keeps FAISS order as tie-breaker)
    passages = sorted(passages, key=lambda p: simple_overlap_score(question, p), reverse=True)
    passages = passages[:max_ctx_passages]

    # 3) Build prompt within budget (Flan-T5 <= 512)
    prompt = build_prompt(question, passages, token_budget=input_token_budget)

    # 4) Generate (encourage complete answers)
    out = llm(
        prompt,
        max_new_tokens=max_new_tokens,
        min_new_tokens=24,              # <-- avoid one-word answers
        do_sample=False,
        num_beams=4,                    # better completeness on CPU
        no_repeat_ngram_size=3,
        truncation=True,
        clean_up_tokenization_spaces=True,
    )
    return out[0]["generated_text"].strip(), passages, count_tokens(prompt)

def debug_query(q):
    ans, used_passages, tok = rag_query(q)
    print(f"Question: {q}")
    #print(f"Input tokens: {tok} | Passages used: {len(used_passages)}")
    print("\n--- Context used ---")
    #for i, p in enumerate(used_passages, 1):
        #print(f"[{i}] {p}\n")
    print("--- Answer ---")
    print(ans)


Device set to use cpu


In [36]:
debug_query("What was the revenue from operations in Q4 FY24?")




Question: What was the revenue from operations in Q4 FY24?

--- Context used ---
--- Answer ---
51,488. (Source: financial_statement_fixed_2024) Revenue from operations in Q3 FY23 was 50,844.
