<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/MCP_AAI_LLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import asyncio
import re
import json
import os
import time
import nest_asyncio # Import nest_asyncio

# Apply nest_asyncio to allow nested asyncio.run() calls
nest_asyncio.apply()

# Import Google Generative AI library
import google.generativeai as genai

# --- Configure Gemini API ---
# In a real Canvas environment, API keys might be provided differently (e.g., via environment variables).
# For this demonstration, we'll use the method you provided (assuming a Colab-like environment
# where 'GEMINI' API key is available via userdata).
# For local execution, you might set GOOGLE_API_KEY = os.environ.get('GEMINI_API_KEY')
# Or directly paste your key here for testing (NOT recommended for production):
# GOOGLE_API_KEY = "YOUR_GEMINI_API_KEY_HERE"
try:
    # This line is specific to Google Colab's userdata
    # If running locally, you might use: GOOGLE_API_KEY = os.environ.get('GEMINI_API_KEY')
    # Or directly paste your key here for testing (NOT recommended for production):
    # GOOGLE_API_KEY = "YOUR_GEMINI_API_KEY_HERE"
    from google.colab import userdata
    GOOGLE_API_KEY = userdata.get('GEMINI')
except ImportError:
    print("google.colab.userdata not found. Attempting to get API key from environment variable 'GEMINI_API_KEY'.")
    GOOGLE_API_KEY = os.environ.get('GEMINI_API_KEY')
except Exception as e:
    print(f"Error getting API key: {e}. Please ensure 'GEMINI' is set in userdata or 'GEMINI_API_KEY' in environment variables.")
    GOOGLE_API_KEY = None # Set to None if key cannot be retrieved

if GOOGLE_API_KEY:
    genai.configure(api_key=GOOGLE_API_KEY)
    print("Gemini API configured successfully.")
else:
    print("WARNING: Gemini API key not found. LLM calls will fail. Please set your API key.")

# Initialize the Gemini model
# Ensure you have access to 'gemini-1.5-pro-latest'
try:
    gemini_model = genai.GenerativeModel('gemini-1.5-pro-latest')
    print("Gemini 1.5 Pro model initialized.")
except Exception as e:
    print(f"Error initializing Gemini model: {e}. Make sure the API key is valid and the model is accessible.")
    gemini_model = None # Set to None if model cannot be initialized


# --- Simulated MCP Servers (Tools) ---
# These functions represent external services or data sources that an MCP server would expose.
# In a real scenario, these would involve actual API calls or database queries.

class SimulatedMCPServers:
    """
    A class to simulate various MCP servers providing tools for flight planning.
    """

    async def get_weather_forecast(self, location: str, date: str) -> str:
        """
        Simulates fetching weather forecast for a given location and date.
        """
        print(f"MCP Server: Fetching weather for {location} on {date}...")
        await asyncio.sleep(0.5)  # Simulate network delay
        weather_data = {
            'Montreal': 'Partly cloudy with a chance of light rain, 20°C.',
            'Shanghai': 'Clear skies, 28°C, light winds.',
            'New York': 'Sunny with a gentle breeze, 22°C.',
            'London': 'Overcast with occasional drizzle, 18°C.',
            'Toronto': 'Sunny, 21°C.',
            'Tokyo': 'Partly cloudy, 26°C, humid.',
            'Default': 'Sunny with a gentle breeze, 25°C.'
        }
        return weather_data.get(location, weather_data['Default'])

    async def get_conflict_zones(self) -> list[str]:
        """
        Simulates fetching current geopolitical conflict zones.
        """
        print("MCP Server: Checking for conflict zones...")
        await asyncio.sleep(0.4)  # Simulate network delay
        return ["Eastern Europe (active)", "Middle East (sporadic)"]

    async def calculate_flight_route(self, origin: str, destination: str, weather_info: str,
                                     conflict_zones: list[str], optimization_goal: str) -> dict:
        """
        Simulates calculating a flight route based on parameters.
        """
        print(f"MCP Server: Calculating route from {origin} to {destination} for {optimization_goal}...")
        await asyncio.sleep(0.8)  # Simulate complex calculation
        avoided_zones = f"Avoiding: {', '.join(conflict_zones)}" if conflict_zones else 'No conflict zones to avoid.'
        route_details = {
            'path': f"{origin} -> (Optimized via {optimization_goal}) -> {destination}",
            'estimated_time': '12 hours 30 minutes' if optimization_goal == 'speed' else '13 hours 45 minutes',
            'fuel_required': '120,000 kg' if optimization_goal == 'speed' else '105,000 kg',
            'notes': f"Weather: {weather_info}. {avoided_zones}.",
            'layovers': 'None required for direct flight.'
        }
        # Add some variation for Toronto to Tokyo
        if origin == 'Toronto' and destination == 'Tokyo':
             route_details['estimated_time'] = '11 hours 15 minutes' if optimization_goal == 'speed' else '12 hours 30 minutes'
             route_details['fuel_required'] = '110,000 kg' if optimization_goal == 'speed' else '95,000 kg'

        return route_details

    async def get_aircraft_performance(self, aircraft_type: str) -> dict:
        """
        Simulates fetching aircraft performance data.
        """
        print(f"MCP Server: Retrieving performance data for {aircraft_type} aircraft...")
        await asyncio.sleep(0.3)  # Simulate database lookup
        performance_data = {
            'cargo': {
                'max_range': '15,000 km',
                'cruise_speed': '900 km/h',
                'fuel_efficiency': 'High for long-haul'
            },
            'passenger': {
                'max_range': '12,000 km',
                'cruise_speed': '920 km/h',
                'fuel_efficiency': 'Medium'
            },
            'default': {
                'max_range': 'unknown',
                'cruise_speed': 'unknown',
                'fuel_efficiency': 'unknown'
            }
        }
        return performance_data.get(aircraft_type, performance_data['default'])

# --- LLM Decision-Making using Gemini 1.5 Pro ---
async def call_llm_for_decision(prompt: str) -> dict:
    """
    Calls the Gemini 1.5 Pro LLM to interpret a user prompt and suggest tool calls.
    The LLM is prompted to return a structured JSON response.
    """
    if not gemini_model:
        raise RuntimeError("Gemini model not initialized. Cannot make LLM call.")

    print(f"LLM: Sending prompt to Gemini 1.5 Pro: '{prompt}'...")

    # Define the expected JSON schema for the LLM's response
    response_schema = {
        "type": "OBJECT",
        "properties": {
            "parsed_intent": {
                "type": "OBJECT",
                "properties": {
                    "origin": {"type": "STRING"},
                    "destination": {"type": "STRING"},
                    "optimization_goal": {"type": "STRING", "enum": ["speed", "fuel efficiency", "balanced"]},
                    "aircraft_type": {"type": "STRING", "enum": ["cargo", "passenger"]},
                    "avoid_conflict_zones": {"type": "BOOLEAN"}
                },
                "required": ["origin", "destination", "optimization_goal", "aircraft_type", "avoid_conflict_zones"]
            },
            "suggested_tool_calls": {
                "type": "ARRAY",
                "items": {
                    "type": "OBJECT",
                    "properties": {
                        "tool": {"type": "STRING", "enum": [
                            "get_conflict_zones",
                            "get_weather_forecast",
                            "calculate_flight_route",
                            "get_aircraft_performance"
                        ]},
                        "params": { # Define params properties explicitly
                            "type": "OBJECT",
                            "properties": {
                                "location": {"type": "STRING"},
                                "date": {"type": "STRING"},
                                "origin": {"type": "STRING"},
                                "destination": {"type": "STRING"},
                                "weather_info": {"type": "STRING"},
                                "conflict_zones": {
                                    "type": "ARRAY",
                                    "items": {"type": "STRING"}
                                },
                                "optimization_goal": {"type": "STRING", "enum": ["speed", "fuel efficiency", "balanced"]},
                                "aircraft_type": {"type": "STRING", "enum": ["cargo", "passenger"]}
                            },
                            # No "required" for params as they vary per tool call
                        }
                    },
                    "required": ["tool", "params"]
                }
            }
        },
        "required": ["parsed_intent", "suggested_tool_calls"]
    }

    # Construct the prompt for the LLM
    llm_prompt = f"""
    You are an AI flight planning assistant. Your task is to interpret a user's request
    and output a structured JSON object containing their parsed intent and a sequence of
    suggested tool calls to fulfill the request.

    Available tools and their parameters:
    - get_weather_forecast(location: str, date: str) -> Fetches weather forecast.
    - get_conflict_zones() -> Fetches current geopolitical conflict zones.
    - calculate_flight_route(origin: str, destination: str, weather_info: str, conflict_zones: list[str], optimization_goal: str) -> Calculates flight route.
    - get_aircraft_performance(aircraft_type: str) -> Fetches aircraft performance data.

    User Request: "{prompt}"

    Please provide the output in JSON format, adhering to the following schema:
    {json.dumps(response_schema, indent=2)}
    """

    try:
        # Make the LLM call
        response = await asyncio.to_thread(
            gemini_model.generate_content,
            llm_prompt,
            generation_config=genai.GenerationConfig(
                response_mime_type="application/json",
                response_schema=response_schema
            )
        )
        # Extract the JSON string from the response and parse it
        json_string = response.candidates[0].content.parts[0].text
        llm_output = json.loads(json_string)
        print("LLM Response (JSON):", json.dumps(llm_output, indent=2))
        return llm_output
    except Exception as e:
        print(f"Error calling Gemini LLM or parsing response: {e}")
        # Fallback or error handling if LLM fails to provide structured output
        # For this demo, we'll raise an error, but in a real app, you might have a simpler fallback
        raise RuntimeError(f"LLM failed to generate valid structured response: {e}")

# --- Agent Orchestration ---
async def plan_flight_agent(prompt: str):
    """
    Simulates an AI agent orchestrating LLM decisions and MCP tool calls
    to plan a flight.
    """
    print("\n--- Initiating Flight Planning Process ---")
    mcp_servers = SimulatedMCPServers()

    try:
        # Step 1: LLM interprets the user's prompt using Gemini 1.5 Pro
        print("Agent: Sending prompt to Gemini 1.5 Pro for interpretation...")
        llm_decision = await call_llm_for_decision(prompt)
        intent = llm_decision['parsed_intent']
        print(f"LLM interpreted intent: From {intent['origin']} to {intent['destination']}, "
              f"optimize for {intent['optimization_goal']}, aircraft: {intent['aircraft_type']}, "
              f"avoid conflicts: {intent['avoid_conflict_zones']}.")

        current_weather_origin = ''
        current_weather_destination = ''
        detected_conflict_zones = []
        aircraft_performance_data = {}

        # Step 2: Agent orchestrates tool calls based on LLM's suggestions (via MCP)
        print("Agent: Orchestrating tool calls via simulated MCP servers...")

        for call in llm_decision['suggested_tool_calls']:
            if call is None: # This should ideally not happen if LLM adheres to schema
                continue

            tool_name = call['tool']
            params = call['params']

            if tool_name == 'get_conflict_zones':
                print(f"Agent: Calling MCP server for '{tool_name}'...")
                detected_conflict_zones = await mcp_servers.get_conflict_zones()
                print(f"MCP Server Response (Conflict Zones): {', '.join(detected_conflict_zones)}")
            elif tool_name == 'get_weather_forecast':
                location = params.get('location')
                date = params.get('date', 'tomorrow') # Default date if LLM doesn't provide
                if location:
                    print(f"Agent: Calling MCP server for '{tool_name}' for {location}...")
                    weather_info = await mcp_servers.get_weather_forecast(location, date)
                    # Heuristic to assign weather to origin/destination based on parsed intent
                    if location.lower() == intent['origin'].lower():
                        current_weather_origin = weather_info
                        print(f"MCP Server Response (Weather at Origin): {current_weather_origin}")
                    elif location.lower() == intent['destination'].lower():
                        current_weather_destination = weather_info
                        print(f"MCP Server Response (Weather at Destination): {current_weather_destination}")
            elif tool_name == 'get_aircraft_performance':
                aircraft_type = params.get('aircraft_type')
                if aircraft_type:
                    print(f"Agent: Calling MCP server for '{tool_name}' for {aircraft_type}...")
                    aircraft_performance_data = await mcp_servers.get_aircraft_performance(aircraft_type)
                    print(f"MCP Server Response (Aircraft Performance): Max Range: {aircraft_performance_data.get('max_range')}, Cruise Speed: {aircraft_performance_data.get('cruise_speed')}")
            # The 'calculate_flight_route' tool call is handled explicitly after gathering all data
            # to ensure all necessary parameters are available.

        # Step 3: Agent calls the final route calculation tool with gathered data
        # This step is explicitly managed by the agent after all prerequisite data is collected.
        print("Agent: All necessary context gathered. Calling 'calculate_flight_route'...")
        combined_weather = f"Origin: {current_weather_origin}, Destination: {current_weather_destination}"
        final_route = await mcp_servers.calculate_flight_route(
            intent['origin'],
            intent['destination'],
            combined_weather,
            detected_conflict_zones if intent['avoid_conflict_zones'] else [],
            intent['optimization_goal']
        )
        print("MCP Server Response (Flight Route Calculation Complete).")

        # Step 4: Agent presents the final plan
        print("\n--- Flight Plan Generated Successfully! ---")
        print("Flight Path:", final_route['path'])
        print("Estimated Time:", final_route['estimated_time'])
        print("Fuel Required:", final_route['fuel_required'])
        print("Notes:", final_route['notes'])
        print("Layovers:", final_route['layovers'])

    except Exception as e:
        print(f"\n--- ERROR during flight planning: {e} ---")

# --- Example Usage ---
if __name__ == "__main__":
    # This block handles running the async function in environments that might already have a loop.
    # With nest_asyncio.apply() at the top, asyncio.run() can be called directly.
    async def run_all_flight_plans():
        await plan_flight_agent("Plan a cargo flight from Montreal to Shanghai, optimize for speed, and avoid any conflict zones.")
        print("\n" + "="*50 + "\n")
        await plan_flight_agent("I need a passenger flight from New York to London, focusing on fuel efficiency.")
        print("\n" + "="*50 + "\n")
        await plan_flight_agent("Plan a cargo flight from Toronto to Tokyo.")

    asyncio.run(run_all_flight_plans())

Gemini API configured successfully.
Gemini 1.5 Pro model initialized.

--- Initiating Flight Planning Process ---
Agent: Sending prompt to Gemini 1.5 Pro for interpretation...
LLM: Sending prompt to Gemini 1.5 Pro: 'Plan a cargo flight from Montreal to Shanghai, optimize for speed, and avoid any conflict zones.'...
LLM Response (JSON): {
  "parsed_intent": {
    "aircraft_type": "cargo",
    "avoid_conflict_zones": true,
    "destination": "Shanghai",
    "optimization_goal": "speed",
    "origin": "Montreal"
  },
  "suggested_tool_calls": [
    {
      "params": {},
      "tool": "get_conflict_zones"
    },
    {
      "params": {
        "date": "today",
        "location": "Montreal"
      },
      "tool": "get_weather_forecast"
    },
    {
      "params": {
        "date": "today",
        "location": "Shanghai"
      },
      "tool": "get_weather_forecast"
    },
    {
      "params": {
        "origin": "Montreal",
        "weather_info": "weather forecasts for Montreal and Shan