In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
import rasterio
import rasterio.plot
from rasterio.transform import rowcol
import pyproj
from shapely.geometry import Point, box
from shapely.ops import transform as shapely_transform
import json
import os
import time
from datetime import datetime
import random
import math
import warnings

# --- Configuration ---
OUTPUT_FOLDER = r"C:\work\ETS_mirza\distributed_montreal"
METADATA_FILENAME = "probability_map_metadata.json"
PROB_MAP_FILENAME = "conditional_probability_map.npy"
AVG_METRICS_MAP_FILENAME = "average_thermal_metrics.npz"
LAND_COVER_PATH = r"C:\work\ETS_mirza\distributed_montreal\dataset\landcover-2020-classification.tif"
REAL_TIME_CSV_PATH = r"C:\work\ETS_mirza\distributed_montreal\real_time.csv" # Assumed input file
OUTPUT_CSV_FILENAME = "phase2_estimation_output_mc.csv"

# --- Parameters ---
PROBABILITY_THRESHOLD = 0.5 # Example: Minimum P(Thermal|Context) to trigger metric lookup
ENERGY_NEED_THRESHOLD = 80 # Example: Battery % below which soaring is considered
MC_SAMPLES = 100 # Number of random samples per estimation cycle
AOI_SIZE_Meters = 500 # Search area radius or half-width around UAV (in meters)

# --- Load Metadata and Maps ---
METADATA_PATH = os.path.join(OUTPUT_FOLDER, METADATA_FILENAME)
PROB_MAP_PATH = os.path.join(OUTPUT_FOLDER, PROB_MAP_FILENAME)
AVG_METRICS_MAP_PATH = os.path.join(OUTPUT_FOLDER, AVG_METRICS_MAP_FILENAME)

print(f"Loading metadata from: {METADATA_PATH}")
try:
    with open(METADATA_PATH, 'r') as f:
        metadata = json.load(f)
    grid_crs_str = metadata['grid_crs']
    grid_height = metadata['grid_dimensions']['height']
    grid_width = metadata['grid_dimensions']['width']
    grid_affine = rasterio.Affine.from_gdal(*metadata['grid_affine_transform_gdal'])
    grid_resolution = metadata['grid_resolution']
    map_shape = tuple(metadata['probability_map_shape'])
    avg_metrics_shape = tuple(metadata['average_metrics_map_shape'])
    season_map = metadata['context_mappings']['season']
    tod_map = metadata['context_mappings']['time_of_day']
    lc_code_to_idx_map = {int(k):v for k,v in metadata['context_mappings']['land_cover_code_to_index'].items()} # Ensure keys are int
    season_map_inv = {v: k for k, v in season_map.items()}
    tod_map_inv = {v: k for k, v in tod_map.items()}
    lc_idx_to_code_map = {v: k for k, v in lc_code_to_idx_map.items()}
    land_cover_names_map = { # Reconstruct based on expected metadata/previous script
        1: "Needleleaf Forest", 2: "Taiga Forest", 5: "Broadleaf Forest", 6: "Mixed Forest",
        8: "Shrubland", 10: "Grassland", 11: "Shrubland-Lichen-Moss", 12: "Grassland-Lichen-Moss",
        13: "Barren-Lichen-Moss", 14: "Wetland", 15: "Cropland", 16: "Barren lands",
        17: "Urban", 18: "Water", 19: "Snow/Ice", 0: "NoData", "Unknown": 0
    }

except Exception as e:
    print(f"Error loading or processing metadata JSON file: {e}")
    exit()

print(f"Loading probability map from: {PROB_MAP_PATH}")
try:
    prob_map = np.load(PROB_MAP_PATH)
    if tuple(map_shape) != prob_map.shape:
        print(f"Warning: Prob map shape {prob_map.shape} differs from metadata {map_shape}")
except Exception as e:
    print(f"Error loading probability map NumPy file: {e}")
    exit()

print(f"Loading average metrics map from: {AVG_METRICS_MAP_PATH}")
try:
    avg_metrics_data = np.load(AVG_METRICS_MAP_PATH)
    if tuple(avg_metrics_shape) != avg_metrics_data['avg_lift'].shape: # Check one metric shape
         print(f"Warning: Avg metrics map shape differs from metadata")
except Exception as e:
    print(f"Error loading average metrics NPZ file: {e}")
    exit()

print(f"Loading Land Cover Raster: {LAND_COVER_PATH}")
try:
    lc_raster = rasterio.open(LAND_COVER_PATH)
    source_crs_raster = lc_raster.crs
    if str(source_crs_raster) != grid_crs_str: # Check CRS consistency
         print(f"ERROR: Land Cover CRS {source_crs_raster} does not match grid CRS {grid_crs_str} from metadata!")
         exit()
except Exception as e:
    print(f"Error opening land cover file: {e}")
    exit()

# --- Load Real-Time Flight Data ---
print(f"Loading flight data from: {REAL_TIME_CSV_PATH}")
try:
    flight_df = pd.read_csv(REAL_TIME_CSV_PATH, parse_dates=['timestamp_utc'])
    flight_df.sort_values(by='timestamp_utc', inplace=True)
    # Convert Lat/Lon to GeoDataFrame to handle CRS easily
    flight_gdf = gpd.GeoDataFrame(
        flight_df, geometry=gpd.points_from_xy(flight_df.longitude, flight_df.latitude), crs="EPSG:4326" # Assume WGS84 input
    )
    # Transform flight path coordinates to the grid's CRS
    print(f"Transforming flight path coordinates to {grid_crs_str}...")
    flight_gdf = flight_gdf.to_crs(grid_crs_str)
    flight_gdf['easting_grid'] = flight_gdf.geometry.x
    flight_gdf['northing_grid'] = flight_gdf.geometry.y
    print(f"Loaded and transformed {len(flight_gdf)} flight data points.")
except Exception as e:
    print(f"Error loading or processing real-time CSV file: {e}")
    exit()


# --- Helper Functions ---
def get_season_idx(month):
    if month in [6, 7, 8]: return season_map["Summer"]
    if month in [5, 9]: return season_map["Spring/Fall"]
    return None

def get_tod_idx(hour):
    if 10 <= hour < 12: return time_of_day_map["Morning"]
    if 12 <= hour < 16: return time_of_day_map["Afternoon"]
    if 16 <= hour < 18: return time_of_day_map["Late Afternoon"]
    return None

# --- Phase 2 Simulation Loop ---
print("Starting Phase 2: Real-time estimation simulation...")
start_process_time = time.time()
results = []

warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

for index, row in flight_gdf.iterrows():
    current_time = row['timestamp_utc']
    current_easting = row['easting_grid']
    current_northing = row['northing_grid']
    current_battery = row['battery_percent'] # Assumed column name

    season_idx = get_season_idx(current_time.month)
    tod_idx = get_tod_idx(current_time.hour)

    # Skip if outside valid time context for thermals
    if season_idx is None or tod_idx is None:
        results.append({"timestamp_utc": current_time, "thermal_detected": False})
        continue

    # Define Area of Interest (AOI) bounds around current location
    min_x_aoi = current_easting - AOI_SIZE_Meters
    max_x_aoi = current_easting + AOI_SIZE_Meters
    min_y_aoi = current_northing - AOI_SIZE_Meters
    max_y_aoi = current_northing + AOI_SIZE_Meters

    # Convert AOI bounds to grid indices
    # Use rasterio rowcol, ensuring we get indices relative to the *grid* affine
    try:
        row_min_aoi, col_min_aoi = rasterio.transform.rowcol(grid_affine, min_x_aoi, max_y_aoi) # Top-left
        row_max_aoi, col_max_aoi = rasterio.transform.rowcol(grid_affine, max_x_aoi, min_y_aoi) # Bottom-right

        # Clamp indices to grid boundaries
        row_min_aoi = max(0, row_min_aoi)
        col_min_aoi = max(0, col_min_aoi)
        row_max_aoi = min(grid_height - 1, row_max_aoi)
        col_max_aoi = min(grid_width - 1, col_max_aoi)

        # Check if AOI is valid
        if row_min_aoi > row_max_aoi or col_min_aoi > col_max_aoi:
             results.append({"timestamp_utc": current_time, "thermal_detected": False, "error": "Invalid AOI"})
             continue

    except IndexError: # UAV outside grid defined by metadata
        results.append({"timestamp_utc": current_time, "thermal_detected": False, "error": "Outside defined grid"})
        continue
    except Exception as e:
        results.append({"timestamp_utc": current_time, "thermal_detected": False, "error": f"AOI Index Error: {e}"})
        continue


    # Extract data for the AOI
    aoi_rows = slice(row_min_aoi, row_max_aoi + 1)
    aoi_cols = slice(col_min_aoi, col_max_aoi + 1)

    # Sample land cover for each cell in AOI (approximation: sample center)
    # This could be slow if AOI is very large - consider optimization if needed
    local_probs = np.zeros((row_max_aoi - row_min_aoi + 1, col_max_aoi - col_min_aoi + 1), dtype=np.float32)
    aoi_coords_x = []
    aoi_coords_y = []
    aoi_indices_map = {} # Map flat index to (r, c) relative to AOI

    for r_local, r_global in enumerate(range(row_min_aoi, row_max_aoi + 1)):
         for c_local, c_global in enumerate(range(col_min_aoi, col_max_aoi + 1)):
             # Get center coordinate of the grid cell
             cell_center_x, cell_center_y = grid_affine * (c_global + 0.5, r_global + 0.5)
             try:
                 # Sample land cover raster at cell center
                 lc_code_gen = lc_raster.sample([(cell_center_x, cell_center_y)], indexes=1)
                 lc_code = next(lc_code_gen)[0]
                 if lc_code == lc_raster.nodata: lc_code = 0
             except IndexError: lc_code = 0 # Outside raster coverage, treat as NoData
             except Exception: lc_code = 0 # Other errors, treat as NoData

             lc_idx = lc_code_to_idx_map.get(lc_code)

             if lc_idx is not None:
                 try:
                     prob = prob_map[r_global, c_global, season_idx, tod_idx, lc_idx]
                     local_probs[r_local, c_local] = prob
                     # Store coords and mapping for sampling
                     aoi_coords_x.append(cell_center_x)
                     aoi_coords_y.append(cell_center_y)
                     aoi_indices_map[len(aoi_coords_x)-1] = (r_local, c_local)
                 except IndexError:
                     local_probs[r_local, c_local] = 0 # Index out of bounds for prob map
             else:
                 local_probs[r_local, c_local] = 0 # Invalid LC index

    # Normalize probabilities in AOI for sampling
    prob_sum = local_probs.sum()
    if prob_sum > 1e-9: # Avoid division by zero
        sampling_weights = local_probs.flatten() / prob_sum
    else:
        # If all probabilities are zero/tiny, cannot sample meaningfully
        results.append({"timestamp_utc": current_time, "thermal_detected": False, "comment": "Zero probability in AOI"})
        continue

    # Biased Monte Carlo Sampling
    # Create flat list of indices corresponding to flattened weights
    flat_indices = list(range(len(sampling_weights)))
    try:
        # Sample ONE flat index based on weights
        chosen_flat_idx = np.random.choice(flat_indices, size=1, p=sampling_weights)[0]
        # Map flat index back to relative (r_local, c_local) and then get global coords
        r_chosen_local, c_chosen_local = aoi_indices_map[chosen_flat_idx]
        r_chosen_global = row_min_aoi + r_chosen_local
        c_chosen_global = col_min_aoi + c_chosen_local
        # Use center of the chosen cell as the estimated location
        est_thermal_x, est_thermal_y = grid_affine * (c_chosen_global + 0.5, r_chosen_global + 0.5)

    except ValueError as e: # Can happen if weights don't sum to 1 due to precision
         results.append({"timestamp_utc": current_time, "thermal_detected": False, "error": f"MC Sample Error: {e}"})
         continue
    except Exception as e:
         results.append({"timestamp_utc": current_time, "thermal_detected": False, "error": f"Unexpected MC Error: {e}"})
         continue


    # Get context for the CHOSEN location
    try:
        final_lc_code_gen = lc_raster.sample([(est_thermal_x, est_thermal_y)], indexes=1)
        final_lc_code = next(final_lc_code_gen)[0]
        if final_lc_code == lc_raster.nodata: final_lc_code = 0
        final_lc_idx = lc_code_to_idx_map.get(final_lc_code)
        if final_lc_idx is None: raise ValueError("Invalid LC Index") # Should not happen if map is consistent

        final_indices = (r_chosen_global, c_chosen_global, season_idx, tod_idx, final_lc_idx)
        final_prob = prob_map[final_indices]

    except Exception as e:
        results.append({"timestamp_utc": current_time, "thermal_detected": False, "error": f"Final Context Error: {e}"})
        continue

    # Decision Logic
    thermal_detected = False
    avg_metrics = {}
    needs_energy = current_battery < ENERGY_NEED_THRESHOLD

    if needs_energy and final_prob > PROBABILITY_THRESHOLD:
        thermal_detected = True
        try:
            avg_metrics = {
                "avg_lift_mps": avg_metrics_data['avg_lift'][final_indices],
                "avg_radius_m": avg_metrics_data['avg_radius'][final_indices],
                "avg_height_m_agl": avg_metrics_data['avg_height'][final_indices],
                "avg_duration_min": avg_metrics_data['avg_duration'][final_indices]
            }
            # Replace NaNs with None or 0? Let's use None for clarity
            for k, v in avg_metrics.items():
                 if np.isnan(v): avg_metrics[k] = None

        except IndexError:
             thermal_detected = False # Error looking up metrics
             avg_metrics = {"error": "Metrics lookup failed"}
        except Exception as e:
             thermal_detected = False
             avg_metrics = {"error": f"Metrics lookup error: {e}"}

    # Store results for this timestep
    result_row = {
        "timestamp_utc": current_time,
        "uav_lat": row['latitude'],
        "uav_lon": row['longitude'],
        "uav_alt_msl": row['altitude_msl'], # Assumed column name
        "uav_battery_percent": current_battery,
        "needs_energy": needs_energy,
        "est_thermal_prob": final_prob,
        "prob_threshold": PROBABILITY_THRESHOLD,
        "thermal_detected": thermal_detected,
        "est_thermal_easting": est_thermal_x if thermal_detected else None,
        "est_thermal_northing": est_thermal_y if thermal_detected else None,
    }
    if thermal_detected:
        result_row.update(avg_metrics)

    results.append(result_row)


print(f"Finished processing {len(flight_df)} flight data points.")
end_process_time = time.time()
print(f"Processing took {end_process_time - start_process_time:.2f} seconds.")

# --- Save Output CSV ---
if results:
    output_df = pd.DataFrame(results)
    output_path = os.path.join(OUTPUT_FOLDER, OUTPUT_CSV_FILENAME)
    print(f"Saving estimation results to: {output_path}")
    try:
        output_df.to_csv(output_path, index=False, float_format='%.6g')
        print("Save successful.")
    except Exception as e:
        print(f"Error saving output CSV: {e}")
else:
    print("No results generated.")

# --- Cleanup ---
lc_raster.close()
warnings.resetwarnings()
print("Script finished.")