In [4]:
from os import name
# %% Imports
import ee
import ipyleaflet
from ipyleaflet import Map, Marker, Icon, TileLayer, LayersControl, Polyline
from ipyleaflet import  Polygon as iPolygon
from shapely.geometry import Polygon as ShapelyPolygon, Point
import ipywidgets as widgets
from ipywidgets import Button, VBox, HBox, HTML, FloatSlider, IntSlider, Dropdown, Layout, Output
import random
import numpy as np
import zipfile
import geopy.distance
from cryptography.fernet import Fernet
import math
import io
import requests  # For future API integrations if needed

# %% Drone and Mission Configuration
drone_config = {
    "DJI Mavic 3": {
        "max_altitude": 500,
        "sensor_width": 13.2,  # mm
        "focal_length": 24,    # mm
        "image_width": 4000,   # pixels
        "image_height": 3000   # pixels
    },
    "Parrot Bluegrass": {
        "max_altitude": 150,
        "sensor_width": 8.8,   # mm
        "focal_length": 16,    # mm
        "image_width": 3000,   # pixels
        "image_height": 2000   # pixels
    }
}

mission_presets = {
    "Plant Counting": {"alt": 15, "overlap": 0.8},
    "Pest Detection": {"alt": 30, "overlap": 0.6},
    "Bird's Eye View": {"alt": 100, "overlap": 0.3}
}

# %% Helper Functions
def update_drone_limits(change):
    """
    Update altitude sliders based on selected drone.
    """
    drone = drone_config[change.new]
    cruise_slider = flight_params_box.children[0]
    descent_slider = flight_params_box.children[1]

    # Set max altitude to minimum of legal limit (120m) and drone capability
    max_alt = min(120, drone['max_altitude'])

    cruise_slider.max = max_alt
    descent_slider.max = max_alt

    # Ensure current values are within new limits
    cruise_slider.value = min(cruise_slider.value, max_alt)
    descent_slider.value = min(descent_slider.value, max_alt)


def calculate_photo_coverage(drone_model, altitude):
    """Calculate the ground area covered by a single photo."""
    drone = drone_config[drone_model]
    sensor_width = drone["sensor_width"]  # mm
    focal_length = drone["focal_length"]  # mm
    image_width = drone["image_width"]    # pixels
    image_height = drone["image_height"]  # pixels

    # Calculate ground sample distance (GSD) in meters per pixel
    gsd = (sensor_width * altitude) / (focal_length * image_width)

    # Calculate ground coverage in meters
    coverage_width = gsd * image_width
    coverage_height = gsd * image_height

    return coverage_width, coverage_height

def calculate_photo_spacing(coverage_width, overlap):
    """Calculate the spacing between photos to achieve desired overlap."""
    return coverage_width * (1 - overlap)

def calculate_field_area(field_coords):
    """
    Calculate the area of the field using geodesic calculations.
    """
    poly = ShapelyPolygon([(lon, lat) for lat, lon in field_coords])
    if not poly.is_valid:
        return 0

    # Use Earth Engine for accurate calculation if available
    try:
        ee_coords = [[lon, lat] for lat, lon in field_coords]
        ee_polygon = ee.Geometry.Polygon(ee_coords)
        return ee_polygon.area().getInfo()
    except:
        # Fallback to geodesic calculation
        coords = list(poly.exterior.coords)
        total_area = 0
        for i in range(len(coords)-1):
            p1 = coords[i]
            p2 = coords[i+1]
            total_area += (p2[0] - p1[0]) * (p2[1] + p1[1])/2
        return abs(total_area) * (111000 ** 2) * math.cos(math.radians(poly.centroid.y))



from shapely.geometry import Polygon, LineString

from shapely.geometry import Polygon, LineString, Point

def generate_lawnmower_waypoints(field_coords, cruise_alt, horizontal_overlap, vertical_overlap):
    """
    Generate lawnmower pattern waypoints with alternating directions, ensuring all waypoints are within the polygon.
    """
    # Get drone model from flight parameters
    drone_model = flight_params_box.children[5].value

    # Calculate photo coverage at cruise altitude
    coverage_width, coverage_height = calculate_photo_coverage(drone_model, cruise_alt)

    # Calculate spacing based on overlaps
    horizontal_spacing = coverage_width * (1 - horizontal_overlap)
    vertical_spacing = coverage_height * (1 - vertical_overlap)

    # Convert field coordinates to Shapely Polygon (order: (lon, lat))
    field_polygon = Polygon([(lon, lat) for lat, lon in field_coords])
    min_lon, min_lat, max_lon, max_lat = field_polygon.bounds

    # Approximate conversion from meters to degrees of latitude
    lat_deg_per_m = 1 / 111320.0  # ~111.32 km per degree
    # Use average latitude for longitude conversion
    avg_lat = (min_lat + max_lat) / 2.0
    lon_deg_per_m = 1 / (111320.0 * math.cos(math.radians(avg_lat)))

    # Determine if we should do horizontal lines or vertical lines
    field_width = max_lon - min_lon
    field_height = max_lat - min_lat
    is_horizontal = (field_width > field_height)

    # Calculate spacing in degrees, depending on orientation
    if is_horizontal:
        # Horizontal pattern - lines spaced vertically (latitude)
        line_spacing_deg = horizontal_spacing * lat_deg_per_m
        num_lines = int((max_lat - min_lat) / line_spacing_deg) + 1
        point_spacing_deg = vertical_spacing * lon_deg_per_m
    else:
        # Vertical pattern - lines spaced horizontally (longitude)
        line_spacing_deg = horizontal_spacing * lon_deg_per_m
        num_lines = int((max_lon - min_lon) / line_spacing_deg) + 1
        point_spacing_deg = vertical_spacing * lat_deg_per_m

    waypoints = []

    if is_horizontal:
        # Generate horizontal lines at different latitudes
        for i in range(num_lines):
            current_lat = min_lat + i * line_spacing_deg
            # Alternate direction for each line
            if i % 2 == 0:
                # Left-to-right
                line = LineString([(min_lon, current_lat), (max_lon, current_lat)])
            else:
                # Right-to-left
                line = LineString([(max_lon, current_lat), (min_lon, current_lat)])

            intersection = field_polygon.intersection(line)

            if intersection.is_empty:
                continue

            if intersection.geom_type == 'MultiLineString':
                # Process each segment of the multi-line
                for segment in intersection.geoms:
                    coords = list(segment.coords)
                    if len(coords) >= 2:
                        start = {
                            'lat': coords[0][1],
                            'lon': coords[0][0],
                            'alt': cruise_alt,
                            'type': 'cruise'
                        }
                        end = {
                            'lat': coords[-1][1],
                            'lon': coords[-1][0],
                            'alt': cruise_alt,
                            'type': 'cruise'
                        }
                        waypoints.extend([start, end])
            elif intersection.geom_type == 'LineString':
                coords = list(intersection.coords)
                if len(coords) >= 2:
                    start = {
                        'lat': coords[0][1],
                        'lon': coords[0][0],
                        'alt': cruise_alt,
                        'type': 'cruise'
                    }
                    end = {
                        'lat': coords[-1][1],
                        'lon': coords[-1][0],
                        'alt': cruise_alt,
                        'type': 'cruise'
                    }
                    waypoints.extend([start, end])

            # Add transition waypoint to the start of the next line (if it lies within the polygon)
            if i < num_lines - 1:
                next_lat = min_lat + (i + 1) * line_spacing_deg
                # Determine the transition point based on the direction of the next line
                if (i + 1) % 2 == 0:
                    # Next line is right-to-left, so transition to max_lon
                    transition_lon = max_lon
                else:
                    # Next line is left-to-right, so transition to min_lon
                    transition_lon = min_lon

                transition_point = Point(transition_lon, next_lat)

                # Check if the transition point is within the polygon
                if field_polygon.contains(transition_point):
                    transition = {
                        'lat': next_lat,
                        'lon': transition_lon,
                        'alt': cruise_alt,
                        'type': 'cruise'
                    }
                    waypoints.append(transition)
    else:
        # Generate vertical lines at different longitudes
        for i in range(num_lines):
            current_lon = min_lon + i * line_spacing_deg
            # Alternate direction for each line
            if i % 2 == 0:
                # Bottom-to-top
                line = LineString([(current_lon, min_lat), (current_lon, max_lat)])
            else:
                # Top-to-bottom
                line = LineString([(current_lon, max_lat), (current_lon, min_lat)])

            intersection = field_polygon.intersection(line)

            if intersection.is_empty:
                continue

            if intersection.geom_type == 'MultiLineString':
                # Process each segment of the multi-line
                for segment in intersection.geoms:
                    coords = list(segment.coords)
                    if len(coords) >= 2:
                        start = {
                            'lat': coords[0][1],
                            'lon': coords[0][0],
                            'alt': cruise_alt,
                            'type': 'cruise'
                        }
                        end = {
                            'lat': coords[-1][1],
                            'lon': coords[-1][0],
                            'alt': cruise_alt,
                            'type': 'cruise'
                        }
                        waypoints.extend([start, end])
            elif intersection.geom_type == 'LineString':
                coords = list(intersection.coords)
                if len(coords) >= 2:
                    start = {
                        'lat': coords[0][1],
                        'lon': coords[0][0],
                        'alt': cruise_alt,
                        'type': 'cruise'
                    }
                    end = {
                        'lat': coords[-1][1],
                        'lon': coords[-1][0],
                        'alt': cruise_alt,
                        'type': 'cruise'
                    }
                    waypoints.extend([start, end])

            # Add transition waypoint to the start of the next line (if it lies within the polygon)
            if i < num_lines - 1:
                next_lon = min_lon + (i + 1) * line_spacing_deg
                # Determine the transition point based on the direction of the next line
                if (i + 1) % 2 == 0:
                    # Next line is top-to-bottom, so transition to max_lat
                    transition_lat = max_lat
                else:
                    # Next line is bottom-to-top, so transition to min_lat
                    transition_lat = min_lat

                transition_point = Point(next_lon, transition_lat)

                # Check if the transition point is within the polygon
                if field_polygon.contains(transition_point):
                    transition = {
                        'lat': transition_lat,
                        'lon': next_lon,
                        'alt': cruise_alt,
                        'type': 'cruise'
                    }
                    waypoints.append(transition)

    return waypoints

def generate_descent_waypoints(field_coords, descent_alt, target_coverage_percentage):
    """
    Generate descent waypoints spread evenly across the field to cover a specific percentage.
    """
    field_polygon = Polygon([(lon, lat) for lat, lon in field_coords])

    # Get drone model from flight parameters
    drone_model = flight_params_box.children[5].value

    # Calculate photo coverage at descent altitude
    coverage_width, coverage_height = calculate_photo_coverage(drone_model, descent_alt)

    # Calculate field area
    field_area = calculate_field_area(field_coords)

    # Calculate target area to cover
    target_area = field_area * target_coverage_percentage

    # Calculate number of images needed
    image_area = coverage_width * coverage_height
    num_images = int(target_area / image_area)

    # Determine lawnmower pattern orientation
    min_lon, min_lat, max_lon, max_lat = field_polygon.bounds
    field_width = max_lon - min_lon
    field_height = max_lat - min_lat
    is_horizontal = (field_width > field_height)

    # Calculate spacing between images
    spacing = calculate_photo_spacing(coverage_width, 0)  # No overlap

    # Convert spacing to degrees
    avg_lat = (min_lat + max_lat) / 2.0
    lat_deg_per_m = 1 / 111320.0  # ~111.32 km per degree
    lon_deg_per_m = 1 / (111320.0 * math.cos(math.radians(avg_lat)))

    spacing_lon = spacing * lon_deg_per_m
    spacing_lat = spacing * lat_deg_per_m

    # Calculate number of rows and columns
    if is_horizontal:
        num_cols = int(math.sqrt(num_images * field_width / field_height))
        num_rows = int(num_images / num_cols) + 1
    else:
        num_rows = int(math.sqrt(num_images * field_height / field_width))
        num_cols = int(num_images / num_rows) + 1

    # Generate grid of points within the field
    waypoints = []
    images_generated = 0

    if is_horizontal:
        current_lon = min_lon
        lon_step = (max_lon - min_lon) / num_cols
        for col in range(num_cols):
            current_lat = min_lat
            lat_step = (max_lat - min_lat) / num_rows
            for row in range(num_rows):
                if images_generated >= num_images:
                    break
                point = Point(current_lon + (col * lon_step), current_lat + (row * lat_step))
                if field_polygon.contains(point):
                    waypoint = {
                        'lat': point.y,
                        'lon': point.x,
                        'alt': descent_alt,
                        'type': 'descent'
                    }
                    waypoints.append(waypoint)
                    images_generated += 1
    else:
        current_lat = min_lat
        lat_step = (max_lat - min_lat) / num_rows
        for row in range(num_rows):
            current_lon = min_lon
            lon_step = (max_lon - min_lon) / num_cols
            for col in range(num_cols):
                if images_generated >= num_images:
                    break
                point = Point(current_lon + (col * lon_step), current_lat + (row * lat_step))
                if field_polygon.contains(point):
                    waypoint = {
                        'lat': point.y,
                        'lon': point.x,
                        'alt': descent_alt,
                        'type': 'descent'
                    }
                    waypoints.append(waypoint)
                    images_generated += 1

    return waypoints





def validate_altitude(alt, drone_model):
    """Validate altitude against drone constraints and legal limit (e.g., 120m)."""
    max_alt = drone_config[drone_model]["max_altitude"]
    return min(alt, max_alt, 120)

def create_kmz(kml_content, kmz_filename):
    """Package the KML content into a KMZ file."""
    with zipfile.ZipFile(kmz_filename, 'w', zipfile.ZIP_DEFLATED) as kmz:
        kmz.writestr('doc.kml', kml_content)


def estimate_flight_time(waypoints, speed=8):
    """Estimate flight time (in hours) given the waypoints and speed (km/h)."""
    total_distance = calculate_path_length(waypoints)
    return total_distance / speed

def calculate_path_length(waypoints):
    """Calculate total flight path length (in km) using the haversine formula."""
    total_distance = 0
    for i in range(1, len(waypoints)):
        pt1 = (waypoints[i-1]['lat'], waypoints[i-1]['lon'])
        pt2 = (waypoints[i]['lat'], waypoints[i]['lon'])
        total_distance += haversine(pt1, pt2)
    return total_distance

def haversine(coord1, coord2):
    """Calculate the haversine distance (in km) between two (lat, lon) points."""
    lat1, lon1 = coord1
    lat2, lon2 = coord2
    R = 6371  # Radius of Earth in km
    dlat = math.radians(lat2 - lat1)
    dlon = math.radians(lon2 - lon1)
    a = math.sin(dlat/2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2)**2
    c = 2 * math.asin(math.sqrt(a))
    return R * c

def calculate_image_coverage(altitudes, drone_model):
    """Calculate the ground area (m²) covered by a single image based on drone specs and altitude(s)."""
    drone = drone_config[drone_model]
    sensor_width = drone["sensor_width"]  # mm
    focal_length = drone["focal_length"]  # mm
    image_width = drone["image_width"]
    image_height = drone["image_height"]

    coverage = {}
    for alt in altitudes:
        # Calculate Ground Sample Distance (GSD) in meters per pixel
        gsd = (sensor_width * alt) / (focal_length * image_width)

        # Calculate ground dimensions covered by the image
        image_width_m = gsd * image_width
        image_height_m = gsd * image_height

        coverage[alt] = image_width_m * image_height_m

    return coverage

def show_flight_path(waypoints, color='blue'):
    """
    Display the flight path with customizable color.
    """
    path_coords = [(wp['lat'], wp['lon']) for wp in waypoints]
    path = Polyline(locations=path_coords, color=color, weight=3)
    m.add_layer(path)

def encrypt_kmz(filename):
    """Encrypt the KMZ file and return the encrypted bytes and the key."""
    key = Fernet.generate_key()
    cipher_suite = Fernet(key)
    with open(filename, "rb") as f:
        encrypted = cipher_suite.encrypt(f.read())
    return encrypted, key

def show_waypoints(waypoints, m):
    """
    Display markers for waypoints on the map with labels and separate colors.
    """
    for wp in waypoints:
        if wp['type'] == 'cruise':
            color = 'blue'
            label = 'Cruise'
        elif wp['type'] == 'descent':
            color = 'green'
            label = 'Descent'
        else:
            color = 'red'
            label = 'Waypoint'
        icon_url = f"http://maps.google.com/mapfiles/ms/icons/{color}-dot.png"
        marker = Marker(location=(wp['lat'], wp['lon']),
                        icon=Icon(icon_url=icon_url, icon_size=[32, 32], icon_anchor=[16, 32]))
        marker.popup = HTML(value=label)
        m.add_layer(marker)


def create_dji_compatible_kml(field_coords, cruise_waypoints, descent_waypoints, cruise_alt, horizontal_overlap):
    """
    Create KML file with DJI-specific elements and different icons for cruise and descent waypoints.
    """
    # Get drone model from flight parameters
    drone_model = flight_params_box.children[5].value

    # Calculate photo coverage at cruise altitude
    coverage_width, coverage_height = calculate_photo_coverage(drone_model, cruise_alt)

    # Calculate photo positions for cruise waypoints
    cruise_lines = []
    for i in range(0, len(cruise_waypoints), 2):
        if i + 1 < len(cruise_waypoints):
            cruise_lines.append((cruise_waypoints[i], cruise_waypoints[i+1]))

    # Create KML content with different icons for cruise and descent
    kml_template = """<?xml version="1.0" encoding="UTF-8"?>
<kml xmlns="http://www.opengis.net/kml/2.2" xmlns:wpml="http://www.dji.com/wpmz/1.0.4">
  <Document>
    <!-- Define styles for different waypoint types -->
    <Style id="cruise_style">
      <IconStyle>
        <color>ff0000ff</color>  <!--icon for cruise waypoints -->
        <scale>0.7</scale>
        <Icon>
          <href>http://maps.google.com/mapfiles/kml/pushpin/ylw-pushpin.png</href>
        </Icon>
      </IconStyle>
      <LineStyle>
        <color>ff0000ff</color>
      </LineStyle>
    </Style>
    <Style id="descent_style">
      <IconStyle>
        <color>ff00ff00</color>  <!-- icon for descent waypoints -->
        <scale>0.5</scale>
        <Icon>
          <href>http://maps.google.com/mapfiles/kml/pal2/icon58.png</href>
        </Icon>
      </IconStyle>
      <LineStyle>
        <color>ff00ff00</color>
      </LineStyle>
    </Style>

    <wpml:missionConfig>
      <wpml:flyToWaylineMode>safely</wpml:flyToWaylineMode>
      <wpml:finishAction>goHome</wpml:finishAction>
      <wpml:exitOnRCLost>executeLostAction</wpml:exitOnRCLost>
      <wpml:executeRCLostAction>goBack</wpml:executeRCLostAction>
      <wpml:takeOffSecurityHeight>20</wpml:takeOffSecurityHeight>
      <wpml:globalTransitionalSpeed>15</wpml:globalTransitionalSpeed>
      <wpml:droneInfo>
        <wpml:droneEnumValue>77</wpml:droneEnumValue>
        <wpml:droneSubEnumValue>2</wpml:droneSubEnumValue>
      </wpml:droneInfo>
      <wpml:payloadInfo>
        <wpml:payloadEnumValue>68</wpml:payloadEnumValue>
        <wpml:payloadSubEnumValue>0</wpml:payloadSubEnumValue>
        <wpml:payloadPositionIndex>0</wpml:payloadPositionIndex>
      </wpml:payloadInfo>
    </wpml:missionConfig>

    <Folder>
      <!-- Cruise Waypoints -->
      {cruise_waypoints}
      <!-- Descent Waypoints -->
      {descent_waypoints}
    </Folder>
  </Document>
</kml>"""

    waypoint_template = """
      <Placemark>
      <name>{name}</name>
        <styleUrl>{style}</styleUrl>
        <Point>
          <coordinates>{lon},{lat}</coordinates>
        </Point>
        <wpml:index>{index}</wpml:index>
        <wpml:height>{alt}</wpml:height>
        <wpml:waypointSpeed>15</wpml:waypointSpeed>
        <wpml:waypointHeadingParam>
          <wpml:waypointHeadingMode>followWayline</wpml:waypointHeadingMode>
          <wpml:waypointHeadingAngle>0</wpml:waypointHeadingAngle>
        </wpml:waypointHeadingParam>
        <wpml:waypointTurnParam>
          <wpml:waypointTurnMode>toPointAndStopWithDiscontinuityCurvature</wpml:waypointTurnMode>
          <wpml:waypointTurnDampingDist>0</wpml:waypointTurnDampingDist>
        </wpml:waypointTurnParam>
        <wpml:gimbalPitchAngle>-90</wpml:gimbalPitchAngle>
        <wpml:useStraightLine>1</wpml:useStraightLine>
        {action_group}
        <wpml:isRisky>0</wpml:isRisky>
      </Placemark>"""

    line_template = """
      <Placemark>
        <styleUrl>cruise_style</styleUrl>
        <LineString>
          <coordinates>
            {start_lon},{start_lat} {end_lon},{end_lat}
          </coordinates>
        </LineString>
        <wpml:actionGroup>
          <wpml:actionGroupId>{index}</wpml:actionGroupId>
          <wpml:actionGroupStartIndex>{start_index}</wpml:actionGroupStartIndex>
          <wpml:actionGroupEndIndex>{end_index}</wpml:actionGroupEndIndex>
          <wpml:actionGroupMode>sequence</wpml:actionGroupMode>
          <wpml:actionTrigger>
            <wpml:actionTriggerType>distance</wpml:actionTriggerType>
            <wpml:distanceInterval>{distance_interval}</wpml:distanceInterval>
          </wpml:actionTrigger>
          <wpml:action>
            <wpml:actionId>0</wpml:actionId>
            <wpml:actionActuatorFunc>takePhoto</wpml:actionActuatorFunc>
            <wpml:actionActuatorFuncParam>
              <wpml:fileSuffix>photo_sequence</wpml:fileSuffix>
              <wpml:payloadPositionIndex>0</wpml:payloadPositionIndex>
              <wpml:useGlobalPayloadLensIndex>0</wpml:useGlobalPayloadLensIndex>
            </wpml:actionActuatorFuncParam>
          </wpml:action>
        </wpml:actionGroup>
      </Placemark>"""

    action_template = """
        <wpml:actionGroup>
          <wpml:actionGroupId>{index}</wpml:actionGroupId>
          <wpml:actionGroupStartIndex>{index}</wpml:actionGroupStartIndex>
          <wpml:actionGroupEndIndex>{index}</wpml:actionGroupEndIndex>
          <wpml:actionGroupMode>sequence</wpml:actionGroupMode>
          <wpml:actionTrigger>
            <wpml:actionTriggerType>reachPoint</wpml:actionTriggerType>
          </wpml:actionTrigger>
          <wpml:action>
            <wpml:actionId>0</wpml:actionId>
            <wpml:actionActuatorFunc>takePhoto</wpml:actionActuatorFunc>
            <wpml:actionActuatorFuncParam>
              <wpml:fileSuffix>photo_{index}</wpml:fileSuffix>
              <wpml:payloadPositionIndex>0</wpml:payloadPositionIndex>
              <wpml:useGlobalPayloadLensIndex>0</wpml:useGlobalPayloadLensIndex>
            </wpml:actionActuatorFuncParam>
          </wpml:action>
        </wpml:actionGroup>"""

    all_waypoints = []
    index = 0

    # Process cruise lines
    for i, (start_wp, end_wp) in enumerate(cruise_lines):
        # Add start and end waypoints with cruise style
        start_placemark = waypoint_template.format(
            name=f"{index}",
            style="cruise_style",
            index=index,
            lat=start_wp['lat'],
            lon=start_wp['lon'],
            alt=start_wp['alt'],
            action_group=""
        )
        all_waypoints.append(start_placemark)
        index += 1

        end_placemark = waypoint_template.format(
            name=f"{index}",
            style="cruise_style",
            index=index,
            lat=end_wp['lat'],
            lon=end_wp['lon'],
            alt=end_wp['alt'],
            action_group=""
        )
        all_waypoints.append(end_placemark)
        index += 1

        # Add line with distance trigger
        distance_interval = calculate_photo_spacing(coverage_width, horizontal_overlap)
        line_placemark = line_template.format(
            name=f"{index}",
            start_lon=start_wp['lon'],
            start_lat=start_wp['lat'],
            end_lon=end_wp['lon'],
            end_lat=end_wp['lat'],
            index=i,
            start_index=index-2,
            end_index=index-1,
            distance_interval=distance_interval
        )
        all_waypoints.append(line_placemark)

    # Process descent waypoints with descent style
    for waypoint in descent_waypoints:
        wp = waypoint_template.format(
            name=f"{index}",
            style="descent_style",
            index=index,
            lat=waypoint['lat'],
            lon=waypoint['lon'],
            alt=waypoint['alt'],
            action_group=action_template.format(index=index)
        )
        all_waypoints.append(wp)
        index += 1

    kml_content = kml_template.format(
        cruise_waypoints="\n".join(all_waypoints[:len(cruise_lines)*3]),
        descent_waypoints="\n".join(all_waypoints[len(cruise_lines)*3:])
    )

    return kml_content


# %% Map Setup and Polygon Drawing
map_center = [50.306199002283115, -112.01492786407472]
m = Map(center=map_center, zoom=16)
m.default_style = {'cursor': 'crosshair'}
ee.Authenticate()
# Add Sentinel-2 imagery
ee.Initialize(project='haider-mirai')
sentinel = (ee.ImageCollection("COPERNICUS/S2")
            .filterDate('2023-01-01', '2023-12-31')
            .median())
sentinel_rgb = sentinel.select(['B4', 'B3', 'B2'])
map_id_dict = sentinel_rgb.getMapId({'min': 0, 'max': 3000, 'bands': ['B4', 'B3', 'B2']})
tile_url = map_id_dict['tile_fetcher'].url_format
tile_layer = TileLayer(url=tile_url, name='Sentinel-2', attribution='Google Earth Engine')
m.add_layer(tile_layer)
m.add_control(LayersControl())

# Global variables for user clicks (stored as (lat, lon)) and drawing state
user_clicks = []
collecting = True

# Polygon layer with 30% opacity
polygon_layer = iPolygon(locations=[], color="red", fill_color="red", fill_opacity=0.3)
m.add_layer(polygon_layer)

def handle_map_click(**kwargs):
    """Handle map clicks to define the field polygon."""
    global collecting, user_clicks
    if not collecting:
        return
    event_type = kwargs.get('type')
    coords = kwargs.get('coordinates')  # (lat, lon)
    if event_type == 'click' and coords:
        if not user_clicks or coords != user_clicks[-1]:
            user_clicks.append(coords)
            # Create an Icon instance instead of a dict.
            marker_icon = Icon(
                icon_url="http://maps.google.com/mapfiles/kml/pushpin/ylw-pushpin.png",
                icon_size=[32, 32],
                icon_anchor=[16, 32]
            )
            marker = Marker(location=coords, icon=marker_icon)
            m.add_layer(marker)
            # Update polygon (close the loop if more than one point)
            if len(user_clicks) > 1:
                polygon_layer.locations = user_clicks + [user_clicks[0]]
            else:
                polygon_layer.locations = user_clicks
            print(f"Clicked: {coords} (Total: {len(user_clicks)})")

m.on_interaction(handle_map_click)

# %% Control Buttons and Status
done_button = Button(description="Done", button_style='success')
reset_button = Button(description="Reset")
status_html = HTML(value="<i>Click on the map to add points. Then press Done.</i>")

total_field_area = None  # Will store the total field area in km²

def finalize_polygon(btn):
    """Finalize the polygon drawing and enable flight planning."""
    global collecting, user_clicks, total_field_area
    collecting = False
    if len(user_clicks) < 3:
        status_html.value = "<b style='color:red'>At least 3 points required!</b>"
        return

    # Field coordinates in (lat, lon)
    field_coords = user_clicks

    # For Shapely, convert to (lon, lat)
    poly = ShapelyPolygon([(lon, lat) for lat, lon in field_coords])
    if not poly.is_valid:
        status_html.value = "<b style='color:red'>Invalid polygon shape!</b>"
        return

    # Calculate area using Earth Engine for accuracy
    try:
        ee_coords = [[lon, lat] for lat, lon in field_coords]
        ee_polygon = ee.Geometry.Polygon(ee_coords)
        area_m2 = ee_polygon.area().getInfo()
        total_field_area = area_m2 / 1e6  # Convert to km²
    except Exception as e:
        # Fallback to rough estimate if Earth Engine fails
        total_field_area = poly.area * 111**2  # Rough km² conversion
        status_html.value += f"<br><i>Earth Engine error: Using rough area estimate.</i>"

    status_html.value = (f"<b>Field Area: {total_field_area:.2f} km²</b><br>"
                         f"Configure flight parameters below")
    generate_btn.disabled = False

def reset_polygon(btn):
    """Reset the drawn polygon and clear markers."""
    global collecting, user_clicks
    collecting = True
    user_clicks.clear()
    polygon_layer.locations = []
    for layer in list(m.layers):
        if isinstance(layer, Marker):
            m.remove_layer(layer)
    status_html.value = "<i>Polygon reset. Click on the map to add new points.</i>"
    generate_btn.disabled = True

done_button.on_click(finalize_polygon)
reset_button.on_click(reset_polygon)
buttons_box = HBox([done_button, reset_button])

# %% Flight Planning Parameters and Additional Widgets
generate_btn = Button(
    description="Generate KMZ",
    button_style='warning',
    icon='rocket',
    disabled=True
)
output = Output()


flight_params_box = VBox([
    widgets.HTML("<b>Flight Planning Parameters:</b>"),
    FloatSlider(value=100.0, min=20.0, max=500.0, step=5.0,
                description='Cruise Alt (m):', style={'description_width': 'initial'}),
    FloatSlider(value=20.0, min=5.0, max=100.0, step=5.0,
                description='Descent Alt (m):', style={'description_width': 'initial'}),
    FloatSlider(value=0.75, min=0.0, max=1.0, step=0.05,
                description='Cruise Overlap:', style={'description_width': 'initial'}),
    FloatSlider(value=0.15, min=0.0, max=1.0, step=0.05,
                description='Descent Overlap:', style={'description_width': 'initial'}),
    Dropdown(
        options=["DJI Mavic 3", "Parrot Bluegrass"],
        value="DJI Mavic 3",
        description="Drone Model:"
    ),
    Dropdown(options=list(mission_presets.keys()), value="Plant Counting", description="Mission Preset:")
], layout=Layout(margin='20px 0 0 0'))
def update_preset(change):
    """
    Update sliders when mission preset changes.
    """
    preset = mission_presets[change.new]
    cruise_alt_slider = flight_params_box.children[0]
    overlap_slider = flight_params_box.children[2]

    cruise_alt_slider.value = preset['alt']
    overlap_slider.value = preset['overlap']

# Attach observer to preset dropdown
mission_preset_dropdown = flight_params_box.children[6]
mission_preset_dropdown.observe(update_preset, names='value')
# Attach observer to drone dropdown
drone_dropdown = flight_params_box.children[5]
drone_dropdown.observe(update_drone_limits, names='value')


def generate_flight_plan(btn):
    """Generate the flight plan and display coverage metrics."""
    global user_clicks, total_field_area
    with output:
        output.clear_output()
        if len(user_clicks) < 3:
            print("Error: No valid polygon drawn!")
            return

        field_coords = user_clicks

        # Retrieve parameters from widgets
        cruise_alt = flight_params_box.children[1].value
        descent_alt = flight_params_box.children[2].value
        cruise_horizontal_overlap = flight_params_box.children[3].value
        cruise_vertical_overlap = 0.65  # Assuming vertical overlap is fixed at 65%
        descent_overlap = flight_params_box.children[4].value
        drone_model = flight_params_box.children[5].value  # Correct index for Drone Model dropdown
        mission_type = flight_params_box.children[6].value

        # Validate altitudes
        cruise_alt = validate_altitude(cruise_alt, drone_model)
        descent_alt = validate_altitude(descent_alt, drone_model)

        try:
            # Generate cruise waypoints with specified overlaps
            cruise_waypoints = generate_lawnmower_waypoints(field_coords, cruise_alt, cruise_horizontal_overlap, cruise_vertical_overlap)

            # Generate descent waypoints
            target_coverage_percentage = 0.15
            descent_waypoints = generate_descent_waypoints(field_coords, descent_alt, target_coverage_percentage)

            # Create DJI-compatible KML
            kml_content = create_dji_compatible_kml(field_coords, cruise_waypoints, descent_waypoints, cruise_alt, cruise_horizontal_overlap)

            # Save KML to file
            kmz_filename = "flight_plan.kmz"
            create_kmz(kml_content, kmz_filename)

            # Update status
            status_html.value += (
                f"<br><b style='color:green'>KMZ generated successfully!</b><br>"
                f"Saved as: <code>{kmz_filename}</code>"
            )

            # Calculate and display metrics
            field_area = calculate_field_area(field_coords) / 1e6  # Convert to km²
            status_html.value += f"<br>Total Field Area: {field_area:.2f} km²"

            # Calculate coverage areas
            cruise_coverage = calculate_photo_coverage(drone_model, cruise_alt)[0] * calculate_photo_coverage(drone_model, cruise_alt)[1]
            descent_coverage_per_photo = calculate_photo_coverage(drone_model, descent_alt)[0] * calculate_photo_coverage(drone_model, descent_alt)[1]
            total_descent_coverage = len(descent_waypoints) * descent_coverage_per_photo / 1e6  # Convert to km²

            status_html.value += f"<br>Cruise Photo Coverage: {cruise_coverage/1e6:.4f} km² per photo"
            status_html.value += f"<br>Descent Photo Coverage: {descent_coverage_per_photo/1e6:.4f} km² per photo"
            status_html.value += f"<br>Total Descent Coverage: {total_descent_coverage:.4f} km²"

            # Calculate 15% of the field area
            fifteen_percent_area = field_area * 0.15
            status_html.value += f"<br>15% of Field Area: {fifteen_percent_area:.2f} km²"

            # Display the number of descent points generated
            num_descent_points = len(descent_waypoints)
            status_html.value += f"<br>Number of Descent Points Generated: {num_descent_points}"

            # After generating cruise_waypoints and descent_waypoints
            show_waypoints(cruise_waypoints, m)
            show_waypoints(descent_waypoints, m)

            # Estimate flight time
            flight_time = estimate_flight_time(cruise_waypoints + descent_waypoints, speed=8)
            status_html.value += f"<br>Estimated Flight Time: {flight_time:.2f} hours"

        except Exception as e:
            status_html.value += f"<br><b style='color:red'>Error: {str(e)}</b>"


generate_btn.on_click(generate_flight_plan)

# %% Assemble Full Interface and Display
full_interface = VBox([
    m,
    buttons_box,
    status_html,
    flight_params_box,
    generate_btn,
    output
])

display(full_interface)

VBox(children=(Map(center=[50.306199002283115, -112.01492786407472], controls=(ZoomControl(options=['position'…

Clicked: [50.30752833806075, -112.01658010482788] (Total: 1)
Clicked: [50.305637106915896, -112.01969146728517] (Total: 2)
Clicked: [50.30539041904784, -112.01226711273195] (Total: 3)
Clicked: [50.30836429307495, -112.01460599899293] (Total: 1)
Clicked: [50.308282068643216, -112.01814651489259] (Total: 2)
Clicked: [50.30722684248275, -112.01891899108888] (Total: 3)
Clicked: [50.30729536437286, -112.00997114181519] (Total: 4)
