In [None]:
import os
import json
import sys
from typing import Dict, Optional, List

from dotenv import load_dotenv

# LangChain core & community
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Azure OpenAI (chat + embeddings)
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings

# Vector store
from langchain.vectorstores import FAISS


# =========================
# Environment & Models
# =========================
# Required env vars (set these before running)
# AZURE_OPENAI_ENDPOINT=https://<your-resource>.openai.azure.com/
# AZURE_OPENAI_API_KEY=...
# AZURE_OPENAI_MODEL_NAME=<your-deployment-name>   e.g. "gpt-4o" or "gpt-4o-mini"
# AZURE_EMBEDDING_ENDPOINT=https://<your-resource>.openai.azure.com/
# AZURE_EMBEDDING_API_KEY=...
# AZURE_EMBEDDING_MODEL_NAME=<your-embeddings-deployment> e.g. "text-embedding-3-large"

load_dotenv()

def fail_if_missing(var_name: str):
    val = os.getenv(var_name)
    if not val:
        print(f"[CONFIG ERROR] Missing environment variable: {var_name}", file=sys.stderr)
        sys.exit(1)
    return val

# Validate required env vars early
fail_if_missing("AZURE_OPENAI_ENDPOINT")
fail_if_missing("AZURE_OPENAI_API_KEY")
fail_if_missing("AZURE_OPENAI_MODEL_NAME")
fail_if_missing("AZURE_EMBEDDING_ENDPOINT")
fail_if_missing("AZURE_EMBEDDING_API_KEY")
fail_if_missing("AZURE_EMBEDDING_MODEL_NAME")

# Chat models
chat_model = AzureChatOpenAI(
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    #api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version="2023-07-01-preview",
    model=os.getenv("AZURE_OPENAI_MODEL_NAME"),
    temperature=0
)

chat_model_lowtemp = AzureChatOpenAI(
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    #api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version="2023-07-01-preview",
    model=os.getenv("AZURE_OPENAI_MODEL_NAME"),
    temperature=0.1
)

# Embeddings
embeddings = AzureOpenAIEmbeddings(
    model=os.getenv("AZURE_EMBEDDING_MODEL_NAME"),
    azure_endpoint=os.getenv("AZURE_EMBEDDING_ENDPOINT"),
    #openai_api_key=os.getenv("AZURE_EMBEDDING_API_KEY"),
    openai_api_version="2024-02-01",
)

# =========================
# Configurable paths per regulator
# =========================
# Map a lowercase regulator short code to the directory containing its PDFs.
# Update these paths to match your local folders.
REGULATOR_DIR_MAP: Dict[str, str] = {
    "faa": "./_regulations/FAA_USA",
    "easa": "./_regulations/EASA_EU",
    "dgca": "./_regulations/DGCA_IN",
    # Add more as needed:
    # "tc": "./_regulations/TransportCanada",
    # "caac": "./_regulations/CAAC_China",
}

# Cache of FAISS vectorstores by regulator to avoid rebuilding
VECTORSTORE_CACHE: Dict[str, FAISS] = {}

# =========================
# 1) Regulator detection (LLM)
# =========================
DETECT_SYSTEM = SystemMessage(
    content="""You are an aviation assistant.
From the user input:
1) Extract latitude and longitude if present.
2) Identify the country where the coordinates are located (do your best from context).
3) Provide that country's civil aviation regulatory authority (official/most common name).
4) Suggest a lowercase regulator search code (no .pdf), e.g., faa, easa, dgca.
Respond ONLY in JSON:
{
  "coordinates": "lat,lon" or "INVALID",
  "country_name": "Country" or "Unknown",
  "aviation_regulator": "Authority" or "Unknown",
  "regulator_code": "shortcode" or "unknown"
}
"""
)

def detect_regulator(user_text: str) -> Dict[str, str]:
    """Use the low-temp chat model to extract coordinates, country, regulator, and code."""
    try:
        res = chat_model_lowtemp.invoke([DETECT_SYSTEM, HumanMessage(content=user_text)])
        data = json.loads(res.content)
        # Basic normalization
        for key in ["coordinates", "country_name", "aviation_regulator", "regulator_code"]:
            if key not in data:
                data[key] = "Unknown" if key != "coordinates" else "INVALID"
        if isinstance(data.get("regulator_code"), str):
            data["regulator_code"] = data["regulator_code"].lower().strip()
        return data
    except Exception:
        return {
            "coordinates": "INVALID",
            "country_name": "Unknown",
            "aviation_regulator": "Unknown",
            "regulator_code": "unknown",
        }

# =========================
# 2) Build / get retriever for a regulator
# =========================
def build_vectorstore_for_regulator(reg_code: str) -> Optional[FAISS]:
    """Load PDFs for a regulator, chunk them, and build a FAISS vectorstore."""
    folder = REGULATOR_DIR_MAP.get(reg_code)
    if not folder:
        print(f"[INFO] No folder configured for regulator '{reg_code}'.", file=sys.stderr)
        return None

    if not os.path.isdir(folder):
        print(f"[INFO] Regulator folder not found: {folder}", file=sys.stderr)
        return None

    print(f"[LOAD] Loading PDFs from: {folder}")
    loader = DirectoryLoader(
        folder,
        glob="**/*.pdf",
        loader_cls=PyPDFLoader,
        show_progress=True
    )
    documents = loader.load()
    if not documents:
        print(f"[INFO] No PDF documents found in: {folder}", file=sys.stderr)
        return None

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    docs = text_splitter.split_documents(documents)

    vs = FAISS.from_documents(docs, embeddings)
    return vs

def get_retriever(reg_code: str):
    """Return a retriever for the regulator code, building and caching if necessary."""
    if reg_code in VECTORSTORE_CACHE:
        return VECTORSTORE_CACHE[reg_code].as_retriever(search_kwargs={"k": 4})

    vs = build_vectorstore_for_regulator(reg_code)
    if vs is None:
        return None

    VECTORSTORE_CACHE[reg_code] = vs
    return vs.as_retriever(search_kwargs={"k": 4})

# =========================
# 3) Context retrieval helper
# =========================
def retrieve_context(query: str, retriever, k: int = 4) -> str:
    """Fetch top-k relevant chunks and concatenate them into a single context string."""
    if retriever is None:
        return ""
    try:
        docs = retriever.get_relevant_documents(query)
    except TypeError:
        # Some retrievers use param name 'k' in call
        docs = retriever.get_relevant_documents(query)
    docs = docs[:k] if docs else []
    context_parts: List[str] = []
    for d in docs:
        source = d.metadata.get("source", "unknown.pdf")
        page = d.metadata.get("page", "NA")
        context_parts.append(f"[Source: {source} | Page: {page}]\n{d.page_content}")
    return "\n\n".join(context_parts)

# =========================
# 4) JSON TLOF assistant system prompt (your rules)
# =========================
TLOF_SYSTEM = SystemMessage(
    content="""You are a civil aviation layout assistant that helps engineers, architects, and designers generate 3D-ready TLOF (Touchdown and Lift-Off Area) configurations from plain English descriptions, producing TLOF configurations in JSON format.
You understand and convert natural language into a structured JSON format for heliport design, including parameters like geometry, markings, lighting, safety areas, and landing markers.
You respect strict validation rules for shape types, colors, marker types, dimensions, and category activations.
Use default values where input is missing or unclear, and validate all parameters with explanations.

IMPORTANT:
- Prefer authoritative context in the provided "Regulation Context" when available (FAA/EASA/DGCA etc.). If a regulation specifies constraints (e.g., TLOF minimums), incorporate them.
- If context conflicts with the user's request, follow the regulation and explain via fields/notes if needed.

Instructions for Handling Updates:
- When the user requests an update to an existing TLOF configuration (e.g., 'update net height to 11'), respond with a minimal JSON object that includes only the updated field(s) within the same nested structure as the original (e.g., {"TLOF": [{"dimensions": {"netHeight": 11}}]}).
- Do not regenerate or include unchanged fields unless explicitly requested.
- Ensure the JSON structure remains compatible with the full TLOF configuration, preserving the TLOF array and dimensions object hierarchy.
- For new TLOF configurations, provide the full JSON structure with all required fields, using defaults for unspecified parameters.

Validation Rules:
- Validate shape types (e.g., 'Rectangle', 'Circle'), colors (e.g., 'white', 'blue'), marker types (e.g., 'V', 'H'), and dimensions.
- Use default values for missing parameters (e.g., transparency: 0, baseHeight: 0, markingColor: 'white').
- Ensure numerical values (e.g., diameter, netHeight) are positive and within realistic bounds.

Output Requirements:
- Output strictly valid JSON and nothing else.
- If you must add brief explanations, include them in a dedicated "notes" field inside the JSON.



**Example Input for New Configuration**:
Generate a rectangular TLOF for a tiltrotor aircraft with 30m x 40m dimensions, elevation 5m, rotation 15 degrees, and 0.6 transparency. Location is [139.6917, 35.6895]. Add a 'V' landing marker in blue, scaled to 8, rotated to 90 degrees.

**Example Output for New Configuration**:
{
  "TLOF": [
    {
      "position": [
        139.6917,
        35.6895
      ],
      "dimensions": {
        "unit": "m",
        "aircraft": "tiltrotor",
        "diameter": 1.0,
        "isVisible": true,
        "layerName": "Praveen_TLOF",
        "shapeType": "Rectangle",
        "scaleCategory": false,
        "textureScaleU": 1,
        "textureScaleV": 1,
        "safetyNetScaleU": 1,
        "safetyNetScaleV": 1,
        "aircraftCategory": false,
        "sides": 3,
        "width": 30,
        "length": 40,
        "height": 0.2,
        "rotation": 15,
        "transparency": 0.6,
        "baseHeight": 5,
        "markingsCategory": true,
        "markingType": "solid",
        "markingColor": "white",
        "markingThickness": 0.5,
        "dashDistance": 1,
        "dashLength": 1,
        "landingMarkerCategory": true,
        "landingMarker": "V",
        "markerScale": 8,
        "markerThickness": 0.02,
        "markerRotation": 90,
        "markerColor": "blue",
        "letterThickness": 0.5,
        "tdpcCategory": false,
        "tdpcType": "circle",
        "tdpcScale": 5,
        "tdpcThickness": 0.5,
        "tdpcRotation": 0,
        "tdpcExtrusion": 0.02,
        "tdpcColor": "white",
        "lightCategory": true,
        "lightColor": "white",
        "lightScale": 1,
        "lightDistance": 1,
        "lightRadius": 0.3,
        "lightHeight": 0.2,
        "safetyAreaCategory": false,
        "safetyAreaType": "multiplier",
        "dValue": 10,
        "multiplier": 1.5,
        "offsetDistance": 3,
        "safetyNetCategory": false,
        "curveAngle": 45,
        "netHeight": 15,
        "safetyNetTransparency": 0.5
      }
    }
  ]
}

**Example Input for Update**:
update net height to 11

**Example Output for Update**:
{
  "TLOF": [
    {
      "dimensions": {
        "netHeight": 11
      }
    }
  ]
}

"""
   
)

# =========================
# 5) JSON repair helper (ensures valid JSON if the model returns stray text)
# =========================
def ensure_valid_json(text: str) -> str:
    """Try to coerce/repair model output into valid JSON, otherwise re-ask the model to reformat."""
    try:
        json.loads(text)
        return text
    except Exception:
        # Attempt a quick trim around code fences if any
        stripped = text.strip()
        if stripped.startswith("```"):
            stripped = stripped.strip("`")
            # Remove language tag if present, then enclosing fences
            parts = stripped.split("\n", 1)
            if len(parts) == 2 and not parts[0].strip().startswith("{"):
                stripped = parts[1]
            stripped = stripped.strip()
            if stripped.endswith("```"):
                stripped = stripped[:-3].strip()
        try:
            json.loads(stripped)
            return stripped
        except Exception:
            # Ask the chat model to reformat as valid JSON only
            fix_msgs = [
                SystemMessage(content="Return the following content as strictly valid JSON. Do not add any text outside JSON."),
                HumanMessage(content=text)
            ]
            fixed = chat_model.invoke(fix_msgs).content
            # Final attempt
            json.loads(fixed)
            return fixed

# =========================
# 6) Interactive chat with memory + RAG
# =========================
def main():
    print("🚁 Welcome! I'm your TLOF assistant. Type 'exit' to quit.")
    print("Tip: include coordinates and desired dimensions (e.g., 'TLOF near 33.9428,-118.4108, need 30m x 40m').\n")

    # Memory of the conversation (LLM-style)
    memory_messages: List = [TLOF_SYSTEM]
    parser = StrOutputParser()

    # Track active regulator + retriever
    active_reg_code: Optional[str] = None
    active_reg_name: Optional[str] = None
    active_country: Optional[str] = None
    retriever = None

    # Last full JSON config (optional, if you want to track & apply patches yourself)
    # We keep it here in case you want to post-process updates; the model already returns minimal JSON per rules.
    last_full_json: Optional[dict] = None

    while True:
        user_input = input("User: ").strip()
        if user_input.lower() in ["exit", "quit", "q"]:
            break

        # 1) Detect regulator (on first run, or if user input suggests a new location)
        det = detect_regulator(user_input)
        reg_code = det.get("regulator_code", "unknown")
        country = det.get("country_name", "Unknown")
        regulator = det.get("aviation_regulator", "Unknown")

        # Print detection (informational)
        print(f"\n[Detect] Country: {country} | Regulator: {regulator} | Code: {reg_code} | Coordinates: {det.get('coordinates')}\n")

        # 2) Load / switch retriever if regulator changes and recognized
        switched_corpus = False
        if reg_code != active_reg_code and reg_code in REGULATOR_DIR_MAP:
            retriever = get_retriever(reg_code)
            active_reg_code = reg_code
            active_reg_name = regulator
            active_country = country
            switched_corpus = True

        if switched_corpus:
            print(f"[RAG] Loaded regulation corpus for '{active_reg_code}' from {REGULATOR_DIR_MAP[active_reg_code]}\n")
        elif reg_code not in REGULATOR_DIR_MAP:
            print("[RAG] No specific corpus configured for this regulator code. Proceeding without retrieval context.\n")

        # 3) Retrieve regulation context for the current user request
        #    If you want to refine retrieval, you can create a query from user_input only,
        #    or enrich with a short target like "TLOF minimum dimensions".
        retrieval_query = user_input
        regulation_context = retrieve_context(retrieval_query, retriever) if retriever else ""

        # 4) Build the final user message injecting the regulation context
        if regulation_context:
            compound_human = f"""User Input:
{user_input}

Regulation Context (authoritative excerpts):
{regulation_context}
"""
        else:
            compound_human = f"""User Input:
{user_input}

Regulation Context:
(None found. Use defaults and general rules, but prefer known standards if recalled. Do not hallucinate exact clause numbers.)
"""

        # Append to memory
        memory_messages.append(HumanMessage(content=compound_human))

        # 5) Invoke the model
        response = chat_model.invoke(memory_messages)
        raw_text = response.content

        # 6) Ensure valid JSON output only
        try:
            json_text = ensure_valid_json(raw_text)
        except Exception as e:
            # If we still fail, fall back to a minimal error JSON
            json_text = json.dumps({
                "error": "Failed to produce valid JSON",
                "details": str(e)[:300]
            })

        # Keep assistant message in memory (as returned)
        memory_messages.append(AIMessage(content=json_text))

        # 7) Display final JSON
        print("\nAssistant (JSON):\n", json_text, "\n")

        # 8) (Optional) Track last full JSON if this output looks like a full config
        try:
            parsed = json.loads(json_text)
            # Heuristics: if it has TLOF array with dimensions, consider it a full config snapshot
            if isinstance(parsed, dict) and "TLOF" in parsed:
                # If it's a minimal update, you can apply it to a maintained full state here.
                # For now, we just store whatever comes as last snapshot.
                last_full_json = parsed
        except Exception:
            pass


if __name__ == "__main__":
    main()


🚁 Welcome! I'm your TLOF assistant. Type 'exit' to quit.
Tip: include coordinates and desired dimensions (e.g., 'TLOF near 33.9428,-118.4108, need 30m x 40m').



User:  hi



[Detect] Country: Unknown | Regulator: Unknown | Code: unknown | Coordinates: INVALID

[RAG] No specific corpus configured for this regulator code. Proceeding without retrieval context.


Assistant (JSON):
 {
  "notes": "No actionable request provided. Please specify details for a TLOF configuration or update."
} 



KeyboardInterrupt: Interrupted by user