In [None]:
# Install dependencies
%pip install -r requirements.txt


In [None]:
# Restart Python to ensure clean environment
dbutils.library.restartPython()


In [None]:
# Initialize Databricks LLM connection
from databricks_langchain import ChatDatabricks
from databricks.sdk import WorkspaceClient
import os

w = WorkspaceClient()

os.environ["DATABRICKS_HOST"] = w.config.host
os.environ["DATABRICKS_TOKEN"] = w.tokens.create(comment="for model serving", lifetime_seconds=1200).token_value

llm = ChatDatabricks(endpoint="databricks-llama-4-maverick")


In [None]:
# Import required libraries and define Pydantic models
from pydantic import BaseModel

class CommsAnalysisResponse(BaseModel):
    commsAnalysis: str


In [None]:
# Define ATC Tools with enhanced functionality
from langchain.agents import initialize_agent, Tool
from langchain.tools import tool
from langchain.agents.agent_types import AgentType
from langgraph.graph import StateGraph, END, START
from langgraph.prebuilt import ToolNode
import sqlite3
import os

@tool
def query_flight_schedule(sql: str) -> str:
    """Run a SQL query on the flight_status.db file."""
    try:
        conn = sqlite3.connect("flight_status.db")
        cursor = conn.cursor()
        cursor.execute(sql)
        rows = cursor.fetchall()
        col_names = [description[0] for description in cursor.description]
        conn.close()
        return "\n".join([str(dict(zip(col_names, row))) for row in rows]) or "No results."
    except Exception as e:
        return f"SQL Error: {str(e)}"

@tool
def query_geotracking(sql: str) -> str:
    """Run a SQL query on the geotracking.db file."""
    try:
        conn = sqlite3.connect("geo_tracking.db")
        cursor = conn.cursor()
        cursor.execute(sql)
        rows = cursor.fetchall()
        col_names = [description[0] for description in cursor.description]
        conn.close()
        return "\n".join([str(dict(zip(col_names, row))) for row in rows]) or "No results."
    except Exception as e:
        return f"SQL Error: {str(e)}"

@tool
def query_weather(sql: str) -> str:
    """Run a SQL query on the weather.db file."""
    try:
        conn = sqlite3.connect("weather.db")
        cursor = conn.cursor()
        cursor.execute(sql)
        rows = cursor.fetchall()
        columns = [desc[0] for desc in cursor.description]
        conn.close()
        return "\n".join([str(dict(zip(columns, row))) for row in rows]) or "No results."
    except Exception as e:
        return f"SQL Error in Weather Tool: {e}"
    
@tool
def comms_agent(message: str) -> CommsAnalysisResponse:
    """
    Analyze pilot communication and return LLM analysis.
    Returns simple JSON: {"commsAnalysis": "analysis text"}
    """
    from databricks_langchain import ChatDatabricks
    
    # Initialize Databricks LLM
    llm = ChatDatabricks(
        endpoint="databricks-meta-llama-3-3-70b-instruct",
    )
    
    # Simple analysis prompt
    analysis_prompt = f"""
    You are an Air Traffic Controller analyzing pilot communication. 
    Provide a brief analysis of this message including intent, urgency, and any key information extracted.

    PILOT MESSAGE: "{message}"

    Provide a concise analysis in 1-2 sentences:
    """
    
    llm = llm.with_structured_output(CommsAnalysisResponse)
    return llm.invoke(analysis_prompt)


In [None]:
# Create tool list for the agent
tools = [
    Tool(
        name="ScheduleTrackerTool",
        func=query_flight_schedule,
        description="Use this tool to query scheduled flights and detect conflicts, delays, or tight arrival overlaps. Accepts SQL input."
    ),
    Tool(
        name="GeoTrackerTool",
        func=query_geotracking,
        description="Use this tool to query geospatial data about flight phases and deviations from expected routes."
    ),
    Tool(
        name="WeatherTrackerTool",
        func=query_weather,
        description="Use this tool to query weather_by_flight table to get wind, visibility, storm/fog info, and help determine flight risk."
    ),
    Tool(
        name="CommsAnalysisTool",
        func=comms_agent,
        description="Use this tool to analyze pilot communication and provide a brief analysis of intent, urgency, and key information extracted."
    )
]


In [None]:
# Enhanced State Management following LangGraph best practices
from typing import Dict, Any, List, Optional, Annotated
from typing_extensions import TypedDict
from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.graph.message import add_messages
from langgraph.checkpoint.base import BaseCheckpointSaver
import json

class ATCState(TypedDict):
    """
    State management for tool-based ATC workflow - Following LangGraph best practices
    Focus on flight_id for all database operations based on actual schemas
    """
    # Input and identification - Using proper add_messages function for LangGraph
    messages: Annotated[List[BaseMessage], add_messages]
    pilot_callsign: Optional[str]  # Actually stores flight_id for database queries
    pilot_request: str
    
    # Tool outputs
    tool_invocations: List[Dict[str, Any]]
    tool_results: Dict[str, Any]
    
    # Final response
    atc_response: Optional[str]
    confidence_score: float
    next_actions: List[str]


In [None]:
# Enhanced LLM Tool Selection with Flight Number Focus
def llm_tool_selection(state: ATCState) -> Dict[str, Any]:
    """
    LLM-driven tool selection - Uses Databricks LLM to intelligently select which tools to call
    Focus on flight number/ID extraction and database queries
    """
    print("🤖 LLM Tool Selector: Analyzing pilot request...")
    
    messages = state.get("messages", [])
    if not messages:
        return {
            "pilot_request": "No request received",
            "tool_invocations": []
        }
    
    # Extract pilot request
    pilot_request = messages[-1].content
    print(f"📨 Pilot Request: {pilot_request}")
    
    # Use LLM to determine which tools to call and extract flight number
    llm = ChatDatabricks(endpoint="databricks-llama-4-maverick")
    
    tool_selection_prompt = f"""
    You are an Air Traffic Control (ATC) agent. Your task is to analyze a pilot's communication and decide which tools to use to gather information. You have access to several tools that query databases using SQL.

    CRITICAL: All database queries are based on FLIGHT_ID. Extract the flight number from the pilot's request and use it as flight_id in your SQL queries.

    Available Tools and their ACTUAL schemas:
    - ScheduleTrackerTool(sql: str): Query flight schedules from database. Table has columns including gate and other flight schedule information. Use flight_id to query.
    - GeoTrackerTool(sql: str): Query aircraft positions from 'geotracking' table. Primary key is flight_id. Use flight_id to query current position.
    - WeatherTrackerTool(sql: str): Query weather for flights. Schema: flight_id, lat, lon, wind_speed_kt, visibility_km, precip_mm, storm, fog, temperature_c. Use flight_id to query.
    - CommsAnalysisTool(message: str): Analyze pilot message for intent and urgency. Always call this tool.

    Pilot Request: "{pilot_request}"

    Steps to follow:
    1. Extract the flight_id from the pilot's request (e.g., "flight number UN002" → "UN002")
    2. Generate SQL queries using flight_id column
    3. Always include CommsAnalysisTool to analyze the communication
    4. Call relevant tools based on the request type

    Example for "ATC flight number UN002 we seem to me out of the expected trajectory can you suggest the correction ?":
    [
        {{"tool_name": "CommsAnalysisTool", "args": {{"message": "ATC flight number UN002 we seem to me out of the expected trajectory can you suggest the correction ?"}}}},
        {{"tool_name": "GeoTrackerTool", "args": {{"sql": "SELECT * FROM geotracking WHERE flight_id = 'UN002'"}}}},
        {{"tool_name": "ScheduleTrackerTool", "args": {{"sql": "SELECT * FROM flights WHERE flight_id = 'UN002'"}}}},
        {{"tool_name": "WeatherTrackerTool", "args": {{"sql": "SELECT * FROM weather WHERE flight_id = 'UN002'"}}}}
    ]

    Example for "Delta 123 requesting weather conditions":
    [
        {{"tool_name": "CommsAnalysisTool", "args": {{"message": "Delta 123 requesting weather conditions"}}}},
        {{"tool_name": "WeatherTrackerTool", "args": {{"sql": "SELECT flight_id, wind_speed_kt, visibility_km, precip_mm, storm, fog, temperature_c FROM weather WHERE flight_id = 'DL123'"}}}},
        {{"tool_name": "GeoTrackerTool", "args": {{"sql": "SELECT * FROM geotracking WHERE flight_id = 'DL123'"}}}}
    ]

    Example for "American 789 engine failure mayday":
    [
        {{"tool_name": "CommsAnalysisTool", "args": {{"message": "American 789 engine failure mayday"}}}},
        {{"tool_name": "GeoTrackerTool", "args": {{"sql": "SELECT * FROM geotracking WHERE flight_id = 'AA789'"}}}},
        {{"tool_name": "ScheduleTrackerTool", "args": {{"sql": "SELECT * FROM flights WHERE flight_id = 'AA789'"}}}},
        {{"tool_name": "WeatherTrackerTool", "args": {{"sql": "SELECT * FROM weather WHERE flight_id = 'AA789'"}}}}
    ]

    Example for "United 456 requesting IFR clearance":
    [
        {{"tool_name": "CommsAnalysisTool", "args": {{"message": "United 456 requesting IFR clearance"}}}},
        {{"tool_name": "ScheduleTrackerTool", "args": {{"sql": "SELECT * FROM flights WHERE flight_id = 'UA456'"}}}},
        {{"tool_name": "GeoTrackerTool", "args": {{"sql": "SELECT * FROM geotracking WHERE flight_id = 'UA456'"}}}},
        {{"tool_name": "WeatherTrackerTool", "args": {{"sql": "SELECT * FROM weather WHERE flight_id = 'UA456'"}}}}
    ]

    Now, for the given pilot request, extract the flight_id and generate the appropriate tool calls.
    Response (JSON list of tool calls only):
    """
    
    try:
        # Call Databricks LLM for tool selection
        llm_response = llm.invoke(tool_selection_prompt)
        tool_invocations = json.loads(llm_response.content.strip())
        
        # Always ensure CommsAnalysisTool is included
        if not any(t["tool_name"] == "CommsAnalysisTool" for t in tool_invocations):
            tool_invocations.insert(0, {"tool_name": "CommsAnalysisTool", "args": {"message": pilot_request}})
            
    except Exception as e:
        print(f"⚠️ LLM tool selection failed: {e}, using fallback logic")
        tool_invocations = [{"tool_name": "CommsAnalysisTool", "args": {"message": pilot_request}}]
    
    print(f"🔧 LLM Selected Tool Invocations: {json.dumps(tool_invocations, indent=2)}")
    
    # Extract flight_id for tracking (let LLM handle this in the queries)
    flight_id = "Unknown"
    for inv in tool_invocations:
        if "sql" in inv.get("args", {}):
            sql = inv["args"]["sql"]
            # Try to extract flight_id from SQL for display purposes
            import re
            match = re.search(r"flight_id = '([^']+)'", sql, re.IGNORECASE)
            if match:
                flight_id = match.group(1)
                break
    
    return {
        "pilot_callsign": flight_id,  # Actually flight_id now
        "pilot_request": pilot_request,
        "tool_invocations": tool_invocations
    }


In [None]:
# Enhanced Tool Execution with Comprehensive Logging
def execute_selected_tools(state: ATCState) -> Dict[str, Any]:
    """
    Execute the tools selected by the LLM with detailed logging
    """
    print("🔧 Executing LLM-selected tools...")
    
    tool_invocations = state.get("tool_invocations", [])
    
    tool_results = {}
    if not tool_invocations:
        print("No tools to execute.")
        return {"tool_results": tool_results}

    print("--- TOOL EXECUTION START ---")
    for invocation in tool_invocations:
        tool_name = invocation.get("tool_name")
        args = invocation.get("args")

        if not tool_name or not args:
            print(f"⚠️ Invalid tool invocation: {invocation}")
            continue

        # Find the tool to run from the global 'tools' list
        tool_to_run = next((t for t in tools if t.name == tool_name), None)

        if tool_to_run:
            try:
                result = tool_to_run.invoke(args)
                tool_results[tool_name] = result
                print(f"✅ {tool_name} executed successfully.")
                print(f"   - Args: {args}")
                print(f"   - Output: {result}")
            except Exception as e:
                error_msg = f"❌ Error executing {tool_name}: {e}"
                print(error_msg)
                tool_results[tool_name] = {"error": error_msg}
        else:
            error_msg = f"⚠️ Tool '{tool_name}' not found."
            print(error_msg)
            tool_results[tool_name] = {"error": error_msg}
    
    print("--- TOOL EXECUTION END ---")
    return {"tool_results": tool_results}


In [None]:
# Enhanced Response Generation with Professional ATC Phraseology
def llm_generate_response(state: ATCState) -> Dict[str, Any]:
    """
    Generate final ATC response using Databricks LLM based on tool results
    """
    print("🤖 Generating LLM-based ATC response...")
    
    flight_id = state.get("pilot_callsign", "Aircraft")  # Actually flight_id
    request = state.get("pilot_request", "")
    tool_results = state.get("tool_results", {})
    
    # Create comprehensive context for LLM
    tool_context = ""
    for tool_name, result in tool_results.items():
        if isinstance(result, dict):
            tool_context += f"\n{tool_name.upper()} RESULTS:\n{json.dumps(result, indent=2)}\n"
        else:
            tool_context += f"\n{tool_name.upper()} RESULTS:\n{str(result)}\n"
    
    # Create professional ATC response prompt
    response_prompt = f"""
    You are a professional Air Traffic Controller. Based on the pilot's request and the tool results, generate an appropriate ATC response using standard aviation phraseology.

    PILOT REQUEST: "{request}"
    FLIGHT ID: {flight_id}

    TOOL RESULTS:{tool_context}

    INSTRUCTIONS:
    1. Use standard ATC phraseology and terminology
    2. Be clear, concise, and professional
    3. Include relevant information from the tool results from flight database queries
    4. For trajectory corrections: Use GeoTracker current position and ScheduleTracker planned route
    5. For weather: Use WeatherTrackerTool data (wind_speed_kt, visibility_km, precip_mm, storm, fog, temperature_c)
    6. For emergencies: Prioritize safety, provide immediate assistance
    7. For clearances: Include all necessary details (heading, altitude, squawk, frequency)
    8. Address the pilot by their flight ID from the database results
    9. Use actual database results to provide specific, accurate information

    Generate ONLY the ATC radio response (no explanations):
    """
    
    try:
        # Call Databricks LLM for response generation
        llm = ChatDatabricks(endpoint="databricks-llama-4-maverick")
        llm_response = llm.invoke(response_prompt)
        atc_response = llm_response.content.strip()
        
        # Clean up any formatting issues
        if atc_response.startswith('"') and atc_response.endswith('"'):
            atc_response = atc_response[1:-1]
            
    except Exception as e:
        print(f"⚠️ LLM response generation failed: {e}, using fallback")
        atc_response = generate_fallback_response(flight_id, request, tool_results)
    
    # Generate next actions
    next_actions = determine_next_actions(tool_results)
    
    # Add the AI response to messages following LangGraph best practices
    ai_message = AIMessage(content=atc_response)
    
    return {
        "messages": [ai_message],
        "atc_response": atc_response,
        "confidence_score": 0.95,
        "next_actions": next_actions
    }

def generate_fallback_response(flight_id: str, request: str, tool_results: Dict[str, Any]) -> str:
    """Generate fallback ATC response when LLM fails"""
    request_lower = request.lower()
    
    if "mayday" in request_lower or "emergency" in request_lower:
        return f"{flight_id}, roger emergency. Squawk 7700. Turn heading 090, descend and maintain 3000 feet. Emergency services are standing by."
    elif "clearance" in request_lower:
        return f"{flight_id}, cleared as filed. Squawk 1234. Contact departure 121.9."
    elif "weather" in request_lower:
        return f"{flight_id}, current conditions: Wind 270 at 8 knots, visibility 10 miles, few clouds at 3000."
    elif "trajectory" in request_lower or "correction" in request_lower:
        return f"{flight_id}, turn left heading 270, maintain current altitude. Advise when established on new heading."
    elif "taxi" in request_lower:
        return f"{flight_id}, taxi to runway 16R via taxiway Alpha, hold short of runway."
    else:
        return f"{flight_id}, SkyLink Navigator. Go ahead with your request."

def determine_next_actions(tool_results: Dict[str, Any]) -> List[str]:
    """Determine recommended next actions based on tool results"""
    actions = []
    
    # Check for emergency
    comms_data = tool_results.get("CommsAnalysisTool", {})
    if isinstance(comms_data, dict):
        comms_analysis = comms_data.get("commsAnalysis", "")
        if "emergency" in comms_analysis.lower() or "mayday" in comms_analysis.lower():
            actions.extend([
                "Monitor emergency frequency",
                "Coordinate with emergency services",
                "Clear airspace as needed"
            ])
    
    # General actions
    actions.extend([
        "Monitor pilot readback confirmation",
        "Update flight progress strip"
    ])
    
    return actions[:3]  # Limit to 3 actions


In [None]:
# Enhanced SkyLink Navigator Class with LLM-Driven Workflow
class SkyLinkNavigator:
    """
    Main ATC Agent with Databricks LLM-driven tool selection and response generation
    """
    
    def __init__(self):
        self.graph = None
        self.tools = tools
        self.tool_node = ToolNode(self.tools)
        
        # Initialize Databricks LLM for intelligent tool selection and response generation
        self.llm = ChatDatabricks(
            endpoint="databricks-llama-4-maverick",
        )
        
        self._build_atc_workflow()
    
    def _build_atc_workflow(self):
        """
        Build LLM-driven ATC workflow: Input → LLM Tool Selection → Tool Execution → LLM Response → End
        """
        workflow = StateGraph(ATCState)

        # Main workflow nodes
        workflow.add_node("llm_tool_selector", llm_tool_selection)
        workflow.add_node("execute_tools", execute_selected_tools)
        workflow.add_node("llm_response_generator", llm_generate_response)
        
        # Set entry point
        workflow.set_entry_point("llm_tool_selector")
        
        # LLM-driven workflow
        workflow.add_edge("llm_tool_selector", "execute_tools")
        workflow.add_edge("execute_tools", "llm_response_generator")
        workflow.add_edge("llm_response_generator", END)
        
        self.graph = workflow.compile()
        print("✅ LLM-Driven ATC Workflow compiled successfully")
    
    async def process_pilot_communication(self, pilot_input: str) -> Dict[str, Any]:
        """
        Main entry point for processing pilot communications
        """
        print(f"\n🎙️ SkyLink Navigator: Processing '{pilot_input}'")
        
        initial_state = {
            "messages": [HumanMessage(content=pilot_input)],
            "pilot_callsign": None,
            "pilot_request": "",
            "tool_invocations": [],
            "tool_results": {},
            "atc_response": None,
            "confidence_score": 0.0,
            "next_actions": []
        }
        
        try:
            result = await self.graph.ainvoke(initial_state)
            
            return {
                "pilot_input": pilot_input,
                "atc_response": result.get("atc_response"),
                "callsign": result.get("pilot_callsign"),
                "tools_used": [inv["tool_name"] for inv in result.get("tool_invocations", [])],
                "confidence": result.get("confidence_score", 0.0),
                "next_actions": result.get("next_actions", []),
                "tool_results": result.get("tool_results", {})
            }
            
        except Exception as e:
            print(f"❌ Error: {e}")
            return {
                "pilot_input": pilot_input,
                "atc_response": "SkyLink Navigator technical difficulties. Please repeat request.",
                "error": str(e)
            }


In [None]:
# Initialize the Enhanced Navigator
navigator = SkyLinkNavigator()


In [None]:
# Visualize the Enhanced Workflow
from IPython.display import Image, display

display(Image(navigator.graph.get_graph().draw_mermaid_png()))


In [None]:
# Test the Enhanced System with Realistic ATC Scenario
test_input = "SkyLink, Delta 123 requesting IFR clearance to Seattle"

# Create proper initial state
initial_state = {
    "messages": [HumanMessage(content=test_input)],
    "pilot_callsign": None,
    "pilot_request": "",
    "tool_invocations": [],
    "tool_results": {},
    "atc_response": None,
    "confidence_score": 0.0,
    "next_actions": []
}

# Invoke the graph with proper state
result = navigator.graph.invoke(initial_state)
print(f"🎙️ Pilot: {test_input}")
print(f"📡 ATC: {result.get('atc_response')}")
print(f"🔧 Tools Used: {[inv['tool_name'] for inv in result.get('tool_invocations', [])]}")
print(f"📋 Next Actions: {result.get('next_actions', [])}")


In [None]:
# Enhanced Interactive Testing with Detailed Logging
def stream_graph_updates(user_input: str):
    """Stream graph updates with enhanced visibility into tool execution"""
    # Create proper initial state following LangGraph best practices
    initial_state = {
        "messages": [HumanMessage(content=user_input)],
        "pilot_callsign": None,
        "pilot_request": "",
        "tool_invocations": [],
        "tool_results": {},
        "atc_response": None,
        "confidence_score": 0.0,
        "next_actions": []
    }
    
    print(f"🎙️ Pilot: {user_input}")
    
    for event in navigator.graph.stream(initial_state):
        for node_name, value in event.items():
            print(f"📍 Node: {node_name}")
            
            # Print ATC response when available
            if "atc_response" in value and value["atc_response"]:
                print(f"📡 ATC: {value['atc_response']}")
            
            # Print tools used
            if "tool_invocations" in value and value["tool_invocations"]:
                tools_used = [inv["tool_name"] for inv in value["tool_invocations"]]
                print(f"🔧 Tools Used: {', '.join(tools_used)}")
            
            # Print next actions
            if "next_actions" in value and value["next_actions"]:
                print(f"📋 Next Actions: {', '.join(value['next_actions'])}")
    
    print("-" * 60)

# Interactive testing loop
def run_interactive_test():
    """Run interactive ATC simulation with enhanced logging"""
    print("🛫 SkyLink Navigator Enhanced Interactive Test")
    print("Enter pilot communications (or 'quit' to exit)")
    print("Examples:")
    print("  - 'Delta 123 requesting IFR clearance to Seattle'")
    print("  - 'Mayday Mayday, United 456, engine failure'")
    print("  - 'American 789 requesting weather conditions'")
    print("-" * 60)
    
    while True:
        user_input = input("Pilot: ")
        if user_input.lower() in ["quit", "exit", "q"]:
            print("✈️ Goodbye!")
            break
        
        if user_input.strip():
            stream_graph_updates(user_input)

# Uncomment the line below to run interactive testing
# run_interactive_test()


In [None]:
# Comprehensive ATC Scenario Testing
def test_atc_scenarios():
    """Test various ATC scenarios with enhanced LLM-driven tool selection"""
    
    scenarios = [
        {
            "name": "Trajectory Correction Request",
            "input": "ATC flight number UN002 we seem to me out of the expected trajectory can you suggest the correction ?"
        },
        {
            "name": "IFR Clearance Request",
            "input": "SkyLink, Delta 123 requesting IFR clearance to Seattle"
        },
        {
            "name": "Emergency Declaration", 
            "input": "Mayday Mayday, United 456, engine failure, requesting immediate assistance"
        },
        {
            "name": "Weather Request",
            "input": "American 789, requesting current weather conditions"
        },
        {
            "name": "Taxi Clearance",
            "input": "Southwest 321, ready to taxi, requesting clearance to runway"
        },
        {
            "name": "Traffic Advisory",
            "input": "Cessna N123AB, requesting traffic advisory on final approach"
        },
        {
            "name": "Landing Request",
            "input": "JetBlue 567, requesting landing clearance runway 16R"
        }
    ]
    
    print("🛫 Testing Enhanced ATC System with LLM-Driven Tools...")
    print("=" * 70)
    
    for scenario in scenarios:
        print(f"\n📻 Scenario: {scenario['name']}")
        print(f"Input: {scenario['input']}")
        print("-" * 50)
        
        # Create proper state for each test
        initial_state = {
            "messages": [HumanMessage(content=scenario['input'])],
            "pilot_callsign": None,
            "pilot_request": "",
            "tool_invocations": [],
            "tool_results": {},
            "atc_response": None,
            "confidence_score": 0.0,
            "next_actions": []
        }
        
        try:
            result = navigator.graph.invoke(initial_state)
            
            print(f"📡 ATC Response: {result.get('atc_response')}")
            print(f"🔧 Tools Used: {', '.join([inv['tool_name'] for inv in result.get('tool_invocations', [])])}")
            print(f"📊 Confidence: {result.get('confidence_score', 0)*100:.0f}%")
            
            if result.get('next_actions'):
                print(f"📋 Next Actions:")
                for action in result.get('next_actions', []):
                    print(f"   • {action}")
                    
        except Exception as e:
            print(f"❌ Error: {e}")
        
        print("=" * 70)

# Run comprehensive testing
test_atc_scenarios()
