<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/MCP_LLM_AGENT_FP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q transformers
!pip install -q unsloth

In [2]:
import json
import re # Import regex
from datetime import datetime, timedelta

# --- Import Unsloth and related libraries ---
import torch
from unsloth import FastLanguageModel
from transformers import AutoTokenizer

# Define the model loading parameters
MAX_SEQ_LENGTH = 2048
DTYPE = None
LOAD_IN_4BIT = True

# --- 1. Model Context Protocol (MCP) - Simulated (Unchanged) ---
class MCPConnector:
    """
    Simulates the Model Context Protocol (MCP) for interacting with external services.
    """
    def __init__(self):
        self.available_tools = {
            "get_current_weather": self._simulated_weather_api,
            "get_notams_for_airport": self._simulated_notam_api,
            "calculate_fuel_burn": self._simulated_fuel_calculator,
            "get_airspace_restrictions": self._simulated_airspace_api,
            "optimize_route": self._simulated_route_optimizer
        }

    def _simulated_weather_api(self, airport_code: str):
        """Simulates fetching weather for an airport."""
        print(f"MCP: Calling simulated weather API for {airport_code}...")
        if airport_code.upper() == "CYUL": # Montreal
            return {"status": "success", "data": "CYUL 250600Z 18005KT 9999 FEW030 15/10 Q1015 NOSIG"}
        elif airport_code.upper() == "KLAX": # Los Angeles
            return {"status": "success", "data": "KLAX 250600Z 27010KT 10SM CLR 20/12 A2992 RMK AO2 SLP135 T02000120"}
        elif airport_code.upper() == "KJFK": # JFK
            return {"status": "success", "data": "KJFK 250600Z 24008KT 10SM BKN050 18/14 A2998"}
        elif airport_code.upper() == "KSAN": # San Diego
            return {"status": "success", "data": "KSAN 250600Z 27005KT 10SM CLR 21/15 A2990"}
        elif airport_code.upper() == "KORD": # Chicago O'Hare
             return {"status": "success", "data": "KORD 250600Z 22010KT 10SM FEW040 22/16 A2987"}
        else:
            return {"status": "error", "message": "Weather data not found for this airport."}

    def _simulated_notam_api(self, airport_code: str):
        """Simulates fetching NOTAMs (Notices to Airmen) for an airport."""
        print(f"MCP: Calling simulated NOTAM API for {airport_code}...")
        if airport_code.upper() == "CYUL":
            return {"status": "success", "data": "CYUL NOTAM: RWY 06R/24L CLOSED UNTIL 260000Z. Caution birds."}
        else:
            return {"status": "success", "data": f"No significant NOTAMs for {airport_code}."}

    def _simulated_fuel_calculator(self, aircraft_type: str, distance_nm: float, conditions: str):
        """Simulates calculating fuel burn based on aircraft, distance, and conditions."""
        print(f"MCP: Calling simulated fuel calculator for {aircraft_type}, {distance_nm}nm...")
        if aircraft_type.lower() == "boeing787":
            base_fuel_per_nm = 0.35
        elif aircraft_type.lower() == "a320":
            base_fuel_per_nm = 0.4
        elif aircraft_type.lower() == "cessna172":
            base_fuel_per_nm = 0.1
        else:
            base_fuel_per_nm = 0.5

        if "heavy crosswind" in conditions:
            base_fuel_per_nm *= 1.2
        estimated_fuel_kg = distance_nm * base_fuel_per_nm
        return {"status": "success", "data": {"estimated_fuel_kg": round(estimated_fuel_kg, 2)}}

    def _simulated_airspace_api(self, departure: str, destination: str, flight_time_utc: datetime):
        """Simulates checking airspace restrictions along a route."""
        print(f"MCP: Calling simulated Airspace API for {departure} to {destination} at {flight_time_utc}...")
        if "restricted" in destination.lower():
            return {"status": "success", "data": {"restrictions": "Temporary Flight Restriction (TFR) near destination."}}
        return {"status": "success", "data": {"restrictions": "No major restrictions."}}

    def _simulated_route_optimizer(self, departure: str, destination: str, current_conditions: dict):
        """Simulates a complex route optimization engine."""
        print(f"MCP: Calling simulated Route Optimizer for {departure} to {destination}...")
        if departure == "CYUL" and destination == "KLAX":
            return {"status": "success", "data": {
                "route_waypoints": ["CYUL", "ORD", "DEN", "LAS", "KLAX"],
                "estimated_distance_nm": 2100,
                "estimated_flight_time_hrs": 4.5
            }}
        elif departure == "KJFK" and destination == "KSAN":
             return {"status": "success", "data": {
                "route_waypoints": ["KJFK", "DFW", "PHX", "KSAN"],
                "estimated_distance_nm": 2400,
                "estimated_flight_time_hrs": 5.0
             }}
        return {"status": "success", "data": {"route_waypoints": [departure, destination], "estimated_distance_nm": 1000, "estimated_flight_time_hrs": 2.0}}

    def call_tool(self, tool_name: str, **kwargs):
        """Dispatches calls to simulated tools based on the MCP."""
        if tool_name in self.available_tools:
            return self.available_tools[tool_name](**kwargs)
        else:
            return {"status": "error", "message": f"Tool '{tool_name}' not available via MCP."}

# --- 2. Large Language Model (LLM) - Concrete Implementation with Unsloth ---

class DeepSeekLLM:
    """
    Implements the LLM's capabilities using the unsloth/DeepSeek-R1-Distill-Llama-8B model.
    """
    def __init__(self, mcp_connector: MCPConnector):
        self.mcp = mcp_connector

        print("\nLLM: Loading unsloth/DeepSeek-R1-Distill-Llama-8B...")
        self.llm_model, self.llm_tokenizer = FastLanguageModel.from_pretrained(
            model_name="unsloth/DeepSeek-R1-Distill-Llama-8B",
            max_seq_length=MAX_SEQ_LENGTH,
            dtype=DTYPE,
            load_in_4bit=LOAD_IN_4BIT,
            # token=TOKEN
        )
        print("LLM loaded successfully!")

        self.flight_plan_template = {
            "departure_airport": "",
            "destination_airport": "",
            "aircraft_type": "Boeing787", # Default aircraft to Boeing787
            "departure_time_utc": "",
            "estimated_flight_time": "",
            "route_waypoints": [],
            "estimated_fuel_burn_kg": "",
            "weather_briefing": {},
            "notam_briefing": {},
            "airspace_briefing": {},
            "notes": []
        }

    def _generate_text(self, prompt: str, max_new_tokens: int = 100, temperature: float = 0.7) -> str:
        """Helper to generate text from the loaded LLM."""
        inputs = self.llm_tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
        outputs = self.llm_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True,
            pad_token_id=self.llm_tokenizer.eos_token_id
        )
        generated_text = self.llm_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        return generated_text.strip()

    def process_query(self, query: str):
        """
        Uses the LLM to understand the user query and identify flight parameters.
        This version is more robust in extracting JSON from conversational LLM output.
        """
        print(f"\nLLM: Processing query using DeepSeek-R1-Distill-Llama-8B: '{query}'")

        prompt = f"""You are an AI assistant specialized in parsing flight planning requests.
        Identify the departure airport (3-letter ICAO code), destination airport (3-letter ICAO code), and aircraft type from the following user query.
        If the aircraft type is not specified, assume it is a Boeing787.
        If any other information is missing, state what is missing.

        Query: "{query}"

        **IMPORTANT:** Your response MUST contain ONLY the JSON output, and nothing else.
        """

        try:
            # Generate text with low temperature to encourage more direct output
            raw_llm_output = self._generate_text(prompt, max_new_tokens=150, temperature=0.01)
            print(f"LLM Raw Output (attempting JSON extraction): {raw_llm_output}")

            # Use regex to find a JSON block within the LLM's output
            json_match = re.search(r"\{.*\}", raw_llm_output, re.DOTALL)

            parsed_output = {}
            if json_match:
                json_str = json_match.group(0)
                try:
                    parsed_output = json.loads(json_str)
                    print(f"Successfully extracted JSON: {parsed_output}")
                except json.JSONDecodeError as e:
                    print(f"JSON parsing failed after regex extraction: {e}. Attempting basic keyword extraction as fallback.")
                    # Fallback to keyword extraction if JSON is malformed
                    pass # Handled below
            else:
                print("No JSON block found in LLM output. Attempting basic keyword extraction as fallback.")

            # Fallback/refinement logic if JSON parsing fails or is incomplete
            if not parsed_output or parsed_output.get("action") != "plan_flight":
                # Attempt to extract info via regex from the original raw_llm_output
                departure = None
                destination = None
                aircraft = None

                # Look for ICAO codes (3-4 capital letters/numbers)
                # This is a bit more generic for airport codes
                dep_match = re.search(r'\b([A-Z]{3,4})\b.*\bto\s+([A-Z]{3,4})\b', query, re.IGNORECASE)
                if dep_match:
                    departure = dep_match.group(1).upper()
                    destination = dep_match.group(2).upper()
                else: # Try parsing JSON directly in case it's in a strange format
                    dep_match = re.search(r'departure_airport":\s*"([A-Z0-9]+)"', raw_llm_output)
                    dest_match = re.search(r'destination_airport":\s*"([A-Z0-9]+)"', raw_llm_output)
                    if dep_match: departure = dep_match.group(1).upper()
                    if dest_match: destination = dest_match.group(1).upper()

                aircraft_match = re.search(r'aircraft_type":\s*"([^"]+)"', raw_llm_output)
                if aircraft_match: aircraft = aircraft_match.group(1)
                else: # Also check if the aircraft was explicitly in the user query
                    aircraft_match_query = re.search(r'using a\s+([a-zA-Z0-9]+)', query, re.IGNORECASE)
                    if aircraft_match_query: aircraft = aircraft_match_query.group(1)

                # If aircraft is not found by LLM or query, default to Boeing787
                if not aircraft:
                    aircraft = "Boeing787"

                # Construct parsed_output based on fallback
                if departure and destination:
                    parsed_output = {"action": "plan_flight", "departure_airport": departure, "destination_airport": destination, "aircraft_type": aircraft}
                else:
                    parsed_output = {"action": "clarify", "message": "Missing key flight details (departure/destination)."}

            # Final check and update the flight_plan_template
            if parsed_output.get("action") == "plan_flight":
                self.flight_plan_template["departure_airport"] = parsed_output.get("departure_airport", "").upper()
                self.flight_plan_template["destination_airport"] = parsed_output.get("destination_airport", "").upper()
                self.flight_plan_template["aircraft_type"] = parsed_output.get("aircraft_type", "Boeing787")

            return parsed_output

        except Exception as e:
            print(f"LLM: An unexpected error occurred during LLM processing: {e}")
            return {"action": "clarify", "message": "An internal error occurred during LLM processing."}

    def generate_briefing(self, briefing_data: dict) -> str:
        """
        Uses the LLM to synthesize information into a natural language briefing.
        """
        print("\nLLM: Generating final briefing using DeepSeek-R1-Distill-Llama-8B...")
        prompt = f"""Based on the following flight plan data, generate a concise and professional flight briefing for a pilot.
        Highlight key information like route, estimated times, fuel, and any important weather or NOTAM alerts.
        The aircraft for this flight is a Boeing 787.

        Flight Plan Data:
        {json.dumps(briefing_data, indent=2)}

        Flight Briefing:
        """
        briefing_text = self._generate_text(prompt, max_new_tokens=500, temperature=0.7)

        return briefing_text

# --- 3. Agentic AI - The Orchestrator (Unchanged) ---

class FlightPlanningAgent:
    """
    An Agentic AI designed for flight planning. It uses an LLM for reasoning
    and an MCPConnector for interacting with external tools.
    """
    def __init__(self, llm: DeepSeekLLM, mcp: MCPConnector):
        self.llm = llm
        self.mcp = mcp
        self.flight_plan = {}
        self.current_conditions = {}

    def initiate_flight_plan(self, query: str):
        """Starts the flight planning process based on a user query."""
        llm_response = self.llm.process_query(query)

        if llm_response["action"] == "plan_flight":
            self.flight_plan = {
                "departure_airport": llm_response["departure_airport"],
                "destination_airport": llm_response["destination_airport"],
                "aircraft_type": llm_response.get("aircraft_type", "Boeing787"),
                "departure_time_utc": "",
                "estimated_flight_time": "",
                "route_waypoints": [],
                "estimated_fuel_burn_kg": "",
                "weather_briefing": {},
                "notam_briefing": {},
                "airspace_briefing": {},
                "notes": []
            }
            print(f"Agent: Initializing plan for {self.flight_plan['departure_airport']} to {self.flight_plan['destination_airport']} with {self.flight_plan['aircraft_type']}.")
            self._execute_planning_steps()
        else:
            print(f"Agent: LLM needs clarification: {llm_response['message']}")

    def _execute_planning_steps(self):
        """
        Orchestrates the sequence of actions to generate a flight plan.
        """
        print("\nAgent: Executing planning steps...")
        departure = self.flight_plan["departure_airport"]
        destination = self.flight_plan["destination_airport"]
        aircraft = self.flight_plan["aircraft_type"]
        current_utc_time = datetime.utcnow()
        self.flight_plan["departure_time_utc"] = current_utc_time.strftime("%Y-%m-%d %H:%M UTC")

        # Step 1: Get Weather Briefing
        print("Agent: Fetching weather data...")
        weather_dep_res = self.mcp.call_tool("get_current_weather", airport_code=departure)
        weather_dest_res = self.mcp.call_tool("get_current_weather", airport_code=destination)

        self.flight_plan["weather_briefing"] = {
            departure: weather_dep_res.get('data'),
            destination: weather_dest_res.get('data')
        }
        self.current_conditions['weather'] = f"Departure: {weather_dep_res.get('data')}, Destination: {weather_dest_res.get('data')}"
        if "crosswind" in weather_dep_res.get('data', '').lower() or "crosswind" in weather_dest_res.get('data', '').lower():
            self.current_conditions['weather'] += " (potential heavy crosswind)"


        # Step 2: Get NOTAMs
        print("Agent: Fetching NOTAMs...")
        notam_dep_res = self.mcp.call_tool("get_notams_for_airport", airport_code=departure)
        notam_dest_res = self.mcp.call_tool("get_notams_for_airport", airport_code=destination)

        self.flight_plan["notam_briefing"] = {
            departure: notam_dep_res.get('data'),
            destination: notam_dest_res.get('data')
        }
        if "RWY" in notam_dep_res.get('data', ''):
            self.flight_plan["notes"].append(f"Check NOTAMs for {departure}: {notam_dep_res.get('data')}")


        # Step 3: Optimize Route
        print("Agent: Optimizing route...")
        route_res = self.mcp.call_tool("optimize_route", departure=departure, destination=destination, current_conditions=self.current_conditions)
        if route_res["status"] == "success":
            self.flight_plan["route_waypoints"] = route_res["data"]["route_waypoints"]
            self.flight_plan["estimated_flight_time"] = f"{route_res['data']['estimated_flight_time_hrs']} hours"
            estimated_distance = route_res["data"]["estimated_distance_nm"]
            print(f"Agent: Route optimized: {self.flight_plan['route_waypoints']}")
        else:
            print(f"Agent: Route optimization failed: {route_res['message']}")
            self.flight_plan["notes"].append("Route optimization failed. Manual routing required.")
            estimated_distance = 1000 # Fallback for fuel calculation


        # Step 4: Calculate Fuel Burn
        print("Agent: Calculating fuel burn...")
        fuel_res = self.mcp.call_tool("calculate_fuel_burn",
                                       aircraft_type=aircraft,
                                       distance_nm=estimated_distance,
                                       conditions=self.current_conditions.get('weather', ''))
        if fuel_res["status"] == "success":
            self.flight_plan["estimated_fuel_burn_kg"] = fuel_res["data"]["estimated_fuel_kg"]
        else:
            print(f"Agent: Fuel calculation failed: {fuel_res['message']}")
            self.flight_plan["notes"].append("Fuel calculation failed. Verify fuel requirements manually.")

        # Step 5: Check Airspace Restrictions (future flight time)
        flight_start_time_utc = current_utc_time + timedelta(hours=1)
        print("Agent: Checking airspace restrictions...")
        airspace_res = self.mcp.call_tool("get_airspace_restrictions",
                                         departure=departure,
                                         destination=destination,
                                         flight_time_utc=flight_start_time_utc)
        if airspace_res["status"] == "success":
            self.flight_plan["airspace_briefing"] = airspace_res["data"]
            if "TFR" in airspace_res["data"].get("restrictions", ""):
                self.flight_plan["notes"].append(f"Airspace Alert: {airspace_res['data']['restrictions']}")
        else:
            print(f"Agent: Airspace check failed: {airspace_res['message']}")
            self.flight_plan["notes"].append("Airspace check failed. Verify all restrictions manually.")


        print("\nAgent: All planning steps completed.")
        final_briefing = self.llm.generate_briefing(self.flight_plan)
        print(final_briefing)


In [4]:
# --- Orchestration ---
if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("Warning: CUDA not available. Model might run on CPU, which will be very slow.")

    print("Initializing Flight Planning AI Agent (Concept) with DeepSeek-R1-Distill-Llama-8B and Boeing 787 focus...")
    mcp_connector = MCPConnector()
    llm_brain = DeepSeekLLM(mcp_connector)
    flight_agent = FlightPlanningAgent(llm_brain, mcp_connector)

    # --- Test Cases ---

    print("\n--- Test Case 1: Flight from CYUL to KLAX (explicitly B787) ---")
    user_query_1 = "Plan a flight from CYUL to KLAX using a Boeing787."
    flight_agent.initiate_flight_plan(user_query_1)

    print("\n" + "="*50 + "\n")

    print("\n--- Test Case 2: Flight from KJFK to KSAN (no aircraft specified, should default to B787) ---")
    user_query_2 = "I need to plan a flight from KJFK to KSAN."
    flight_agent.initiate_flight_plan(user_query_2)

    print("\n" + "="*50 + "\n")

    print("\n--- Test Case 3: Flight from CYUL to KORD (A320 specified) ---")
    user_query_3 = "Plan a flight from CYUL to KORD with an A320."
    flight_agent.initiate_flight_plan(user_query_3)

    print("\n" + "="*50 + "\n")

    print("\n--- Test Case 4: Clarification needed (no departure/destination) ---")
    user_query_4 = "Just plan a flight for me."
    flight_agent.initiate_flight_plan(user_query_4)

    print("\n" + "="*50 + "\n")

    print("\n--- Test Case 5: Complex query with specific instructions for LLM to handle ---")
    user_query_5 = "Can you help me prepare for a long-haul flight tomorrow from KLAX to EGLL (London Heathrow) with the new Boeing 787 Dreamliner? I want to know about fuel and a good route."
    flight_agent.initiate_flight_plan(user_query_5)

Initializing Flight Planning AI Agent (Concept) with DeepSeek-R1-Distill-Llama-8B and Boeing 787 focus...

LLM: Loading unsloth/DeepSeek-R1-Distill-Llama-8B...
==((====))==  Unsloth 2025.6.5: Fast Llama patching. Transformers: 4.52.4.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
LLM loaded successfully!

--- Test Case 1: Flight from CYUL to KLAX (explicitly B787) ---

LLM: Processing query using DeepSeek-R1-Distill-Llama-8B: 'Plan a flight from CYUL to KLAX using a Boeing787.'
LLM Raw Output (attempting JSON extraction): So, the output should be:
         {"departure_airport": "CYUL", "destination_airport": "KLAX", "aircraft_type": "Boeing787"}

        So, the 