# **Stage-1:- Creating a Knowledge Base**

- **Problem Statement**
  - Contract / defense data is stored in Excel but it‚Äôs hard to search intelligently
  - Traditional keyword search fails when the query wording is different (synonyms / rephrasing)
  - Finding the most relevant past contract descriptions takes too much manual effort
  - Even if you find a match, getting the full context (supplier, program, amount, etc.) from the correct row is difficult
  - A scalable system is needed to support quick retrieval + future AI extraction workflows

- **Proposed Solution**
  - Build a Vector Knowledge Base (KB) from the Excel dataset
  - Convert contract descriptions into semantic embeddings using a transformer model
  - Store embeddings in a FAISS vector index for fast similarity-based retrieval
  - Store all original Excel columns as metadata to return complete structured information
  - Enable searching based on meaning, not only exact words

- **Outcome**
  * Your Excel knowledge base will become **searchable by meaning (semantic search)** instead of only keywords
  * When you give a **new contract description/query**, the system will return the **most similar past contracts** instantly
  * You will be able to retrieve the **best-matching row** even if the words are different (synonyms, rephrasing, short forms)
  * Along with the match, you will also get the **complete row details** (Supplier, Program, Amount, Dates, etc.) because metadata is stored
  * Your extraction pipeline will become **more accurate**, since the LLM can be grounded with relevant historical examples
  * It will reduce **manual lookup time**, improve consistency, and make the process scalable as data grows
  * You will have a reusable **system KB (FAISS + metadata files)** that can be loaded anytime without rebuilding every time
  * This becomes a strong base for building an **agentic workflow** like: Retrieve ‚Üí Validate ‚Üí Extract ‚Üí Store ‚Üí Report




In [21]:
import os
import re
import pickle
import pandas as pd
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer

# We use a better model for detailed context
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"

def clean_text(text: str) -> str:
    if pd.isna(text) or text is None:
        return ""
    text = str(text)
    return re.sub(r"\s+", " ", text).strip()

def safe_to_str(x):
    if pd.isna(x):
        return ""
    return str(x).strip()

def build_system_kb_store_enriched(
    excel_path: str,
    save_dir: str = "system_kb_store",
    model_name: str = DEFAULT_MODEL_NAME,
    batch_size: int = 32,
    embed_column: str = "Description of Contract",
):
    os.makedirs(save_dir, exist_ok=True)

    print(f"\nüìÇ Loading Knowledge Base: {excel_path}")
    # Auto-detect file type
    if excel_path.endswith(".csv"):
        df = pd.read_csv(excel_path)
    else:
        df = pd.read_excel(excel_path)
        
    # Normalize column names
    df.columns = [c.strip() for c in df.columns]
    print(f"   Loaded rows={len(df)} cols={len(df.columns)}")

    kb_texts = []
    kb_meta = []

    print("‚ú® Enriching text with important columns...")

    for idx, row in df.iterrows():
        # 1. Get the base description
        desc = clean_text(row.get(embed_column, ""))

        # 2. Extract Important Columns
        # We use .get() so the code doesn't crash if a column is missing
        market = safe_to_str(row.get("Market Segment", ""))
        sys_type = safe_to_str(row.get("System Type (Specific)", ""))
        sys_name = safe_to_str(row.get("System Name (Specific)", ""))
        supplier = safe_to_str(row.get("Supplier Name", ""))
        customer = safe_to_str(row.get("Customer Country", ""))
        program = safe_to_str(row.get("Program Type", ""))
        value_cert = safe_to_str(row.get("Value Certainty", ""))

        # 3. Fallback: If description is empty, build one from metadata
        if not desc:
            desc = f"Contract for {sys_name} ({sys_type}) supplied by {supplier}."

        # 4. Create the "Rich Context" String
        # This is what gets embedded. The model will now "know" these fields.
        enriched_text = (
            f"Market: {market}. "
            f"System: {sys_name} ({sys_type}). "
            f"Supplier: {supplier}. "
            f"Customer: {customer}. "
            f"Program: {program}. "
            f"Certainty: {value_cert}. "
            f"Details: {desc}"
        )

        # 5. Save Metadata
        meta = {"row_id": int(idx), "original_text": desc, "enriched_text": enriched_text}
        
        # Save all columns to metadata for retrieval later
        for col in df.columns:
            meta[col] = safe_to_str(row[col])

        kb_texts.append(enriched_text)
        kb_meta.append(meta)

    print(f"   Prepared {len(kb_texts)} rows for embedding.")

    print(f"üß† Loading Model: {model_name}")
    embedder = SentenceTransformer(model_name)

    print("   Creating embeddings...")
    embeddings = embedder.encode(
        kb_texts, 
        batch_size=batch_size, 
        show_progress_bar=True, 
        normalize_embeddings=True
    )

    embeddings = np.vstack(embeddings).astype("float32")
    dim = embeddings.shape[1]

    # Indexing
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings)

    index_path = os.path.join(save_dir, "system_kb.faiss")
    meta_path = os.path.join(save_dir, "system_kb_meta.pkl")

    faiss.write_index(index, index_path)
    with open(meta_path, "wb") as f:
        pickle.dump(kb_meta, f)

    print("\n‚úÖ System KB Created Successfully!")
    return index_path, meta_path

# Part 2: Retriever (Updated to use Enriched Text)

class SystemKBRetriever:
    def __init__(self, kb_dir="system_kb_store", model_name=DEFAULT_MODEL_NAME):
        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path):
            raise FileNotFoundError("‚ùå KB missing. Build it first.")

        print(f"\nüöÄ Loading Retriever...")
        self.index = faiss.read_index(index_path)
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        self.embedder = SentenceTransformer(model_name)

    def retrieve(self, query_text: str, top_k: int = 5):
        query_text = clean_text(query_text)
        if not query_text: return []

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)
        
        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0: continue
            results.append({"score": float(score), "meta": self.meta[idx]})

        return results

# Run Pipeline
if __name__ == "__main__":
    # UPDATE THIS PATH TO YOUR FILE
    EXCEL_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\sample_data.xlsx"
    KB_DIR = "system_kb_store"

    # 1. Build KB with Enriched Columns
    build_system_kb_store_enriched(
        excel_path=EXCEL_PATH,
        save_dir=KB_DIR,
        embed_column="Description of Contract"
    )

    # 2. Test Search
    r = SystemKBRetriever(kb_dir=KB_DIR)

    # Test query that relies on the columns we just added
    query = "Lockheed Martin fighter jets for the US Navy"
    
    hits = r.retrieve(query, top_k=3)

    print("\n" + "=" * 60)
    print(f"QUERY: {query}")
    print("=" * 60)

    for i, h in enumerate(hits, start=1):
        m = h["meta"]
        print(f"\nüîπ Rank: {i} (Score: {h['score']:.4f})")
        # Print the Enriched text to prove it's working
        print(f"   Context Used: {m['enriched_text'][:200]}...")


üìÇ Loading Knowledge Base: C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\sample_data.xlsx
   Loaded rows=2979 cols=29
‚ú® Enriching text with important columns...
   Prepared 2979 rows for embedding.
üß† Loading Model: sentence-transformers/all-mpnet-base-v2
   Creating embeddings...


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 94/94 [25:48<00:00, 16.47s/it]



‚úÖ System KB Created Successfully!

üöÄ Loading Retriever...

QUERY: Lockheed Martin fighter jets for the US Navy

üîπ Rank: 1 (Score: 0.7450)
   Context Used: Market: Air Platforms. System: Ancillary mission equipment (Fighter). Supplier: Lockheed Martin. Customer: USA. Program: Procurement. Certainty: Confirmed. Details: Lockheed Martin Corp., Lockheed Mar...

üîπ Rank: 2 (Score: 0.7448)
   Context Used: Market: Air Platforms. System: Ancillary Mission Equipment (Fighter). Supplier: Lockheed Martin Aeronautics. Customer: Unknown. Program: Procurement. Certainty: Confirmed. Details: Lockheed Martin Cor...

üîπ Rank: 3 (Score: 0.7445)
   Context Used: Market: Air Platforms. System: Ancillary Mission Equipment (Fighter). Supplier: Lockheed Martin Aeronautics. Customer: USA. Program: Procurement. Certainty: Confirmed. Details: Lockheed Martin Corp., ...


## **Stage 2**

In this stage I will be creating AI agest that can help in extraction of data based on the input. 

In [22]:
import os, sys
import re
import json
import pickle
import difflib
import datetime
from typing import Annotated, TypedDict, List, Dict, Any, Optional, Tuple

import pandas as pd
import faiss
from dateutil import parser
from dateutil.relativedelta import relativedelta
import getpass

# LangGraph / LangChain
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from openai import OpenAI

# Excel formatting
from openpyxl import load_workbook
from openpyxl.styles import PatternFill, Font

sys.path.append(os.getcwd())

from prompts import (
    GEOGRAPHY_PROMPT, 
    SYSTEM_CLASSIFIER_PROMPT, 
    CONTRACT_EXTRACTOR_PROMPT, 
    VALIDATOR_FIX_PROMPT
)

**Configurations and Supporting File Path**

In [None]:
## LLM CLIENT SETUP (OpenRouter)

if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter the LLM Foundry API Key: ")

client = OpenAI(
    api_key=f'{os.environ.get("LLMFOUNDRY_TOKEN")}:my-test-project',
    base_url="https://llmfoundry.straive.com/openai/v1/",
)
OPENROUTER_MODEL = "gpt-4o-mini"

# CONFIGURATION & FILE PATHS
TAXONOMY_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\taxonomy.json"
SUPPLIERS_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\suppliers.json"
INPUT_EXCEL_PATH = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\source_file.xlsx"
OUTPUT_CSV_PATH = "Processed_Defense_Data.csv"
RAG_KB_DIR = r"C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\system_kb_store"

**Helper Functions for Stage-2**

In [23]:
## RULE BOOK + GEOGRAPHY

# CORRECTED: Guidance values now strictly match standard Taxonomy keys
RULE_BOOK = {
    "defensive_countermeasures": {
        "triggers": ["flare", "chaff", "countermeasure", "decoy", "mju-", "ale-"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Defensive Systems', Specific: 'Defensive Aid Suite'"
    },
    "radars_and_sensors": {
        "triggers": ["radar", "sonar", "sensor", "an/apy", "an/tpy"],
        "guidance": "Market Segment: 'C4ISR Systems', System Type (General): 'Sensors'"
    },
    "ammunition": {
        "triggers": ["cartridge", "round", "projectile", " 5.56", " 7.62", "ammo"],
        "guidance": "Market Segment: 'Weapon Systems', System Type (General): 'Ammunition'"
    },
    "infrastructure_labs": {
        "triggers": [
            "lab hardware", 
            "laboratory", 
            "test bed", 
            "facility", 
            "infrastructure", 
            "test environment"
        ],
        # FIXED: "Infrastructure" -> "Infrastructure & Construction" to match Category 1 below
        "guidance": "Market Segment: 'Infrastructure & Construction', System Type (General): 'RDT&E Facilities', System Type (Specific): 'Not Applicable'. Priority Override: Ignore platform mentions (like DDG, F-35) if the deliverable is clearly for a lab or facility."
    },

    # --- Category 1: Construction & Facilities (High Confidence) ---
    "construction_projects": {
        "triggers": [
            "construction of", "design and construction", "paving", "dredging", 
            "renovation of", "roof repair", "hvac", "hangar", "architect-engineer"
        ],
        "guidance": {
            "Market Segment": "Infrastructure & Construction",
            "Program Type": "Construction/Facilities",
            "Value Certainty": "Confirmed"
        }
    },

    # --- Category 2: IT & Software (ESA/ESI/Licensing) ---
    "it_software_licensing": {
        "triggers": [
            "microsoft", "cisco", "software licensing", "perpetual licenses", 
            "enterprise software initiative", "dod esi", "cloud services", "aws", "azure"
        ],
        "guidance": {
            "Market Segment": "ICT & Cyber",
            "System Type (General)": "Software & Licensing",
            "Program Type": "Procurement (Software)",
            "System Piloting": "Not Applicable"
        }
    },

    # --- Category 3: R&D and Engineering Services ---
    "research_development": {
        "triggers": [
            "research", "developing", "studies", "analysis", "modeling and simulation",
            "prototyping", "sbir", "sttr", "darpa", "demonstration"
        ],
        "guidance": {
            "Program Type": "RDT&E",
            "Value Certainty": "Estimated" # R&D is often cost-plus and varies
        }
    },

    # --- Category 4: Logistics & Sustainment ---
    "logistics_support": {
        "triggers": [
            "contractor logistics support", "cls", "sustainment", "supply chain management",
            "performance based logistics", "depot maintenance", "obsolescence"
        ],
        "guidance": {
            "Program Type": "MRO/Support",
            "System Piloting": "Not Applicable" # The service itself is N/A
        }
    },

    # --- Category 5: Specific Weapon System Overrides ---
    "missile_production": {
        "triggers": ["production of lot", "all up round", "guidance units", "tail caps"],
        "guidance": {
            "Program Type": "Procurement",
            "System Piloting": "Not Applicable" # Missiles are not piloted vehicles
        }
    }
}


GEOGRAPHY_MAPPING = {
    "Sub-Saharan Africa": [
        "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cameroon", "Cape Verde",
        "Central African Republic", "Chad", "Congo, Democratic Republic of", "Congo, Republic of",
        "Djibouti", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon", "Gambia",
        "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia",
        "Madagascar", "Malawi", "Mali", "Mauritius", "Mozambique", "Namibia", "Niger",
        "Nigeria", "Rwanda", "Senegal", "Seychelles", "Sierra Leone", "Somalia", "South Africa",
        "South Sudan", "Sudan", "Tanzania", "Togo", "Uganda", "Zambia", "Zimbabwe"
    ],
    "Asia-Pacific": [
        "Australia", "Brunei", "Cambodia", "China", "Hong Kong", "Indonesia", "Japan", "Laos",
        "Malaysia", "Mongolia", "Myanmar", "New Zealand", "North Korea", "Papua New Guinea",
        "Philippines", "Singapore", "South Korea", "Taiwan", "Thailand", "Vietnam"
    ],
    "Europe": [
        "Albania", "Austria", "Belgium", "Bosnia and Herzegovina", "Bulgaria", "Croatia", "Cyprus",
        "Czech Republic", "Denmark", "Estonia", "Finland", "France", "Georgia", "Germany", "Greece",
        "Hungary", "Iceland", "Ireland", "Italy", "Kosovo", "Latvia", "Lithuania", "Luxembourg",
        "Malta", "Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal",
        "Romania", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland", "Turkey",
        "Ukraine", "United Kingdom"
    ],
    "Latin America": [
        "Argentina", "Bahamas", "Barbados", "Belize", "Bolivia", "Brazil", "Chile", "Colombia",
        "Costa Rica", "Cuba", "Curacao", "Dominican Republic", "Ecuador", "El Salvador", "Guatemala",
        "Guyana", "Haiti", "Honduras", "Jamaica", "Mexico", "Nicaragua", "Panama", "Paraguay",
        "Peru", "Suriname", "Trinidad and Tobago", "Uruguay", "Venezuela"
    ],
    "Middle East and North Africa": [
        "Algeria", "Bahrain", "Egypt", "Iran", "Iraq", "Israel", "Jordan", "Kuwait", "Lebanon",
        "Libya", "Mauritania", "Morocco", "Oman", "Qatar", "Saudi Arabia", "Syria", "Tunisia",
        "United Arab Emirates", "Yemen"
    ],
    "North America": ["Canada", "USA"],
    "Russia & CIS": [
        "Armenia", "Azerbaijan", "Belarus", "Kazakhstan", "Kyrgyzstan", "Moldova", "Russia",
        "Tajikistan", "Turkmenistan", "Uzbekistan"
    ],
    "South Asia": [
        "Afghanistan", "Bangladesh", "India", "Maldives", "Nepal", "Pakistan", "Sri Lanka"
    ],
    "Unknown": [
        "Andorra", "Antigua and Barbuda", "Bhutan", "Comoros", "Dominica", "Federated States of Micronesia",
        "Fiji", "Grenada", "Kiribati", "Liechtenstein", "Marshall Islands", "Monaco", "Nauru", "Palau",
        "Palestine", "Puerto Rico", "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines",
        "Samoa", "San Marino", "Sao Tom and Principe", "Solomon Islands", "Timor-Leste", "Tonga", "Tuvalu",
        "Unknown", "Vanuatu", "Vatican City", "Western Sahara"
    ]
}

ALLOWED_OPERATORS = [
    "Army",
    "Navy",
    "Air Force",
    "Defence Wide",
    "Ukraine (Assistance)",
    "Foreign Assistance",
    "Other"
]

PROGRAM_TYPE_ALLOWED = [
    "Procurement",
    "Training",
    "MRO/Support",
    "RDT&E",
    "Upgrade",
    "Other Service"
]

DESIGNATOR_PATTERNS = [
    r"\bDDG[-\s]?\d+\b", r"\bCVN[-\s]?\d+\\b", r"\bSSN[-\s]?\d+\b",
    r"\bLCS[-\s]?\d+\b", r"\bLPD[-\s]?\d+\b", r"\bLHA[-\s]?\d+\b", r"\bLHD[-\s]?\d+\b",
    r"\bF-\d+\b", r"\bB-\d+\b", r"\bC-\d+\b", r"\bA-\d+\b",
    r"\bMQ-\d+\b", r"\bRQ-\d+\b",
    r"\bAN\/[A-Z0-9\-]+\b",
    r"\b(AIM|AGM|SM|RIM|MIM)-\d+\b",
]

In [24]:
def chunk_text(text: str, chunk_size: int = 1800, overlap: int = 250) -> List[str]:
    """
    Deterministic chunking: stable + safe + prevents token overflow.
    - chunk_size: number of characters per chunk
    - overlap: overlapping characters to preserve boundaries
    """
    text = str(text)
    if not text.strip():
        return []

    chunks = []
    start = 0
    n = len(text)

    while start < n:
        end = min(start + chunk_size, n)
        chunk = text[start:end].strip()
        if chunk:
            chunks.append(chunk)

        if end >= n:
            break

        start = max(0, end - overlap)

    return chunks
#--------------------------------------------

def pick_best_non_empty(values: List[str]) -> str:
    """Return first strong value else empty."""
    for v in values:
        if v and str(v).strip() and str(v).strip().lower() not in ["unknown", "n/a", "not applicable", "none"]:
            return str(v).strip()
    return ""
#--------------------------------------------

def normalize_customer_operator(raw_text: str) -> str:
    """
    Strictly maps input text to the allowed SOP drop-down list.
    """
    if not raw_text:
        return "Unknown"
        
    text = str(raw_text).strip().lower()

    # 1. Ukraine (Assistance) - High Priority
    # Detects: Ukraine Security Assistance Initiative, USAI, or explicitly "for Ukraine"
    if "ukraine" in text:
        return "Ukraine (Assistance)"

    # 2. Foreign Assistance
    # Detects: FMS, Foreign Military Sales, Foreign Customers
    if any(k in text for k in ["foreign military sales", "fms", "foreign customer", "foreign government"]):
        return "Foreign Assistance"

    # 3. Air Force (Includes Space Force per SOP)
    if any(k in text for k in ["space force", "ussf", "air force", "usaf"]):
        return "Air Force"

    # 4. Navy (Includes Marine Corps as they are Dept of Navy, unless 'Other' is preferred)
    # Note: SOP didn't explicitly list Marine Corps. Standard practice maps USMC to Navy. 
    # If strict separation is needed, move 'marine' to Other. 
    if any(k in text for k in ["navy", "usn", "marine", "usmc", "naval"]):
        return "Navy"

    # 5. Army
    if any(k in text for k in ["army", "usa ", "u.s. army"]): # Space padding to avoid matching "army" inside other words
        return "Army"

    # 6. Defence Wide
    # Detects: DLA, MDA, DTRA, DARPA, DISA, DCMA
    if any(k in text for k in ["defense logistics", "dla", "missile defense", "mda", "defense wide", "defence wide", "darpa", "disa"]):
        return "Defence Wide"

    # 7. Other
    # Detects: Coast Guard (per SOP), DIA, Joint, Special Ops
    if any(k in text for k in ["coast guard", "uscg", "defense intelligence", "dia", "joint", "special operations", "socom"]):
        return "Other"

    # Fallback to Other if it looks like a government agency but didn't match above
    if len(text) > 2:
        return "Other"

    return "Unknown"

#--------------------------------------------

def load_json_file(filename, default_value):
    try:
        with open(filename, "r", encoding="utf-8") as f:
            print(f"Loaded: {filename}")
            return json.load(f)
    except Exception as e:
        print(f"Warning: Could not load {filename} ({e}). Using default.")
        return default_value
#--------------------------------------------

def normalize_country_name(c: str) -> str:
    """
    Standardizes country names to ensure consistent Region mapping.
    """
    if not c:
        return "Unknown"
    
    t = str(c).strip().lower()
    
    # Common variations map
    mappings = {
        "usa": "USA", "us": "USA", "u.s.": "USA", "united states": "USA", 
        "united states of america": "USA", "america": "USA",
        "uk": "United Kingdom", "u.k.": "United Kingdom", "britain": "United Kingdom", 
        "great britain": "United Kingdom", "england": "United Kingdom",
        "uae": "United Arab Emirates",
        "sk": "South Korea", "rok": "South Korea", "republic of korea": "South Korea",
        "prc": "China", "people's republic of china": "China"
    }

    if t in mappings:
        return mappings[t]
        
    # Return capitalized version if no match found
    return str(c).strip().title()
#--------------------------------------------

def normalize_program_type(pt: str) -> str:
    if not pt:
        return "Other Service"

    t = str(pt).strip().lower()

    if any(k in t for k in ["mro", "support", "maintenance", "repair", "overhaul", "sustainment", "logistics"]):
        return "MRO/Support"
    if "training" in t:
        return "Training"
    if any(k in t for k in ["rdte", "research", "development", "prototype", "test and evaluation"]):
        return "RDT&E"
    if any(k in t for k in ["upgrade", "modernization", "modification"]):
        return "Upgrade"
    if any(k in t for k in ["procure", "buy", "purchase", "production", "delivery"]):
        return "Procurement"

    return "Other Service"
#--------------------------------------------

raw_taxonomy = load_json_file(TAXONOMY_PATH, {})
TAXONOMY_STR = json.dumps(raw_taxonomy, separators=(",", ":"))

SUPPLIER_LIST = load_json_file(SUPPLIERS_PATH, {})
print(SUPPLIER_LIST)

#--------------------------------------------
def get_best_supplier_match(extracted_name: str) -> str:
    """
    Refines the LLM-extracted supplier name against the official SUPPLIER_LIST.
    
    Logic:
    1. Detect "Multiple" scenarios (semicolons, ' and ', etc).
    2. Check for Exact Matches (Case Insensitive).
    3. Check for Substring Matches (e.g., "Raytheon" inside "Raytheon Technologies").
    4. Use Fuzzy Matching for typos (e.g., "Lokheed Martin").
    """
    if not extracted_name or str(extracted_name).lower() in ["unknown", "n/a", ""]:
        return "Unknown"

    raw = str(extracted_name).strip()
    raw_lower = raw.lower()

    # --- RULE 1: DETECT "MULTIPLE" ---
    # If LLM explicitly said "Multiple", keep it.
    if raw_lower == "multiple":
        return "Multiple"
    
    # If the text contains list delimiters, force "Multiple"
    # Example: "Boeing; Lockheed Martin; and Raytheon"
    if ";" in raw or " and " in raw_lower or " vs " in raw_lower:
        return "Multiple"

    # --- PREPARE LIST ---
    # Ensure list is strings and sort by length (longest first to catch specific names before generic ones)
    valid_suppliers = sorted([str(s) for s in SUPPLIER_LIST], key=len, reverse=True)

    # --- RULE 2: EXACT MATCH (Case Insensitive) ---
    for s in valid_suppliers:
        if s.lower() == raw_lower:
            return s

    # --- RULE 3: SUBSTRING MATCH ---
    # Case A: Extracted is inside Official (e.g., Extracted "Raytheon" -> Official "Raytheon Co.")
    for s in valid_suppliers:
        if len(raw) > 4 and raw_lower in s.lower(): 
            return s
            
    # Case B: Official is inside Extracted (e.g., Extracted "The Boeing Company" -> Official "Boeing")
    for s in valid_suppliers:
        if len(s) > 4 and s.lower() in raw_lower:
            return s

    # --- RULE 4: FUZZY MATCH (Typos) ---
    # uses difflib to find close matches (cutoff=0.8 means 80% similarity required)
    matches = difflib.get_close_matches(raw, valid_suppliers, n=1, cutoff=0.8)
    if matches:
        return matches[0]

    # --- FALLBACK ---
    # If no match found in JSON list, return what the LLM found (it might be a new supplier)
    return raw

#--------------------------------------------

def extract_awardee_supplier_strict(paragraph: str) -> Tuple[str, str]:
    text = str(paragraph).strip()

    patterns = [
        r"^([A-Z][A-Za-z0-9&\-\.\s]+?),\s+.*?\s+(?:is|was|has been)\s+awarded\b",
        r"^([A-Z][A-Za-z0-9&\-\.\s]+?),\s+.*?\s+received\s+an?\s+award\b",
    ]

    for pat in patterns:
        m = re.search(pat, text, flags=re.IGNORECASE)
        if m:
            raw_supplier = m.group(1).strip()
            final_supplier = get_best_supplier_match(raw_supplier)
            return final_supplier, raw_supplier

    return "Unknown", "Not Found"

#--------------------------------------------

def calculate_mro_months(start_date_str, end_date_text, program_type):
    if program_type != "MRO/Support":
        return "Not Applicable"
    try:
        if not start_date_str or not end_date_text:
            return "Not Applicable"

        start = pd.to_datetime(start_date_str, dayfirst=True)
        end = parser.parse(str(end_date_text), fuzzy=True)

        diff = relativedelta(end, start)
        total_months = diff.years * 12 + diff.months
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"

#--------------------------------------------

def get_region_for_country(country_name):
    if not country_name or str(country_name).strip().lower() in ["unknown", "n/a", "not applicable"]:
        return "Unknown"

    clean = str(country_name).strip().lower()

    if clean in ["us", "usa", "u.s.", "united states", "united states of america"]:
        return "North America"
    if clean in ["uk", "u.k.", "britain", "great britain"]:
        return "Europe"

    for region, countries in GEOGRAPHY_MAPPING.items():
        if any(c.lower() == clean for c in countries):
            return region

    return "Unknown"

def extract_designators(text: str):
    text = str(text)
    found = []
    for pat in DESIGNATOR_PATTERNS:
        found.extend(re.findall(pat, text, flags=re.IGNORECASE))
    cleaned = [f.upper().replace(" ", "").replace("--", "-") for f in found]
    final, seen = [], set()
    for x in cleaned:
        if x not in seen:
            final.append(x)
            seen.add(x)
    return final


def detect_piloting_rule_based(text: str, designators: List[str]) -> str:
    t = str(text).lower()

    if any(d.startswith(("MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    if any(k in t for k in ["unmanned", "uav", "drone", "autonomous"]):
        return "Uncrewed"

    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD")) for d in designators):
        return "Crewed"
    if "uss " in t:
        return "Crewed"

    return "Not Applicable"

def detect_piloting_strict(text: str, designators: List[str], system_type: str = "") -> str:
    """
    Strictly determines System Piloting based on SOP definitions.
    """
    t = str(text).lower()
    
    # 1. OPTIONAL (Explicitly stated)
    if any(k in t for k in ["optionally manned", "optional piloting", "manned-unmanned teaming", "manned/unmanned"]):
        return "Optional"

    # 2. UNCREWED (UAVs, Drones, Autonomous)
    # Check designators first (Strong signal)
    if any(d.startswith(("MQ-", "RQ-", "XQ-", "MQ-", "RQ-")) for d in designators):
        return "Uncrewed"
    # Check keywords
    if any(k in t for k in ["unmanned", "uav", "uas", "drone", "autonomous", "remotely piloted", "rpa"]):
        return "Uncrewed"

    # 3. CREWED (Requires human pilot/driver)
    # Check designators (Ships, Fighters, Bombers, Transports)
    if any(d.startswith(("DDG", "CVN", "SSN", "LCS", "LPD", "LHA", "LHD", "CG-", "FFG")) for d in designators):
        return "Crewed"
    if any(d.startswith(("F-", "B-", "C-", "A-", "AH-", "UH-", "CH-", "MH-")) and not d.startswith("C-UAS") for d in designators):
        return "Crewed"
    # Check keywords
    if any(k in t for k in ["manned", "crew", "pilot", "cockpit", "fighter aircraft", "helicopter", "submarine", "frigate", "destroyer", "carrier"]):
        # Exclusion: "unmanned" check above handles "unmanned surface vessel"
        if "unmanned" not in t: 
            return "Crewed"

    # 4. NOT APPLICABLE (Default for non-vehicles)
    # If the system is clearly a component, weapon, or support item, it's N/A.
    # We use a broad check for "Not Applicable" candidates.
    na_indicators = [
        "missile", "munition", "bomb", "ammunition", "round", "cartridge", 
        "radar", "sensor", "radio", "software", "training", "simulator", 
        "engine", "spare part", "container", "gun", "artillery", "howitzer",
        "maintenance", "support", "logistics", "service"
    ]
    if any(k in t for k in na_indicators):
        return "Not Applicable"
        
    # If specific system type is known from Stage 3 and it's not a platform
    if system_type and system_type not in ["Air Platform", "Naval Platform", "Land Platform"]:
         return "Not Applicable"

    # Default fallback if ambiguous (often best to leave blank or N/A, but SOP says N/A for non-driving)
    return "Not Applicable"


def normalize_program_type_improved(pt_llm: str, text: str) -> str:
    """
    Improved Program Type logic with expanded keywords and better conflict resolution.
    """
    text_lower = str(text).lower()
    pt_llm = str(pt_llm).strip()

    # 1. RDT&E (Strongest signal)
    if any(k in text_lower for k in ["rdt&e", "research", "development", "prototype", "demonstration", "sbir", "sttr", "study", "design phase"]):
        return "RDT&E"

    # 2. MRO/Support (Strong signal)
    # Added: "sustainment", "cls" (contractor logistics support), "pbl" (performance based logistics)
    if any(k in text_lower for k in ["mro", "sustainment", "maintenance", "repair", "overhaul", "logistics", "support services", "cls", "technical support", "engineering support"]):
        return "MRO/Support"

    # 3. Upgrade (Specific action on existing)
    if any(k in text_lower for k in ["upgrade", "modernization", "retrofit", "modification", "life extension", "sle", "update"]):
        return "Upgrade"

    # 4. Procurement (Broadest category, tricky vs Training)
    # Explicit "Procurement" signals
    procure_signals = ["production", "delivery", "procurement", "acquisition", "supply", "purchase", "manufacture", "fabrication", "assembly"]
    if any(k in text_lower for k in procure_signals):
        return "Procurement"

    # Training HARDWARE check (Simulators -> Procurement)
    if "training" in text_lower or "simulator" in text_lower:
        if any(k in text_lower for k in ["simulator", "device", "trainer", "hardware", "system", "kit", "aids"]):
             return "Procurement"
    
    # 5. Training (Services Only)
    # If it says "training" and wasn't caught by Procurement above
    if "training" in text_lower or "instruction" in text_lower:
        return "Training"

    # 6. Other Service (Catch-all for services)
    if "service" in text_lower or "labor" in text_lower:
        return "Other Service"

    # Fallback: Trust LLM if it found a valid category
    valid_types = ["Procurement", "Training", "MRO/Support", "RDT&E", "Upgrade", "Other Service"]
    if pt_llm in valid_types:
        return pt_llm

    # Default if truly unknown but looks like a contract award
    if "award" in text_lower:
        return "Procurement" # Safe default for "awarded contract for X"

    return "Unknown"


def calculate_value_certainty_strict(text: str, value_str: str) -> str:
    """
    Highly recommends 'Confirmed'. Only 'Estimated' if explicitly stated.
    """
    text_lower = text.lower()
    
    # 1. Strong "Confirmed" Default
    # If we extracted a specific non-zero value, assume Confirmed unless proven otherwise.
    if not value_str or value_str == "0.000":
        return "Estimated" # No value = Estimated/Unknown

    # 2. Explicit Estimation Modifiers (Strict)
    # Only tag 'Estimated' if these words explicitly modify the value context.
    # "Ceiling" implies IDIQ max, which is an estimate of potential spend.
    estimation_signals = ["estimated value", "estimated cost", "ceiling", "maximum value", "potential value", "not to exceed"]
    
    if any(k in text_lower for k in estimation_signals):
        return "Confirmed"

    # 3. IDIQ / BPA (Usually Ceiling = Estimated)
    # However, if user wants "Confirmed" for the stated face value, we can be lenient.
    # SOP says: "Select Estimated when the value is not confirmed and possibility of future modifications."
    # IDIQ ceilings ARE limits, not confirmed spend. So we keep them as Estimated.
    if "indefinite-delivery" in text_lower or "idiq" in text_lower:
         # Often IDIQs have a "face value" (initial task order) vs "ceiling".
         # If the text says "awarded a $X task order", it's confirmed.
         # If it says "awarded a $X IDIQ contract", it's a ceiling (Estimated).
         if "task order" in text_lower or "delivery order" in text_lower:
             return "Estimated" 
         return "Confirmed"

    return "Confirmed"

def clean_money_string(val_str):
    """Converts '$12,345.67' or '12.5' to float 12345.67"""
    try:
        # Remove currency symbols, commas, and whitespace
        clean = re.sub(r'[^\d\.]', '', str(val_str))
        return float(clean)
    except:
        return 0.0

def smart_value_extraction(text: str) -> str:
    """
    Robust extraction: Finds the largest 'Award' value while ignoring 'Obligated' funds.
    Returns string in MILLIONS (e.g. "12.500")
    """
    text = str(text)
    
    # 1. Find ALL dollar patterns with their context indices
    # Matches: $100, $100.00, $100 million, $100,000
    money_pattern = r'\$([\d,]+(?:\.\d+)?)\s*(million|billion|b|m)?'
    matches = []
    
    for m in re.finditer(money_pattern, text, re.IGNORECASE):
        raw_val = m.group(1)
        suffix = m.group(2)
        start_idx = m.start()
        
        # Convert to raw float
        val_float = clean_money_string(raw_val)
        
        # Handle "million/billion" word suffix
        if suffix:
            if suffix.lower().startswith('b'):
                val_float *= 1_000_000_000
            elif suffix.lower().startswith('m'):
                val_float *= 1_000_000
        
        matches.append({
            "val": val_float,
            "start": start_idx,
            "context": text[max(0, start_idx-50):min(len(text), start_idx+50)].lower() # Look at words around it
        })

    if not matches:
        return ""

    # 2. Filter out "Obligated" amounts (The Trap)
    # If the text near the money says "obligated", "fiscal", "funds", it's likely NOT the total award
    candidates = []
    for m in matches:
        # Penalize if "obligated" or "expiration" is very close
        if "obligated" in m["context"] or "expire" in m["context"]:
            continue 
        candidates.append(m["val"])

    # 3. Fallback: If all values were obligated, use the largest one found (better than zero)
    if not candidates:
        best_val = max(m["val"] for m in matches)
    else:
        # 4. Selection: Usually the Award Value is the LARGEST non-obligated amount
        best_val = max(candidates)

    # 5. Convert to Millions string
    return f"{best_val / 1_000_000:.3f}"

def detect_quantity_should_be_na(paragraph: str) -> bool:
    """
    If paragraph contains many item quantities -> treat Quantity as Not Applicable.
    Example3: 483 missiles, 82 missiles, 156 missiles, 198 containers...
    That is NOT a single contract-level quantity.
    """
    t = str(paragraph)

    # Count numeric patterns that look like item quantities
    qty_candidates = re.findall(r"\b(\d{1,5})\b", t)

    # If too many numbers appear, it's almost always a line-item contract
    if len(qty_candidates) >= 12:
        return True

    # Specific strong markers for line-item heavy paragraphs
    if any(k in t.lower() for k in ["as follows:", "lot", "containers", "spare", "tail cap", "guidance unit"]):
        return True

    return False


def normalize_currency(cur: str) -> str:
    if not cur:
        return "USD$"
    c = str(cur).strip().upper()
    if c in ["USD", "US$", "$", "US DOLLAR", "DOLLARS"]:
        return "USD$"
    return c

def word_to_int(token: str) -> Optional[int]:
    """
    Converts word numbers to int:
    eight -> 8, twenty one -> 21

    Supports 0-99 safely.
    Returns None if not a valid word-number.
    """
    if not token:
        return None

    t = str(token).strip().lower()
    t = t.replace("-", " ")
    parts = [p for p in t.split() if p]

    ones = {
        "zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5,
        "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10,
        "eleven": 11, "twelve": 12, "thirteen": 13, "fourteen": 14,
        "fifteen": 15, "sixteen": 16, "seventeen": 17, "eighteen": 18,
        "nineteen": 19
    }

    tens = {
        "twenty": 20, "thirty": 30, "forty": 40, "fifty": 50,
        "sixty": 60, "seventy": 70, "eighty": 80, "ninety": 90
    }

    # single word
    if len(parts) == 1:
        if parts[0] in ones:
            return ones[parts[0]]
        if parts[0] in tens:
            return tens[parts[0]]
        return None

    # two words: "twenty one"
    if len(parts) == 2:
        if parts[0] in tens and parts[1] in ones:
            return tens[parts[0]] + ones[parts[1]]
        return None

    return None


def parse_qty_token(qty_token: str) -> Optional[int]:
    """
    Converts numeric token or word-number token into int.
    """
    if qty_token is None:
        return None
    s = str(qty_token).strip().lower()

    # numeric
    if re.fullmatch(r"\d+", s):
        return int(s)

    # word-number
    return word_to_int(s)


def parse_fms_countries(paragraph: str) -> List[str]:
    """
    Extract FMS customer countries list.
    Example: 'governments of Australia, Bahrain, Belgium...'
    """
    text = str(paragraph)

    m = re.search(
        r"governments of (.+?)(?:\.\s| Work will be performed| Fiscal| This contract|$)",
        text,
        flags=re.IGNORECASE | re.DOTALL
    )
    if not m:
        return []

    block = m.group(1)
    raw = re.split(r",|\band\b", block)

    countries = []
    for c in raw:
        c = c.strip()
        if 2 < len(c) <= 40:
            countries.append(c)

    final, seen = [], set()
    for c in countries:
        if c.lower() not in seen:
            final.append(c)
            seen.add(c.lower())

    return final


def detect_multiple_supplier_award(paragraph: str) -> bool:
    """
    Detect Example2 pattern:
    Many suppliers listed with semicolons + "are awarded ..."
    """
    t = str(paragraph).strip()
    if ";" in t and re.search(r"\bare awarded\b", t, flags=re.IGNORECASE):
        if re.search(r"\([A-Z0-9]{6,}\)", t):
            return True
    return False


def parse_line_item_operator_allocations(paragraph: str) -> List[Dict[str, str]]:
    """
    Extract line-item splits like:

    '483 AIM-9X ... missiles (212 for the Navy, 187 for the Air Force and 84 for FMS customers);'
    '82 AIM-9X ... missiles (eight for the Navy, eight for the Air Force and 66 for FMS customers);'

    Output:
      [
        {"item_name": "...", "operator":"Navy", "quantity":"212", "g2g_b2g":"B2G"},
        {"item_name": "...", "operator":"Air Force", "quantity":"187", "g2g_b2g":"B2G"},
        {"item_name": "...", "operator":"Foreign Assistance", "quantity":"84", "g2g_b2g":"G2G"}
      ]
    """
    text = str(paragraph)

    # This finds: <total qty> <item desc> ( <allocation block> )
    # Example: 483 AIM-9X ... missiles (212 for the Navy, 187 for the Air Force and 84 for FMS customers)
    item_pattern = r"\b(\d+)\s+(.+?)\s*\(([^)]*for[^)]*)\)"
    matches = re.findall(item_pattern, text, flags=re.IGNORECASE)

    results = []

    for total_qty, item_desc, alloc_block in matches:
        item_desc_clean = re.sub(r"\s+", " ", item_desc).strip(" ;,. -")

        # --- BRANCH allocations (digit or word-number)
        branch_pattern = r"\b(\d+|[A-Za-z\-]+)\s+for\s+the\s+(Navy|Air Force|Army|Marine Corps|Space Force|Coast Guard|Defense Logistics Agency|Missile Defense Agency)\b"
        for qty_token, op_raw in re.findall(branch_pattern, alloc_block, flags=re.IGNORECASE):
            qty_val = parse_qty_token(qty_token)
            if qty_val is None:
                continue
            normalized_op = normalize_customer_operator(op_raw)       
            results.append({
                "item_name": item_desc_clean,
                "operator": normalized_op,
                "quantity": str(qty_val),
                "g2g_b2g": "B2G"
            })

        # --- FMS allocations (digit or word-number)
        fms_pattern = r"\b(\d+|[A-Za-z\-]+)\s+for\s+(?:Foreign Military Sales\s*\(FMS\)\s*customers|Foreign Military Sales\s*customers|FMS\s*customers|FMS)\b"
        for qty_token in re.findall(fms_pattern, alloc_block, flags=re.IGNORECASE):
            qty_val = parse_qty_token(qty_token)
            if qty_val is None:
                continue

            results.append({
                "item_name": item_desc_clean,
                "operator": "Foreign Assistance",
                "quantity": str(qty_val),
                "g2g_b2g": "G2G"
            })

    return results


def split_rows_engine(base_row: dict, paragraph: str) -> List[dict]:
    """
    FINAL STAGE5 Split Engine (Your required behavior)

    SPLITS:
    1) Multi-supplier award => Supplier Name = Multiple (NO supplier split)
    2) Line-item + operator allocation split (Navy/AirForce/FMS)
    3) FMS country split ONLY for G2G rows

    Only split fields are modified, rest remain unchanged.
    """
    paragraph = str(paragraph).strip()
    row = base_row.copy()

    split_reasons = []

    # ------------------------------------------------------------------
    # 1) MULTIPLE SUPPLIER AWARD (Example2)
    # ------------------------------------------------------------------
    if detect_multiple_supplier_award(paragraph):
        row["Supplier Name"] = "Multiple"
        row["Supplier Name Evidence"] = "Multiple Suppliers (Detected)"
        row["Split Flag"] = "No"
        row["Split Reason"] = "Multiple supplier award detected (no supplier split)"
        return [row]

    rows = [row]

    # ------------------------------------------------------------------
    # 2) LINE-ITEM + OPERATOR SPLIT (Your expected output)
    # ------------------------------------------------------------------
    item_allocs = parse_line_item_operator_allocations(paragraph)

    if item_allocs:
        split_reasons.append("Line-item operator allocation split")

        new_rows = []
        for r in rows:
            for alloc in item_allocs:
                rr = r.copy()

                # Only split fields change
                rr["Customer Operator"] = alloc["operator"]
                rr["Quantity"] = alloc["quantity"]
                rr["G2G/B2G"] = alloc["g2g_b2g"]

                # Make system labels reflect line-item name (matches your expected output)
                rr["System Name (General)"] = alloc["item_name"]
                rr["System Name (Specific)"] = alloc["item_name"]

                new_rows.append(rr)

        rows = new_rows

    # ------------------------------------------------------------------
    # 3) FMS COUNTRY SPLIT (ONLY for G2G rows)
    # ------------------------------------------------------------------
    fms_countries = parse_fms_countries(paragraph)
    if fms_countries:
        split_reasons.append("FMS country split")

        final_rows = []
        for r in rows:
            is_g2g = str(r.get("G2G/B2G", "")).strip().upper() == "G2G"
            is_fms_operator = "foreign" in str(r.get("Customer Operator", "")).lower()

            if is_g2g or is_fms_operator:
                for c in fms_countries:
                    rr = r.copy()
                    rr["Customer Country"] = normalize_country_name(c)
                    rr["Customer Region"] = get_region_for_country(rr["Customer Country"])
                    final_rows.append(rr)
            else:
                final_rows.append(r)

        rows = final_rows

    # ------------------------------------------------------------------
    # Final flags
    # ------------------------------------------------------------------
    if split_reasons:
        reason = " | ".join(split_reasons)
        for r in rows:
            r["Split Flag"] = "Yes"
            r["Split Reason"] = reason
    else:
        for r in rows:
            r["Split Flag"] = "No"
            r["Split Reason"] = "No split condition found"

    return rows

## LLM CALL HELPER (OpenRouter Safe Wrapper)
def call_llm_json(system_prompt: str, user_prompt: str, max_tokens: int):
    """
    Safe OpenRouter call wrapper
    - JSON response enforced
    - max_tokens enforced
    """
    completion = client.chat.completions.create(
        model=OPENROUTER_MODEL,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        temperature=0,
        max_tokens=max_tokens,
        response_format={"type": "json_object"},
    )
    return json.loads(completion.choices[0].message.content)

# 11) EXCEL HIGHLIGHTING FEATURE

def highlight_evidence_reason_columns(excel_path: str):
    wb = load_workbook(excel_path)

    # Always ensure at least one visible sheet
    ws = wb.active
    ws.sheet_state = "visible"

    header = [cell.value for cell in ws[1]]

    evidence_cols = []
    reason_cols = []

    for idx, col_name in enumerate(header, start=1):
        if isinstance(col_name, str) and "Evidence" in col_name:
            evidence_cols.append(idx)
        if isinstance(col_name, str) and "Reason" in col_name:
            reason_cols.append(idx)

    evidence_fill = PatternFill(start_color="FFF2CC", end_color="FFF2CC", fill_type="solid")
    reason_fill = PatternFill(start_color="D9E1F2", end_color="D9E1F2", fill_type="solid")
    header_font = Font(bold=True)

    for col_idx in evidence_cols:
        ws.cell(row=1, column=col_idx).fill = evidence_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for col_idx in reason_cols:
        ws.cell(row=1, column=col_idx).fill = reason_fill
        ws.cell(row=1, column=col_idx).font = header_font

    for row in range(2, ws.max_row + 1):
        for col_idx in evidence_cols:
            ws.cell(row=row, column=col_idx).fill = evidence_fill
        for col_idx in reason_cols:
            ws.cell(row=row, column=col_idx).fill = reason_fill

    wb.save(excel_path)
    print("Evidence + Reason columns highlighted successfully.")

def apply_rule_book(text: str):
    """
    Scans text against CUSTOM_RULE_BOOK triggers.
    Returns the guidance dict if a match is found, else None.
    """
    text_lower = text.lower()
    best_match = None
    
    for rule_name, rule_data in RULE_BOOK.items():
        for trigger in rule_data["triggers"]:
            if trigger in text_lower:
                # If we find a trigger, we return the guidance
                # You can add logic here to prioritize longer triggers
                return rule_data["guidance"]
    return None

Loaded: C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\taxonomy.json
Loaded: C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\notebook\suppliers.json
['22nd Century Tech', 'A&P Group', 'A&R Pacific -Garney Federal', 'AAR Supply Chain Inc.', 'Aardvark Clear', 'AASKI Technology', 'AAVCO', 'Abacus Tech Corp', 'Abdallah Al-Faris', 'Abeking Rasmuss', 'ABG Shipyards', 'ABM Shipyard', 'Absher Construction Co.', 'Abu Dhabi MAR', 'Abu Dhabi SB', 'ACC Construction Co.', 'Accenture', 'Accurate Energetic Systems', 'ACE Technology', 'Aceinfo Solutions', 'ACHILE Consortium', 'Achleitner', 'ACMI', 'ACT-Corp', 'ActioNet', 'ADCOM Systems', 'ADI Group', 'Admiralty Ship', 'Advanced Navigation and Positioning Corp.', 'Advanced Technology International', 'AdvElect Co (AEC)', 'AECOM', 'Aegis Technologies', 'Aeraccess', 'Aero Def Systems', 'Aero Synergie', 'Aero Vodochody', 'Aerodata AG', 'Aerodyca', 'Aerojet Rocketdyne', 'Aeromaritime Grp', 'Aeromot', 'Aeronautical Development Establishme

**Using Knowledgebase to support my extraction**

In [25]:
## RAG RETRIEVER (FAISS + METADATA)
class SystemKBRetriever:
    def __init__(self, kb_dir: str, embed_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
        self.kb_dir = kb_dir
        self.embed_model = embed_model

        index_path = os.path.join(kb_dir, "system_kb.faiss")
        meta_path = os.path.join(kb_dir, "system_kb_meta.pkl")

        if not os.path.exists(index_path) or not os.path.exists(meta_path):
            raise FileNotFoundError(
                f"KB files not found in: {kb_dir}\n"
                f"Expected:\n- {index_path}\n- {meta_path}\n\n"
                f"Build KB first using your KB builder script."
            )
        print(f"Loading FAISS Index: {index_path}")
        self.index = faiss.read_index(index_path)

        print(f"Loading KB Metadata: {meta_path}")
        with open(meta_path, "rb") as f:
            self.meta = pickle.load(f)

        print(f"KB Loaded rows: {len(self.meta)}")

        self.embedder = None

    def _lazy_load_embedder(self):
        if self.embedder is None:
            from sentence_transformers import SentenceTransformer
            self.embedder = SentenceTransformer(self.embed_model)

    def retrieve(self, query_text: str, top_k: int = 3):
        import numpy as np
        query_text = str(query_text).strip()
        if not query_text:
            return []

        self._lazy_load_embedder()

        q_emb = self.embedder.encode([query_text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(q_emb, top_k)

        results = []
        for score, idx in zip(scores[0], idxs[0]):
            if idx < 0:
                continue
            results.append({"score": float(score), "meta": self.meta[idx]})
        return results

In [26]:
retriever = SystemKBRetriever(kb_dir=RAG_KB_DIR)

Loading FAISS Index: C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\system_kb_store\system_kb.faiss
Loading KB Metadata: C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\system_kb_store\system_kb_meta.pkl
KB Loaded rows: 600


In [27]:
def safe_str(x, default=""):
    """
    Safely converts input to string, returning default if None or empty/whitespace.
    """
    if x is None:
        return default
    s = str(x).strip()
    return s if s else default

def rag_best_hit(paragraph: str, top_k: int = 3):
    """
    Retrieves the single best match from the Knowledge Base (FAISS).
    Returns: (hit_object, score, metadata_dict)
    """
    # Ensure the global retriever is loaded
    if 'retriever' not in globals():
        print("Warning: 'retriever' is not defined. Returning empty hit.")
        return None, 0.0, {}

    hits = retriever.retrieve(paragraph, top_k=top_k)
    
    if not hits:
        return None, 0.0, {}
        
    best = hits[0]
    return best, float(best.get("score", 0.0)), best.get("meta", {})

#-------------------------------------------
RAG_STRONG_THRESHOLD = 0.78   # if >= this, trust KB fully
RAG_MEDIUM_THRESHOLD = 0.70   # if >= this, use KB as strong hint

def normalize_country_name(c: str) -> str:
    if not c:
        return "Unknown"
    t = str(c).strip().lower()
    if t in ["united states", "united states of america", "us", "u.s.", "usa", "america"]:
        return "USA"
    if t in ["united kingdom", "uk", "u.k.", "britain", "great britain"]:
        return "UK"
    return str(c).strip()

def rag_best_hit(paragraph: str, top_k: int = 3):
    hits = retriever.retrieve(paragraph, top_k=top_k)
    if not hits:
        return None, 0.0, {}
    best = hits[0]
    return best, float(best.get("score", 0.0)), best.get("meta", {})

def safe_str(x, default=""):
    if x is None:
        return default
    s = str(x).strip()
    return s if s else default

def normalize_program_type_strict(pt_llm: str, text: str) -> str:
    """
    Decides Program Type based on strict SOP hierarchy.
    Priority: RDT&E > MRO > Upgrade > Procurement (Hardware/Simulators) > Training (Services).
    """
    text_lower = str(text).lower()
    pt_llm = str(pt_llm).strip()

    # 1. RDT&E (Highest Priority)
    # Detects: Prototypes, Research, Development, SBIR
    if any(k in text_lower for k in ["rdt&e", "research", "development", "prototype", "sbir", "demonstration"]):
        return "RDT&E"

    # 2. MRO/Support
    # Detects: Sustainment, Maintenance, Logistics, Repair, Overhaul
    if any(k in text_lower for k in ["mro", "sustainment", "maintenance", "logistics", "repair", "overhaul", "support services"]):
        return "MRO/Support"

    # 3. Upgrade
    # Detects: Modernization, Retrofit, Modification of existing platforms
    if any(k in text_lower for k in ["upgrade", "modernization", "retrofit", "modification", "life extension"]):
        return "Upgrade"

    # 4. Training (Strict Split)
    # Rule: Training SERVICES = Training.
    # Rule: Training HARDWARE (Simulators, Aircraft) = Procurement.
    if "training" in text_lower or "training" in pt_llm.lower():
        # Check for hardware indicators
        if any(k in text_lower for k in ["simulator", "device", "trainer aircraft", "hardware", "system", "kit"]):
            return "Procurement"
        return "Training"

    # 5. Procurement
    # Detects: Production, Delivery, Acquisition, Construction
    if any(k in text_lower for k in ["production", "delivery", "procurement", "acquisition", "supply", "purchase"]):
        return "Procurement"

    # Fallback to LLM's choice if valid, otherwise Other
    if pt_llm in ["Procurement", "Training", "MRO/Support", "RDT&E", "Upgrade", "Other Service"]:
        return pt_llm

    return "Other Service"


def calculate_value_certainty(text: str, value_str: str) -> str:
    """
    Determines if value is 'Confirmed' or 'Estimated'.
    """
    text_lower = text.lower()
    
    # 1. Keywords indicating estimation/limit
    if any(k in text_lower for k in ["ceiling", "maximum", "estimated", "potential", "not to exceed", "idiq", "bpa", "blanket purchase agreement"]):
        return "Estimated"

    # 2. Shared/Split Value Indicators
    # If multiple companies share a pot, the value is Estimated for that specific row
    if any(k in text_lower for k in ["shared", "competing", "each awarded", "multiple award", "aggregate"]):
        return "Estimated"

    # 3. If no value was extracted, it can't be confirmed
    if not value_str or value_str == "0.000":
        return "Estimated"

    return "Confirmed"


def calculate_usd_value(val_million_str: str, currency: str) -> str:
    """
    Populates 'Value (USD$ Million)'.
    Assuming 1:1 if already USD, otherwise leaves blank or requires conversion logic.
    """
    try:
        val = float(val_million_str)
        if val == 0:
            return ""
            
        # If Currency is USD, copy the value
        if "USD" in currency.upper() or "$" in currency:
            return f"{val:.3f}"
            
        # (Optional) Add simplistic conversion if needed, e.g., GBP -> 1.25 * Val
        # For now, we return extraction.
        return f"{val:.3f}"
    except:
        return ""

## AGENTS / TOOLS (Chunk-wise extraction + Merge)

In [28]:
## AGENTS / TOOLS (Chunk-wise extraction + Merge)

# Stage 1: SOURCING EXTRACTOR

class SourcingInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    url: str = Field(description="Source URL of the contract announcement/news.")
    date: str = Field(description="Contract date in Excel (string).")

@tool("sourcing_extractor")
def sourcing_extractor(paragraph: str, url: str, date: str):
    """
    Stage 1: SOURCING EXTRACTOR

    Purpose:
    - Creates the base skeleton row (stable fields).

    Output columns created:
    - Description of Contract
    - Additional Notes (Internal Only)
    - Source Link(s)
    - Contract Date
    - Reported Date (By SGA)

    Important:
    - These fields remain SAME even after splits.
    - Every split row inherits these values.
    """
    reported_date = datetime.datetime.now().strftime("%Y-%m-%d")

    notes = "Standard extraction."
    if "modification" in str(paragraph).lower():
        notes = "Contract Modification."
    if "multiple award" in str(paragraph).lower():
        notes = "Multiple award contract detected (non-supplier split)."

    return {
        "Description of Contract": paragraph,
        "Additional Notes (Internal Only)": notes,
        "Source Link(s)": url,
        "Contract Date": date,
        "Reported Date (By SGA)": reported_date
    }


In [29]:
## # Stage 2: GEOGRAPHY EXTRACTOR (Chunk -> Merge)

class GeographyInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")

@tool("geography_extractor")
def geography_extractor(paragraph: str):
    """
    Stage 2: GEOGRAPHY EXTRACTOR
    Updates Supplier Region based on the extracted Supplier Country.
    """
    paragraph = str(paragraph).strip()
    if not paragraph:
        return {}

    # --- RAG Optimization ---
    best_hit, best_score, best_meta = rag_best_hit(paragraph, top_k=3)
    
    # Defaults
    cust_country = "Unknown"
    cust_op = "Unknown"
    supp_country = "Unknown"

    if best_hit and best_score >= 0.78:
        # Trust KB if high confidence
        cust_country = safe_str(best_meta.get("Customer Country", ""))
        cust_op = normalize_customer_operator(safe_str(best_meta.get("Customer Operator", "")))
        supp_country = safe_str(best_meta.get("Supplier Country", ""))
    else:
        # LLM Extraction
        chunks = chunk_text(paragraph, chunk_size=1800, overlap=250)
        outputs = []
        for ch in chunks:
            try:
                raw = call_llm_json(GEOGRAPHY_PROMPT, ch, max_tokens=250)
                outputs.append(raw)
            except Exception:
                continue
        
        cust_country = pick_best_non_empty([o.get("Customer Country") for o in outputs]) or "Unknown"
        cust_op = normalize_customer_operator(pick_best_non_empty([o.get("Customer Operator") for o in outputs]))
        supp_country = pick_best_non_empty([o.get("Supplier Country") for o in outputs]) or "Unknown"

    # Normalize Countries
    cust_country = normalize_country_name(cust_country)
    supp_country = normalize_country_name(supp_country)

    # Calculate Regions using the SHARED logic
    cust_region = get_region_for_country(cust_country)
    supp_region = get_region_for_country(supp_country) # Same list as Customer Region

    return {
        "Customer Region": cust_region,
        "Customer Country": cust_country,
        "Customer Operator": cust_op,
        "Supplier Region": supp_region,   # Extracted here
        "Supplier Country": supp_country,
        "Domestic Content": "Indigenous" if cust_country == supp_country else "Imported"
    }


In [30]:
# Stage 3: SYSTEM CLASSIFIER (Chunk -> Merge)

class SystemInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    item_focus: str = Field(description="Specific line-item focus (from Stage5), e.g. 'All up round containers'")

@tool("system_classifier")
def system_classifier(paragraph: str, item_focus: str = ""):
    """
    Stage 3: RAG-Enhanced Classifier (Accuracy Mode)
    """
    paragraph = str(paragraph).strip()
    
    # --- 1. RETRIEVE KNOWLEDGE (The "Memory") ---
    # We query the KB for the specific item to see how we handled it before.
    search_query = item_focus if len(item_focus) > 3 else paragraph
    hits = retriever.retrieve(search_query, top_k=3)
    
    # Build the Context String
    rag_context = "No direct historical match found."
    if hits:
        rag_context = "### HISTORICAL PRECEDENTS (GOLDEN EXAMPLES):\n"
        for i, h in enumerate(hits, 1):
            m = h['meta']
            # We show the LLM the 'Correct Answer' from the past
            rag_context += (
                f"[{i}] Description: \"{m.get('original_text', '')[:120]}...\"\n"
                f"    -> Classified As: {m.get('Market Segment', '?')} / {m.get('System Name (Specific)', '?')}\n"
            )

    # --- 2. GET KEYWORD HINTS (The "Rule Book") ---
    rule_guidance = apply_rule_book(paragraph)
    rule_str = ""
    if rule_guidance:
        rule_str = f"Trace detected keywords: {json.dumps(rule_guidance)}"

    # --- 3. EXECUTE PROMPT ---
    # We inject the RAG context directly into the User Prompt so it's 'top of mind'
    formatted_system_prompt = SYSTEM_CLASSIFIER_PROMPT.format(
        taxonomy_reference=TAXONOMY_STR,
        rule_book_overrides=rule_str
    )

    user_prompt = f"""
    TARGET CONTRACT TEXT: "{paragraph}"
    SPECIFIC ITEM TO CLASSIFY: "{item_focus if item_focus else 'Main Deliverable'}"
    
    {rag_context}
    
    TASK:
    1. Check the 'HISTORICAL PRECEDENTS' above. If the item matches, align with that classification.
    2. If no precedent, analyze the 'TARGET CONTRACT TEXT' functionally.
    3. Extract the JSON.
    """

    try:
        # Call LLM
        res = call_llm_json(formatted_system_prompt, user_prompt, max_tokens=600)
        
        # --- 4. VALIDATION SAFETY NET ---
        # Ensure Piloting is logically consistent (LLMs sometimes struggle with 'Crewed' vs 'N/A')
        designators = extract_designators(paragraph)
        sys_type = res.get("System Type (General)", "")
        
        # If the LLM was unsure about piloting, run the strict rule checker
        if res.get("System Piloting") in ["Unknown", "N/A", None] and sys_type:
            res["System Piloting"] = detect_piloting_strict(paragraph, designators, sys_type)
            
        return res

    except Exception as e:
        return {"Confidence": "Low", "Error": str(e)}

In [31]:
# Stage 4: CONTRACT EXTRACTOR (Chunk -> Merge + Strict Supplier)

class ContractInfoInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    contract_date: str = Field(description="Contract date as string.")

def clean_llm_number(num_str):
    """Converts '$12.5M' or '12,000,000' to float."""
    if not num_str: return 0.0
    try:
        clean = str(num_str).upper().replace("$", "").replace("MILLION", "").replace("M", "").replace(",", "").strip()
        # Handle 'Billion' conversion if necessary
        multiplier = 1000 if "BILLION" in str(num_str).upper() or "B" in str(num_str).upper() else 1
        clean = clean.replace("BILLION", "").replace("B", "")
        return float(clean) * multiplier
    except:
        return 0.0

def calculate_mro_duration_sop(contract_date_str, completion_date_text, program_type):
    """
    SOP Calculation: Subtract completion date from contract date. 
    Capture difference in whole months (Round Down).
    Applicable ONLY for MRO/Support.
    """
    if program_type != "MRO/Support":
        return "Not Applicable"
    
    if not contract_date_str or not completion_date_text or completion_date_text.lower() in ["not applicable", "n/a", "unknown"]:
        return "Not Applicable"

    try:
        start_date = pd.to_datetime(contract_date_str)
        # Fuzzy parse the completion text (e.g., "September 2024")
        end_date = parser.parse(str(completion_date_text), fuzzy=True)
        
        # Calculate difference
        diff = relativedelta(end_date, start_date)
        total_months = (diff.years * 12) + diff.months
        
        # SOP: "For fractional values, apply the round-down approach"
        # relativedelta typically gives whole integers, but if days are involved:
        # If end day < start day, relativedelta usually handles this by reducing month count.
        # We ensure it's at least 0.
        return str(max(0, int(total_months)))
    except:
        return "Not Applicable"

@tool("contract_extractor")
def contract_extractor(paragraph: str, contract_date: str):
    """
    Stage 4: CONTRACT EXTRACTOR
    """
    paragraph = str(paragraph).strip()

    # 1. Prepare Prompt
    formatted_prompt = CONTRACT_EXTRACTOR_PROMPT.format(
        program_type_enum=str(PROGRAM_TYPE_ALLOWED)
    )
    user_message = f"CONTRACT TEXT: {paragraph}\nDATE: {contract_date}"

    # 2. Call LLM
    try:
        # Assuming call_llm_json is defined/imported
        llm_data = call_llm_json(formatted_prompt, user_message, max_tokens=400)
    except Exception as e:
        print(f"LLM Error: {e}")
        llm_data = {}

    # 3. --- SUPPLIER MATCHING LOGIC (The Fix) ---
    llm_supplier_raw = llm_data.get("extracted_supplier", "Unknown")
    
    # Run the raw LLM output through your fuzzy matcher
    final_supplier = get_best_supplier_match(llm_supplier_raw)

    # 4. Financial Processing
    raw_val = llm_data.get("value_million", 0)
    val_float = clean_llm_number(raw_val)
    val_final_str = f"{val_float:.3f}"
    
    # Currency Handling
    currency_code = llm_data.get("currency_code", "USD$")
    if "USD" in currency_code.upper() and "$" not in currency_code:
        currency_code = "USD$"

    if "USD" in currency_code:
        usd_final_str = val_final_str
    else:
        usd_final_str = f"{clean_llm_number(llm_data.get('value_usd_million', 0)):.3f}"

    # 5. Other Fields
    final_program_type = normalize_program_type_improved(
        llm_data.get("program_type", "Unknown"), 
        paragraph
    )

    mro_months = calculate_mro_duration_sop(
        contract_date, 
        llm_data.get("completion_date_text", ""), 
        final_program_type
    )

    return {
        "Supplier Name": final_supplier,  # This now contains the clean JSON match or "Multiple"
        "Supplier Name Evidence": "LLM + Fuzzy Match",
        "Program Type": final_program_type,
        "Value (Million)": val_final_str,
        "Value (USD$ Million)": usd_final_str,
        "Currency": currency_code,
        "Value Certainty": llm_data.get("value_certainty", "Estimated"),
        "Value Note (If Any)": llm_data.get("value_note", ""),
        "G2G/B2G": llm_data.get("g2g_b2g", "B2G"),
        "Expected MRO Contract Duration (Months)": mro_months,
        "Quantity": llm_data.get("quantity", "Not Applicable")
    }

In [32]:
# Stage 5: SPLITTER AGENT (FULL PARAGRAPH ALWAYS)

class SplitterInput(BaseModel):
    paragraph: str = Field(description="Full contract paragraph/description text.")
    base_row: dict = Field(description="Extracted row after Stage1-4.")

@tool("splitter_agent")
def splitter_agent(paragraph: str, base_row: dict):
  """
  Stage 5: SPLITTER AGENT
  Purpose:
    - Applies deterministic split logic to generate multiple output rows
      when paragraph has explicit multi allocations.
  Supported splits:
    - Operator/Quantity split ("212 for the Navy", "187 for the Air Force")
    - FMS country split (only for rows marked as G2G)
    - Multi-value note (does not split, only notes)

  IMPORTANT:
    - Supplier split is REMOVED to prevent wrong supplier explosions.
  """
  try:
      rows = split_rows_engine(base_row, paragraph)
      for r in rows:
          r.setdefault("Split Flag", "No")
          r.setdefault("Split Reason", "")
      return {"rows": rows}
  except Exception as e:
      base_row["Split Flag"] = "Error"
      base_row["Split Reason"] = f"Split failed: {str(e)}"
      return {"rows": [base_row]}


In [33]:
# Stage 6: QUALITY VALIDATOR AGENT

class QAInput(BaseModel):
    paragraph: str = Field(description="Original paragraph for reference.")
    rows: list = Field(description="Final split rows list output from Stage5.")

@tool("quality_validator")
def quality_validator(paragraph: str, rows: list):
    """
    Stage 6: QUALITY VALIDATOR (Hybrid: Rule-Based + KB Check)
    """
    text = str(paragraph).lower()
    validated_rows = []

    for r in rows:
        flags = []
        fixes = []

        # --- Rule Based Checks ---
        supplier = str(r.get("Supplier Name", "")).strip()
        if supplier.lower() in ["unknown", "", "multiple"]:
             # If text clearly has an award, flag it
             if "awarded" in text:
                 flags.append("Supplier Unknown")
                 fixes.append("Check LLM Fallback")

        # --- KB/Taxonomy Validation (NEW) ---
        sys_name = str(r.get("System Name (General)", "")).strip()
        
        # If we have a system name, verify it against KB context
        if sys_name and sys_name.lower() != "unknown":
            # Quick check: does this system appear in our RAG hits?
            # We assume 'retriever' is globally available as per notebook scope
            hits = retriever.retrieve(sys_name, top_k=1)
            if hits:
                top_score = hits[0]['score']
                # If score is very low, the system might be hallucinated or poorly named
                if top_score < 0.35: 
                    flags.append(f"System Name '{sys_name}' has low KB confidence ({top_score:.2f})")
                    fixes.append("Verify system name against standard taxonomy")

        # --- Final Status ---
        qa_status = "PASS" if not flags else "FAIL"
        
        rr = r.copy()
        rr["QA Status"] = qa_status
        rr["QA Flags"] = " | ".join(flags) if flags else "None"
        rr["QA Fix Suggestion"] = " | ".join(fixes) if fixes else "None"
        validated_rows.append(rr)

    return {"rows": validated_rows}

In [34]:
# Stage 7: LLM VALIDATOR (FAIL ONLY) - Chunk Aware
class LLMValidateInput(BaseModel):
    paragraph: str = Field(description="Original paragraph text.")
    row: dict = Field(description="One FAIL row to validate/correct.")

@tool("llm_fail_row_validator")
def llm_fail_row_validator(paragraph: str, row: dict):
    """
    Stage 7: VALIDATOR FIX (Linked to prompts.py)
    """
    # LINK THE PROMPT
    formatted_fix_prompt = VALIDATOR_FIX_PROMPT.format(
        failed_row_json=json.dumps(row, indent=2),
        program_type_enum=str(PROGRAM_TYPE_ALLOWED)
    )

    user_prompt = f"ORIGINAL TEXT CHUNK: {paragraph[:2000]}"

    try:
        fix = call_llm_json(formatted_fix_prompt, user_prompt, max_tokens=350)
        
        # Merge Fixes
        corrected = row.copy()
        if fix.get("Supplier Name") and fix["Supplier Name"] != "Unknown":
            corrected["Supplier Name"] = fix["Supplier Name"]
        if fix.get("Program Type"):
            corrected["Program Type"] = fix["Program Type"]
        if fix.get("Value (Million)"):
            corrected["Value (Million)"] = fix["Value (Million)"]
            
        corrected["LLM QA Fix Summary"] = fix.get("Fix Summary", "Auto-fixed by Agent")
        return {"row": corrected}

    except Exception:
        return {"row": row}

In [35]:
class AgentState(TypedDict):
    input_text: str
    input_date: str
    input_url: str

    final_data: dict
    final_rows: list
    validated_rows: list
    final_rows_post_llm: list

    messages: Annotated[List[AnyMessage], add_messages]


def stage_1_sourcing(state: AgentState):
    res = sourcing_extractor.invoke({
        "paragraph": state["input_text"],
        "url": state["input_url"],
        "date": state["input_date"],
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_2_geography(state: AgentState):
    res = geography_extractor.invoke({"paragraph": state["input_text"]})
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_3_system(state: AgentState):
    """
    NEW Stage3 Node:
    Runs system classification PER split row using item_focus.
    """

    paragraph = state["input_text"]
    rows = state.get("final_rows", [])

    updated_rows = []

    for r in rows:
        item_focus = str(r.get("System Name (Specific)", "")).strip()

        res = system_classifier.invoke({
            "paragraph": paragraph,
            "item_focus": item_focus
        })

        rr = r.copy()
        rr.update(res)   # merge system labels into the row
        updated_rows.append(rr)

    return {"final_rows": updated_rows}


def stage_4_contract(state: AgentState):
    res = contract_extractor.invoke({
        "paragraph": state["input_text"],
        "contract_date": state["input_date"]
    })
    new_data = state.get("final_data", {}).copy()
    new_data.update(res)
    return {"final_data": new_data}


def stage_5_split(state: AgentState):
    """
    Stage5 Node: SplitterEngine

     Ensures output is ALWAYS stored in `final_rows`
    so Stage3 SystemClassifierRAG can loop through them.
    """

    base_row = state.get("final_data", {}) or {}
    paragraph = state.get("input_text", "")

    try:
        res = splitter_agent.invoke({
            "paragraph": paragraph,
            "base_row": base_row
        })

        rows = res.get("rows", None)

        # Hard safety fallback
        if not rows or not isinstance(rows, list):
            rows = [base_row]

        return {"final_rows": rows}

    except Exception as e:
        # Never crash pipeline due to split failure
        fallback = base_row.copy()
        fallback["Split Flag"] = "Error"
        fallback["Split Reason"] = f"SplitterEngine failed: {str(e)}"
        return {"final_rows": [fallback]}


def stage_6_quality_validator(state: AgentState):
    res = quality_validator.invoke({
        "paragraph": state["input_text"],
        "rows": state["final_rows"]
    })
    return {"validated_rows": res.get("rows", state["final_rows"])}


def stage_7_llm_fix_fail_rows(state: AgentState):
    paragraph = state["input_text"]
    validated_rows = state.get("validated_rows", [])

    fixed_rows = []
    for r in validated_rows:
        if r.get("QA Status") == "FAIL":
            fix_res = llm_fail_row_validator.invoke({"paragraph": paragraph, "row": r})
            fixed_rows.append(fix_res.get("row", r))
        else:
            fixed_rows.append(r)

    return {"final_rows_post_llm": fixed_rows}

def node_system_classifier_rag(state: AgentState):
    """
    Stage3 Node (AFTER SplitterEngine)

    Runs system classification per split-row using:
       item_focus = row["System Name (Specific)"] or row["System Name (General)"]

    Updates each row with Market Segment / System Type / System Name / Evidence / Reason
    """

    paragraph = state["input_text"]
    rows = state.get("final_rows", [])

    # If split engine didn't create rows, fallback to single final_data row
    if not rows:
        base = state.get("final_data", {})
        rows = [base] if base else []

    updated_rows = []

    for r in rows:
        item_focus = str(r.get("System Name (Specific)", "")).strip()
        if not item_focus:
            item_focus = str(r.get("System Name (General)", "")).strip()

        # invoke your Stage3 tool
        sys_res = system_classifier.invoke({
            "paragraph": paragraph,
            "item_focus": item_focus
        })

        rr = r.copy()
        rr.update(sys_res)
        updated_rows.append(rr)

    return {"final_rows": updated_rows}


In [36]:
workflow = StateGraph(AgentState)

workflow.add_node("SourcingExtractor", stage_1_sourcing)
workflow.add_node("GeographyExtractor", stage_2_geography)
workflow.add_node("ContractExtractor", stage_4_contract)
workflow.add_node("SplitterEngine", stage_5_split)
workflow.add_node("SystemClassifierRAG", node_system_classifier_rag)  # 
workflow.add_node("QualityValidator", stage_6_quality_validator)
workflow.add_node("LLMFailRowFixer", stage_7_llm_fix_fail_rows)

workflow.add_edge(START, "SourcingExtractor")
workflow.add_edge("SourcingExtractor", "GeographyExtractor")
workflow.add_edge("GeographyExtractor", "ContractExtractor")
workflow.add_edge("ContractExtractor", "SplitterEngine")   # 
workflow.add_edge("SplitterEngine", "SystemClassifierRAG") # 
workflow.add_edge("SystemClassifierRAG", "QualityValidator")
workflow.add_edge("QualityValidator", "LLMFailRowFixer")
workflow.add_edge("LLMFailRowFixer", END)
app = workflow.compile()



app = workflow.compile()

In [37]:
from IPython.display import Image, display
try:
  display(Image(app.get_graph().draw_mermaid_png()))
except Exception as e:
  print(e)

Failed to reach https://mermaid.ink API while trying to render your graph after 1 retries. To resolve this issue:
1. Check your internet connection and try again
2. Try with higher retry settings: `draw_mermaid_png(..., max_retries=5, retry_delay=2.0)`
3. Use the Pyppeteer rendering method which will render your graph locally in a browser: `draw_mermaid_png(..., draw_method=MermaidDrawMethod.PYPPETEER)`


In [38]:
# 10) GRAPH EXPORT (OFFLINE SAFE)

def export_workflow_mermaid(app_obj, out_file="workflow.mmd"):
    mmd = app_obj.get_graph().draw_mermaid()
    with open(out_file, "w", encoding="utf-8") as f:
        f.write(mmd)
    print(f"Workflow Mermaid saved locally: {out_file}")
    return out_file

In [39]:
if __name__ == "__main__":

    print(f"\n Loading Input File: {INPUT_EXCEL_PATH}")
    
    # Define Output Path as Excel
    OUTPUT_EXCEL_PATH = "Processed_Defense_Data.xlsx"

    export_workflow_mermaid(app, out_file="workflow.mmd")

    try:
        df_input = pd.read_excel(INPUT_EXCEL_PATH)
        
        # Basic validation
        required_cols = ["Source URL", "Contract Date", "Contract Description"]
        # Allow loose matching or strip whitespace from columns if needed
        df_input.columns = [c.strip() for c in df_input.columns]
        
        if not all(col in df_input.columns for col in required_cols):
             raise ValueError(f"Excel file must contain columns: {required_cols}")

        print(f"Processing {len(df_input)} rows...")
        results = []

        for index, row in df_input.iterrows():
            print(f"\n Row {index + 1}/{len(df_input)}")

            desc = str(row["Contract Description"]) if pd.notna(row["Contract Description"]) else ""
            c_date = str(row["Contract Date"]) if pd.notna(row["Contract Date"]) else str(datetime.date.today())
            c_url = str(row["Source URL"]) if pd.notna(row["Source URL"]) else ""

            initial_state: AgentState = {
                "input_text": desc,
                "input_date": c_date,
                "input_url": c_url,
                "final_data": {},
                "final_rows": [],
                "validated_rows": [],
                "final_rows_post_llm": [],
                "messages": []
            }

            output_state = app.invoke(initial_state)

            # Hierarchy of fallback for getting rows
            rows = output_state.get("final_rows_post_llm", [])
            if not rows:
                rows = output_state.get("validated_rows", [])
            if not rows:
                rows = output_state.get("final_rows", [])
            if not rows:
                rows = [output_state.get("final_data", {})]

            results.extend(rows)

        df_final = pd.DataFrame(results)

        FINAL_COLUMNS = [
            "Customer Region", "Customer Country", "Customer Operator",
            "Supplier Region", "Supplier Country", "Domestic Content",

            "Split Flag", "Split Reason",

            "Market Segment", "Market Segment Evidence", "Market Segment Reason",
            "System Type (General)", "System Type (General) Evidence", "System Type (General) Reason",
            "System Type (Specific)", "System Type (Specific) Evidence", "System Type (Specific) Reason",
            "System Name (General)", "System Name (General) Evidence", "System Name (General) Reason",
            "System Name (Specific)", "System Name (Specific) Evidence", "System Name (Specific) Reason",
            "System Piloting", "System Piloting Evidence", "System Piloting Reason",
            "Confidence",

            "Supplier Name", "Supplier Name Evidence",
            "Program Type", "Expected MRO Contract Duration (Months)",
            "Quantity", "Value Certainty", "Value (Million)", "Currency",
            "Value (USD$ Million)", "Value Note (If Any)", "G2G/B2G",
            "Signing Month", "Signing Year",

            "QA Status", "QA Flags", "QA Fix Suggestion",
            "LLM QA Fix Summary",

            "Description of Contract",
            "Additional Notes (Internal Only)",
            "Source Link(s)",
            "Contract Date",
            "Reported Date (By SGA)"
        ]

        # Reindex ensures all columns exist, filling missing ones with empty string
        df_final = df_final.reindex(columns=FINAL_COLUMNS, fill_value="")

        # 1. Save as Excel
        print(f"Saving to Excel: {OUTPUT_EXCEL_PATH}")
        df_final.to_excel(OUTPUT_EXCEL_PATH, index=False, engine='openpyxl')
        
        # 2. Apply Formatting/Highlighting
        print("Applying Highlighting...")
        highlight_evidence_reason_columns(OUTPUT_EXCEL_PATH)

        print("\nProcessing Complete!")
        #print(df_final[["Supplier Name", "Value (USD$ Million)", "Customer Operator", "Supplier Region"]].head(3).to_string(index=False))

    except Exception as e:
        print(f"\n ERROR: {e}")
        import traceback
        traceback.print_exc()


 Loading Input File: C:\Users\mukeshkr\Agentic-AI-Defense-Data-Extraction\data\source_file.xlsx
Workflow Mermaid saved locally: workflow.mmd
Processing 9 rows...

 Row 1/9

 Row 2/9

 Row 3/9

 Row 4/9

 Row 5/9

 Row 6/9

 Row 7/9

 Row 8/9

 Row 9/9
Saving to Excel: Processed_Defense_Data.xlsx
Applying Highlighting...
Evidence + Reason columns highlighted successfully.

Processing Complete!


In [None]:
# ==========================================================
# DEFENSE AGENTIC PIPELINE ‚Äî 7 AGENTS (FINAL)
# ==========================================================

import os, json, pickle, datetime, getpass
from typing import TypedDict, List, Annotated
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer

# ---------------- LangGraph ----------------
from langchain_core.tools import tool
from langchain_core.messages import AnyMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages

# ---------------- LLM ----------------
from openai import OpenAI

# ---------------- Excel ----------------
from openpyxl import load_workbook
from openpyxl.styles import PatternFill

# ==========================================================
# CONFIG
# ==========================================================

INPUT_EXCEL = "source_file.xlsx"
OUTPUT_EXCEL = "Processed_Defense_Data.xlsx"
KB_DIR = "system_kb_store"

RAG_STRONG = 0.78
RAG_MEDIUM = 0.70

# ==========================================================
# PROMPT PLACEHOLDERS (PASTE YOUR REAL PROMPTS)
# ==========================================================

GEOGRAPHY_PROMPT = """<<< GEOGRAPHY PROMPT HERE >>>"""
SYSTEM_CLASSIFIER_PROMPT = """<<< SYSTEM CLASSIFIER PROMPT HERE >>>"""
CONTRACT_EXTRACTOR_PROMPT = """<<< CONTRACT EXTRACTOR PROMPT HERE >>>"""
VALIDATOR_FIX_PROMPT = """<<< OPTIONAL VALIDATOR PROMPT HERE >>>"""

# ==========================================================
# LLM CLIENT
# ==========================================================

if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter LLM Foundry API Key: ")

if "OPENROUTER_API_KEY" not in os.environ:
    os.environ["OPENROUTER_API_KEY"] = getpass.getpass("Enter OpenRouter API Key: ")

# ===== LLM FOUNDRY (PRIMARY ‚Äì INTERNAL) =====
llm_foundry_client = OpenAI(
    api_key=f'{os.environ["LLMFOUNDRY_TOKEN"]}:agentic',
    base_url="https://llmfoundry.straive.com/openai/v1/"
)

FOUNDRY_MODEL = "gpt-4o-mini"


# ===== OPENROUTER (SECONDARY / FALLBACK / COMPARISON) =====
openrouter_client = OpenAI(
    api_key=os.environ["OPENROUTER_API_KEY"],
    base_url="https://openrouter.ai/api/v1"
)

OPENROUTER_MODEL = "openai/gpt-4o-mini"

def call_llm(
    prompt: str,
    backend: str = "foundry",   # "foundry" | "openrouter"
    max_tokens: int = 500
):
    """
    Unified LLM call wrapper.
    You decide backend per agent / per confidence / per column.
    """

    if backend == "openrouter":
        client = openrouter_client
        model = OPENROUTER_MODEL
    else:
        client = llm_foundry_client
        model = FOUNDRY_MODEL

    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
        max_tokens=max_tokens,
        response_format={"type": "json_object"},
    )

    return json.loads(response.choices[0].message.content)
# ==========================================================
# KB RETRIEVER
# ==========================================================

class KBRetriever:
    def __init__(self, kb_dir):
        self.index = faiss.read_index(os.path.join(kb_dir, "system_kb.faiss"))
        with open(os.path.join(kb_dir, "system_kb_meta.pkl"), "rb") as f:
            self.meta = pickle.load(f)
        self.embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

    def best_hit(self, text):
        emb = self.embedder.encode([text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(emb, 1)
        if idxs[0][0] < 0:
            return {}, 0.0, None
        return self.meta[idxs[0][0]], float(scores[0][0]), idxs[0][0]

retriever = KBRetriever(KB_DIR)

# ==========================================================
# MODE / CONFIDENCE
# ==========================================================

def mode_from_score(score):
    if score >= RAG_STRONG:
        return "KB_ONLY"
    if score >= RAG_MEDIUM:
        return "KB_GUIDED"
    return "LLM_ONLY"

def confidence_from_mode(mode):
    return {
        "KB_ONLY": "High (KB)",
        "KB_GUIDED": "Medium (KB+LLM)",
        "LLM_ONLY": "Low (LLM)"
    }.get(mode, "Unknown")

# ==========================================================
# AGENT STATE
# ==========================================================

class AgentState(TypedDict):
    text: str
    date: str
    url: str

    kb_meta: dict
    kb_score: float
    kb_mode: str
    kb_row_id: int | None

    row: dict
    rows: list

    messages: Annotated[List[AnyMessage], add_messages]

# ==========================================================
# AGENT 1 ‚Äî KB ROUTER
# ==========================================================

def kb_router_agent(state: AgentState):
    meta, score, row_id = retriever.best_hit(state["text"])
    return {
        "kb_meta": meta,
        "kb_score": score,
        "kb_mode": mode_from_score(score),
        "kb_row_id": row_id
    }

# ==========================================================
# AGENT 2 ‚Äî SOURCING
# ==========================================================

def sourcing_agent(state: AgentState):
    return {
        "row": {
            "Description of Contract": state["text"],
            "Contract Date": state["date"],
            "Source Link(s)": state["url"],
            "Reported Date (By SGA)": datetime.date.today().isoformat()
        }
    }

# ==========================================================
# AGENT 3 ‚Äî GEOGRAPHY
# ==========================================================

def geography_agent(state: AgentState):
    row = state["row"].copy()

    if state["kb_mode"] != "LLM_ONLY":
        for k in ["Customer Country", "Customer Region",
                  "Customer Operator", "Supplier Country",
                  "Supplier Region", "Domestic Content"]:
            if state["kb_meta"].get(k):
                row[k] = state["kb_meta"][k]
    else:
        llm = call_llm(GEOGRAPHY_PROMPT + "\n" + state["text"],backend="foundry" if state["kb_mode"] != "LLM_ONLY" else "openrouter")

    return {"row": row}

# ==========================================================
# AGENT 4 ‚Äî SYSTEM CLASSIFIER
# ==========================================================

def system_classifier_agent(state: AgentState):
    row = state["row"].copy()

    if state["kb_mode"] != "LLM_ONLY":
        for k in ["Market Segment", "System Type (General)",
                  "System Type (Specific)", "System Name (General)",
                  "System Name (Specific)", "System Piloting"]:
            if state["kb_meta"].get(k):
                row[k] = state["kb_meta"][k]
    else:
        llm = call_llm(SYSTEM_CLASSIFIER_PROMPT + "\n" + state["text"],backend="foundry" if state["kb_mode"] != "LLM_ONLY" else "openrouter")


    return {"row": row}

# ==========================================================
# AGENT 5 ‚Äî CONTRACT EXTRACTOR
# ==========================================================

def contract_extractor_agent(state: AgentState):
    row = state["row"].copy()

    if state["kb_mode"] != "LLM_ONLY":
        for k in ["Supplier Name", "Program Type", "Value (Million)",
                  "Value (USD$ Million)", "Currency",
                  "Value Certainty", "Quantity", "G2G/B2G"]:
            if state["kb_meta"].get(k):
                row[k] = state["kb_meta"][k]
    else:
        llm = call_llm(CONTRACT_EXTRACTOR_PROMPT + "\n" + state["text"],backend="openrouter")

    return {"rows": [row]}

# ==========================================================
# AGENT 6 ‚Äî EVALUATION (ALL COLUMNS)
# ==========================================================

def evaluation_agent(state: AgentState):
    evaluated = []

    for r in state["rows"]:
        row = r.copy()

        for col in list(row.keys()):
            row[f"{col} Source"] = (
                "KB" if state["kb_mode"] == "KB_ONLY"
                else "KB+LLM" if state["kb_mode"] == "KB_GUIDED"
                else "LLM"
            )
            row[f"{col} Confidence"] = confidence_from_mode(state["kb_mode"])
            row[f"{col} Reason"] = f"Mode={state['kb_mode']} KBscore={state['kb_score']:.2f}"

        row["Accuracy Score"] = 90 if state["kb_mode"] == "KB_ONLY" else 70 if state["kb_mode"] == "KB_GUIDED" else 50
        row["Evaluation Status"] = (
            "HIGH CONFIDENCE" if row["Accuracy Score"] >= 85 else
            "MEDIUM CONFIDENCE" if row["Accuracy Score"] >= 65 else
            "LOW CONFIDENCE"
        )

        evaluated.append(row)

    return {"rows": evaluated}

# ==========================================================
# AGENT 7 ‚Äî EXCEL FORMATTER
# ==========================================================

def excel_formatter_agent(state: AgentState):
    df = pd.DataFrame(state["rows"])
    df.to_excel(OUTPUT_EXCEL, index=False)

    wb = load_workbook(OUTPUT_EXCEL)
    ws = wb.active
    headers = [c.value for c in ws[1]]

    if "Accuracy Score" in headers:
        idx = headers.index("Accuracy Score") + 1
        green = PatternFill("solid", fgColor="C6EFCE")
        yellow = PatternFill("solid", fgColor="FFEB9C")
        red = PatternFill("solid", fgColor="F4CCCC")

        for r in range(2, ws.max_row + 1):
            cell = ws.cell(row=r, column=idx)
            try:
                v = int(cell.value)
            except:
                continue
            cell.fill = green if v >= 85 else yellow if v >= 65 else red

    wb.save(OUTPUT_EXCEL)
    return {}

# ==========================================================
# LANGGRAPH
# ==========================================================

graph = StateGraph(AgentState)

graph.add_node("KBRouter", kb_router_agent)
graph.add_node("Sourcing", sourcing_agent)
graph.add_node("Geography", geography_agent)
graph.add_node("System", system_classifier_agent)
graph.add_node("Contract", contract_extractor_agent)
graph.add_node("Evaluation", evaluation_agent)
graph.add_node("ExcelFormatter", excel_formatter_agent)

graph.add_edge(START, "KBRouter")
graph.add_edge("KBRouter", "Sourcing")
graph.add_edge("Sourcing", "Geography")
graph.add_edge("Geography", "System")
graph.add_edge("System", "Contract")
graph.add_edge("Contract", "Evaluation")
graph.add_edge("Evaluation", "ExcelFormatter")
graph.add_edge("ExcelFormatter", END)

app = graph.compile()

# ==========================================================
# MAIN
# ==========================================================

if __name__ == "__main__":

    df = pd.read_excel(INPUT_EXCEL)

    for _, r in df.iterrows():
        state = {
            "text": str(r["Contract Description"]),
            "date": str(r["Contract Date"]),
            "url": str(r["Source URL"]),
            "row": {},
            "rows": [],
            "messages": []
        }

        app.invoke(state)

    print("‚úÖ 7-AGENT DEFENSE PIPELINE EXECUTED SUCCESSFULLY")


In [None]:
# ==========================================================
# DEFENSE AGENTIC PIPELINE ‚Äî 7 AGENTS (FINAL)
# ==========================================================

import os, json, pickle, datetime, getpass
from typing import TypedDict, List, Annotated
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from IPython.display import Image, display

# ---------------- LangGraph ----------------
from langchain_core.tools import tool
from langchain_core.messages import AnyMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages

# ---------------- LLM ----------------
from openai import OpenAI

# ---------------- Excel ----------------
from openpyxl import load_workbook
from openpyxl.styles import PatternFill

# ==========================================================
# CONFIG
# ==========================================================

INPUT_EXCEL = "/content/sample_data.xlsx"
OUTPUT_EXCEL = "Processed_Defense_Data.xlsx"
KB_DIR = "/content/system_kb_store"

RAG_STRONG = 0.78
RAG_MEDIUM = 0.70

# ==========================================================
# PROMPT PLACEHOLDERS (PASTE YOUR REAL PROMPTS)
# ==========================================================

GEOGRAPHY_PROMPT = """
You are a Defense Geography Analyst. 
Extract the Customer Country, Customer Operator, and Supplier Country from the text.

STRICT RULES:
1. **Customer Country**: 
   - Identify the government/nation PAYING for or RECEIVING the goods.
   - If "Foreign Military Sales (FMS)" is mentioned, look for the specific country name (e.g., "FMS to Japan").
   - Do NOT assume the "Work Location" is the Customer. (e.g., Work in Alabama for a contract supporting the UK -> Customer is UK).

2. **Customer Operator**:
   - Extract the specific service branch (e.g., "Navy", "Air Force", "Army", "Coast Guard", "Marines").
   - If a specific foreign military branch is named (e.g., "Royal Australian Air Force"), extract that.

3. **Supplier Country**:
   - Identify the country where the Supplier Company's headquarters is located.

Return JSON ONLY:
{{
  "Customer Country": "...",
  "Customer Operator": "...",
  "Supplier Country": "..."
}}
"""

# --- STAGE 3: SYSTEM CLASSIFIER ---
SYSTEM_CLASSIFIER_PROMPT = """
You are a Senior Defense System Classification Analyst.

1. **REFERENCE TAXONOMY**:
{taxonomy_reference}

2. **RULE BOOK OVERRIDES**:
{rule_book_overrides}

3. **TASK**:
   - Classify the system described in the contract into **Market Segment**, **System Type (General)**, and **System Name**.
   - **CRITICAL**: If "ITEM_FOCUS" is provided, classify THAT specific item. If empty, classify the main system in the text.
   - Use the "RAG Examples" provided to guide your choice if the text is similar.

4. **CLASSIFICATION RULES**:
   - **Generic IT/Enterprise Software**: If the contract is for generic office software (e.g., Microsoft 365, DoD ESI), cloud services, or non-tactical IT, classify Market Segment as **"Unknown"** or **"Not Applicable"**.
   - **Air vs Navy**: If the system is an Aircraft (e.g., P-8, E-2D, F-35), Market Segment is **"Air Platforms"**, even if the customer is the Navy.
   - **Ship/Submarine**: Market Segment is **"Naval Platforms"**.

5. **SYSTEM NAME EXTRACTION**:
   - **System Name (General)**: The **Host Platform** or **Class** (e.g., "E-2D Advanced Hawkeye", "Arleigh Burke-class", "Los Angeles-class").
   - **System Name (Specific)**: The **Specific Subject** of the contract.
     - If it's a specific ship/aircraft instance: Extract the name/hull number (e.g., "USS Pinckney (DDG-91)", "USS Hartford (SSN-768)", "USNS Robert Ballard (T-AGS 67)").
     - If it's a service/mod description: Extract the description (e.g., "Extend Services and Adds Hours...", "Depot Modernization Period").
     - If it's a component: Extract the component name.

6. **OUTPUT RULES**:
   - Return ONLY a FLAT JSON object.
   - Evidence must be copied EXACTLY from the text.
   - If evidence is not present, output "Not Found".

Return JSON:
{{
  "Market Segment": "...",
  "Market Segment Evidence": "...",
  "Market Segment Reason": "...",
  
  "System Type (General)": "...",
  "System Type (General) Evidence": "...",
  "System Type (General) Reason": "...",

  "System Type (Specific)": "...",
  "System Type (Specific) Evidence": "...",
  "System Type (Specific) Reason": "...",

  "System Name (General)": "...",
  "System Name (General) Evidence": "...",
  "System Name (General) Reason": "...",

  "System Name (Specific)": "...",
  "System Name (Specific) Evidence": "...",
  "System Name (Specific) Reason": "...",

  "Confidence": "High/Medium/Low"
}}
"""

CONTRACT_EXTRACTOR_PROMPT = """
You are a Defense Contract Financial Analyst.

1. **TASK**: Extract supplier, program type, financial certainty, FMS status, completion date, and currency.
2. **PROGRAM TYPE ENUM**:
   {program_type_enum}

3. **STRICT RULES**:
   - **Supplier Name**: Extract the **Clean Entity Name**. Include the **Major Division** if specified (e.g., "General Dynamics Electric Boat", "Northrop Grumman Aerospace"). Do not include legal suffixes like "Corp", "Inc", "L.P." unless part of the brand.
   - **Program Type**:
     - **MRO/Support**: Includes "depot modernization", "maintenance", "overhaul", "repair", "sustainment", "logistics support".
     - **Procurement**: Includes "production", "manufacture", "delivery" of new hardware.
     - **RDT&E**: Research, development, prototyping.
   - **Value Certainty**: 
     - "Confirmed" for definite contracts/mods.
     - "Estimated" for IDIQ ceilings, "potential value", or "maximum value".
   - **G2G/B2G**: "G2G" ONLY if "Foreign Military Sales" (FMS) is explicitly mentioned. Otherwise "B2G".
   - **Value Note**: Capture notes about IDIQs, options, or ceilings.

Return JSON ONLY:
{{
  "program_type": "...",
  "currency_code": "...",
  "value_certainty": "...",
  "completion_date_text": "...",
  "g2g_b2g": "...",
  "value_note": "...",
  "extracted_supplier": "..."
}}
"""

# ==========================================================
# LLM CLIENT
# ==========================================================

if "LLMFOUNDRY_TOKEN" not in os.environ:
    os.environ["LLMFOUNDRY_TOKEN"] = getpass.getpass("Enter LLM Foundry API Key: ")

if "OPENROUTER_API_KEY" not in os.environ:
    os.environ["OPENROUTER_API_KEY"] = getpass.getpass("Enter OpenRouter API Key: ")

# ===== LLM FOUNDRY (PRIMARY ‚Äì INTERNAL) =====
llm_foundry_client = OpenAI(
    api_key=f'{os.environ["LLMFOUNDRY_TOKEN"]}:agentic',
    base_url="https://llmfoundry.straive.com/openai/v1/"
)

FOUNDRY_MODEL = "gpt-4o-mini"


# ===== OPENROUTER (SECONDARY / FALLBACK / COMPARISON) =====
openrouter_client = OpenAI(
    api_key=os.environ["OPENROUTER_API_KEY"],
    base_url="https://openrouter.ai/api/v1"
)

OPENROUTER_MODEL = "openai/gpt-4o-mini"

def call_llm(
    prompt: str,
    backend: str = "foundry",   # "foundry" | "openrouter"
    max_tokens: int = 500
):
    """
    Unified LLM call wrapper.
    You decide backend per agent / per confidence / per column.
    """

    if backend == "openrouter":
        client = openrouter_client
        model = OPENROUTER_MODEL
    else:
        client = llm_foundry_client
        model = FOUNDRY_MODEL

    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
        max_tokens=max_tokens,
        response_format={"type": "json_object"},
    )

    return json.loads(response.choices[0].message.content)
# ==========================================================
# KB RETRIEVER
# ==========================================================

class KBRetriever:
    def __init__(self, kb_dir):
        self.index = faiss.read_index(os.path.join(kb_dir, "system_kb.faiss"))
        with open(os.path.join(kb_dir, "system_kb_meta.pkl"), "rb") as f:
            self.meta = pickle.load(f)
        self.embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

    def best_hit(self, text):
        emb = self.embedder.encode([text], normalize_embeddings=True).astype("float32")
        scores, idxs = self.index.search(emb, 1)
        if idxs[0][0] < 0:
            return {}, 0.0, None
        return self.meta[idxs[0][0]], float(scores[0][0]), idxs[0][0]

retriever = KBRetriever(KB_DIR)

# ==========================================================
# MODE / CONFIDENCE
# ==========================================================

def mode_from_score(score):
    if score >= RAG_STRONG:
        return "KB_ONLY"
    if score >= RAG_MEDIUM:
        return "KB_GUIDED"
    return "LLM_ONLY"

def confidence_from_mode(mode):
    return {
        "KB_ONLY": "High (KB)",
        "KB_GUIDED": "Medium (KB+LLM)",
        "LLM_ONLY": "Low (LLM)"
    }.get(mode, "Unknown")

# ==========================================================
# AGENT STATE
# ==========================================================

class AgentState(TypedDict):
    text: str
    date: str
    url: str

    kb_meta: dict
    kb_score: float
    kb_mode: str
    kb_row_id: int | None

    row: dict
    rows: list

    messages: Annotated[List[AnyMessage], add_messages]

# ==========================================================
# AGENT 1 ‚Äî KB ROUTER
# ==========================================================

def kb_router_agent(state: AgentState):
    meta, score, row_id = retriever.best_hit(state["text"])
    return {
        "kb_meta": meta,
        "kb_score": score,
        "kb_mode": mode_from_score(score),
        "kb_row_id": row_id
    }

# ==========================================================
# AGENT 2 ‚Äî SOURCING
# ==========================================================

def sourcing_agent(state: AgentState):
    return {
        "row": {
            "Description of Contract": state["text"],
            "Contract Date": state["date"],
            "Source Link(s)": state["url"],
            "Reported Date (By SGA)": datetime.date.today().isoformat()
        }
    }

# ==========================================================
# AGENT 3 ‚Äî GEOGRAPHY
# ==========================================================

def geography_agent(state: AgentState):
    row = state["row"].copy()

    if state["kb_mode"] != "LLM_ONLY":
        for k in ["Customer Country", "Customer Region",
                  "Customer Operator", "Supplier Country",
                  "Supplier Region", "Domestic Content"]:
            if state["kb_meta"].get(k):
                row[k] = state["kb_meta"][k]
    else:
        llm = call_llm(GEOGRAPHY_PROMPT + "\n" + state["text"],backend="foundry" if state["kb_mode"] != "LLM_ONLY" else "openrouter")

    return {"row": row}

# ==========================================================
# AGENT 4 ‚Äî SYSTEM CLASSIFIER
# ==========================================================

def system_classifier_agent(state: AgentState):
    row = state["row"].copy()

    if state["kb_mode"] != "LLM_ONLY":
        for k in ["Market Segment", "System Type (General)",
                  "System Type (Specific)", "System Name (General)",
                  "System Name (Specific)", "System Piloting"]:
            if state["kb_meta"].get(k):
                row[k] = state["kb_meta"][k]
    else:
        llm = call_llm(SYSTEM_CLASSIFIER_PROMPT + "\n" + state["text"],backend="foundry" if state["kb_mode"] != "LLM_ONLY" else "openrouter")


    return {"row": row}

# ==========================================================
# AGENT 5 ‚Äî CONTRACT EXTRACTOR
# ==========================================================

def contract_extractor_agent(state: AgentState):
    row = state["row"].copy()

    if state["kb_mode"] != "LLM_ONLY":
        for k in ["Supplier Name", "Program Type", "Value (Million)",
                  "Value (USD$ Million)", "Currency",
                  "Value Certainty", "Quantity", "G2G/B2G"]:
            if state["kb_meta"].get(k):
                row[k] = state["kb_meta"][k]
    else:
        llm = call_llm(CONTRACT_EXTRACTOR_PROMPT + "\n" + state["text"],backend="openrouter")

    return {"rows": [row]}

# ==========================================================
# AGENT 6 ‚Äî EVALUATION (ALL COLUMNS)
# ==========================================================

def evaluation_agent(state: AgentState):
    evaluated = []

    for r in state["rows"]:
        row = r.copy()

        for col in list(row.keys()):
            row[f"{col} Source"] = (
                "KB" if state["kb_mode"] == "KB_ONLY"
                else "KB+LLM" if state["kb_mode"] == "KB_GUIDED"
                else "LLM"
            )
            row[f"{col} Confidence"] = confidence_from_mode(state["kb_mode"])
            row[f"{col} Reason"] = f"Mode={state['kb_mode']} KBscore={state['kb_score']:.2f}"

        row["Accuracy Score"] = 90 if state["kb_mode"] == "KB_ONLY" else 70 if state["kb_mode"] == "KB_GUIDED" else 50
        row["Evaluation Status"] = (
            "HIGH CONFIDENCE" if row["Accuracy Score"] >= 85 else
            "MEDIUM CONFIDENCE" if row["Accuracy Score"] >= 65 else
            "LOW CONFIDENCE"
        )

        evaluated.append(row)

    return {"rows": evaluated}

# ==========================================================
# AGENT 7 ‚Äî EXCEL FORMATTER
# ==========================================================

def excel_formatter_agent(state: AgentState):
    df = pd.DataFrame(state["rows"])
    df.to_excel(OUTPUT_EXCEL, index=False)

    wb = load_workbook(OUTPUT_EXCEL)
    ws = wb.active
    headers = [c.value for c in ws[1]]

    if "Accuracy Score" in headers:
        idx = headers.index("Accuracy Score") + 1
        green = PatternFill("solid", fgColor="C6EFCE")
        yellow = PatternFill("solid", fgColor="FFEB9C")
        red = PatternFill("solid", fgColor="F4CCCC")

        for r in range(2, ws.max_row + 1):
            cell = ws.cell(row=r, column=idx)
            try:
                v = int(cell.value)
            except:
                continue
            cell.fill = green if v >= 85 else yellow if v >= 65 else red

    wb.save(OUTPUT_EXCEL)
    return {}

# ==========================================================
# LANGGRAPH
# ==========================================================

graph = StateGraph(AgentState)

graph.add_node("KBRouter", kb_router_agent)
graph.add_node("Sourcing", sourcing_agent)
graph.add_node("Geography", geography_agent)
graph.add_node("System", system_classifier_agent)
graph.add_node("Contract", contract_extractor_agent)
graph.add_node("Evaluation", evaluation_agent)
graph.add_node("ExcelFormatter", excel_formatter_agent)

graph.add_edge(START, "KBRouter")
graph.add_edge("KBRouter", "Sourcing")
graph.add_edge("Sourcing", "Geography")
graph.add_edge("Geography", "System")
graph.add_edge("System", "Contract")
graph.add_edge("Contract", "Evaluation")
graph.add_edge("Evaluation", "ExcelFormatter")
graph.add_edge("ExcelFormatter", END)

app = graph.compile()
display(Image(app.get_graph().draw_mermaid_png()))
# ==========================================================
# MAIN
# ==========================================================

if __name__ == "__main__":

    df = pd.read_excel(INPUT_EXCEL)
    all_rows = []

    for i, r in df.iterrows():
        print(f"Processing row {i + 1}/{len(df)}")

        state = {
            "text": str(r["Contract Description"]),
            "date": str(r["Contract Date"]),
            "url": str(r["Source URL"]),
            "row": {},
            "rows": [],
            "messages": []
        }

        output_state = app.invoke(state)

        if "rows" in output_state and isinstance(output_state["rows"], list):
            all_rows.extend(output_state["rows"])
        elif "row" in output_state and isinstance(output_state["row"], dict):
            all_rows.append(output_state["row"])

    # ==========================
    # ENSURE ALL COLUMNS
    # ==========================
    all_columns = set()
    for row in all_rows:
        all_columns.update(row.keys())

    df_out = pd.DataFrame(all_rows)
    df_out = df_out.reindex(columns=sorted(all_columns), fill_value="")

    # ==========================
    # SAVE EXCEL
    # ==========================
    print(f"\nSaving output to: {OUTPUT_EXCEL}")
    df_out.to_excel(OUTPUT_EXCEL, index=False)

    # ==========================
    # CONDITIONAL FORMATTING
    # ==========================
    wb = load_workbook(OUTPUT_EXCEL)
    ws = wb.active

    headers = [c.value for c in ws[1]]
    if "Accuracy Score" in headers:
        idx = headers.index("Accuracy Score") + 1
        green = PatternFill("solid", fgColor="C6EFCE")
        yellow = PatternFill("solid", fgColor="FFEB9C")
        red = PatternFill("solid", fgColor="F4CCCC")

        for r in range(2, ws.max_row + 1):
            cell = ws.cell(row=r, column=idx)
            try:
                v = int(cell.value)
            except:
                continue
            cell.fill = green if v >= 85 else yellow if v >= 65 else red

    wb.save(OUTPUT_EXCEL)

    print("‚úÖ ALL COLUMNS SAVED + EXCEL GENERATED SUCCESSFULLY")
