In [None]:
import os
import json
import difflib
import pandas as pd
import datetime
from dateutil import parser
from dateutil.relativedelta import relativedelta
import getpass
from typing import Annotated, TypedDict, List
import re
import pickle
import faiss

# LangChain / LangGraph Imports
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from openai import OpenAI

# ==============================================================================
# ‚úÖ 0. RAG RETRIEVER (Single File Implementation)
# ==============================================================================

class SystemKBRetriever:
    """
    Loads FAISS index + metadata created from your KB excel.
    Uses ONLY the contract paragraph to retrieve similar examples.
    """
    def __init__(self, kb_dir: str):
        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError(
                f"‚ùå KB files not found in: {kb_dir}\n"
                f"Expected:\n- {index_path}\n- {meta_path}\n\n"
                f"Build the KB first using your KB builder script."
            )

        print(f"‚úÖ Loading FAISS Index: {index_path}")
        self.index = faiss.read_index(index_path)

        print(f"‚úÖ Loading KB Metadata: {meta_path}")
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        print(f"‚úÖ KB Loaded: {len(self.meta)} rows")

    def retrieve(self, query_text: str, top_k: int = 3):
        """
        Returns:
        [
          {"score": float, "meta": {...all 29 cols...}},
          ...
        ]
        """
        import numpy as np
        from sentence_transformers import SentenceTransformer

        query_text = str(query_text).strip()
        if not query_text:
            return []

        # ‚úÖ Use embedding model only when needed (lazy load)
        if not hasattr(self, "embedder"):
            self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)

        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0:
                continue
            results.append({
                "score": float(score),
                "meta": self.meta[idx]
            })

        return results


# ==============================================================================
# 1. CONFIGURATION & FILE PATHS
# ==============================================================================

# ‚¨áÔ∏è UPDATE PATHS HERE ‚¨áÔ∏è
TAXONOMY_PATH = r'C:\Users\mukeshkr\Desktop\DefenseExtraction\testing\taxonomy.json'
SUPPLIERS_PATH = r'C:\Users\mukeshkr\Desktop\DefenseExtraction\testing\suppliers.json'
INPUT_EXCEL_PATH = r"C:\Users\mukeshkr\Desktop\DefenseExtraction\data\source_file.xlsx"
OUTPUT_EXCEL_PATH = "Processed_Defense_Data.xlsx"

# ‚úÖ RAG KB Directory (must contain system_kb.faiss + system_kb_meta.pkl)
RAG_KB_DIR = r"C:\Users\mukeshkr\Desktop\DefenseExtraction\testing\system_kb_store"

# Setup API Key
if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter the LLM Foundry API Key: ")

# Shared Client
client = OpenAI(
    api_key=f'{os.environ.get("LLMFOUNDRY_TOKEN")}:my-test-project',
    base_url="https://llmfoundry.straive.com/openai/v1/",
)

# ‚úÖ Load retriever once globally
retriever = SystemKBRetriever(kb_dir=RAG_KB_DIR)

# --- FILE LOADING HELPERS ---
def load_json_file(filename, default_value):
    try:
        with open(filename, 'r', encoding='utf-8') as f:
            print(f"‚úÖ Loaded {filename}")
            return json.load(f)
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not load {filename} ({e}). Using default.")
        return default_value

# 1. Load Taxonomy
raw_taxonomy = load_json_file(TAXONOMY_PATH, {})
TAXONOMY_STR = json.dumps(raw_taxonomy, separators=(',', ':'))

# 2. Load Suppliers
SUPPLIER_LIST = load_json_file(SUPPLIERS_PATH, [
    "Dell Inc", "Boeing", "Lockheed Martin", "Raytheon Technologies",
    "Northrop Grumman", "L3Harris", "BAE Systems", "General Dynamics"
])

# 3. System Rules
RULE_BOOK = {
    "defensive_countermeasures": {
        "triggers": ["flare", "chaff", "countermeasure", "decoy", "mju-", "ale-"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Defensive Systems', Specific: 'Defensive Aid Suite'"
    },
    "radars_and_sensors": {
        "triggers": ["radar", "sonar", "sensor", "an/apy", "an/tpy"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Sensors'"
    },
    "ammunition": {
        "triggers": ["cartridge", "round", "projectile", " 5.56", " 7.62", "ammo"],
        "guidance": "Market Segment: 'Weapon Systems', System Type (General): 'Ammunition'"
    }
}

# 4. Geography Mapping
GEOGRAPHY_MAPPING = {
    "North America": ["USA", "United States", "US", "Canada", "America"],
    "Europe": ["UK", "United Kingdom", "Ukraine", "Germany", "France", "Italy", "Spain", "Poland", "Netherlands", "Norway", "Sweden", "Finland", "Denmark", "Belgium"],
    "Asia-Pacific": ["Australia", "Japan", "South Korea", "Taiwan", "India", "Singapore", "New Zealand"],
    "Middle East and North Africa": ["Israel", "Saudi Arabia", "UAE", "Egypt", "Qatar", "Kuwait", "Iraq"],
    "International Organisations": ["NATO", "EU", "IFU", "UN", "NSPA"]
}

# ==============================================================================
# 2. HELPER FUNCTIONS
# ==============================================================================

def get_best_supplier_match(extracted_name):
    """Fuzzy matches the extracted name against the loaded SUPPLIER_LIST."""
    if not extracted_name or extracted_name.lower() in ["unknown", "n/a"]:
        return "Unknown"

    clean_name = extracted_name.strip()

    supplier_map = {s.lower(): s for s in SUPPLIER_LIST}
    if clean_name.lower() in supplier_map:
        return supplier_map[clean_name.lower()]

    matches = difflib.get_close_matches(clean_name, SUPPLIER_LIST, n=1, cutoff=0.6)
    return matches[0] if matches else clean_name


def calculate_mro_months(start_date_str, end_date_text, program_type):
    """Calculates duration only if Program Type is MRO/Support."""
    if program_type != "MRO/Support":
        return "Not Applicable"

    try:
        if not start_date_str or not end_date_text or str(end_date_text).lower() in ["unknown", "n/a"]:
            return "Not Applicable"

        start = pd.to_datetime(start_date_str, dayfirst=True)
        end = parser.parse(end_date_text, fuzzy=True)

        diff = relativedelta(end, start)
        total_months = (diff.years * 12) + diff.months
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"


def get_region_for_country(country_name):
    """Robust lookup handling casing/whitespace."""
    if not country_name or country_name.lower() in ["unknown", "n/a", "not applicable"]:
        return "Unknown"

    clean = country_name.strip().lower()

    if clean in ["us", "usa", "united states", "united states of america"]:
        return "North America"
    if clean in ["uk", "britain", "great britain"]:
        return "Europe"

    for region, countries in GEOGRAPHY_MAPPING.items():
        if any(c.lower() == clean for c in countries):
            return region
    return "Unknown"


# ‚úÖ Regex Designator Extractors (for System Name + Piloting)
DESIGNATOR_PATTERNS = [
    r"\bDDG[-\s]?\d+\b", r"\bCVN[-\s]?\d+\b", r"\bSSN[-\s]?\d+\b",
    r"\bLCS[-\s]?\d+\b", r"\bLPD[-\s]?\d+\b", r"\bLHA[-\s]?\d+\b", r"\bLHD[-\s]?\d+\b",
    r"\bF-\d+\b", r"\bB-\d+\b", r"\bC-\d+\b", r"\bA-\d+\b",
    r"\bMQ-\d+\b", r"\bRQ-\d+\b",
    r"\bAN\/[A-Z0-9\-]+\b",
    r"\b(AIM|AGM|SM|RIM|MIM)-\d+\b",
]

def extract_designators(text: str):
    text = str(text)
    found = []
    for pat in DESIGNATOR_PATTERNS:
        found.extend(re.findall(pat, text, flags=re.IGNORECASE))
    cleaned = []
    for f in found:
        cleaned.append(f.upper().replace(" ", "").replace("--", "-"))
    seen = set()
    final = []
    for x in cleaned:
        if x not in seen:
            final.append(x)
            seen.add(x)
    return final

def detect_piloting_rule_based(text: str, designators: List[str]) -> str:
    t = str(text).lower()

    if any(d.startswith(("MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    if any(k in t for k in ["unmanned", "uav", "drone", "autonomous"]):
        return "Uncrewed"

    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD")) for d in designators):
        return "Crewed"
    if "uss " in t:
        return "Crewed"

    return "Not Applicable"


# ==============================================================================
# 3. TOOL DEFINITIONS (AGENTS)
# ==============================================================================

# --- AGENT 1: SOURCING ---
class SourcingInput(BaseModel):
    paragraph: str = Field(description="Contract text.")
    url: str = Field(description="Source URL.")
    date: str = Field(description="Contract Date.")

@tool("sourcing_extractor")
def sourcing_extractor(paragraph: str, url: str, date: str):
    """Stage 1: Prepares Metadata."""
    reported_date = datetime.datetime.now().strftime("%Y-%m-%d")
    notes = "Standard extraction."
    if "modification" in paragraph.lower():
        notes = "Contract Modification."
    if "split" in paragraph.lower():
        notes = "Split award detected."

    return {
        "Description of Contract": paragraph,
        "Additional Notes (Internal Only)": notes,
        "Source Link(s)": url,
        "Contract Date": date,
        "Reported Date (By SGA)": reported_date
    }


# --- AGENT 2: GEOGRAPHY ---
class GeographyInput(BaseModel):
    paragraph: str = Field(description="Contract text.")

@tool("geography_extractor")
def geography_extractor(paragraph: str):
    """Stage 2: Geography Logic."""
    sys_prompt = """
Extract: Customer Country, Supplier Country, Customer Operator.
Logic: If 'Navy awarded...', Operator is Navy.
Return JSON:
{
  "Customer Country": "...",
  "Customer Operator": "...",
  "Supplier Country": "..."
}
"""
    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": paragraph}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
    except Exception:
        raw = {}

    cust = raw.get("Customer Country", "Unknown")
    supp = raw.get("Supplier Country", "Unknown")

    domestic = "Indigenous" if cust.lower() == supp.lower() else "Imported"
    if "united states" in cust.lower() and "usa" in supp.lower():
        domestic = "Indigenous"

    return {
        "Customer Region": get_region_for_country(cust),
        "Customer Country": cust,
        "Customer Operator": raw.get("Customer Operator", "Unknown"),
        "Supplier Region": get_region_for_country(supp),
        "Supplier Country": supp,
        "Domestic Content": domestic
    }


# --- ‚úÖ AGENT 3: SYSTEM (UPGRADED WITH RAG + Evidence + Reason) ---
class SystemInput(BaseModel):
    paragraph: str = Field(description="Contract text.")

@tool("system_classifier")
def system_classifier(paragraph: str):
    """Stage 3: System classification using RAG + Rule Book + Evidence & Reason."""
    paragraph = str(paragraph).strip()
    if not paragraph:
        return {}

    lower_text = paragraph.lower()

    # RULE BOOK triggers
    hints = [
        f"RULE: {v['guidance']}"
        for _, v in RULE_BOOK.items()
        if any(t in lower_text for t in v["triggers"])
    ]
    hint_str = "\n".join(hints) if hints else "No special override rules triggered."

    # Local extraction
    designators = extract_designators(paragraph)

    # Rule-based piloting
    piloting_rule = detect_piloting_rule_based(paragraph, designators)

    # RAG Retrieval
    rag_hits = retriever.retrieve(paragraph, top_k=3)

    rag_examples = []
    for hit in rag_hits:
        meta = hit["meta"]
        rag_examples.append({
            "score": round(hit["score"], 4),
            "Market Segment": meta.get("Market Segment", ""),
            "System Type (General)": meta.get("System Type (General)", ""),
            "System Type (Specific)": meta.get("System Type (Specific)", ""),
            "System Name (General)": meta.get("System Name (General)", ""),
            "System Name (Specific)": meta.get("System Name (Specific)", ""),
            "System Piloting": meta.get("System Piloting", ""),
            "Supplier Name": meta.get("Supplier Name", ""),
            "Customer Operator": meta.get("Customer Operator", ""),
            "Snippet": meta.get("Description of Contract", "")[:220] + "..."
        })

    sys_prompt = f"""
You are a Senior Defense System Classification Analyst.

Use these inputs:
1) RAG Similar Examples (top 3)
2) Rule Book Overrides
3) Extracted Designators

RULE BOOK OVERRIDES:
{hint_str}

OUTPUT RULES:
- Return ONLY a FLAT JSON object.
- Every value must be a STRING.
- Do NOT return nested JSON or lists.
- Evidence must be copied EXACTLY from paragraph.
- If evidence not present, use "Not Found".

Return JSON exactly:
{{
  "Market Segment": "",
  "Market Segment Evidence": "",
  "Market Segment Reason": "",

  "System Type (General)": "",
  "System Type (General) Evidence": "",
  "System Type (General) Reason": "",

  "System Type (Specific)": "",
  "System Type (Specific) Evidence": "",
  "System Type (Specific) Reason": "",

  "System Name (General)": "",
  "System Name (General) Evidence": "",
  "System Name (General) Reason": "",

  "System Name (Specific)": "",
  "System Name (Specific) Evidence": "",
  "System Name (Specific) Reason": "",

  "System Piloting": "",
  "System Piloting Evidence": "",
  "System Piloting Reason": "",

  "Confidence": "High/Medium/Low"
}}
"""

    user_prompt = f"""
INPUT PARAGRAPH:
{paragraph}

DESIGNATORS (regex extracted):
{designators if designators else "None"}

RULE-BASED PILOTING:
{piloting_rule}

RAG SIMILAR EXAMPLES:
{json.dumps(rag_examples, indent=2)}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        result = json.loads(completion.choices[0].message.content)

        # Hard override piloting (best accuracy)
        result["System Piloting"] = piloting_rule
        if not result.get("System Piloting Reason"):
            result["System Piloting Reason"] = "Piloting derived using deterministic rules."
        if not result.get("System Piloting Evidence"):
            result["System Piloting Evidence"] = "Not Found"

        # Ensure flat
        for k, v in result.items():
            if isinstance(v, (dict, list)):
                result[k] = str(v)

        return result

    except Exception as e:
        return {
            "Market Segment": "",
            "System Type (General)": "",
            "System Type (Specific)": "",
            "System Name (General)": "",
            "System Name (Specific)": "",
            "System Piloting": piloting_rule,
            "Market Segment Evidence": "Not Found",
            "System Type (General) Evidence": "Not Found",
            "System Type (Specific) Evidence": "Not Found",
            "System Name (General) Evidence": "Not Found",
            "System Name (Specific) Evidence": "Not Found",
            "System Piloting Evidence": "Not Found",
            "Market Segment Reason": "",
            "System Type (General) Reason": "",
            "System Type (Specific) Reason": "",
            "System Name (General) Reason": "",
            "System Name (Specific) Reason": "",
            "System Piloting Reason": "Piloting derived using deterministic rules.",
            "Confidence": "Low",
            "Error": str(e)
        }


# --- AGENT 4: CONTRACT ---
class ContractInfoInput(BaseModel):
    paragraph: str = Field(description="Contract text.")
    contract_date: str = Field(description="Signed date.")

@tool("contract_extractor")
def contract_extractor(paragraph: str, contract_date: str):
    """Stage 4: Extracts Financials, Program Type, and Dates based on strict SOP."""

    system_instruction = """
You are a Defense Contract Financial Analyst. Extract data strictly following these SOP rules:

1) Supplier Name: Extract the exact company name text found in paragraph.
2) Program Type: Procurement / Training / MRO/Support / RDT&E / Upgrade / Other Service
3) Value Certainty:
   - Confirmed: fixed price/obligated stated
   - Estimated: ceiling/potential/IDIQ/multi-award
4) Quantity: number or Not Applicable
5) G2G/B2G:
   - G2G only if Foreign Military Sales (FMS)
   - else B2G
6) Completion Date: only needed for MRO duration calc

Return JSON only.
"""

    user_prompt = f"""
Analyze contract:
"{paragraph}"

Signed Date: {contract_date}

Return JSON:
{{
  "raw_supplier_name": "",
  "program_type": "",
  "value_million_raw": "",
  "currency_code": "",
  "value_certainty": "",
  "quantity": "",
  "completion_date_text": "",
  "g2g_b2g": "",
  "value_note": ""
}}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
    except Exception as e:
        return {"Error": str(e)}

    final_supplier = get_best_supplier_match(raw.get("raw_supplier_name"))
    prog_type = raw.get("program_type", "Unknown")
    mro_months = calculate_mro_months(contract_date, raw.get("completion_date_text"), prog_type)

    try:
        val_str = str(raw.get("value_million_raw", "0")).replace(",", "").replace("$", "")
        val_float = float(val_str)
        val_formatted = "{:.3f}".format(val_float)
    except:
        val_formatted = "0.000"

    try:
        dt = pd.to_datetime(contract_date)
        sign_month = dt.strftime("%B")
        sign_year = str(dt.year)
    except:
        sign_month, sign_year = "Unknown", "Unknown"

    val_note = raw.get("value_note", "Not Applicable")
    if "split" in paragraph.lower() and val_note == "Not Applicable":
        val_note = "Split contract; value distribution unclear."

    return {
        "Supplier Name": final_supplier,
        "Program Type": prog_type,
        "Expected MRO Contract Duration (Months)": mro_months,
        "Quantity": raw.get("quantity", "Not Applicable"),
        "Value Certainty": raw.get("value_certainty", "Confirmed"),
        "Value (Million)": val_formatted,
        "Currency": raw.get("currency_code", "USD$"),
        "Value (USD$ Million)": val_formatted,
        "Value Note (If Any)": val_note,
        "G2G/B2G": raw.get("g2g_b2g", "B2G"),
        "Signing Month": sign_month,
        "Signing Year": sign_year
    }


# ==============================================================================
# 4. LANGGRAPH PIPELINE
# ==============================================================================

class AgentState(TypedDict):
    input_text: str
    input_date: str
    input_url: str
    final_data: dict
    messages: Annotated[List[AnyMessage], add_messages]

def stage_1_sourcing(state: AgentState):
    res = sourcing_extractor.invoke({
        "paragraph": state["input_text"],
        "url": state["input_url"],
        "date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_2_geography(state: AgentState):
    res = geography_extractor.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_3_system(state: AgentState):
    res = system_classifier.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_4_contract(state: AgentState):
    res = contract_extractor.invoke({
        "paragraph": state["input_text"],
        "contract_date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

workflow = StateGraph(AgentState)
workflow.add_node("Stage1", stage_1_sourcing)
workflow.add_node("Stage2", stage_2_geography)
workflow.add_node("Stage3", stage_3_system)
workflow.add_node("Stage4", stage_4_contract)

workflow.add_edge(START, "Stage1")
workflow.add_edge("Stage1", "Stage2")
workflow.add_edge("Stage2", "Stage3")
workflow.add_edge("Stage3", "Stage4")
workflow.add_edge("Stage4", END)

app = workflow.compile()

# ==============================================================================
# 5. EXECUTION & FORMATTING
# ==============================================================================

if __name__ == "__main__":

    print(f"üìÇ Loading Input File: {INPUT_EXCEL_PATH}...")

    try:
        df_input = pd.read_excel(INPUT_EXCEL_PATH)

        required_cols = ["Source URL", "Contract Date", "Contract Description"]
        if not all(col in df_input.columns for col in required_cols):
            raise ValueError(f"Excel file must contain columns: {required_cols}")

        print(f"üöÄ Processing {len(df_input)} rows...")
        results = []

        for index, row in df_input.iterrows():
            print(f"   -> Row {index + 1}...")

            desc = str(row["Contract Description"]) if pd.notna(row["Contract Description"]) else ""
            c_date = str(row["Contract Date"]) if pd.notna(row["Contract Date"]) else str(datetime.date.today())
            c_url = str(row["Source URL"]) if pd.notna(row["Source URL"]) else ""

            initial_state = {
                "input_text": desc,
                "input_date": c_date,
                "input_url": c_url,
                "final_data": {},
                "messages": []
            }

            output_state = app.invoke(initial_state)
            results.append(output_state["final_data"])

        df_final = pd.DataFrame(results)

        FINAL_COLUMNS = [
            "Customer Region", "Customer Country", "Customer Operator",
            "Supplier Region", "Supplier Country", "Domestic Content",

            "Market Segment", "Market Segment Evidence", "Market Segment Reason",
            "System Type (General)", "System Type (General) Evidence", "System Type (General) Reason",
            "System Type (Specific)", "System Type (Specific) Evidence", "System Type (Specific) Reason",
            "System Name (General)", "System Name (General) Evidence", "System Name (General) Reason",
            "System Name (Specific)", "System Name (Specific) Evidence", "System Name (Specific) Reason",
            "System Piloting", "System Piloting Evidence", "System Piloting Reason",
            "Confidence",

            "Supplier Name", "Program Type", "Expected MRO Contract Duration (Months)",
            "Quantity", "Value Certainty", "Value (Million)", "Currency",
            "Value (USD$ Million)", "Value Note (If Any)", "G2G/B2G",
            "Signing Month", "Signing Year",

            "Description of Contract",
            "Additional Notes (Internal Only)",
            "Source Link(s)",
            "Contract Date",
            "Reported Date (By SGA)"
        ]

        df_final = df_final.reindex(columns=FINAL_COLUMNS, fill_value="")

        df_final.to_excel(OUTPUT_EXCEL_PATH, index=False)

        print("\n‚úÖ Processing Complete!")
        print(f"üíæ File saved to: {OUTPUT_EXCEL_PATH}")
        print(df_final.head().to_string())

    except Exception as e:
        print(f"\n‚ùå Error: {e}")


In [None]:
import os
import json
import pandas as pd
import datetime
from dateutil import parser
from dateutil.relativedelta import relativedelta
import getpass
from typing import Annotated, TypedDict, List
import re
import pickle
import faiss
from difflib import SequenceMatcher

# LangChain / LangGraph Imports
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from openai import OpenAI


# ==============================================================================
# ‚úÖ RAG RETRIEVER (Single File Implementation)
# ==============================================================================

class SystemKBRetriever:
    """
    Loads FAISS index + metadata created from your KB excel.
    Uses ONLY the contract paragraph to retrieve similar examples.
    """

    def __init__(self, kb_dir: str):
        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError(
                f"‚ùå KB files not found in: {kb_dir}\n"
                f"Expected:\n- {index_path}\n- {meta_path}\n\n"
                f"Build the KB first using your KB builder script."
            )

        print(f"‚úÖ Loading FAISS Index: {index_path}")
        self.index = faiss.read_index(index_path)

        print(f"‚úÖ Loading KB Metadata: {meta_path}")
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        print(f"‚úÖ KB Loaded: {len(self.meta)} rows")

    def retrieve(self, query_text: str, top_k: int = 3):
        """
        Returns:
        [
          {"score": float, "meta": {...all 29 cols...}},
          ...
        ]
        """
        import numpy as np
        from sentence_transformers import SentenceTransformer

        query_text = str(query_text).strip()
        if not query_text:
            return []

        # Use embedding model only when needed (lazy load)
        if not hasattr(self, "embedder"):
            self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)

        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0:
                continue
            results.append({"score": float(score), "meta": self.meta[idx]})

        return results


# ==============================================================================
# ‚úÖ 1. CONFIGURATION & FILE PATHS
# ==============================================================================

TAXONOMY_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\taxonomy.json"
SUPPLIERS_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\suppliers.json"
INPUT_EXCEL_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\source_file.xlsx"
OUTPUT_EXCEL_PATH = "Processed_Defense_Data.xlsx"

RAG_KB_DIR = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\system_kb_store"


# ==============================================================================
# ‚úÖ 2. API Setup
# ==============================================================================

if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter the LLM Foundry API Key: ")

client = OpenAI(
    api_key=f'{os.environ.get("LLMFOUNDRY_TOKEN")}:my-test-project',
    base_url="https://llmfoundry.straive.com/openai/v1/",
)

retriever = SystemKBRetriever(kb_dir=RAG_KB_DIR)


# ==============================================================================
# ‚úÖ 3. JSON Loaders
# ==============================================================================

def load_json_file(filename, default_value):
    try:
        with open(filename, "r", encoding="utf-8") as f:
            print(f"‚úÖ Loaded: {filename}")
            return json.load(f)
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not load {filename} ({e}). Using default.")
        return default_value


raw_taxonomy = load_json_file(TAXONOMY_PATH, {})
TAXONOMY_STR = json.dumps(raw_taxonomy, separators=(",", ":"))


# ‚úÖ IMPORTANT: Supplier list MUST come from suppliers.json
SUPPLIER_LIST = load_json_file(SUPPLIERS_PATH, [])
if not SUPPLIER_LIST:
    raise ValueError("‚ùå suppliers.json loaded 0 suppliers. Please verify the path.")


# ==============================================================================
# ‚úÖ 4. Rule Book + Geography Mapping
# ==============================================================================

RULE_BOOK = {
    "defensive_countermeasures": {
        "triggers": ["flare", "chaff", "countermeasure", "decoy", "mju-", "ale-"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Defensive Systems', Specific: 'Defensive Aid Suite'",
    },
    "radars_and_sensors": {
        "triggers": ["radar", "sonar", "sensor", "an/apy", "an/tpy"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Sensors'",
    },
    "ammunition": {
        "triggers": ["cartridge", "round", "projectile", " 5.56", " 7.62", "ammo"],
        "guidance": "Market Segment: 'Weapon Systems', System Type (General): 'Ammunition'",
    },
}

GEOGRAPHY_MAPPING = {
    "North America": ["USA", "United States", "US", "Canada", "America"],
    "Europe": ["UK", "United Kingdom", "Ukraine", "Germany", "France", "Italy", "Spain", "Poland", "Netherlands", "Norway", "Sweden", "Finland", "Denmark", "Belgium"],
    "Asia-Pacific": ["Australia", "Japan", "South Korea", "Taiwan", "India", "Singapore", "New Zealand"],
    "Middle East and North Africa": ["Israel", "Saudi Arabia", "UAE", "Egypt", "Qatar", "Kuwait", "Iraq"],
    "International Organisations": ["NATO", "EU", "IFU", "UN", "NSPA"],
}


# ==============================================================================
# ‚úÖ 5. HELPER FUNCTIONS (Supplier Matching FIXED ‚úÖ‚úÖ‚úÖ)
# ==============================================================================

def normalize_supplier_text(x: str) -> str:
    if not x:
        return ""
    x = str(x).strip()

    # normalize unicode dashes
    x = x.replace("‚Äì", "-").replace("‚Äî", "-")

    # remove multiple spaces
    x = re.sub(r"\s+", " ", x).strip()

    # remove extra location after comma
    # ex: "General Dynamics NASSCO - San Diego, San Diego, California"
    x = x.split(",")[0].strip()

    return x


def token_overlap_score(a: str, b: str) -> float:
    a_tokens = set(re.findall(r"[a-z0-9]+", a.lower()))
    b_tokens = set(re.findall(r"[a-z0-9]+", b.lower()))
    if not a_tokens or not b_tokens:
        return 0.0
    return len(a_tokens & b_tokens) / max(len(a_tokens), len(b_tokens))


def get_best_supplier_match(extracted_supplier: str):
    """
    ‚úÖ FINAL Supplier Name MUST be from SUPPLIER_LIST (suppliers.json)
    Returns:
      (best_supplier_name, best_score)
    """
    if not extracted_supplier:
        return "Unknown", 0.0

    extracted_supplier = normalize_supplier_text(extracted_supplier)
    low = extracted_supplier.lower()

    if low in ["unknown", "n/a", "not applicable"]:
        return "Unknown", 0.0

    # 1) Exact match
    supplier_map = {s.lower(): s for s in SUPPLIER_LIST}
    if low in supplier_map:
        return supplier_map[low], 1.0

    # 2) Containment match (best practical for long extracted)
    containment_hits = [s for s in SUPPLIER_LIST if s.lower() in low]
    if containment_hits:
        containment_hits.sort(key=len, reverse=True)  # longest = most specific
        return containment_hits[0], 0.95

    # 3) Hybrid fuzzy match over supplier list
    best_name = "Unknown"
    best_score = 0.0

    for s in SUPPLIER_LIST:
        s_clean = normalize_supplier_text(s)

        seq = SequenceMatcher(None, low, s_clean.lower()).ratio()
        tok = token_overlap_score(extracted_supplier, s_clean)

        final = (0.65 * tok) + (0.35 * seq)

        if final > best_score:
            best_score = final
            best_name = s

    # strict cutoff: avoid wrong match
    if best_score < 0.45:
        return "Unknown", best_score

    return best_name, best_score


def calculate_mro_months(start_date_str, end_date_text, program_type):
    """Calculates duration only if Program Type is MRO/Support."""
    if program_type != "MRO/Support":
        return "Not Applicable"

    try:
        if not start_date_str or not end_date_text or str(end_date_text).lower() in ["unknown", "n/a"]:
            return "Not Applicable"

        start = pd.to_datetime(start_date_str, dayfirst=True)
        end = parser.parse(end_date_text, fuzzy=True)

        diff = relativedelta(end, start)
        total_months = (diff.years * 12) + diff.months
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"


def get_region_for_country(country_name):
    """Robust lookup handling casing/whitespace."""
    if not country_name or country_name.lower() in ["unknown", "n/a", "not applicable"]:
        return "Unknown"

    clean = country_name.strip().lower()

    if clean in ["us", "usa", "united states", "united states of america"]:
        return "North America"
    if clean in ["uk", "britain", "great britain"]:
        return "Europe"

    for region, countries in GEOGRAPHY_MAPPING.items():
        if any(c.lower() == clean for c in countries):
            return region
    return "Unknown"


# ==============================================================================
# ‚úÖ 6. Designators (System Name + Piloting)
# ==============================================================================

DESIGNATOR_PATTERNS = [
    r"\bDDG[-\s]?\d+\b", r"\bCVN[-\s]?\d+\b", r"\bSSN[-\s]?\d+\b",
    r"\bLCS[-\s]?\d+\b", r"\bLPD[-\s]?\d+\b", r"\bLHA[-\s]?\d+\b", r"\bLHD[-\s]?\d+\b",
    r"\bF-\d+\b", r"\bB-\d+\b", r"\bC-\d+\b", r"\bA-\d+\b",
    r"\bMQ-\d+\b", r"\bRQ-\d+\b",
    r"\bAN\/[A-Z0-9\-]+\b",
    r"\b(AIM|AGM|SM|RIM|MIM)-\d+\b",
]

def extract_designators(text: str):
    text = str(text)
    found = []
    for pat in DESIGNATOR_PATTERNS:
        found.extend(re.findall(pat, text, flags=re.IGNORECASE))

    cleaned = []
    for f in found:
        cleaned.append(f.upper().replace(" ", "").replace("--", "-"))

    seen = set()
    final = []
    for x in cleaned:
        if x not in seen:
            final.append(x)
            seen.add(x)
    return final


def detect_piloting_rule_based(text: str, designators: List[str]) -> str:
    t = str(text).lower()

    if any(d.startswith(("MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    if any(k in t for k in ["unmanned", "uav", "drone", "autonomous"]):
        return "Uncrewed"

    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD")) for d in designators):
        return "Crewed"
    if "uss " in t:
        return "Crewed"

    return "Not Applicable"


# ==============================================================================
# ‚úÖ 7. TOOL DEFINITIONS (Agents)
# ==============================================================================

# --- AGENT 1: SOURCING ---
class SourcingInput(BaseModel):
    paragraph: str = Field(description="Contract text.")
    url: str = Field(description="Source URL.")
    date: str = Field(description="Contract Date.")

@tool("sourcing_extractor")
def sourcing_extractor(paragraph: str, url: str, date: str):
    reported_date = datetime.datetime.now().strftime("%Y-%m-%d")
    notes = "Standard extraction."
    if "modification" in paragraph.lower():
        notes = "Contract Modification."
    if "split" in paragraph.lower():
        notes = "Split award detected."

    return {
        "Description of Contract": paragraph,
        "Additional Notes (Internal Only)": notes,
        "Source Link(s)": url,
        "Contract Date": date,
        "Reported Date (By SGA)": reported_date
    }


# --- AGENT 2: GEOGRAPHY ---
class GeographyInput(BaseModel):
    paragraph: str = Field(description="Contract text.")

@tool("geography_extractor")
def geography_extractor(paragraph: str):
    sys_prompt = """
Extract: Customer Country, Supplier Country, Customer Operator.
Logic: If 'Navy awarded...', Operator is Navy.
Return JSON:
{
  "Customer Country": "...",
  "Customer Operator": "...",
  "Supplier Country": "..."
}
"""
    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": paragraph}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
    except Exception:
        raw = {}

    cust = raw.get("Customer Country", "Unknown")
    supp = raw.get("Supplier Country", "Unknown")

    domestic = "Indigenous" if cust.lower() == supp.lower() else "Imported"
    if "united states" in cust.lower() and "usa" in supp.lower():
        domestic = "Indigenous"

    return {
        "Customer Region": get_region_for_country(cust),
        "Customer Country": cust,
        "Customer Operator": raw.get("Customer Operator", "Unknown"),
        "Supplier Region": get_region_for_country(supp),
        "Supplier Country": supp,
        "Domestic Content": domestic
    }


# --- AGENT 3: SYSTEM (RAG + Evidence + Reason) ---
class SystemInput(BaseModel):
    paragraph: str = Field(description="Contract text.")

@tool("system_classifier")
def system_classifier(paragraph: str):
    paragraph = str(paragraph).strip()
    if not paragraph:
        return {}

    lower_text = paragraph.lower()

    hints = [
        f"RULE: {v['guidance']}"
        for _, v in RULE_BOOK.items()
        if any(t in lower_text for t in v["triggers"])
    ]
    hint_str = "\n".join(hints) if hints else "No special override rules triggered."

    designators = extract_designators(paragraph)
    piloting_rule = detect_piloting_rule_based(paragraph, designators)

    rag_hits = retriever.retrieve(paragraph, top_k=3)

    rag_examples = []
    for hit in rag_hits:
        meta = hit["meta"]
        rag_examples.append({
            "score": round(hit["score"], 4),
            "Market Segment": meta.get("Market Segment", ""),
            "System Type (General)": meta.get("System Type (General)", ""),
            "System Type (Specific)": meta.get("System Type (Specific)", ""),
            "System Name (General)": meta.get("System Name (General)", ""),
            "System Name (Specific)": meta.get("System Name (Specific)", ""),
            "System Piloting": meta.get("System Piloting", ""),
            "Supplier Name": meta.get("Supplier Name", ""),
            "Customer Operator": meta.get("Customer Operator", ""),
            "Snippet": meta.get("Description of Contract", "")[:220] + "..."
        })

    sys_prompt = f"""
You are a Senior Defense System Classification Analyst.

Use these inputs:
1) RAG Similar Examples (top 3)
2) Rule Book Overrides
3) Extracted Designators

RULE BOOK OVERRIDES:
{hint_str}

OUTPUT RULES:
- Return ONLY a FLAT JSON object.
- Every value must be a STRING.
- Do NOT return nested JSON or lists.
- Evidence must be copied EXACTLY from paragraph.
- If evidence not present, use "Not Found".

Return JSON exactly:
{{
  "Market Segment": "",
  "Market Segment Evidence": "",
  "Market Segment Reason": "",

  "System Type (General)": "",
  "System Type (General) Evidence": "",
  "System Type (General) Reason": "",

  "System Type (Specific)": "",
  "System Type (Specific) Evidence": "",
  "System Type (Specific) Reason": "",

  "System Name (General)": "",
  "System Name (General) Evidence": "",
  "System Name (General) Reason": "",

  "System Name (Specific)": "",
  "System Name (Specific) Evidence": "",
  "System Name (Specific) Reason": "",

  "System Piloting": "",
  "System Piloting Evidence": "",
  "System Piloting Reason": "",

  "Confidence": "High/Medium/Low"
}}
"""

    user_prompt = f"""
INPUT PARAGRAPH:
{paragraph}

DESIGNATORS (regex extracted):
{designators if designators else "None"}

RULE-BASED PILOTING:
{piloting_rule}

RAG SIMILAR EXAMPLES:
{json.dumps(rag_examples, indent=2)}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        result = json.loads(completion.choices[0].message.content)

        # Piloting override
        result["System Piloting"] = piloting_rule
        if not result.get("System Piloting Reason"):
            result["System Piloting Reason"] = "Piloting derived using deterministic rules."
        if not result.get("System Piloting Evidence"):
            result["System Piloting Evidence"] = "Not Found"

        # Ensure flat
        for k, v in result.items():
            if isinstance(v, (dict, list)):
                result[k] = str(v)

        return result

    except Exception as e:
        return {
            "Market Segment": "",
            "System Type (General)": "",
            "System Type (Specific)": "",
            "System Name (General)": "",
            "System Name (Specific)": "",
            "System Piloting": piloting_rule,
            "Market Segment Evidence": "Not Found",
            "System Type (General) Evidence": "Not Found",
            "System Type (Specific) Evidence": "Not Found",
            "System Name (General) Evidence": "Not Found",
            "System Name (Specific) Evidence": "Not Found",
            "System Piloting Evidence": "Not Found",
            "Market Segment Reason": "",
            "System Type (General) Reason": "",
            "System Type (Specific) Reason": "",
            "System Name (General) Reason": "",
            "System Name (Specific) Reason": "",
            "System Piloting Reason": "Piloting derived using deterministic rules.",
            "Confidence": "Low",
            "Error": str(e)
        }


# --- AGENT 4: CONTRACT (Supplier Match FIXED ‚úÖ‚úÖ‚úÖ) ---
class ContractInfoInput(BaseModel):
    paragraph: str = Field(description="Contract text.")
    contract_date: str = Field(description="Signed date.")

@tool("contract_extractor")
def contract_extractor(paragraph: str, contract_date: str):
    system_instruction = """
You are a Defense Contract Financial Analyst. Extract data strictly following these SOP rules:

1) Supplier Name: Extract exact company name as written in paragraph.
2) Program Type: Procurement / Training / MRO/Support / RDT&E / Upgrade / Other Service
3) Value Certainty:
   - Confirmed: fixed price/obligated stated
   - Estimated: ceiling/potential/IDIQ/multi-award
4) Quantity: number or Not Applicable
5) G2G/B2G:
   - G2G only if Foreign Military Sales (FMS)
   - else B2G
6) Completion Date: only needed for MRO duration calc

Return JSON only.
"""

    user_prompt = f"""
Analyze contract:
"{paragraph}"

Signed Date: {contract_date}

Return JSON:
{{
  "raw_supplier_name": "",
  "program_type": "",
  "value_million_raw": "",
  "currency_code": "",
  "value_certainty": "",
  "quantity": "",
  "completion_date_text": "",
  "g2g_b2g": "",
  "value_note": ""
}}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
    except Exception as e:
        return {"Error": str(e)}

    raw_supplier = str(raw.get("raw_supplier_name", "")).strip()

    # ‚úÖ Fallback supplier extraction if LLM returns blank
    if not raw_supplier:
        m = re.search(r"^(.*?)( is awarded| is being awarded)", paragraph, flags=re.IGNORECASE)
        if m:
            raw_supplier = m.group(1).strip()

    # ‚úÖ FINAL supplier must come from suppliers.json best match
    final_supplier, supplier_match_score = get_best_supplier_match(raw_supplier)

    prog_type = raw.get("program_type", "Unknown")
    mro_months = calculate_mro_months(contract_date, raw.get("completion_date_text"), prog_type)

    try:
        val_str = str(raw.get("value_million_raw", "0")).replace(",", "").replace("$", "")
        val_float = float(val_str)
        val_formatted = "{:.3f}".format(val_float)
    except:
        val_formatted = "0.000"

    try:
        dt = pd.to_datetime(contract_date)
        sign_month = dt.strftime("%B")
        sign_year = str(dt.year)
    except:
        sign_month, sign_year = "Unknown", "Unknown"

    val_note = raw.get("value_note", "Not Applicable")
    if "split" in paragraph.lower() and val_note == "Not Applicable":
        val_note = "Split contract; value distribution unclear."

    return {
        # ‚úÖ Final output supplier from JSON
        "Supplier Name": final_supplier,

        # ‚úÖ Debug (optional but VERY useful)
        "Supplier Name Raw (LLM)": raw_supplier,
        "Supplier Match Score": round(float(supplier_match_score), 3),

        "Program Type": prog_type,
        "Expected MRO Contract Duration (Months)": mro_months,
        "Quantity": raw.get("quantity", "Not Applicable"),
        "Value Certainty": raw.get("value_certainty", "Confirmed"),
        "Value (Million)": val_formatted,
        "Currency": raw.get("currency_code", "USD$"),
        "Value (USD$ Million)": val_formatted,
        "Value Note (If Any)": val_note,
        "G2G/B2G": raw.get("g2g_b2g", "B2G"),
        "Signing Month": sign_month,
        "Signing Year": sign_year
    }


# ==============================================================================
# ‚úÖ 8. LANGGRAPH PIPELINE
# ==============================================================================

class AgentState(TypedDict):
    input_text: str
    input_date: str
    input_url: str
    final_data: dict
    messages: Annotated[List[AnyMessage], add_messages]

def stage_1_sourcing(state: AgentState):
    res = sourcing_extractor.invoke({
        "paragraph": state["input_text"],
        "url": state["input_url"],
        "date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_2_geography(state: AgentState):
    res = geography_extractor.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_3_system(state: AgentState):
    res = system_classifier.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_4_contract(state: AgentState):
    res = contract_extractor.invoke({
        "paragraph": state["input_text"],
        "contract_date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

workflow = StateGraph(AgentState)
workflow.add_node("Stage1", stage_1_sourcing)
workflow.add_node("Stage2", stage_2_geography)
workflow.add_node("Stage3", stage_3_system)
workflow.add_node("Stage4", stage_4_contract)

workflow.add_edge(START, "Stage1")
workflow.add_edge("Stage1", "Stage2")
workflow.add_edge("Stage2", "Stage3")
workflow.add_edge("Stage3", "Stage4")
workflow.add_edge("Stage4", END)

app = workflow.compile()


# ==============================================================================
# ‚úÖ 9. EXECUTION & OUTPUT FORMAT
# ==============================================================================

if __name__ == "__main__":

    print(f"üìÇ Loading Input File: {INPUT_EXCEL_PATH}...")

    try:
        df_input = pd.read_excel(INPUT_EXCEL_PATH)

        required_cols = ["Source URL", "Contract Date", "Contract Description"]
        if not all(col in df_input.columns for col in required_cols):
            raise ValueError(f"Excel file must contain columns: {required_cols}")

        print(f"üöÄ Processing {len(df_input)} rows...")
        results = []

        for index, row in df_input.iterrows():
            print(f"   -> Row {index + 1}...")

            desc = str(row["Contract Description"]) if pd.notna(row["Contract Description"]) else ""
            c_date = str(row["Contract Date"]) if pd.notna(row["Contract Date"]) else str(datetime.date.today())
            c_url = str(row["Source URL"]) if pd.notna(row["Source URL"]) else ""

            initial_state = {
                "input_text": desc,
                "input_date": c_date,
                "input_url": c_url,
                "final_data": {},
                "messages": []
            }

            output_state = app.invoke(initial_state)
            results.append(output_state["final_data"])

        df_final = pd.DataFrame(results)

        FINAL_COLUMNS = [
            "Customer Region", "Customer Country", "Customer Operator",
            "Supplier Region", "Supplier Country", "Domestic Content",

            "Market Segment", "Market Segment Evidence", "Market Segment Reason",
            "System Type (General)", "System Type (General) Evidence", "System Type (General) Reason",
            "System Type (Specific)", "System Type (Specific) Evidence", "System Type (Specific) Reason",
            "System Name (General)", "System Name (General) Evidence", "System Name (General) Reason",
            "System Name (Specific)", "System Name (Specific) Evidence", "System Name (Specific) Reason",
            "System Piloting", "System Piloting Evidence", "System Piloting Reason",
            "Confidence",

            # ‚úÖ Supplier Debug (optional but recommended)
            "Supplier Name Raw (LLM)",
            "Supplier Match Score",

            "Supplier Name", "Program Type", "Expected MRO Contract Duration (Months)",
            "Quantity", "Value Certainty", "Value (Million)", "Currency",
            "Value (USD$ Million)", "Value Note (If Any)", "G2G/B2G",
            "Signing Month", "Signing Year",

            "Description of Contract",
            "Additional Notes (Internal Only)",
            "Source Link(s)",
            "Contract Date",
            "Reported Date (By SGA)"
        ]

        df_final = df_final.reindex(columns=FINAL_COLUMNS, fill_value="")
        df_final.to_excel(OUTPUT_EXCEL_PATH, index=False)

        print("\n‚úÖ Processing Complete!")
        print(f"üíæ File saved to: {OUTPUT_EXCEL_PATH}")
        print(df_final.head().to_string())

    except Exception as e:
        print(f"\n‚ùå Error: {e}")


In [None]:
import os
import re
import json
import difflib
import pickle
import datetime
from typing import Annotated, TypedDict, List, Dict, Any

import pandas as pd
import faiss
from dateutil import parser
from dateutil.relativedelta import relativedelta
import getpass

# LangGraph / LangChain
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from openai import OpenAI

# Excel formatting
from openpyxl import load_workbook
from openpyxl.styles import PatternFill, Font


# ==============================================================================
# 0) DEBUG LOGGING HELPERS
# ==============================================================================

def log_block(title: str, content: str):
    """
    Prints a clearly separated block in console logs.

    Why this matters:
    - You want to validate LLM extraction decisions row-by-row.
    - Helps debugging issues in specific stages like System Classification or Split logic.
    """
    print("\n" + "=" * 100)
    print(title)
    print("=" * 100)
    print(content)


# ==============================================================================
# 1) RAG RETRIEVER (FAISS + Metadata)
# ==============================================================================

class SystemKBRetriever:
    """
    RAG Retriever for Defense System Classification.

    Purpose:
    - Loads FAISS vector index (system_kb.faiss)
    - Loads metadata rows (system_kb_meta.pkl)
    - Retrieves top-k similar historical examples
      using semantic embeddings over "Description of Contract".

    Why this improves accuracy:
    - Your taxonomy-based system classification becomes consistent
      because the model sees "known-good labeled examples" from your excel KB.

    Output:
    retrieve(query_text) returns:
      [
        {"score": float, "meta": {all KB columns}},
        ...
      ]
    """

    def __init__(self, kb_dir: str, embed_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
        self.kb_dir = kb_dir
        self.embed_model = embed_model

        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError(
                f"‚ùå KB files not found in: {kb_dir}\n"
                f"Expected:\n- {index_path}\n- {meta_path}\n\n"
                f"Build the KB first using your KB builder script."
            )

        print(f"‚úÖ Loading FAISS Index: {index_path}")
        self.index = faiss.read_index(index_path)

        print(f"‚úÖ Loading KB Metadata: {meta_path}")
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        print(f"‚úÖ KB Loaded rows: {len(self.meta)}")

        self.embedder = None

    def _lazy_load_embedder(self):
        """
        Lazy-load the embedding model only when needed.

        Why:
        - Faster pipeline startup
        - Avoids memory overhead until the first retrieval call
        """
        if self.embedder is None:
            from sentence_transformers import SentenceTransformer
            self.embedder = SentenceTransformer(self.embed_model)

    def retrieve(self, query_text: str, top_k: int = 3):
        """
        Retrieves top-k semantic matches from the KB.

        Parameters:
        - query_text: input paragraph
        - top_k: number of examples

        Returns:
        - List of dicts with score + metadata row.
        """
        import numpy as np

        query_text = str(query_text).strip()
        if not query_text:
            return []

        self._lazy_load_embedder()

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)

        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0:
                continue
            results.append({"score": float(score), "meta": self.meta[idx]})

        return results


# ==============================================================================
# 2) CONFIGURATION & FILE PATHS
# ==============================================================================

TAXONOMY_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\taxonomy.json"
SUPPLIERS_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\suppliers.json"
INPUT_EXCEL_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\source_file.xlsx"
OUTPUT_EXCEL_PATH = "Processed_Defense_Data.xlsx"

RAG_KB_DIR = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\system_kb_store"


# Setup API key once
if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter the LLM Foundry API Key: ")

client = OpenAI(
    api_key=f'{os.environ.get("LLMFOUNDRY_TOKEN")}:my-test-project',
    base_url="https://llmfoundry.straive.com/openai/v1/",
)

retriever = SystemKBRetriever(kb_dir=RAG_KB_DIR)


# ==============================================================================
# 3) LOAD JSON HELPERS
# ==============================================================================

def load_json_file(filename, default_value):
    """
    Loads a JSON file safely.

    Why:
    - Your taxonomy and supplier list must load reliably
    - Prevents pipeline crash if the file is missing
    """
    try:
        with open(filename, "r", encoding="utf-8") as f:
            print(f"‚úÖ Loaded: {filename}")
            return json.load(f)
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not load {filename} ({e}). Using default.")
        return default_value


raw_taxonomy = load_json_file(TAXONOMY_PATH, {})
TAXONOMY_STR = json.dumps(raw_taxonomy, separators=(",", ":"))

SUPPLIER_LIST = load_json_file(SUPPLIERS_PATH, [
    "Dell Inc", "Boeing", "Lockheed Martin", "Raytheon Technologies",
    "Northrop Grumman", "L3Harris", "BAE Systems", "General Dynamics"
])


# ==============================================================================
# 4) RULE BOOK + GEOGRAPHY
# ==============================================================================

RULE_BOOK = {
    "defensive_countermeasures": {
        "triggers": ["flare", "chaff", "countermeasure", "decoy", "mju-", "ale-"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Defensive Systems', Specific: 'Defensive Aid Suite'"
    },
    "radars_and_sensors": {
        "triggers": ["radar", "sonar", "sensor", "an/apy", "an/tpy"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Sensors'"
    },
    "ammunition": {
        "triggers": ["cartridge", "round", "projectile", " 5.56", " 7.62", "ammo"],
        "guidance": "Market Segment: 'Weapon Systems', System Type (General): 'Ammunition'"
    }
}

GEOGRAPHY_MAPPING = {
    "North America": ["USA", "United States", "US", "United States of America", "Canada", "America"],
    "Europe": ["UK", "United Kingdom", "Ukraine", "Germany", "France", "Italy", "Spain", "Poland", "Netherlands",
               "Norway", "Sweden", "Finland", "Denmark", "Belgium"],
    "Asia-Pacific": ["Australia", "Japan", "South Korea", "Taiwan", "India", "Singapore", "New Zealand"],
    "Middle East and North Africa": ["Israel", "Saudi Arabia", "UAE", "United Arab Emirates", "Egypt", "Qatar", "Kuwait", "Iraq"],
    "International Organisations": ["NATO", "EU", "IFU", "UN", "NSPA"]
}


# ==============================================================================
# 5) BASE HELPERS (Supplier + Dates + Region + Designators)
# ==============================================================================

def get_best_supplier_match(extracted_name: str):
    """
    Supplier standardization function.

    Steps:
    1) Exact match with suppliers.json
    2) Fuzzy match (difflib) to find best candidate
    3) If no match, return extracted text

    Goal:
    - Ensure supplier name output matches "standard supplier taxonomy"
      used by your client.
    """
    if not extracted_name or str(extracted_name).strip().lower() in ["unknown", "n/a", "not applicable"]:
        return "Unknown"

    clean_name = str(extracted_name).strip()

    supplier_map = {s.lower(): s for s in SUPPLIER_LIST}
    if clean_name.lower() in supplier_map:
        return supplier_map[clean_name.lower()]

    matches = difflib.get_close_matches(clean_name, SUPPLIER_LIST, n=1, cutoff=0.6)
    return matches[0] if matches else clean_name


def calculate_mro_months(start_date_str, end_date_text, program_type):
    """
    Calculates MRO duration in months.

    Rule:
    - ONLY valid if program_type == "MRO/Support"
    """
    if program_type != "MRO/Support":
        return "Not Applicable"
    try:
        if not start_date_str or not end_date_text:
            return "Not Applicable"

        start = pd.to_datetime(start_date_str, dayfirst=True)
        end = parser.parse(str(end_date_text), fuzzy=True)

        diff = relativedelta(end, start)
        total_months = diff.years * 12 + diff.months
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"


def get_region_for_country(country_name):
    """
    Maps country name -> region string.
    Uses GEOGRAPHY_MAPPING.
    """
    if not country_name or str(country_name).strip().lower() in ["unknown", "n/a", "not applicable"]:
        return "Unknown"

    clean = str(country_name).strip().lower()

    if clean in ["us", "usa", "u.s.", "united states", "united states of america"]:
        return "North America"
    if clean in ["uk", "u.k.", "britain", "great britain"]:
        return "Europe"

    for region, countries in GEOGRAPHY_MAPPING.items():
        if any(c.lower() == clean for c in countries):
            return region

    return "Unknown"


DESIGNATOR_PATTERNS = [
    r"\bDDG[-\s]?\d+\b", r"\bCVN[-\s]?\d+\b", r"\bSSN[-\s]?\d+\b",
    r"\bLCS[-\s]?\d+\b", r"\bLPD[-\s]?\d+\b", r"\bLHA[-\s]?\d+\b", r"\bLHD[-\s]?\d+\b",
    r"\bF-\d+\b", r"\bB-\d+\b", r"\bC-\d+\b", r"\bA-\d+\b",
    r"\bMQ-\d+\b", r"\bRQ-\d+\b",
    r"\bAN\/[A-Z0-9\-]+\b",
    r"\b(AIM|AGM|SM|RIM|MIM)-\d+\b",
]

def extract_designators(text: str):
    """
    Extracts common defense platform designators from paragraph.

    Examples:
    - DDG-51, CVN-78
    - MQ-9, RQ-4
    - AIM-9X
    - AN/APY-10
    """
    text = str(text)
    found = []
    for pat in DESIGNATOR_PATTERNS:
        found.extend(re.findall(pat, text, flags=re.IGNORECASE))
    cleaned = []
    for f in found:
        cleaned.append(f.upper().replace(" ", "").replace("--", "-"))
    final = []
    seen = set()
    for x in cleaned:
        if x not in seen:
            final.append(x)
            seen.add(x)
    return final


def detect_piloting_rule_based(text: str, designators: List[str]) -> str:
    """
    Deterministic piloting classification to reduce model errors.

    Output:
    - "Crewed"
    - "Uncrewed"
    - "Not Applicable"
    """
    t = str(text).lower()

    if any(d.startswith(("MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    if any(k in t for k in ["unmanned", "uav", "drone", "autonomous"]):
        return "Uncrewed"

    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD")) for d in designators):
        return "Crewed"
    if "uss " in t:
        return "Crewed"

    return "Not Applicable"


# ==============================================================================
# 6) ENHANCED SPLIT ENGINE (Multi-operator / Multi-country / Multi-supplier / Multi-value)
# ==============================================================================

def parse_operator_quantity_allocations(paragraph: str):
    """
    Detects quantity allocations by operator.

    Example patterns:
      - "212 for the Navy"
      - "187 for the Air Force"
      - "84 for Foreign Military Sales (FMS) customers"

    Output:
      [
        {"operator": "Navy", "quantity": "212", "g2g_b2g": "B2G"},
        {"operator": "Foreign Assistance", "quantity": "84", "g2g_b2g": "G2G"}
      ]
    """
    text = str(paragraph)
    allocations = []

    pattern = r"(\d+)\s+for\s+the\s+(Navy|Air Force|Army|Marine Corps)"
    matches = re.findall(pattern, text, flags=re.IGNORECASE)
    for qty, op in matches:
        allocations.append({"operator": op.title(), "quantity": qty, "g2g_b2g": "B2G"})

    fms_pattern = r"(\d+)\s+for\s+(?:Foreign Military Sales\s*\(FMS\)\s*customers|FMS\s*customers|a\s*FMS\s*customer|FMS)"
    fms_matches = re.findall(fms_pattern, text, flags=re.IGNORECASE)
    for qty in fms_matches:
        allocations.append({"operator": "Foreign Assistance", "quantity": qty, "g2g_b2g": "G2G"})

    unique = []
    seen = set()
    for a in allocations:
        key = (a["operator"], a["quantity"], a["g2g_b2g"])
        if key not in seen:
            unique.append(a)
            seen.add(key)

    return unique


def parse_fms_countries(paragraph: str):
    """
    Extracts FMS customer country list.

    Looks for:
      'governments of Australia, Bahrain, Belgium...'

    Output: ["Australia", "Bahrain", ...]
    """
    text = str(paragraph)

    m = re.search(
        r"governments of (.+?)(?:\.\s| Work will be performed| Fiscal| This contract|$)",
        text,
        flags=re.IGNORECASE | re.DOTALL
    )
    if not m:
        return []

    block = m.group(1)
    raw = re.split(r",|\band\b", block)

    countries = []
    for c in raw:
        c = c.strip()
        if 2 < len(c) <= 40:
            countries.append(c)

    final = []
    seen = set()
    for c in countries:
        if c.lower() not in seen:
            final.append(c)
            seen.add(c.lower())

    return final


def parse_multiple_suppliers(paragraph: str):
    """
    Attempts to detect multi-supplier contract statements.

    Examples:
    - "Lockheed Martin and Raytheon were awarded..."
    - "Boeing, Northrop Grumman, and General Dynamics..."
    - "multiple awardees include..."

    Output:
      ["Lockheed Martin", "Raytheon"]
    """
    text = str(paragraph)
    lower = text.lower()

    if " and " not in lower and "," not in lower:
        return []

    candidates = []
    for supplier in SUPPLIER_LIST:
        if supplier.lower() in lower:
            candidates.append(supplier)

    # If multiple suppliers found ‚Üí split required
    candidates = list(dict.fromkeys(candidates))
    return candidates if len(candidates) >= 2 else []


def parse_multiple_values(paragraph: str):
    """
    Detects multiple financial values in paragraph.

    Examples:
    - "$328,156,454"
    - "ceiling value of $500 million"
    - "base value $20 million and option value $10 million"

    Output:
      ["328,156,454", "500", ...] (raw strings)
    """
    text = str(paragraph)

    # $328,156,454
    money_pattern = r"\$([\d,]+(?:\.\d+)?)"
    vals = re.findall(money_pattern, text)

    # remove duplicates
    vals = list(dict.fromkeys(vals))
    return vals


def parse_share_percentages(paragraph: str):
    """
    Detects share splits like:
      '60% for Company A and 40% for Company B'

    Output:
      [
        {"supplier": "Company A", "percentage": "60"},
        {"supplier": "Company B", "percentage": "40"}
      ]
    """
    text = str(paragraph)
    result = []

    percent_pattern = r"(\d+)\s?%\s+(?:for\s+)?([A-Z][A-Za-z0-9&\-\.\s]+)"
    matches = re.findall(percent_pattern, text)

    for pct, name in matches:
        name = name.strip()
        # Try to best-match supplier name list
        supplier_std = get_best_supplier_match(name)
        result.append({"supplier": supplier_std, "percentage": pct})

    # keep only valid multi splits
    if len(result) >= 2:
        return result
    return []


def detect_multi_country_presence(paragraph: str):
    """
    Lightweight country scan from mapping lists to detect multi-country mention.
    """
    text = str(paragraph).lower()
    countries_found = set()

    for region, countries in GEOGRAPHY_MAPPING.items():
        for c in countries:
            if c.lower() in text:
                countries_found.add(c)

    return list(countries_found)


def split_rows_engine(base_row: dict, paragraph: str):
    """
    MASTER SPLIT ENGINE.

    Goal:
    - Takes one extracted row (base_row)
    - Produces 1..N output rows depending on split conditions

    Split triggers supported:
    1) Operator allocation splits (Navy/Air Force/FMS)
    2) Multi-country FMS customer lists
    3) Multi supplier mentions
    4) Multi values inside same paragraph
    5) Percent share splits
    6) Multi-country region scan

    Important:
    - Only split-driving columns should change.
    - Shared columns remain consistent.
    """
    paragraph = str(paragraph)

    allocations = parse_operator_quantity_allocations(paragraph)
    fms_countries = parse_fms_countries(paragraph)
    multi_suppliers = parse_multiple_suppliers(paragraph)
    multi_values = parse_multiple_values(paragraph)
    share_splits = parse_share_percentages(paragraph)
    multi_countries = detect_multi_country_presence(paragraph)

    split_reasons = []

    if allocations:
        split_reasons.append("Multi-operator allocation found")
    if fms_countries:
        split_reasons.append("FMS multi-country list found")
    if multi_suppliers:
        split_reasons.append("Multi-supplier mention found")
    if len(multi_values) >= 2:
        split_reasons.append("Multiple financial values found")
    if share_splits:
        split_reasons.append("Percentage share split found")
    if len(multi_countries) >= 2:
        split_reasons.append("Multiple countries detected in text")

    if not split_reasons:
        base_row["Split Flag"] = "No"
        base_row["Split Reason"] = "No split condition found"
        return [base_row]

    # Start with one base row
    rows = [base_row.copy()]
    base_reason = " | ".join(split_reasons)

    # 1) Multi supplier split
    if multi_suppliers:
        new_rows = []
        for r in rows:
            for s in multi_suppliers:
                rr = r.copy()
                rr["Supplier Name"] = s
                rr["Split Flag"] = "Yes"
                rr["Split Reason"] = f"{base_reason} (Supplier split)"
                new_rows.append(rr)
        rows = new_rows

    # 2) Operator split
    if allocations:
        new_rows = []
        for r in rows:
            for alloc in allocations:
                rr = r.copy()
                rr["Customer Operator"] = alloc["operator"]
                rr["Quantity"] = alloc["quantity"]
                rr["G2G/B2G"] = alloc["g2g_b2g"]
                rr["Split Flag"] = "Yes"
                rr["Split Reason"] = f"{base_reason} (Operator/Quantity split)"
                new_rows.append(rr)
        rows = new_rows

    # 3) Share split ‚Üí can modify Value Note
    if share_splits:
        new_rows = []
        for r in rows:
            for sh in share_splits:
                rr = r.copy()
                rr["Supplier Name"] = sh["supplier"]
                rr["Value Note (If Any)"] = f"Share split detected: {sh['percentage']}%"
                rr["Split Flag"] = "Yes"
                rr["Split Reason"] = f"{base_reason} (Share % split)"
                new_rows.append(rr)
        rows = new_rows

    # 4) Multi financial values split (optional)
    # Here we do NOT override your Value (Million) because that comes from contract_extractor,
    # but we store in Value Note as cross-reference.
    if len(multi_values) >= 2:
        for r in rows:
            note = r.get("Value Note (If Any)", "Not Applicable")
            r["Value Note (If Any)"] = f"{note} | Multiple values detected: {multi_values[:5]}"

    # 5) FMS country split applied ONLY when G2G rows exist
    if fms_countries:
        final_rows = []
        for r in rows:
            if r.get("G2G/B2G") == "G2G":
                for c in fms_countries:
                    rr = r.copy()
                    rr["Customer Country"] = c
                    rr["Customer Region"] = get_region_for_country(c)
                    rr["Split Flag"] = "Yes"
                    rr["Split Reason"] = f"{base_reason} (FMS country split)"
                    final_rows.append(rr)
            else:
                final_rows.append(r)
        rows = final_rows

    # Always ensure flags exist
    for r in rows:
        r.setdefault("Split Flag", "Yes")
        r.setdefault("Split Reason", base_reason)

    return rows


# ==============================================================================
# 7) AGENTS / TOOLS
# ==============================================================================

# --- Stage 1: Sourcing ---
class SourcingInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    url: str = Field(description="Source URL of the contract announcement/news.")
    date: str = Field(description="Contract date in Excel (string).")

@tool("sourcing_extractor")
def sourcing_extractor(paragraph: str, url: str, date: str):
    """
    Stage 1: SOURCING EXTRACTOR

    Goal:
    - Build the base skeleton row containing:
      - Description of Contract (raw paragraph)
      - Source Link(s)
      - Contract Date
      - Reported Date (By SGA)
      - Additional Notes (Internal Only)

    Why needed:
    - These columns remain SAME even if contract splits into multiple rows.
    """
    reported_date = datetime.datetime.now().strftime("%Y-%m-%d")

    notes = "Standard extraction."
    if "modification" in str(paragraph).lower():
        notes = "Contract Modification."
    if "split" in str(paragraph).lower():
        notes = "Split award detected."

    return {
        "Description of Contract": paragraph,
        "Additional Notes (Internal Only)": notes,
        "Source Link(s)": url,
        "Contract Date": date,
        "Reported Date (By SGA)": reported_date
    }


# --- Stage 2: Geography ---
class GeographyInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")

@tool("geography_extractor")
def geography_extractor(paragraph: str):
    """
    Stage 2: GEOGRAPHY EXTRACTOR

    Goal:
    - Identify:
      - Customer Country
      - Customer Operator
      - Supplier Country
    - Derive:
      - Customer Region
      - Supplier Region
      - Domestic Content (Indigenous vs Imported)

    Why needed:
    - Geography can be split-driving when multiple customer countries exist.
    """
    sys_prompt = """
Extract: Customer Country, Supplier Country, Customer Operator.
Logic: If 'Navy awarded...', operator is Navy.
Return JSON:
{
  "Customer Country": "...",
  "Customer Operator": "...",
  "Supplier Country": "..."
}
"""
    log_block("HUMAN MESSAGE (Stage2 - Geography)", paragraph)

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": sys_prompt},
                      {"role": "user", "content": paragraph}],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage2 - Geography)", json.dumps(raw, indent=2))
    except Exception as e:
        raw = {}
        log_block("AI ERROR (Stage2 - Geography)", str(e))

    cust = raw.get("Customer Country", "Unknown")
    supp = raw.get("Supplier Country", "Unknown")

    domestic = "Indigenous" if str(cust).lower() == str(supp).lower() else "Imported"

    return {
        "Customer Region": get_region_for_country(cust),
        "Customer Country": cust,
        "Customer Operator": raw.get("Customer Operator", "Unknown"),
        "Supplier Region": get_region_for_country(supp),
        "Supplier Country": supp,
        "Domestic Content": domestic
    }


# --- Stage 3: System Classification (RAG + Evidence + Reason) ---
class SystemInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")

@tool("system_classifier")
def system_classifier(paragraph: str):
    """
    Stage 3: SYSTEM CLASSIFIER (RAG-ENHANCED)

    Goal:
    - Extract system-level labels using:
      ‚úÖ Taxonomy reference
      ‚úÖ Rule book triggers
      ‚úÖ RAG similar examples
      ‚úÖ Deterministic piloting override

    Output:
    - Adds Evidence + Reason for:
      - Market Segment
      - System Type (General)
      - System Type (Specific)
      - System Name (General)
      - System Name (Specific)
      - System Piloting

    Why this matters:
    - Your biggest accuracy issues were in:
      Market, System Type, System Name, System Piloting
    - RAG makes results consistent with your labeled history dataset
    """
    paragraph = str(paragraph).strip()
    if not paragraph:
        return {}

    log_block("HUMAN MESSAGE (Stage3 - System)", paragraph)

    lower_text = paragraph.lower()
    hints = [
        f"RULE: {v['guidance']}"
        for _, v in RULE_BOOK.items()
        if any(t in lower_text for t in v["triggers"])
    ]
    hint_str = "\n".join(hints) if hints else "No special override rules triggered."

    designators = extract_designators(paragraph)
    piloting_rule = detect_piloting_rule_based(paragraph, designators)

    rag_hits = retriever.retrieve(paragraph, top_k=3)
    rag_examples = []
    for hit in rag_hits:
        meta = hit["meta"]
        rag_examples.append({
            "score": round(hit["score"], 4),
            "Market Segment": meta.get("Market Segment", ""),
            "System Type (General)": meta.get("System Type (General)", ""),
            "System Type (Specific)": meta.get("System Type (Specific)", ""),
            "System Name (General)": meta.get("System Name (General)", ""),
            "System Name (Specific)": meta.get("System Name (Specific)", ""),
            "System Piloting": meta.get("System Piloting", ""),
            "Snippet": meta.get("Description of Contract", "")[:220] + "..."
        })

    sys_prompt = f"""
You are a Senior Defense System Classification Analyst.

REFERENCE TAXONOMY:
{TAXONOMY_STR}

RULE BOOK OVERRIDES:
{hint_str}

OUTPUT RULES:
- Return ONLY a FLAT JSON object.
- Every value must be a STRING.
- Do NOT return nested objects or lists.
- Evidence must be copied EXACTLY from paragraph text.
- If evidence not present, output "Not Found".

Return JSON:
{{
  "Market Segment": "",
  "Market Segment Evidence": "",
  "Market Segment Reason": "",

  "System Type (General)": "",
  "System Type (General) Evidence": "",
  "System Type (General) Reason": "",

  "System Type (Specific)": "",
  "System Type (Specific) Evidence": "",
  "System Type (Specific) Reason": "",

  "System Name (General)": "",
  "System Name (General) Evidence": "",
  "System Name (General) Reason": "",

  "System Name (Specific)": "",
  "System Name (Specific) Evidence": "",
  "System Name (Specific) Reason": "",

  "System Piloting": "",
  "System Piloting Evidence": "",
  "System Piloting Reason": "",

  "Confidence": "High/Medium/Low"
}}
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

DESIGNATORS (regex extracted):
{designators if designators else "None"}

RULE BASED PILOTING:
{piloting_rule}

RAG EXAMPLES:
{json.dumps(rag_examples, indent=2)}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": sys_prompt},
                      {"role": "user", "content": user_prompt}],
            temperature=0,
            response_format={"type": "json_object"}
        )
        result = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage3 - System)", json.dumps(result, indent=2))

        # Hard override piloting (best accuracy)
        result["System Piloting"] = piloting_rule
        result.setdefault("System Piloting Evidence", "Not Found")
        result.setdefault("System Piloting Reason", "Derived from deterministic piloting rules.")

        # Ensure flat
        for k, v in result.items():
            if isinstance(v, (dict, list)):
                result[k] = str(v)

        return result

    except Exception as e:
        log_block("AI ERROR (Stage3 - System)", str(e))
        return {
            "Market Segment": "",
            "Market Segment Evidence": "Not Found",
            "Market Segment Reason": "",

            "System Type (General)": "",
            "System Type (General) Evidence": "Not Found",
            "System Type (General) Reason": "",

            "System Type (Specific)": "",
            "System Type (Specific) Evidence": "Not Found",
            "System Type (Specific) Reason": "",

            "System Name (General)": "",
            "System Name (General) Evidence": "Not Found",
            "System Name (General) Reason": "",

            "System Name (Specific)": "",
            "System Name (Specific) Evidence": "Not Found",
            "System Name (Specific) Reason": "",

            "System Piloting": piloting_rule,
            "System Piloting Evidence": "Not Found",
            "System Piloting Reason": "Derived from deterministic piloting rules.",

            "Confidence": "Low",
            "Error": str(e)
        }


# --- Stage 4: Contract Extractor ---
class ContractInfoInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    contract_date: str = Field(description="Contract date as string.")

@tool("contract_extractor")
def contract_extractor(paragraph: str, contract_date: str):
    """
    Stage 4: CONTRACT EXTRACTOR

    Goal:
    - Extract the contract financial + program details:
      - Supplier Name (raw)
      - Program Type
      - Quantity
      - Value (Million)
      - Currency
      - Value Certainty
      - G2G/B2G
      - Completion Date Text

    Critical improvement:
    - Supplier Name returned from model is ALWAYS standardized
      using suppliers.json fuzzy match.
    """
    system_instruction = """
You are a Defense Contract Financial Analyst.

Rules:
1) raw_supplier_name: extract exact supplier from paragraph
2) program_type: Procurement/Training/MRO/Support/RDT&E/Upgrade/Other Service
3) value_certainty: Confirmed vs Estimated
4) quantity: extract numeric units if found else Not Applicable
5) g2g_b2g: G2G only if FMS mentioned else B2G
6) completion_date_text: for MRO calc only

Return JSON only.
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

SIGNED DATE:
{contract_date}

Return JSON:
{{
  "raw_supplier_name": "",
  "program_type": "",
  "value_million_raw": "",
  "currency_code": "",
  "value_certainty": "",
  "quantity": "",
  "completion_date_text": "",
  "g2g_b2g": "",
  "value_note": ""
}}
"""

    log_block("HUMAN MESSAGE (Stage4 - Contract)", paragraph)

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": system_instruction},
                      {"role": "user", "content": user_prompt}],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage4 - Contract)", json.dumps(raw, indent=2))
    except Exception as e:
        log_block("AI ERROR (Stage4 - Contract)", str(e))
        return {"Error": str(e)}

    final_supplier = get_best_supplier_match(raw.get("raw_supplier_name"))

    prog_type = raw.get("program_type", "Unknown")
    mro_months = calculate_mro_months(contract_date, raw.get("completion_date_text"), prog_type)

    try:
        val_str = str(raw.get("value_million_raw", "0")).replace(",", "").replace("$", "")
        val_float = float(val_str)
        val_formatted = "{:.3f}".format(val_float)
    except:
        val_formatted = "0.000"

    try:
        dt = pd.to_datetime(contract_date)
        sign_month = dt.strftime("%B")
        sign_year = str(dt.year)
    except:
        sign_month, sign_year = "Unknown", "Unknown"

    return {
        "Supplier Name": final_supplier,
        "Program Type": prog_type,
        "Expected MRO Contract Duration (Months)": mro_months,
        "Quantity": raw.get("quantity", "Not Applicable"),
        "Value Certainty": raw.get("value_certainty", "Confirmed"),
        "Value (Million)": val_formatted,
        "Currency": raw.get("currency_code", "USD$"),
        "Value (USD$ Million)": val_formatted,
        "Value Note (If Any)": raw.get("value_note", "Not Applicable"),
        "G2G/B2G": raw.get("g2g_b2g", "B2G"),
        "Signing Month": sign_month,
        "Signing Year": sign_year
    }


# --- Stage 5: Split Agent ---
class SplitterInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    base_row: dict = Field(description="Final extracted row after Stage1-4.")

@tool("splitter_agent")
def splitter_agent(paragraph: str, base_row: dict):
    """
    Stage 5: SPLITTER AGENT

    Goal:
    - Detect whether the extracted row must be split into multiple output rows.
    - Uses split_rows_engine() to apply deterministic split logic.
    - Ensures the split matches patterns found in your sample_data.

    Output:
    - Returns {"rows": [row1, row2, ...]}
    """
    try:
        rows = split_rows_engine(base_row, paragraph)
        for r in rows:
            r.setdefault("Split Flag", "No")
            r.setdefault("Split Reason", "")
        return {"rows": rows}
    except Exception as e:
        base_row["Split Flag"] = "Error"
        base_row["Split Reason"] = f"Split failed: {str(e)}"
        return {"rows": [base_row]}


# ==============================================================================
# 8) LANGGRAPH PIPELINE
# ==============================================================================

class AgentState(TypedDict):
    """
    State object passed between LangGraph stages.

    Contains:
    - Raw inputs (text, date, url)
    - Aggregated extraction dict (final_data)
    - Final output rows (final_rows) after splitting
    """
    input_text: str
    input_date: str
    input_url: str
    final_data: dict
    final_rows: list
    messages: Annotated[List[AnyMessage], add_messages]


def stage_1_sourcing(state: AgentState):
    """
    Node Stage1: Runs sourcing_extractor tool and updates final_data.
    """
    res = sourcing_extractor.invoke({
        "paragraph": state["input_text"],
        "url": state["input_url"],
        "date": state["input_date"],
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_2_geography(state: AgentState):
    """
    Node Stage2: Extracts geography fields and updates final_data.
    """
    res = geography_extractor.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_3_system(state: AgentState):
    """
    Node Stage3: System classification using RAG + evidence + reason.
    """
    res = system_classifier.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_4_contract(state: AgentState):
    """
    Node Stage4: Extracts supplier/program/financial/quantity.
    Supplier name is standardized via suppliers.json fuzzy match.
    """
    res = contract_extractor.invoke({
        "paragraph": state["input_text"],
        "contract_date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_5_split(state: AgentState):
    """
    Node Stage5: Applies split logic to create 1..N rows based on paragraph.
    """
    res = splitter_agent.invoke({
        "paragraph": state["input_text"],
        "base_row": state["final_data"]
    })
    return {"final_rows": res.get("rows", [state["final_data"]])}


workflow = StateGraph(AgentState)
workflow.add_node("Stage1", stage_1_sourcing)
workflow.add_node("Stage2", stage_2_geography)
workflow.add_node("Stage3", stage_3_system)
workflow.add_node("Stage4", stage_4_contract)
workflow.add_node("Stage5", stage_5_split)

workflow.add_edge(START, "Stage1")
workflow.add_edge("Stage1", "Stage2")
workflow.add_edge("Stage2", "Stage3")
workflow.add_edge("Stage3", "Stage4")
workflow.add_edge("Stage4", "Stage5")
workflow.add_edge("Stage5", END)

app = workflow.compile()


# ==============================================================================
# 9) GRAPH VISUALIZATION (OFFLINE SAFE)
# ==============================================================================

def export_workflow_mermaid(app_obj, out_file="workflow.mmd"):
    """
    Exports Mermaid graph as TEXT locally (no mermaid.ink required).

    Why:
    - Office laptops often block mermaid.ink API calls
    - You still need graph visualization / documentation

    Output:
    - A workflow.mmd file (Mermaid format)
    """
    mmd = app_obj.get_graph().draw_mermaid()
    with open(out_file, "w", encoding="utf-8") as f:
        f.write(mmd)
    print(f"‚úÖ Workflow Mermaid saved locally: {out_file}")
    return out_file


# ==============================================================================
# 10) EXCEL HIGHLIGHTING FEATURE
# ==============================================================================

def highlight_evidence_reason_columns(excel_path: str):
    """
    Highlights Evidence + Reason columns in output Excel.

    Evidence Columns:
    - colored light yellow
    Reason Columns:
    - colored light blue

    Goal:
    - Your team can validate 'why this label was chosen'
      without confusion.
    """
    wb = load_workbook(excel_path)
    ws = wb.active

    header = [cell.value for cell in ws[1]]

    evidence_cols = []
    reason_cols = []

    for idx, col_name in enumerate(header, start=1):
        if isinstance(col_name, str) and "Evidence" in col_name:
            evidence_cols.append(idx)
        if isinstance(col_name, str) and "Reason" in col_name:
            reason_cols.append(idx)

    evidence_fill = PatternFill(start_color="FFF2CC", end_color="FFF2CC", fill_type="solid")  # Yellow
    reason_fill = PatternFill(start_color="D9E1F2", end_color="D9E1F2", fill_type="solid")    # Blue
    header_font = Font(bold=True)

    for col_idx in evidence_cols:
        ws.cell(row=1, column=col_idx).fill = evidence_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for col_idx in reason_cols:
        ws.cell(row=1, column=col_idx).fill = reason_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for row in range(2, ws.max_row + 1):
        for col_idx in evidence_cols:
            ws.cell(row=row, column=col_idx).fill = evidence_fill
        for col_idx in reason_cols:
            ws.cell(row=row, column=col_idx).fill = reason_fill

    wb.save(excel_path)
    print("‚úÖ Evidence + Reason columns highlighted successfully.")


# ==============================================================================
# 11) MAIN EXECUTION
# ==============================================================================

if __name__ == "__main__":

    print(f"\nüìå Loading Input File: {INPUT_EXCEL_PATH}")

    # Offline safe workflow graph
    export_workflow_mermaid(app, out_file="workflow.mmd")

    try:
        df_input = pd.read_excel(INPUT_EXCEL_PATH)

        required_cols = ["Source URL", "Contract Date", "Contract Description"]
        if not all(col in df_input.columns for col in required_cols):
            raise ValueError(f"Excel file must contain columns: {required_cols}")

        print(f"üöÄ Processing {len(df_input)} rows...")
        results = []

        for index, row in df_input.iterrows():
            print(f"\nüîπ Row {index + 1}/{len(df_input)}")

            desc = str(row["Contract Description"]) if pd.notna(row["Contract Description"]) else ""
            c_date = str(row["Contract Date"]) if pd.notna(row["Contract Date"]) else str(datetime.date.today())
            c_url = str(row["Source URL"]) if pd.notna(row["Source URL"]) else ""

            initial_state = {
                "input_text": desc,
                "input_date": c_date,
                "input_url": c_url,
                "final_data": {},
                "final_rows": [],
                "messages": []
            }

            output_state = app.invoke(initial_state)

            rows = output_state.get("final_rows", [])
            if not rows:
                rows = [output_state.get("final_data", {})]

            results.extend(rows)

        df_final = pd.DataFrame(results)

        FINAL_COLUMNS = [
            "Customer Region", "Customer Country", "Customer Operator",
            "Supplier Region", "Supplier Country", "Domestic Content",

            "Split Flag", "Split Reason",

            "Market Segment", "Market Segment Evidence", "Market Segment Reason",
            "System Type (General)", "System Type (General) Evidence", "System Type (General) Reason",
            "System Type (Specific)", "System Type (Specific) Evidence", "System Type (Specific) Reason",
            "System Name (General)", "System Name (General) Evidence", "System Name (General) Reason",
            "System Name (Specific)", "System Name (Specific) Evidence", "System Name (Specific) Reason",
            "System Piloting", "System Piloting Evidence", "System Piloting Reason",
            "Confidence",

            "Supplier Name", "Program Type", "Expected MRO Contract Duration (Months)",
            "Quantity", "Value Certainty", "Value (Million)", "Currency",
            "Value (USD$ Million)", "Value Note (If Any)", "G2G/B2G",
            "Signing Month", "Signing Year",

            "Description of Contract",
            "Additional Notes (Internal Only)",
            "Source Link(s)",
            "Contract Date",
            "Reported Date (By SGA)"
        ]

        df_final = df_final.reindex(columns=FINAL_COLUMNS, fill_value="")
        df_final.to_excel(OUTPUT_EXCEL_PATH, index=False)

        # Highlight Evidence + Reason
        highlight_evidence_reason_columns(OUTPUT_EXCEL_PATH)

        print("\n‚úÖ Processing Complete!")
        print(f"üíæ Output File Saved: {OUTPUT_EXCEL_PATH}")
        print(df_final.head(3).to_string(index=False))

    except Exception as e:
        print(f"\n‚ùå ERROR: {e}")


In [None]:
import os
import re
import json
import difflib
import pickle
import datetime
from typing import Annotated, TypedDict, List, Dict, Any, Optional

import pandas as pd
import faiss
from dateutil import parser
from dateutil.relativedelta import relativedelta
import getpass

# LangGraph / LangChain
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from openai import OpenAI

# Excel formatting
from openpyxl import load_workbook
from openpyxl.styles import PatternFill, Font


# ==============================================================================
# 0) DEBUG LOGGING HELPERS
# ==============================================================================

def log_block(title: str, content: str):
    """
    Prints a clean and separated debug block in the console.

    WHY THIS IS IMPORTANT:
    - This pipeline is multi-stage and agentic (Stage1 -> Stage6).
    - If any stage fails, debugging becomes hard without clear logs.
    - This helps validate extraction decisions row-by-row.

    OUTPUT:
    - A titled separator
    - The content text (prompt / response / error)

    NOTE:
    - Does NOT affect output Excel.
    - Only prints to console.
    """
    print("\n" + "=" * 130)
    print(title)
    print("=" * 130)
    print(content)


# ==============================================================================
# 1) RAG RETRIEVER (FAISS + Metadata)
# ==============================================================================

class SystemKBRetriever:
    """
    RAG Retriever for Defense System Classification.

    PURPOSE:
    This retriever provides semantic retrieval over a historical labeled dataset
    stored inside a FAISS index + metadata pickle.

    EXPECTED FILES INSIDE kb_dir:
    ‚úÖ system_kb.faiss       -> FAISS index
    ‚úÖ system_kb_meta.pkl    -> list[dict] metadata rows

    WHY RAG IMPROVES ACCURACY:
    - System taxonomy classification becomes inconsistent if the LLM guesses.
    - RAG provides labeled examples similar to the current paragraph.
    - This improves consistency for:
      Market Segment, System Types, System Names, Piloting, etc.

    RETURN FORMAT:
    retrieve(query_text, top_k=3) returns:
      [
        {
          "score": float,
          "meta": {... column values from KB ...}
        },
        ...
      ]
    """

    def __init__(self, kb_dir: str, embed_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
        """
        Initializes and loads FAISS + metadata.

        PARAMETERS:
        - kb_dir: directory where FAISS + metadata exist
        - embed_model: sentence-transformers model name

        IMPORTANT:
        - embedder is lazy-loaded so script startup is fast.
        """
        self.kb_dir = kb_dir
        self.embed_model = embed_model

        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError(
                f"‚ùå KB files not found in: {kb_dir}\n"
                f"Expected:\n- {index_path}\n- {meta_path}\n\n"
                f"Build the KB first before running this pipeline."
            )

        print(f"‚úÖ Loading FAISS Index: {index_path}")
        self.index = faiss.read_index(index_path)

        print(f"‚úÖ Loading KB Metadata: {meta_path}")
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        print(f"‚úÖ KB Loaded rows: {len(self.meta)}")

        self.embedder = None

    def _lazy_load_embedder(self):
        """
        Lazy-load the embedding model only on demand.

        WHY:
        - Avoid memory overhead at script startup
        - Faster pipeline initialization
        """
        if self.embedder is None:
            from sentence_transformers import SentenceTransformer
            self.embedder = SentenceTransformer(self.embed_model)

    def retrieve(self, query_text: str, top_k: int = 3):
        """
        Retrieve top-k semantically similar KB examples.

        INPUT:
        - query_text: contract paragraph (string)
        - top_k: number of similar KB rows to retrieve

        OUTPUT:
        - list of dicts: [{"score":..., "meta":...}, ...]
        """
        import numpy as np

        query_text = str(query_text).strip()
        if not query_text:
            return []

        self._lazy_load_embedder()

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)

        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0:
                continue
            results.append({"score": float(score), "meta": self.meta[idx]})

        return results


# ==============================================================================
# 2) CONFIGURATION & FILE PATHS
# ==============================================================================

"""
PATH CONFIGURATION

Update these paths based on your system.

IMPORTANT:
- taxonomy.json is your classification constraints
- suppliers.json is your supplier normalization list
- input excel must have the required columns:
    ["Source URL", "Contract Date", "Contract Description"]
"""

TAXONOMY_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\taxonomy.json"
SUPPLIERS_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\suppliers.json"
INPUT_EXCEL_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\source_file.xlsx"
OUTPUT_EXCEL_PATH = "Processed_Defense_Data.xlsx"

RAG_KB_DIR = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\system_kb_store"


# ==============================================================================
# 3) SETUP LLM CLIENT
# ==============================================================================

"""
LLM Client Setup Notes:

You are using LLM Foundry OpenAI-compatible endpoint.

IMPORTANT:
- Key must be passed properly: api_key="TOKEN:project"
- base_url is Foundry endpoint
"""

if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter the LLM Foundry API Key: ")

client = OpenAI(
    api_key=f'{os.environ.get("LLMFOUNDRY_TOKEN")}:my-test-project',
    base_url="https://llmfoundry.straive.com/openai/v1/",
)

retriever = SystemKBRetriever(kb_dir=RAG_KB_DIR)


# ==============================================================================
# 4) LOAD JSON HELPERS
# ==============================================================================

def load_json_file(filename, default_value):
    """
    Loads a JSON file safely with fallback.

    WHY:
    - taxonomy.json and suppliers.json are critical
    - pipeline must not crash if JSON file is missing/corrupt

    RETURNS:
    - json object (dict/list) if loaded successfully
    - default_value otherwise
    """
    try:
        with open(filename, "r", encoding="utf-8") as f:
            print(f"‚úÖ Loaded: {filename}")
            return json.load(f)
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not load {filename} ({e}). Using default.")
        return default_value


raw_taxonomy = load_json_file(TAXONOMY_PATH, {})
TAXONOMY_STR = json.dumps(raw_taxonomy, separators=(",", ":"))

SUPPLIER_LIST = load_json_file(SUPPLIERS_PATH, [
    "Dell Inc", "Boeing", "Lockheed Martin", "Raytheon Technologies",
    "Northrop Grumman", "L3Harris", "BAE Systems", "General Dynamics"
])


# ==============================================================================
# 5) RULE BOOK + GEOGRAPHY
# ==============================================================================

"""
RULE_BOOK

This is deterministic "hint injection" logic to reduce misclassification
for very frequent patterns (radars, countermeasures, ammo).

If any triggers match the paragraph text, we provide a guidance string
inside system prompt for Stage3 classifier.
"""

RULE_BOOK = {
    "defensive_countermeasures": {
        "triggers": ["flare", "chaff", "countermeasure", "decoy", "mju-", "ale-"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Defensive Systems', Specific: 'Defensive Aid Suite'"
    },
    "radars_and_sensors": {
        "triggers": ["radar", "sonar", "sensor", "an/apy", "an/tpy"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Sensors'"
    },
    "ammunition": {
        "triggers": ["cartridge", "round", "projectile", " 5.56", " 7.62", "ammo"],
        "guidance": "Market Segment: 'Weapon Systems', System Type (General): 'Ammunition'"
    }
}

GEOGRAPHY_MAPPING = {
    "North America": ["USA", "United States", "US", "United States of America", "Canada", "America"],
    "Europe": ["UK", "United Kingdom", "Ukraine", "Germany", "France", "Italy", "Spain", "Poland", "Netherlands",
               "Norway", "Sweden", "Finland", "Denmark", "Belgium"],
    "Asia-Pacific": ["Australia", "Japan", "South Korea", "Taiwan", "India", "Singapore", "New Zealand"],
    "Middle East and North Africa": ["Israel", "Saudi Arabia", "UAE", "United Arab Emirates", "Egypt", "Qatar", "Kuwait", "Iraq"],
    "International Organisations": ["NATO", "EU", "IFU", "UN", "NSPA"]
}


# ==============================================================================
# 6) BASE HELPERS (Supplier + Dates + Region + Designators)
# ==============================================================================

def get_best_supplier_match(extracted_name: str):
    """
    Standardizes supplier names using suppliers.json.

    WHY:
    - LLM outputs inconsistent supplier strings
    - Your final Excel must match standardized naming

    MATCH STRATEGY:
    1) If blank/unknown -> "Unknown"
    2) Exact match ignoring case
    3) Fuzzy match using difflib
    4) Otherwise return cleaned extracted string
    """
    if not extracted_name or str(extracted_name).strip().lower() in ["unknown", "n/a", "not applicable", ""]:
        return "Unknown"

    clean_name = str(extracted_name).strip()
    supplier_map = {s.lower(): s for s in SUPPLIER_LIST}

    if clean_name.lower() in supplier_map:
        return supplier_map[clean_name.lower()]

    matches = difflib.get_close_matches(clean_name, SUPPLIER_LIST, n=1, cutoff=0.6)
    return matches[0] if matches else clean_name


def calculate_mro_months(start_date_str, end_date_text, program_type):
    """
    Calculates MRO duration in months.

    RULE:
    - Only valid for program_type == "MRO/Support"
    - otherwise return "Not Applicable"

    OUTPUT:
    - month count as string OR "Not Applicable"
    """
    if program_type != "MRO/Support":
        return "Not Applicable"
    try:
        if not start_date_str or not end_date_text:
            return "Not Applicable"

        start = pd.to_datetime(start_date_str, dayfirst=True)
        end = parser.parse(str(end_date_text), fuzzy=True)

        diff = relativedelta(end, start)
        total_months = diff.years * 12 + diff.months
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"


def get_region_for_country(country_name):
    """
    Maps a country name into a region bucket using GEOGRAPHY_MAPPING.

    IMPORTANT:
    - If unknown -> "Unknown"
    - Handles US/USA/UK common strings
    """
    if not country_name or str(country_name).strip().lower() in ["unknown", "n/a", "not applicable", ""]:
        return "Unknown"

    clean = str(country_name).strip().lower()

    if clean in ["us", "usa", "u.s.", "united states", "united states of america"]:
        return "North America"
    if clean in ["uk", "u.k.", "britain", "great britain"]:
        return "Europe"

    for region, countries in GEOGRAPHY_MAPPING.items():
        if any(c.lower() == clean for c in countries):
            return region

    return "Unknown"


DESIGNATOR_PATTERNS = [
    r"\bDDG[-\s]?\d+\b", r"\bCVN[-\s]?\d+\b", r"\bSSN[-\s]?\d+\b",
    r"\bLCS[-\s]?\d+\b", r"\bLPD[-\s]?\d+\b", r"\bLHA[-\s]?\d+\b", r"\bLHD[-\s]?\d+\b",
    r"\bF-\d+\b", r"\bB-\d+\b", r"\bC-\d+\b", r"\bA-\d+\b",
    r"\bMQ-\d+\b", r"\bRQ-\d+\b",
    r"\bAN\/[A-Z0-9\-]+\b",
    r"\b(AIM|AGM|SM|RIM|MIM)-\d+\b",
]

def extract_designators(text: str):
    """
    Extracts defense platform identifiers/designators from contract text.

    WHY:
    - Helps classify system piloting deterministically
    - Helps identify platform family quickly

    OUTPUT:
    - list[str] of unique designators found in paragraph
    """
    text = str(text)
    found = []
    for pat in DESIGNATOR_PATTERNS:
        found.extend(re.findall(pat, text, flags=re.IGNORECASE))

    cleaned = []
    for f in found:
        cleaned.append(f.upper().replace(" ", "").replace("--", "-"))

    final = []
    seen = set()
    for x in cleaned:
        if x not in seen:
            final.append(x)
            seen.add(x)

    return final


def detect_piloting_rule_based(text: str, designators: List[str]) -> str:
    """
    Deterministic piloting classifier.

    OUTPUT MUST BE ONE OF:
    - "Crewed"
    - "Uncrewed"
    - "Not Applicable"

    RULES:
    - MQ-/RQ- or unmanned/UAV/drone/autonomous -> Uncrewed
    - DDG/CVN/SSN/USS -> Crewed
    - Else -> Not Applicable
    """
    t = str(text).lower()

    if any(d.startswith(("MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    if any(k in t for k in ["unmanned", "uav", "drone", "autonomous"]):
        return "Uncrewed"

    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD")) for d in designators):
        return "Crewed"
    if "uss " in t:
        return "Crewed"

    return "Not Applicable"


# ==============================================================================
# 7) SPLIT ENGINE (DETERMINISTIC)
# ==============================================================================

def _normalize_spaces(text: str) -> str:
    """
    Normalizes whitespace to simplify parsing.

    WHY:
    - DoD contract paragraphs often contain newlines, multiple spaces.
    - Regex extraction becomes more stable.
    """
    return re.sub(r"\s+", " ", str(text or "")).strip()

def _safe_int(value: str):
    """
    Safely converts string -> integer.
    Returns None if conversion fails.
    """
    try:
        return int(str(value).replace(",", "").strip())
    except Exception:
        return None

def _extract_supplier_from_description(paragraph: str) -> str:
    """
    Fallback supplier extraction from DoD style contracts.

    EXAMPLE:
    "Raytheon Technologies, Tucson, Arizona, is awarded $XYZ..."

    We extract the leading entity before first comma if sentence matches.
    """
    p = _normalize_spaces(paragraph)
    m = re.match(r"^(.*?),\s+.*?\s+is awarded", p, flags=re.IGNORECASE)
    if m:
        return m.group(1).strip()
    return "Unknown"

def parse_line_items(paragraph: str):
    """
    Extracts line-item splits when paragraph contains "as follows:" and semicolons.

    EXAMPLE:
      "as follows: 212 for the Navy; 187 for the Air Force; 84 for FMS customers"

    OUTPUT:
      [
        {
          "item_total_qty": 212,
          "item_name": "for the Navy",
          "allocation_text": ""
        }
      ]

    NOTE:
    - Best-effort parsing; not perfect.
    - If no chunks found, returns [].
    """
    paragraph = _normalize_spaces(paragraph)

    if "as follows:" in paragraph.lower():
        idx = paragraph.lower().find("as follows:")
        split_part = paragraph[idx + len("as follows:"):].strip()
    else:
        split_part = paragraph

    chunks = [c.strip() for c in split_part.split(";") if c.strip()]

    items = []
    for ch in chunks:
        m = re.match(r"^(\d{1,6}(?:,\d{3})*)\s+(.*)$", ch)
        if not m:
            continue

        qty = _safe_int(m.group(1))
        rest = m.group(2).strip()

        allocation_text = ""
        item_name = rest

        paren = re.search(r"\((.*?)\)", rest)
        if paren:
            allocation_text = paren.group(0).strip()
            item_name = re.sub(r"\(.*?\)", "", rest).strip()

        items.append(
            {
                "item_total_qty": qty,
                "item_name": item_name,
                "allocation_text": allocation_text,
            }
        )
    return items

def parse_operator_allocations(allocation_text: str):
    """
    Detects allocations inside a text fragment like:
    - "212 for the Navy"
    - "187 for the Air Force"
    - "84 for FMS customers"

    OUTPUT:
      [
        {"allocation_type":"operator","operator":"Navy","qty":212,"g2g_b2g":"B2G"},
        {"allocation_type":"fms_bucket","operator":"Foreign Assistance","qty":84,"g2g_b2g":"G2G"}
      ]

    RULES:
    - Navy/Air Force/Army/Marine => B2G
    - FMS => G2G + operator=Foreign Assistance
    """
    if not allocation_text:
        return []

    txt = allocation_text.lower()

    allocations = []
    ops = re.findall(r"(\d{1,6}(?:,\d{3})*)\s+for\s+the\s+(navy|air force|army|marine corps)", txt)
    for qty_word, op in ops:
        qty = _safe_int(qty_word)
        if qty is None:
            continue
        op_norm = "Air Force" if op == "air force" else op.title()
        allocations.append({
            "allocation_type": "operator",
            "operator": op_norm,
            "qty": qty,
            "g2g_b2g": "B2G"
        })

    fms = re.findall(
        r"(\d{1,6}(?:,\d{3})*)\s+for\s+(?:foreign military sales\s*\(fms\)\s*customers|fms\s*customers|fms)",
        txt
    )
    for qty_word in fms:
        qty = _safe_int(qty_word)
        if qty is None:
            continue
        allocations.append({
            "allocation_type": "fms_bucket",
            "operator": "Foreign Assistance",
            "qty": qty,
            "g2g_b2g": "G2G"
        })

    uniq, seen = [], set()
    for a in allocations:
        key = (a["allocation_type"], a["operator"], a["qty"], a["g2g_b2g"])
        if key not in seen:
            uniq.append(a)
            seen.add(key)

    return uniq

def parse_fms_countries(paragraph: str):
    """
    Extracts multi-country list for FMS.

    PATTERN:
      "governments of Australia, Bahrain, Belgium..."

    OUTPUT:
      ["Australia","Bahrain","Belgium",...]
    """
    text = str(paragraph)

    m = re.search(
        r"governments of (.+?)(?:\.\s| Work will be performed| Fiscal| This contract|$)",
        text,
        flags=re.IGNORECASE | re.DOTALL
    )
    if not m:
        return []

    block = m.group(1)
    raw = re.split(r",|\band\b", block)

    countries = []
    for c in raw:
        c = c.strip()
        if 2 < len(c) <= 50:
            countries.append(c)

    final, seen = [], set()
    for c in countries:
        if c.lower() not in seen:
            final.append(c)
            seen.add(c.lower())

    return final

def split_rows_engine(base_row: dict, paragraph: str):
    """
    MASTER SPLIT ENGINE (DETERMINISTIC)

    INPUTS:
    - base_row: dict output after Stage1-4
    - paragraph: contract paragraph text

    OUTPUT:
    - list of rows (1..N)

    SPLIT RULES:
    ‚úÖ Operator allocations:
       - "212 for the Navy"
       - "187 for the Air Force"
       => Creates one row per operator allocation

    ‚úÖ FMS allocations:
       - "84 for FMS customers"
       => Creates one row with operator "Foreign Assistance" and G2G

    ‚úÖ Multi-country FMS:
       - "governments of Australia, Bahrain, Belgium"
       => creates one row per country (Quantity=1 each)

    VALIDATION COLUMNS:
    - Split Flag
    - Split Reason
    - Split Evidence
    - Split Line Item
    """
    paragraph = _normalize_spaces(paragraph)
    base_row = base_row.copy()

    # Supplier fallback (important)
    if not base_row.get("Supplier Name") or base_row.get("Supplier Name") in ["", "Unknown"]:
        extracted = _extract_supplier_from_description(paragraph)
        base_row["Supplier Name"] = get_best_supplier_match(extracted)

    line_items = parse_line_items(paragraph)
    fms_countries = parse_fms_countries(paragraph)

    # fallback as one chunk
    if not line_items:
        line_items = [{
            "item_total_qty": base_row.get("Quantity", "Not Applicable"),
            "item_name": base_row.get("System Name (Specific)", "Unknown Deliverable"),
            "allocation_text": ""
        }]

    final_rows = []
    any_split = False

    for item in line_items:
        item_total_qty = item.get("item_total_qty")
        item_name = item.get("item_name", "Unknown Deliverable")
        allocation_text = item.get("allocation_text", "")

        allocations = parse_operator_allocations(item_name)

        # if no allocations, return as is
        if not allocations:
            r = base_row.copy()
            r["Split Flag"] = "No"
            r["Split Reason"] = "No operator/FMS allocation detected."
            r["Split Evidence"] = "Not Found"
            r["Split Line Item"] = str(item_name)
            if item_total_qty is not None:
                r["Quantity"] = str(item_total_qty)
            final_rows.append(r)
            continue

        any_split = True

        for alloc in allocations:
            if alloc["allocation_type"] == "operator":
                rr = base_row.copy()
                rr["Customer Operator"] = alloc["operator"]
                rr["Quantity"] = str(alloc["qty"])
                rr["G2G/B2G"] = alloc["g2g_b2g"]

                rr["Split Flag"] = "Yes"
                rr["Split Reason"] = "Operator allocation split detected."
                rr["Split Evidence"] = item_name
                rr["Split Line Item"] = str(item_name)
                final_rows.append(rr)

            elif alloc["allocation_type"] == "fms_bucket":
                fms_qty = alloc["qty"]

                if fms_countries:
                    for c in fms_countries:
                        rr = base_row.copy()
                        rr["Customer Country"] = c
                        rr["Customer Region"] = get_region_for_country(c)
                        rr["Customer Operator"] = "Foreign Assistance"
                        rr["Quantity"] = "1"
                        rr["G2G/B2G"] = "G2G"

                        rr["Split Flag"] = "Yes"
                        rr["Split Reason"] = "FMS multi-country detected -> 1 row per country"
                        rr["Split Evidence"] = f"Countries: {', '.join(fms_countries)}"
                        rr["Split Line Item"] = str(item_name)
                        final_rows.append(rr)
                else:
                    rr = base_row.copy()
                    rr["Customer Operator"] = "Foreign Assistance"
                    rr["Quantity"] = str(fms_qty)
                    rr["G2G/B2G"] = "G2G"

                    rr["Split Flag"] = "Yes"
                    rr["Split Reason"] = "FMS quantity detected but no explicit countries"
                    rr["Split Evidence"] = item_name
                    rr["Split Line Item"] = str(item_name)
                    final_rows.append(rr)

    if not final_rows:
        base_row["Split Flag"] = "No"
        base_row["Split Reason"] = "Split fallback (no rows created)"
        base_row["Split Evidence"] = "Not Found"
        base_row["Split Line Item"] = "Not Found"
        return [base_row]

    for r in final_rows:
        r.setdefault("Split Flag", "Yes" if any_split else "No")
        r.setdefault("Split Reason", "Not Found")
        r.setdefault("Split Evidence", "Not Found")
        r.setdefault("Split Line Item", "Not Found")

    return final_rows


# ==============================================================================
# 8) TOOLS / AGENTS (Stage1 -> Stage6)
# ==============================================================================

class SourcingInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    url: str = Field(description="Source URL of the contract announcement/news.")
    date: str = Field(description="Contract date string.")

@tool("sourcing_extractor")
def sourcing_extractor(paragraph: str, url: str, date: str):
    """
    STAGE 1 TOOL: SOURCING EXTRACTOR

    GOAL:
    Build base row skeleton metadata fields which remain constant across splits.

    OUTPUT FIELDS:
    - Description of Contract
    - Additional Notes (Internal Only)
    - Source Link(s)
    - Contract Date
    - Reported Date (By SGA)

    RULES:
    ‚úÖ Must not guess/modify system classification
    ‚úÖ Must not guess financial values
    ‚úÖ Only sourcing metadata is allowed here
    """
    reported_date = datetime.datetime.now().strftime("%Y-%m-%d")

    notes = "Standard extraction."
    if "modification" in str(paragraph).lower():
        notes = "Contract Modification."
    if "split" in str(paragraph).lower():
        notes = "Split award detected."

    return {
        "Description of Contract": paragraph,
        "Additional Notes (Internal Only)": notes,
        "Source Link(s)": url,
        "Contract Date": date,
        "Reported Date (By SGA)": reported_date
    }


class GeographyInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")

@tool("geography_extractor")
def geography_extractor(paragraph: str):
    """
    STAGE 2 TOOL: GEOGRAPHY EXTRACTOR

    GOAL:
    Extract and derive geography information:
    - Customer Country
    - Customer Operator
    - Supplier Country
    - Customer Region (derived)
    - Supplier Region (derived)
    - Domestic Content (derived)

    RULES FOR LLM:
    ‚úÖ Return JSON only
    ‚úÖ Keep short and clean values
    ‚úÖ If missing -> "Unknown"
    """
    sys_prompt = """
You are a Defense Contract Geography Analyst.

Extract ONLY:
1) Customer Country
2) Customer Operator
3) Supplier Country

Return JSON ONLY:
{
  "Customer Country": "...",
  "Customer Operator": "...",
  "Supplier Country": "..."
}

Rules:
- If missing -> "Unknown"
- Keep values short and clean
"""

    log_block("HUMAN MESSAGE (Stage2 - Geography)", paragraph)

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": paragraph}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage2 - Geography)", json.dumps(raw, indent=2))
    except Exception as e:
        raw = {}
        log_block("AI ERROR (Stage2 - Geography)", str(e))

    cust = raw.get("Customer Country", "Unknown")
    supp = raw.get("Supplier Country", "Unknown")
    domestic = "Indigenous" if str(cust).lower() == str(supp).lower() else "Imported"

    return {
        "Customer Region": get_region_for_country(cust),
        "Customer Country": cust,
        "Customer Operator": raw.get("Customer Operator", "Unknown"),
        "Supplier Region": get_region_for_country(supp),
        "Supplier Country": supp,
        "Domestic Content": domestic
    }


class SystemInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")

@tool("system_classifier")
def system_classifier(paragraph: str):
    """
    STAGE 3 TOOL: SYSTEM CLASSIFIER (RAG + TAXONOMY + EVIDENCE/REASON)

    GOAL:
    Assign system classification fields using:
    ‚úÖ taxonomy.json
    ‚úÖ deterministic rule-book triggers
    ‚úÖ RAG examples from FAISS KB
    ‚úÖ deterministic piloting override

    REQUIRED OUTPUT (FLAT JSON):
    - Market Segment
    - System Type (General)
    - System Type (Specific)
    - System Name (General)
    - System Name (Specific)
    - System Piloting (overridden in python)
    - Confidence

    REQUIRED SUPPORT FIELDS:
    - Evidence fields must match EXACT paragraph text
    - Reason fields are short and logical

    STRICT RULES:
    ‚úÖ Output FLAT JSON ONLY
    ‚úÖ All values must be STRING
    ‚úÖ Evidence must be EXACT substring from paragraph
    ‚úÖ If evidence not found -> "Not Found"
    """
    paragraph = str(paragraph).strip()
    if not paragraph:
        return {}

    log_block("HUMAN MESSAGE (Stage3 - System)", paragraph)

    lower_text = paragraph.lower()
    hints = [
        f"RULE: {v['guidance']}"
        for _, v in RULE_BOOK.items()
        if any(t in lower_text for t in v["triggers"])
    ]
    hint_str = "\n".join(hints) if hints else "No special override rules triggered."

    designators = extract_designators(paragraph)
    piloting_rule = detect_piloting_rule_based(paragraph, designators)

    rag_hits = retriever.retrieve(paragraph, top_k=3)
    rag_examples = []
    for hit in rag_hits:
        meta = hit["meta"]
        rag_examples.append({
            "score": round(hit["score"], 4),
            "Market Segment": meta.get("Market Segment", ""),
            "System Type (General)": meta.get("System Type (General)", ""),
            "System Type (Specific)": meta.get("System Type (Specific)", ""),
            "System Name (General)": meta.get("System Name (General)", ""),
            "System Name (Specific)": meta.get("System Name (Specific)", ""),
            "System Piloting": meta.get("System Piloting", ""),
            "Snippet": meta.get("Description of Contract", "")[:200] + "..."
        })

    sys_prompt = f"""
You are a Senior Defense System Classification Analyst.

REFERENCE TAXONOMY:
{TAXONOMY_STR}

RULE BOOK OVERRIDES:
{hint_str}

STRICT OUTPUT RULES:
- Return ONLY a FLAT JSON object
- Every value MUST be a STRING
- Evidence MUST be copied EXACTLY from paragraph
- If evidence not found -> "Not Found"

Return JSON:
{{
  "Market Segment": "",
  "Market Segment Evidence": "",
  "Market Segment Reason": "",

  "System Type (General)": "",
  "System Type (General) Evidence": "",
  "System Type (General) Reason": "",

  "System Type (Specific)": "",
  "System Type (Specific) Evidence": "",
  "System Type (Specific) Reason": "",

  "System Name (General)": "",
  "System Name (General) Evidence": "",
  "System Name (General) Reason": "",

  "System Name (Specific)": "",
  "System Name (Specific) Evidence": "",
  "System Name (Specific) Reason": "",

  "System Piloting": "",
  "System Piloting Evidence": "",
  "System Piloting Reason": "",

  "Confidence": "High/Medium/Low"
}}
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

DESIGNATORS:
{designators if designators else "None"}

RULE PILOTING OVERRIDE:
{piloting_rule}

RAG EXAMPLES:
{json.dumps(rag_examples, indent=2)}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": user_prompt},
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        result = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage3 - System)", json.dumps(result, indent=2))

        result["System Piloting"] = piloting_rule
        result.setdefault("System Piloting Evidence", "Not Found")
        result.setdefault("System Piloting Reason", "Derived using deterministic piloting rules.")

        # enforce flat
        for k, v in result.items():
            if isinstance(v, (dict, list)):
                result[k] = str(v)

        return result

    except Exception as e:
        log_block("AI ERROR (Stage3 - System)", str(e))
        return {
            "Market Segment": "",
            "Market Segment Evidence": "Not Found",
            "Market Segment Reason": "",

            "System Type (General)": "",
            "System Type (General) Evidence": "Not Found",
            "System Type (General) Reason": "",

            "System Type (Specific)": "",
            "System Type (Specific) Evidence": "Not Found",
            "System Type (Specific) Reason": "",

            "System Name (General)": "",
            "System Name (General) Evidence": "Not Found",
            "System Name (General) Reason": "",

            "System Name (Specific)": "",
            "System Name (Specific) Evidence": "Not Found",
            "System Name (Specific) Reason": "",

            "System Piloting": piloting_rule,
            "System Piloting Evidence": "Not Found",
            "System Piloting Reason": "Derived using deterministic piloting rules (fallback).",

            "Confidence": "Low",
            "Error": str(e)
        }


class ContractInfoInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    contract_date: str = Field(description="Contract date string.")

@tool("contract_extractor")
def contract_extractor(paragraph: str, contract_date: str):
    """
    STAGE 4 TOOL: CONTRACT EXTRACTOR (Financial + Program)

    GOAL:
    Extract supplier/program/financial fields in a structured JSON output.

    REQUIRED OUTPUT KEYS:
    - raw_supplier_name
    - program_type
    - value_million_raw
    - currency_code
    - value_certainty
    - quantity
    - completion_date_text
    - g2g_b2g
    - value_note

    RULES:
    ‚úÖ Supplier name must come EXACTLY from paragraph
    ‚úÖ program_type must be one of:
       Procurement/Training/MRO/Support/RDT&E/Upgrade/Other Service
    ‚úÖ g2g_b2g: if FMS -> G2G else B2G
    ‚úÖ quantity numeric else "Not Applicable"

    POST-PROCESSING:
    - supplier standardized using suppliers.json
    - values formatted to 3 decimals
    - signing month/year derived from contract_date
    """
    system_instruction = """
You are a Defense Contract Financial Analyst.

Return JSON ONLY.

Rules:
1) raw_supplier_name: extract exact supplier from paragraph text
2) program_type: Procurement/Training/MRO/Support/RDT&E/Upgrade/Other Service
3) value_certainty: Confirmed vs Estimated
4) quantity: numeric if found else Not Applicable
5) g2g_b2g: G2G only if FMS mentioned else B2G
6) completion_date_text: end date if mentioned
7) value_million_raw: numeric only
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

SIGNED DATE:
{contract_date}

Return JSON:
{{
  "raw_supplier_name": "",
  "program_type": "",
  "value_million_raw": "",
  "currency_code": "",
  "value_certainty": "",
  "quantity": "",
  "completion_date_text": "",
  "g2g_b2g": "",
  "value_note": ""
}}
"""

    log_block("HUMAN MESSAGE (Stage4 - Contract)", paragraph)

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage4 - Contract)", json.dumps(raw, indent=2))
    except Exception as e:
        log_block("AI ERROR (Stage4 - Contract)", str(e))
        return {"Error": str(e)}

    final_supplier = get_best_supplier_match(raw.get("raw_supplier_name"))

    prog_type = raw.get("program_type", "Unknown")
    mro_months = calculate_mro_months(contract_date, raw.get("completion_date_text"), prog_type)

    try:
        val_str = str(raw.get("value_million_raw", "0")).replace(",", "").replace("$", "")
        val_float = float(val_str)
        val_formatted = "{:.3f}".format(val_float)
    except:
        val_formatted = "0.000"

    try:
        dt = pd.to_datetime(contract_date)
        sign_month = dt.strftime("%B")
        sign_year = str(dt.year)
    except:
        sign_month, sign_year = "Unknown", "Unknown"

    return {
        "Supplier Name": final_supplier,
        "Program Type": prog_type,
        "Expected MRO Contract Duration (Months)": mro_months,
        "Quantity": raw.get("quantity", "Not Applicable"),
        "Value Certainty": raw.get("value_certainty", "Confirmed"),
        "Value (Million)": val_formatted,
        "Currency": raw.get("currency_code", "USD$"),
        "Value (USD$ Million)": val_formatted,
        "Value Note (If Any)": raw.get("value_note", "Not Applicable"),
        "G2G/B2G": raw.get("g2g_b2g", "B2G"),
        "Signing Month": sign_month,
        "Signing Year": sign_year
    }


class SplitterInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    base_row: dict = Field(description="Final extracted row after Stage1-4.")

@tool("splitter_agent")
def splitter_agent(paragraph: str, base_row: dict):
    """
    STAGE 5 TOOL: SPLITTER AGENT (DETERMINISTIC)

    GOAL:
    Apply deterministic row split logic to handle:
    - Multi-operator allocations (Navy/Air Force/etc)
    - FMS allocations (G2G)
    - Multi-country FMS list

    OUTPUT:
    Returns:
    {
      "rows": [row1, row2, ...]
    }
    """
    try:
        rows = split_rows_engine(base_row, paragraph)
        for r in rows:
            r.setdefault("Split Flag", "No")
            r.setdefault("Split Reason", "")
            r.setdefault("Split Evidence", "Not Found")
            r.setdefault("Split Line Item", "Not Found")
        return {"rows": rows}
    except Exception as e:
        base_row = base_row.copy()
        base_row["Split Flag"] = "Error"
        base_row["Split Reason"] = f"Split failed: {str(e)}"
        base_row["Split Evidence"] = "Not Found"
        base_row["Split Line Item"] = "Not Found"
        return {"rows": [base_row]}


# ==============================================================================
# 9) STAGE 6 QUALITY VALIDATOR (RULE-BASED + LIGHT AUTO-FIX)
# ==============================================================================

def _contains_any(text: str, keywords: List[str]) -> bool:
    """Returns True if any keyword appears in text (case-insensitive)."""
    t = str(text or "").lower()
    return any(k.lower() in t for k in keywords)

def _supplier_appears_in_paragraph(supplier: str, paragraph: str) -> bool:
    """
    Checks if standardized supplier appears in paragraph text.

    WHY:
    - After standardization, supplier text might not match exactly
      but partial match should still be detected.

    RULE:
    - If supplier == Unknown -> return False
    """
    supplier = str(supplier or "").strip()
    if not supplier or supplier.lower() == "unknown":
        return False

    p = str(paragraph or "").lower()
    s = supplier.lower()

    # try direct match
    if s in p:
        return True

    # try partial token match (first 2 words)
    parts = s.split()
    if len(parts) >= 2:
        key = " ".join(parts[:2])
        return key in p

    return False

def validate_single_row(row: Dict[str, Any]) -> Dict[str, Any]:
    """
    VALIDATOR FUNCTION (RULE-BASED)

    GOAL:
    Validate one extracted output row for common errors like:
    1) Supplier mismatch (supplier not present in paragraph)
    2) Piloting mismatch (UAV but crewed)
    3) Segment mismatch (radar but wrong Market Segment)
    4) Value mismatch (Value missing but paragraph has $)
    5) G2G/B2G mismatch (FMS but B2G)

    OUTPUT:
    Adds these fields:
    - Validation Flag: PASS/FAIL/WARN
    - Validation Issues: string summary
    - AutoFix Applied: Yes/No
    - AutoFix Notes: explanation

    AUTOFIX POLICY (SAFE FIXES ONLY):
    ‚úÖ Fix G2G/B2G to G2G if paragraph contains "FMS"
    ‚úÖ Fix System Piloting using deterministic rules
    ‚úÖ Fix Supplier Name if unknown using fallback extraction
    ‚úÖ Fix Value Certainty = Estimated if paragraph contains "estimated"
    ‚ùå Never modify Market Segment / System Name based on validator (too risky)
    """
    updated = row.copy()
    issues = []
    autofix_notes = []
    autofix_applied = False

    paragraph = updated.get("Description of Contract", "")
    supplier = updated.get("Supplier Name", "Unknown")

    # 1) Supplier mismatch
    if supplier and supplier != "Unknown":
        if not _supplier_appears_in_paragraph(supplier, paragraph):
            issues.append("Supplier Name not found in paragraph (possible wrong supplier extraction).")

    # 2) Piloting mismatch
    designators = extract_designators(paragraph)
    piloting_expected = detect_piloting_rule_based(paragraph, designators)
    piloting_current = updated.get("System Piloting", "Not Applicable")

    if piloting_expected != "Not Applicable":
        if piloting_current != piloting_expected:
            issues.append(f"System Piloting mismatch (expected={piloting_expected}, got={piloting_current}).")
            updated["System Piloting"] = piloting_expected
            updated["System Piloting Reason"] = "Validator override: deterministic piloting rules applied."
            autofix_applied = True
            autofix_notes.append("Corrected System Piloting using deterministic rules.")

    # 3) FMS mismatch
    if _contains_any(paragraph, ["fms", "foreign military sales"]):
        if updated.get("G2G/B2G", "B2G") != "G2G":
            issues.append("Paragraph mentions FMS but row marked as B2G.")
            updated["G2G/B2G"] = "G2G"
            autofix_applied = True
            autofix_notes.append("Fixed G2G/B2G to G2G due to FMS keyword.")

    # 4) Value mismatch (simple)
    if "$" in str(paragraph) and str(updated.get("Value (Million)", "")).strip() in ["", "0.000", "0", "Not Applicable"]:
        issues.append("Paragraph contains $ value but Value (Million) is missing/zero.")

    # 5) Segment sanity hint (warn only)
    if _contains_any(paragraph, ["radar", "an/apy", "an/tpy"]) and updated.get("Market Segment", ""):
        if updated.get("Market Segment") not in ["C4ISR Systems", "C4ISR"]:
            issues.append("Radar keyword detected but Market Segment not C4ISR (check classification).")

    # 6) Supplier unknown autofix
    if supplier in ["Unknown", "", None]:
        fallback_supplier = _extract_supplier_from_description(paragraph)
        if fallback_supplier != "Unknown":
            updated["Supplier Name"] = get_best_supplier_match(fallback_supplier)
            autofix_applied = True
            autofix_notes.append("Filled Supplier Name from DoD sentence fallback extraction.")

    updated["AutoFix Applied"] = "Yes" if autofix_applied else "No"
    updated["AutoFix Notes"] = " | ".join(autofix_notes) if autofix_notes else "Not Applicable"

    if not issues:
        updated["Validation Flag"] = "PASS"
        updated["Validation Issues"] = "No issues detected."
    else:
        # FAIL if supplier mismatch or value mismatch
        severe = any(
            "Supplier Name not found" in x or "Value (Million) is missing" in x
            for x in issues
        )
        updated["Validation Flag"] = "FAIL" if severe else "WARN"
        updated["Validation Issues"] = " | ".join(issues)

    return updated


class ValidatorInput(BaseModel):
    rows: list = Field(description="List of extracted rows AFTER splitting. Each row is a dict.")

@tool("quality_validator_agent")
def quality_validator_agent(rows: list):
    """
    STAGE 6 TOOL: QUALITY VALIDATOR AGENT

    GOAL:
    Validate final output rows AFTER split logic.

    INPUT:
    - rows: list of extracted row dicts

    OUTPUT:
    - {"validated_rows": [row1, row2, ...]}

    THIS TOOL:
    ‚úÖ Detects likely extraction errors
    ‚úÖ Applies SAFE auto-fixes only
    ‚úÖ Adds validation columns

    IT MUST NEVER:
    ‚ùå Re-run system classification with LLM (expensive + unstable)
    ‚ùå Override Market Segment/System Name based only on heuristic
    """
    validated = []
    for r in rows:
        validated.append(validate_single_row(r))
    return {"validated_rows": validated}


# ==============================================================================
# 10) LANGGRAPH PIPELINE
# ==============================================================================

class AgentState(TypedDict):
    """
    LangGraph pipeline shared state structure.

    INPUTS:
    - input_text: contract paragraph
    - input_date: contract date string
    - input_url: source url

    INTERNAL:
    - final_data: merged dict after Stage1-4
    - final_rows: list after Stage5 split

    FINAL:
    - validated_rows: list after Stage6 quality validation

    messages:
    - reserved for message tracing
    """
    input_text: str
    input_date: str
    input_url: str
    final_data: dict
    final_rows: list
    validated_rows: list
    messages: Annotated[List[AnyMessage], add_messages]


def stage_1_sourcing(state: AgentState):
    """
    LANGGRAPH NODE Stage1:
    Runs sourcing_extractor and updates final_data.
    """
    res = sourcing_extractor.invoke({
        "paragraph": state["input_text"],
        "url": state["input_url"],
        "date": state["input_date"],
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_2_geography(state: AgentState):
    """
    LANGGRAPH NODE Stage2:
    Runs geography_extractor and updates final_data.
    """
    res = geography_extractor.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_3_system(state: AgentState):
    """
    LANGGRAPH NODE Stage3:
    Runs system_classifier and updates final_data.
    """
    res = system_classifier.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_4_contract(state: AgentState):
    """
    LANGGRAPH NODE Stage4:
    Runs contract_extractor and updates final_data.
    """
    res = contract_extractor.invoke({
        "paragraph": state["input_text"],
        "contract_date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_5_split(state: AgentState):
    """
    LANGGRAPH NODE Stage5:
    Runs splitter_agent and outputs final_rows.
    """
    res = splitter_agent.invoke({
        "paragraph": state["input_text"],
        "base_row": state["final_data"]
    })
    return {"final_rows": res.get("rows", [state["final_data"]])}


def stage_6_validate(state: AgentState):
    """
    LANGGRAPH NODE Stage6:
    Runs quality_validator_agent on Stage5 output rows.

    OUTPUT:
    - validated_rows: list of post-validated rows
    """
    res = quality_validator_agent.invoke({
        "rows": state.get("final_rows", [])
    })
    return {"validated_rows": res.get("validated_rows", state.get("final_rows", []))}


workflow = StateGraph(AgentState)
workflow.add_node("Stage1", stage_1_sourcing)
workflow.add_node("Stage2", stage_2_geography)
workflow.add_node("Stage3", stage_3_system)
workflow.add_node("Stage4", stage_4_contract)
workflow.add_node("Stage5", stage_5_split)
workflow.add_node("Stage6", stage_6_validate)

workflow.add_edge(START, "Stage1")
workflow.add_edge("Stage1", "Stage2")
workflow.add_edge("Stage2", "Stage3")
workflow.add_edge("Stage3", "Stage4")
workflow.add_edge("Stage4", "Stage5")
workflow.add_edge("Stage5", "Stage6")
workflow.add_edge("Stage6", END)

app = workflow.compile()


# ==============================================================================
# 11) GRAPH VISUALIZATION
# ==============================================================================

def export_workflow_mermaid(app_obj, out_file="workflow.mmd"):
    """
    Exports LangGraph workflow graph to Mermaid format (.mmd).

    WHY:
    - Helps document the agentic pipeline
    - Offline-safe (no API calls to mermaid.ink)
    """
    mmd = app_obj.get_graph().draw_mermaid()
    with open(out_file, "w", encoding="utf-8") as f:
        f.write(mmd)
    print(f"‚úÖ Workflow Mermaid saved locally: {out_file}")
    return out_file


# ==============================================================================
# 12) EXCEL HIGHLIGHTING FEATURE
# ==============================================================================

def highlight_evidence_reason_columns(excel_path: str):
    """
    Highlights Evidence + Reason columns in output Excel.

    Evidence Columns -> Yellow
    Reason Columns   -> Blue

    WHY:
    - Makes validation easy for reviewers
    - Improves interpretability

    NOTE:
    - Applies to the active sheet
    """
    wb = load_workbook(excel_path)
    ws = wb.active

    header = [cell.value for cell in ws[1]]

    evidence_cols = []
    reason_cols = []

    for idx, col_name in enumerate(header, start=1):
        if isinstance(col_name, str) and "Evidence" in col_name:
            evidence_cols.append(idx)
        if isinstance(col_name, str) and "Reason" in col_name:
            reason_cols.append(idx)

    evidence_fill = PatternFill(start_color="FFF2CC", end_color="FFF2CC", fill_type="solid")
    reason_fill = PatternFill(start_color="D9E1F2", end_color="D9E1F2", fill_type="solid")
    header_font = Font(bold=True)

    for col_idx in evidence_cols:
        ws.cell(row=1, column=col_idx).fill = evidence_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for col_idx in reason_cols:
        ws.cell(row=1, column=col_idx).fill = reason_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for row in range(2, ws.max_row + 1):
        for col_idx in evidence_cols:
            ws.cell(row=row, column=col_idx).fill = evidence_fill
        for col_idx in reason_cols:
            ws.cell(row=row, column=col_idx).fill = reason_fill

    wb.save(excel_path)
    print("‚úÖ Evidence + Reason columns highlighted successfully.")


# ==============================================================================
# 13) MAIN EXECUTION
# ==============================================================================

if __name__ == "__main__":

    print(f"\nüìå Loading Input File: {INPUT_EXCEL_PATH}")

    # Offline-safe workflow graph output
    export_workflow_mermaid(app, out_file="workflow.mmd")

    try:
        df_input = pd.read_excel(INPUT_EXCEL_PATH)

        required_cols = ["Source URL", "Contract Date", "Contract Description"]
        if not all(col in df_input.columns for col in required_cols):
            raise ValueError(f"Excel file must contain columns: {required_cols}")

        print(f"üöÄ Processing {len(df_input)} rows...")
        results = []

        for index, row in df_input.iterrows():
            print(f"\nüîπ Row {index + 1}/{len(df_input)}")

            desc = str(row["Contract Description"]) if pd.notna(row["Contract Description"]) else ""
            c_date = str(row["Contract Date"]) if pd.notna(row["Contract Date"]) else str(datetime.date.today())
            c_url = str(row["Source URL"]) if pd.notna(row["Source URL"]) else ""

            initial_state = {
                "input_text": desc,
                "input_date": c_date,
                "input_url": c_url,
                "final_data": {},
                "final_rows": [],
                "validated_rows": [],
                "messages": []
            }

            output_state = app.invoke(initial_state)

            rows = output_state.get("validated_rows", [])
            if not rows:
                rows = output_state.get("final_rows", [])
            if not rows:
                rows = [output_state.get("final_data", {})]

            results.extend(rows)

        df_final = pd.DataFrame(results)

        FINAL_COLUMNS = [
            "Customer Region", "Customer Country", "Customer Operator",
            "Supplier Region", "Supplier Country", "Domestic Content",

            "Split Flag", "Split Reason", "Split Evidence", "Split Line Item",

            "Market Segment", "Market Segment Evidence", "Market Segment Reason",
            "System Type (General)", "System Type (General) Evidence", "System Type (General) Reason",
            "System Type (Specific)", "System Type (Specific) Evidence", "System Type (Specific) Reason",
            "System Name (General)", "System Name (General) Evidence", "System Name (General) Reason",
            "System Name (Specific)", "System Name (Specific) Evidence", "System Name (Specific) Reason",
            "System Piloting", "System Piloting Evidence", "System Piloting Reason",
            "Confidence",

            "Supplier Name", "Program Type", "Expected MRO Contract Duration (Months)",
            "Quantity", "Value Certainty", "Value (Million)", "Currency",
            "Value (USD$ Million)", "Value Note (If Any)", "G2G/B2G",
            "Signing Month", "Signing Year",

            # Stage 6 validation columns
            "Validation Flag", "Validation Issues",
            "AutoFix Applied", "AutoFix Notes",

            "Description of Contract",
            "Additional Notes (Internal Only)",
            "Source Link(s)",
            "Contract Date",
            "Reported Date (By SGA)"
        ]

        df_final = df_final.reindex(columns=FINAL_COLUMNS, fill_value="")
        df_final.to_excel(OUTPUT_EXCEL_PATH, index=False)

        highlight_evidence_reason_columns(OUTPUT_EXCEL_PATH)

        print("\n‚úÖ Processing Complete!")
        print(f"üíæ Output File Saved: {OUTPUT_EXCEL_PATH}")
        print(df_final.head(3).to_string(index=False))

    except Exception as e:
        print(f"\n‚ùå ERROR: {e}")


In [None]:
import os
import re
import json
import difflib
import pickle
import datetime
from typing import Annotated, TypedDict, List, Dict, Any, Optional

import pandas as pd
import faiss
from dateutil import parser
from dateutil.relativedelta import relativedelta
import getpass

# LangGraph / LangChain
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from openai import OpenAI

# Excel formatting
from openpyxl import load_workbook
from openpyxl.styles import PatternFill, Font


# ==============================================================================
# 0) DEBUG LOGGING HELPERS
# ==============================================================================

def log_block(title: str, content: str):
    """
    Prints a clean and separated debug block in the console.

    WHY THIS MATTERS:
    - You have multiple stages: Stage1 -> Stage6
    - When a value is wrong, you must inspect prompts + responses per stage
    - Helps debugging extraction + validation decisions row-by-row

    OUTPUT:
    - Clearly separated logs (title + content)
    """
    print("\n" + "=" * 130)
    print(title)
    print("=" * 130)
    print(content)


# ==============================================================================
# 1) RAG RETRIEVER (FAISS + Metadata)
# ==============================================================================

class SystemKBRetriever:
    """
    RAG Retriever for Defense System Classification.

    PURPOSE:
    - Loads FAISS index and metadata from kb_dir
    - Retrieves top-k similar labeled examples

    WHY:
    - Makes system classification consistent against your labeled KB dataset
    - Reduces hallucination and random taxonomy mapping

    EXPECTED FILES:
    - system_kb.faiss
    - system_kb_meta.pkl

    OUTPUT:
    retrieve(query_text) returns:
      [
        {"score": float, "meta": dict_of_kb_row},
        ...
      ]
    """

    def __init__(self, kb_dir: str, embed_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
        """
        Initializes retriever:
        - loads FAISS index
        - loads metadata rows list
        - lazy-loads sentence-transformer embedder
        """
        self.kb_dir = kb_dir
        self.embed_model = embed_model

        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError(
                f"‚ùå KB files not found in: {kb_dir}\n"
                f"Expected:\n- {index_path}\n- {meta_path}\n\n"
                f"Build the KB first before running this pipeline."
            )

        print(f"‚úÖ Loading FAISS Index: {index_path}")
        self.index = faiss.read_index(index_path)

        print(f"‚úÖ Loading KB Metadata: {meta_path}")
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        print(f"‚úÖ KB Loaded rows: {len(self.meta)}")
        self.embedder = None

    def _lazy_load_embedder(self):
        """
        Loads embedding model only when needed.

        WHY:
        - Faster script startup
        - Saves memory until retrieval is actually requested
        """
        if self.embedder is None:
            from sentence_transformers import SentenceTransformer
            self.embedder = SentenceTransformer(self.embed_model)

    def retrieve(self, query_text: str, top_k: int = 3):
        """
        Retrieve top-k nearest semantic matches from FAISS.

        INPUT:
        - query_text: paragraph string
        - top_k: number of matches

        OUTPUT:
        - list of {score, meta}
        """
        import numpy as np

        query_text = str(query_text).strip()
        if not query_text:
            return []

        self._lazy_load_embedder()

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)

        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0:
                continue
            results.append({"score": float(score), "meta": self.meta[idx]})
        return results


# ==============================================================================
# 2) CONFIGURATION & FILE PATHS
# ==============================================================================

TAXONOMY_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\taxonomy.json"
SUPPLIERS_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\suppliers.json"
INPUT_EXCEL_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\source_file.xlsx"
OUTPUT_EXCEL_PATH = "Processed_Defense_Data.xlsx"
RAG_KB_DIR = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\system_kb_store"


# ==============================================================================
# 3) SETUP LLM CLIENT
# ==============================================================================

if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter the LLM Foundry API Key: ")

client = OpenAI(
    api_key=f'{os.environ.get("LLMFOUNDRY_TOKEN")}:my-test-project',
    base_url="https://llmfoundry.straive.com/openai/v1/",
)

retriever = SystemKBRetriever(kb_dir=RAG_KB_DIR)


# ==============================================================================
# 4) LOAD JSON HELPERS
# ==============================================================================

def load_json_file(filename, default_value):
    """
    Safe JSON loader.

    WHY:
    - taxonomy.json / suppliers.json are required for extraction + normalization
    - pipeline should not crash if missing; should fallback
    """
    try:
        with open(filename, "r", encoding="utf-8") as f:
            print(f"‚úÖ Loaded: {filename}")
            return json.load(f)
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not load {filename} ({e}). Using default.")
        return default_value


raw_taxonomy = load_json_file(TAXONOMY_PATH, {})
TAXONOMY_STR = json.dumps(raw_taxonomy, separators=(",", ":"))

SUPPLIER_LIST = load_json_file(SUPPLIERS_PATH, [
    "Dell Inc", "Boeing", "Lockheed Martin", "Raytheon Technologies",
    "Northrop Grumman", "L3Harris", "BAE Systems", "General Dynamics"
])


# ==============================================================================
# 5) RULE BOOK + GEOGRAPHY
# ==============================================================================

RULE_BOOK = {
    "defensive_countermeasures": {
        "triggers": ["flare", "chaff", "countermeasure", "decoy", "mju-", "ale-"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Defensive Systems', Specific: 'Defensive Aid Suite'"
    },
    "radars_and_sensors": {
        "triggers": ["radar", "sonar", "sensor", "an/apy", "an/tpy"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Sensors'"
    },
    "ammunition": {
        "triggers": ["cartridge", "round", "projectile", " 5.56", " 7.62", "ammo"],
        "guidance": "Market Segment: 'Weapon Systems', System Type (General): 'Ammunition'"
    }
}

GEOGRAPHY_MAPPING = {
    "North America": ["USA", "United States", "US", "United States of America", "Canada", "America"],
    "Europe": ["UK", "United Kingdom", "Ukraine", "Germany", "France", "Italy", "Spain", "Poland", "Netherlands",
               "Norway", "Sweden", "Finland", "Denmark", "Belgium"],
    "Asia-Pacific": ["Australia", "Japan", "South Korea", "Taiwan", "India", "Singapore", "New Zealand"],
    "Middle East and North Africa": ["Israel", "Saudi Arabia", "UAE", "United Arab Emirates", "Egypt", "Qatar", "Kuwait", "Iraq"],
    "International Organisations": ["NATO", "EU", "IFU", "UN", "NSPA"]
}


# ==============================================================================
# 6) BASE HELPERS (Supplier + Dates + Region + Designators)
# ==============================================================================

def get_best_supplier_match(extracted_name: str):
    """
    Standardizes supplier names using suppliers.json.

    STRATEGY:
    1) Unknown/blank -> "Unknown"
    2) Exact match ignoring case
    3) Fuzzy match with difflib
    4) fallback to cleaned extracted string
    """
    if not extracted_name or str(extracted_name).strip().lower() in ["unknown", "n/a", "not applicable", ""]:
        return "Unknown"

    clean_name = str(extracted_name).strip()
    supplier_map = {s.lower(): s for s in SUPPLIER_LIST}

    if clean_name.lower() in supplier_map:
        return supplier_map[clean_name.lower()]

    matches = difflib.get_close_matches(clean_name, SUPPLIER_LIST, n=1, cutoff=0.6)
    return matches[0] if matches else clean_name


def calculate_mro_months(start_date_str, end_date_text, program_type):
    """
    Calculates MRO duration in months.

    RULE:
    - Only valid when program_type == "MRO/Support"
    """
    if program_type != "MRO/Support":
        return "Not Applicable"
    try:
        if not start_date_str or not end_date_text:
            return "Not Applicable"

        start = pd.to_datetime(start_date_str, dayfirst=True)
        end = parser.parse(str(end_date_text), fuzzy=True)

        diff = relativedelta(end, start)
        total_months = diff.years * 12 + diff.months
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"


def get_region_for_country(country_name):
    """
    Maps country name to region bucket.

    OUTPUT:
    - region string OR "Unknown"
    """
    if not country_name or str(country_name).strip().lower() in ["unknown", "n/a", "not applicable", ""]:
        return "Unknown"

    clean = str(country_name).strip().lower()

    if clean in ["us", "usa", "u.s.", "united states", "united states of america"]:
        return "North America"
    if clean in ["uk", "u.k.", "britain", "great britain"]:
        return "Europe"

    for region, countries in GEOGRAPHY_MAPPING.items():
        if any(c.lower() == clean for c in countries):
            return region

    return "Unknown"


DESIGNATOR_PATTERNS = [
    r"\bDDG[-\s]?\d+\b", r"\bCVN[-\s]?\d+\b", r"\bSSN[-\s]?\d+\b",
    r"\bLCS[-\s]?\d+\b", r"\bLPD[-\s]?\d+\b", r"\bLHA[-\s]?\d+\b", r"\bLHD[-\s]?\d+\b",
    r"\bF-\d+\b", r"\bB-\d+\b", r"\bC-\d+\b", r"\bA-\d+\b",
    r"\bMQ-\d+\b", r"\bRQ-\d+\b",
    r"\bAN\/[A-Z0-9\-]+\b",
    r"\b(AIM|AGM|SM|RIM|MIM)-\d+\b",
]

def extract_designators(text: str):
    """
    Extracts defense platform designators.

    EXAMPLES:
    - DDG-51, CVN-78
    - MQ-9, RQ-4
    - AN/APY-10
    """
    text = str(text)
    found = []
    for pat in DESIGNATOR_PATTERNS:
        found.extend(re.findall(pat, text, flags=re.IGNORECASE))

    cleaned = []
    for f in found:
        cleaned.append(f.upper().replace(" ", "").replace("--", "-"))

    final = []
    seen = set()
    for x in cleaned:
        if x not in seen:
            final.append(x)
            seen.add(x)

    return final


def detect_piloting_rule_based(text: str, designators: List[str]) -> str:
    """
    Determines System Piloting using deterministic rules.

    OUTPUT:
    - "Crewed"
    - "Uncrewed"
    - "Not Applicable"
    """
    t = str(text).lower()

    if any(d.startswith(("MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    if any(k in t for k in ["unmanned", "uav", "drone", "autonomous"]):
        return "Uncrewed"

    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD")) for d in designators):
        return "Crewed"
    if "uss " in t:
        return "Crewed"

    return "Not Applicable"


# ==============================================================================
# 7) SPLIT ENGINE (DETERMINISTIC)
# ==============================================================================

def _normalize_spaces(text: str) -> str:
    """Normalizes whitespace for stable regex parsing."""
    return re.sub(r"\s+", " ", str(text or "")).strip()

def _safe_int(value: str):
    """Safely converts string to int; returns None if conversion fails."""
    try:
        return int(str(value).replace(",", "").strip())
    except Exception:
        return None

def _extract_supplier_from_description(paragraph: str) -> str:
    """
    Fallback supplier extraction from typical DoD contract format.

    Example:
    "Raytheon Technologies, Tucson, Arizona, is awarded $X..."

    Extraction:
    - take text before first comma, if "is awarded" exists
    """
    p = _normalize_spaces(paragraph)
    m = re.match(r"^(.*?),\s+.*?\s+is awarded", p, flags=re.IGNORECASE)
    if m:
        return m.group(1).strip()
    return "Unknown"

def parse_fms_countries(paragraph: str):
    """
    Extracts FMS countries if paragraph contains:
      "governments of Australia, Bahrain, Belgium..."

    OUTPUT:
    - list[str]
    """
    text = str(paragraph)

    m = re.search(
        r"governments of (.+?)(?:\.\s| Work will be performed| Fiscal| This contract|$)",
        text,
        flags=re.IGNORECASE | re.DOTALL
    )
    if not m:
        return []

    block = m.group(1)
    raw = re.split(r",|\band\b", block)

    countries = []
    for c in raw:
        c = c.strip()
        if 2 < len(c) <= 50:
            countries.append(c)

    final, seen = [], set()
    for c in countries:
        if c.lower() not in seen:
            final.append(c)
            seen.add(c.lower())

    return final

def parse_operator_allocations(text: str):
    """
    Detects operator allocations like:
    "212 for the Navy"
    "187 for the Air Force"
    "84 for FMS customers"

    OUTPUT:
      [
        {"operator": "Navy", "qty": 212, "g2g_b2g": "B2G"},
        {"operator": "Foreign Assistance", "qty": 84, "g2g_b2g": "G2G"}
      ]
    """
    if not text:
        return []

    txt = str(text).lower()
    allocations = []

    ops = re.findall(r"(\d{1,6}(?:,\d{3})*)\s+for\s+the\s+(navy|air force|army|marine corps)", txt)
    for qty_word, op in ops:
        qty = _safe_int(qty_word)
        if qty is None:
            continue
        op_norm = "Air Force" if op == "air force" else op.title()
        allocations.append({"operator": op_norm, "qty": qty, "g2g_b2g": "B2G"})

    fms = re.findall(
        r"(\d{1,6}(?:,\d{3})*)\s+for\s+(?:foreign military sales\s*\(fms\)\s*customers|fms\s*customers|fms)",
        txt
    )
    for qty_word in fms:
        qty = _safe_int(qty_word)
        if qty is None:
            continue
        allocations.append({"operator": "Foreign Assistance", "qty": qty, "g2g_b2g": "G2G"})

    uniq, seen = [], set()
    for a in allocations:
        key = (a["operator"], a["qty"], a["g2g_b2g"])
        if key not in seen:
            uniq.append(a)
            seen.add(key)

    return uniq

def split_rows_engine(base_row: dict, paragraph: str):
    """
    Split engine that expands rows when multi-operator allocations exist.

    Splits supported:
    ‚úÖ Navy/Air Force/Army/Marine allocations
    ‚úÖ FMS allocations (G2G)
    ‚úÖ FMS countries -> 1 row per country

    Adds:
    - Split Flag
    - Split Reason
    - Split Evidence
    - Split Line Item
    """
    paragraph = _normalize_spaces(paragraph)
    base_row = base_row.copy()

    # Supplier fallback (critical)
    if not base_row.get("Supplier Name") or base_row.get("Supplier Name") in ["", "Unknown"]:
        extracted = _extract_supplier_from_description(paragraph)
        base_row["Supplier Name"] = get_best_supplier_match(extracted)

    fms_countries = parse_fms_countries(paragraph)
    allocations = parse_operator_allocations(paragraph)

    if not allocations:
        base_row["Split Flag"] = "No"
        base_row["Split Reason"] = "No allocation split detected."
        base_row["Split Evidence"] = "Not Found"
        base_row["Split Line Item"] = "Not Found"
        return [base_row]

    rows = []
    for alloc in allocations:
        if alloc["g2g_b2g"] == "G2G" and fms_countries:
            for c in fms_countries:
                rr = base_row.copy()
                rr["Customer Country"] = c
                rr["Customer Region"] = get_region_for_country(c)
                rr["Customer Operator"] = "Foreign Assistance"
                rr["Quantity"] = "1"
                rr["G2G/B2G"] = "G2G"

                rr["Split Flag"] = "Yes"
                rr["Split Reason"] = "FMS multi-country split"
                rr["Split Evidence"] = f"Countries: {', '.join(fms_countries)}"
                rr["Split Line Item"] = "FMS countries"
                rows.append(rr)
        else:
            rr = base_row.copy()
            rr["Customer Operator"] = alloc["operator"]
            rr["Quantity"] = str(alloc["qty"])
            rr["G2G/B2G"] = alloc["g2g_b2g"]

            rr["Split Flag"] = "Yes"
            rr["Split Reason"] = "Operator allocation split"
            rr["Split Evidence"] = paragraph[:250] + ("..." if len(paragraph) > 250 else "")
            rr["Split Line Item"] = alloc["operator"]
            rows.append(rr)

    return rows


# ==============================================================================
# 8) TOOLS / AGENTS (Stage1 -> Stage6)
# ==============================================================================

class SourcingInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    url: str = Field(description="Source URL of the contract announcement/news.")
    date: str = Field(description="Contract date string.")

@tool("sourcing_extractor")
def sourcing_extractor(paragraph: str, url: str, date: str):
    """
    Stage 1: Sourcing Extractor

    Produces base metadata fields that stay constant:
    - Description of Contract
    - Source Link(s)
    - Contract Date
    - Reported Date (By SGA)
    - Additional Notes (Internal Only)
    """
    reported_date = datetime.datetime.now().strftime("%Y-%m-%d")

    notes = "Standard extraction."
    if "modification" in str(paragraph).lower():
        notes = "Contract Modification."
    if "split" in str(paragraph).lower():
        notes = "Split award detected."

    return {
        "Description of Contract": paragraph,
        "Additional Notes (Internal Only)": notes,
        "Source Link(s)": url,
        "Contract Date": date,
        "Reported Date (By SGA)": reported_date
    }


class GeographyInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")

@tool("geography_extractor")
def geography_extractor(paragraph: str):
    """
    Stage 2: Geography Extractor

    Extract:
    - Customer Country
    - Customer Operator
    - Supplier Country

    Derive:
    - Customer Region
    - Supplier Region
    - Domestic Content
    """
    sys_prompt = """
You are a Defense Contract Geography Analyst.

Extract ONLY:
1) Customer Country
2) Customer Operator
3) Supplier Country

Return JSON ONLY:
{
  "Customer Country": "...",
  "Customer Operator": "...",
  "Supplier Country": "..."
}

Rules:
- If missing -> "Unknown"
- Keep values short and clean
"""

    log_block("HUMAN MESSAGE (Stage2 - Geography)", paragraph)

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": paragraph}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage2 - Geography)", json.dumps(raw, indent=2))
    except Exception as e:
        raw = {}
        log_block("AI ERROR (Stage2 - Geography)", str(e))

    cust = raw.get("Customer Country", "Unknown")
    supp = raw.get("Supplier Country", "Unknown")

    domestic = "Indigenous" if str(cust).lower() == str(supp).lower() else "Imported"

    return {
        "Customer Region": get_region_for_country(cust),
        "Customer Country": cust,
        "Customer Operator": raw.get("Customer Operator", "Unknown"),
        "Supplier Region": get_region_for_country(supp),
        "Supplier Country": supp,
        "Domestic Content": domestic
    }


class SystemInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")

@tool("system_classifier")
def system_classifier(paragraph: str):
    """
    Stage 3: System Classifier (RAG enhanced)

    Output fields (flat):
    - Market Segment (+Evidence+Reason)
    - System Type General (+Evidence+Reason)
    - System Type Specific (+Evidence+Reason)
    - System Name General (+Evidence+Reason)
    - System Name Specific (+Evidence+Reason)
    - System Piloting (overridden deterministically)
    - Confidence
    """
    paragraph = str(paragraph).strip()
    if not paragraph:
        return {}

    log_block("HUMAN MESSAGE (Stage3 - System)", paragraph)

    lower_text = paragraph.lower()
    hints = [
        f"RULE: {v['guidance']}"
        for _, v in RULE_BOOK.items()
        if any(t in lower_text for t in v["triggers"])
    ]
    hint_str = "\n".join(hints) if hints else "No special override rules triggered."

    designators = extract_designators(paragraph)
    piloting_rule = detect_piloting_rule_based(paragraph, designators)

    rag_hits = retriever.retrieve(paragraph, top_k=3)
    rag_examples = []
    for hit in rag_hits:
        meta = hit["meta"]
        rag_examples.append({
            "score": round(hit["score"], 4),
            "Market Segment": meta.get("Market Segment", ""),
            "System Type (General)": meta.get("System Type (General)", ""),
            "System Type (Specific)": meta.get("System Type (Specific)", ""),
            "System Name (General)": meta.get("System Name (General)", ""),
            "System Name (Specific)": meta.get("System Name (Specific)", ""),
            "System Piloting": meta.get("System Piloting", ""),
            "Snippet": meta.get("Description of Contract", "")[:200] + "..."
        })

    sys_prompt = f"""
You are a Senior Defense System Classification Analyst.

REFERENCE TAXONOMY:
{TAXONOMY_STR}

RULE BOOK OVERRIDES:
{hint_str}

STRICT OUTPUT RULES:
- Return ONLY a FLAT JSON object
- Every value MUST be a STRING
- Evidence MUST be copied EXACTLY from paragraph
- If evidence not found -> "Not Found"

Return JSON:
{{
  "Market Segment": "",
  "Market Segment Evidence": "",
  "Market Segment Reason": "",

  "System Type (General)": "",
  "System Type (General) Evidence": "",
  "System Type (General) Reason": "",

  "System Type (Specific)": "",
  "System Type (Specific) Evidence": "",
  "System Type (Specific) Reason": "",

  "System Name (General)": "",
  "System Name (General) Evidence": "",
  "System Name (General) Reason": "",

  "System Name (Specific)": "",
  "System Name (Specific) Evidence": "",
  "System Name (Specific) Reason": "",

  "System Piloting": "",
  "System Piloting Evidence": "",
  "System Piloting Reason": "",

  "Confidence": "High/Medium/Low"
}}
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

DESIGNATORS:
{designators if designators else "None"}

RULE PILOTING OVERRIDE:
{piloting_rule}

RAG EXAMPLES:
{json.dumps(rag_examples, indent=2)}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        result = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage3 - System)", json.dumps(result, indent=2))

        # hard override piloting
        result["System Piloting"] = piloting_rule
        result.setdefault("System Piloting Evidence", "Not Found")
        result.setdefault("System Piloting Reason", "Derived using deterministic piloting rules.")

        for k, v in result.items():
            if isinstance(v, (dict, list)):
                result[k] = str(v)

        return result

    except Exception as e:
        log_block("AI ERROR (Stage3 - System)", str(e))
        return {
            "Market Segment": "",
            "Market Segment Evidence": "Not Found",
            "Market Segment Reason": "",

            "System Type (General)": "",
            "System Type (General) Evidence": "Not Found",
            "System Type (General) Reason": "",

            "System Type (Specific)": "",
            "System Type (Specific) Evidence": "Not Found",
            "System Type (Specific) Reason": "",

            "System Name (General)": "",
            "System Name (General) Evidence": "Not Found",
            "System Name (General) Reason": "",

            "System Name (Specific)": "",
            "System Name (Specific) Evidence": "Not Found",
            "System Name (Specific) Reason": "",

            "System Piloting": piloting_rule,
            "System Piloting Evidence": "Not Found",
            "System Piloting Reason": "Derived using deterministic piloting rules (fallback).",

            "Confidence": "Low",
            "Error": str(e)
        }


class ContractInfoInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    contract_date: str = Field(description="Contract date string.")

@tool("contract_extractor")
def contract_extractor(paragraph: str, contract_date: str):
    """
    Stage 4: Contract Extractor

    Extract:
    - Supplier Name (raw -> standardized)
    - Program Type
    - Quantity
    - Value (Million)
    - Currency
    - Value Certainty
    - G2G/B2G
    - Completion date text (for MRO)
    """
    system_instruction = """
You are a Defense Contract Financial Analyst.

Return JSON ONLY.

Rules:
1) raw_supplier_name: exact supplier as in paragraph
2) program_type: Procurement/Training/MRO/Support/RDT&E/Upgrade/Other Service
3) value_certainty: Confirmed vs Estimated
4) quantity: numeric if available else Not Applicable
5) g2g_b2g: G2G only if FMS mentioned else B2G
6) completion_date_text: end date if mentioned
7) value_million_raw: numeric only
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

SIGNED DATE:
{contract_date}

Return JSON:
{{
  "raw_supplier_name": "",
  "program_type": "",
  "value_million_raw": "",
  "currency_code": "",
  "value_certainty": "",
  "quantity": "",
  "completion_date_text": "",
  "g2g_b2g": "",
  "value_note": ""
}}
"""

    log_block("HUMAN MESSAGE (Stage4 - Contract)", paragraph)

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage4 - Contract)", json.dumps(raw, indent=2))
    except Exception as e:
        log_block("AI ERROR (Stage4 - Contract)", str(e))
        return {"Error": str(e)}

    final_supplier = get_best_supplier_match(raw.get("raw_supplier_name"))
    prog_type = raw.get("program_type", "Unknown")
    mro_months = calculate_mro_months(contract_date, raw.get("completion_date_text"), prog_type)

    try:
        val_str = str(raw.get("value_million_raw", "0")).replace(",", "").replace("$", "")
        val_float = float(val_str)
        val_formatted = "{:.3f}".format(val_float)
    except:
        val_formatted = "0.000"

    try:
        dt = pd.to_datetime(contract_date)
        sign_month = dt.strftime("%B")
        sign_year = str(dt.year)
    except:
        sign_month, sign_year = "Unknown", "Unknown"

    return {
        "Supplier Name": final_supplier,
        "Program Type": prog_type,
        "Expected MRO Contract Duration (Months)": mro_months,
        "Quantity": raw.get("quantity", "Not Applicable"),
        "Value Certainty": raw.get("value_certainty", "Confirmed"),
        "Value (Million)": val_formatted,
        "Currency": raw.get("currency_code", "USD$"),
        "Value (USD$ Million)": val_formatted,
        "Value Note (If Any)": raw.get("value_note", "Not Applicable"),
        "G2G/B2G": raw.get("g2g_b2g", "B2G"),
        "Signing Month": sign_month,
        "Signing Year": sign_year
    }


class SplitterInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    base_row: dict = Field(description="Final extracted row after Stage1-4.")

@tool("splitter_agent")
def splitter_agent(paragraph: str, base_row: dict):
    """
    Stage 5: Splitter Agent

    Splits row using deterministic engine:
    - operator allocations
    - FMS multi-countries
    """
    try:
        rows = split_rows_engine(base_row, paragraph)
        for r in rows:
            r.setdefault("Split Flag", "No")
            r.setdefault("Split Reason", "")
            r.setdefault("Split Evidence", "Not Found")
            r.setdefault("Split Line Item", "Not Found")
        return {"rows": rows}
    except Exception as e:
        base_row = base_row.copy()
        base_row["Split Flag"] = "Error"
        base_row["Split Reason"] = f"Split failed: {str(e)}"
        base_row["Split Evidence"] = "Not Found"
        base_row["Split Line Item"] = "Not Found"
        return {"rows": [base_row]}


# ==============================================================================
# 9) STAGE 6 QUALITY VALIDATOR (RULE-BASED + FAIL-ONLY LLM VALIDATION)
# ==============================================================================

def _contains_any(text: str, keywords: List[str]) -> bool:
    """True if any keyword appears in text (case-insensitive)."""
    t = str(text or "").lower()
    return any(k.lower() in t for k in keywords)

def _supplier_appears_in_paragraph(supplier: str, paragraph: str) -> bool:
    """
    Checks if supplier appears in paragraph text.

    NOTE:
    - Works on standardized supplier (may not match exactly)
    - Also tries first 2 tokens partial match
    """
    supplier = str(supplier or "").strip()
    if not supplier or supplier.lower() == "unknown":
        return False

    p = str(paragraph or "").lower()
    s = supplier.lower()

    if s in p:
        return True

    parts = s.split()
    if len(parts) >= 2:
        key = " ".join(parts[:2])
        return key in p

    return False

def validate_single_row_rule_based(row: Dict[str, Any]) -> Dict[str, Any]:
    """
    RULE-BASED VALIDATION (Stage6.1)

    Adds:
    - Validation Flag: PASS/WARN/FAIL
    - Validation Issues
    - AutoFix Applied
    - AutoFix Notes

    SAFE FIXES:
    ‚úÖ fix piloting by deterministic rules
    ‚úÖ fix G2G/B2G if FMS exists
    ‚úÖ fill supplier if unknown using fallback supplier extraction

    DOES NOT:
    ‚ùå overwrite Market Segment / System Names
    """
    updated = row.copy()
    issues = []
    autofix_notes = []
    autofix_applied = False

    paragraph = updated.get("Description of Contract", "")
    supplier = updated.get("Supplier Name", "Unknown")

    # supplier mismatch
    if supplier and supplier != "Unknown":
        if not _supplier_appears_in_paragraph(supplier, paragraph):
            issues.append("Supplier Name not found in paragraph (possible wrong supplier extraction).")

    # piloting mismatch
    designators = extract_designators(paragraph)
    piloting_expected = detect_piloting_rule_based(paragraph, designators)
    piloting_current = updated.get("System Piloting", "Not Applicable")

    if piloting_expected != "Not Applicable" and piloting_current != piloting_expected:
        issues.append(f"System Piloting mismatch (expected={piloting_expected}, got={piloting_current}).")
        updated["System Piloting"] = piloting_expected
        updated["System Piloting Reason"] = "Validator override: deterministic piloting rules applied."
        autofix_applied = True
        autofix_notes.append("Corrected System Piloting using deterministic rules.")

    # FMS mismatch
    if _contains_any(paragraph, ["fms", "foreign military sales"]):
        if updated.get("G2G/B2G", "B2G") != "G2G":
            issues.append("Paragraph mentions FMS but row marked as B2G.")
            updated["G2G/B2G"] = "G2G"
            autofix_applied = True
            autofix_notes.append("Fixed G2G/B2G to G2G due to FMS keyword.")

    # Value mismatch
    if "$" in str(paragraph) and str(updated.get("Value (Million)", "")).strip() in ["", "0.000", "0"]:
        issues.append("Paragraph contains $ value but Value (Million) is missing/zero.")

    # segment heuristic warn
    if _contains_any(paragraph, ["radar", "an/apy", "an/tpy"]) and updated.get("Market Segment", ""):
        if updated.get("Market Segment") not in ["C4ISR Systems", "C4ISR"]:
            issues.append("Radar keyword detected but Market Segment not C4ISR (check classification).")

    # Supplier unknown autofix
    if supplier in ["Unknown", "", None]:
        fallback_supplier = _extract_supplier_from_description(paragraph)
        if fallback_supplier != "Unknown":
            updated["Supplier Name"] = get_best_supplier_match(fallback_supplier)
            autofix_applied = True
            autofix_notes.append("Filled Supplier Name from DoD fallback extraction.")

    updated["AutoFix Applied"] = "Yes" if autofix_applied else "No"
    updated["AutoFix Notes"] = " | ".join(autofix_notes) if autofix_notes else "Not Applicable"

    if not issues:
        updated["Validation Flag"] = "PASS"
        updated["Validation Issues"] = "No issues detected."
    else:
        severe = any(
            "Supplier Name not found" in x or "Value (Million) is missing" in x
            for x in issues
        )
        updated["Validation Flag"] = "FAIL" if severe else "WARN"
        updated["Validation Issues"] = " | ".join(issues)

    return updated


def llm_validate_fail_row(row: Dict[str, Any]) -> Dict[str, Any]:
    """
    LLM VALIDATION (Stage6.2) - RUN ONLY FOR FAIL ROWS

    WHY THIS IS IMPORTANT:
    - Some FAIL rows require semantic judgment:
      Supplier confusion, wrong system label, wrong G2G/B2G logic.

    INPUT:
    - Row dict (already extracted) including paragraph in Description of Contract.

    OUTPUT:
    Adds:
    - LLM Validation Run: Yes/No
    - LLM Validation Verdict: CONFIRMED / WRONG / NEEDS_REVIEW
    - LLM Validation Notes: explanation
    - LLM Suggested Fixes: JSON string (flat)
    - LLM Fix Applied: Yes/No

    SAFE APPLICATION RULES:
    - We only apply these fixes automatically:
      ‚úÖ Supplier Name (only if suggested supplier appears in paragraph)
      ‚úÖ G2G/B2G
      ‚úÖ Value Certainty (if estimated)
      ‚úÖ Program Type (only if suggested value in allowed list)
    - We NEVER auto-override:
      ‚ùå Market Segment / System Types / System Names (only suggest, do not apply)
    """
    updated = row.copy()
    paragraph = updated.get("Description of Contract", "")

    updated["LLM Validation Run"] = "Yes"
    updated.setdefault("LLM Fix Applied", "No")
    updated.setdefault("LLM Suggested Fixes", "Not Applicable")
    updated.setdefault("LLM Validation Verdict", "NEEDS_REVIEW")
    updated.setdefault("LLM Validation Notes", "Not Applicable")

    sys_prompt = """
You are a Defense Contract Quality Auditor.

Your job:
- Validate if the extracted row fields match the contract paragraph.
- Detect mismatches or hallucinations.
- Suggest corrected fields ONLY when strongly supported by paragraph text.

IMPORTANT RULES:
1) NEVER fabricate any supplier/system/value not present in paragraph.
2) Output must be a FLAT JSON object only (no nesting).
3) You must include a verdict:
   - "CONFIRMED" (row looks correct)
   - "WRONG" (row clearly incorrect)
   - "NEEDS_REVIEW" (unclear)
4) If suggesting a fix, give corrected values.

Return JSON:
{
  "verdict": "CONFIRMED/WRONG/NEEDS_REVIEW",
  "notes": "short explanation",
  "suggest_supplier_name": "",
  "suggest_g2g_b2g": "",
  "suggest_program_type": "",
  "suggest_value_certainty": "",
  "suggest_customer_operator": "",
  "suggest_quantity": ""
}

Allowed program_type values:
Procurement / Training / MRO/Support / RDT&E / Upgrade / Other Service
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

CURRENT EXTRACTED ROW:
Supplier Name: {updated.get("Supplier Name")}
Market Segment: {updated.get("Market Segment")}
System Type (General): {updated.get("System Type (General)")}
System Type (Specific): {updated.get("System Type (Specific)")}
System Name (General): {updated.get("System Name (General)")}
System Name (Specific): {updated.get("System Name (Specific)")}
Program Type: {updated.get("Program Type")}
Quantity: {updated.get("Quantity")}
Value (Million): {updated.get("Value (Million)")}
Currency: {updated.get("Currency")}
G2G/B2G: {updated.get("G2G/B2G")}
Customer Operator: {updated.get("Customer Operator")}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )

        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage6.2 - LLM FAIL VALIDATOR)", json.dumps(raw, indent=2))

        verdict = str(raw.get("verdict", "NEEDS_REVIEW")).strip().upper()
        notes = str(raw.get("notes", "Not Applicable")).strip()

        updated["LLM Validation Verdict"] = verdict
        updated["LLM Validation Notes"] = notes

        # suggested fixes
        suggested = {
            "suggest_supplier_name": raw.get("suggest_supplier_name", ""),
            "suggest_g2g_b2g": raw.get("suggest_g2g_b2g", ""),
            "suggest_program_type": raw.get("suggest_program_type", ""),
            "suggest_value_certainty": raw.get("suggest_value_certainty", ""),
            "suggest_customer_operator": raw.get("suggest_customer_operator", ""),
            "suggest_quantity": raw.get("suggest_quantity", "")
        }
        updated["LLM Suggested Fixes"] = json.dumps(suggested, ensure_ascii=False)

        # ===== APPLY SAFE FIXES ONLY =====
        applied_any = False

        # Supplier
        sug_supplier = str(raw.get("suggest_supplier_name", "")).strip()
        if sug_supplier:
            sug_std = get_best_supplier_match(sug_supplier)
            if _supplier_appears_in_paragraph(sug_std, paragraph):
                updated["Supplier Name"] = sug_std
                applied_any = True

        # G2G/B2G
        sug_g2g = str(raw.get("suggest_g2g_b2g", "")).strip()
        if sug_g2g in ["G2G", "B2G"]:
            updated["G2G/B2G"] = sug_g2g
            applied_any = True

        # Program Type (strict allowed list)
        allowed_programs = {"Procurement", "Training", "MRO/Support", "RDT&E", "Upgrade", "Other Service"}
        sug_prog = str(raw.get("suggest_program_type", "")).strip()
        if sug_prog in allowed_programs:
            updated["Program Type"] = sug_prog
            applied_any = True

        # Value certainty
        sug_cert = str(raw.get("suggest_value_certainty", "")).strip()
        if sug_cert in ["Confirmed", "Estimated"]:
            updated["Value Certainty"] = sug_cert
            applied_any = True

        # Operator
        sug_op = str(raw.get("suggest_customer_operator", "")).strip()
        if sug_op and sug_op.lower() != "unknown":
            updated["Customer Operator"] = sug_op
            applied_any = True

        # Quantity (only if numeric)
        sug_qty = str(raw.get("suggest_quantity", "")).strip()
        if sug_qty and sug_qty.isdigit():
            updated["Quantity"] = sug_qty
            applied_any = True

        updated["LLM Fix Applied"] = "Yes" if applied_any else "No"
        return updated

    except Exception as e:
        log_block("AI ERROR (Stage6.2 - LLM FAIL VALIDATOR)", str(e))
        updated["LLM Validation Verdict"] = "NEEDS_REVIEW"
        updated["LLM Validation Notes"] = f"LLM validator failed: {str(e)}"
        updated["LLM Suggested Fixes"] = "Not Applicable"
        updated["LLM Fix Applied"] = "No"
        return updated


class ValidatorInput(BaseModel):
    rows: list = Field(description="List of extracted rows AFTER splitting. Each row is a dict.")

@tool("quality_validator_agent")
def quality_validator_agent(rows: list):
    """
    Stage 6: Quality Validator Agent (Hybrid)

    PIPELINE:
    Stage 6.1 -> Rule-based validator for all rows
    Stage 6.2 -> LLM validator ONLY for FAIL rows

    WHY:
    - Cost controlled (LLM only when needed)
    - Increases accuracy for tricky mismatches

    OUTPUT:
    - {"validated_rows": [validated_row1, validated_row2, ...]}
    """
    validated = []
    for r in rows:
        rb = validate_single_row_rule_based(r)

        # default LLM fields
        rb.setdefault("LLM Validation Run", "No")
        rb.setdefault("LLM Validation Verdict", "Not Applicable")
        rb.setdefault("LLM Validation Notes", "Not Applicable")
        rb.setdefault("LLM Suggested Fixes", "Not Applicable")
        rb.setdefault("LLM Fix Applied", "No")

        if rb.get("Validation Flag") == "FAIL":
            rb = llm_validate_fail_row(rb)

        validated.append(rb)

    return {"validated_rows": validated}


# ==============================================================================
# 10) LANGGRAPH PIPELINE
# ==============================================================================

class AgentState(TypedDict):
    """
    Shared state between LangGraph nodes.
    """
    input_text: str
    input_date: str
    input_url: str
    final_data: dict
    final_rows: list
    validated_rows: list
    messages: Annotated[List[AnyMessage], add_messages]


def stage_1_sourcing(state: AgentState):
    """Stage1 node"""
    res = sourcing_extractor.invoke({
        "paragraph": state["input_text"],
        "url": state["input_url"],
        "date": state["input_date"],
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_2_geography(state: AgentState):
    """Stage2 node"""
    res = geography_extractor.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_3_system(state: AgentState):
    """Stage3 node"""
    res = system_classifier.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_4_contract(state: AgentState):
    """Stage4 node"""
    res = contract_extractor.invoke({
        "paragraph": state["input_text"],
        "contract_date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_5_split(state: AgentState):
    """Stage5 node"""
    res = splitter_agent.invoke({
        "paragraph": state["input_text"],
        "base_row": state["final_data"]
    })
    return {"final_rows": res.get("rows", [state["final_data"]])}


def stage_6_validate(state: AgentState):
    """Stage6 node"""
    res = quality_validator_agent.invoke({
        "rows": state.get("final_rows", [])
    })
    return {"validated_rows": res.get("validated_rows", state.get("final_rows", []))}


workflow = StateGraph(AgentState)
workflow.add_node("Stage1", stage_1_sourcing)
workflow.add_node("Stage2", stage_2_geography)
workflow.add_node("Stage3", stage_3_system)
workflow.add_node("Stage4", stage_4_contract)
workflow.add_node("Stage5", stage_5_split)
workflow.add_node("Stage6", stage_6_validate)

workflow.add_edge(START, "Stage1")
workflow.add_edge("Stage1", "Stage2")
workflow.add_edge("Stage2", "Stage3")
workflow.add_edge("Stage3", "Stage4")
workflow.add_edge("Stage4", "Stage5")
workflow.add_edge("Stage5", "Stage6")
workflow.add_edge("Stage6", END)

app = workflow.compile()


# ==============================================================================
# 11) WORKFLOW GRAPH EXPORT (OFFLINE SAFE)
# ==============================================================================

def export_workflow_mermaid(app_obj, out_file="workflow.mmd"):
    """
    Exports LangGraph workflow diagram in Mermaid format.

    WHY:
    - Documentation for your workflow
    - Offline safe (no external API call)
    """
    mmd = app_obj.get_graph().draw_mermaid()
    with open(out_file, "w", encoding="utf-8") as f:
        f.write(mmd)
    print(f"‚úÖ Workflow Mermaid saved locally: {out_file}")
    return out_file


# ==============================================================================
# 12) EXCEL HIGHLIGHTING FEATURE
# ==============================================================================

def highlight_evidence_reason_columns(excel_path: str):
    """
    Highlights Evidence + Reason columns in output Excel.

    Evidence Columns -> Yellow
    Reason Columns   -> Blue
    """
    wb = load_workbook(excel_path)
    ws = wb.active

    header = [cell.value for cell in ws[1]]

    evidence_cols = []
    reason_cols = []

    for idx, col_name in enumerate(header, start=1):
        if isinstance(col_name, str) and "Evidence" in col_name:
            evidence_cols.append(idx)
        if isinstance(col_name, str) and "Reason" in col_name:
            reason_cols.append(idx)

    evidence_fill = PatternFill(start_color="FFF2CC", end_color="FFF2CC", fill_type="solid")
    reason_fill = PatternFill(start_color="D9E1F2", end_color="D9E1F2", fill_type="solid")
    header_font = Font(bold=True)

    for col_idx in evidence_cols:
        ws.cell(row=1, column=col_idx).fill = evidence_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for col_idx in reason_cols:
        ws.cell(row=1, column=col_idx).fill = reason_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for row in range(2, ws.max_row + 1):
        for col_idx in evidence_cols:
            ws.cell(row=row, column=col_idx).fill = evidence_fill
        for col_idx in reason_cols:
            ws.cell(row=row, column=col_idx).fill = reason_fill

    wb.save(excel_path)
    print("‚úÖ Evidence + Reason columns highlighted successfully.")


# ==============================================================================
# 13) MAIN EXECUTION
# ==============================================================================

if __name__ == "__main__":

    print(f"\nüìå Loading Input File: {INPUT_EXCEL_PATH}")

    export_workflow_mermaid(app, out_file="workflow.mmd")

    try:
        df_input = pd.read_excel(INPUT_EXCEL_PATH)

        required_cols = ["Source URL", "Contract Date", "Contract Description"]
        if not all(col in df_input.columns for col in required_cols):
            raise ValueError(f"Excel file must contain columns: {required_cols}")

        print(f"üöÄ Processing {len(df_input)} rows...")
        results = []

        for index, row in df_input.iterrows():
            print(f"\nüîπ Row {index + 1}/{len(df_input)}")

            desc = str(row["Contract Description"]) if pd.notna(row["Contract Description"]) else ""
            c_date = str(row["Contract Date"]) if pd.notna(row["Contract Date"]) else str(datetime.date.today())
            c_url = str(row["Source URL"]) if pd.notna(row["Source URL"]) else ""

            initial_state = {
                "input_text": desc,
                "input_date": c_date,
                "input_url": c_url,
                "final_data": {},
                "final_rows": [],
                "validated_rows": [],
                "messages": []
            }

            output_state = app.invoke(initial_state)

            rows = output_state.get("validated_rows", [])
            if not rows:
                rows = output_state.get("final_rows", [])
            if not rows:
                rows = [output_state.get("final_data", {})]

            results.extend(rows)

        df_final = pd.DataFrame(results)

        FINAL_COLUMNS = [
            "Customer Region", "Customer Country", "Customer Operator",
            "Supplier Region", "Supplier Country", "Domestic Content",

            "Split Flag", "Split Reason", "Split Evidence", "Split Line Item",

            "Market Segment", "Market Segment Evidence", "Market Segment Reason",
            "System Type (General)", "System Type (General) Evidence", "System Type (General) Reason",
            "System Type (Specific)", "System Type (Specific) Evidence", "System Type (Specific) Reason",
            "System Name (General)", "System Name (General) Evidence", "System Name (General) Reason",
            "System Name (Specific)", "System Name (Specific) Evidence", "System Name (Specific) Reason",
            "System Piloting", "System Piloting Evidence", "System Piloting Reason",
            "Confidence",

            "Supplier Name", "Program Type", "Expected MRO Contract Duration (Months)",
            "Quantity", "Value Certainty", "Value (Million)", "Currency",
            "Value (USD$ Million)", "Value Note (If Any)", "G2G/B2G",
            "Signing Month", "Signing Year",

            # Stage6 rule validation
            "Validation Flag", "Validation Issues",
            "AutoFix Applied", "AutoFix Notes",

            # Stage6 FAIL-only LLM validator
            "LLM Validation Run", "LLM Validation Verdict", "LLM Validation Notes",
            "LLM Suggested Fixes", "LLM Fix Applied",

            "Description of Contract",
            "Additional Notes (Internal Only)",
            "Source Link(s)",
            "Contract Date",
            "Reported Date (By SGA)"
        ]

        df_final = df_final.reindex(columns=FINAL_COLUMNS, fill_value="")
        df_final.to_excel(OUTPUT_EXCEL_PATH, index=False)

        highlight_evidence_reason_columns(OUTPUT_EXCEL_PATH)

        print("\n‚úÖ Processing Complete!")
        print(f"üíæ Output File Saved: {OUTPUT_EXCEL_PATH}")
        print(df_final.head(3).to_string(index=False))

    except Exception as e:
        print(f"\n‚ùå ERROR: {e}")


In [None]:
## Final Version 

# ======================================================================================
# DEFENSE CONTRACT DATA EXTRACTION PIPELINE (AGENTIC + RAG + SPLIT + QA VALIDATION)
# ======================================================================================
# Author: Mukesh Kumar Sharma (Customized Agentic Pipeline)
# Maintained by: ChatGPT
#
# GOAL:
# - Read contract paragraphs from Excel
# - Extract structured fields (supplier/program/value/system/etc.)
# - Apply split logic for multi-operator / multi-supplier / multi-country cases
# - Validate outputs using deterministic + LLM validator (ONLY FAIL rows)
# - Export final Excel with evidence/reason highlighted
#
# ======================================================================================

import os
import re
import json
import difflib
import pickle
import datetime
import getpass
from typing import Annotated, TypedDict, List, Dict, Any, Optional

import pandas as pd
import faiss
from dateutil import parser
from dateutil.relativedelta import relativedelta

# LangGraph / LangChain
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from openai import OpenAI

# Excel formatting
from openpyxl import load_workbook
from openpyxl.styles import PatternFill, Font

# ======================================================================================
# 0) DEBUG LOGGING HELPERS
# ======================================================================================

def log_block(title: str, content: str):
    """
    Print a separated log block in console output.

    WHY THIS EXISTS:
    - This pipeline is multi-stage and deeply agentic.
    - Debugging becomes hard when you don't know which stage produced a wrong value.
    - This function prints stage-wise logs so you can track:
      * What paragraph went to LLM
      * What response came back
      * What transformation happened after extraction
    """
    print("\n" + "=" * 110)
    print(title)
    print("=" * 110)
    print(content)


def _normalize_spaces(text: str) -> str:
    """
    Normalize whitespace spacing for safer regex and matching.

    Example:
    "Lockheed   Martin,   Fort Worth , TX" -> "Lockheed Martin, Fort Worth , TX"
    """
    return re.sub(r"\s+", " ", str(text or "")).strip()


# ======================================================================================
# 1) RAG RETRIEVER (FAISS + Metadata)
# ======================================================================================

class SystemKBRetriever:
    """
    RAG Retriever for defense system classification (FAISS + metadata).

    PURPOSE:
    - Loads FAISS index and metadata created from your labeled KB dataset
      (system_kb.faiss + system_kb_meta.pkl)
    - Retrieves top-k similar examples using semantic embedding similarity

    WHY THIS IMPROVES ACCURACY:
    - Defense system classification is highly taxonomy-driven.
    - Your KB contains "known correct" historical labeled examples.
    - RAG reduces random hallucinations and increases classification stability.

    OUTPUT:
    retrieve(query_text) returns:
      [
        {"score": float, "meta": {...KB row columns...}},
        ...
      ]
    """

    def __init__(self, kb_dir: str, embed_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
        self.kb_dir = kb_dir
        self.embed_model = embed_model

        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError(
                f"‚ùå KB files not found in: {kb_dir}\n"
                f"Expected:\n- {index_path}\n- {meta_path}\n"
                f"Build KB first using your KB builder script."
            )

        print(f"‚úÖ Loading FAISS Index: {index_path}")
        self.index = faiss.read_index(index_path)

        print(f"‚úÖ Loading KB Metadata: {meta_path}")
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        print(f"‚úÖ KB Loaded rows: {len(self.meta)}")

        self.embedder = None

    def _lazy_load_embedder(self):
        """
        Lazy-load embedder only when retrieval is actually requested.

        WHY:
        - Speeds up pipeline startup
        - Avoids heavy RAM usage until needed
        """
        if self.embedder is None:
            from sentence_transformers import SentenceTransformer
            self.embedder = SentenceTransformer(self.embed_model)

    def retrieve(self, query_text: str, top_k: int = 3):
        """
        Retrieve top-k semantic matches from KB.

        Args:
            query_text: contract paragraph
            top_k: number of examples returned

        Returns:
            list of dict {score, meta}
        """
        import numpy as np

        query_text = str(query_text).strip()
        if not query_text:
            return []

        self._lazy_load_embedder()

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)

        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0:
                continue
            results.append({"score": float(score), "meta": self.meta[idx]})

        return results


# ======================================================================================
# 2) CONFIGURATION (PATHS + MODELS)
# ======================================================================================

TAXONOMY_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\taxonomy.json"
SUPPLIERS_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\suppliers.json"
INPUT_EXCEL_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\source_file.xlsx"
OUTPUT_EXCEL_PATH = "Processed_Defense_Data.xlsx"

RAG_KB_DIR = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\system_kb_store"

# Program Type must be EXACTLY this set (enforced globally)
ALLOWED_PROGRAM_TYPES = {
    "Procurement",
    "Training",
    "MRO/Support",
    "RDT&E",
    "Upgrade",
    "Other Service"
}

# Setup API key
if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter the LLM Foundry API Key: ")

client = OpenAI(
    api_key=f'{os.environ.get("LLMFOUNDRY_TOKEN")}:my-test-project',
    base_url="https://llmfoundry.straive.com/openai/v1/",
)

retriever = SystemKBRetriever(kb_dir=RAG_KB_DIR)


# ======================================================================================
# 3) JSON LOADERS
# ======================================================================================

def load_json_file(filename, default_value):
    """
    Load a JSON file safely.

    WHY:
    - Taxonomy + suppliers list are critical to the pipeline.
    - This prevents full crash if file missing during deployment.

    Returns:
        dict or list (based on file)
    """
    try:
        with open(filename, "r", encoding="utf-8") as f:
            print(f"‚úÖ Loaded JSON: {filename}")
            return json.load(f)
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not load {filename} ({e}). Using default.")
        return default_value


raw_taxonomy = load_json_file(TAXONOMY_PATH, {})
TAXONOMY_STR = json.dumps(raw_taxonomy, separators=(",", ":"))

SUPPLIER_LIST = load_json_file(SUPPLIERS_PATH, [
    "Dell Inc", "Boeing", "Lockheed Martin", "Raytheon Technologies",
    "Northrop Grumman", "L3Harris", "BAE Systems", "General Dynamics"
])


# ======================================================================================
# 4) RULE BOOK + GEOGRAPHY
# ======================================================================================

RULE_BOOK = {
    "defensive_countermeasures": {
        "triggers": ["flare", "chaff", "countermeasure", "decoy", "mju-", "ale-"],
        "guidance": "Market Segment='C4ISR Systems', System Type (General)='Defensive Systems', Specific='Defensive Aid Suite'"
    },
    "radars_and_sensors": {
        "triggers": ["radar", "sonar", "sensor", "an/apy", "an/tpy"],
        "guidance": "Market Segment='C4ISR Systems', System Type (General)='Sensors'"
    },
    "ammunition": {
        "triggers": ["cartridge", "round", "projectile", " 5.56", " 7.62", "ammo"],
        "guidance": "Market Segment='Weapon Systems', System Type (General)='Ammunition'"
    }
}

GEOGRAPHY_MAPPING = {
    "North America": ["USA", "United States", "US", "United States of America", "Canada", "America"],
    "Europe": ["UK", "United Kingdom", "Ukraine", "Germany", "France", "Italy", "Spain", "Poland", "Netherlands",
               "Norway", "Sweden", "Finland", "Denmark", "Belgium"],
    "Asia-Pacific": ["Australia", "Japan", "South Korea", "Taiwan", "India", "Singapore", "New Zealand"],
    "Middle East and North Africa": ["Israel", "Saudi Arabia", "UAE", "United Arab Emirates", "Egypt", "Qatar", "Kuwait", "Iraq"],
    "International Organisations": ["NATO", "EU", "IFU", "UN", "NSPA"]
}


# ======================================================================================
# 5) BASE HELPERS (Supplier + Dates + Region + Designators)
# ======================================================================================

def get_region_for_country(country_name: str) -> str:
    """
    Maps country name -> region using GEOGRAPHY_MAPPING.

    Returns:
        region string (example "Europe") or "Unknown"
    """
    if not country_name or str(country_name).strip().lower() in ["unknown", "n/a", "not applicable"]:
        return "Unknown"

    clean = str(country_name).strip().lower()

    if clean in ["us", "usa", "u.s.", "united states", "united states of america"]:
        return "North America"
    if clean in ["uk", "u.k.", "britain", "great britain"]:
        return "Europe"

    for region, countries in GEOGRAPHY_MAPPING.items():
        if any(c.lower() == clean for c in countries):
            return region

    return "Unknown"


def calculate_mro_months(start_date_str: str, end_date_text: str, program_type: str) -> str:
    """
    Calculate MRO duration in months.

    IMPORTANT RULE:
    - Only compute when program_type == "MRO/Support"
    - For all other program types return "Not Applicable"

    Args:
        start_date_str: signed date from Excel
        end_date_text: completion date text extracted from paragraph
        program_type: must be normalized program type

    Returns:
        months as string OR "Not Applicable"
    """
    if program_type != "MRO/Support":
        return "Not Applicable"

    try:
        if not start_date_str or not end_date_text:
            return "Not Applicable"

        start = pd.to_datetime(start_date_str, dayfirst=True)
        end = parser.parse(str(end_date_text), fuzzy=True)

        diff = relativedelta(end, start)
        total_months = diff.years * 12 + diff.months
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"


# ----------------------------- SUPPLIER HELPERS (CRITICAL FIX) -----------------------------

def _extract_supplier_from_description(paragraph: str) -> str:
    """
    Extract supplier name from typical DoD award structure using regex.

    WHY THIS EXISTS:
    - LLM sometimes returns incomplete supplier names ("Lockheed", "Raytheon Co.")
    - Fuzzy match can accidentally map to wrong supplier
    - DoD structure is VERY consistent: first entity before comma is supplier

    Supported patterns:
    1) "Company Name, City, State, is awarded..."
    2) "Company Name, City, State, was awarded..."

    Returns:
        extracted supplier string OR "Unknown"
    """
    p = _normalize_spaces(paragraph)

    m = re.match(r"^(.*?),\s+.*?\s+is awarded", p, flags=re.IGNORECASE)
    if m:
        supplier = m.group(1).strip()
        if len(supplier) >= 3:
            return supplier

    m2 = re.match(r"^(.*?),\s+.*?\s+was awarded", p, flags=re.IGNORECASE)
    if m2:
        supplier = m2.group(1).strip()
        if len(supplier) >= 3:
            return supplier

    return "Unknown"


def get_best_supplier_match(extracted_name: str, paragraph: str = "") -> str:
    """
    Robust supplier standardization (SAFE LOGIC).

    STRICT PRIORITY ORDER:
    1) If extracted supplier is empty -> "Unknown"
    2) If extracted supplier appears in paragraph -> accept it (minimal modification)
    3) Exact match with suppliers.json list
    4) If exactly one known supplier appears inside paragraph -> use that
    5) Fuzzy match ONLY if high confidence (cutoff >= 0.75)
    6) Else keep extracted text (do NOT hallucinate)

    WHY THIS IS IMPORTANT:
    - Your previous supplier issue happened because fuzzy match was too aggressive.
    - This logic prevents wrong overrides (example: "Lockheed" -> "Lockheed Martin"
      only when it's safe).
    """
    if not extracted_name or str(extracted_name).strip().lower() in ["unknown", "n/a", "not applicable", ""]:
        return "Unknown"

    extracted_clean = str(extracted_name).strip()
    para_lower = str(paragraph or "").lower()

    # If supplier already appears in paragraph -> safest (keep it)
    if extracted_clean.lower() in para_lower:
        return extracted_clean

    # Exact match
    supplier_map = {s.lower(): s for s in SUPPLIER_LIST}
    if extracted_clean.lower() in supplier_map:
        return supplier_map[extracted_clean.lower()]

    # Paragraph scan match (best evidence-based)
    paragraph_hits = []
    for s in SUPPLIER_LIST:
        if s.lower() in para_lower:
            paragraph_hits.append(s)
    paragraph_hits = list(dict.fromkeys(paragraph_hits))

    if len(paragraph_hits) == 1:
        return paragraph_hits[0]

    # Fuzzy match only if strong
    matches = difflib.get_close_matches(extracted_clean, SUPPLIER_LIST, n=1, cutoff=0.75)
    if matches:
        return matches[0]

    return extracted_clean


def normalize_program_type(program_type_raw: str) -> str:
    """
    Normalize program type into strict allowed values.

    IMPORTANT:
    - NEVER output "MRO"
    - If maintenance/sustainment/support -> MUST be "MRO/Support"

    Returns:
        Strict program type string from ALLOWED_PROGRAM_TYPES
    """
    prog_type = str(program_type_raw or "").strip()

    prog_type_map = {
        "mro": "MRO/Support",
        "support": "MRO/Support",
        "mro/support": "MRO/Support",
        "maintenance": "MRO/Support",
        "sustainment": "MRO/Support",
        "logistics": "MRO/Support",
        "repair": "MRO/Support",
        "spares": "MRO/Support",
        "procurements": "Procurement",
        "rdte": "RDT&E",
        "r&d": "RDT&E"
    }

    prog_type_norm = prog_type_map.get(prog_type.lower(), prog_type)

    if prog_type_norm not in ALLOWED_PROGRAM_TYPES:
        return "Other Service"

    return prog_type_norm


# ----------------------------- DESIGNATORS + PILOTING RULES -----------------------------

DESIGNATOR_PATTERNS = [
    r"\bDDG[-\s]?\d+\b", r"\bCVN[-\s]?\d+\b", r"\bSSN[-\s]?\d+\b",
    r"\bLCS[-\s]?\d+\b", r"\bLPD[-\s]?\d+\b", r"\bLHA[-\s]?\d+\b", r"\bLHD[-\s]?\d+\b",
    r"\bF-\d+\b", r"\bB-\d+\b", r"\bC-\d+\b", r"\bA-\d+\b",
    r"\bMQ-\d+\b", r"\bRQ-\d+\b",
    r"\bAN\/[A-Z0-9\-]+\b",
    r"\b(AIM|AGM|SM|RIM|MIM)-\d+\b",
]


def extract_designators(text: str) -> List[str]:
    """
    Extract defense platform designators from paragraph.

    Examples:
    - DDG-51, CVN-78, SSN-774
    - MQ-9, RQ-4
    - AIM-9X, SM-6
    - AN/APY-10

    Returns:
        list of unique designator tokens (standardized to uppercase)
    """
    text = str(text or "")
    found = []
    for pat in DESIGNATOR_PATTERNS:
        found.extend(re.findall(pat, text, flags=re.IGNORECASE))

    cleaned = [f.upper().replace(" ", "").replace("--", "-") for f in found]

    final = []
    seen = set()
    for x in cleaned:
        if x not in seen:
            final.append(x)
            seen.add(x)
    return final


def detect_piloting_rule_based(text: str, designators: List[str]) -> str:
    """
    Deterministic piloting classification.

    WHY:
    - System Piloting is often wrongly inferred by LLM
    - Designators MQ- / RQ- strongly indicate UAV/uncrewed systems

    Output:
        "Crewed" | "Uncrewed" | "Not Applicable"
    """
    t = str(text).lower()

    if any(d.startswith(("MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    if any(k in t for k in ["unmanned", "uav", "drone", "autonomous"]):
        return "Uncrewed"

    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD")) for d in designators):
        return "Crewed"
    if "uss " in t:
        return "Crewed"

    return "Not Applicable"


# ======================================================================================
# 6) ENHANCED SPLIT ENGINE
# ======================================================================================

def parse_operator_quantity_allocations(paragraph: str) -> List[Dict[str, str]]:
    """
    Detect quantity allocations by operator.

    Example:
      "212 for the Navy, 187 for the Air Force, and 84 for FMS customers"

    Returns:
      [
        {"operator": "Navy", "quantity": "212", "g2g_b2g": "B2G"},
        {"operator": "Foreign Assistance", "quantity": "84", "g2g_b2g": "G2G"}
      ]
    """
    text = str(paragraph or "")
    allocations = []

    pattern = r"(\d+)\s+for\s+the\s+(Navy|Air Force|Army|Marine Corps)"
    matches = re.findall(pattern, text, flags=re.IGNORECASE)
    for qty, op in matches:
        allocations.append({"operator": op.title(), "quantity": qty, "g2g_b2g": "B2G"})

    fms_pattern = r"(\d+)\s+for\s+(?:Foreign Military Sales\s*\(FMS\)\s*customers|FMS\s*customers|a\s*FMS\s*customer|FMS)"
    fms_matches = re.findall(fms_pattern, text, flags=re.IGNORECASE)
    for qty in fms_matches:
        allocations.append({"operator": "Foreign Assistance", "quantity": qty, "g2g_b2g": "G2G"})

    unique = []
    seen = set()
    for a in allocations:
        key = (a["operator"], a["quantity"], a["g2g_b2g"])
        if key not in seen:
            unique.append(a)
            seen.add(key)

    return unique


def parse_fms_countries(paragraph: str) -> List[str]:
    """
    Extract FMS customer countries list.

    Looks for patterns like:
      "governments of Australia, Bahrain, Belgium..."

    Returns:
      ["Australia", "Bahrain", ...]
    """
    text = str(paragraph or "")

    m = re.search(
        r"governments of (.+?)(?:\.\s| Work will be performed| Fiscal| This contract|$)",
        text,
        flags=re.IGNORECASE | re.DOTALL
    )
    if not m:
        return []

    block = m.group(1)
    raw = re.split(r",|\band\b", block)

    countries = []
    for c in raw:
        c = c.strip()
        if 2 < len(c) <= 40:
            countries.append(c)

    final = []
    seen = set()
    for c in countries:
        if c.lower() not in seen:
            final.append(c)
            seen.add(c.lower())

    return final


def parse_multiple_suppliers(paragraph: str) -> List[str]:
    """
    Detect multi-supplier mention using supplier taxonomy list.

    Example:
    - "Lockheed Martin and Raytheon..."
    - "multiple awardees include..."

    Returns:
      list of supplier names from SUPPLIER_LIST found in paragraph.
      Returns [] if only one or none.
    """
    text = str(paragraph or "")
    lower = text.lower()

    if " and " not in lower and "," not in lower:
        return []

    candidates = []
    for supplier in SUPPLIER_LIST:
        if supplier.lower() in lower:
            candidates.append(supplier)

    candidates = list(dict.fromkeys(candidates))
    return candidates if len(candidates) >= 2 else []


def parse_multiple_values(paragraph: str) -> List[str]:
    """
    Detect multiple financial values in paragraph.

    Examples:
    - "$328,156,454"
    - "ceiling value of $500 million"
    - "base value $20 million and option value $10 million"

    Returns:
      ["328,156,454", "500", ...] raw numbers only
    """
    text = str(paragraph or "")
    money_pattern = r"\$([\d,]+(?:\.\d+)?)"
    vals = re.findall(money_pattern, text)
    return list(dict.fromkeys(vals))


def split_rows_engine(base_row: dict, paragraph: str) -> List[dict]:
    """
    MASTER SPLIT ENGINE

    PURPOSE:
    - Convert 1 extracted base_row into 1..N final output rows
      depending on split conditions.

    SUPPORTED SPLITS:
    1) Multi supplier mentions (if paragraph contains 2+ known suppliers)
    2) Operator quantity allocation (Navy/Air Force + qty)
    3) FMS multi-country split (only applied for G2G rows)
    4) Multi financial values -> added only as Value Note (no overwrite)

    IMPORTANT:
    - Shared columns remain same
    - Split-driven columns are modified
    - Split Flag + Split Reason are always filled
    """
    paragraph = str(paragraph or "")

    allocations = parse_operator_quantity_allocations(paragraph)
    fms_countries = parse_fms_countries(paragraph)
    multi_suppliers = parse_multiple_suppliers(paragraph)
    multi_values = parse_multiple_values(paragraph)

    split_reasons = []

    if multi_suppliers:
        split_reasons.append("Multi-supplier mention found")
    if allocations:
        split_reasons.append("Multi-operator allocation found")
    if fms_countries:
        split_reasons.append("FMS multi-country list found")
    if len(multi_values) >= 2:
        split_reasons.append("Multiple financial values found")

    if not split_reasons:
        base_row["Split Flag"] = "No"
        base_row["Split Reason"] = "No split condition found"
        return [base_row]

    base_reason = " | ".join(split_reasons)
    rows = [base_row.copy()]

    # Supplier split
    if multi_suppliers:
        new_rows = []
        for r in rows:
            for s in multi_suppliers:
                rr = r.copy()
                rr["Supplier Name"] = s
                rr["Split Flag"] = "Yes"
                rr["Split Reason"] = f"{base_reason} (Supplier split)"
                new_rows.append(rr)
        rows = new_rows

    # Operator split
    if allocations:
        new_rows = []
        for r in rows:
            for alloc in allocations:
                rr = r.copy()
                rr["Customer Operator"] = alloc["operator"]
                rr["Quantity"] = alloc["quantity"]
                rr["G2G/B2G"] = alloc["g2g_b2g"]
                rr["Split Flag"] = "Yes"
                rr["Split Reason"] = f"{base_reason} (Operator/Quantity split)"
                new_rows.append(rr)
        rows = new_rows

    # Multi values note
    if len(multi_values) >= 2:
        for r in rows:
            note = r.get("Value Note (If Any)", "Not Applicable")
            r["Value Note (If Any)"] = f"{note} | Multiple values detected: {multi_values[:5]}"

    # FMS country split
    if fms_countries:
        final_rows = []
        for r in rows:
            if r.get("G2G/B2G") == "G2G":
                for c in fms_countries:
                    rr = r.copy()
                    rr["Customer Country"] = c
                    rr["Customer Region"] = get_region_for_country(c)
                    rr["Split Flag"] = "Yes"
                    rr["Split Reason"] = f"{base_reason} (FMS country split)"
                    final_rows.append(rr)
            else:
                final_rows.append(r)
        rows = final_rows

    for r in rows:
        r.setdefault("Split Flag", "Yes")
        r.setdefault("Split Reason", base_reason)

    return rows


# ======================================================================================
# 7) AGENTS / TOOLS
# ======================================================================================

# ---------------------------------- Stage 1: Sourcing ----------------------------------

class SourcingInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    url: str = Field(description="Source URL of the contract announcement/news.")
    date: str = Field(description="Contract date in Excel (string).")


@tool("sourcing_extractor")
def sourcing_extractor(paragraph: str, url: str, date: str):
    """
    STAGE 1 TOOL: SOURCING EXTRACTOR

    GOAL:
    - Build the base skeleton row of the dataset with sourcing fields.
    - These fields MUST remain unchanged even when row split happens later.

    OUTPUT COLUMNS:
    - Description of Contract
    - Source Link(s)
    - Contract Date
    - Reported Date (By SGA)
    - Additional Notes (Internal Only)

    WHY IMPORTANT:
    - Split engine later will duplicate rows.
    - Without this stage, split rows lose traceability to paragraph and URL.
    """
    reported_date = datetime.datetime.now().strftime("%Y-%m-%d")

    notes = "Standard extraction."
    if "modification" in str(paragraph).lower():
        notes = "Contract Modification."
    if "multiple award" in str(paragraph).lower() or "split" in str(paragraph).lower():
        notes = "Potential split award detected."

    return {
        "Description of Contract": paragraph,
        "Additional Notes (Internal Only)": notes,
        "Source Link(s)": url,
        "Contract Date": date,
        "Reported Date (By SGA)": reported_date
    }


# ---------------------------------- Stage 2: Geography ----------------------------------

class GeographyInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")


@tool("geography_extractor")
def geography_extractor(paragraph: str):
    """
    STAGE 2 TOOL: GEOGRAPHY EXTRACTOR

    GOAL:
    - Extract geography context (buyer + supplier location).

    OUTPUT:
    - Customer Country
    - Customer Operator
    - Supplier Country
    - Customer Region (derived)
    - Supplier Region (derived)
    - Domestic Content (derived)

    RULES:
    - Customer Operator is the defense branch / customer body:
      Navy, Air Force, Army, Marine Corps, Foreign Assistance etc.
    - Domestic Content:
      Indigenous if Customer Country == Supplier Country
      else Imported
    """
    sys_prompt = """
Extract: Customer Country, Supplier Country, Customer Operator.

STRICT RULES:
- Return JSON only.
- If not found, return "Unknown".
- Operator examples:
  "Navy", "Air Force", "Army", "Marine Corps", "Foreign Assistance"

Return JSON:
{
  "Customer Country": "...",
  "Customer Operator": "...",
  "Supplier Country": "..."
}
"""

    log_block("HUMAN MESSAGE (Stage2 - Geography)", paragraph)

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": paragraph}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage2 - Geography)", json.dumps(raw, indent=2))
    except Exception as e:
        raw = {}
        log_block("AI ERROR (Stage2 - Geography)", str(e))

    cust = raw.get("Customer Country", "Unknown")
    supp = raw.get("Supplier Country", "Unknown")

    domestic = "Indigenous" if str(cust).lower() == str(supp).lower() else "Imported"

    return {
        "Customer Region": get_region_for_country(cust),
        "Customer Country": cust,
        "Customer Operator": raw.get("Customer Operator", "Unknown"),
        "Supplier Region": get_region_for_country(supp),
        "Supplier Country": supp,
        "Domestic Content": domestic
    }


# ---------------------------------- Stage 3: System Classifier (RAG) ---------------------

class SystemInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")


@tool("system_classifier")
def system_classifier(paragraph: str):
    """
    STAGE 3 TOOL: SYSTEM CLASSIFIER (RAG + RULE BOOK + EVIDENCE/REASON)

    GOAL:
    - Identify taxonomy-based defense system classification fields with evidence + reasoning.

    OUTPUT FIELDS:
    - Market Segment (+ Evidence + Reason)
    - System Type (General) (+ Evidence + Reason)
    - System Type (Specific) (+ Evidence + Reason)
    - System Name (General) (+ Evidence + Reason)
    - System Name (Specific) (+ Evidence + Reason)
    - System Piloting (+ Evidence + Reason)
    - Confidence

    ACCURACY ENHANCEMENTS:
    ‚úÖ RULE_BOOK triggers add hard guidance for known keyword patterns
    ‚úÖ RAG examples provide consistent "known good" labeled references
    ‚úÖ Deterministic System Piloting overrides model errors:
       MQ-/RQ- -> Uncrewed, USS/DDG/CVN -> Crewed

    STRICT RULES:
    - Evidence text must be copied EXACTLY from paragraph if present.
    - If not found -> "Not Found"
    - Return flat JSON only (no nested dict/list).
    """
    paragraph = str(paragraph or "").strip()
    if not paragraph:
        return {}

    log_block("HUMAN MESSAGE (Stage3 - System)", paragraph)

    lower_text = paragraph.lower()
    hints = [
        f"RULE: {v['guidance']}"
        for _, v in RULE_BOOK.items()
        if any(t in lower_text for t in v["triggers"])
    ]
    hint_str = "\n".join(hints) if hints else "No special override rules triggered."

    designators = extract_designators(paragraph)
    piloting_rule = detect_piloting_rule_based(paragraph, designators)

    rag_hits = retriever.retrieve(paragraph, top_k=3)
    rag_examples = []
    for hit in rag_hits:
        meta = hit["meta"]
        rag_examples.append({
            "score": round(hit["score"], 4),
            "Market Segment": meta.get("Market Segment", ""),
            "System Type (General)": meta.get("System Type (General)", ""),
            "System Type (Specific)": meta.get("System Type (Specific)", ""),
            "System Name (General)": meta.get("System Name (General)", ""),
            "System Name (Specific)": meta.get("System Name (Specific)", ""),
            "System Piloting": meta.get("System Piloting", ""),
            "Snippet": (meta.get("Description of Contract", "")[:220] + "...") if meta.get("Description of Contract") else ""
        })

    sys_prompt = f"""
You are a Senior Defense System Classification Analyst.

REFERENCE TAXONOMY (JSON):
{TAXONOMY_STR}

RULE BOOK OVERRIDES:
{hint_str}

OUTPUT RULES:
- Return ONLY a FLAT JSON object.
- Every value must be a STRING.
- Do NOT return nested objects or lists.
- Evidence must be copied EXACTLY from paragraph.
- If evidence not present, output "Not Found".

Return JSON:
{{
  "Market Segment": "",
  "Market Segment Evidence": "",
  "Market Segment Reason": "",

  "System Type (General)": "",
  "System Type (General) Evidence": "",
  "System Type (General) Reason": "",

  "System Type (Specific)": "",
  "System Type (Specific) Evidence": "",
  "System Type (Specific) Reason": "",

  "System Name (General)": "",
  "System Name (General) Evidence": "",
  "System Name (General) Reason": "",

  "System Name (Specific)": "",
  "System Name (Specific) Evidence": "",
  "System Name (Specific) Reason": "",

  "System Piloting": "",
  "System Piloting Evidence": "",
  "System Piloting Reason": "",

  "Confidence": "High/Medium/Low"
}}
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

DESIGNATORS (regex extracted):
{designators if designators else "None"}

RULE BASED PILOTING:
{piloting_rule}

RAG EXAMPLES (top matches):
{json.dumps(rag_examples, indent=2)}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        result = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage3 - System)", json.dumps(result, indent=2))

        # Hard override piloting
        result["System Piloting"] = piloting_rule
        result.setdefault("System Piloting Evidence", "Not Found")
        result.setdefault("System Piloting Reason", "Derived from deterministic piloting rules (designator + keywords).")

        # ensure flat values
        for k, v in result.items():
            if isinstance(v, (dict, list)):
                result[k] = str(v)

        return result

    except Exception as e:
        log_block("AI ERROR (Stage3 - System)", str(e))
        return {
            "Market Segment": "",
            "Market Segment Evidence": "Not Found",
            "Market Segment Reason": "",

            "System Type (General)": "",
            "System Type (General) Evidence": "Not Found",
            "System Type (General) Reason": "",

            "System Type (Specific)": "",
            "System Type (Specific) Evidence": "Not Found",
            "System Type (Specific) Reason": "",

            "System Name (General)": "",
            "System Name (General) Evidence": "Not Found",
            "System Name (General) Reason": "",

            "System Name (Specific)": "",
            "System Name (Specific) Evidence": "Not Found",
            "System Name (Specific) Reason": "",

            "System Piloting": piloting_rule,
            "System Piloting Evidence": "Not Found",
            "System Piloting Reason": "Derived from deterministic piloting rules.",

            "Confidence": "Low",
            "Error": str(e)
        }


# ---------------------------------- Stage 4: Contract Extractor --------------------------

class ContractInfoInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    contract_date: str = Field(description="Contract signed date (string from Excel).")


@tool("contract_extractor")
def contract_extractor(paragraph: str, contract_date: str):
    """
    STAGE 4 TOOL: CONTRACT EXTRACTOR (SUPPLIER + FINANCIAL + PROGRAM)

    GOAL:
    Extract financial + contract fields from paragraph:

    OUTPUT FIELDS:
    - Supplier Name  ‚úÖ (fixed logic + safe matching)
    - Program Type   ‚úÖ (STRICT: must be exactly one of allowed types)
    - Quantity
    - Value (Million)
    - Currency
    - Value Certainty
    - G2G/B2G
    - Completion Date Text
    - Value Note

    SUPPLIER NAME LOGIC (VERY IMPORTANT FIX):
    - LLM raw supplier is not trusted blindly
    - We compute a fallback supplier using DoD regex structure:
      "Company Name, City, ST, is awarded..."
    - Then we standardize with safe logic:
      * Exact/paragraph based match first
      * Fuzzy ONLY at high cutoff
      * Never hallucinate supplier

    PROGRAM TYPE STRICT RULE:
    - Must be EXACTLY one of:
      Procurement / Training / MRO/Support / RDT&E / Upgrade / Other Service
    - NEVER output "MRO"
    """
    system_instruction = f"""
You are a Defense Contract Financial Analyst.

Return JSON ONLY.

STRICT RULES:
1) raw_supplier_name:
   - Extract EXACT supplier company name as written in paragraph (no guessing).
   - If not found, return "" (empty string).

2) program_type MUST be EXACTLY one of:
   - Procurement
   - Training
   - MRO/Support
   - RDT&E
   - Upgrade
   - Other Service

   IMPORTANT:
   - Do NOT return "MRO"
   - Do NOT return "Support"
   - It MUST be "MRO/Support" exactly for sustainment/maintenance/support/spares.

3) value_certainty MUST be one of:
   - Confirmed
   - Estimated

4) quantity:
   - Extract numeric quantity if present else "Not Applicable"

5) g2g_b2g:
   - "G2G" ONLY if paragraph mentions FMS/Foreign Military Sales
   - Else "B2G"

6) value_million_raw:
   - numeric only, may contain comma or decimal
   - DO NOT include "$"

Return JSON:
{{
  "raw_supplier_name": "",
  "program_type": "",
  "value_million_raw": "",
  "currency_code": "",
  "value_certainty": "",
  "quantity": "",
  "completion_date_text": "",
  "g2g_b2g": "",
  "value_note": ""
}}
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

SIGNED DATE:
{contract_date}
"""

    log_block("HUMAN MESSAGE (Stage4 - Contract)", paragraph)

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage4 - Contract)", json.dumps(raw, indent=2))
    except Exception as e:
        log_block("AI ERROR (Stage4 - Contract)", str(e))
        return {"Error": str(e)}

    # ------------------ Supplier Fix: do NOT break supplier name ------------------

    raw_supplier = raw.get("raw_supplier_name", "")
    fallback_supplier = _extract_supplier_from_description(paragraph)
    candidate_supplier = raw_supplier if raw_supplier else fallback_supplier

    final_supplier = get_best_supplier_match(candidate_supplier, paragraph=paragraph)

    # ------------------ Program Type Fix: must be EXACT ------------------

    prog_type = normalize_program_type(raw.get("program_type", ""))

    # ------------------ MRO months only when MRO/Support ------------------

    mro_months = calculate_mro_months(contract_date, raw.get("completion_date_text"), prog_type)

    # ------------------ Value formatting ------------------

    try:
        val_str = str(raw.get("value_million_raw", "0")).replace(",", "").replace("$", "").strip()
        val_float = float(val_str)
        val_formatted = "{:.3f}".format(val_float)
    except:
        val_formatted = "0.000"

    # ------------------ Signing Month/Year ------------------

    try:
        dt = pd.to_datetime(contract_date)
        sign_month = dt.strftime("%B")
        sign_year = str(dt.year)
    except:
        sign_month, sign_year = "Unknown", "Unknown"

    return {
        "Supplier Name": final_supplier,
        "Program Type": prog_type,
        "Expected MRO Contract Duration (Months)": mro_months,
        "Quantity": raw.get("quantity", "Not Applicable"),
        "Value Certainty": raw.get("value_certainty", "Confirmed"),
        "Value (Million)": val_formatted,
        "Currency": raw.get("currency_code", "USD$"),
        "Value (USD$ Million)": val_formatted,
        "Value Note (If Any)": raw.get("value_note", "Not Applicable"),
        "G2G/B2G": raw.get("g2g_b2g", "B2G"),
        "Signing Month": sign_month,
        "Signing Year": sign_year
    }


# ---------------------------------- Stage 5: Split Agent --------------------------------

class SplitterInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    base_row: dict = Field(description="Final extracted row after Stage1-4.")


@tool("splitter_agent")
def splitter_agent(paragraph: str, base_row: dict):
    """
    STAGE 5 TOOL: SPLITTER AGENT

    GOAL:
    - Identify if one extracted row must be split into multiple rows
      using deterministic split logic.

    WHEN SPLIT HAPPENS:
    - Multiple suppliers in one paragraph
    - Operator-wise quantity allocation (Navy/Air Force split)
    - FMS multi-country list (only for G2G rows)
    - Multiple values -> stored in Value Note

    OUTPUT:
    Returns JSON:
    {
      "rows": [row1, row2, ...]
    }
    """
    try:
        rows = split_rows_engine(base_row, paragraph)
        for r in rows:
            r.setdefault("Split Flag", "No")
            r.setdefault("Split Reason", "")
        return {"rows": rows}
    except Exception as e:
        base_row["Split Flag"] = "Error"
        base_row["Split Reason"] = f"Split failed: {str(e)}"
        return {"rows": [base_row]}


# ---------------------------------- Stage 6: Quality Validator (Rule-Based) --------------

class QualityValidatorInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    row: dict = Field(description="One fully extracted row after splitting.")


@tool("quality_validator")
def quality_validator(paragraph: str, row: dict):
    """
    STAGE 6 TOOL: QUALITY VALIDATOR (RULE-BASED)

    GOAL:
    - Detect obviously wrong output rows and flag them.

    THIS VALIDATOR DOES NOT "FIX" DATA.
    It ONLY flags with:
    - QA Status = PASS/FAIL
    - QA Flags = list of issues
    - QA Notes = explanation

    IMPORTANT VALIDATIONS:
    ‚úÖ Supplier mismatch:
       - If Supplier Name NOT mentioned anywhere in paragraph -> FAIL

    ‚úÖ Program Type format:
       - If Program Type not in allowed strict list -> FAIL
       - If Program Type == "MRO" -> FAIL (disallowed)

    ‚úÖ System mismatch signals:
       - If Market Segment/System Type empty -> FAIL

    OUTPUT:
    - QA Status
    - QA Flags
    - QA Notes
    """
    paragraph_l = str(paragraph or "").lower()

    supplier = str(row.get("Supplier Name", "")).strip()
    program_type = str(row.get("Program Type", "")).strip()

    flags = []

    # Supplier validation
    if supplier in ["", "Unknown"]:
        flags.append("Supplier missing or Unknown")
    else:
        # if supplier does not appear in paragraph -> very suspicious
        if supplier.lower() not in paragraph_l:
            # allow short fallback check: maybe supplier appears as partial token
            partial_hit = any(tok.lower() in paragraph_l for tok in supplier.split() if len(tok) >= 4)
            if not partial_hit:
                flags.append("Supplier does not match paragraph evidence")

    # Program type validation
    if program_type not in ALLOWED_PROGRAM_TYPES:
        flags.append("Program Type not in allowed list")
    if program_type.strip().lower() == "mro":
        flags.append("Program Type invalid: must be 'MRO/Support' not 'MRO'")

    # System sanity
    market = str(row.get("Market Segment", "")).strip()
    sys_type = str(row.get("System Type (General)", "")).strip()
    sys_name = str(row.get("System Name (General)", "")).strip()

    if market == "":
        flags.append("Market Segment missing")
    if sys_type == "":
        flags.append("System Type (General) missing")
    if sys_name == "":
        flags.append("System Name (General) missing")

    status = "PASS" if len(flags) == 0 else "FAIL"

    return {
        "QA Status": status,
        "QA Flags": "; ".join(flags) if flags else "None",
        "QA Notes": "Rule-based QA validation applied."
    }


# ---------------------------------- Stage 6B: LLM Validator (ONLY FAIL rows) -------------

class LLMValidatorInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    row: dict = Field(description="Row that failed rule-based QA and requires LLM validation.")


@tool("llm_fail_validator")
def llm_fail_validator(paragraph: str, row: dict):
    """
    STAGE 6B TOOL: LLM FAIL VALIDATOR (ONLY RUNS WHEN QA STATUS = FAIL)

    WHY THIS EXISTS:
    - Rule-based QA catches false positives sometimes.
    - LLM can reason better if supplier is acceptable even if not literally present.
      (Example: "Lockheed Martin Corp." vs "Lockheed Martin")

    WHAT IT DOES:
    - It reviews the paragraph + extracted row
    - It decides:
        * should this row still be FAIL?
        * OR upgrade to PASS with justification?

    OUTPUT:
    {
      "QA Status Final": "PASS" or "FAIL",
      "QA LLM Notes": "short reason",
      "Supplier Fix Suggestion": "...",
      "Program Type Fix Suggestion": "..."
    }

    IMPORTANT:
    - It does NOT overwrite the row automatically.
    - It provides suggestions only for debugging + manual review.
    """
    sys_prompt = """
You are a Defense Contract QA Auditor.

You will receive:
1) Contract paragraph text
2) Extracted output row data

TASK:
- Validate if supplier and program type make sense based on paragraph.
- If supplier seems wrong -> suggest correction (if possible).
- If program type seems wrong -> suggest correction.

STRICT RULES:
- Do NOT hallucinate suppliers that are not evidenced in the paragraph.
- Program Type must be exactly one of:
  Procurement, Training, MRO/Support, RDT&E, Upgrade, Other Service

Return JSON only:
{
  "QA Status Final": "PASS/FAIL",
  "QA LLM Notes": "",
  "Supplier Fix Suggestion": "",
  "Program Type Fix Suggestion": ""
}
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

ROW:
{json.dumps(row, indent=2)}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        return raw
    except Exception as e:
        return {
            "QA Status Final": "FAIL",
            "QA LLM Notes": f"LLM validator failed: {str(e)}",
            "Supplier Fix Suggestion": "",
            "Program Type Fix Suggestion": ""
        }


# ======================================================================================
# 8) LANGGRAPH PIPELINE
# ======================================================================================

class AgentState(TypedDict):
    """
    LangGraph Agent State container.

    Fields:
    - input_text: contract paragraph
    - input_date: signing date from Excel
    - input_url: source URL from Excel

    - final_data: dict containing extracted fields before splitting
    - final_rows: list of rows after split applied

    - messages: internal message channel (LangGraph requirement)
    """
    input_text: str
    input_date: str
    input_url: str
    final_data: dict
    final_rows: list
    messages: Annotated[List[AnyMessage], add_messages]


def stage_1_sourcing(state: AgentState):
    """
    NODE Stage1: Sourcing Extractor

    PURPOSE:
    - Populate paragraph-level traceability fields:
      Description, Source link, dates, internal notes

    IMPORTANT:
    - These fields MUST remain same even after split happens.
    """
    res = sourcing_extractor.invoke({
        "paragraph": state["input_text"],
        "url": state["input_url"],
        "date": state["input_date"],
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_2_geography(state: AgentState):
    """
    NODE Stage2: Geography Extractor

    PURPOSE:
    - Identify Customer Country/Operator + Supplier Country
    - Derive regions + domestic content
    """
    res = geography_extractor.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_3_system(state: AgentState):
    """
    NODE Stage3: System Classifier

    PURPOSE:
    - Classify defense contract into taxonomy fields using:
      - TAXONOMY JSON reference
      - RULE BOOK keyword guidance
      - RAG examples
      - Deterministic piloting override

    OUTPUT:
    - Market Segment + evidence/reason
    - System fields + evidence/reason
    """
    res = system_classifier.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_4_contract(state: AgentState):
    """
    NODE Stage4: Contract Extractor

    PURPOSE:
    - Extract supplier/program type/value/quantity and normalize.

    ENHANCED SUPPLIER LOGIC:
    - Uses regex-based fallback supplier extraction for DoD structured paragraphs
    - Uses evidence-based matching (paragraph scan) BEFORE fuzzy matching
    - Prevents wrong supplier assignments

    PROGRAM TYPE FIX:
    - Program Type must be one of allowed list
    - "MRO" is forbidden; must be "MRO/Support"
    """
    res = contract_extractor.invoke({
        "paragraph": state["input_text"],
        "contract_date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_5_split(state: AgentState):
    """
    NODE Stage5: Split Engine

    PURPOSE:
    - Convert 1 extracted row -> N rows if paragraph describes:
      multi suppliers, multi operators/qty allocation, multi FMS countries.

    OUTPUT:
    - final_rows list used downstream for validation + export
    """
    res = splitter_agent.invoke({
        "paragraph": state["input_text"],
        "base_row": state["final_data"]
    })
    return {"final_rows": res.get("rows", [state["final_data"]])}


def stage_6_quality_validation(state: AgentState):
    """
    NODE Stage6: Quality Validation (Rule-based + Optional LLM on FAIL)

    PURPOSE:
    - Evaluate each output row AFTER splitting.
    - Assign QA Status = PASS/FAIL.
    - If FAIL:
        Run LLM validator (Stage6B) ONLY for FAIL rows.

    WHY IMPORTANT:
    - This catches broken supplier extraction or program type issues.
    - Prevents bad rows from silently going into output Excel.
    """
    paragraph = state["input_text"]
    validated_rows = []

    for row in state.get("final_rows", []):
        qa = quality_validator.invoke({"paragraph": paragraph, "row": row})
        merged = row.copy()
        merged.update(qa)

        # Only run LLM validator if FAIL
        if merged.get("QA Status") == "FAIL":
            llm_q = llm_fail_validator.invoke({"paragraph": paragraph, "row": merged})
            merged["QA Status Final"] = llm_q.get("QA Status Final", "FAIL")
            merged["QA LLM Notes"] = llm_q.get("QA LLM Notes", "")
            merged["Supplier Fix Suggestion"] = llm_q.get("Supplier Fix Suggestion", "")
            merged["Program Type Fix Suggestion"] = llm_q.get("Program Type Fix Suggestion", "")
        else:
            merged["QA Status Final"] = "PASS"
            merged["QA LLM Notes"] = "Not Applicable"
            merged["Supplier Fix Suggestion"] = "Not Applicable"
            merged["Program Type Fix Suggestion"] = "Not Applicable"

        validated_rows.append(merged)

    return {"final_rows": validated_rows}


workflow = StateGraph(AgentState)
workflow.add_node("Stage1", stage_1_sourcing)
workflow.add_node("Stage2", stage_2_geography)
workflow.add_node("Stage3", stage_3_system)
workflow.add_node("Stage4", stage_4_contract)
workflow.add_node("Stage5", stage_5_split)
workflow.add_node("Stage6", stage_6_quality_validation)

workflow.add_edge(START, "Stage1")
workflow.add_edge("Stage1", "Stage2")
workflow.add_edge("Stage2", "Stage3")
workflow.add_edge("Stage3", "Stage4")
workflow.add_edge("Stage4", "Stage5")
workflow.add_edge("Stage5", "Stage6")
workflow.add_edge("Stage6", END)

app = workflow.compile()


# ======================================================================================
# 9) GRAPH VISUALIZATION (OFFLINE SAFE)
# ======================================================================================

def export_workflow_mermaid(app_obj, out_file="workflow.mmd"):
    """
    Export Mermaid graph text locally (no external API call).

    WHY:
    - Office or Streamlit cloud might block mermaid.ink
    - Still useful for documenting your agent pipeline

    Output:
    - workflow.mmd file
    """
    mmd = app_obj.get_graph().draw_mermaid()
    with open(out_file, "w", encoding="utf-8") as f:
        f.write(mmd)
    print(f"‚úÖ Workflow Mermaid saved locally: {out_file}")
    return out_file


# ======================================================================================
# 10) EXCEL HIGHLIGHTING FEATURE
# ======================================================================================

def highlight_evidence_reason_columns(excel_path: str):
    """
    Highlight Evidence + Reason columns in output Excel.

    Evidence Columns:
      - Light Yellow
    Reason Columns:
      - Light Blue

    This makes review very easy for business users.
    """
    wb = load_workbook(excel_path)
    ws = wb.active

    header = [cell.value for cell in ws[1]]

    evidence_cols = []
    reason_cols = []

    for idx, col_name in enumerate(header, start=1):
        if isinstance(col_name, str) and "Evidence" in col_name:
            evidence_cols.append(idx)
        if isinstance(col_name, str) and "Reason" in col_name:
            reason_cols.append(idx)

    evidence_fill = PatternFill(start_color="FFF2CC", end_color="FFF2CC", fill_type="solid")
    reason_fill = PatternFill(start_color="D9E1F2", end_color="D9E1F2", fill_type="solid")
    header_font = Font(bold=True)

    for col_idx in evidence_cols:
        ws.cell(row=1, column=col_idx).fill = evidence_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for col_idx in reason_cols:
        ws.cell(row=1, column=col_idx).fill = reason_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for row in range(2, ws.max_row + 1):
        for col_idx in evidence_cols:
            ws.cell(row=row, column=col_idx).fill = evidence_fill
        for col_idx in reason_cols:
            ws.cell(row=row, column=col_idx).fill = reason_fill

    wb.save(excel_path)
    print("‚úÖ Evidence + Reason columns highlighted successfully.")


# ======================================================================================
# 11) MAIN EXECUTION
# ======================================================================================

if __name__ == "__main__":

    print(f"\nüìå Loading Input File: {INPUT_EXCEL_PATH}")

    # Offline safe workflow graph
    export_workflow_mermaid(app, out_file="workflow.mmd")

    try:
        df_input = pd.read_excel(INPUT_EXCEL_PATH)

        required_cols = ["Source URL", "Contract Date", "Contract Description"]
        if not all(col in df_input.columns for col in required_cols):
            raise ValueError(f"Excel file must contain columns: {required_cols}")

        print(f"üöÄ Processing {len(df_input)} rows...")
        results = []

        for index, row in df_input.iterrows():
            print(f"\nüîπ Row {index + 1}/{len(df_input)}")

            desc = str(row["Contract Description"]) if pd.notna(row["Contract Description"]) else ""
            c_date = str(row["Contract Date"]) if pd.notna(row["Contract Date"]) else str(datetime.date.today())
            c_url = str(row["Source URL"]) if pd.notna(row["Source URL"]) else ""

            initial_state = {
                "input_text": desc,
                "input_date": c_date,
                "input_url": c_url,
                "final_data": {},
                "final_rows": [],
                "messages": []
            }

            output_state = app.invoke(initial_state)

            rows = output_state.get("final_rows", [])
            if not rows:
                rows = [output_state.get("final_data", {})]

            results.extend(rows)

        df_final = pd.DataFrame(results)

        FINAL_COLUMNS = [
            "QA Status", "QA Flags", "QA Notes",
            "QA Status Final", "QA LLM Notes",
            "Supplier Fix Suggestion", "Program Type Fix Suggestion",

            "Customer Region", "Customer Country", "Customer Operator",
            "Supplier Region", "Supplier Country", "Domestic Content",

            "Split Flag", "Split Reason",

            "Market Segment", "Market Segment Evidence", "Market Segment Reason",
            "System Type (General)", "System Type (General) Evidence", "System Type (General) Reason",
            "System Type (Specific)", "System Type (Specific) Evidence", "System Type (Specific) Reason",
            "System Name (General)", "System Name (General) Evidence", "System Name (General) Reason",
            "System Name (Specific)", "System Name (Specific) Evidence", "System Name (Specific) Reason",
            "System Piloting", "System Piloting Evidence", "System Piloting Reason",
            "Confidence",

            "Supplier Name", "Program Type", "Expected MRO Contract Duration (Months)",
            "Quantity", "Value Certainty", "Value (Million)", "Currency",
            "Value (USD$ Million)", "Value Note (If Any)", "G2G/B2G",
            "Signing Month", "Signing Year",

            "Description of Contract",
            "Additional Notes (Internal Only)",
            "Source Link(s)",
            "Contract Date",
            "Reported Date (By SGA)"
        ]

        df_final = df_final.reindex(columns=FINAL_COLUMNS, fill_value="")
        df_final.to_excel(OUTPUT_EXCEL_PATH, index=False)

        # Highlight Evidence + Reason columns
        highlight_evidence_reason_columns(OUTPUT_EXCEL_PATH)

        print("\n‚úÖ Processing Complete!")
        print(f"üíæ Output File Saved: {OUTPUT_EXCEL_PATH}")
        print(df_final.head(3).to_string(index=False))

    except Exception as e:
        print(f"\n‚ùå ERROR: {e}")
