üåç POST-DEFORESTATION LAND USE PREDICTION SYSTEM
===========================================

Author: Robert Masolele, Wageningen University

Date: 2025

Version: 1.0

OVERVIEW:
---------
This notebook provides an interactive tool for predicting land use following
deforestation using region-specific deep learning models trained on Sentinel-1,
Sentinel-2, and location encoding data.

Before we start, please have your GEE cloud project id with you. In the next cell STEP:1, You will need it to initialize Google Earth Engine (GEE)

FEATURES:
---------
All the cells contain functions, except the last cell. Run all the cells, and on the output of the last cell follow the steps below by clicking the icons on the right. I advise to draw a small area i.e. <10 km square to get results fast. I f you get error check STEP 12. TROUBLESHOOTING, otherwise createan issue on my github page.

üñºÔ∏è Draw or upload a Region of Interest (ROI) on an interactive map

üß† Automatically selects AI model based on location (Africa, Southeast Asia, Latin America)

üõ∞Ô∏è Downloads and preprocesses Sentinel-1 + Sentinel-2 + elevation + indices

üåæ Predicts land use categories over deforested areas only using ONNX models

üó∫Ô∏è Click to visualise side-by-side map of RGB imagery + follow-up land use prediction

üì§ You can also export predictions as GeoTIFF for GIS analysis,. Please follow the steps,  cheers!:smili_face_with_heart_eyes 

SUPPORTED REGIONS:
------------------
1. AFRICA: 17 input bands, 25 output classes
2. SOUTHEAST ASIA: 15 input bands, 24 output classes
3. LATIN AMERICA: 15 input bands, 22 output classes

WORKFLOW:
---------
1. Draw ROI on map ‚Üí 2. Download data ‚Üí 3. Run prediction ‚Üí 4. Visualize results



# ##1. INSTALLATION AND SETUP

In [None]:
## Run this cell once to install all required packages.

# Install required packages
!pip install earthengine-api geemap rasterio numpy matplotlib ipywidgets onnxruntime requests folium pyproj tqdm -q

# Authenticate Earth Engine
import ee
try:
    ee.Initialize()
except Exception as e:
    print("üîê Earth Engine authentication required...")
    ee.Authenticate()
    ee.Initialize(project='ENTER YOUR GEE PROJECT ID')

print("‚úÖ Installation complete!")

[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/17.4 MB[0m [31m?[0m eta [36m-:--:--[0m
[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m[90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.6/17.4 MB[0m [31m77.2 MB/s[0m eta [36m0:00:01[0m
[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m[90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m6.9/17.4 MB[0m [31m94.9 MB/s[0m eta [36m0:00:01[0m
[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m[90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m12.5/17.4 MB[0m [31m121.2 MB/s[0m eta [36m0:00:01[0m
[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m [32m17.

# ##2. IMPORT LIBRARIES

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import rasterio
from rasterio.windows import Window
from rasterio.plot import reshape_as_image
import cv2
import math
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Earth Engine and Geemap
import ee
import geemap
from ipywidgets import widgets, HBox, VBox, Layout, Button, Output, Dropdown
from IPython.display import display, clear_output, HTML, FileLink

# ONNX Runtime
import onnxruntime as ort

# Hugging Face
import requests
from huggingface_hub import hf_hub_download

# Coordinate transformation
from pyproj import Transformer

print("‚úÖ Libraries imported successfully!")

‚úÖ Libraries imported successfully!


# ##3. CONFIGURATION AND CONSTANTS

In [None]:
# Region configurations
REGION_CONFIGS = {
    'Africa': {
        'model_name': 'best_weights_att_unet_lagtime_5_Fused3_2023_totalLoss6V1_without_loss_sentAfrica6.onnx',
        'input_bands': 17,
        'output_classes': 25,
        'classes': [
            'Background', 'OLSCP', 'Pasture', 'Mining',
            'OSSCP', 'Roads', 'Forest', 'Plantation_forest',
            'Coffee', 'Build_up', 'Water', 'Oil_palm', 'Rubber', 'Cocoa',
            'Avocado', 'Soy', 'Sugar', 'Maize', 'Banana', 'Pineapple',
            'Rice', 'Wood_logging', 'Cashew', 'Tea', 'Others'
        ],
        'bbox': ee.Geometry.Rectangle([-20.0, -35.0, 55.0, 40.0]),
        'color_map': plt.cm.tab20
    },
    'Southeast Asia': {
        'model_name': 'best_weights_att_unet_lagtime_5_Fused3_2023_totalLoss6V1_without_loss_sent_Southeast_Asia23.onnx',
        'input_bands': 15,
        'output_classes': 24,
        'classes': [
            'Background', 'OLSCP', 'Pasture', 'Mining',
            'OSSCP', 'Roads', 'Forest', 'Plantation_forest',
            'Coffee', 'Build_up', 'Water', 'Oil_palm', 'Rubber', 'Cocoa',
            'Clove', 'Soy', 'Sugar', 'Maize', 'Banana', 'Pineapple',
            'Rice', 'Wood_logging', 'Cashew', 'Tea'
        ],
        'bbox': ee.Geometry.Rectangle([55.0, -10.0, 150.0, 60.0]),
        'color_map': plt.cm.tab20
    },
    'Latin America': {
        'model_name': 'best_weights_att_unet_lagtime_5_Fused3_2023_totalLoss6V1_without_loss_sent_Latin_America56.onnx',
        'input_bands': 15,
        'output_classes': 22,
        'classes': [
            'Background', 'OLSCP', 'Pasture', 'Mining',
            'OSSCP', 'Roads', 'Forest', 'Plantation_forest',
            'Coffee', 'Build_up', 'Water', 'Oil_palm', 'Rubber', 'Cocoa',
            'Avocado', 'Soy', 'Sugar', 'Maize', 'Banana', 'Pineapple',
            'Rice', 'Wood_logging'
        ],
        'bbox': ee.Geometry.Rectangle([-95.0, -55.0, -30.0, 20.0]),
        'color_map': plt.cm.tab20
    }
}

# Model parameters
PATCH_SIZE = 64

# Sentinel-2 bands (Harmonized Surface Reflectance)
S2_BANDS = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B11', 'B12']
S1_BANDS = ['VV', 'VH']

# Dates for compositing (adjust based on your needs)
START_DATE = '2024-01-01'
END_DATE = '2024-12-31'

# Output directories
os.makedirs('models', exist_ok=True)
os.makedirs('downloads', exist_ok=True)
os.makedirs('predictions', exist_ok=True)

print("‚úÖ Configuration loaded!")

‚úÖ Configuration loaded!


# ##4. HELPER FUNCTIONS

In [None]:
def load_onnx_model(region_name, cache_dir="models"):
    """
    Load ONNX model for a given region from Hugging Face.

    Args:
        region_name (str): Region name ('Africa', 'Southeast Asia', 'Latin America')
        cache_dir (str): Directory to cache downloaded models

    Returns:
        tuple: (ort.InferenceSession, dict) Loaded ONNX model session and config
    """
    print(f"üì• Loading model for region: {region_name}")

    if region_name not in REGION_CONFIGS:
        raise ValueError(f"Unknown region: {region_name}. Choose from {list(REGION_CONFIGS.keys())}")

    config = REGION_CONFIGS[region_name].copy()
    filename = config.get('model_name')

    if not filename:
        raise ValueError(f"No model filename specified for region: {region_name}")

    print(f"Model filename: {filename}")

    # Ensure cache directory exists
    os.makedirs(cache_dir, exist_ok=True)
    model_path = os.path.join(cache_dir, filename)

    # Try multiple sources for the model
    model_loaded = False

    # Source 1: Check if already downloaded
    if os.path.exists(model_path):
        print(f"‚úì Found model locally at: {model_path}")
        model_loaded = True

    # Source 2: Try Hugging Face
    if not model_loaded:
        try:
            print(f"Attempting to download from Hugging Face...")
            model_path = hf_hub_download(
                repo_id="Masolele/deforestwatch-models",
                filename=filename,
                repo_type="dataset",
                cache_dir=cache_dir,
                force_download=True
            )
            print(f"‚úì Downloaded from Hugging Face to: {model_path}")
            model_loaded = True
        except Exception as e:
            print(f"Could not download from Hugging Face: {e}")

    # Source 3: Try direct URL
    if not model_loaded:
        try:
            print(f"Attempting direct download...")
            url = f"https://huggingface.co/datasets/Masolele/deforestwatch-models/resolve/main/{filename}"
            response = requests.get(url, stream=True, timeout=30)
            if response.status_code == 200:
                with open(model_path, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                print(f"‚úì Downloaded from direct URL to: {model_path}")
                model_loaded = True
        except Exception as e:
            print(f"Direct download failed: {e}")

    if not model_loaded:
        raise FileNotFoundError(
            f"Could not find or download model {filename}. "
            f"Please download it manually from: "
            f"https://huggingface.co/datasets/Masolele/deforestwatch-models/tree/main"
        )

    # Create ONNX Runtime session
    print(f"Creating ONNX Runtime session...")
    session_options = ort.SessionOptions()
    session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

    # Try to use GPU if available
    providers = []
    if 'CUDAExecutionProvider' in ort.get_available_providers():
        providers.append('CUDAExecutionProvider')
        print("‚úì CUDA available for GPU acceleration")
    providers.append('CPUExecutionProvider')

    try:
        model_session = ort.InferenceSession(
            model_path,
            session_options,
            providers=providers
        )
    except Exception as e:
        print(f"Error creating session: {e}")
        print("Falling back to CPU only...")
        model_session = ort.InferenceSession(
            model_path,
            session_options,
            providers=['CPUExecutionProvider']
        )

    # Get model info
    input_info = model_session.get_inputs()[0]
    output_info = model_session.get_outputs()[0]

    print(f"‚úÖ Model loaded successfully!")
    print(f"   Input shape: {input_info.shape}")
    print(f"   Output shape: {output_info.shape}")

    # Update config with model info
    config['actual_input_shape'] = input_info.shape
    config['actual_output_shape'] = output_info.shape
    config['model_path'] = model_path

    return model_session, config

# ## 5. PREPROCESSING FUNCTIONS

In [None]:
# Normalization functions
def normalise_vv(raster):
    """Normalize VV band (-25 to 0 dB)"""
    raster = np.clip(raster, -25, 0)
    return (raster + 25) / 25

def normalise_vh(raster):
    """Normalize VH band (-30 to -5 dB)"""
    raster = np.clip(raster, -30, -5)
    return (raster + 30) / 25

def normalise_longitude(raster):
    """Normalize longitude values (-180 to 180)"""
    raster = np.clip(raster, -180, 180)
    return (raster + 180) / 360

def normalise_latitude(raster):
    """Normalize latitude values (-60 to 60)"""
    raster = np.clip(raster, -60, 60)
    return (raster + 60) / 120

def normalise_altitude(raster):
    """Normalize elevation values (-400 to 8000 m)"""
    raster = np.clip(raster, -400, 8000)
    return (raster + 400) / 8400

def normalise_ndre(raster):
    """Normalize NDRE values (-1 to 1)"""
    raster = np.clip(raster, -1, 1)
    return (raster + 1) / 2

def normalise_evi(raster):
    """Normalize EVI values (-1 to 1)"""
    raster = np.clip(raster, -1, 1)
    return (raster + 1) / 2

def normalise_ndvi(raster):
    """Normalize NDVI values (-1 to 1)"""
    raster = np.clip(raster, -1, 1)
    return (raster + 1) / 2

def norm_optical(image):
    """
    Normalize optical bands using log-sigmoid transformation.
    """
    NORM_PERCENTILES = np.array([
        [1.7417, 2.0233], [1.7261, 2.0389], [1.6798, 2.1796],
        [2.3829, 2.7578], [1.7417, 2.0233], [1.7417, 2.0233],
        [1.7417, 2.0233], [1.7417, 2.0233], [1.7417, 2.0233]
    ])

    image = np.log(image * 0.005 + 1)
    image = (image - NORM_PERCENTILES[:, 0]) / NORM_PERCENTILES[:, 1]
    image = np.exp(image * 5 - 1)
    image = image / (image + 1)
    return image

def extract_lat_lon(image_path):
    """
    Extract latitude and longitude values for each pixel.

    Args:
        image_path (str): Path to the georeferenced image

    Returns:
        tuple: (latitudes, longitudes) arrays
    """
    try:
        with rasterio.open(image_path) as src:
            transform = src.transform
            crs = src.crs
            height, width = src.shape

            rows, cols = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
            xs, ys = rasterio.transform.xy(transform, rows, cols)
            xs = np.array(xs)
            ys = np.array(ys)

            if crs.to_string() != "EPSG:4326":
                transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
                longitudes, latitudes = transformer.transform(xs, ys)
            else:
                longitudes, latitudes = xs, ys

            return latitudes, longitudes
    except Exception as e:
        print(f"Warning: Could not extract lat/lon from {image_path}: {e}")
        # Return dummy coordinates
        height, width = 100, 100
        lon = np.linspace(-180, 180, width)
        lat = np.linspace(-90, 90, height)
        lon_grid, lat_grid = np.meshgrid(lon, lat)
        return lat_grid, lon_grid

def preprocess_africa(x_img, image_path=None, transform=None, crs=None):
    """
    Preprocess image for Africa model (17 bands).

    Args:
        x_img: Input image array (H, W, C)
        image_path: Path to image file (optional)
        transform: Geotransform (optional)
        crs: Coordinate reference system (optional)

    Returns:
        np.ndarray: Preprocessed image (H, W, 17)
    """
    # Extract optical bands (0-8)
    optical = x_img[:, :, :9].astype(np.float32)
    optical = np.where(optical < 0, 0, optical)
    optical_norm = norm_optical(optical)

    # SAR bands (9-10)
    vv = normalise_vv(x_img[:, :, 9].astype(np.float32))
    vh = normalise_vh(x_img[:, :, 10].astype(np.float32))

    # Elevation (11)
    alt = normalise_altitude(x_img[:, :, 11].astype(np.float32))

    # Coordinates
    if transform is not None and crs is not None:
        height, width = x_img.shape[:2]
        rows, cols = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
        xs, ys = rasterio.transform.xy(transform, rows, cols)

        # Convert to 2D arrays
        xs = np.array(xs).reshape(height, width)
        ys = np.array(ys).reshape(height, width)

        if crs.to_string() != "EPSG:4326":
            transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
            longitudes, latitudes = transformer.transform(xs, ys)
        else:
            longitudes, latitudes = xs, ys

        lon_norm = normalise_longitude(longitudes)
        lat_norm = normalise_latitude(latitudes)

    # Vegetation indices
    red_edge1 = optical[:, :, 3]
    nir = optical[:, :, 6]
    red = optical[:, :, 2]
    blue = optical[:, :, 0]

    ndvi = np.where((nir + red) == 0, 0, (nir - red) / (nir + red))
    evi = np.where((nir + red) == 0, 0, 2.5 * ((nir - red) / (nir + 6 * red - 7.5 * blue + 1)))
    ndre = np.where((nir + red_edge1) == 0, 0, (nir - red_edge1) / (nir + red_edge1))

    evi_norm = normalise_evi(evi)
    ndre_norm = normalise_ndre(ndre)

    # Ensure 3D shape
    ndvi = ndvi[:, :, np.newaxis] if ndvi.ndim == 2 else ndvi
    evi_norm = evi_norm[:, :, np.newaxis] if evi_norm.ndim == 2 else evi_norm
    ndre_norm = ndre_norm[:, :, np.newaxis] if ndre_norm.ndim == 2 else ndre_norm
    vv = vv[:, :, np.newaxis] if vv.ndim == 2 else vv
    vh = vh[:, :, np.newaxis] if vh.ndim == 2 else vh
    alt = alt[:, :, np.newaxis] if alt.ndim == 2 else alt
    #lon_norm = lon_norm[:, :, np.newaxis] if lon_norm.ndim == 2 else lon_norm
    #lat_norm = lat_norm[:, :, np.newaxis] if lat_norm.ndim == 2 else lat_norm

    # For Debugging: Print shapes before concatenation
    print(f"lon_norm shape: {lon_norm.shape}, ndim: {lon_norm.ndim}")
    print(f"lat_norm shape: {lat_norm.shape}, ndim: {lat_norm.ndim}")

    # Force reshape to 3D if needed
    h, w = x_img.shape[:2]
    if lon_norm.ndim != 3:
        lon_norm = lon_norm.reshape(h, w, 1)
    if lat_norm.ndim != 3:
        lat_norm = lat_norm.reshape(h, w, 1)

    # Concatenate all bands
    image = np.concatenate([
        optical_norm,      # 9
        ndvi,              # 1
        ndre_norm,         # 1
        evi_norm,          # 1
        vv,                # 1
        vh,                # 1
        alt,               # 1
        lon_norm,          # 1
        lat_norm           # 1
    ], axis=2)  # Total: 17 bands

    # Ensure correct number of bands
    if image.shape[2] != 17:
        if image.shape[2] < 17:
            padding = np.zeros((image.shape[0], image.shape[1], 17 - image.shape[2]))
            image = np.concatenate([image, padding], axis=2)
        else:
            image = image[:, :, :17]

    return np.nan_to_num(image, nan=0.0, posinf=1.0, neginf=0.0)

def preprocess_latin_america(x_img, image_path=None, transform=None, crs=None):
    """
    Preprocess image for Latin America model (15 bands).
    """
    # Extract optical bands
    optical = x_img[:, :, :9].astype(np.float32)
    optical = np.where(optical < 0, 0, optical)
    optical_norm = norm_optical(optical)

    # SAR bands
    vv = normalise_vv(x_img[:, :, 9].astype(np.float32))
    vh = normalise_vh(x_img[:, :, 10].astype(np.float32))

    # Elevation
    alt = normalise_altitude(x_img[:, :, 11].astype(np.float32))

    # Coordinates
    if transform is not None and crs is not None:
        height, width = x_img.shape[:2]
        rows, cols = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
        xs, ys = rasterio.transform.xy(transform, rows, cols)

        # Convert to 2D arrays
        xs = np.array(xs).reshape(height, width)
        ys = np.array(ys).reshape(height, width)

        if crs.to_string() != "EPSG:4326":
            transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
            longitudes, latitudes = transformer.transform(xs, ys)
        else:
            longitudes, latitudes = xs, ys

        lon_norm = normalise_longitude(longitudes)
        lat_norm = normalise_latitude(latitudes)

    # NDVI only
    nir = optical[:, :, 6]
    red = optical[:, :, 2]
    ndvi = np.where((nir + red) == 0, 0, (nir - red) / (nir + red))
    #ndvi = normalise_ndvi(ndvi)

    # Ensure 3D shape
    ndvi = ndvi[:, :, np.newaxis] if ndvi.ndim == 2 else ndvi
    vv = vv[:, :, np.newaxis] if vv.ndim == 2 else vv
    vh = vh[:, :, np.newaxis] if vh.ndim == 2 else vh
    alt = alt[:, :, np.newaxis] if alt.ndim == 2 else alt
    lon_norm = lon_norm[:, :, np.newaxis] if lon_norm.ndim == 2 else lon_norm
    lat_norm = lat_norm[:, :, np.newaxis] if lat_norm.ndim == 2 else lat_norm

    # Concatenate
    image = np.concatenate([
        optical_norm,  # 9
        ndvi,          # 1
        vv,            # 1
        vh,            # 1
        alt,           # 1
        lon_norm,      # 1
        lat_norm       # 1
    ], axis=2)  # Total: 15 bands

    if image.shape[2] != 15:
        if image.shape[2] < 15:
            padding = np.zeros((image.shape[0], image.shape[1], 15 - image.shape[2]))
            image = np.concatenate([image, padding], axis=2)
        else:
            image = image[:, :, :15]

    return np.nan_to_num(image, nan=0.0, posinf=1.0, neginf=0.0)

def preprocess_southeast_asia(x_img, image_path=None, transform=None, crs=None):
    """Preprocess image for Southeast Asia model (same as Latin America)."""
    return preprocess_latin_america(x_img, image_path, transform, crs)

# Add functions to REGION_CONFIGS
REGION_CONFIGS['Africa']['preprocess_function'] = preprocess_africa
REGION_CONFIGS['Southeast Asia']['preprocess_function'] = preprocess_southeast_asia
REGION_CONFIGS['Latin America']['preprocess_function'] = preprocess_latin_america

print("‚úÖ Preprocessing functions defined!")

‚úÖ Preprocessing functions defined!


# ## 6. HELPER FUNCTIONS FOR SENTINEL-1

In [None]:
# Helper functions
def preproc_s1(s1_collection):
    """
    Preprocesses an S1 image collection with slope correction and edge masking

    Parameters
    ----------
    s1_collection : ee.ImageCollection
        An S1 image collection on float/amplitude format (not dB)

    Returns
    -------
    s1_collection : ee.ImageCollection
        The slope-corrected and edge-masked S1 image collection, coverted to dB scaling

    """
    # Do the slope correction
    s1_collection = slope_correction(s1_collection)

    # Mask the edge noise
    s1_collection = s1_collection.map(maskAngGT30)
    s1_collection = s1_collection.map(maskAngLT452)

    # Convert to dB
    s1_collection = s1_collection.map(lin_to_db)

    return ee.ImageCollection(s1_collection)

'''
Code below is adopted from adugnag/gee_s1_ard
'''

def slope_correction(collection,
                     TERRAIN_FLATTENING_MODEL = 'VOLUME',
                     DEM = ee.Image('USGS/SRTMGL1_003'),
                     TERRAIN_FLATTENING_ADDITIONAL_LAYOVER_SHADOW_BUFFER = 0):
    """
    Parameters
    ----------
    collection : ee image collection
        DESCRIPTION.
    TERRAIN_FLATTENING_MODEL : string
        The radiometric terrain normalization model, either volume or direct
    DEM : ee asset
        The DEM to be used
    TERRAIN_FLATTENING_ADDITIONAL_LAYOVER_SHADOW_BUFFER : integer
        The additional buffer to account for the passive layover and shadow
    Returns
    -------
    ee image collection
        An image collection where radiometric terrain normalization is
        implemented on each image
    """

    ninetyRad = ee.Image.constant(90).multiply(math.pi/180)

    def _volumetric_model_SCF(theta_iRad, alpha_rRad):
        """
        Parameters
        ----------
        theta_iRad : ee.Image
            The scene incidence angle
        alpha_rRad : ee.Image
            Slope steepness in range
        Returns
        -------
        ee.Image
            Applies the volume model in the radiometric terrain normalization
        """

        # Volume model
        nominator = (ninetyRad.subtract(theta_iRad).add(alpha_rRad)).tan()
        denominator = (ninetyRad.subtract(theta_iRad)).tan()
        return nominator.divide(denominator)

    def _direct_model_SCF(theta_iRad, alpha_rRad, alpha_azRad):
        """
        Parameters
        ----------
        theta_iRad : ee.Image
            The scene incidence angle
        alpha_rRad : ee.Image
            Slope steepness in range
        Returns
        -------
        ee.Image
            Applies the direct model in the radiometric terrain normalization
        """
        # Surface model
        nominator = (ninetyRad.subtract(theta_iRad)).cos()
        denominator = alpha_azRad.cos().multiply((ninetyRad.subtract(theta_iRad).add(alpha_rRad)).cos())
        return nominator.divide(denominator)

    def _erode(image, distance):
        """

        Parameters
        ----------
        image : ee.Image
            Image to apply the erode function to
        distance : integer
            The distance to apply the buffer
        Returns
        -------
        ee.Image
            An image that is masked to conpensate for passive layover
            and shadow depending on the given distance
        """
        # buffer function (thanks Noel)

        d = (image.Not().unmask(1).fastDistanceTransform(30).sqrt()
             .multiply(ee.Image.pixelArea().sqrt()))

        return image.updateMask(d.gt(distance))

    def _masking(alpha_rRad, theta_iRad, buffer):
        """
        Parameters
        ----------
        alpha_rRad : ee.Image
            Slope steepness in range
        theta_iRad : ee.Image
            The scene incidence angle
        buffer : TYPE
            DESCRIPTION.
        Returns
        -------
        ee.Image
            An image that is masked to conpensate for passive layover
            and shadow depending on the given distance
        """
        # calculate masks
        # layover, where slope > radar viewing angle
        layover = alpha_rRad.lt(theta_iRad).rename('layover')
        # shadow
        shadow = alpha_rRad.gt(ee.Image.constant(-1)
                        .multiply(ninetyRad.subtract(theta_iRad))).rename('shadow')
        # combine layover and shadow
        mask = layover.And(shadow)
        # add buffer to final mask
        if (buffer > 0):
            mask = _erode(mask, buffer)
        return mask.rename('no_data_mask')

    def _correct(image):
        """

        Parameters
        ----------
        image : ee.Image
            Image to apply the radiometric terrain normalization to
        Returns
        -------
        ee.Image
            Radiometrically terrain corrected image
        """

        bandNames = image.bandNames()

        geom = image.geometry()
        proj = image.select(1).projection()

        elevation = DEM.resample('bilinear').reproject(proj,None, 10).clip(geom)

        # calculate the look direction
        heading = ee.Terrain.aspect(image.select('angle')).reduceRegion(ee.Reducer.mean(), image.geometry(), 1000)


        #in case of null values for heading replace with 0
        heading = ee.Dictionary(heading).combine({'aspect': 0}, False).get('aspect')

        heading = ee.Algorithms.If(
            ee.Number(heading).gt(180),
            ee.Number(heading).subtract(360),
            ee.Number(heading)
        )

        # the numbering follows the article chapters
        # 2.1.1 Radar geometry
        theta_iRad = image.select('angle').multiply(math.pi/180)
        phi_iRad = ee.Image.constant(heading).multiply(math.pi/180)

        # 2.1.2 Terrain geometry
        alpha_sRad = ee.Terrain.slope(elevation).select('slope').multiply(math.pi / 180)

        aspect = ee.Terrain.aspect(elevation).select('aspect').clip(geom)

        aspect_minus = aspect.updateMask(aspect.gt(180)).subtract(360)

        phi_sRad = aspect.updateMask(aspect.lte(180))\
            .unmask()\
            .add(aspect_minus.unmask())\
            .multiply(-1)\
            .multiply(math.pi / 180)

        #elevation = DEM.reproject(proj,None, 10).clip(geom)

        # 2.1.3 Model geometry
        # reduce to 3 angle
        phi_rRad = phi_iRad.subtract(phi_sRad)

        # slope steepness in range (eq. 2)
        alpha_rRad = (alpha_sRad.tan().multiply(phi_rRad.cos())).atan()

        # slope steepness in azimuth (eq 3)
        alpha_azRad = (alpha_sRad.tan().multiply(phi_rRad.sin())).atan()

        # 2.2
        # Gamma_nought
        gamma0 = image.divide(theta_iRad.cos())

        if (TERRAIN_FLATTENING_MODEL == 'VOLUME'):
            # Volumetric Model
            scf = _volumetric_model_SCF(theta_iRad, alpha_rRad)

        if (TERRAIN_FLATTENING_MODEL == 'DIRECT'):
            scf = _direct_model_SCF(theta_iRad, alpha_rRad, alpha_azRad)

        # apply model for Gamm0
        gamma0_flat = gamma0.multiply(scf)

        # get Layover/Shadow mask
        mask = _masking(alpha_rRad, theta_iRad, TERRAIN_FLATTENING_ADDITIONAL_LAYOVER_SHADOW_BUFFER)
        output = gamma0_flat.mask(mask).rename(bandNames).copyProperties(image)
        output = ee.Image(output).addBands(image.select('angle'), None, True)

        return output.set('system:time_start', image.get('system:time_start'))
    return collection.map(_correct)


def maskAngLT452(image):
    """
    mask out angles >= 45.23993
    Parameters
    ----------
    image : ee.Image
        image to apply the border noise masking
    Returns
    -------
    ee.Image
        Masked image
    """
    ang = image.select(['angle'])
    return image.updateMask(ang.lt(45.23993)).set('system:time_start', image.get('system:time_start'))


def maskAngGT30(image):
    """
    mask out angles <= 30.63993
    Parameters
    ----------
    image : ee.Image
        image to apply the border noise masking
    Returns
    -------
    ee.Image
        Masked image
    """

    ang = image.select(['angle'])
    return image.updateMask(ang.gt(30.63993)).set('system:time_start', image.get('system:time_start'))


def lin_to_db(image):
    """
    Convert backscatter from linear to dB.
    Parameters
    ----------
    image : ee.Image
        Image to convert
    Returns
    -------
    ee.Image
        output image
    """
    bandNames = image.bandNames().remove('angle')
    db = ee.Image.constant(10).multiply(image.select(bandNames).log10()).rename(bandNames)
    return image.addBands(db, None, True)

# ## 7. EARTH ENGINE DATA DOWNLOAD

In [None]:
#Define Quality assessment bands and threshold
QA_BAND = 'cs_cdf'
CLEAR_THRESHOLD = 0.40
def clearMask(img):

    img = img.toFloat().resample('bilinear').reproject(img.select('B2').projection())
    return img.updateMask(img.select(QA_BAND).gte(CLEAR_THRESHOLD))

def get_sentinel2_composite(bbox, start_date, end_date):
    """Get cloud-free Sentinel-2 composite."""
    s2 = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
    csPlus = ee.ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED').filterDate(start_date, end_date)


    s2_masked = s2.filterBounds(bbox).filterDate(start_date, end_date).linkCollection(csPlus, [QA_BAND])\
    .map(clearMask).median().select(S2_BANDS).clip(bbox)

    return s2_masked



def get_sentinel1_composite(bbox, start_date, end_date):
    """Get Sentinel-1 composite."""
    s1 = ee.ImageCollection('COPERNICUS/S1_GRD_FLOAT')
    s1_filtered = s1.filterBounds(bbox)\
                    .filterDate(start_date, end_date)\
                    .filter(ee.Filter.eq('instrumentMode', 'IW'))

    # Preprocess the image collection
    s1_preproc = preproc_s1(s1_filtered)

    # Select the relevant bands
    s1_preproc = s1_preproc.select('VV', 'VH')

    # Create a median composite
    s1_composite = s1_preproc.median()

    return s1_composite.clip(bbox)


def get_elevation_data(bbox):
    """Get elevation data."""
    #elevation = ee.Image('USGS/SRTMGL1_003').select('elevation')
    elevation = ee.ImageCollection('COPERNICUS/DEM/GLO30').select('DEM').mosaic()
    return elevation.clip(bbox)

def get_forest_loss_data(bbox):
    """Get Hansen forest loss data."""
    loss_dataset = ee.Image('UMD/hansen/global_forest_change_2023_v1_11')
    return loss_dataset.select(['loss']).clip(bbox)

def download_region_data(bbox, region_name, start_date=START_DATE, end_date=END_DATE,
                         scale=10, output_dir='downloads'):
    """
    Download all required data for a region.

    Returns:
        str: Path to downloaded GeoTIFF
    """
    os.makedirs(output_dir, exist_ok=True)

    print(f"üì• Downloading data for {region_name}...")

    # Get all data layers
    s2 = get_sentinel2_composite(bbox, start_date, end_date)
    s1 = get_sentinel1_composite(bbox, start_date, end_date)
    elevation = get_elevation_data(bbox)
    loss = get_forest_loss_data(bbox)

    # Add longitude and latitude bands
    lonlat = ee.Image.pixelLonLat()

    # Stack all bands
    image = s2.addBands(s1)\
              .addBands(elevation)\
              .addBands(lonlat.select('longitude'))\
              .addBands(lonlat.select('latitude'))\
              .addBands(loss)\
              .int16()

    # Export
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_path = os.path.join(output_dir, f'{region_name.replace(" ", "_")}_{timestamp}.tif')

    print(f"Exporting to {output_path}...")
    geemap.ee_export_image(
        image,
        filename=output_path,
        scale=scale,
        region=bbox,
        file_per_band=False
    )

    print(f"‚úÖ Data downloaded to {output_path}")
    return output_path

print("‚úÖ Earth Engine functions defined!")


‚úÖ Earth Engine functions defined!


# ## 8. PREDICTION FUNCTIONS

In [None]:
def pad_image_for_tiling(image, patch_size, padding_mode='reflect'):
    """
    Pad image for tiled prediction.

    Args:
        image: Input image (H, W, C)
        patch_size: Size of patches
        padding_mode: Padding mode

    Returns:
        tuple: (padded_image, (pad_top, pad_left))
    """
    h, w, c = image.shape

    pad_h = (patch_size - h % patch_size) % patch_size
    pad_w = (patch_size - w % patch_size) % patch_size

    if pad_h > 0 or pad_w > 0:
        pad_top = pad_h // 2
        pad_bottom = pad_h - pad_top
        pad_left = pad_w // 2
        pad_right = pad_w - pad_left

        padded_image = np.pad(
            image,
            ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)),
            mode=padding_mode
        )
        return padded_image, (pad_top, pad_left)

    return image, (0, 0)

def predict_tiled_onnx(model_session, image, patch_size=64):
    """
    Predict on large image using tiling with ONNX model.

    Args:
        model_session: ONNX Runtime session
        image: Input image (H, W, C)
        patch_size: Size of patches

    Returns:
        np.ndarray: Predictions (H, W, num_classes)
    """
    h, w, c = image.shape

    # Get model info
    input_name = model_session.get_inputs()[0].name
    output_name = model_session.get_outputs()[0].name
    input_shape = model_session.get_inputs()[0].shape

    # Determine number of classes
    if len(model_session.get_outputs()[0].shape) == 4:
        output_shape = model_session.get_outputs()[0].shape
        num_classes = output_shape[1] if output_shape[1] <= 4 else output_shape[-1]
    else:
        num_classes = model_session.get_outputs()[0].shape[-1]

    # Initialize output
    predictions = np.zeros((h, w, num_classes), dtype=np.float32)

    # Determine if model expects channels-first
    transpose_needed = (len(input_shape) == 4 and input_shape[1] <= 4)

    # Process in patches
    step = patch_size // 8  # 50% overlap

    for i in range(0, h, step):
        for j in range(0, w, step):
            # Calculate actual patch boundaries
            i_start = i
            i_end = min(i + patch_size, h)
            j_start = j
            j_end = min(j + patch_size, w)

            # Get actual patch size (might be smaller at edges)
            actual_patch_height = i_end - i_start
            actual_patch_width = j_end - j_start

            patch = image[i_start:i_end, j_start:j_end, :]

            # Pad if needed for model input
            if patch.shape[0] < patch_size or patch.shape[1] < patch_size:
                pad_h = patch_size - patch.shape[0] if patch.shape[0] < patch_size else 0
                pad_w = patch_size - patch.shape[1] if patch.shape[1] < patch_size else 0
                patch = np.pad(patch, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')



            # Add batch dimension
            patch_batch = np.expand_dims(patch, axis=0).astype(np.float32)

            # Transpose if needed
            if transpose_needed:
                patch_batch = np.transpose(patch_batch, (0, 3, 1, 2))

            # Run inference
            pred_patch = model_session.run(
                [output_name],
                {input_name: patch_batch}
            )[0]

            # Remove batch dimension
            if pred_patch.shape[0] == 1:
                pred_patch = pred_patch[0]

            # Transpose back if needed
            if transpose_needed and len(pred_patch.shape) == 3:
                pred_patch = np.transpose(pred_patch, (1, 2, 0))
            elif transpose_needed and len(pred_patch.shape) == 4:
                pred_patch = pred_patch[0]
                pred_patch = np.transpose(pred_patch, (1, 2, 0))

            # Accumulate
            #pred_h, pred_w = pred_patch.shape[:2]
            #predictions[i:i+pred_h, j:j+pred_w, :] = pred_patch

            # After prediction, crop back to actual size
            if actual_patch_height < patch_size or actual_patch_width < patch_size:
                # Crop the prediction back to actual patch size
                pred_patch = pred_patch[:actual_patch_height, :actual_patch_width, :]

            # Then assign to predictions array
            predictions[i_start:i_start+actual_patch_height, j_start:j_start+actual_patch_width, :] = pred_patch



    return predictions

def predict_land_use(model_session, config, image_path, output_dir='predictions'):
    """
    Run follow-up land use prediction on downloaded image.

    Returns:
        tuple: (prediction_path, confidence_path, class_names)
    """
    os.makedirs(output_dir, exist_ok=True)

    print(f"üß† Starting prediction for {config.get('region_name', 'Unknown')}")

    # Read image
    with rasterio.open(image_path) as src:
        x_img = src.read()
        print(f"Original shape from rasterio (bands, height, width): {x_img.shape}")
        # Get the height, width, and count of bands from the metadata
        height = src.height
        width = src.width
        bands = src.count # This should be 15 for your image
        x_img = reshape_as_image(x_img)
        print("Last band min/max:", x_img[:, :, -1].min(), x_img[:, :, -1].max())
        print(f"Reshaped image shape (height, width, bands): {x_img.shape}")
        #x_img = np.moveaxis(x_img, 0, -1)
        transform = src.transform
        crs = src.crs
        profile = src.profile.copy()

    print(f"Input image shape: {x_img.shape}")

    # Extract loss band (last band)
    if x_img.shape[2] >= 15:
        loss = x_img[:, :, -1]#.astype(np.uint8)
        print("Loss unique values:", np.unique(loss))
        x_img_input = x_img[:, :, :14]  # Use first 14 bands
    else:
        loss = x_img[:, :, -1]#np.zeros_like(x_img[:, :, 0])
        x_img_input = x_img

    # Preprocess
    preprocess_func = config.get('preprocess_function', preprocess_africa)
    x_img_processed = preprocess_func(x_img_input, image_path=image_path, transform=transform, crs=crs)
    print(f"x_img_processed image shape: {x_img_processed.shape}")

    # Pad for tiling
    padded_image, pad_coords = pad_image_for_tiling(x_img_processed, PATCH_SIZE)

    print(f"padded_image image shape: {padded_image.shape}")

    # Predict
    print("Running inference...")
    predictions = predict_tiled_onnx(model_session, padded_image, PATCH_SIZE)

    # Remove padding
    if pad_coords[0] > 0 or pad_coords[1] > 0:
        predictions = predictions[pad_coords[0]:pad_coords[0]+x_img.shape[0],
                                 pad_coords[1]:pad_coords[1]+x_img.shape[1], :]

    # Final prediction
    final_prediction = np.argmax(predictions, axis=2).astype(np.uint8)
    print("Loss unique values2:", np.unique(loss))
    final_prediction = np.where(loss == 0, 0, final_prediction)  # Mask non-loss areas

    print(f"final_prediction image shape: {final_prediction.shape}")

    # Confidence
    #confidence = np.max(predictions, axis=2)
    #confidence = np.nan_to_num(confidence, nan=0, posinf=100, neginf=0)
    #confidence = confidence * 100
    #confidence = np.clip(confidence, 0, 100)  # Clip to valid range
    #confidence = confidence.astype(np.uint8)

    # Confidence
    confidence = np.max(predictions, axis=2)

    # Check if there are any -inf or invalid values
    print(f"Confidence stats before cleaning: min={confidence.min():.6f}, max={confidence.max():.6f}")
    print(f"Has -inf: {np.any(np.isneginf(confidence))}")
    print(f"Has inf: {np.any(np.isinf(confidence))}")
    print(f"Has nan: {np.any(np.isnan(confidence))}")

    # Replace all infinities and NaNs with 0
    confidence = np.where(np.isinf(confidence), 0, confidence)
    confidence = np.where(np.isneginf(confidence), 0, confidence)
    confidence = np.where(np.isnan(confidence), 0, confidence)

    # Now scale to 0-100 range
    confidence = confidence * 100

    # Clip to ensure values are within 0-100 range
    confidence = np.clip(confidence, 0, 100)

    # Convert to uint8
    confidence = confidence.astype(np.uint8)
    confidence = np.where(loss == 0, 0, confidence)  # Mask non-loss areas

    print(f"Confidence stats after cleaning: min={confidence.min()}, max={confidence.max()}")

    print(f"confidence image shape: {confidence.shape}")

    # Save results
    base_name = os.path.splitext(os.path.basename(image_path))[0]

    # Save prediction
    pred_path = os.path.join(output_dir, f'{base_name}_landuse.tif')
    #profile.update(dtype='uint8', count=1, compress='lzw')
    profile.update(dtype='uint8', count=1, compress='lzw', nodata=0)

    with rasterio.open(pred_path, 'w', **profile) as dst:
        # Reshape to 3D if needed
        #if final_prediction.ndim == 2:
          #final_prediction = final_prediction.reshape(1, final_prediction.shape[0], final_prediction.shape[1])
        #dst.write(final_prediction)

        if final_prediction.ndim == 2:
            # Keep it as 2D, rasterio will handle it properly
            pass

        dst.write(final_prediction, 1)  # Write as single band

    #with rasterio.open(pred_path, 'w', **profile) as dst:
        #dst.write(final_prediction, 1)
        dst.update_tags(
            classes=','.join(config.get('classes', [])),
            region=config.get('region_name', 'Unknown'),
            model=config.get('model_name', 'Unknown')
        )



    # Save confidence
    conf_path = os.path.join(output_dir, f'{base_name}_confidence.tif')
    with rasterio.open(conf_path, 'w', **profile) as dst:
        dst.write(confidence, 1)

    print(f"‚úÖ Predictions saved to {pred_path}")

    return pred_path, conf_path, config.get('classes', [])

print("‚úÖ Prediction functions defined!")

‚úÖ Prediction functions defined!


# ## 9. VISUALIZATION FUNCTIONS

In [None]:
def create_legend(classes, color_map):
    """
    Create matplotlib legend for land use classes.

    Args:
        classes: List of class names
        color_map: Matplotlib colormap

    Returns:
        matplotlib.figure.Figure: Legend figure
    """
    n_classes = len(classes)

    n_cols = int(np.ceil(n_classes / 3))
    #fig, ax = plt.subplots(figsize=(8, 10))
    #fig, ax = plt.subplots(figsize=(15, 4))
    fig, ax = plt.subplots(figsize=(15, 3))

    ax.axis('off')

    colors = [color_map(i / max(1, n_classes - 1)) for i in range(n_classes)]

    for i, (class_name, color) in enumerate(zip(classes, colors)):
        #ax.add_patch(plt.Rectangle((0, n_classes - i - 1), 0.3, 0.8,
        #                          facecolor=color, edgecolor='black'))
        #ax.text(0.4, n_classes - i - 0.6, f'{i}: {class_name}',
        #        fontsize=9, va='center')

        row = i // n_cols
        col = i % n_cols

        x = col * 3.5
        y = 2 - row  # 3 rows: 2,1,0

        ax.add_patch(plt.Rectangle((x, y), 0.3, 0.5,
                                  facecolor=color, edgecolor='black'))
        ax.text(x + 0.4, y + 0.25, f'{i}: {class_name}',
                fontsize=9, va='center')

    #ax.set_xlim(0, 5)
    #ax.set_ylim(0, n_classes)
    ax.set_xlim(0, n_cols * 3.5)
    ax.set_ylim(-2.4, 3)
    ax.set_title('Follow-up Land Use Classes', fontsize=12, fontweight='bold', pad=20)

    # Abbreviations (below legend)
    ax.text(
        0,
        -1.2,
        'Abbreviations:',
        fontsize=10,
        fontweight='bold',
        ha='left',
        va='top'
    )

    ax.text(
        0,
        -1.7,
        'OLSCP ‚Äì Other large-scale cropland',
        fontsize=9,
        ha='left',
        va='top'
    )

    ax.text(
        0,
        -2.1,
        'OSSCP ‚Äì Other small-scale cropland',
        fontsize=9,
        ha='left',
        va='top'
    )


    plt.tight_layout()
    return fig

def visualize_results(image_path, prediction_path, confidence_path, classes, region_name):
    """
    Create comprehensive visualization of results.
    """
    #fig = plt.figure(figsize=(18, 12))
    fig = plt.figure(figsize=(16, 9))


    # Read input image (RGB)
    with rasterio.open(image_path) as src:
        rgb = src.read([4, 3, 2])  # B4=Red, B3=Green, B2=Blue
        print(f"rgb shape after clip: {rgb.shape}")
        rgb = np.moveaxis(rgb, 0, -1)
        print(f"rgb shape after clip: {rgb.shape}")
        rgb = np.clip(rgb / 3000, 0, 1)
        print(f"rgb shape after clip: {rgb.shape}")

    # Read prediction
    #with rasterio.open(prediction_path) as src:
    #    prediction = src.read(1)

    # Read prediction
    with rasterio.open(prediction_path) as src:
        prediction = src.read(1)  # This returns a 2D array
        print(f"Prediction shape after read(1): {prediction.shape}")
        print(f"Prediction dtype: {prediction.dtype}")
        print(f"Prediction min/max: {prediction.min()}, {prediction.max()}")


    # Read confidence
    with rasterio.open(confidence_path) as src:
        confidence = src.read(1)

    # Get loss mask
    with rasterio.open(image_path) as src:
        if src.count > 15:
            loss = src.read(src.count)
            loss = loss.squeeze()  # Remove the first dimension if it's 1
            print(f"loss shape after src read and squeeze: {loss.shape}")
            print(f"loss min/max: {loss.min()}, {loss.max()}")
        else:
            loss = np.zeros_like(prediction)
            print(f"loss shape zeros like prediction: {loss.shape}")

    print(f"loss shape: {loss.shape}")

    # Mask for deforested areas
    deforested_mask = (loss > 0) #& (prediction > 0)
    #deforested_mask = prediction > 0 #(loss > 0) #& (prediction > 0)
    print(f"deforested_mask shape: {deforested_mask.shape}")

    # 1. RGB with deforestation highlight
    ax1 = plt.subplot(2, 3, 1)
    rgb_highlighted = rgb.copy()
    assert rgb_highlighted.ndim == 3
    assert deforested_mask.shape == rgb_highlighted.shape[:2]
    mask3 = deforested_mask[:, :, None]
    rgb_highlighted = np.where(
        mask3,
        np.array([1.0, 0.0, 0.0]),
        rgb_highlighted
    )
    ax1.imshow(rgb_highlighted)
    #ax1.set_title('RGB with Deforestation Highlight', fontsize=11, fontweight='bold')
    ax1.set_title('Satellite Image - False color', fontsize=11, fontweight='bold')
    ax1.axis('off')

    # 2. Follow-up Land use prediction
    ax2 = plt.subplot(2, 3, 2)
    prediction_masked = np.where(deforested_mask, prediction, 0)

    print(f"rgb_highlighted shape: {rgb_highlighted.shape}")
    print(f"deforested_mask shape: {deforested_mask.shape}")

    color_map = REGION_CONFIGS[region_name]['color_map']
    #im2 = ax2.imshow(prediction_masked, cmap=color_map, vmin=0, vmax=max(1, len(classes)-1))
    im2 = ax2.imshow(prediction, cmap=color_map, vmin=0, vmax=max(1, len(classes)-1))
    ax2.set_title('Follow-up Land Use Prediction', fontsize=11, fontweight='bold')
    ax2.axis('off')
    plt.colorbar(im2, ax=ax2, fraction=0.026, pad=0.04)

    # 3. Prediction confidence
    ax3 = plt.subplot(2, 3, 3)
    confidence_masked = np.where(deforested_mask, confidence, 0)
    #im3 = ax3.imshow(confidence_masked, cmap='RdYlGn', vmin=0, vmax=100)
    im3 = ax3.imshow(confidence, cmap='gnuplot2', vmin=0, vmax=100)
    ax3.set_title('Prediction Confidence (%)', fontsize=11, fontweight='bold')
    ax3.axis('off')
    plt.colorbar(im3, ax=ax3, fraction=0.026, pad=0.04)

    # 4. Class distribution
    ax4 = plt.subplot(2, 3, 4)
    unique_classes, counts = np.unique(prediction[prediction > 0], return_counts=True)
    if len(unique_classes) > 0:
        colors = [color_map(cls / max(1, len(classes)-1)) for cls in unique_classes]
        ax4.bar(range(len(unique_classes)), counts, color=colors, edgecolor='black')
        ax4.set_xticks(range(len(unique_classes)))
        ax4.set_xticklabels([str(int(cls)) for cls in unique_classes])
        ax4.set_xlabel('Class ID')
        ax4.set_ylabel('Pixel Count')
        ax4.set_title('Class Distribution', fontsize=11, fontweight='bold')
        ax4.grid(alpha=0.3)

    # 5. Confidence histogram
    ax5 = plt.subplot(2, 3, 5)
    ax5.hist(confidence[confidence > 0].flatten(), bins=20, edgecolor='black', alpha=0.7)
    ax5.set_xlabel('Confidence (%)')
    ax5.set_ylabel('Frequency')
    ax5.set_title('Confidence Distribution', fontsize=11, fontweight='bold')
    ax5.grid(alpha=0.3)

    # 6. Region info
    ax6 = plt.subplot(2, 3, 6)
    ax6.axis('off')
    info_text = f"""
    Region: {region_name}
    Total Pixels: {prediction.size:,}
    Deforested Pixels: {np.count_nonzero(prediction):,}
    Mean Confidence: {np.mean(confidence[confidence > 0]):.1f}%
    """
    ax6.text(0.1, 0.5, info_text, fontsize=10, va='center', linespacing=1.5)
    ax6.set_title('Statistics', fontsize=11, fontweight='bold')

    plt.suptitle(f'Monitoring Land Use Following Deforestation - {region_name}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

print("‚úÖ Visualization functions defined!")

‚úÖ Visualization functions defined!


# ## 10. INTERACTIVE INTERFACE

In [None]:
class DeforestationPredictor:
    """Interactive deforestation prediction tool with ONNX models"""

    def __init__(self):
        # Initialize Earth Engine map
        self.map = geemap.Map(
            center=[0, 0],
            zoom=3,
            height=600,
            layout=Layout(width='70%')
        )

        # ===============================
        # ADD BASEMAP AND FOREST LOSS
        # ===============================

        # High-resolution satellite basemap
        # Basemaps
        self.map = geemap.Map(basemap='HYBRID')

        self.map.addLayer(
            geemap.basemaps['Esri.WorldStreetMap'],
            {},
            'Road map',
            shown=False
        )


        # --- GLOBAL HANSEN FOREST LOSS (NO CLIP) ---
        hansen = ee.Image('UMD/hansen/global_forest_change_2023_v1_11')
        self.loss_img = hansen.select('loss')
        self.lossyear_img = hansen.select('lossyear')


        self.map.addLayer(
            self.lossyear_img,
            {
                'min': 1,
                'max': 23,
                'palette': [
                    '#ffffcc','#ffeda0','#fed976','#feb24c',
                    '#fd8d3c','#fc4e2a','#e31a1c','#bd0026','#800026'
                ]
            },
            'Forest loss year (global)',
            shown=True
        )


        lossyear_legend = {
            '2001': '#ffffcc',
            '2005': '#fed976',
            '2010': '#fd8d3c',
            '2015': '#e31a1c',
            '2023': '#800026'
        }

        self.map.add_legend(
            title='Forest loss year (Hansen)',
            legend_dict=lossyear_legend,
            position='bottomright'
        )





        # Initialize state variables
        self.roi = None
        self.region_name = None
        self.region_config = {}
        self.model_session = None
        self.data_path = None
        self.prediction_path = None
        self.confidence_path = None
        self.class_names = None

        # Create output widget for logs
        #self.output = Output(layout=Layout(width='100%', height='300px'))
        self.output = Output(layout=Layout(width='100%'))


        # Create control widgets
        self._create_widgets()

        # Arrange layout
        self._create_layout()

        # Initialize drawing control
        self._setup_draw_control()
        self.map.addLayerControl(position='topright')

    def _setup_draw_control(self):
        """Setup draw control"""
        self.map.clear_controls()
        self.map.add_draw_control()

        def handle_draw(target, action, geo_json):
            if action == 'created' and geo_json:
                self._process_drawn_geometry(geo_json)

        if hasattr(self.map.draw_control, 'on_draw'):
            self.map.draw_control.on_draw(handle_draw)



    def _process_drawn_geometry(self, geo_json):
        """Process drawn geometry"""
        try:
            self.roi = ee.Geometry(geo_json['geometry'])

            # Detect region
            self.region_name = None
            for region_name, config in REGION_CONFIGS.items():
                bbox = config['bbox']
                if bbox.contains(self.roi.centroid()).getInfo():
                    self.region_name = region_name
                    break

            if self.region_name:
                self.region_config = REGION_CONFIGS[self.region_name].copy()
                self.region_dropdown.value = self.region_name

                area_sqkm = self.roi.area().divide(1e6).getInfo()

                self.region_label.value = f"""
                <h3>üåç Region: <span style='color:green'>{self.region_name}</span></h3>
                <p>Model: {self.region_config.get('model_name', 'Unknown')}</p>
                <p>Classes: {self.region_config.get('output_classes', 0)}</p>
                """

                self.roi_label.value = f"""
                <h3>üó∫Ô∏è ROI: <span style='color:green'>{area_sqkm:.1f} km¬≤</span></h3>
                """

                self.download_btn.disabled = False
                self.predict_btn.disabled = True

                self.status.value = f"""
                <h4>Status: <span style='color:green'>‚úì ROI drawn in {self.region_name}</span></h4>
                <p>Ready to download satellite data</p>
                """

                with self.output:
                    clear_output()
                    print(f"‚úì ROI drawn in {self.region_name}")
                    print(f"  Area: {area_sqkm:.1f} km¬≤")
            else:
                self.region_label.value = """
                <h3>üåç Region: <span style='color:orange'>Outside supported regions</span></h3>
                """
                self.download_btn.disabled = True

                with self.output:
                    clear_output()
                    print("‚ö†Ô∏è Please draw within Africa, Southeast Asia, or Latin America")

        except Exception as e:
            with self.output:
                clear_output()
                print(f"‚ùå Error: {str(e)}")

    def _create_widgets(self):
        """Create all interactive widgets"""
        # Region label
        self.region_label = widgets.HTML(
            value="<h3>üåç Region: <span style='color:red'>Not selected</span></h3>",
            layout=Layout(width='100%')
        )

        # ROI info label
        self.roi_label = widgets.HTML(
            value="<h3>üó∫Ô∏è ROI: <span style='color:red'>Not drawn</span></h3>",
            layout=Layout(width='100%')
        )

        # Download button
        self.download_btn = widgets.Button(
            description="üì• Download Satellite Data",
            button_style='primary',
            disabled=True,
            icon='download',
            layout=Layout(width='250px', height='50px')
        )
        self.download_btn.on_click(self.download_data)

        # Predict button
        self.predict_btn = widgets.Button(
            description="üß† Run Follow-up Land Use Prediction",
            button_style='success',
            disabled=True,
            icon='cogs',
            layout=Layout(width='250px', height='50px')
        )
        self.predict_btn.on_click(self.run_prediction)

        # Clear button
        self.clear_btn = widgets.Button(
            description="üóëÔ∏è Clear All",
            button_style='warning',
            icon='trash',
            layout=Layout(width='250px', height='50px')
        )
        self.clear_btn.on_click(self.clear_all)

        # Region dropdown
        self.region_dropdown = widgets.Dropdown(
            options=['Africa', 'Southeast Asia', 'Latin America'],
            value=None,
            description='Select Region:',
            disabled=False,
            layout=Layout(width='300px')
        )
        self.region_dropdown.observe(self.on_region_change, names='value')

        # ===============================
        # FOREST LOSS YEAR SLIDER
        # ===============================

        self.lossyear_slider = widgets.IntSlider(
            value=23,
            min=1,
            max=23,
            step=1,
            description='Loss year:',
            continuous_update=False,
            layout=Layout(width='95%')
        )

        self.lossyear_slider.observe(self.update_lossyear_layer, names='value')

        # Progress bar
        self.progress = widgets.IntProgress(
            value=0,
            min=0,
            max=100,
            description='Progress:',
            bar_style='info',
            style={'description_width': 'initial'},
            layout=Layout(width='95%')
        )

        # Status label
        self.status = widgets.HTML(
            value="<h4>Status: <span style='color:blue'>Ready - Draw a region on the map</span></h4>",
            layout=Layout(width='100%')
        )

        # Results display
        self.results_label = widgets.HTML(
            value="<h3>üìä Results</h3>",
            layout=Layout(width='100%')
        )

        # Visualization button
        self.viz_btn = widgets.Button(
            description="üìà Visualize Results",
            button_style='info',
            disabled=True,
            icon='chart-bar',
            layout=Layout(width='200px', height='40px')
        )
        self.viz_btn.on_click(self.visualize_results)

        # Download results button
        self.download_results_btn = widgets.Button(
            description="üíæ Download GeoTIFFs",
            button_style='info',
            disabled=True,
            icon='file-download',
            layout=Layout(width='200px', height='40px')
        )
        self.download_results_btn.on_click(self.download_results)

        # Instructions
        self.instructions = widgets.Accordion(children=[
            widgets.HTML("""
            <div style="padding: 10px; font-size: 14px;">
            <h4>üìñ HOW TO USE:</h4>
            <ol>
                <li><strong>Draw a polygon</strong> on the map</li>
                <li>System detects region automatically</li>
                <li>Click <strong>"Download Satellite Data"</strong></li>
                <li>Click <strong>"Run Follow-up Land Use Prediction"</strong></li>
                <li>View results and download GeoTIFFs</li>
            </ol>
            <h4>üìä OUTPUT:</h4>
            <ul>
                <li><strong>Land Use Map</strong>: Predicted classes (GeoTIFF)</li>
                <li><strong>Confidence Map</strong>: Prediction confidence (GeoTIFF)</li>
                <li><strong>Visualizations</strong>: RGB, predictions, statistics</li>
            </ul>
            <h4>‚ö†Ô∏è NOTES:</h4>
            <ul>
                <li>Predictions only on deforested areas</li>
                <li>Each region has specific model</li>
                <li>Processing time: 1-5 minutes</li>
                <li>Requires internet for data/model download</li>
            </ul>
            </div>
            """)
        ])
        self.instructions.set_title(0, 'üìö Instructions')
        self.instructions.selected_index = None

    def _create_layout(self):
        """Arrange widgets in layout"""
        button_row1 = HBox([
            self.download_btn,
            self.predict_btn,
            self.clear_btn
        ], layout=Layout(justify_content='center', margin='10px 0'))

        button_row2 = HBox([
            self.viz_btn,
            self.download_results_btn
        ], layout=Layout(justify_content='center', margin='10px 0'))

        control_panel = VBox([
            self.instructions,
            self.region_label,
            self.roi_label,
            self.region_dropdown,
            self.lossyear_slider,
            button_row1,
            self.progress,
            self.status,
            self.results_label,
            button_row2
            #self.output
        ], layout=Layout(width='30%', padding='10px'))

        self.main_layout = VBox([
            HBox([
                self.map,
                control_panel
            ], layout=Layout(width='100%')),

            self.output   # ‚Üê visualizations now appear BELOW the map
        ], layout=Layout(width='100%'))

    def on_region_change(self, change):
        """Handle manual region selection"""
        if change['new']:
            self.region_name = change['new']
            self.region_config = REGION_CONFIGS.get(self.region_name, {}).copy()

            self.region_label.value = f"""
            <h3>üåç Region: <span style='color:green'>{self.region_name} (Manual)</span></h3>
            """

            if self.roi:
                self.download_btn.disabled = False

            with self.output:
                clear_output()
                print(f"‚úì Region set to: {self.region_name}")


    def update_lossyear_layer(self, change):
        """Update forest loss year visualization"""
        year = change['new']

        #hansen = ee.Image('UMD/hansen/global_forest_change_2023_v1_11')
        #lossyear = hansen.select('lossyear')

        #filtered = lossyear.updateMask(lossyear.eq(year))
        filtered = self.lossyear_img.updateMask(
            self.lossyear_img.eq(year)
        )

        vis = {
            'min': 1,
            'max': 23,
            'palette': ['red']
        }


        self.map.layers = self.map.layers[:3]  # keep basemap + base layers
        #self.map.addLayer(
        #        filtered,
        #        vis,
        #        f'Forest loss {2000 + year}',
        #        shown=True
        #    )



    def download_data(self, b):
        """Download satellite data for ROI"""
        with self.output:
            clear_output()

            if not self.roi:
                print("‚ùå Please draw a region on the map first")
                return

            if not self.region_name:
                print("‚ùå Could not detect region. Please select manually.")
                return

            try:
                self.progress.value = 0
                self.status.value = "<h4>Status: <span style='color:orange'>Downloading satellite data...</span></h4>"

                print(f"üì• Downloading data for {self.region_name}...")

                bbox = self.roi.buffer(500).bounds()
                self.progress.value = 30

                self.data_path = download_region_data(
                    bbox,
                    self.region_name,
                    scale=10,
                    output_dir='downloads'
                )

                self.progress.value = 60
                print(f"‚úì Data downloaded")

                print(f"üß† Loading {self.region_name} model...")
                self.model_session, loaded_config = load_onnx_model(self.region_name)
                self.region_config.update(loaded_config)

                self.progress.value = 100
                self.predict_btn.disabled = False

                self.status.value = f"""
                <h4>Status: <span style='color:green'>‚úì Data and model loaded</span></h4>
                """

                print("‚úÖ Ready to run prediction")

            except Exception as e:
                self.status.value = f"<h4>Status: <span style='color:red'>Error: {str(e)[:100]}</span></h4>"
                print(f"‚ùå Error: {str(e)}")

    def run_prediction(self, b):
        """Run follow-up land use prediction"""
        with self.output:
            clear_output()

            if not self.data_path:
                print("‚ùå Please download data first")
                return

            try:
                self.progress.value = 0
                self.status.value = "<h4>Status: <span style='color:orange'>Running prediction...</span></h4>"

                print(f"üß† Running prediction for {self.region_name}...")
                self.progress.value = 20

                self.prediction_path, self.confidence_path, self.class_names = predict_land_use(
                    self.model_session,
                    self.region_config,
                    self.data_path
                )

                self.progress.value = 80

                with rasterio.open(self.prediction_path) as src:
                    pred_data = src.read(1)
                    #pred_data = reshape_as_image(pred_data)
                    unique_classes = np.unique(pred_data[pred_data > 0])

                self.progress.value = 100
                self.viz_btn.disabled = False
                self.download_results_btn.disabled = False

                self.results_label.value = f"""
                <h3>üìä Results - {self.region_name}</h3>
                <p><strong>Classes found:</strong> {len(unique_classes)}</p>
                <p><strong>Land use map:</strong> {os.path.basename(self.prediction_path)}</p>
                <p><strong>Confidence map:</strong> {os.path.basename(self.confidence_path)}</p>
                """

                self.status.value = f"""
                <h4>Status: <span style='color:green'>‚úì Prediction complete!</span></h4>
                """

                print(f"‚úÖ Prediction complete!")
                print(f"üìÅ Output files saved in 'predictions' directory")

            except Exception as e:
                self.status.value = f"<h4>Status: <span style='color:red'>Error: {str(e)[:100]}</span></h4>"
                print(f"‚ùå Error: {str(e)}")

    def visualize_results(self, b):
        """Visualize prediction results"""
        with self.output:
            clear_output()

            if not self.prediction_path:
                print("‚ùå Please run prediction first")
                return

            try:
                print("üìà Generating visualizations...")
                visualize_results(
                    self.data_path,
                    self.prediction_path,
                    self.confidence_path,
                    self.class_names,
                    self.region_name
                )

                if self.class_names:
                    legend_fig = create_legend(self.class_names, self.region_config.get('color_map', plt.cm.tab20))
                    plt.show()
                    plt.tight_layout(rect=[0, 0.03, 1, 0.95])


                print("‚úÖ Visualizations generated!")

            except Exception as e:
                print(f"‚ùå Error: {str(e)}")

    def download_results(self, b):
        """Provide download links for results"""
        with self.output:
            clear_output()

            if not self.prediction_path:
                print("‚ùå No results to download")
                return

            try:
                print("üíæ Download GeoTIFF files:")
                print("=" * 50)

                from IPython.display import FileLink, display

                print("\n1. Land Use Classification:")
                display(FileLink(self.prediction_path, result_html_prefix="üó∫Ô∏è "))

                print("\n2. Prediction Confidence:")
                display(FileLink(self.confidence_path, result_html_prefix="üìä "))

                print("\n3. Input Satellite Data:")
                display(FileLink(self.data_path, result_html_prefix="üõ∞Ô∏è "))

                print("\n‚úÖ Files ready for download")

            except Exception as e:
                print(f"‚ùå Error: {str(e)}")

    def clear_all(self, b):
        """Clear all selections and reset"""
        with self.output:
            clear_output()
            print("üóëÔ∏è Clearing all data...")

        self.roi = None
        self.region_name = None
        self.region_config = {}
        self.model_session = None
        self.data_path = None
        self.prediction_path = None
        self.confidence_path = None
        self.class_names = None

        self.map.clear_draw()

        self.region_label.value = "<h3>üåç Region: <span style='color:red'>Not selected</span></h3>"
        self.roi_label.value = "<h3>üó∫Ô∏è ROI: <span style='color:red'>Not drawn</span></h3>"
        self.region_dropdown.value = None
        self.results_label.value = "<h3>üìä Results</h3>"

        self.download_btn.disabled = True
        self.predict_btn.disabled = True
        self.viz_btn.disabled = True
        self.download_results_btn.disabled = True

        self.progress.value = 0
        self.status.value = "<h4>Status: <span style='color:blue'>Ready - Draw a region on the map</span></h4>"

        self._setup_draw_control()

        print("‚úÖ All cleared! Ready for new analysis.")

    def display(self):
        """Display the interface"""
        display(self.main_layout)

    # Ensure layer control is added last



print("‚úÖ Interactive interface defined!")

‚úÖ Interactive interface defined!


# ## 11. LAUNCH THE APPLICATION

In [None]:



# %% [markdown]


# %%
print("""
================================================================================
üåç POST-DEFORESTATION LAND USE PREDICTION SYSTEM
================================================================================
Version: 1.0
Author: Your Name/Organization
Date: 2024

SUPPORTED REGIONS:
- Africa (25 land use classes)
- Southeast Asia (24 land use classes)
- Latin America (22 land use classes)

INSTRUCTIONS:
1. Draw a polygon on the map within a supported region
2. Click "Download Satellite Data" to fetch Sentinel-1/2 data
3. Click "Run Follow-up Land Use Prediction" to classify deforestation areas
4. Visualize results and download GeoTIFF files

OUTPUT:
- Follow-up land use classification map (GeoTIFF)
- Prediction confidence map (GeoTIFF)
- Interactive visualizations
- Class legend and statistics
================================================================================
""")

# Create and display the predictor
predictor = DeforestationPredictor()
predictor.display()



üåç POST-DEFORESTATION LAND USE PREDICTION SYSTEM
Version: 1.0
Author: Your Name/Organization
Date: 2024

SUPPORTED REGIONS:
- Africa (25 land use classes)
- Southeast Asia (24 land use classes)
- Latin America (22 land use classes)

INSTRUCTIONS:
1. Draw a polygon on the map within a supported region
2. Click "Download Satellite Data" to fetch Sentinel-1/2 data
3. Click "Run Follow-up Land Use Prediction" to classify deforestation areas
4. Visualize results and download GeoTIFF files

OUTPUT:
- Follow-up land use classification map (GeoTIFF)
- Prediction confidence map (GeoTIFF)
- Interactive visualizations
- Class legend and statistics



VBox(children=(HBox(children=(Map(center=[0, 0], controls=(MapDrawControl(marker={'shapeOptions': {'color': '#‚Ä¶


üîß TROUBLESHOOTING GUIDE

COMMON ISSUES:

1. "Earth Engine not authenticated"
   - Run the authentication cell (Section 1)
   - Follow the prompts to authorize

2. "Model not found" error
   - Check internet connection
   - Manual download: https://huggingface.co/datasets/Masolele/deforestwatch-models
   - Place .onnx files in 'models/' directory

3. Slow download/prediction
   - Use smaller ROI (under 100 km¬≤ for testing)
   - 30m resolution is sufficient for most applications

4. No results in prediction
   - The area might not have deforestation (loss = 0)
   - Try a different location with known deforestation

5. Visualization errors
   - Ensure prediction ran successfully first
   - Check if output files exist in 'predictions/' directory

GETTING HELP:
- Check console output for error messages
- Ensure all packages are installed (Section 1)
- Models require ~50MB each (download once)



<Figure size 640x480 with 0 Axes>

# ## 12. TROUBLESHOOTING

================================================================================
üîß TROUBLESHOOTING GUIDE
================================================================================

COMMON ISSUES:

1. "Earth Engine not authenticated"
   - Run the authentication cell (Section 1)
   - Follow the prompts to authorize

2. "Model not found" error
   - Check internet connection
   - Manual download: https://huggingface.co/datasets/Masolele/deforestwatch-models
   - Place .onnx files in 'models/' directory

3. Slow download/prediction
   - Use smaller ROI (under 100 km¬≤ for testing)
   - 10m resolution is sufficient for most applications

4. No results in prediction
   - The area might not have deforestation (loss = 0)
   - Try a different location with known deforestation

5. Visualization errors
   - Ensure prediction ran successfully first
   - Check if output files exist in 'predictions/' directory

GETTING HELP:
- Check console output for error messages
- Ensure all packages are installed (Section 1)
- Models require ~50MB each (download once)
================================================================================
