In [None]:
import os
import json
import difflib
import pandas as pd
import datetime
from dateutil import parser
from dateutil.relativedelta import relativedelta
import getpass
from typing import Annotated, TypedDict, List
import re
import pickle
import faiss

# LangChain / LangGraph Imports
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from openai import OpenAI

# ==============================================================================
# ‚úÖ 0. RAG RETRIEVER (Single File Implementation)
# ==============================================================================

class SystemKBRetriever:
    """
    Loads FAISS index + metadata created from your KB excel.
    Uses ONLY the contract paragraph to retrieve similar examples.
    """
    def __init__(self, kb_dir: str):
        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError(
                f"‚ùå KB files not found in: {kb_dir}\n"
                f"Expected:\n- {index_path}\n- {meta_path}\n\n"
                f"Build the KB first using your KB builder script."
            )

        print(f"‚úÖ Loading FAISS Index: {index_path}")
        self.index = faiss.read_index(index_path)

        print(f"‚úÖ Loading KB Metadata: {meta_path}")
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        print(f"‚úÖ KB Loaded: {len(self.meta)} rows")

    def retrieve(self, query_text: str, top_k: int = 3):
        """
        Returns:
        [
          {"score": float, "meta": {...all 29 cols...}},
          ...
        ]
        """
        import numpy as np
        from sentence_transformers import SentenceTransformer

        query_text = str(query_text).strip()
        if not query_text:
            return []

        # ‚úÖ Use embedding model only when needed (lazy load)
        if not hasattr(self, "embedder"):
            self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)

        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0:
                continue
            results.append({
                "score": float(score),
                "meta": self.meta[idx]
            })

        return results


# ==============================================================================
# 1. CONFIGURATION & FILE PATHS
# ==============================================================================

# ‚¨áÔ∏è UPDATE PATHS HERE ‚¨áÔ∏è
TAXONOMY_PATH = r'C:\Users\mukeshkr\Desktop\DefenseExtraction\testing\taxonomy.json'
SUPPLIERS_PATH = r'C:\Users\mukeshkr\Desktop\DefenseExtraction\testing\suppliers.json'
INPUT_EXCEL_PATH = r"C:\Users\mukeshkr\Desktop\DefenseExtraction\data\source_file.xlsx"
OUTPUT_EXCEL_PATH = "Processed_Defense_Data.xlsx"

# ‚úÖ RAG KB Directory (must contain system_kb.faiss + system_kb_meta.pkl)
RAG_KB_DIR = r"C:\Users\mukeshkr\Desktop\DefenseExtraction\testing\system_kb_store"

# Setup API Key
if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter the LLM Foundry API Key: ")

# Shared Client
client = OpenAI(
    api_key=f'{os.environ.get("LLMFOUNDRY_TOKEN")}:my-test-project',
    base_url="https://llmfoundry.straive.com/openai/v1/",
)

# ‚úÖ Load retriever once globally
retriever = SystemKBRetriever(kb_dir=RAG_KB_DIR)

# --- FILE LOADING HELPERS ---
def load_json_file(filename, default_value):
    try:
        with open(filename, 'r', encoding='utf-8') as f:
            print(f"‚úÖ Loaded {filename}")
            return json.load(f)
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not load {filename} ({e}). Using default.")
        return default_value

# 1. Load Taxonomy
raw_taxonomy = load_json_file(TAXONOMY_PATH, {})
TAXONOMY_STR = json.dumps(raw_taxonomy, separators=(',', ':'))

# 2. Load Suppliers
SUPPLIER_LIST = load_json_file(SUPPLIERS_PATH, [
    "Dell Inc", "Boeing", "Lockheed Martin", "Raytheon Technologies",
    "Northrop Grumman", "L3Harris", "BAE Systems", "General Dynamics"
])

# 3. System Rules
RULE_BOOK = {
    "defensive_countermeasures": {
        "triggers": ["flare", "chaff", "countermeasure", "decoy", "mju-", "ale-"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Defensive Systems', Specific: 'Defensive Aid Suite'"
    },
    "radars_and_sensors": {
        "triggers": ["radar", "sonar", "sensor", "an/apy", "an/tpy"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Sensors'"
    },
    "ammunition": {
        "triggers": ["cartridge", "round", "projectile", " 5.56", " 7.62", "ammo"],
        "guidance": "Market Segment: 'Weapon Systems', System Type (General): 'Ammunition'"
    }
}

# 4. Geography Mapping
GEOGRAPHY_MAPPING = {
    "North America": ["USA", "United States", "US", "Canada", "America"],
    "Europe": ["UK", "United Kingdom", "Ukraine", "Germany", "France", "Italy", "Spain", "Poland", "Netherlands", "Norway", "Sweden", "Finland", "Denmark", "Belgium"],
    "Asia-Pacific": ["Australia", "Japan", "South Korea", "Taiwan", "India", "Singapore", "New Zealand"],
    "Middle East and North Africa": ["Israel", "Saudi Arabia", "UAE", "Egypt", "Qatar", "Kuwait", "Iraq"],
    "International Organisations": ["NATO", "EU", "IFU", "UN", "NSPA"]
}

# ==============================================================================
# 2. HELPER FUNCTIONS
# ==============================================================================

def get_best_supplier_match(extracted_name):
    """Fuzzy matches the extracted name against the loaded SUPPLIER_LIST."""
    if not extracted_name or extracted_name.lower() in ["unknown", "n/a"]:
        return "Unknown"

    clean_name = extracted_name.strip()

    supplier_map = {s.lower(): s for s in SUPPLIER_LIST}
    if clean_name.lower() in supplier_map:
        return supplier_map[clean_name.lower()]

    matches = difflib.get_close_matches(clean_name, SUPPLIER_LIST, n=1, cutoff=0.6)
    return matches[0] if matches else clean_name


def calculate_mro_months(start_date_str, end_date_text, program_type):
    """Calculates duration only if Program Type is MRO/Support."""
    if program_type != "MRO/Support":
        return "Not Applicable"

    try:
        if not start_date_str or not end_date_text or str(end_date_text).lower() in ["unknown", "n/a"]:
            return "Not Applicable"

        start = pd.to_datetime(start_date_str, dayfirst=True)
        end = parser.parse(end_date_text, fuzzy=True)

        diff = relativedelta(end, start)
        total_months = (diff.years * 12) + diff.months
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"


def get_region_for_country(country_name):
    """Robust lookup handling casing/whitespace."""
    if not country_name or country_name.lower() in ["unknown", "n/a", "not applicable"]:
        return "Unknown"

    clean = country_name.strip().lower()

    if clean in ["us", "usa", "united states", "united states of america"]:
        return "North America"
    if clean in ["uk", "britain", "great britain"]:
        return "Europe"

    for region, countries in GEOGRAPHY_MAPPING.items():
        if any(c.lower() == clean for c in countries):
            return region
    return "Unknown"


# ‚úÖ Regex Designator Extractors (for System Name + Piloting)
DESIGNATOR_PATTERNS = [
    r"\bDDG[-\s]?\d+\b", r"\bCVN[-\s]?\d+\b", r"\bSSN[-\s]?\d+\b",
    r"\bLCS[-\s]?\d+\b", r"\bLPD[-\s]?\d+\b", r"\bLHA[-\s]?\d+\b", r"\bLHD[-\s]?\d+\b",
    r"\bF-\d+\b", r"\bB-\d+\b", r"\bC-\d+\b", r"\bA-\d+\b",
    r"\bMQ-\d+\b", r"\bRQ-\d+\b",
    r"\bAN\/[A-Z0-9\-]+\b",
    r"\b(AIM|AGM|SM|RIM|MIM)-\d+\b",
]

def extract_designators(text: str):
    text = str(text)
    found = []
    for pat in DESIGNATOR_PATTERNS:
        found.extend(re.findall(pat, text, flags=re.IGNORECASE))
    cleaned = []
    for f in found:
        cleaned.append(f.upper().replace(" ", "").replace("--", "-"))
    seen = set()
    final = []
    for x in cleaned:
        if x not in seen:
            final.append(x)
            seen.add(x)
    return final

def detect_piloting_rule_based(text: str, designators: List[str]) -> str:
    t = str(text).lower()

    if any(d.startswith(("MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    if any(k in t for k in ["unmanned", "uav", "drone", "autonomous"]):
        return "Uncrewed"

    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD")) for d in designators):
        return "Crewed"
    if "uss " in t:
        return "Crewed"

    return "Not Applicable"


# ==============================================================================
# 3. TOOL DEFINITIONS (AGENTS)
# ==============================================================================

# --- AGENT 1: SOURCING ---
class SourcingInput(BaseModel):
    paragraph: str = Field(description="Contract text.")
    url: str = Field(description="Source URL.")
    date: str = Field(description="Contract Date.")

@tool("sourcing_extractor")
def sourcing_extractor(paragraph: str, url: str, date: str):
    """Stage 1: Prepares Metadata."""
    reported_date = datetime.datetime.now().strftime("%Y-%m-%d")
    notes = "Standard extraction."
    if "modification" in paragraph.lower():
        notes = "Contract Modification."
    if "split" in paragraph.lower():
        notes = "Split award detected."

    return {
        "Description of Contract": paragraph,
        "Additional Notes (Internal Only)": notes,
        "Source Link(s)": url,
        "Contract Date": date,
        "Reported Date (By SGA)": reported_date
    }


# --- AGENT 2: GEOGRAPHY ---
class GeographyInput(BaseModel):
    paragraph: str = Field(description="Contract text.")

@tool("geography_extractor")
def geography_extractor(paragraph: str):
    """Stage 2: Geography Logic."""
    sys_prompt = """
Extract: Customer Country, Supplier Country, Customer Operator.
Logic: If 'Navy awarded...', Operator is Navy.
Return JSON:
{
  "Customer Country": "...",
  "Customer Operator": "...",
  "Supplier Country": "..."
}
"""
    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": paragraph}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
    except Exception:
        raw = {}

    cust = raw.get("Customer Country", "Unknown")
    supp = raw.get("Supplier Country", "Unknown")

    domestic = "Indigenous" if cust.lower() == supp.lower() else "Imported"
    if "united states" in cust.lower() and "usa" in supp.lower():
        domestic = "Indigenous"

    return {
        "Customer Region": get_region_for_country(cust),
        "Customer Country": cust,
        "Customer Operator": raw.get("Customer Operator", "Unknown"),
        "Supplier Region": get_region_for_country(supp),
        "Supplier Country": supp,
        "Domestic Content": domestic
    }


# --- ‚úÖ AGENT 3: SYSTEM (UPGRADED WITH RAG + Evidence + Reason) ---
class SystemInput(BaseModel):
    paragraph: str = Field(description="Contract text.")

@tool("system_classifier")
def system_classifier(paragraph: str):
    """Stage 3: System classification using RAG + Rule Book + Evidence & Reason."""
    paragraph = str(paragraph).strip()
    if not paragraph:
        return {}

    lower_text = paragraph.lower()

    # RULE BOOK triggers
    hints = [
        f"RULE: {v['guidance']}"
        for _, v in RULE_BOOK.items()
        if any(t in lower_text for t in v["triggers"])
    ]
    hint_str = "\n".join(hints) if hints else "No special override rules triggered."

    # Local extraction
    designators = extract_designators(paragraph)

    # Rule-based piloting
    piloting_rule = detect_piloting_rule_based(paragraph, designators)

    # RAG Retrieval
    rag_hits = retriever.retrieve(paragraph, top_k=3)

    rag_examples = []
    for hit in rag_hits:
        meta = hit["meta"]
        rag_examples.append({
            "score": round(hit["score"], 4),
            "Market Segment": meta.get("Market Segment", ""),
            "System Type (General)": meta.get("System Type (General)", ""),
            "System Type (Specific)": meta.get("System Type (Specific)", ""),
            "System Name (General)": meta.get("System Name (General)", ""),
            "System Name (Specific)": meta.get("System Name (Specific)", ""),
            "System Piloting": meta.get("System Piloting", ""),
            "Supplier Name": meta.get("Supplier Name", ""),
            "Customer Operator": meta.get("Customer Operator", ""),
            "Snippet": meta.get("Description of Contract", "")[:220] + "..."
        })

    sys_prompt = f"""
You are a Senior Defense System Classification Analyst.

Use these inputs:
1) RAG Similar Examples (top 3)
2) Rule Book Overrides
3) Extracted Designators

RULE BOOK OVERRIDES:
{hint_str}

OUTPUT RULES:
- Return ONLY a FLAT JSON object.
- Every value must be a STRING.
- Do NOT return nested JSON or lists.
- Evidence must be copied EXACTLY from paragraph.
- If evidence not present, use "Not Found".

Return JSON exactly:
{{
  "Market Segment": "",
  "Market Segment Evidence": "",
  "Market Segment Reason": "",

  "System Type (General)": "",
  "System Type (General) Evidence": "",
  "System Type (General) Reason": "",

  "System Type (Specific)": "",
  "System Type (Specific) Evidence": "",
  "System Type (Specific) Reason": "",

  "System Name (General)": "",
  "System Name (General) Evidence": "",
  "System Name (General) Reason": "",

  "System Name (Specific)": "",
  "System Name (Specific) Evidence": "",
  "System Name (Specific) Reason": "",

  "System Piloting": "",
  "System Piloting Evidence": "",
  "System Piloting Reason": "",

  "Confidence": "High/Medium/Low"
}}
"""

    user_prompt = f"""
INPUT PARAGRAPH:
{paragraph}

DESIGNATORS (regex extracted):
{designators if designators else "None"}

RULE-BASED PILOTING:
{piloting_rule}

RAG SIMILAR EXAMPLES:
{json.dumps(rag_examples, indent=2)}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        result = json.loads(completion.choices[0].message.content)

        # Hard override piloting (best accuracy)
        result["System Piloting"] = piloting_rule
        if not result.get("System Piloting Reason"):
            result["System Piloting Reason"] = "Piloting derived using deterministic rules."
        if not result.get("System Piloting Evidence"):
            result["System Piloting Evidence"] = "Not Found"

        # Ensure flat
        for k, v in result.items():
            if isinstance(v, (dict, list)):
                result[k] = str(v)

        return result

    except Exception as e:
        return {
            "Market Segment": "",
            "System Type (General)": "",
            "System Type (Specific)": "",
            "System Name (General)": "",
            "System Name (Specific)": "",
            "System Piloting": piloting_rule,
            "Market Segment Evidence": "Not Found",
            "System Type (General) Evidence": "Not Found",
            "System Type (Specific) Evidence": "Not Found",
            "System Name (General) Evidence": "Not Found",
            "System Name (Specific) Evidence": "Not Found",
            "System Piloting Evidence": "Not Found",
            "Market Segment Reason": "",
            "System Type (General) Reason": "",
            "System Type (Specific) Reason": "",
            "System Name (General) Reason": "",
            "System Name (Specific) Reason": "",
            "System Piloting Reason": "Piloting derived using deterministic rules.",
            "Confidence": "Low",
            "Error": str(e)
        }


# --- AGENT 4: CONTRACT ---
class ContractInfoInput(BaseModel):
    paragraph: str = Field(description="Contract text.")
    contract_date: str = Field(description="Signed date.")

@tool("contract_extractor")
def contract_extractor(paragraph: str, contract_date: str):
    """Stage 4: Extracts Financials, Program Type, and Dates based on strict SOP."""

    system_instruction = """
You are a Defense Contract Financial Analyst. Extract data strictly following these SOP rules:

1) Supplier Name: Extract the exact company name text found in paragraph.
2) Program Type: Procurement / Training / MRO/Support / RDT&E / Upgrade / Other Service
3) Value Certainty:
   - Confirmed: fixed price/obligated stated
   - Estimated: ceiling/potential/IDIQ/multi-award
4) Quantity: number or Not Applicable
5) G2G/B2G:
   - G2G only if Foreign Military Sales (FMS)
   - else B2G
6) Completion Date: only needed for MRO duration calc

Return JSON only.
"""

    user_prompt = f"""
Analyze contract:
"{paragraph}"

Signed Date: {contract_date}

Return JSON:
{{
  "raw_supplier_name": "",
  "program_type": "",
  "value_million_raw": "",
  "currency_code": "",
  "value_certainty": "",
  "quantity": "",
  "completion_date_text": "",
  "g2g_b2g": "",
  "value_note": ""
}}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
    except Exception as e:
        return {"Error": str(e)}

    final_supplier = get_best_supplier_match(raw.get("raw_supplier_name"))
    prog_type = raw.get("program_type", "Unknown")
    mro_months = calculate_mro_months(contract_date, raw.get("completion_date_text"), prog_type)

    try:
        val_str = str(raw.get("value_million_raw", "0")).replace(",", "").replace("$", "")
        val_float = float(val_str)
        val_formatted = "{:.3f}".format(val_float)
    except:
        val_formatted = "0.000"

    try:
        dt = pd.to_datetime(contract_date)
        sign_month = dt.strftime("%B")
        sign_year = str(dt.year)
    except:
        sign_month, sign_year = "Unknown", "Unknown"

    val_note = raw.get("value_note", "Not Applicable")
    if "split" in paragraph.lower() and val_note == "Not Applicable":
        val_note = "Split contract; value distribution unclear."

    return {
        "Supplier Name": final_supplier,
        "Program Type": prog_type,
        "Expected MRO Contract Duration (Months)": mro_months,
        "Quantity": raw.get("quantity", "Not Applicable"),
        "Value Certainty": raw.get("value_certainty", "Confirmed"),
        "Value (Million)": val_formatted,
        "Currency": raw.get("currency_code", "USD$"),
        "Value (USD$ Million)": val_formatted,
        "Value Note (If Any)": val_note,
        "G2G/B2G": raw.get("g2g_b2g", "B2G"),
        "Signing Month": sign_month,
        "Signing Year": sign_year
    }


# ==============================================================================
# 4. LANGGRAPH PIPELINE
# ==============================================================================

class AgentState(TypedDict):
    input_text: str
    input_date: str
    input_url: str
    final_data: dict
    messages: Annotated[List[AnyMessage], add_messages]

def stage_1_sourcing(state: AgentState):
    res = sourcing_extractor.invoke({
        "paragraph": state["input_text"],
        "url": state["input_url"],
        "date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_2_geography(state: AgentState):
    res = geography_extractor.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_3_system(state: AgentState):
    res = system_classifier.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_4_contract(state: AgentState):
    res = contract_extractor.invoke({
        "paragraph": state["input_text"],
        "contract_date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

workflow = StateGraph(AgentState)
workflow.add_node("Stage1", stage_1_sourcing)
workflow.add_node("Stage2", stage_2_geography)
workflow.add_node("Stage3", stage_3_system)
workflow.add_node("Stage4", stage_4_contract)

workflow.add_edge(START, "Stage1")
workflow.add_edge("Stage1", "Stage2")
workflow.add_edge("Stage2", "Stage3")
workflow.add_edge("Stage3", "Stage4")
workflow.add_edge("Stage4", END)

app = workflow.compile()

# ==============================================================================
# 5. EXECUTION & FORMATTING
# ==============================================================================

if __name__ == "__main__":

    print(f"üìÇ Loading Input File: {INPUT_EXCEL_PATH}...")

    try:
        df_input = pd.read_excel(INPUT_EXCEL_PATH)

        required_cols = ["Source URL", "Contract Date", "Contract Description"]
        if not all(col in df_input.columns for col in required_cols):
            raise ValueError(f"Excel file must contain columns: {required_cols}")

        print(f"üöÄ Processing {len(df_input)} rows...")
        results = []

        for index, row in df_input.iterrows():
            print(f"   -> Row {index + 1}...")

            desc = str(row["Contract Description"]) if pd.notna(row["Contract Description"]) else ""
            c_date = str(row["Contract Date"]) if pd.notna(row["Contract Date"]) else str(datetime.date.today())
            c_url = str(row["Source URL"]) if pd.notna(row["Source URL"]) else ""

            initial_state = {
                "input_text": desc,
                "input_date": c_date,
                "input_url": c_url,
                "final_data": {},
                "messages": []
            }

            output_state = app.invoke(initial_state)
            results.append(output_state["final_data"])

        df_final = pd.DataFrame(results)

        FINAL_COLUMNS = [
            "Customer Region", "Customer Country", "Customer Operator",
            "Supplier Region", "Supplier Country", "Domestic Content",

            "Market Segment", "Market Segment Evidence", "Market Segment Reason",
            "System Type (General)", "System Type (General) Evidence", "System Type (General) Reason",
            "System Type (Specific)", "System Type (Specific) Evidence", "System Type (Specific) Reason",
            "System Name (General)", "System Name (General) Evidence", "System Name (General) Reason",
            "System Name (Specific)", "System Name (Specific) Evidence", "System Name (Specific) Reason",
            "System Piloting", "System Piloting Evidence", "System Piloting Reason",
            "Confidence",

            "Supplier Name", "Program Type", "Expected MRO Contract Duration (Months)",
            "Quantity", "Value Certainty", "Value (Million)", "Currency",
            "Value (USD$ Million)", "Value Note (If Any)", "G2G/B2G",
            "Signing Month", "Signing Year",

            "Description of Contract",
            "Additional Notes (Internal Only)",
            "Source Link(s)",
            "Contract Date",
            "Reported Date (By SGA)"
        ]

        df_final = df_final.reindex(columns=FINAL_COLUMNS, fill_value="")

        df_final.to_excel(OUTPUT_EXCEL_PATH, index=False)

        print("\n‚úÖ Processing Complete!")
        print(f"üíæ File saved to: {OUTPUT_EXCEL_PATH}")
        print(df_final.head().to_string())

    except Exception as e:
        print(f"\n‚ùå Error: {e}")


In [None]:
import os
import json
import pandas as pd
import datetime
from dateutil import parser
from dateutil.relativedelta import relativedelta
import getpass
from typing import Annotated, TypedDict, List
import re
import pickle
import faiss
from difflib import SequenceMatcher

# LangChain / LangGraph Imports
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from openai import OpenAI


# ==============================================================================
# ‚úÖ RAG RETRIEVER (Single File Implementation)
# ==============================================================================

class SystemKBRetriever:
    """
    Loads FAISS index + metadata created from your KB excel.
    Uses ONLY the contract paragraph to retrieve similar examples.
    """

    def __init__(self, kb_dir: str):
        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError(
                f"‚ùå KB files not found in: {kb_dir}\n"
                f"Expected:\n- {index_path}\n- {meta_path}\n\n"
                f"Build the KB first using your KB builder script."
            )

        print(f"‚úÖ Loading FAISS Index: {index_path}")
        self.index = faiss.read_index(index_path)

        print(f"‚úÖ Loading KB Metadata: {meta_path}")
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        print(f"‚úÖ KB Loaded: {len(self.meta)} rows")

    def retrieve(self, query_text: str, top_k: int = 3):
        """
        Returns:
        [
          {"score": float, "meta": {...all 29 cols...}},
          ...
        ]
        """
        import numpy as np
        from sentence_transformers import SentenceTransformer

        query_text = str(query_text).strip()
        if not query_text:
            return []

        # Use embedding model only when needed (lazy load)
        if not hasattr(self, "embedder"):
            self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)

        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0:
                continue
            results.append({"score": float(score), "meta": self.meta[idx]})

        return results


# ==============================================================================
# ‚úÖ 1. CONFIGURATION & FILE PATHS
# ==============================================================================

TAXONOMY_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\taxonomy.json"
SUPPLIERS_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\suppliers.json"
INPUT_EXCEL_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\source_file.xlsx"
OUTPUT_EXCEL_PATH = "Processed_Defense_Data.xlsx"

RAG_KB_DIR = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\system_kb_store"


# ==============================================================================
# ‚úÖ 2. API Setup
# ==============================================================================

if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter the LLM Foundry API Key: ")

client = OpenAI(
    api_key=f'{os.environ.get("LLMFOUNDRY_TOKEN")}:my-test-project',
    base_url="https://llmfoundry.straive.com/openai/v1/",
)

retriever = SystemKBRetriever(kb_dir=RAG_KB_DIR)


# ==============================================================================
# ‚úÖ 3. JSON Loaders
# ==============================================================================

def load_json_file(filename, default_value):
    try:
        with open(filename, "r", encoding="utf-8") as f:
            print(f"‚úÖ Loaded: {filename}")
            return json.load(f)
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not load {filename} ({e}). Using default.")
        return default_value


raw_taxonomy = load_json_file(TAXONOMY_PATH, {})
TAXONOMY_STR = json.dumps(raw_taxonomy, separators=(",", ":"))


# ‚úÖ IMPORTANT: Supplier list MUST come from suppliers.json
SUPPLIER_LIST = load_json_file(SUPPLIERS_PATH, [])
if not SUPPLIER_LIST:
    raise ValueError("‚ùå suppliers.json loaded 0 suppliers. Please verify the path.")


# ==============================================================================
# ‚úÖ 4. Rule Book + Geography Mapping
# ==============================================================================

RULE_BOOK = {
    "defensive_countermeasures": {
        "triggers": ["flare", "chaff", "countermeasure", "decoy", "mju-", "ale-"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Defensive Systems', Specific: 'Defensive Aid Suite'",
    },
    "radars_and_sensors": {
        "triggers": ["radar", "sonar", "sensor", "an/apy", "an/tpy"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Sensors'",
    },
    "ammunition": {
        "triggers": ["cartridge", "round", "projectile", " 5.56", " 7.62", "ammo"],
        "guidance": "Market Segment: 'Weapon Systems', System Type (General): 'Ammunition'",
    },
}

GEOGRAPHY_MAPPING = {
    "North America": ["USA", "United States", "US", "Canada", "America"],
    "Europe": ["UK", "United Kingdom", "Ukraine", "Germany", "France", "Italy", "Spain", "Poland", "Netherlands", "Norway", "Sweden", "Finland", "Denmark", "Belgium"],
    "Asia-Pacific": ["Australia", "Japan", "South Korea", "Taiwan", "India", "Singapore", "New Zealand"],
    "Middle East and North Africa": ["Israel", "Saudi Arabia", "UAE", "Egypt", "Qatar", "Kuwait", "Iraq"],
    "International Organisations": ["NATO", "EU", "IFU", "UN", "NSPA"],
}


# ==============================================================================
# ‚úÖ 5. HELPER FUNCTIONS (Supplier Matching FIXED ‚úÖ‚úÖ‚úÖ)
# ==============================================================================

def normalize_supplier_text(x: str) -> str:
    if not x:
        return ""
    x = str(x).strip()

    # normalize unicode dashes
    x = x.replace("‚Äì", "-").replace("‚Äî", "-")

    # remove multiple spaces
    x = re.sub(r"\s+", " ", x).strip()

    # remove extra location after comma
    # ex: "General Dynamics NASSCO - San Diego, San Diego, California"
    x = x.split(",")[0].strip()

    return x


def token_overlap_score(a: str, b: str) -> float:
    a_tokens = set(re.findall(r"[a-z0-9]+", a.lower()))
    b_tokens = set(re.findall(r"[a-z0-9]+", b.lower()))
    if not a_tokens or not b_tokens:
        return 0.0
    return len(a_tokens & b_tokens) / max(len(a_tokens), len(b_tokens))


def get_best_supplier_match(extracted_supplier: str):
    """
    ‚úÖ FINAL Supplier Name MUST be from SUPPLIER_LIST (suppliers.json)
    Returns:
      (best_supplier_name, best_score)
    """
    if not extracted_supplier:
        return "Unknown", 0.0

    extracted_supplier = normalize_supplier_text(extracted_supplier)
    low = extracted_supplier.lower()

    if low in ["unknown", "n/a", "not applicable"]:
        return "Unknown", 0.0

    # 1) Exact match
    supplier_map = {s.lower(): s for s in SUPPLIER_LIST}
    if low in supplier_map:
        return supplier_map[low], 1.0

    # 2) Containment match (best practical for long extracted)
    containment_hits = [s for s in SUPPLIER_LIST if s.lower() in low]
    if containment_hits:
        containment_hits.sort(key=len, reverse=True)  # longest = most specific
        return containment_hits[0], 0.95

    # 3) Hybrid fuzzy match over supplier list
    best_name = "Unknown"
    best_score = 0.0

    for s in SUPPLIER_LIST:
        s_clean = normalize_supplier_text(s)

        seq = SequenceMatcher(None, low, s_clean.lower()).ratio()
        tok = token_overlap_score(extracted_supplier, s_clean)

        final = (0.65 * tok) + (0.35 * seq)

        if final > best_score:
            best_score = final
            best_name = s

    # strict cutoff: avoid wrong match
    if best_score < 0.45:
        return "Unknown", best_score

    return best_name, best_score


def calculate_mro_months(start_date_str, end_date_text, program_type):
    """Calculates duration only if Program Type is MRO/Support."""
    if program_type != "MRO/Support":
        return "Not Applicable"

    try:
        if not start_date_str or not end_date_text or str(end_date_text).lower() in ["unknown", "n/a"]:
            return "Not Applicable"

        start = pd.to_datetime(start_date_str, dayfirst=True)
        end = parser.parse(end_date_text, fuzzy=True)

        diff = relativedelta(end, start)
        total_months = (diff.years * 12) + diff.months
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"


def get_region_for_country(country_name):
    """Robust lookup handling casing/whitespace."""
    if not country_name or country_name.lower() in ["unknown", "n/a", "not applicable"]:
        return "Unknown"

    clean = country_name.strip().lower()

    if clean in ["us", "usa", "united states", "united states of america"]:
        return "North America"
    if clean in ["uk", "britain", "great britain"]:
        return "Europe"

    for region, countries in GEOGRAPHY_MAPPING.items():
        if any(c.lower() == clean for c in countries):
            return region
    return "Unknown"


# ==============================================================================
# ‚úÖ 6. Designators (System Name + Piloting)
# ==============================================================================

DESIGNATOR_PATTERNS = [
    r"\bDDG[-\s]?\d+\b", r"\bCVN[-\s]?\d+\b", r"\bSSN[-\s]?\d+\b",
    r"\bLCS[-\s]?\d+\b", r"\bLPD[-\s]?\d+\b", r"\bLHA[-\s]?\d+\b", r"\bLHD[-\s]?\d+\b",
    r"\bF-\d+\b", r"\bB-\d+\b", r"\bC-\d+\b", r"\bA-\d+\b",
    r"\bMQ-\d+\b", r"\bRQ-\d+\b",
    r"\bAN\/[A-Z0-9\-]+\b",
    r"\b(AIM|AGM|SM|RIM|MIM)-\d+\b",
]

def extract_designators(text: str):
    text = str(text)
    found = []
    for pat in DESIGNATOR_PATTERNS:
        found.extend(re.findall(pat, text, flags=re.IGNORECASE))

    cleaned = []
    for f in found:
        cleaned.append(f.upper().replace(" ", "").replace("--", "-"))

    seen = set()
    final = []
    for x in cleaned:
        if x not in seen:
            final.append(x)
            seen.add(x)
    return final


def detect_piloting_rule_based(text: str, designators: List[str]) -> str:
    t = str(text).lower()

    if any(d.startswith(("MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    if any(k in t for k in ["unmanned", "uav", "drone", "autonomous"]):
        return "Uncrewed"

    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD")) for d in designators):
        return "Crewed"
    if "uss " in t:
        return "Crewed"

    return "Not Applicable"


# ==============================================================================
# ‚úÖ 7. TOOL DEFINITIONS (Agents)
# ==============================================================================

# --- AGENT 1: SOURCING ---
class SourcingInput(BaseModel):
    paragraph: str = Field(description="Contract text.")
    url: str = Field(description="Source URL.")
    date: str = Field(description="Contract Date.")

@tool("sourcing_extractor")
def sourcing_extractor(paragraph: str, url: str, date: str):
    reported_date = datetime.datetime.now().strftime("%Y-%m-%d")
    notes = "Standard extraction."
    if "modification" in paragraph.lower():
        notes = "Contract Modification."
    if "split" in paragraph.lower():
        notes = "Split award detected."

    return {
        "Description of Contract": paragraph,
        "Additional Notes (Internal Only)": notes,
        "Source Link(s)": url,
        "Contract Date": date,
        "Reported Date (By SGA)": reported_date
    }


# --- AGENT 2: GEOGRAPHY ---
class GeographyInput(BaseModel):
    paragraph: str = Field(description="Contract text.")

@tool("geography_extractor")
def geography_extractor(paragraph: str):
    sys_prompt = """
Extract: Customer Country, Supplier Country, Customer Operator.
Logic: If 'Navy awarded...', Operator is Navy.
Return JSON:
{
  "Customer Country": "...",
  "Customer Operator": "...",
  "Supplier Country": "..."
}
"""
    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": paragraph}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
    except Exception:
        raw = {}

    cust = raw.get("Customer Country", "Unknown")
    supp = raw.get("Supplier Country", "Unknown")

    domestic = "Indigenous" if cust.lower() == supp.lower() else "Imported"
    if "united states" in cust.lower() and "usa" in supp.lower():
        domestic = "Indigenous"

    return {
        "Customer Region": get_region_for_country(cust),
        "Customer Country": cust,
        "Customer Operator": raw.get("Customer Operator", "Unknown"),
        "Supplier Region": get_region_for_country(supp),
        "Supplier Country": supp,
        "Domestic Content": domestic
    }


# --- AGENT 3: SYSTEM (RAG + Evidence + Reason) ---
class SystemInput(BaseModel):
    paragraph: str = Field(description="Contract text.")

@tool("system_classifier")
def system_classifier(paragraph: str):
    paragraph = str(paragraph).strip()
    if not paragraph:
        return {}

    lower_text = paragraph.lower()

    hints = [
        f"RULE: {v['guidance']}"
        for _, v in RULE_BOOK.items()
        if any(t in lower_text for t in v["triggers"])
    ]
    hint_str = "\n".join(hints) if hints else "No special override rules triggered."

    designators = extract_designators(paragraph)
    piloting_rule = detect_piloting_rule_based(paragraph, designators)

    rag_hits = retriever.retrieve(paragraph, top_k=3)

    rag_examples = []
    for hit in rag_hits:
        meta = hit["meta"]
        rag_examples.append({
            "score": round(hit["score"], 4),
            "Market Segment": meta.get("Market Segment", ""),
            "System Type (General)": meta.get("System Type (General)", ""),
            "System Type (Specific)": meta.get("System Type (Specific)", ""),
            "System Name (General)": meta.get("System Name (General)", ""),
            "System Name (Specific)": meta.get("System Name (Specific)", ""),
            "System Piloting": meta.get("System Piloting", ""),
            "Supplier Name": meta.get("Supplier Name", ""),
            "Customer Operator": meta.get("Customer Operator", ""),
            "Snippet": meta.get("Description of Contract", "")[:220] + "..."
        })

    sys_prompt = f"""
You are a Senior Defense System Classification Analyst.

Use these inputs:
1) RAG Similar Examples (top 3)
2) Rule Book Overrides
3) Extracted Designators

RULE BOOK OVERRIDES:
{hint_str}

OUTPUT RULES:
- Return ONLY a FLAT JSON object.
- Every value must be a STRING.
- Do NOT return nested JSON or lists.
- Evidence must be copied EXACTLY from paragraph.
- If evidence not present, use "Not Found".

Return JSON exactly:
{{
  "Market Segment": "",
  "Market Segment Evidence": "",
  "Market Segment Reason": "",

  "System Type (General)": "",
  "System Type (General) Evidence": "",
  "System Type (General) Reason": "",

  "System Type (Specific)": "",
  "System Type (Specific) Evidence": "",
  "System Type (Specific) Reason": "",

  "System Name (General)": "",
  "System Name (General) Evidence": "",
  "System Name (General) Reason": "",

  "System Name (Specific)": "",
  "System Name (Specific) Evidence": "",
  "System Name (Specific) Reason": "",

  "System Piloting": "",
  "System Piloting Evidence": "",
  "System Piloting Reason": "",

  "Confidence": "High/Medium/Low"
}}
"""

    user_prompt = f"""
INPUT PARAGRAPH:
{paragraph}

DESIGNATORS (regex extracted):
{designators if designators else "None"}

RULE-BASED PILOTING:
{piloting_rule}

RAG SIMILAR EXAMPLES:
{json.dumps(rag_examples, indent=2)}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        result = json.loads(completion.choices[0].message.content)

        # Piloting override
        result["System Piloting"] = piloting_rule
        if not result.get("System Piloting Reason"):
            result["System Piloting Reason"] = "Piloting derived using deterministic rules."
        if not result.get("System Piloting Evidence"):
            result["System Piloting Evidence"] = "Not Found"

        # Ensure flat
        for k, v in result.items():
            if isinstance(v, (dict, list)):
                result[k] = str(v)

        return result

    except Exception as e:
        return {
            "Market Segment": "",
            "System Type (General)": "",
            "System Type (Specific)": "",
            "System Name (General)": "",
            "System Name (Specific)": "",
            "System Piloting": piloting_rule,
            "Market Segment Evidence": "Not Found",
            "System Type (General) Evidence": "Not Found",
            "System Type (Specific) Evidence": "Not Found",
            "System Name (General) Evidence": "Not Found",
            "System Name (Specific) Evidence": "Not Found",
            "System Piloting Evidence": "Not Found",
            "Market Segment Reason": "",
            "System Type (General) Reason": "",
            "System Type (Specific) Reason": "",
            "System Name (General) Reason": "",
            "System Name (Specific) Reason": "",
            "System Piloting Reason": "Piloting derived using deterministic rules.",
            "Confidence": "Low",
            "Error": str(e)
        }


# --- AGENT 4: CONTRACT (Supplier Match FIXED ‚úÖ‚úÖ‚úÖ) ---
class ContractInfoInput(BaseModel):
    paragraph: str = Field(description="Contract text.")
    contract_date: str = Field(description="Signed date.")

@tool("contract_extractor")
def contract_extractor(paragraph: str, contract_date: str):
    system_instruction = """
You are a Defense Contract Financial Analyst. Extract data strictly following these SOP rules:

1) Supplier Name: Extract exact company name as written in paragraph.
2) Program Type: Procurement / Training / MRO/Support / RDT&E / Upgrade / Other Service
3) Value Certainty:
   - Confirmed: fixed price/obligated stated
   - Estimated: ceiling/potential/IDIQ/multi-award
4) Quantity: number or Not Applicable
5) G2G/B2G:
   - G2G only if Foreign Military Sales (FMS)
   - else B2G
6) Completion Date: only needed for MRO duration calc

Return JSON only.
"""

    user_prompt = f"""
Analyze contract:
"{paragraph}"

Signed Date: {contract_date}

Return JSON:
{{
  "raw_supplier_name": "",
  "program_type": "",
  "value_million_raw": "",
  "currency_code": "",
  "value_certainty": "",
  "quantity": "",
  "completion_date_text": "",
  "g2g_b2g": "",
  "value_note": ""
}}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_instruction},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
    except Exception as e:
        return {"Error": str(e)}

    raw_supplier = str(raw.get("raw_supplier_name", "")).strip()

    # ‚úÖ Fallback supplier extraction if LLM returns blank
    if not raw_supplier:
        m = re.search(r"^(.*?)( is awarded| is being awarded)", paragraph, flags=re.IGNORECASE)
        if m:
            raw_supplier = m.group(1).strip()

    # ‚úÖ FINAL supplier must come from suppliers.json best match
    final_supplier, supplier_match_score = get_best_supplier_match(raw_supplier)

    prog_type = raw.get("program_type", "Unknown")
    mro_months = calculate_mro_months(contract_date, raw.get("completion_date_text"), prog_type)

    try:
        val_str = str(raw.get("value_million_raw", "0")).replace(",", "").replace("$", "")
        val_float = float(val_str)
        val_formatted = "{:.3f}".format(val_float)
    except:
        val_formatted = "0.000"

    try:
        dt = pd.to_datetime(contract_date)
        sign_month = dt.strftime("%B")
        sign_year = str(dt.year)
    except:
        sign_month, sign_year = "Unknown", "Unknown"

    val_note = raw.get("value_note", "Not Applicable")
    if "split" in paragraph.lower() and val_note == "Not Applicable":
        val_note = "Split contract; value distribution unclear."

    return {
        # ‚úÖ Final output supplier from JSON
        "Supplier Name": final_supplier,

        # ‚úÖ Debug (optional but VERY useful)
        "Supplier Name Raw (LLM)": raw_supplier,
        "Supplier Match Score": round(float(supplier_match_score), 3),

        "Program Type": prog_type,
        "Expected MRO Contract Duration (Months)": mro_months,
        "Quantity": raw.get("quantity", "Not Applicable"),
        "Value Certainty": raw.get("value_certainty", "Confirmed"),
        "Value (Million)": val_formatted,
        "Currency": raw.get("currency_code", "USD$"),
        "Value (USD$ Million)": val_formatted,
        "Value Note (If Any)": val_note,
        "G2G/B2G": raw.get("g2g_b2g", "B2G"),
        "Signing Month": sign_month,
        "Signing Year": sign_year
    }


# ==============================================================================
# ‚úÖ 8. LANGGRAPH PIPELINE
# ==============================================================================

class AgentState(TypedDict):
    input_text: str
    input_date: str
    input_url: str
    final_data: dict
    messages: Annotated[List[AnyMessage], add_messages]

def stage_1_sourcing(state: AgentState):
    res = sourcing_extractor.invoke({
        "paragraph": state["input_text"],
        "url": state["input_url"],
        "date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_2_geography(state: AgentState):
    res = geography_extractor.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_3_system(state: AgentState):
    res = system_classifier.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

def stage_4_contract(state: AgentState):
    res = contract_extractor.invoke({
        "paragraph": state["input_text"],
        "contract_date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}

workflow = StateGraph(AgentState)
workflow.add_node("Stage1", stage_1_sourcing)
workflow.add_node("Stage2", stage_2_geography)
workflow.add_node("Stage3", stage_3_system)
workflow.add_node("Stage4", stage_4_contract)

workflow.add_edge(START, "Stage1")
workflow.add_edge("Stage1", "Stage2")
workflow.add_edge("Stage2", "Stage3")
workflow.add_edge("Stage3", "Stage4")
workflow.add_edge("Stage4", END)

app = workflow.compile()


# ==============================================================================
# ‚úÖ 9. EXECUTION & OUTPUT FORMAT
# ==============================================================================

if __name__ == "__main__":

    print(f"üìÇ Loading Input File: {INPUT_EXCEL_PATH}...")

    try:
        df_input = pd.read_excel(INPUT_EXCEL_PATH)

        required_cols = ["Source URL", "Contract Date", "Contract Description"]
        if not all(col in df_input.columns for col in required_cols):
            raise ValueError(f"Excel file must contain columns: {required_cols}")

        print(f"üöÄ Processing {len(df_input)} rows...")
        results = []

        for index, row in df_input.iterrows():
            print(f"   -> Row {index + 1}...")

            desc = str(row["Contract Description"]) if pd.notna(row["Contract Description"]) else ""
            c_date = str(row["Contract Date"]) if pd.notna(row["Contract Date"]) else str(datetime.date.today())
            c_url = str(row["Source URL"]) if pd.notna(row["Source URL"]) else ""

            initial_state = {
                "input_text": desc,
                "input_date": c_date,
                "input_url": c_url,
                "final_data": {},
                "messages": []
            }

            output_state = app.invoke(initial_state)
            results.append(output_state["final_data"])

        df_final = pd.DataFrame(results)

        FINAL_COLUMNS = [
            "Customer Region", "Customer Country", "Customer Operator",
            "Supplier Region", "Supplier Country", "Domestic Content",

            "Market Segment", "Market Segment Evidence", "Market Segment Reason",
            "System Type (General)", "System Type (General) Evidence", "System Type (General) Reason",
            "System Type (Specific)", "System Type (Specific) Evidence", "System Type (Specific) Reason",
            "System Name (General)", "System Name (General) Evidence", "System Name (General) Reason",
            "System Name (Specific)", "System Name (Specific) Evidence", "System Name (Specific) Reason",
            "System Piloting", "System Piloting Evidence", "System Piloting Reason",
            "Confidence",

            # ‚úÖ Supplier Debug (optional but recommended)
            "Supplier Name Raw (LLM)",
            "Supplier Match Score",

            "Supplier Name", "Program Type", "Expected MRO Contract Duration (Months)",
            "Quantity", "Value Certainty", "Value (Million)", "Currency",
            "Value (USD$ Million)", "Value Note (If Any)", "G2G/B2G",
            "Signing Month", "Signing Year",

            "Description of Contract",
            "Additional Notes (Internal Only)",
            "Source Link(s)",
            "Contract Date",
            "Reported Date (By SGA)"
        ]

        df_final = df_final.reindex(columns=FINAL_COLUMNS, fill_value="")
        df_final.to_excel(OUTPUT_EXCEL_PATH, index=False)

        print("\n‚úÖ Processing Complete!")
        print(f"üíæ File saved to: {OUTPUT_EXCEL_PATH}")
        print(df_final.head().to_string())

    except Exception as e:
        print(f"\n‚ùå Error: {e}")


In [None]:
import os
import re
import json
import difflib
import pickle
import datetime
from typing import Annotated, TypedDict, List, Dict, Any

import pandas as pd
import faiss
from dateutil import parser
from dateutil.relativedelta import relativedelta
import getpass

# LangGraph / LangChain
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from openai import OpenAI

# Excel formatting
from openpyxl import load_workbook
from openpyxl.styles import PatternFill, Font


# ==============================================================================
# 0) DEBUG LOGGING HELPERS
# ==============================================================================

def log_block(title: str, content: str):
    """
    Prints a clearly separated block in console logs.

    Why this matters:
    - You want to validate LLM extraction decisions row-by-row.
    - Helps debugging issues in specific stages like System Classification or Split logic.
    """
    print("\n" + "=" * 100)
    print(title)
    print("=" * 100)
    print(content)


# ==============================================================================
# 1) RAG RETRIEVER (FAISS + Metadata)
# ==============================================================================

class SystemKBRetriever:
    """
    RAG Retriever for Defense System Classification.

    Purpose:
    - Loads FAISS vector index (system_kb.faiss)
    - Loads metadata rows (system_kb_meta.pkl)
    - Retrieves top-k similar historical examples
      using semantic embeddings over "Description of Contract".

    Why this improves accuracy:
    - Your taxonomy-based system classification becomes consistent
      because the model sees "known-good labeled examples" from your excel KB.

    Output:
    retrieve(query_text) returns:
      [
        {"score": float, "meta": {all KB columns}},
        ...
      ]
    """

    def __init__(self, kb_dir: str, embed_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
        self.kb_dir = kb_dir
        self.embed_model = embed_model

        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError(
                f"‚ùå KB files not found in: {kb_dir}\n"
                f"Expected:\n- {index_path}\n- {meta_path}\n\n"
                f"Build the KB first using your KB builder script."
            )

        print(f"‚úÖ Loading FAISS Index: {index_path}")
        self.index = faiss.read_index(index_path)

        print(f"‚úÖ Loading KB Metadata: {meta_path}")
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        print(f"‚úÖ KB Loaded rows: {len(self.meta)}")

        self.embedder = None

    def _lazy_load_embedder(self):
        """
        Lazy-load the embedding model only when needed.

        Why:
        - Faster pipeline startup
        - Avoids memory overhead until the first retrieval call
        """
        if self.embedder is None:
            from sentence_transformers import SentenceTransformer
            self.embedder = SentenceTransformer(self.embed_model)

    def retrieve(self, query_text: str, top_k: int = 3):
        """
        Retrieves top-k semantic matches from the KB.

        Parameters:
        - query_text: input paragraph
        - top_k: number of examples

        Returns:
        - List of dicts with score + metadata row.
        """
        import numpy as np

        query_text = str(query_text).strip()
        if not query_text:
            return []

        self._lazy_load_embedder()

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)

        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0:
                continue
            results.append({"score": float(score), "meta": self.meta[idx]})

        return results


# ==============================================================================
# 2) CONFIGURATION & FILE PATHS
# ==============================================================================

TAXONOMY_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\taxonomy.json"
SUPPLIERS_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\suppliers.json"
INPUT_EXCEL_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\source_file.xlsx"
OUTPUT_EXCEL_PATH = "Processed_Defense_Data.xlsx"

RAG_KB_DIR = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\system_kb_store"


# Setup API key once
if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter the LLM Foundry API Key: ")

client = OpenAI(
    api_key=f'{os.environ.get("LLMFOUNDRY_TOKEN")}:my-test-project',
    base_url="https://llmfoundry.straive.com/openai/v1/",
)

retriever = SystemKBRetriever(kb_dir=RAG_KB_DIR)


# ==============================================================================
# 3) LOAD JSON HELPERS
# ==============================================================================

def load_json_file(filename, default_value):
    """
    Loads a JSON file safely.

    Why:
    - Your taxonomy and supplier list must load reliably
    - Prevents pipeline crash if the file is missing
    """
    try:
        with open(filename, "r", encoding="utf-8") as f:
            print(f"‚úÖ Loaded: {filename}")
            return json.load(f)
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not load {filename} ({e}). Using default.")
        return default_value


raw_taxonomy = load_json_file(TAXONOMY_PATH, {})
TAXONOMY_STR = json.dumps(raw_taxonomy, separators=(",", ":"))

SUPPLIER_LIST = load_json_file(SUPPLIERS_PATH, [
    "Dell Inc", "Boeing", "Lockheed Martin", "Raytheon Technologies",
    "Northrop Grumman", "L3Harris", "BAE Systems", "General Dynamics"
])


# ==============================================================================
# 4) RULE BOOK + GEOGRAPHY
# ==============================================================================

RULE_BOOK = {
    "defensive_countermeasures": {
        "triggers": ["flare", "chaff", "countermeasure", "decoy", "mju-", "ale-"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Defensive Systems', Specific: 'Defensive Aid Suite'"
    },
    "radars_and_sensors": {
        "triggers": ["radar", "sonar", "sensor", "an/apy", "an/tpy"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Sensors'"
    },
    "ammunition": {
        "triggers": ["cartridge", "round", "projectile", " 5.56", " 7.62", "ammo"],
        "guidance": "Market Segment: 'Weapon Systems', System Type (General): 'Ammunition'"
    }
}

GEOGRAPHY_MAPPING = {
    "North America": ["USA", "United States", "US", "United States of America", "Canada", "America"],
    "Europe": ["UK", "United Kingdom", "Ukraine", "Germany", "France", "Italy", "Spain", "Poland", "Netherlands",
               "Norway", "Sweden", "Finland", "Denmark", "Belgium"],
    "Asia-Pacific": ["Australia", "Japan", "South Korea", "Taiwan", "India", "Singapore", "New Zealand"],
    "Middle East and North Africa": ["Israel", "Saudi Arabia", "UAE", "United Arab Emirates", "Egypt", "Qatar", "Kuwait", "Iraq"],
    "International Organisations": ["NATO", "EU", "IFU", "UN", "NSPA"]
}


# ==============================================================================
# 5) BASE HELPERS (Supplier + Dates + Region + Designators)
# ==============================================================================

def get_best_supplier_match(extracted_name: str):
    """
    Supplier standardization function.

    Steps:
    1) Exact match with suppliers.json
    2) Fuzzy match (difflib) to find best candidate
    3) If no match, return extracted text

    Goal:
    - Ensure supplier name output matches "standard supplier taxonomy"
      used by your client.
    """
    if not extracted_name or str(extracted_name).strip().lower() in ["unknown", "n/a", "not applicable"]:
        return "Unknown"

    clean_name = str(extracted_name).strip()

    supplier_map = {s.lower(): s for s in SUPPLIER_LIST}
    if clean_name.lower() in supplier_map:
        return supplier_map[clean_name.lower()]

    matches = difflib.get_close_matches(clean_name, SUPPLIER_LIST, n=1, cutoff=0.6)
    return matches[0] if matches else clean_name


def calculate_mro_months(start_date_str, end_date_text, program_type):
    """
    Calculates MRO duration in months.

    Rule:
    - ONLY valid if program_type == "MRO/Support"
    """
    if program_type != "MRO/Support":
        return "Not Applicable"
    try:
        if not start_date_str or not end_date_text:
            return "Not Applicable"

        start = pd.to_datetime(start_date_str, dayfirst=True)
        end = parser.parse(str(end_date_text), fuzzy=True)

        diff = relativedelta(end, start)
        total_months = diff.years * 12 + diff.months
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"


def get_region_for_country(country_name):
    """
    Maps country name -> region string.
    Uses GEOGRAPHY_MAPPING.
    """
    if not country_name or str(country_name).strip().lower() in ["unknown", "n/a", "not applicable"]:
        return "Unknown"

    clean = str(country_name).strip().lower()

    if clean in ["us", "usa", "u.s.", "united states", "united states of america"]:
        return "North America"
    if clean in ["uk", "u.k.", "britain", "great britain"]:
        return "Europe"

    for region, countries in GEOGRAPHY_MAPPING.items():
        if any(c.lower() == clean for c in countries):
            return region

    return "Unknown"


DESIGNATOR_PATTERNS = [
    r"\bDDG[-\s]?\d+\b", r"\bCVN[-\s]?\d+\b", r"\bSSN[-\s]?\d+\b",
    r"\bLCS[-\s]?\d+\b", r"\bLPD[-\s]?\d+\b", r"\bLHA[-\s]?\d+\b", r"\bLHD[-\s]?\d+\b",
    r"\bF-\d+\b", r"\bB-\d+\b", r"\bC-\d+\b", r"\bA-\d+\b",
    r"\bMQ-\d+\b", r"\bRQ-\d+\b",
    r"\bAN\/[A-Z0-9\-]+\b",
    r"\b(AIM|AGM|SM|RIM|MIM)-\d+\b",
]

def extract_designators(text: str):
    """
    Extracts common defense platform designators from paragraph.

    Examples:
    - DDG-51, CVN-78
    - MQ-9, RQ-4
    - AIM-9X
    - AN/APY-10
    """
    text = str(text)
    found = []
    for pat in DESIGNATOR_PATTERNS:
        found.extend(re.findall(pat, text, flags=re.IGNORECASE))
    cleaned = []
    for f in found:
        cleaned.append(f.upper().replace(" ", "").replace("--", "-"))
    final = []
    seen = set()
    for x in cleaned:
        if x not in seen:
            final.append(x)
            seen.add(x)
    return final


def detect_piloting_rule_based(text: str, designators: List[str]) -> str:
    """
    Deterministic piloting classification to reduce model errors.

    Output:
    - "Crewed"
    - "Uncrewed"
    - "Not Applicable"
    """
    t = str(text).lower()

    if any(d.startswith(("MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    if any(k in t for k in ["unmanned", "uav", "drone", "autonomous"]):
        return "Uncrewed"

    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD")) for d in designators):
        return "Crewed"
    if "uss " in t:
        return "Crewed"

    return "Not Applicable"


# ==============================================================================
# 6) ENHANCED SPLIT ENGINE (Multi-operator / Multi-country / Multi-supplier / Multi-value)
# ==============================================================================

def parse_operator_quantity_allocations(paragraph: str):
    """
    Detects quantity allocations by operator.

    Example patterns:
      - "212 for the Navy"
      - "187 for the Air Force"
      - "84 for Foreign Military Sales (FMS) customers"

    Output:
      [
        {"operator": "Navy", "quantity": "212", "g2g_b2g": "B2G"},
        {"operator": "Foreign Assistance", "quantity": "84", "g2g_b2g": "G2G"}
      ]
    """
    text = str(paragraph)
    allocations = []

    pattern = r"(\d+)\s+for\s+the\s+(Navy|Air Force|Army|Marine Corps)"
    matches = re.findall(pattern, text, flags=re.IGNORECASE)
    for qty, op in matches:
        allocations.append({"operator": op.title(), "quantity": qty, "g2g_b2g": "B2G"})

    fms_pattern = r"(\d+)\s+for\s+(?:Foreign Military Sales\s*\(FMS\)\s*customers|FMS\s*customers|a\s*FMS\s*customer|FMS)"
    fms_matches = re.findall(fms_pattern, text, flags=re.IGNORECASE)
    for qty in fms_matches:
        allocations.append({"operator": "Foreign Assistance", "quantity": qty, "g2g_b2g": "G2G"})

    unique = []
    seen = set()
    for a in allocations:
        key = (a["operator"], a["quantity"], a["g2g_b2g"])
        if key not in seen:
            unique.append(a)
            seen.add(key)

    return unique


def parse_fms_countries(paragraph: str):
    """
    Extracts FMS customer country list.

    Looks for:
      'governments of Australia, Bahrain, Belgium...'

    Output: ["Australia", "Bahrain", ...]
    """
    text = str(paragraph)

    m = re.search(
        r"governments of (.+?)(?:\.\s| Work will be performed| Fiscal| This contract|$)",
        text,
        flags=re.IGNORECASE | re.DOTALL
    )
    if not m:
        return []

    block = m.group(1)
    raw = re.split(r",|\band\b", block)

    countries = []
    for c in raw:
        c = c.strip()
        if 2 < len(c) <= 40:
            countries.append(c)

    final = []
    seen = set()
    for c in countries:
        if c.lower() not in seen:
            final.append(c)
            seen.add(c.lower())

    return final


def parse_multiple_suppliers(paragraph: str):
    """
    Attempts to detect multi-supplier contract statements.

    Examples:
    - "Lockheed Martin and Raytheon were awarded..."
    - "Boeing, Northrop Grumman, and General Dynamics..."
    - "multiple awardees include..."

    Output:
      ["Lockheed Martin", "Raytheon"]
    """
    text = str(paragraph)
    lower = text.lower()

    if " and " not in lower and "," not in lower:
        return []

    candidates = []
    for supplier in SUPPLIER_LIST:
        if supplier.lower() in lower:
            candidates.append(supplier)

    # If multiple suppliers found ‚Üí split required
    candidates = list(dict.fromkeys(candidates))
    return candidates if len(candidates) >= 2 else []


def parse_multiple_values(paragraph: str):
    """
    Detects multiple financial values in paragraph.

    Examples:
    - "$328,156,454"
    - "ceiling value of $500 million"
    - "base value $20 million and option value $10 million"

    Output:
      ["328,156,454", "500", ...] (raw strings)
    """
    text = str(paragraph)

    # $328,156,454
    money_pattern = r"\$([\d,]+(?:\.\d+)?)"
    vals = re.findall(money_pattern, text)

    # remove duplicates
    vals = list(dict.fromkeys(vals))
    return vals


def parse_share_percentages(paragraph: str):
    """
    Detects share splits like:
      '60% for Company A and 40% for Company B'

    Output:
      [
        {"supplier": "Company A", "percentage": "60"},
        {"supplier": "Company B", "percentage": "40"}
      ]
    """
    text = str(paragraph)
    result = []

    percent_pattern = r"(\d+)\s?%\s+(?:for\s+)?([A-Z][A-Za-z0-9&\-\.\s]+)"
    matches = re.findall(percent_pattern, text)

    for pct, name in matches:
        name = name.strip()
        # Try to best-match supplier name list
        supplier_std = get_best_supplier_match(name)
        result.append({"supplier": supplier_std, "percentage": pct})

    # keep only valid multi splits
    if len(result) >= 2:
        return result
    return []


def detect_multi_country_presence(paragraph: str):
    """
    Lightweight country scan from mapping lists to detect multi-country mention.
    """
    text = str(paragraph).lower()
    countries_found = set()

    for region, countries in GEOGRAPHY_MAPPING.items():
        for c in countries:
            if c.lower() in text:
                countries_found.add(c)

    return list(countries_found)


def split_rows_engine(base_row: dict, paragraph: str):
    """
    MASTER SPLIT ENGINE.

    Goal:
    - Takes one extracted row (base_row)
    - Produces 1..N output rows depending on split conditions

    Split triggers supported:
    1) Operator allocation splits (Navy/Air Force/FMS)
    2) Multi-country FMS customer lists
    3) Multi supplier mentions
    4) Multi values inside same paragraph
    5) Percent share splits
    6) Multi-country region scan

    Important:
    - Only split-driving columns should change.
    - Shared columns remain consistent.
    """
    paragraph = str(paragraph)

    allocations = parse_operator_quantity_allocations(paragraph)
    fms_countries = parse_fms_countries(paragraph)
    multi_suppliers = parse_multiple_suppliers(paragraph)
    multi_values = parse_multiple_values(paragraph)
    share_splits = parse_share_percentages(paragraph)
    multi_countries = detect_multi_country_presence(paragraph)

    split_reasons = []

    if allocations:
        split_reasons.append("Multi-operator allocation found")
    if fms_countries:
        split_reasons.append("FMS multi-country list found")
    if multi_suppliers:
        split_reasons.append("Multi-supplier mention found")
    if len(multi_values) >= 2:
        split_reasons.append("Multiple financial values found")
    if share_splits:
        split_reasons.append("Percentage share split found")
    if len(multi_countries) >= 2:
        split_reasons.append("Multiple countries detected in text")

    if not split_reasons:
        base_row["Split Flag"] = "No"
        base_row["Split Reason"] = "No split condition found"
        return [base_row]

    # Start with one base row
    rows = [base_row.copy()]
    base_reason = " | ".join(split_reasons)

    # 1) Multi supplier split
    if multi_suppliers:
        new_rows = []
        for r in rows:
            for s in multi_suppliers:
                rr = r.copy()
                rr["Supplier Name"] = s
                rr["Split Flag"] = "Yes"
                rr["Split Reason"] = f"{base_reason} (Supplier split)"
                new_rows.append(rr)
        rows = new_rows

    # 2) Operator split
    if allocations:
        new_rows = []
        for r in rows:
            for alloc in allocations:
                rr = r.copy()
                rr["Customer Operator"] = alloc["operator"]
                rr["Quantity"] = alloc["quantity"]
                rr["G2G/B2G"] = alloc["g2g_b2g"]
                rr["Split Flag"] = "Yes"
                rr["Split Reason"] = f"{base_reason} (Operator/Quantity split)"
                new_rows.append(rr)
        rows = new_rows

    # 3) Share split ‚Üí can modify Value Note
    if share_splits:
        new_rows = []
        for r in rows:
            for sh in share_splits:
                rr = r.copy()
                rr["Supplier Name"] = sh["supplier"]
                rr["Value Note (If Any)"] = f"Share split detected: {sh['percentage']}%"
                rr["Split Flag"] = "Yes"
                rr["Split Reason"] = f"{base_reason} (Share % split)"
                new_rows.append(rr)
        rows = new_rows

    # 4) Multi financial values split (optional)
    # Here we do NOT override your Value (Million) because that comes from contract_extractor,
    # but we store in Value Note as cross-reference.
    if len(multi_values) >= 2:
        for r in rows:
            note = r.get("Value Note (If Any)", "Not Applicable")
            r["Value Note (If Any)"] = f"{note} | Multiple values detected: {multi_values[:5]}"

    # 5) FMS country split applied ONLY when G2G rows exist
    if fms_countries:
        final_rows = []
        for r in rows:
            if r.get("G2G/B2G") == "G2G":
                for c in fms_countries:
                    rr = r.copy()
                    rr["Customer Country"] = c
                    rr["Customer Region"] = get_region_for_country(c)
                    rr["Split Flag"] = "Yes"
                    rr["Split Reason"] = f"{base_reason} (FMS country split)"
                    final_rows.append(rr)
            else:
                final_rows.append(r)
        rows = final_rows

    # Always ensure flags exist
    for r in rows:
        r.setdefault("Split Flag", "Yes")
        r.setdefault("Split Reason", base_reason)

    return rows


# ==============================================================================
# 7) AGENTS / TOOLS
# ==============================================================================

# --- Stage 1: Sourcing ---
class SourcingInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    url: str = Field(description="Source URL of the contract announcement/news.")
    date: str = Field(description="Contract date in Excel (string).")

@tool("sourcing_extractor")
def sourcing_extractor(paragraph: str, url: str, date: str):
    """
    Stage 1: SOURCING EXTRACTOR

    Goal:
    - Build the base skeleton row containing:
      - Description of Contract (raw paragraph)
      - Source Link(s)
      - Contract Date
      - Reported Date (By SGA)
      - Additional Notes (Internal Only)

    Why needed:
    - These columns remain SAME even if contract splits into multiple rows.
    """
    reported_date = datetime.datetime.now().strftime("%Y-%m-%d")

    notes = "Standard extraction."
    if "modification" in str(paragraph).lower():
        notes = "Contract Modification."
    if "split" in str(paragraph).lower():
        notes = "Split award detected."

    return {
        "Description of Contract": paragraph,
        "Additional Notes (Internal Only)": notes,
        "Source Link(s)": url,
        "Contract Date": date,
        "Reported Date (By SGA)": reported_date
    }


# --- Stage 2: Geography ---
class GeographyInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")

@tool("geography_extractor")
def geography_extractor(paragraph: str):
    """
    Stage 2: GEOGRAPHY EXTRACTOR

    Goal:
    - Identify:
      - Customer Country
      - Customer Operator
      - Supplier Country
    - Derive:
      - Customer Region
      - Supplier Region
      - Domestic Content (Indigenous vs Imported)

    Why needed:
    - Geography can be split-driving when multiple customer countries exist.
    """
    sys_prompt = """
Extract: Customer Country, Supplier Country, Customer Operator.
Logic: If 'Navy awarded...', operator is Navy.
Return JSON:
{
  "Customer Country": "...",
  "Customer Operator": "...",
  "Supplier Country": "..."
}
"""
    log_block("HUMAN MESSAGE (Stage2 - Geography)", paragraph)

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": sys_prompt},
                      {"role": "user", "content": paragraph}],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage2 - Geography)", json.dumps(raw, indent=2))
    except Exception as e:
        raw = {}
        log_block("AI ERROR (Stage2 - Geography)", str(e))

    cust = raw.get("Customer Country", "Unknown")
    supp = raw.get("Supplier Country", "Unknown")

    domestic = "Indigenous" if str(cust).lower() == str(supp).lower() else "Imported"

    return {
        "Customer Region": get_region_for_country(cust),
        "Customer Country": cust,
        "Customer Operator": raw.get("Customer Operator", "Unknown"),
        "Supplier Region": get_region_for_country(supp),
        "Supplier Country": supp,
        "Domestic Content": domestic
    }


# --- Stage 3: System Classification (RAG + Evidence + Reason) ---
class SystemInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")

@tool("system_classifier")
def system_classifier(paragraph: str):
    """
    Stage 3: SYSTEM CLASSIFIER (RAG-ENHANCED)

    Goal:
    - Extract system-level labels using:
      ‚úÖ Taxonomy reference
      ‚úÖ Rule book triggers
      ‚úÖ RAG similar examples
      ‚úÖ Deterministic piloting override

    Output:
    - Adds Evidence + Reason for:
      - Market Segment
      - System Type (General)
      - System Type (Specific)
      - System Name (General)
      - System Name (Specific)
      - System Piloting

    Why this matters:
    - Your biggest accuracy issues were in:
      Market, System Type, System Name, System Piloting
    - RAG makes results consistent with your labeled history dataset
    """
    paragraph = str(paragraph).strip()
    if not paragraph:
        return {}

    log_block("HUMAN MESSAGE (Stage3 - System)", paragraph)

    lower_text = paragraph.lower()
    hints = [
        f"RULE: {v['guidance']}"
        for _, v in RULE_BOOK.items()
        if any(t in lower_text for t in v["triggers"])
    ]
    hint_str = "\n".join(hints) if hints else "No special override rules triggered."

    designators = extract_designators(paragraph)
    piloting_rule = detect_piloting_rule_based(paragraph, designators)

    rag_hits = retriever.retrieve(paragraph, top_k=3)
    rag_examples = []
    for hit in rag_hits:
        meta = hit["meta"]
        rag_examples.append({
            "score": round(hit["score"], 4),
            "Market Segment": meta.get("Market Segment", ""),
            "System Type (General)": meta.get("System Type (General)", ""),
            "System Type (Specific)": meta.get("System Type (Specific)", ""),
            "System Name (General)": meta.get("System Name (General)", ""),
            "System Name (Specific)": meta.get("System Name (Specific)", ""),
            "System Piloting": meta.get("System Piloting", ""),
            "Snippet": meta.get("Description of Contract", "")[:220] + "..."
        })

    sys_prompt = f"""
You are a Senior Defense System Classification Analyst.

REFERENCE TAXONOMY:
{TAXONOMY_STR}

RULE BOOK OVERRIDES:
{hint_str}

OUTPUT RULES:
- Return ONLY a FLAT JSON object.
- Every value must be a STRING.
- Do NOT return nested objects or lists.
- Evidence must be copied EXACTLY from paragraph text.
- If evidence not present, output "Not Found".

Return JSON:
{{
  "Market Segment": "",
  "Market Segment Evidence": "",
  "Market Segment Reason": "",

  "System Type (General)": "",
  "System Type (General) Evidence": "",
  "System Type (General) Reason": "",

  "System Type (Specific)": "",
  "System Type (Specific) Evidence": "",
  "System Type (Specific) Reason": "",

  "System Name (General)": "",
  "System Name (General) Evidence": "",
  "System Name (General) Reason": "",

  "System Name (Specific)": "",
  "System Name (Specific) Evidence": "",
  "System Name (Specific) Reason": "",

  "System Piloting": "",
  "System Piloting Evidence": "",
  "System Piloting Reason": "",

  "Confidence": "High/Medium/Low"
}}
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

DESIGNATORS (regex extracted):
{designators if designators else "None"}

RULE BASED PILOTING:
{piloting_rule}

RAG EXAMPLES:
{json.dumps(rag_examples, indent=2)}
"""

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": sys_prompt},
                      {"role": "user", "content": user_prompt}],
            temperature=0,
            response_format={"type": "json_object"}
        )
        result = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage3 - System)", json.dumps(result, indent=2))

        # Hard override piloting (best accuracy)
        result["System Piloting"] = piloting_rule
        result.setdefault("System Piloting Evidence", "Not Found")
        result.setdefault("System Piloting Reason", "Derived from deterministic piloting rules.")

        # Ensure flat
        for k, v in result.items():
            if isinstance(v, (dict, list)):
                result[k] = str(v)

        return result

    except Exception as e:
        log_block("AI ERROR (Stage3 - System)", str(e))
        return {
            "Market Segment": "",
            "Market Segment Evidence": "Not Found",
            "Market Segment Reason": "",

            "System Type (General)": "",
            "System Type (General) Evidence": "Not Found",
            "System Type (General) Reason": "",

            "System Type (Specific)": "",
            "System Type (Specific) Evidence": "Not Found",
            "System Type (Specific) Reason": "",

            "System Name (General)": "",
            "System Name (General) Evidence": "Not Found",
            "System Name (General) Reason": "",

            "System Name (Specific)": "",
            "System Name (Specific) Evidence": "Not Found",
            "System Name (Specific) Reason": "",

            "System Piloting": piloting_rule,
            "System Piloting Evidence": "Not Found",
            "System Piloting Reason": "Derived from deterministic piloting rules.",

            "Confidence": "Low",
            "Error": str(e)
        }


# --- Stage 4: Contract Extractor ---
class ContractInfoInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    contract_date: str = Field(description="Contract date as string.")

@tool("contract_extractor")
def contract_extractor(paragraph: str, contract_date: str):
    """
    Stage 4: CONTRACT EXTRACTOR

    Goal:
    - Extract the contract financial + program details:
      - Supplier Name (raw)
      - Program Type
      - Quantity
      - Value (Million)
      - Currency
      - Value Certainty
      - G2G/B2G
      - Completion Date Text

    Critical improvement:
    - Supplier Name returned from model is ALWAYS standardized
      using suppliers.json fuzzy match.
    """
    system_instruction = """
You are a Defense Contract Financial Analyst.

Rules:
1) raw_supplier_name: extract exact supplier from paragraph
2) program_type: Procurement/Training/MRO/Support/RDT&E/Upgrade/Other Service
3) value_certainty: Confirmed vs Estimated
4) quantity: extract numeric units if found else Not Applicable
5) g2g_b2g: G2G only if FMS mentioned else B2G
6) completion_date_text: for MRO calc only

Return JSON only.
"""

    user_prompt = f"""
PARAGRAPH:
{paragraph}

SIGNED DATE:
{contract_date}

Return JSON:
{{
  "raw_supplier_name": "",
  "program_type": "",
  "value_million_raw": "",
  "currency_code": "",
  "value_certainty": "",
  "quantity": "",
  "completion_date_text": "",
  "g2g_b2g": "",
  "value_note": ""
}}
"""

    log_block("HUMAN MESSAGE (Stage4 - Contract)", paragraph)

    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": system_instruction},
                      {"role": "user", "content": user_prompt}],
            temperature=0,
            response_format={"type": "json_object"}
        )
        raw = json.loads(completion.choices[0].message.content)
        log_block("AI RESPONSE (Stage4 - Contract)", json.dumps(raw, indent=2))
    except Exception as e:
        log_block("AI ERROR (Stage4 - Contract)", str(e))
        return {"Error": str(e)}

    final_supplier = get_best_supplier_match(raw.get("raw_supplier_name"))

    prog_type = raw.get("program_type", "Unknown")
    mro_months = calculate_mro_months(contract_date, raw.get("completion_date_text"), prog_type)

    try:
        val_str = str(raw.get("value_million_raw", "0")).replace(",", "").replace("$", "")
        val_float = float(val_str)
        val_formatted = "{:.3f}".format(val_float)
    except:
        val_formatted = "0.000"

    try:
        dt = pd.to_datetime(contract_date)
        sign_month = dt.strftime("%B")
        sign_year = str(dt.year)
    except:
        sign_month, sign_year = "Unknown", "Unknown"

    return {
        "Supplier Name": final_supplier,
        "Program Type": prog_type,
        "Expected MRO Contract Duration (Months)": mro_months,
        "Quantity": raw.get("quantity", "Not Applicable"),
        "Value Certainty": raw.get("value_certainty", "Confirmed"),
        "Value (Million)": val_formatted,
        "Currency": raw.get("currency_code", "USD$"),
        "Value (USD$ Million)": val_formatted,
        "Value Note (If Any)": raw.get("value_note", "Not Applicable"),
        "G2G/B2G": raw.get("g2g_b2g", "B2G"),
        "Signing Month": sign_month,
        "Signing Year": sign_year
    }


# --- Stage 5: Split Agent ---
class SplitterInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    base_row: dict = Field(description="Final extracted row after Stage1-4.")

@tool("splitter_agent")
def splitter_agent(paragraph: str, base_row: dict):
    """
    Stage 5: SPLITTER AGENT

    Goal:
    - Detect whether the extracted row must be split into multiple output rows.
    - Uses split_rows_engine() to apply deterministic split logic.
    - Ensures the split matches patterns found in your sample_data.

    Output:
    - Returns {"rows": [row1, row2, ...]}
    """
    try:
        rows = split_rows_engine(base_row, paragraph)
        for r in rows:
            r.setdefault("Split Flag", "No")
            r.setdefault("Split Reason", "")
        return {"rows": rows}
    except Exception as e:
        base_row["Split Flag"] = "Error"
        base_row["Split Reason"] = f"Split failed: {str(e)}"
        return {"rows": [base_row]}


# ==============================================================================
# 8) LANGGRAPH PIPELINE
# ==============================================================================

class AgentState(TypedDict):
    """
    State object passed between LangGraph stages.

    Contains:
    - Raw inputs (text, date, url)
    - Aggregated extraction dict (final_data)
    - Final output rows (final_rows) after splitting
    """
    input_text: str
    input_date: str
    input_url: str
    final_data: dict
    final_rows: list
    messages: Annotated[List[AnyMessage], add_messages]


def stage_1_sourcing(state: AgentState):
    """
    Node Stage1: Runs sourcing_extractor tool and updates final_data.
    """
    res = sourcing_extractor.invoke({
        "paragraph": state["input_text"],
        "url": state["input_url"],
        "date": state["input_date"],
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_2_geography(state: AgentState):
    """
    Node Stage2: Extracts geography fields and updates final_data.
    """
    res = geography_extractor.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_3_system(state: AgentState):
    """
    Node Stage3: System classification using RAG + evidence + reason.
    """
    res = system_classifier.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_4_contract(state: AgentState):
    """
    Node Stage4: Extracts supplier/program/financial/quantity.
    Supplier name is standardized via suppliers.json fuzzy match.
    """
    res = contract_extractor.invoke({
        "paragraph": state["input_text"],
        "contract_date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_5_split(state: AgentState):
    """
    Node Stage5: Applies split logic to create 1..N rows based on paragraph.
    """
    res = splitter_agent.invoke({
        "paragraph": state["input_text"],
        "base_row": state["final_data"]
    })
    return {"final_rows": res.get("rows", [state["final_data"]])}


workflow = StateGraph(AgentState)
workflow.add_node("Stage1", stage_1_sourcing)
workflow.add_node("Stage2", stage_2_geography)
workflow.add_node("Stage3", stage_3_system)
workflow.add_node("Stage4", stage_4_contract)
workflow.add_node("Stage5", stage_5_split)

workflow.add_edge(START, "Stage1")
workflow.add_edge("Stage1", "Stage2")
workflow.add_edge("Stage2", "Stage3")
workflow.add_edge("Stage3", "Stage4")
workflow.add_edge("Stage4", "Stage5")
workflow.add_edge("Stage5", END)

app = workflow.compile()


# ==============================================================================
# 9) GRAPH VISUALIZATION (OFFLINE SAFE)
# ==============================================================================

def export_workflow_mermaid(app_obj, out_file="workflow.mmd"):
    """
    Exports Mermaid graph as TEXT locally (no mermaid.ink required).

    Why:
    - Office laptops often block mermaid.ink API calls
    - You still need graph visualization / documentation

    Output:
    - A workflow.mmd file (Mermaid format)
    """
    mmd = app_obj.get_graph().draw_mermaid()
    with open(out_file, "w", encoding="utf-8") as f:
        f.write(mmd)
    print(f"‚úÖ Workflow Mermaid saved locally: {out_file}")
    return out_file


# ==============================================================================
# 10) EXCEL HIGHLIGHTING FEATURE
# ==============================================================================

def highlight_evidence_reason_columns(excel_path: str):
    """
    Highlights Evidence + Reason columns in output Excel.

    Evidence Columns:
    - colored light yellow
    Reason Columns:
    - colored light blue

    Goal:
    - Your team can validate 'why this label was chosen'
      without confusion.
    """
    wb = load_workbook(excel_path)
    ws = wb.active

    header = [cell.value for cell in ws[1]]

    evidence_cols = []
    reason_cols = []

    for idx, col_name in enumerate(header, start=1):
        if isinstance(col_name, str) and "Evidence" in col_name:
            evidence_cols.append(idx)
        if isinstance(col_name, str) and "Reason" in col_name:
            reason_cols.append(idx)

    evidence_fill = PatternFill(start_color="FFF2CC", end_color="FFF2CC", fill_type="solid")  # Yellow
    reason_fill = PatternFill(start_color="D9E1F2", end_color="D9E1F2", fill_type="solid")    # Blue
    header_font = Font(bold=True)

    for col_idx in evidence_cols:
        ws.cell(row=1, column=col_idx).fill = evidence_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for col_idx in reason_cols:
        ws.cell(row=1, column=col_idx).fill = reason_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for row in range(2, ws.max_row + 1):
        for col_idx in evidence_cols:
            ws.cell(row=row, column=col_idx).fill = evidence_fill
        for col_idx in reason_cols:
            ws.cell(row=row, column=col_idx).fill = reason_fill

    wb.save(excel_path)
    print("‚úÖ Evidence + Reason columns highlighted successfully.")


# ==============================================================================
# 11) MAIN EXECUTION
# ==============================================================================

if __name__ == "__main__":

    print(f"\nüìå Loading Input File: {INPUT_EXCEL_PATH}")

    # Offline safe workflow graph
    export_workflow_mermaid(app, out_file="workflow.mmd")

    try:
        df_input = pd.read_excel(INPUT_EXCEL_PATH)

        required_cols = ["Source URL", "Contract Date", "Contract Description"]
        if not all(col in df_input.columns for col in required_cols):
            raise ValueError(f"Excel file must contain columns: {required_cols}")

        print(f"üöÄ Processing {len(df_input)} rows...")
        results = []

        for index, row in df_input.iterrows():
            print(f"\nüîπ Row {index + 1}/{len(df_input)}")

            desc = str(row["Contract Description"]) if pd.notna(row["Contract Description"]) else ""
            c_date = str(row["Contract Date"]) if pd.notna(row["Contract Date"]) else str(datetime.date.today())
            c_url = str(row["Source URL"]) if pd.notna(row["Source URL"]) else ""

            initial_state = {
                "input_text": desc,
                "input_date": c_date,
                "input_url": c_url,
                "final_data": {},
                "final_rows": [],
                "messages": []
            }

            output_state = app.invoke(initial_state)

            rows = output_state.get("final_rows", [])
            if not rows:
                rows = [output_state.get("final_data", {})]

            results.extend(rows)

        df_final = pd.DataFrame(results)

        FINAL_COLUMNS = [
            "Customer Region", "Customer Country", "Customer Operator",
            "Supplier Region", "Supplier Country", "Domestic Content",

            "Split Flag", "Split Reason",

            "Market Segment", "Market Segment Evidence", "Market Segment Reason",
            "System Type (General)", "System Type (General) Evidence", "System Type (General) Reason",
            "System Type (Specific)", "System Type (Specific) Evidence", "System Type (Specific) Reason",
            "System Name (General)", "System Name (General) Evidence", "System Name (General) Reason",
            "System Name (Specific)", "System Name (Specific) Evidence", "System Name (Specific) Reason",
            "System Piloting", "System Piloting Evidence", "System Piloting Reason",
            "Confidence",

            "Supplier Name", "Program Type", "Expected MRO Contract Duration (Months)",
            "Quantity", "Value Certainty", "Value (Million)", "Currency",
            "Value (USD$ Million)", "Value Note (If Any)", "G2G/B2G",
            "Signing Month", "Signing Year",

            "Description of Contract",
            "Additional Notes (Internal Only)",
            "Source Link(s)",
            "Contract Date",
            "Reported Date (By SGA)"
        ]

        df_final = df_final.reindex(columns=FINAL_COLUMNS, fill_value="")
        df_final.to_excel(OUTPUT_EXCEL_PATH, index=False)

        # Highlight Evidence + Reason
        highlight_evidence_reason_columns(OUTPUT_EXCEL_PATH)

        print("\n‚úÖ Processing Complete!")
        print(f"üíæ Output File Saved: {OUTPUT_EXCEL_PATH}")
        print(df_final.head(3).to_string(index=False))

    except Exception as e:
        print(f"\n‚ùå ERROR: {e}")
