In [1]:
import os
import re
import pickle
import faiss
import numpy as np
from typing import List, Dict
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForCausalLM, pipeline
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from string import punctuation
import re
import nltk
from nltk.corpus import stopwords

nltk.download("stopwords")
STOPWORDS = set(stopwords.words("english"))
# ---------------- Config ----------------
EMBED_MODEL   = "sentence-transformers/all-MiniLM-L6-v2"
CROSS_ENCODER = "cross-encoder/ms-marco-MiniLM-L-6-v2"
QA_MODEL      = "deepset/roberta-base-squad2"


OUT_DIR     = "data/index_merged"
FAISS_PATH  = os.path.join(OUT_DIR, "faiss_merged.index")
BM25_PATH   = os.path.join(OUT_DIR, "bm25_merged.pkl")
META_PATH   = os.path.join(OUT_DIR, "meta_merged.pkl")

# ---------------- Load Indexes & Models ----------------
print("Loading FAISS, BM25, metadata, and models ...")

faiss_index = faiss.read_index(FAISS_PATH)

with open(BM25_PATH, "rb") as f:
    bm25_obj = pickle.load(f)
bm25 = bm25_obj["bm25"]

with open(META_PATH, "rb") as f:
    meta: List[Dict] = pickle.load(f)

embed_model = SentenceTransformer(EMBED_MODEL)
reranker = CrossEncoder(CROSS_ENCODER)

qa_pipeline = pipeline(
    "question-answering",
    model=QA_MODEL,
    tokenizer=QA_MODEL
)

print("Indexes and models loaded!")

# ---------------- Guardrails ----------------
BLOCKED_TERMS = ["weather", "cricket", "movie", "song", "football", "holiday",
                 "travel", "recipe", "music", "game", "sports", "politics", "election"]

FINANCE_DOMAINS = [
    "financial reporting", "balance sheet", "income statement",
    "assets and liabilities", "equity", "revenue", "profit and loss",
    "goodwill impairment", "cash flow", "dividends", "taxation",
    "investment", "valuation", "capital structure", "ownership interests",
    "subsidiaries", "shareholders equity", "expenses", "earnings",
    "debt", "amortization", "depreciation"
]
finance_embeds = embed_model.encode(FINANCE_DOMAINS, convert_to_tensor=True)

def validate_query(query: str, threshold: float = 0.5) -> bool:
    q_lower = query.lower()
    if any(bad in q_lower for bad in BLOCKED_TERMS):
        print("[Guardrail] Rejected by blocklist.")
        return False
    q_emb = embed_model.encode(query, convert_to_tensor=True)
    sim_scores = util.cos_sim(q_emb, finance_embeds)
    max_score = float(sim_scores.max())
    if max_score > threshold:
        print(f"[Guardrail] Accepted (semantic match {max_score:.2f})")
        return True
    else:
        print(f"[Guardrail] Rejected (low semantic score {max_score:.2f})")
        return False

def validate_output(answer: str, context_docs: List[Dict]) -> str:
    combined_context = " ".join([doc["content"].lower() for doc in context_docs])
    if answer.lower() in combined_context:
        return answer
    return "The information could not be verified in the financial statements."

def preprocess_query(query: str, remove_stopwords: bool = True) -> str:
    """
    Preprocess query: clean, lowercase, optional stopword removal.
    """
    # Lowercase
    query = query.lower()
    # Remove non-alphanumeric (keep numbers, words, spaces)
    query = re.sub(r"[^a-z0-9\s]", " ", query)
    # Tokenize
    tokens = query.split()
    if remove_stopwords:
        tokens = [t for t in tokens if t not in STOPWORDS]
    return " ".join(tokens)

# ---------------- Hybrid Candidate Retrieval ----------------
def hybrid_candidates(query: str, candidate_k: int = 50, alpha: float = 0.5) -> List[int]:
    q_emb = embed_model.encode(
    [preprocess_query(query, remove_stopwords=False)], 
    convert_to_numpy=True, 
    normalize_embeddings=True
)

    faiss_scores, faiss_ids = faiss_index.search(q_emb, max(candidate_k, 50))
    faiss_ids = faiss_ids[0]
    faiss_scores = faiss_scores[0]

    tokenized_query = preprocess_query(query, remove_stopwords=True).split()
    bm25_scores = bm25.get_scores(tokenized_query)

    topN = max(candidate_k, 50)
    bm25_top = np.argsort(bm25_scores)[::-1][:topN]
    faiss_top = faiss_ids[:topN]
    union_ids = np.unique(np.concatenate([bm25_top, faiss_top]))

    faiss_score_map = {int(i): float(s) for i, s in zip(faiss_ids, faiss_scores)}
    f_arr = np.array([faiss_score_map.get(int(i), -1.0) for i in union_ids], dtype=float)
    f_min = np.min(f_arr)
    if np.any(f_arr < 0):
        f_arr = np.where(f_arr < 0, f_min, f_arr)
    b_arr = np.array([bm25_scores[int(i)] for i in union_ids], dtype=float)

    def _norm(x):
        rng = np.ptp(x)
        return (x - np.min(x)) / (rng + 1e-9)

    f_norm = _norm(f_arr)
    b_norm = _norm(b_arr)
    combined = alpha * f_norm + (1 - alpha) * b_norm
    order = np.argsort(combined)[::-1]
    ranked_ids = union_ids[order][:candidate_k]
    return ranked_ids.tolist()

# ---------------- Cross-Encoder Rerank ----------------
def rerank_cross_encoder(query: str, cand_ids: List[int], top_k: int = 10) -> List[Dict]:
    pairs = [(query, meta[i]["content"]) for i in cand_ids]
    scores = reranker.predict(pairs)
    order = np.argsort(scores)[::-1][:top_k]
    results = []
    for rank_idx in order:
        i = cand_ids[rank_idx]
        results.append({
            "id": meta[i]["id"],
            "chunk_size": meta[i]["chunk_size"],
            "content": meta[i]["content"],
            "rerank_score": float(scores[rank_idx]),
        })
    return results

# ---------------- QA Answer Extraction via OpenAI Mistral ----------------
def answer_question(query: str, context_docs: List[Dict]) -> str:
    processed_query = query
    context = "\n".join([doc["content"] for doc in context_docs])
    qa_input = {
        "question": processed_query,
        "context": context
    }
    result = qa_pipeline(qa_input)
    return result["answer"]


# ---------------- Concept + Year Extraction ----------------
def extract_value_for_year_and_concept(year: str, concept: str, context_docs: List[Dict]) -> str:
    target_year = str(year)
    concept_lower = concept.lower()

    for doc in context_docs:
        text = doc.get("content", "")
        # Split text into lines and filter
        lines = [line for line in text.split("\n") if line.strip() and any(c.isdigit() for c in line)]
        
        header_idx = None
        year_to_col = {}

        # Step 1: Identify header line with years
        for idx, line in enumerate(lines):
            years_in_line = re.findall(r"20\d{2}", line)
            if years_in_line:
                for col_idx, y in enumerate(years_in_line):
                    year_to_col[y] = col_idx
                header_idx = idx
                break

        if target_year not in year_to_col or header_idx is None:
            continue

        # Step 2: Find the line with the concept below header
        for line in lines[header_idx+1:]:
            if concept_lower in line.lower():
                cols = re.split(r"\s{2,}|\t", line)
                col_idx = year_to_col[target_year]
                if col_idx < len(cols):
                    value = cols[col_idx].replace(",", "")
                    return value
    return None

# ---------------- End-to-End RAG Pipeline ----------------
def rag_pipeline(query: str, top_k: int = 5, candidate_k: int = 50, alpha: float = 0.6):
    if not validate_query(query):
        return "Query rejected: Please ask finance-related questions.", []

    cand_ids = hybrid_candidates(query, candidate_k=candidate_k, alpha=alpha)
    reranked = rerank_cross_encoder(query, cand_ids, top_k=top_k)

    year_match = re.search(r"(20\d{2})", query)
    year = year_match.group(0) if year_match else None
    concept = re.sub(r"for the year 20\d{2}", "", query, flags=re.IGNORECASE).strip()

    year_specific_answer = None
    if year and concept:
        year_specific_answer = extract_value_for_year_and_concept(year, concept, reranked)

    if year_specific_answer:
        answer = year_specific_answer
    else:
        answer = answer_question(query, reranked)

    final_answer = validate_output(answer, reranked)
    return final_answer, reranked

# ---------------- Example ----------------
if __name__ == "__main__":
    q = "What is the revenue from air ticketing for  year 2024?"
    final_answer, top_docs = rag_pipeline(q, top_k=5, candidate_k=60, alpha=0.6)

    print(f"\nQuery: {q}")
    print("\nFinal Answer:\n", final_answer)
    print("\nTop supporting docs:")
    for r in top_docs:
        print(f"[{r['id']}] (chunk={r['chunk_size']}, score={r['rerank_score']:.3f}) -> {r['content'][:120]}...")


[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/kundankumar/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Loading FAISS, BM25, metadata, and models ...


Device set to use mps:0


Indexes and models loaded!
[Guardrail] Accepted (semantic match 0.54)





Query: What is the revenue from air ticketing for  year 2024?

Final Answer:
 201246

Top supporting docs:
[100-3] (chunk=100, score=1.601) -> Total current liabilities values: {"2023": "453658", "2024": "297574"} Total liabilities values: {"2023": "483769", "202...
[100-12] (chunk=100, score=-1.283) -> Hotels and packages values: {"2023": "337686", "2024": "435542", "2025": "520411"} Bus ticketing values: {"2023": "74873...
[100-11] (chunk=100, score=-1.500) -> Deferred tax liabilities, net values: {"2024": "4754", "2025": "2526"} Other non-current liabilities values: {"2024": "1...
[400-3] (chunk=400, score=-4.645) -> Total revenue values: {"2023": "593036", "2024": "782524", "2025": "978336"} Other income values: {"2023": "2798", "2024...
[100-15] (chunk=100, score=-5.389) -> Profit (loss) for the year values: {"2023": "(11168)", "2024": "216743", "2025": "95274"} Owners of the Company values: ...
