# Sentinel-2 DSWE Monthly Generator
This code employs Sentinel-2 remote sensing data (blue, green, red, NIR, SWIR1, and SWIR2) within the Dynamic Surface Water Extent (DSWE) algorithm to develop monthly water inundation extent maps for a given study area. The code first creates monthly composites from avaialable Sentinel-2 data, then applies the algorithm, exporting each product as an asset to Google Earth Engine.

DSWE Methodology: Jones, J.W., 2019. Improved Automated Detection of Subpixel-Scale Inundation—Revised Dynamic Surface Water Extent (DSWE) Partial Surface Water Tests. Remote Sensing 11, 374. https://doi.org/10.3390/rs11040374

Sentinel-2: European Space Agency (ESA). (2023). Sentinel-2 imagery. Copernicus Open Access Hub. Retrieved from https://scihub.copernicus.eu/

Google Earth Engine: Gorelick, N., Hancher, M., Dixon, M., Ilyushchenko, S., Thau, D., Moore, R., 2017. Google Earth Engine: Planetary-scale geospatial analysis for everyone. Remote Sensing of Environment, Big Remotely Sensed Data: tools, applications and experiences 202, 18–27. https://doi.org/10.1016/j.rse.2017.06.031

Author: James (Huck) Rees, PhD Student, UC Santa Barbara Geography

Date: March 10th, 2024

## Import packages and initialize GEE

In [1]:
import ee
import geopandas as gpd
import os
import logging
from datetime import datetime, timedelta
import calendar
from dateutil.relativedelta import relativedelta

import numpy as np
from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks, argrelextrema
import matplotlib.pyplot as plt

# Initialize
ee.Initialize()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

## Initialize functions for generating monthly Sentinel-2 composites

In [2]:
def load_roi(shapefile_path):
    """Load ROI from a shapefile and return as an EE Geometry."""
    gdf = gpd.read_file(shapefile_path)
    return ee.Geometry.Polygon(gdf.unary_union.__geo_interface__["coordinates"])


def mask_clouds_sentinel(image):
    """Mask clouds using Sentinel-2 Scene Classification Layer (SCL) and AOT."""
    scl = image.select("SCL")
    aot = image.select("AOT").multiply(0.001)

    # Mask out cloud-related classes
    cloud_free = scl.neq(3).And(scl.neq(8)).And(scl.neq(9)).And(scl.neq(10))
    clean_image = image.updateMask(cloud_free).updateMask(aot.lt(0.3))

    return clean_image

def create_monthly_composite_with_gap_filling(year, month, roi, 
                                               max_expansion=2, 
                                               coverage_threshold=0.95):
    """
    Creates monthly Sentinel-2 composite with iterative temporal expansion.
    
    If the base month has insufficient coverage (due to clouds/shadows), this function
    iteratively expands the temporal window to ±1 month, then ±2 months, filling only
    the gaps left by the previous attempt.
    
    Parameters:
    -----------
    year : int
        Year for composite
    month : int
        Month for composite (1-12)
    roi : ee.Geometry
        Region of interest
    max_expansion : int, optional (default=2)
        Maximum temporal expansion in months (±N months)
    coverage_threshold : float, optional (default=0.95)
        Minimum required coverage (0.0 to 1.0)
        0.95 = 95% of ROI must have valid pixels
    
    Returns:
    --------
    tuple : (composite, qc_value, actual_coverage)
        composite : ee.Image - Final median composite with all bands
        qc_value : int - Expansion level used (0=base month only, 1=±1 month, 2=±2 months)
        actual_coverage : float - Final coverage as fraction (0.0 to 1.0)
    """
    
    # Get date range for base month
    start_date = datetime(year, month, 1)
    last_day = calendar.monthrange(year, month)[1]
    end_date = datetime(year, month, last_day)
    
    # Convert to GEE date strings
    start_str = start_date.strftime('%Y-%m-%d')
    end_str = end_date.strftime('%Y-%m-%d')
    
    # Step 1: Try base month composite
    logging.info(f"  Attempting base month composite ({year}-{month:02d})...")
    
    collection = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")\
        .filterBounds(roi)\
        .filterDate(start_str, end_str)\
        .map(mask_clouds_sentinel)
    
    # Check if we have any images
    count = collection.size().getInfo()
    if count == 0:
        logging.warning(f"  No Sentinel-2 images available for {year}-{month:02d}")
        # Return empty result
        return None, -1, 0.0
    
    # Create base composite
    base_composite = collection.median().clip(roi)
    base_coverage = calculate_coverage(base_composite, roi)
    
    logging.info(f"  Base month coverage: {base_coverage*100:.1f}%")
    
    # If coverage is sufficient, return base composite with QC=0
    if base_coverage >= coverage_threshold:
        return base_composite, 0, base_coverage
    
    # Step 2: Iteratively expand temporal window
    final_composite = base_composite
    final_coverage = base_coverage
    qc_value = 0
    
    for expansion in range(1, max_expansion + 1):
        logging.info(f"  Coverage insufficient. Expanding to ±{expansion} month(s)...")
        
        # Calculate expanded date range
        expanded_start = start_date - relativedelta(months=expansion)
        expanded_end = end_date + relativedelta(months=expansion)
        
        expanded_start_str = expanded_start.strftime('%Y-%m-%d')
        expanded_end_str = expanded_end.strftime('%Y-%m-%d')
        
        # Get expanded collection
        expanded_collection = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")\
            .filterBounds(roi)\
            .filterDate(expanded_start_str, expanded_end_str)\
            .map(mask_clouds_sentinel)
        
        expanded_count = expanded_collection.size().getInfo()
        if expanded_count == 0:
            logging.warning(f"  No additional images in ±{expansion} month window")
            continue
        
        # Create composite from expanded window
        expanded_composite = expanded_collection.median().clip(roi)
        
        # Create mask of pixels that still need filling
        # (pixels that are masked in current composite)
        gap_mask = final_composite.mask().Not()
        
        # Fill gaps: use expanded composite only where current composite has no data
        final_composite = final_composite.unmask(expanded_composite)
        
        # Calculate new coverage
        final_coverage = calculate_coverage(final_composite, roi)
        qc_value = expansion
        
        logging.info(f"  Coverage after ±{expansion} month expansion: {final_coverage*100:.1f}%")
        
        # If we've reached threshold, stop
        if final_coverage >= coverage_threshold:
            break
    
    # Final check
    if final_coverage < coverage_threshold:
        logging.warning(f"  Final coverage ({final_coverage*100:.1f}%) still below threshold "
                       f"({coverage_threshold*100:.0f}%) after {max_expansion} month expansion")
    
    return final_composite, qc_value, final_coverage

def calculate_coverage(image, roi):
    """
    Calculate the fraction of valid (non-masked) pixels in an image over a region.
    
    Uses a coarser scale and single band to avoid computation timeouts.
    
    Parameters:
    -----------
    image : ee.Image
        Image to calculate coverage for
    roi : ee.Geometry
        Region of interest
    
    Returns:
    --------
    float : Coverage fraction (0.0 to 1.0)
        1.0 = 100% coverage (all pixels valid)
        0.0 = 0% coverage (all pixels masked)
    """
    
    # Select just one band to check coverage (Red band - B4)
    # This is much faster than checking all bands
    single_band = image.select('B4')
    
    # Create a binary mask: 1 where data exists, 0 where masked
    valid_pixels = single_band.mask()
    
    # Count valid pixels at coarser scale to speed up computation
    # 30m scale is sufficient for coverage estimation
    valid_count = valid_pixels.reduceRegion(
        reducer=ee.Reducer.sum(),
        geometry=roi,
        scale=30,  # Coarser scale for faster computation
        maxPixels=1e13,
        tileScale=4  # Helps with large computations
    ).getInfo()
    
    # Count total pixels in ROI
    total_pixels = ee.Image.constant(1).clip(roi).reduceRegion(
        reducer=ee.Reducer.count(),
        geometry=roi,
        scale=30,  # Match the scale used above
        maxPixels=1e13,
        tileScale=4
    ).getInfo()
    
    # Extract values
    valid_pixel_count = valid_count.get('B4', 0)
    total_pixel_count = total_pixels.get('constant', 0)
    
    # Calculate coverage
    coverage = valid_pixel_count / total_pixel_count if total_pixel_count > 0 else 0.0
    
    return coverage

## Initialize functions for water masking

In [3]:
def find_bimodal_trough(histogram_data, band_name='SWIR1', smoothing_sigma=2):
    """
    Find the trough (local minimum) between two peaks in a bimodal distribution.
    
    This implements the concept from Inman & Lyons (2020) of finding the natural 
    boundary between wet and dry pixels in the SWIR reflectance histogram. When 
    you plot SWIR values for a wetland image, you typically see two "humps" 
    (peaks): one for water/wet areas (low reflectance) and one for dry land 
    (high reflectance). The valley (trough) between these peaks represents the 
    natural separation point.
    
    Parameters:
    -----------
    histogram_data : dict
        GEE histogram output with structure: 
        {'BandName': {'bucketMeans': [...], 'histogram': [...]}}
    band_name : str
        Name of the band ('SWIR1' or 'SWIR2')
    smoothing_sigma : float
        Gaussian smoothing parameter to reduce noise in the histogram
        Higher values = smoother curve but may miss subtle features
        
    Returns:
    --------
    float : The reflectance value at the trough (threshold)
    dict : Diagnostic information about the distribution
    """
    
    # Extract histogram data
    band_data = histogram_data[band_name]
    means = np.array(band_data['bucketMeans'])  # Reflectance values (x-axis)
    counts = np.array(band_data['histogram'])    # Pixel counts (y-axis)
    
    # Smooth the histogram to reduce noise
    # Think of this like drawing a smooth curve through scattered points
    counts_smooth = gaussian_filter1d(counts, sigma=smoothing_sigma)
    
    # Find peaks (the two "humps" in the histogram)
    # prominence ensures we only find significant peaks, not small bumps
    peaks, peak_properties = find_peaks(
        counts_smooth, 
        prominence=np.max(counts_smooth) * 0.1  # Peak must be at least 10% of tallest peak
    )
    
    # If we don't find two clear peaks, fall back to percentile method
    if len(peaks) < 2:
        # Use the 30th percentile as a conservative wet/dry boundary
        cumsum = np.cumsum(counts)
        total = cumsum[-1]
        threshold_idx = np.where(cumsum >= total * 0.30)[0][0]
        threshold = means[threshold_idx]
        
        diagnostics = {
            'method': 'percentile_fallback',
            'threshold': threshold,
            'peaks_found': len(peaks),
            'wet_mode': None,
            'dry_mode': None,
            'reason': 'Bimodal structure not clear - using 30th percentile'
        }
        
        return threshold, diagnostics
    
    # Sort peaks by reflectance value (left to right on histogram)
    peak_indices = peaks[np.argsort(means[peaks])]
    
    # The first peak (leftmost) = wet mode (low reflectance)
    # The second peak (rightmost) = dry mode (high reflectance)
    wet_peak_idx = peak_indices[0]
    dry_peak_idx = peak_indices[1] if len(peak_indices) > 1 else peak_indices[0]
    
    # Find the lowest point (trough) between the two peaks
    search_range = counts_smooth[wet_peak_idx:dry_peak_idx+1]
    local_minima = argrelextrema(search_range, np.less)[0]
    
    if len(local_minima) > 0:
        # Take the deepest minimum (lowest point in the valley)
        trough_local_idx = local_minima[np.argmin(search_range[local_minima])]
        trough_idx = wet_peak_idx + trough_local_idx
        threshold = means[trough_idx]
    else:
        # Fallback: use midpoint between the two peaks
        threshold = (means[wet_peak_idx] + means[dry_peak_idx]) / 2
    
    # Package diagnostic information
    diagnostics = {
        'method': 'bimodal_trough',
        'threshold': threshold,
        'wet_mode': means[wet_peak_idx],           # Reflectance of wet peak
        'wet_mode_count': int(counts[wet_peak_idx]),  # Height of wet peak
        'dry_mode': means[dry_peak_idx],           # Reflectance of dry peak
        'dry_mode_count': int(counts[dry_peak_idx]),  # Height of dry peak
        'peaks_found': len(peaks),
        'trough_position': (threshold - means[wet_peak_idx]) / (means[dry_peak_idx] - means[wet_peak_idx])  # 0-1 scale
    }
    
    return threshold, diagnostics

def calculate_dynamic_swir2_threshold(image, roi, min_swir2=400, max_swir2=1500,
                                       save_plot=True, output_dir=None, 
                                       year=None, month=None):
    """
    Calculate dynamic SWIR2 threshold for a given image based on its histogram.
    
    This function analyzes the distribution of SWIR2 reflectance values across
    your study area and finds the natural separation between wet and dry pixels.
    The threshold is scene-specific and adapts to seasonal flooding conditions.
    
    This simplified version returns only SWIR2 threshold for use in Test 6
    (vegetated inundation enhancement).
    
    ADAPTED FOR SENTINEL-2: Uses DN values from B12 band instead of reflectance.
    
    Parameters:
    -----------
    image : ee.Image
        The Sentinel-2 composite image
    roi : ee.Geometry
        Region of interest (your study area boundary)
    min_swir2, max_swir2 : float
        Safety constraints on SWIR2 threshold (DN units)
        Default range: 400 to 1500 DN
    save_plot : bool, optional (default=True)
        Whether to save histogram plot with threshold
    output_dir : str, optional (default=None)
        Directory to save plots. If None, saves to current directory
    year : int, optional
        Year for plot filename
    month : int, optional
        Month for plot filename
        
    Returns:
    --------
    float : The calculated SWIR2 threshold value (DN units)
    """
    
    # Extract SWIR2 band (B12 for Sentinel-2)
    swir2 = image.select(['B12'])
    
    # Get histogram from Google Earth Engine
    hist_dict = swir2.reduceRegion(
        reducer=ee.Reducer.histogram(maxBuckets=100),
        geometry=roi,
        scale=10,
        maxPixels=1e13
    ).getInfo()
    
    # Prepare histogram data for analysis
    swir2_hist = {'B12': hist_dict['B12']}
    
    # Find the trough (natural boundary) in histogram
    swir2_threshold, swir2_diag = find_bimodal_trough(swir2_hist, 'B12')
    
    # Apply safety constraints to prevent unreasonable values
    swir2_threshold_clipped = np.clip(swir2_threshold, min_swir2, max_swir2)
    
    # Create plot if requested
    if save_plot:
        import matplotlib.pyplot as plt
        import os
        
        # Extract histogram data
        means = np.array(hist_dict['B12']['bucketMeans'])
        counts = np.array(hist_dict['B12']['histogram'])
        
        # Create figure
        plt.figure(figsize=(10, 6))
        plt.bar(means, counts, width=(means[1] - means[0]) * 0.8, 
                color='steelblue', alpha=0.7, edgecolor='black', linewidth=0.5)
        
        # Add threshold line
        plt.axvline(swir2_threshold_clipped, color='red', linestyle='--', 
                   linewidth=2, label=f'Threshold = {swir2_threshold_clipped:.0f} DN')
        
        # Add wet and dry mode lines if available
        if swir2_diag.get('wet_mode') is not None:
            plt.axvline(swir2_diag['wet_mode'], color='blue', linestyle=':', 
                       linewidth=1.5, alpha=0.7, label=f'Wet Mode = {swir2_diag["wet_mode"]:.0f} DN')
        if swir2_diag.get('dry_mode') is not None:
            plt.axvline(swir2_diag['dry_mode'], color='orange', linestyle=':', 
                       linewidth=1.5, alpha=0.7, label=f'Dry Mode = {swir2_diag["dry_mode"]:.0f} DN')
        
        # Labels and title
        plt.xlabel('SWIR2 DN Value (Band 12)', fontsize=12, fontweight='bold')
        plt.ylabel('Pixel Count', fontsize=12, fontweight='bold')
        
        if year and month:
            plt.title(f'SWIR2 Histogram with Dynamic Threshold (DN Units)\n{year}-{month:02d}', 
                     fontsize=14, fontweight='bold')
        else:
            plt.title('SWIR2 Histogram with Dynamic Threshold (DN Units)', 
                     fontsize=14, fontweight='bold')
        
        plt.legend(loc='upper right', fontsize=10)
        plt.grid(True, alpha=0.3, linestyle='--')
        plt.tight_layout()
        
        # Determine output directory
        if output_dir is None:
            output_dir = os.getcwd()
        else:
            os.makedirs(output_dir, exist_ok=True)
        
        # Create filename
        if year and month:
            base_filename = f'SWIR2_threshold_DN_{year}_{month:02d}'
        else:
            base_filename = 'SWIR2_threshold_DN'
        
        # Save as PNG and JPEG
        png_path = os.path.join(output_dir, f'{base_filename}.png')
        jpeg_path = os.path.join(output_dir, f'{base_filename}.jpeg')
        
        plt.savefig(png_path, dpi=300, bbox_inches='tight')
        plt.savefig(jpeg_path, dpi=300, bbox_inches='tight', format='jpeg')
        plt.close()
        
        print(f"Plots saved:")
        print(f"  PNG:  {png_path}")
        print(f"  JPEG: {jpeg_path}")
    
    return float(swir2_threshold_clipped)

def morphological_filter(dswe_image, size_threshold=150, max_class_threshold=2, 
                         roi=None, return_diagnostics=True):
    """
    Remove isolated blobs of low-confidence water classifications that are completely
    surrounded by dry pixels. Preserves any blob containing high-confidence water pixels
    (class > 2) regardless of size, ensuring the main floodplain "megablob" is retained.
    
    ADAPTED FOR SENTINEL-2: Uses scale=10m and adjusted parameters for equivalent area coverage.
    
    A blob is removed if:
    1. It is smaller than size_threshold (in pixels), AND
    2. All pixels in the blob are class 1 or 2 (max value <= max_class_threshold)
    
    Parameters:
    -----------
    dswe_image : ee.Image
        DSWE classification image (0=no water, 1=low, 2=partial, 3=moderate, 4=high)
    size_threshold : int, optional (default=150)
        Maximum blob size (in pixels) eligible for removal
        At 10m resolution: 150 pixels = 1.5 hectares
        Default maintains similar area threshold as Landsat version (~4.5 ha at 30m)
    max_class_threshold : int, optional (default=2)
        Maximum DSWE class value - blobs with ANY pixel > this are always preserved
        Default of 2 means blobs containing class 3 or 4 are kept regardless of size
    roi : ee.Geometry, optional (default=None)
        Region of interest for diagnostic calculations
        If None, diagnostics cannot be calculated
    return_diagnostics : bool, optional (default=True)
        Whether to return diagnostic information about filtering
        Requires roi to be specified
        
    Returns:
    --------
    ee.Image or tuple:
        If return_diagnostics=False: filtered DSWE image
        If return_diagnostics=True: (filtered_dswe, diagnostics_dict)
        
    Diagnostics dict contains:
        - pixels_removed: int
        - area_removed_km2: float
        - class_1_pixels_removed: int
        - class_2_pixels_removed: int
        - class_3_pixels_removed: int (should be 0)
        - class_4_pixels_removed: int (should be 0)
        - size_threshold_used: int
        - max_class_threshold_used: int
        - percent_water_removed: float
    """
    
    # Step 1: Create binary mask of any water (classes 1-4)
    water_mask = dswe_image.gt(0)
    
    # Step 2: Label connected components
    # Use 8-connectivity (diagonal neighbors connect) to avoid fragmenting natural wetlands
    # maxSize is the tile size for processing - must be <= 1024
    labeled = water_mask.connectedComponents(
        connectedness=ee.Kernel.square(1),  # 8-connectivity
        maxSize=256  # Tile size for processing (not max blob size!)
    )
    
    # Step 3: Add labels band to DSWE image for connected components reduction
    # This is the KEY fix - must add labels to the image being reduced
    dswe_with_labels = dswe_image.addBands(labeled.select('labels'))
    
    # Step 4: Calculate statistics for each blob
    # reduceConnectedComponents maps the blob-level statistic back to every pixel in that blob
    blob_max = dswe_with_labels.reduceConnectedComponents(
        reducer=ee.Reducer.max(),
        labelBand='labels'
    )
    
    blob_count = dswe_with_labels.reduceConnectedComponents(
        reducer=ee.Reducer.count(),
        labelBand='labels'
    )
    
    # Step 5: Create removal mask
    # Identify pixels belonging to blobs that should be removed
    # Note: Band name is 'DSWE' (uppercase) after reduceConnectedComponents
    is_small = blob_count.select('DSWE').lte(size_threshold)
    is_low_confidence = blob_max.select('DSWE').lte(max_class_threshold)
    removal_mask = is_small.And(is_low_confidence)
    
    # Step 6: Apply filter
    # Set pixels in removable blobs to 0 (no water)
    # All other pixels remain unchanged
    filtered_dswe = dswe_image.where(removal_mask, 0)
    
    # Step 7: Calculate diagnostics if requested
    diagnostics = None
    if return_diagnostics and roi is not None:
        try:
            # Count total pixels changed
            changed_pixels = dswe_image.neq(filtered_dswe).And(dswe_image.mask())
            
            # Calculate statistics
            original_stats = dswe_image.gt(0).reduceRegion(
                reducer=ee.Reducer.sum(),
                geometry=roi,
                scale=10,  # Sentinel-2 resolution
                maxPixels=1e13
            ).getInfo()
            
            filtered_stats = filtered_dswe.gt(0).reduceRegion(
                reducer=ee.Reducer.sum(),
                geometry=roi,
                scale=10,  # Sentinel-2 resolution
                maxPixels=1e13
            ).getInfo()
            
            original_water_pixels = original_stats.get('DSWE', 0)
            filtered_water_pixels = filtered_stats.get('DSWE', 0)
            pixels_removed = original_water_pixels - filtered_water_pixels
            area_removed_km2 = pixels_removed * 0.0001  # 10m pixels = 0.0001 km²
            
            # Count pixels by original class that were removed
            class_1_removed = dswe_image.eq(1).And(changed_pixels).reduceRegion(
                reducer=ee.Reducer.sum(),
                geometry=roi,
                scale=10,
                maxPixels=1e13
            ).getInfo().get('DSWE', 0)
            
            class_2_removed = dswe_image.eq(2).And(changed_pixels).reduceRegion(
                reducer=ee.Reducer.sum(),
                geometry=roi,
                scale=10,
                maxPixels=1e13
            ).getInfo().get('DSWE', 0)
            
            # These should always be zero if filter works correctly
            class_3_removed = dswe_image.eq(3).And(changed_pixels).reduceRegion(
                reducer=ee.Reducer.sum(),
                geometry=roi,
                scale=10,
                maxPixels=1e13
            ).getInfo().get('DSWE', 0)
            
            class_4_removed = dswe_image.eq(4).And(changed_pixels).reduceRegion(
                reducer=ee.Reducer.sum(),
                geometry=roi,
                scale=10,
                maxPixels=1e13
            ).getInfo().get('DSWE', 0)
            
            diagnostics = {
                'pixels_removed': int(pixels_removed) if pixels_removed else 0,
                'area_removed_km2': round(area_removed_km2, 2) if area_removed_km2 else 0.0,
                'class_1_pixels_removed': int(class_1_removed) if class_1_removed else 0,
                'class_2_pixels_removed': int(class_2_removed) if class_2_removed else 0,
                'class_3_pixels_removed': int(class_3_removed) if class_3_removed else 0,
                'class_4_pixels_removed': int(class_4_removed) if class_4_removed else 0,
                'size_threshold_used': size_threshold,
                'max_class_threshold_used': max_class_threshold,
                'percent_water_removed': round(100 * pixels_removed / original_water_pixels, 2) if original_water_pixels > 0 else 0.0
            }
            
        except Exception as e:
            logging.warning(f"Could not calculate diagnostics: {e}")
            diagnostics = {
                'pixels_removed': 0,
                'area_removed_km2': 0.0,
                'class_1_pixels_removed': 0,
                'class_2_pixels_removed': 0,
                'class_3_pixels_removed': 0,
                'class_4_pixels_removed': 0,
                'size_threshold_used': size_threshold,
                'max_class_threshold_used': max_class_threshold,
                'percent_water_removed': 0.0,
                'error': str(e)
            }
    
    # Add metadata to filtered image
    filtered_dswe = filtered_dswe.set({
        'morphological_filter_applied': True,
        'blob_size_threshold': size_threshold,
        'blob_max_class_threshold': max_class_threshold,
        'sensor': 'Sentinel-2',
        'resolution_m': 10
    })
    
    if return_diagnostics:
        return filtered_dswe, diagnostics
    else:
        return filtered_dswe

def apply_dswe(image):
    """Apply DSWE classification to a Sentinel-2 image."""
    blue = image.select('B2')
    green = image.select('B3')
    red = image.select('B4')
    nir = image.select('B8')
    swir1 = image.select('B11')
    swir2 = image.select('B12')

    mndwi = green.subtract(swir1).divide(green.add(swir1)).rename("MNDWI")
    ndvi = nir.subtract(red).divide(nir.add(red)).rename("NDVI")
    mbsrv = green.add(red).rename("MBSRV")
    mbsrn = nir.add(swir1).rename("MBSRN")
    awesh = blue.add(green.multiply(2.5)).subtract(mbsrn.multiply(1.5)).subtract(swir2.multiply(0.25)).rename("AWESH")

    t1 = mndwi.gt(0.124)
    t2 = mbsrv.gt(mbsrn)
    t3 = awesh.gt(0)
    t4 = (mndwi.gt(-0.44)).And(swir1.lt(900)).And(nir.lt(1500)).And(ndvi.lt(0.7))
    t5 = (mndwi.gt(-0.5)).And(green.lt(1000)).And(swir1.lt(3000)).And(swir2.lt(1000)).And(nir.lt(2500))

    dswe = (t1.multiply(1)
            .add(t2.multiply(10))
            .add(t3.multiply(100))
            .add(t4.multiply(1000))
            .add(t5.multiply(10000)))

    no_water = dswe.eq(0).Or(dswe.eq(1)).Or(dswe.eq(10)).Or(dswe.eq(100)).Or(dswe.eq(1000))
    high_conf_water = dswe.eq(1111).Or(dswe.eq(10111)).Or(dswe.eq(11101)).Or(dswe.eq(11110)).Or(dswe.eq(11111))
    moderate_conf_water = dswe.eq(111).Or(dswe.eq(1011)).Or(dswe.eq(1101)).Or(dswe.eq(1110))\
        .Or(dswe.eq(10011)).Or(dswe.eq(10101)).Or(dswe.eq(10110)).Or(dswe.eq(11001))\
        .Or(dswe.eq(11010)).Or(dswe.eq(11100))
    potential_wetland = dswe.eq(11000)
    low_conf_water = dswe.eq(11).Or(dswe.eq(101)).Or(dswe.eq(110))\
        .Or(dswe.eq(1001)).Or(dswe.eq(1010)).Or(dswe.eq(1100))\
        .Or(dswe.eq(10000)).Or(dswe.eq(10001)).Or(dswe.eq(10010)).Or(dswe.eq(10100))

    dswe_final = (no_water.multiply(0)
                  .add(high_conf_water.multiply(4))
                  .add(moderate_conf_water.multiply(3))
                  .add(potential_wetland.multiply(2))
                  .add(low_conf_water.multiply(1))
                  .rename("DSWE"))

    return dswe_final

def Dswe_with_Test6(image, roi, min_swir2=400, max_swir2=1500, 
                     save_plot=True, output_dir=None, year=None, month=None):
    """
    Calculate DSWE classification with Test 6 enhancement for vegetated inundation.
    
    This function applies the standard DSWE algorithm, then upgrades class confidence
    for pixels that also pass Test 6 (SWIR2 < dynamic threshold). This enhancement
    is designed to better capture water beneath dense vegetation (e.g., papyrus swamps)
    where traditional spectral indices may fail but SWIR2 still indicates moisture.
    
    ADAPTED FOR SENTINEL-2: Uses DN values from B12 band instead of reflectance.
    
    Upgrade logic:
    - Class 0 (No Water) + Test 6 pass → Class 1 (Low Water)
    - Class 1 (Low Water) + Test 6 pass → Class 2 (Partial Wetland)
    - Class 2 (Partial Wetland) + Test 6 pass → Class 3 (Moderate Water)
    - Class 3 (Moderate Water) + Test 6 pass → Class 4 (High Water)
    - Class 4 (High Water) → Remains Class 4 (no change)
    
    Parameters:
    -----------
    image : ee.Image
        Sentinel-2 composite with standard bands
    roi : ee.Geometry
        Region of interest for threshold calculation
    min_swir2, max_swir2 : float
        Safety constraints on SWIR2 threshold (DN units)
        Default range: 400 to 1500 DN
    save_plot : bool, optional (default=True)
        Whether to save SWIR2 histogram plot
    output_dir : str, optional (default=None)
        Directory to save plots
    year : int, optional
        Year for metadata and plot filename
    month : int, optional
        Month for metadata and plot filename
        
    Returns:
    --------
    tuple : (upgraded_classification, original_classification, swir2_threshold)
        - upgraded_classification: ee.Image with Test 6 upgrades applied
        - original_classification: ee.Image with standard DSWE (for comparison)
        - swir2_threshold: float, the calculated SWIR2 threshold value (DN)
    """
    
    # Step 1: Run standard DSWE algorithm (using Sentinel-2 version)
    original_dswe = apply_dswe(image)
    
    # Step 2: Calculate dynamic SWIR2 threshold
    swir2_threshold = calculate_dynamic_swir2_threshold(
        image, roi, 
        min_swir2=min_swir2, 
        max_swir2=max_swir2,
        save_plot=save_plot,
        output_dir=output_dir,
        year=year,
        month=month
    )
    
    # Step 3: Create Test 6 (SWIR2 < threshold)
    swir2 = image.select(['B12'])
    test6 = swir2.lt(swir2_threshold)
    
    # Step 4: Apply upgrade logic
    # Start with original classification
    upgraded_dswe = original_dswe
    
    # Upgrade class 0 → 1 if Test 6 passes
    upgraded_dswe = upgraded_dswe.where(
        original_dswe.eq(0).And(test6), 
        1
    )
    
    # Upgrade class 1 → 2 if Test 6 passes
    upgraded_dswe = upgraded_dswe.where(
        original_dswe.eq(1).And(test6), 
        2
    )
    
    # Upgrade class 2 → 3 if Test 6 passes
    upgraded_dswe = upgraded_dswe.where(
        original_dswe.eq(2).And(test6), 
        3
    )
    
    # Upgrade class 3 → 4 if Test 6 passes
    upgraded_dswe = upgraded_dswe.where(
        original_dswe.eq(3).And(test6), 
        4
    )
    
    # Class 4 remains unchanged (no .where() operation needed)
    
    # Step 5: Add metadata to both images
    metadata = {
        'swir2_threshold_DN': swir2_threshold,
        'test6_applied': True,
        'algorithm': 'DSWE_with_Test6_Sentinel2',
        'sensor': 'Sentinel-2',
        'resolution_m': 10
    }
    
    if year is not None:
        metadata['year'] = year
    if month is not None:
        metadata['month'] = month
    
    upgraded_dswe = upgraded_dswe.set(metadata).rename(['DSWE'])
    original_dswe = original_dswe.set({
        'swir2_threshold_DN': swir2_threshold,
        'test6_applied': False,
        'algorithm': 'DSWE_standard_Sentinel2',
        'sensor': 'Sentinel-2',
        'resolution_m': 10
    }).rename(['DSWE'])
    
    return upgraded_dswe, original_dswe, swir2_threshold

## Implement functions for exporting

In [4]:
def export_to_asset(image, year, month, asset_folder, roi, 
                    swir2_threshold=None, qc_value=None, 
                    morpho_diagnostics=None, size_threshold=150, 
                    max_class_threshold=2):
    """
    Export DSWE composite to GEE asset with comprehensive metadata.
    
    Parameters:
    -----------
    image : ee.Image
        DSWE classification image to export
    year : int
        Year of the data
    month : int
        Month of the data (1-12)
    asset_folder : str
        GEE asset folder path
    roi : ee.Geometry
        Region of interest for export
    swir2_threshold : float, optional
        SWIR2 threshold value (DN) used in Test 6
    qc_value : int, optional
        Quality control value (0=base month, 1=±1 month, 2=±2 months)
    morpho_diagnostics : dict, optional
        Dictionary containing morphological filter diagnostics
    size_threshold : int, optional (default=150)
        Blob size threshold used in morphological filter
    max_class_threshold : int, optional (default=2)
        Max class threshold used in morphological filter
    """
    
    asset_id = f"{asset_folder}/DSWE_{year}_{month:02d}"
    
    # Check if asset already exists
    try:
        ee.data.getAsset(asset_id)
        logging.info(f"Skipping {asset_id}, already exists.")
        return
    except:
        pass  # Asset doesn't exist, continue with export
    
    # Build comprehensive metadata dictionary
    metadata = {
        'year': year,
        'month': month,
        'algorithm': 'DSWE_with_Test6_Sentinel2',
        'sensor': 'Sentinel-2',
        'resolution_m': 10,
        'test6_applied': True if swir2_threshold is not None else False,
        'morphological_filter_applied': True if morpho_diagnostics is not None else False,
        'processing_date': datetime.now().isoformat()
    }
    
    # Add SWIR2 threshold if provided
    if swir2_threshold is not None:
        metadata['swir2_threshold_DN'] = float(swir2_threshold)
    
    # Add QC value if provided
    if qc_value is not None:
        metadata['qc_temporal_expansion'] = int(qc_value)
        metadata['qc_description'] = f"0=base month, 1=±1 month, 2=±2 months (value={qc_value})"
    
    # Add morphological filter parameters
    metadata['blob_size_threshold'] = int(size_threshold)
    metadata['blob_max_class_threshold'] = int(max_class_threshold)
    
    # Add morphological filter diagnostics if provided
    if morpho_diagnostics is not None:
        metadata['pixels_removed'] = int(morpho_diagnostics.get('pixels_removed', 0))
        metadata['area_removed_km2'] = float(morpho_diagnostics.get('area_removed_km2', 0.0))
        metadata['class_1_pixels_removed'] = int(morpho_diagnostics.get('class_1_pixels_removed', 0))
        metadata['class_2_pixels_removed'] = int(morpho_diagnostics.get('class_2_pixels_removed', 0))
        metadata['class_3_pixels_removed'] = int(morpho_diagnostics.get('class_3_pixels_removed', 0))
        metadata['class_4_pixels_removed'] = int(morpho_diagnostics.get('class_4_pixels_removed', 0))
        metadata['percent_water_removed'] = float(morpho_diagnostics.get('percent_water_removed', 0.0))
    
    # Set metadata on image
    image_with_metadata = image.set(metadata)
    
    # Export to asset
    task = ee.batch.Export.image.toAsset(
        image=image_with_metadata.select(["DSWE"]),
        description=f"DSWE_{year}_{month:02d}",
        assetId=asset_id,
        scale=10,
        region=roi,
        maxPixels=1e13
    )
    task.start()
    logging.info(f"Exporting {asset_id}...")
    
def export_qc_raster(qc_image, year, month, asset_folder, roi):
    """
    Export quality control raster showing temporal expansion used.
    
    Parameters:
    -----------
    qc_image : ee.Image
        QC raster with values 0, 1, or 2
    year : int
        Year of the data
    month : int
        Month of the data (1-12)
    asset_folder : str
        GEE asset folder path for QC products
    roi : ee.Geometry
        Region of interest for export
    """
    
    asset_id = f"{asset_folder}/QC_{year}_{month:02d}"
    
    # Check if asset already exists
    try:
        ee.data.getAsset(asset_id)
        logging.info(f"Skipping {asset_id}, already exists.")
        return
    except:
        pass
    
    # Add metadata
    qc_with_metadata = qc_image.set({
        'year': year,
        'month': month,
        'product_type': 'Quality_Control',
        'description': 'Temporal expansion used: 0=base month, 1=±1 month, 2=±2 months',
        'sensor': 'Sentinel-2',
        'processing_date': datetime.now().isoformat()
    })
    
    # Export
    task = ee.batch.Export.image.toAsset(
        image=qc_with_metadata,
        description=f"QC_{year}_{month:02d}",
        assetId=asset_id,
        scale=10,
        region=roi,
        maxPixels=1e13
    )
    task.start()
    logging.info(f"Exporting {asset_id}...")
    
def export_composite_image(image, year, month, composite_folder, roi, qc_value=None):
    """
    Export the median composite image (RGB bands only) to GEE asset.
    
    Parameters:
    -----------
    image : ee.Image
        Sentinel-2 composite image
    year : int
        Year of the data
    month : int
        Month of the data (1-12)
    composite_folder : str
        GEE asset folder path for composites
    roi : ee.Geometry
        Region of interest for export
    qc_value : int, optional
        Quality control value indicating temporal expansion used
    """
    
    asset_id = f"{composite_folder}/Composite_{year}_{month:02d}"
    
    # Check if asset already exists
    try:
        ee.data.getAsset(asset_id)
        logging.info(f"Skipping {asset_id}, already exists.")
        return
    except:
        pass
    
    # Select RGB bands (Red, Green, Blue)
    rgb_image = image.select(['B4', 'B3', 'B2'])
    
    # Build metadata
    metadata = {
        'year': year,
        'month': month,
        'product_type': 'RGB_Composite',
        'bands': 'B4,B3,B2 (Red,Green,Blue)',
        'sensor': 'Sentinel-2',
        'resolution_m': 10,
        'processing_date': datetime.now().isoformat()
    }
    
    # Add QC value if provided
    if qc_value is not None:
        metadata['qc_temporal_expansion'] = int(qc_value)
        metadata['composite_date_range'] = f"±{qc_value} months from base"
    
    # Set metadata
    rgb_with_metadata = rgb_image.set(metadata)
    
    # Export
    task = ee.batch.Export.image.toAsset(
        image=rgb_with_metadata,
        description=f"Composite_RGB_{year}_{month:02d}",
        assetId=asset_id,
        scale=10,
        region=roi,
        maxPixels=1e13
    )
    task.start()
    logging.info(f"Exporting RGB composite to {asset_id}...")
    
def process_monthly_dswe(start_date, end_date, roi, 
                         mask_asset_folder, 
                         composite_asset_folder,
                         qc_asset_folder,
                         output_dir='./swir2_histograms_DN',
                         min_swir2_dn=400,
                         max_swir2_dn=1500,
                         morpho_size_threshold=150,
                         morpho_class_threshold=2,
                         max_temporal_expansion=2,
                         coverage_threshold=0.95,
                         save_histograms=True):
    """
    Generate and export enhanced DSWE products for each month with Test 6 and morphological filtering.
    
    Parameters:
    -----------
    start_date : datetime
        Start date for processing
    end_date : datetime
        End date for processing
    roi : ee.Geometry
        Region of interest
    mask_asset_folder : str
        GEE asset folder for DSWE products
    composite_asset_folder : str
        GEE asset folder for RGB composites
    qc_asset_folder : str
        GEE asset folder for QC rasters
    output_dir : str, optional (default='./swir2_histograms_DN')
        Directory to save SWIR2 histogram plots
    min_swir2_dn : float, optional (default=400)
        Minimum SWIR2 threshold (DN)
    max_swir2_dn : float, optional (default=1500)
        Maximum SWIR2 threshold (DN)
    morpho_size_threshold : int, optional (default=150)
        Blob size threshold for morphological filter (pixels)
    morpho_class_threshold : int, optional (default=2)
        Max class threshold for morphological filter
    max_temporal_expansion : int, optional (default=2)
        Maximum temporal expansion for gap filling (±N months)
    coverage_threshold : float, optional (default=0.95)
        Minimum coverage threshold (0.0 to 1.0)
    save_histograms : bool, optional (default=True)
        Whether to save SWIR2 histogram plots
    """
    
    # Create output directory for histograms
    if save_histograms:
        os.makedirs(output_dir, exist_ok=True)
        logging.info(f"SWIR2 histograms will be saved to: {output_dir}")
    
    current_date = start_date
    
    while current_date <= end_date:
        year = current_date.year
        month = current_date.month
        
        logging.info(f"\n{'='*60}")
        logging.info(f"Processing {year}-{month:02d}")
        logging.info(f"{'='*60}")
        
        try:
            # ================================================================
            # Step 1: Create composite with gap filling
            # ================================================================
            logging.info("Step 1: Creating composite with gap filling...")
            
            composite, qc_value, coverage = create_monthly_composite_with_gap_filling(
                year=year,
                month=month,
                roi=roi,
                max_expansion=max_temporal_expansion,
                coverage_threshold=coverage_threshold
            )
            
            # Check if composite was created
            if composite is None:
                logging.error(f"  No data available for {year}-{month:02d}. Skipping.")
                current_date = (current_date.replace(day=28) + timedelta(days=4)).replace(day=1)
                continue
            
            logging.info(f"  ✓ Composite created successfully")
            logging.info(f"    QC Value: {qc_value} (0=base month, 1=±1 month, 2=±2 months)")
            logging.info(f"    Final Coverage: {coverage*100:.1f}%")
            
            # Warning if coverage is still below threshold
            if coverage < coverage_threshold:
                logging.warning(f"  ⚠ Coverage ({coverage*100:.1f}%) below threshold "
                              f"({coverage_threshold*100:.0f}%) after {max_temporal_expansion} month expansion")
            
            # ================================================================
            # Step 2: Calculate dynamic SWIR2 threshold
            # ================================================================
            logging.info("Step 2: Calculating dynamic SWIR2 threshold...")
            
            swir2_threshold = calculate_dynamic_swir2_threshold(
                composite, 
                roi,
                min_swir2=min_swir2_dn,
                max_swir2=max_swir2_dn,
                save_plot=save_histograms,
                output_dir=output_dir,
                year=year,
                month=month
            )
            
            logging.info(f"  ✓ SWIR2 Threshold: {swir2_threshold:.0f} DN")
            
            # ================================================================
            # Step 3: Apply DSWE with Test 6
            # ================================================================
            logging.info("Step 3: Applying DSWE with Test 6...")
            
            upgraded_dswe, original_dswe, threshold = Dswe_with_Test6(
                composite,
                roi,
                min_swir2=min_swir2_dn,
                max_swir2=max_swir2_dn,
                save_plot=False,  # Already saved in Step 2
                output_dir=output_dir,
                year=year,
                month=month
            )
            
            logging.info(f"  ✓ DSWE classification complete (with Test 6 enhancement)")
            
            # ================================================================
            # Step 4: Apply morphological filter with diagnostics
            # ================================================================
            logging.info("Step 4: Applying morphological filter...")
            
            filtered_dswe, diagnostics = morphological_filter(
                upgraded_dswe,
                size_threshold=morpho_size_threshold,
                max_class_threshold=morpho_class_threshold,
                roi=roi,
                return_diagnostics=True
            )
            
            logging.info(f"  ✓ Morphological filter applied")
            
            # ================================================================
            # Step 5: Log diagnostics
            # ================================================================
            logging.info("Step 5: Filter Diagnostics:")
            logging.info(f"  Pixels removed: {diagnostics['pixels_removed']:,}")
            logging.info(f"  Area removed: {diagnostics['area_removed_km2']:.2f} km²")
            logging.info(f"  Percent water removed: {diagnostics['percent_water_removed']:.2f}%")
            logging.info(f"  Class breakdown:")
            logging.info(f"    Class 1 removed: {diagnostics['class_1_pixels_removed']:,} pixels")
            logging.info(f"    Class 2 removed: {diagnostics['class_2_pixels_removed']:,} pixels")
            logging.info(f"    Class 3 removed: {diagnostics['class_3_pixels_removed']:,} pixels (should be 0)")
            logging.info(f"    Class 4 removed: {diagnostics['class_4_pixels_removed']:,} pixels (should be 0)")
            
            # Sanity check: warn if high-confidence pixels were removed
            if diagnostics['class_3_pixels_removed'] > 0 or diagnostics['class_4_pixels_removed'] > 0:
                logging.warning(f"  ⚠ High-confidence water pixels were removed! Check filter parameters.")
            
            # ================================================================
            # Step 6: Create QC raster
            # ================================================================
            logging.info("Step 6: Creating QC raster...")
            
            qc_raster = ee.Image.constant(qc_value).clip(roi).rename('QC').toInt8()
            
            logging.info(f"  ✓ QC raster created (value={qc_value})")
            
            # ================================================================
            # Step 7: Export all products
            # ================================================================
            logging.info("Step 7: Exporting products to GEE...")
            
            # Export DSWE product
            export_to_asset(
                filtered_dswe, 
                year, 
                month, 
                mask_asset_folder, 
                roi,
                swir2_threshold=swir2_threshold,
                qc_value=qc_value,
                morpho_diagnostics=diagnostics,
                size_threshold=morpho_size_threshold,
                max_class_threshold=morpho_class_threshold
            )
            
            # Export QC raster
            export_qc_raster(
                qc_raster, 
                year, 
                month, 
                qc_asset_folder, 
                roi
            )
            
            # Export RGB composite
            export_composite_image(
                composite, 
                year, 
                month, 
                composite_asset_folder, 
                roi,
                qc_value=qc_value
            )
            
            logging.info(f"  ✓ All exports initiated successfully")
            logging.info(f"✓ Processing complete for {year}-{month:02d}\n")
            
        except Exception as e:
            logging.error(f"✗ Failed to process {year}-{month:02d}: {str(e)}")
            logging.exception("Full traceback:")
        
        # Move to next month
        current_date = (current_date.replace(day=28) + timedelta(days=4)).replace(day=1)
    
    logging.info(f"\n{'='*60}")
    logging.info("Processing complete for all months")
    logging.info(f"{'='*60}\n")

## Input pathnames, parameter, and run 

In [11]:
# Parameters
start_date = datetime(2024, 8, 1)
end_date = datetime(2024, 8, 31)
study_area_path = r"C:\Users\huckr\Desktop\UCSB\Okavango\Data\StudyAreas\Delta_UCB\Delta_UCB_WGS84.shp"

# Asset folders (create these in GEE first!)
mask_asset_folder = "projects/ee-okavango/assets/water_masks/monthly_DSWE_Sent2_10m_v3/DSWE_Products"
composite_asset_folder = "projects/ee-okavango/assets/water_masks/monthly_DSWE_Sent2_10m_v3/Source_S2_Compositesxx"
qc_asset_folder = "projects/ee-okavango/assets/water_masks/monthly_DSWE_Sent2_10m_v3/QC_Rastersxx"

# Load ROI
roi = load_roi(study_area_path)

# Run processing
process_monthly_dswe(
    start_date=start_date,
    end_date=end_date,
    roi=roi,
    mask_asset_folder=mask_asset_folder,
    composite_asset_folder=composite_asset_folder,
    qc_asset_folder=qc_asset_folder,
    output_dir=r'D:\Okavango\Data\Water_Masks\Sentinel2\S2_SWIR2_histograms',
    min_swir2_dn=400,
    max_swir2_dn=1500,
    morpho_size_threshold=150,
    morpho_class_threshold=2,
    max_temporal_expansion=2,
    coverage_threshold=0.95,
    save_histograms=False
)


2025-12-03 10:55:47,124 - INFO - 
2025-12-03 10:55:47,125 - INFO - Processing 2024-08
2025-12-03 10:55:47,127 - INFO - Step 1: Creating composite with gap filling...
2025-12-03 10:55:47,129 - INFO -   Attempting base month composite (2024-08)...
2025-12-03 10:56:06,185 - INFO -   Base month coverage: 100.0%
2025-12-03 10:56:06,187 - INFO -   ✓ Composite created successfully
2025-12-03 10:56:06,189 - INFO -     QC Value: 0 (0=base month, 1=±1 month, 2=±2 months)
2025-12-03 10:56:06,193 - INFO -     Final Coverage: 100.0%
2025-12-03 10:56:06,194 - INFO - Step 2: Calculating dynamic SWIR2 threshold...
2025-12-03 10:59:23,765 - INFO -   ✓ SWIR2 Threshold: 1500 DN
2025-12-03 10:59:23,767 - INFO - Step 3: Applying DSWE with Test 6...
2025-12-03 10:59:24,327 - INFO -   ✓ DSWE classification complete (with Test 6 enhancement)
2025-12-03 10:59:24,329 - INFO - Step 4: Applying morphological filter...
2025-12-03 11:07:21,972 - INFO -   ✓ Morphological filter applied
2025-12-03 11:07:21,975 - INFO

## Monitor GEE tasks

In [15]:
# Get list of all running GEE tasks
task_list = ee.batch.Task.list()

# Print task statuses

for task in task_list:
    print(f"Task: {task.status()['description']}, Status: {task.status()['state']}")

Task: Composite_RGB_2024_08, Status: FAILED
Task: QC_2024_08, Status: FAILED
Task: DSWE_2024_08, Status: RUNNING
Task: Composite_RGB_2024_08, Status: FAILED
Task: QC_2024_08, Status: FAILED
Task: DSWE_2024_08, Status: CANCELLED
Task: Composite_RGB_2024_08, Status: FAILED


KeyboardInterrupt: 