In [1]:
================================================================================
WETLAND PERMANENCE CLASSIFICATION - FINAL CORRECTED VERSION
================================================================================
Brookings County, South Dakota - 2024
Method: NDWI Coefficient of Variation Classification
Author: [Image Bhattarai]
Date: November 2024


- Fixed CDL codes (using only 111, 190, 195)
- Added edge buffering (2-pixel erosion) to reduce mixed pixel effects
- Improved patch counting (4-connectivity)
- Added data validation and quality checks
- Exports: Classification GeoTIFF, CV GeoTIFF, Excel workbook
- Enhanced error handling and reporting
================================================================================
"""

import numpy as np
import rasterio
from rasterio.plot import show
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import seaborn as sns
from scipy.ndimage import label, find_objects, binary_erosion
import pandas as pd
import warnings
import os
from datetime import datetime
from rasterio.warp import reproject, Resampling

warnings.filterwarnings('ignore')
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300

print("="*80)
print("WETLAND PERMANENCE CLASSIFICATION - FINAL CORRECTED VERSION")
print("Brookings County, South Dakota - 2024")
print("="*80)
print(f"Analysis started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# ============================================================================
# CONFIGURATION
# ============================================================================

base_path = "/content/drive/MyDrive/GIS-project/Colab_8Class_Unsupervised_Outputs/01_Mosaicked_Indices/"
cdl_path = "/content/drive/MyDrive/cdl_2024_brookings.tif"
output_path = "/content/drive/MyDrive/Wetland_Report_Final_Corrected/"

os.makedirs(output_path, exist_ok=True)
print(f"\n‚úì Output directory: {output_path}")

# Analysis parameters
LOW_CV_THRESHOLD = 20
HIGH_CV_THRESHOLD = 50
MIN_PATCH_SIZE_HA = 0.01
EDGE_BUFFER_PIXELS = 2  # 2-pixel erosion to remove edge effects

print(f"\nAnalysis Parameters:")
print(f"  CV Thresholds: <{LOW_CV_THRESHOLD}% (Permanent), {LOW_CV_THRESHOLD}-{HIGH_CV_THRESHOLD}% (Vegetated), >{HIGH_CV_THRESHOLD}% (Seasonal)")
print(f"  Minimum patch size: {MIN_PATCH_SIZE_HA} ha")
print(f"  Edge buffer: {EDGE_BUFFER_PIXELS} pixels")

seasons = ['Spring', 'Summer', 'Fall']
ndwi_files = {
    'Spring': f"{base_path}NDWI_Spring_Full_10m.tif",
    'Summer': f"{base_path}NDWI_Summer_Full_10m.tif",
    'Fall': f"{base_path}NDWI_Fall_Full_10m.tif"
}

class_names = {
    1: 'Permanent Wetland',
    2: 'Vegetated Wetland',
    3: 'Seasonal Wetland'
}

class_names_full = {
    1: 'Permanent Wetland (Open Water)',
    2: 'Vegetated Wetland (Emergent)',
    3: 'Seasonal Wetland (Ephemeral)'
}

# Visualization colors
colors_map = ['#4169E1', '#FFA500', '#FF4500']  # Blue, Orange, Red
colors_pie = ['#4169E1', '#FFA500', '#FF4500']
colors_lines = ['#4169E1', '#FFA500', '#FF4500']
markers = ['o', '^', 's']

props = dict(boxstyle='round', facecolor='wheat', alpha=0.8, edgecolor='black', linewidth=2)

# ============================================================================
# STEP 1: LOAD AND VALIDATE DATA
# ============================================================================
print("\n" + "="*80)
print("STEP 1: LOADING AND VALIDATING DATA")
print("="*80)

data = {'NDWI': {}}
metadata = None

# Load NDWI data and get target grid
try:
    with rasterio.open(ndwi_files['Spring']) as src:
        target_transform = src.transform
        target_width = src.width
        target_height = src.height
        target_crs = src.crs
        metadata = src.meta.copy()
        pixel_size = abs(src.transform[0])

    print(f"\nTarget Grid (from NDWI):")
    print(f"  Dimensions: {target_width} x {target_height} pixels")
    print(f"  Resolution: {pixel_size:.1f}m")
    print(f"  CRS: {target_crs}")

except Exception as e:
    print(f"‚ùå ERROR loading NDWI reference file: {e}")
    raise

# Load all NDWI layers
print(f"\nLoading NDWI data for {len(seasons)} seasons...")
for season, file_path in ndwi_files.items():
    try:
        with rasterio.open(file_path) as src:
            ndwi_data = src.read(1).astype(np.float32)
            ndwi_data[ndwi_data == src.nodata] = np.nan
            data['NDWI'][season] = ndwi_data

            # Validate NDWI range
            valid_data = ndwi_data[~np.isnan(ndwi_data)]
            min_val, max_val = np.min(valid_data), np.max(valid_data)
            print(f"  {season:8s}: Range [{min_val:+.3f}, {max_val:+.3f}] | Valid pixels: {len(valid_data):,}")

            if min_val < -1.5 or max_val > 1.5:
                print(f"    ‚ö†Ô∏è  WARNING: NDWI values outside expected range [-1, +1]")

    except Exception as e:
        print(f"‚ùå ERROR loading {season} NDWI: {e}")
        raise

print("‚úì All NDWI data loaded successfully")

# Load CDL
print(f"\nLoading CDL 2024...")
try:
    with rasterio.open(cdl_path) as src:
        cdl_30m = src.read(1)
        cdl_transform = src.transform
        cdl_crs = src.crs
        cdl_res = abs(src.transform[0])

    print(f"  Dimensions: {cdl_30m.shape[1]} x {cdl_30m.shape[0]} pixels")
    print(f"  Resolution: {cdl_res:.1f}m")

    # CRITICAL: Check what CDL codes are actually present
    unique_codes = np.unique(cdl_30m)
    print(f"\n  CDL codes present in file: {len(unique_codes)} unique values")

    # Check for wetland codes
    MODERN_WETLAND_CODES = {
        111: 'Open Water',
        190: 'Woody Wetlands',
        195: 'Herbaceous Wetlands'
    }

    DEPRECATED_CODES = {
        83: 'Water (deprecated)',
        87: 'Wetlands (deprecated)'
    }

    print(f"\n  Wetland code analysis:")
    wetland_found = False
    for code, name in MODERN_WETLAND_CODES.items():
        count = np.sum(cdl_30m == code)
        if count > 0:
            print(f"    ‚úì Code {code} ({name}): {count:,} pixels")
            wetland_found = True
        else:
            print(f"    ‚úó Code {code} ({name}): NOT FOUND")

    for code, name in DEPRECATED_CODES.items():
        count = np.sum(cdl_30m == code)
        if count > 0:
            print(f"    ‚ö†Ô∏è  Code {code} ({name}): {count:,} pixels (SHOULD NOT BE USED)")

    if not wetland_found:
        print(f"\n    ‚ùå ERROR: No wetland codes found in CDL file!")
        print(f"    First 20 unique codes in file: {unique_codes[:20]}")
        raise ValueError("No wetland pixels detected in CDL")

    print("‚úì CDL loaded successfully")

except Exception as e:
    print(f"‚ùå ERROR loading CDL: {e}")
    raise

# ============================================================================
# STEP 2: CREATE AND ALIGN WETLAND MASK (FIXED)
# ============================================================================
print("\n" + "="*80)
print("STEP 2: CREATING AND ALIGNING WETLAND MASK")
print("="*80)

# Use ONLY modern CDL codes (111, 190, 195)
WETLAND_CODES = {
    111: 'Open Water',
    190: 'Woody Wetlands',
    195: 'Herbaceous Wetlands'
}

print(f"Creating wetland mask using CDL codes: {list(WETLAND_CODES.keys())}")

# Create mask at 30m resolution
wetland_mask_30m = np.isin(cdl_30m, list(WETLAND_CODES.keys()))
wetland_pixels_30m = np.sum(wetland_mask_30m)
area_30m = wetland_pixels_30m * (cdl_res**2) / 10000  # Convert to ha

print(f"\nWetland mask at 30m resolution:")
print(f"  Pixels: {wetland_pixels_30m:,}")
print(f"  Area: {area_30m:.2f} ha")

# Reproject to 10m grid
print(f"\nReprojecting from 30m to 10m resolution...")
wetland_mask_10m = np.zeros((target_height, target_width), dtype=np.uint8)

reproject(
    source=wetland_mask_30m.astype(np.uint8),
    destination=wetland_mask_10m,
    src_transform=cdl_transform,
    src_crs=cdl_crs,
    dst_transform=target_transform,
    dst_crs=target_crs,
    resampling=Resampling.nearest,
    dst_nodata=0
)

wetland_mask_10m = wetland_mask_10m.astype(bool)
pixel_area_ha = (pixel_size**2) / 10000  # 10m x 10m = 100 m¬≤ = 0.01 ha

print(f"‚úì Reprojection complete")
print(f"\nWetland mask at 10m resolution:")
print(f"  Pixels: {np.sum(wetland_mask_10m):,}")
print(f"  Area: {np.sum(wetland_mask_10m) * pixel_area_ha:.2f} ha")
print(f"  Pixel size: {pixel_area_ha:.4f} ha")

# Apply edge buffering to reduce mixed pixel effects
print(f"\nApplying {EDGE_BUFFER_PIXELS}-pixel erosion to remove edge effects...")
wetland_mask_original = wetland_mask_10m.copy()
structure = np.ones((3, 3))
wetland_mask_buffered = binary_erosion(wetland_mask_10m, structure=structure, iterations=EDGE_BUFFER_PIXELS)

pixels_removed = np.sum(wetland_mask_original) - np.sum(wetland_mask_buffered)
area_removed = pixels_removed * pixel_area_ha

print(f"  Pixels removed (edge effects): {pixels_removed:,}")
print(f"  Area removed: {area_removed:.2f} ha ({100*pixels_removed/np.sum(wetland_mask_original):.1f}%)")
print(f"  Final core wetland area: {np.sum(wetland_mask_buffered) * pixel_area_ha:.2f} ha")

# Use buffered mask for analysis
wetland_mask = wetland_mask_buffered
total_wetland_pixels = np.sum(wetland_mask)
total_wetland_area = total_wetland_pixels * pixel_area_ha

print(f"\n‚úì Final wetland mask created")
print(f"  Total pixels: {total_wetland_pixels:,}")
print(f"  Total area: {total_wetland_area:.2f} ha")

# --- Figure 1: Wetland Mask ---
print("\nüìä Generating Figure 1: Wetland Mask...")
fig, ax = plt.subplots(figsize=(12, 10))
wetland_display = wetland_mask.astype(float)
wetland_display[~wetland_mask] = np.nan
ax.imshow(wetland_display, cmap='Blues', interpolation='nearest')
ax.set_title('Figure 1: Wetland Mask - Brookings County, SD\n(USDA CDL 2024, 2-pixel edge buffer applied)',
             fontsize=16, fontweight='bold', pad=20)
ax.axis('off')
textstr = (f'Total Area: {total_wetland_area:.2f} ha\n'
           f'Resolution: 10m\n'
           f'CDL Codes: 111, 190, 195\n'
           f'Edge Buffer: {EDGE_BUFFER_PIXELS} pixels')
ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=12,
        verticalalignment='top', bbox=props, family='monospace')
plt.tight_layout()
plt.savefig(f'{output_path}Figure1_Wetland_Mask.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print("‚úì Saved: Figure1_Wetland_Mask.png")

# --- Export Wetland Mask as GeoTIFF ---
print("\nüíæ Exporting Wetland Mask as GeoTIFF...")
mask_meta = metadata.copy()
mask_meta.update({'dtype': 'uint8', 'nodata': 0, 'count': 1})
mask_export = wetland_mask.astype(np.uint8)
# Set non-wetland pixels to nodata (0)
mask_export[~wetland_mask] = 0

with rasterio.open(f'{output_path}Wetland_Mask.tif', 'w', **mask_meta) as dst:
    dst.write(mask_export, 1)
print("‚úì Saved: Wetland_Mask.tif")


# ============================================================================
# STEP 3: CALCULATE COEFFICIENT OF VARIATION (CV)
# ============================================================================
print("\n" + "="*80)
print("STEP 3: CALCULATING COEFFICIENT OF VARIATION")
print("="*80)

# Stack seasonal NDWI
ndwi_stack = np.stack([data['NDWI'][s] for s in seasons], axis=0)

# Calculate statistics
mean_ndwi = np.nanmean(ndwi_stack, axis=0)
std_ndwi = np.nanstd(ndwi_stack, axis=0)

# CV = (std / |mean|) * 100
# Use absolute value to handle negative NDWI (Gao index)
with np.errstate(divide='ignore', invalid='ignore'):
    cv_map = np.divide(std_ndwi, np.abs(mean_ndwi),
                       out=np.full_like(mean_ndwi, np.nan),
                       where=(mean_ndwi != 0)) * 100

# Apply wetland mask
cv_wetlands = cv_map.copy()
cv_wetlands[~wetland_mask] = np.nan

# Calculate CV statistics
cv_valid = cv_wetlands[~np.isnan(cv_wetlands)]
print(f"\nCV Statistics (wetland pixels only):")
print(f"  Valid pixels: {len(cv_valid):,}")
print(f"  Min CV: {np.min(cv_valid):.1f}%")
print(f"  Max CV: {np.max(cv_valid):.1f}%")
print(f"  Mean CV: {np.mean(cv_valid):.1f}%")
print(f"  Median CV: {np.median(cv_valid):.1f}%")
print(f"  25th percentile: {np.percentile(cv_valid, 25):.1f}%")
print(f"  75th percentile: {np.percentile(cv_valid, 75):.1f}%")

print("‚úì CV calculated successfully")

# --- Figure 2: CV Map ---
print("\nüìä Generating Figure 2: CV Map...")
fig, ax = plt.subplots(figsize=(12, 10))
cv_colors = ['#006837', '#a6d96a', '#ffffbf', '#fdae61', '#d7191c']
cmap_cv = mcolors.LinearSegmentedColormap.from_list('cv_cmap', cv_colors, N=256)
im = ax.imshow(cv_wetlands, cmap=cmap_cv, vmin=0, vmax=100, interpolation='bilinear')
cbar = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.04, extend='max')
cbar.set_label('Coefficient of Variation (%)', fontsize=14, fontweight='bold')
cbar.ax.hlines([LOW_CV_THRESHOLD, HIGH_CV_THRESHOLD], 0, 1, colors='black', linewidth=2, linestyles='--')
ax.set_title('Figure 2: NDWI Coefficient of Variation (CV) Map\nBrookings County, SD',
             fontsize=14, fontweight='bold', pad=20)
ax.axis('off')
textstr = (f'CV Statistics:\n'
           f'Mean: {np.mean(cv_valid):.1f}%\n'
           f'Median: {np.median(cv_valid):.1f}%\n'
           f'Range: [{np.min(cv_valid):.1f}%, {np.max(cv_valid):.1f}%]\n'
           f'Valid Pixels: {len(cv_valid):,}')
ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=10,
        verticalalignment='top', bbox=props, family='monospace')
plt.tight_layout()
plt.savefig(f'{output_path}Figure2_CV_Map.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print("‚úì Saved: Figure2_CV_Map.png")

# Export CV as GeoTIFF
print("\nüíæ Exporting CV map as GeoTIFF...")
cv_meta = metadata.copy()
cv_meta.update({'dtype': 'float32', 'nodata': -9999, 'count': 1})
cv_export = cv_map.copy()
cv_export[np.isnan(cv_export)] = -9999
with rasterio.open(f'{output_path}CV_Map.tif', 'w', **cv_meta) as dst:
    dst.write(cv_export.astype('float32'), 1)
print("‚úì Saved: CV_Map.tif")

# ============================================================================
# STEP 4: CLASSIFY WETLAND PERMANENCE
# ============================================================================
print("\n" + "="*80)
print("STEP 4: CLASSIFYING WETLAND PERMANENCE")
print("="*80)

# Apply thresholds
classification = np.full_like(cv_wetlands, np.nan)
classification[cv_wetlands < LOW_CV_THRESHOLD] = 1  # Permanent
classification[(cv_wetlands >= LOW_CV_THRESHOLD) & (cv_wetlands < HIGH_CV_THRESHOLD)] = 2 # Vegetated
classification[cv_wetlands >= HIGH_CV_THRESHOLD] = 3      # Seasonal

# Initial classification counts
for class_val in [1, 2, 3]:
    count = np.sum(classification == class_val)
    print(f"  {class_names[class_val]:20s}: {count:,} pixels ({count * pixel_area_ha:.2f} ha)")

print("\nüßπ Cleaning classification (removing patches < {MIN_PATCH_SIZE_HA} ha)...")
classification_clean = classification.copy()

# Use 4-connectivity for more conservative patch counting
structure_4conn = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])

for class_val in [1, 2, 3]:
    class_mask = (classification == class_val)
    labeled, num_features = label(class_mask, structure=structure_4conn)

    patches_removed = 0
    area_removed = 0

    for patch_id in range(1, num_features + 1):
        patch_pixels = np.sum(labeled == patch_id)
        patch_area = patch_pixels * pixel_area_ha

        if patch_area < MIN_PATCH_SIZE_HA:
            classification_clean[labeled == patch_id] = np.nan
            patches_removed += 1
            area_removed += patch_area

    print(f"  {class_names[class_val]:20s}: Removed {patches_removed:,} patches ({area_removed:.2f} ha)")

print("‚úì Classification complete")

# Final classification counts
total_classified_pixels = 0
print("\nFinal classification:")
for class_val in [1, 2, 3]:
    count = np.sum(classification_clean == class_val)
    total_classified_pixels += count
    area = count * pixel_area_ha
    print(f"  {class_names[class_val]:20s}: {count:,} pixels ({area:.2f} ha)")

print(f"  {'TOTAL':20s}: {total_classified_pixels:,} pixels ({total_classified_pixels * pixel_area_ha:.2f} ha)")

# --- Figure 3: Classification Map ---
print("\nüìä Generating Figure 3: Classification Map...")
fig, ax = plt.subplots(figsize=(12, 10))

class_display = np.full(classification_clean.shape + (3,), 0.95)
class_display[classification_clean == 1] = mcolors.to_rgb(colors_map[0])
class_display[classification_clean == 2] = mcolors.to_rgb(colors_map[1])
class_display[classification_clean == 3] = mcolors.to_rgb(colors_map[2])

ax.imshow(class_display, interpolation='nearest')
ax.set_title('Figure 3: Wetland Permanence Classification Map\nBrookings County, SD (NDWI CV Method)',
             fontsize=14, fontweight='bold', pad=20)
ax.axis('off')

legend_elements = [
    Patch(facecolor=colors_map[0], edgecolor='black',
          label=f'{class_names_full[1]} (CV < {LOW_CV_THRESHOLD}%)'),
    Patch(facecolor=colors_map[1], edgecolor='black',
          label=f'{class_names_full[2]} (CV {LOW_CV_THRESHOLD}-{HIGH_CV_THRESHOLD}%)'),
    Patch(facecolor=colors_map[2], edgecolor='black',
          label=f'{class_names_full[3]} (CV > {HIGH_CV_THRESHOLD}%)')
]
ax.legend(handles=legend_elements, loc='lower right', fontsize=11, framealpha=0.95,
         edgecolor='black', fancybox=True)

plt.tight_layout()
plt.savefig(f'{output_path}Figure3_Classification_Map.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print("‚úì Saved: Figure3_Classification_Map.png")

# Export classification as GeoTIFF
print("\nüíæ Exporting classification as GeoTIFF...")
class_meta = metadata.copy()
class_meta.update({'dtype': 'uint8', 'nodata': 0, 'count': 1})
class_export = classification_clean.copy()
class_export[np.isnan(class_export)] = 0
with rasterio.open(f'{output_path}Wetland_Classification.tif', 'w', **class_meta) as dst:
    dst.write(class_export.astype('uint8'), 1)
    # Write class names as metadata
    dst.update_tags(1,
                   class_1='Permanent Wetland',
                   class_2='Vegetated Wetland',
                   class_3='Seasonal Wetland')
print("‚úì Saved: Wetland_Classification.tif")

# ============================================================================
# STEP 5: CALCULATE COMPREHENSIVE STATISTICS
# ============================================================================
print("\n" + "="*80)
print("STEP 5: CALCULATING STATISTICS")
print("="*80)

# Area-based statistics
area_stats = {}
for class_val in [1, 2, 3]:
    count = np.sum(classification_clean == class_val)
    area_ha = count * pixel_area_ha
    pct = (count / total_classified_pixels) * 100
    area_stats[class_val] = {'count': count, 'area_ha': area_ha, 'pct': pct}

# Patch-based statistics (using 4-connectivity)
patch_stats = {}
for class_val in [1, 2, 3]:
    class_mask = (classification_clean == class_val)
    labeled_array, num_patches = label(class_mask, structure=structure_4conn)

    patch_sizes = []
    for patch_id in range(1, num_patches + 1):
        patch_pixels = np.sum(labeled_array == patch_id)
        patch_area_ha = patch_pixels * pixel_area_ha
        if patch_area_ha >= MIN_PATCH_SIZE_HA:
            patch_sizes.append(patch_area_ha)

    patch_stats[class_val] = {
        'num_patches': len(patch_sizes),
        'sizes': patch_sizes,
        'mean_size': np.mean(patch_sizes) if patch_sizes else 0,
        'median_size': np.median(patch_sizes) if patch_sizes else 0,
        'std_size': np.std(patch_sizes) if patch_sizes else 0,
        'min_size': np.min(patch_sizes) if patch_sizes else 0,
        'max_size': np.max(patch_sizes) if patch_sizes else 0
    }

    print(f"\n{class_names[class_val]}:")
    print(f"  Number of patches: {len(patch_sizes):,}")
    print(f"  Mean size: {patch_stats[class_val]['mean_size']:.4f} ha")
    print(f"  Median size: {patch_stats[class_val]['median_size']:.4f} ha")
    print(f"  Size range: [{patch_stats[class_val]['min_size']:.4f}, {patch_stats[class_val]['max_size']:.2f}] ha")

total_patches = sum([patch_stats[i]['num_patches'] for i in [1, 2, 3]])
print(f"\n‚úì Total patches: {total_patches:,}")

# CV statistics by class
cv_stats_by_class = {}
for class_val in [1, 2, 3]:
    class_cv = cv_wetlands[classification_clean == class_val]
    class_cv = class_cv[~np.isnan(class_cv)]

    cv_stats_by_class[class_val] = {
        'mean': np.mean(class_cv),
        'median': np.median(class_cv),
        'std': np.std(class_cv),
        'min': np.min(class_cv),
        'max': np.max(class_cv),
        'q25': np.percentile(class_cv, 25),
        'q75': np.percentile(class_cv, 75)
    }

# --- Figure 4: Pie Charts ---
print("\nüìä Generating Figure 4: Pie Charts...")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))

# Area distribution
sizes_area = [area_stats[i]['area_ha'] for i in [1, 2, 3]]
labels_area = [f"{class_names[i]}\n{area_stats[i]['pct']:.1f}%" for i in [1, 2, 3]]
ax1.pie(sizes_area, labels=labels_area, colors=colors_pie, startangle=90,
       autopct='', textprops={'fontsize': 12, 'weight': 'bold'})
ax1.set_title(f'Wetland Area Distribution\nTotal: {total_classified_pixels * pixel_area_ha:.2f} ha',
             fontsize=16, fontweight='bold', pad=20)

# Patch distribution
sizes_patch = [patch_stats[i]['num_patches'] for i in [1, 2, 3]]
labels_patch = [f"{class_names[i]}\n{patch_stats[i]['num_patches']/total_patches*100:.1f}%" for i in [1, 2, 3]]
ax2.pie(sizes_patch, labels=labels_patch, colors=colors_pie, startangle=90,
       autopct='', textprops={'fontsize': 12, 'weight': 'bold'})
ax2.set_title(f'Wetland Patch Count Distribution\nTotal: {total_patches:,} patches',
             fontsize=16, fontweight='bold', pad=20)

plt.tight_layout()
plt.savefig(f'{output_path}Figure4_Pie_Charts.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print("‚úì Saved: Figure4_Pie_Charts.png")

# ============================================================================
# STEP 6: SEASONAL FINGERPRINT VALIDATION
# ============================================================================
print("\n" + "="*80)
print("STEP 6: SEASONAL FINGERPRINT VALIDATION")
print("="*80)

fingerprints = {}
for class_val in [1, 2, 3]:
    class_mask = (classification_clean == class_val)
    fingerprints[class_val] = {
        'spring': np.nanmean(data['NDWI']['Spring'][class_mask]),
        'summer': np.nanmean(data['NDWI']['Summer'][class_mask]),
        'fall': np.nanmean(data['NDWI']['Fall'][class_mask])
    }

    fp = fingerprints[class_val]
    values = [fp['spring'], fp['summer'], fp['fall']]
    range_val = max(values) - min(values)
    print(f"\n{class_names[class_val]}:")
    print(f"  Spring: {fp['spring']:+.4f}")
    print(f"  Summer: {fp['summer']:+.4f}")
    print(f"  Fall: {fp['fall']:+.4f}")
    print(f"  Range: {range_val:.4f}")

# --- Figure 7: Seasonal Fingerprint ---
print("\nüìä Generating Figure 7: Seasonal Fingerprint...")
fig, ax = plt.subplots(figsize=(12, 8))

season_positions = [0, 1, 2]
season_labels = ['Spring', 'Summer', 'Fall']

for idx, class_val in enumerate([1, 2, 3]):
    fp = fingerprints[class_val]
    values = [fp['spring'], fp['summer'], fp['fall']]
    ax.plot(season_positions, values, color=colors_lines[idx], marker=markers[idx],
            markersize=12, linewidth=3, label=class_names_full[class_val])

ax.set_xlabel('Season', fontsize=14, fontweight='bold')
ax.set_ylabel('Mean NDWI (Gao)', fontsize=14, fontweight='bold')
ax.set_title('Figure 7: Seasonal NDWI Fingerprints by Class\n(Validation of CV Thresholds)',
             fontsize=16, fontweight='bold', pad=20)
ax.axhline(0, color='black', linestyle='--', linewidth=1, label='Water/Vegetation Boundary (0.0)')
ax.set_xticks(season_positions)
ax.set_xticklabels(season_labels, fontsize=12)
ax.legend(fontsize=12, loc='best', framealpha=0.95, edgecolor='black')
ax.grid(alpha=0.3, linestyle='--')

# Add interpretation text
interp_text = ('Permanent (Blue): Stably Negative (Open Water)\n'
               'Vegetated (Orange): Stably Positive (Vegetation)\n'
               'Seasonal (Red): Fluctuating (Ephemeral)')
ax.text(0.02, 0.98, interp_text, transform=ax.transAxes, fontsize=10,
        verticalalignment='top', bbox=props)

plt.tight_layout()
plt.savefig(f'{output_path}Figure7_Seasonal_Fingerprint.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print("‚úì Saved: Figure7_Seasonal_Fingerprint.png")

# ============================================================================
# STEP 7: ADDITIONAL ANALYSIS FIGURES
# ============================================================================
print("\n" + "="*80)
print("STEP 7: GENERATING ADDITIONAL ANALYSIS FIGURES")
print("="*80)

# --- Figure 6: Comprehensive Analysis (4 subplots) ---
print("\nüìä Generating Figure 6: Comprehensive Analysis...")
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Subplot A: CV Distribution Histogram
ax = axes[0, 0]
cv_all = cv_wetlands[~np.isnan(cv_wetlands)]
# Filter extreme values for better visualization
cv_all_filtered = cv_all[cv_all <= 150]  # Remove extreme outliers

ax.hist(cv_all_filtered, bins=80, color='steelblue', alpha=0.7, edgecolor='black', linewidth=0.5)
ax.axvline(LOW_CV_THRESHOLD, color='green', linestyle='--', linewidth=2.5,
          label=f'Permanent Threshold (CV = {LOW_CV_THRESHOLD}%)')
ax.axvline(HIGH_CV_THRESHOLD, color='red', linestyle='--', linewidth=2.5,
          label=f'Seasonal Threshold (CV = {HIGH_CV_THRESHOLD}%)')

ax.set_xlabel('Coefficient of Variation (%)', fontsize=12, fontweight='bold')
ax.set_ylabel('Frequency (pixels)', fontsize=12, fontweight='bold')
ax.set_title('(A) CV Distribution Across All Wetlands', fontsize=13, fontweight='bold', pad=10)
ax.legend(fontsize=10, loc='upper right')
ax.grid(alpha=0.3, linestyle='--')
ax.set_xlim(0, 150)  # Cap at 150% for better viz

# Add statistics text
n_extreme = np.sum(cv_all > 150)
if n_extreme > 0:
    ax.text(0.98, 0.5, f'Note: {n_extreme:,} pixels\nwith CV > 150%\n(not shown)',
           transform=ax.transAxes, fontsize=9, va='center', ha='right',
           bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7))


# Subplot B: CV by Class (Boxplot)
ax = axes[0, 1]

# Prepare data with FILTERING of extreme values
cv_by_class = []
class_labels = []
for class_val in [1, 2, 3]:
    class_cv = cv_wetlands[classification_clean == class_val]
    class_cv = class_cv[~np.isnan(class_cv)]
    # FILTER: Cap CV at 150% to avoid extreme outliers
    class_cv_filtered = class_cv[class_cv <= 150]
    cv_by_class.append(class_cv_filtered)
    class_labels.append(class_names[class_val])

    # Report filtering
    n_filtered = len(class_cv) - len(class_cv_filtered)
    if n_filtered > 0:
        print(f"  {class_names[class_val]}: Filtered {n_filtered} extreme CV values (>{150}%)")


# Create boxplot with better styling
bp = ax.boxplot(cv_by_class, labels=class_labels,
                patch_artist=True,
                showmeans=True,
                showfliers=True,  # Show outliers but they're now capped at 150%
                widths=0.6,
                medianprops=dict(color='red', linewidth=2),
                meanprops=dict(marker='D', markerfacecolor='green', markersize=8),
                flierprops=dict(marker='o', markerfacecolor='gray', markersize=4, alpha=0.5))

# Color the boxes
box_colors = ['lightblue', 'lightyellow', 'lightcoral']
for patch, color in zip(bp['boxes'], box_colors):
    patch.set_facecolor(color)
    patch.set_edgecolor('black')
    patch.set_linewidth(1.5)


ax.set_ylabel('Coefficient of Variation (%)', fontsize=12, fontweight='bold')
ax.set_title('(B) CV Distribution by Class (Filtered)', fontsize=13, fontweight='bold', pad=10)
ax.grid(alpha=0.3, axis='y', linestyle='--')
ax.set_ylim(0, 150)  # Set consistent y-axis

# Add legend for mean/median
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], color='red', linewidth=2, label='Median'),
    Line2D([0], [0], marker='D', color='w', markerfacecolor='green',
           markersize=8, label='Mean')
]
ax.legend(handles=legend_elements, loc='upper right', fontsize=9)


# Subplot C: Area vs Count Comparison
ax = axes[1, 0]

categories = ['Permanent', 'Vegetated', 'Seasonal']
area_pct = [area_stats[i]['pct'] for i in [1, 2, 3]]
patch_pct = [patch_stats[i]['num_patches']/total_patches*100 for i in [1, 2, 3]]

x = np.arange(len(categories))
width = 0.38

bars1 = ax.bar(x - width/2, area_pct, width, label='% of Total Area',
              color='steelblue', edgecolor='black', linewidth=1.5)
bars2 = ax.bar(x + width/2, patch_pct, width, label='% of Total Patches',
              color='coral', edgecolor='black', linewidth=1.5)

ax.set_ylabel('Percentage (%)', fontsize=12, fontweight='bold')
ax.set_title('(C) Area vs Patch Count Distribution', fontsize=13, fontweight='bold', pad=10)
ax.set_xticks(x)
ax.set_xticklabels(categories, fontsize=11, fontweight='bold')
ax.legend(fontsize=10, loc='upper right')
ax.grid(alpha=0.3, axis='y', linestyle='--')
ax.set_ylim(0, 65)

# Add value labels with better positioning
for i, (a, p) in enumerate(zip(area_pct, patch_pct)):
    ax.text(i - width/2, a + 1.5, f'{a:.1f}%', ha='center', va='bottom',
           fontsize=10, fontweight='bold')
    ax.text(i + width/2, p + 1.5, f'{p:.1f}%', ha='center', va='bottom',
           fontsize=10, fontweight='bold')


# Subplot D: Patch Size Distribution (Log scale)
ax = axes[1, 1]

all_sizes = []
all_labels = []
for class_val in [1, 2, 3]:
    sizes = patch_stats[class_val]['sizes']
    all_sizes.extend(np.log10(np.array(sizes) + 0.001))
    all_labels.extend([class_names[class_val]] * len(sizes))

sizes_df = pd.DataFrame({'Log10_Area': all_sizes, 'Class': all_labels})

# Plot histograms with better binning
for class_val, color in zip([1, 2, 3], colors_map):
    data = sizes_df[sizes_df['Class'] == class_names[class_val]]['Log10_Area']
    ax.hist(data, bins=45, alpha=0.6, label=class_names[class_val],
           color=color, edgecolor='black', linewidth=0.5)

ax.set_xlabel('Log‚ÇÅ‚ÇÄ(Patch Area in ha)', fontsize=12, fontweight='bold')
ax.set_ylabel('Frequency (Number of Patches)', fontsize=12, fontweight='bold')
ax.set_title('(D) Patch Size Distribution by Class', fontsize=13, fontweight='bold', pad=10)
ax.legend(fontsize=10, loc='upper right')
ax.grid(alpha=0.3, linestyle='--')

# Add reference lines for key sizes
ax.axvline(np.log10(0.01), color='gray', linestyle=':', linewidth=1, alpha=0.7)
ax.axvline(np.log10(0.1), color='gray', linestyle=':', linewidth=1, alpha=0.7)
ax.axvline(np.log10(1.0), color='gray', linestyle=':', linewidth=1, alpha=0.7)

# Add size labels
size_labels = ['0.01 ha', '0.1 ha', '1 ha']
size_positions = [np.log10(0.01), np.log10(0.1), np.log10(1.0)]
for label, pos in zip(size_labels, size_positions):
    ax.text(pos, ax.get_ylim()[1]*0.95, label, rotation=90,
           va='top', ha='right', fontsize=8, alpha=0.7)


plt.tight_layout()
plt.savefig(f'{output_path}Figure6_Comprehensive_Analysis.png',
           dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print("‚úì Saved: Figure6_Comprehensive_Analysis.png")

# --- Figure 9: Bimodal Distribution ---
print("\nüìä Generating Figure 9: Bimodal Size Distribution...")
all_sizes = []
for class_val in [1, 2, 3]:
    all_sizes.extend(patch_stats[class_val]['sizes'])

all_sizes_log = np.log10(all_sizes)
all_sizes_log = all_sizes_log[np.isfinite(all_sizes_log)]

fig, ax = plt.subplots(figsize=(12, 7))
counts, bins, patches = ax.hist(all_sizes_log, bins=50, color='#3182bd', alpha=0.8, edgecolor='black')
ax.axvline(x=np.log10(0.03), color='red', linestyle='--', linewidth=2.5,
          label='Peak 1: Fragmented Systems (~0.03 ha)')
ax.axvline(x=np.log10(0.5), color='green', linestyle='--', linewidth=2.5,
          label='Peak 2: Contiguous Systems (~0.5 ha)')
ax.set_title('Figure 9: Bimodal Distribution of Patch Sizes (All Classes)',
             fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel('Wetland Size (Log10 ha)', fontsize=12, fontweight='bold')
ax.set_ylabel('Frequency (Number of Patches)', fontsize=12, fontweight='bold')
ax.legend(fontsize=11, loc='upper right')
ax.grid(axis='y', alpha=0.4)

# Add text annotations
peak1_count = np.sum((all_sizes_log >= np.log10(0.01)) & (all_sizes_log < np.log10(0.1)))
peak2_count = np.sum((all_sizes_log >= np.log10(0.3)) & (all_sizes_log < np.log10(2)))
textstr = f'Small wetlands (<0.1 ha): {peak1_count:,}\nLarge wetlands (>0.3 ha): {peak2_count:,}'
ax.text(0.98, 0.98, textstr, transform=ax.transAxes, fontsize=10,
        verticalalignment='top', horizontalalignment='right', bbox=props, family='monospace')

plt.tight_layout()
plt.savefig(f'{output_path}Figure9_Bimodal_Distribution.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print("‚úì Saved: Figure9_Bimodal_Distribution.png")


# ============================================================================
# STEP 8: CREATE STATISTICS TABLE AND EXPORT
# ============================================================================
print("\n" + "="*80)
print("STEP 8: CREATING STATISTICS TABLE AND EXCEL EXPORT")
print("="*80)

# Main statistics table
stats_table = []
for class_val in [1, 2, 3]:
    if class_val == 1:
        cv_thresh_str = f"CV < {LOW_CV_THRESHOLD}%"
    elif class_val == 2:
        cv_thresh_str = f"{LOW_CV_THRESHOLD}%-{HIGH_CV_THRESHOLD}%"
    else:
        cv_thresh_str = f"CV > {HIGH_CV_THRESHOLD}%"

    stats_table.append({
        'Class': class_names[class_val],
        'CV Threshold': cv_thresh_str,
        'Area (ha)': round(area_stats[class_val]['area_ha'], 2),
        'Area (%)': round(area_stats[class_val]['pct'], 1),
        'Patches': patch_stats[class_val]['num_patches'],
        'Patch (%)': round(patch_stats[class_val]['num_patches']/total_patches*100, 1),
        'Mean Size (ha)': round(patch_stats[class_val]['mean_size'], 4),
        'Median Size (ha)': round(patch_stats[class_val]['median_size'], 4)
    })

df_stats = pd.DataFrame(stats_table)

print("\n" + "="*80)
print("FINAL WETLAND CLASSIFICATION STATISTICS")
print("="*80)
print(df_stats.to_string(index=False))
print("="*80)

# Export to Excel with multiple sheets
excel_path = f'{output_path}Wetland_Statistics_Complete.xlsx'
print(f"\nüíæ Creating Excel workbook: {excel_path}")

with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
    # Sheet 1: Main Statistics
    df_stats.to_excel(writer, sheet_name='Classification Statistics', index=False)

    # Sheet 2: Seasonal Fingerprints
    fingerprint_data = []
    for season in ['spring', 'summer', 'fall']:
        row = {'Season': season.capitalize()}
        for class_val in [1, 2, 3]:
            row[class_names[class_val]] = round(fingerprints[class_val][season], 4)
        fingerprint_data.append(row)
    df_fingerprint = pd.DataFrame(fingerprint_data)
    df_fingerprint.to_excel(writer, sheet_name='Seasonal Fingerprints', index=False)

    # Sheet 3: CV Statistics by Class
    cv_stats_data = []
    for class_val in [1, 2, 3]:
        cv_stats_data.append({
            'Class': class_names[class_val],
            'Mean CV': round(cv_stats_by_class[class_val]['mean'], 2),
            'Median CV': round(cv_stats_by_class[class_val]['median'], 2),
            'Std Dev CV': round(cv_stats_by_class[class_val]['std'], 2),
            'Min CV': round(cv_stats_by_class[class_val]['min'], 2),
            'Max CV': round(cv_stats_by_class[class_val]['max'], 2),
            'Q25 CV': round(cv_stats_by_class[class_val]['q25'], 2),
            'Q75 CV': round(cv_stats_by_class[class_val]['q75'], 2)
        })
    df_cv_stats = pd.DataFrame(cv_stats_data)
    df_cv_stats.to_excel(writer, sheet_name='CV Statistics', index=False)

    # Sheet 4: Patch Size Statistics
    patch_size_data = []
    for class_val in [1, 2, 3]:
        patch_size_data.append({
            'Class': class_names[class_val],
            'Number of Patches': patch_stats[class_val]['num_patches'],
            'Mean Size (ha)': round(patch_stats[class_val]['mean_size'], 4),
            'Median Size (ha)': round(patch_stats[class_val]['median_size'], 4),
            'Std Dev (ha)': round(patch_stats[class_val]['std_size'], 4),
            'Min Size (ha)': round(patch_stats[class_val]['min_size'], 4),
            'Max Size (ha)': round(patch_stats[class_val]['max_size'], 4)
        })
    df_patch_size = pd.DataFrame(patch_size_data)
    df_patch_size.to_excel(writer, sheet_name='Patch Size Statistics', index=False)

    # Sheet 5: Metadata
    metadata_info = pd.DataFrame({
        'Parameter': [
            'Analysis Date',
            'Study Area',
            'NDWI Type',
            'Resolution',
            'Seasons Analyzed',
            'Total Wetland Area (ha)',
            'Total Classified Area (ha)',
            'Total Patches',
            'Permanent Threshold',
            'Seasonal Threshold',
            'CDL Codes Used',
            'CDL Year',
            'Edge Buffer (pixels)',
            'Minimum Patch Size (ha)',
            'Patch Connectivity'
        ],
        'Value': [
            datetime.now().strftime('%Y-%m-%d'),
            'Brookings County, South Dakota',
            'Gao NDWI (NIR-SWIR)',
            '10m',
            'Spring, Summer, Fall 2024',
            f'{total_wetland_area:.2f}',
            f'{total_classified_pixels * pixel_area_ha:.2f}',
            total_patches,
            f'CV < {LOW_CV_THRESHOLD}%',
            f'CV > {HIGH_CV_THRESHOLD}%',
            '111 (Open Water), 190 (Woody Wetlands), 195 (Herbaceous Wetlands)',
            '2024',
            EDGE_BUFFER_PIXELS,
            MIN_PATCH_SIZE_HA,
            '4-connectivity'
        ]
    })
    metadata_info.to_excel(writer, sheet_name='Metadata', index=False)

print("‚úì Excel workbook created with 5 sheets:")
print("  - Classification Statistics")
print("  - Seasonal Fingerprints")
print("  - CV Statistics")
print("  - Patch Size Statistics")
print("  - Metadata")

# Save CSV version too
csv_path = f'{output_path}Wetland_Statistics_Summary.csv'
df_stats.to_csv(csv_path, index=False)
print(f"‚úì CSV summary saved: {csv_path}")

# --- Generate Statistics Table Figure ---
print("\nüìä Generating Statistics Table Figure...")
fig, ax = plt.subplots(figsize=(16, 5))
ax.axis('tight')
ax.axis('off')

table_data = [df_stats.columns.tolist()] + df_stats.values.tolist()
table = ax.table(cellText=table_data, cellLoc='center', loc='center',
                colWidths=[0.20, 0.14, 0.10, 0.09, 0.10, 0.09, 0.14, 0.14])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2.2)

# Style header row
for i in range(len(df_stats.columns)):
    table[(0, i)].set_facecolor('#08519c')
    table[(0, i)].set_text_props(weight='bold', color='white')

# Style data rows
colors_rows = ['#c6dbef', '#fee0d2', '#fcbba1']
for i in range(1, len(table_data)):
    for j in range(len(df_stats.columns)):
        table[(i, j)].set_facecolor(colors_rows[i-1])
        table[(i, j)].set_text_props(weight='bold')

ax.set_title('Wetland Classification Statistics Summary\nBrookings County, SD - 2024',
             fontsize=14, fontweight='bold', pad=20)

plt.tight_layout()
plt.savefig(f'{output_path}Figure_Statistics_Table.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig)
print("‚úì Saved: Figure_Statistics_Table.png")

# ============================================================================
# STEP 9: FINAL SUMMARY AND VALIDATION
# ============================================================================
print("\n" + "="*80)
print("FINAL ANALYSIS SUMMARY")
print("="*80)

print("\nüìä KEY FINDINGS:")
print(f"  Total wetland area analyzed: {total_wetland_area:.2f} ha")
print(f"  Total classified area: {total_classified_pixels * pixel_area_ha:.2f} ha")
print(f"  Total number of patches: {total_patches:,}")
print(f"\n  Dominant class by AREA: {class_names[max(area_stats, key=lambda x: area_stats[x]['pct'])]} ({max([area_stats[i]['pct'] for i in [1,2,3]]):.1f}%)")
print(f"  Dominant class by COUNT: {class_names[max(patch_stats, key=lambda x: patch_stats[x]['num_patches'])]} ({max([patch_stats[i]['num_patches']/total_patches*100 for i in [1,2,3]]):.1f}%)")

print("\n‚úÖ DATA QUALITY CHECKS:")
# Check 1: Total percentages
total_area_pct = sum([area_stats[i]['pct'] for i in [1, 2, 3]])
total_patch_pct = sum([patch_stats[i]['num_patches']/total_patches*100 for i in [1, 2, 3]])
print(f"  Area percentages sum: {total_area_pct:.1f}% {'‚úì' if abs(total_area_pct - 100) < 0.5 else '‚ö†Ô∏è'}")
print(f"  Patch percentages sum: {total_patch_pct:.1f}% {'‚úì' if abs(total_patch_pct - 100) < 0.5 else '‚ö†Ô∏è'}")

# Check 2: NDWI fingerprint validation
for class_val in [1, 2, 3]:
    fp = fingerprints[class_val]
    values = [fp['spring'], fp['summer'], fp['fall']]
    range_val = max(values) - min(values)
    mean_val = np.mean(values)

    if class_val == 1:  # Permanent should be negative and stable
        check = mean_val < -0.1 and range_val < 0.1
        print(f"  {class_names[class_val]:20s}: Mean NDWI = {mean_val:+.3f}, Range = {range_val:.3f} {'‚úì' if check else '‚ö†Ô∏è'}")
    elif class_val == 3:  # Seasonal should fluctuate
        check = range_val > 0.2
        print(f"  {class_names[class_val]:20s}: Mean NDWI = {mean_val:+.3f}, Range = {range_val:.3f} {'‚úì' if check else '‚ö†Ô∏è'}")
    else:  # Vegetated
        check = mean_val > 0
        print(f"  {class_names[class_val]:20s}: Mean NDWI = {mean_val:+.3f}, Range = {range_val:.3f} {'‚úì' if check else '‚ö†Ô∏è'}")

print("\nüìÅ OUTPUT FILES GENERATED:")
output_files = [
    "Figure1_Wetland_Mask.png",
    "Figure2_CV_Map.png",
    "Figure3_Classification_Map.png",
    "Figure4_Pie_Charts.png",
    "Figure6_Comprehensive_Analysis.png",
    "Figure7_Seasonal_Fingerprint.png",
    "Figure9_Bimodal_Distribution.png",
    "Figure_Statistics_Table.png",
    "Wetland_Mask.tif", # Added Wetland Mask GeoTIFF
    "CV_Map.tif",
    "Wetland_Classification.tif",
    "Wetland_Statistics_Complete.xlsx",
    "Wetland_Statistics_Summary.csv"
]

for filename in output_files:
    filepath = f"{output_path}{filename}"
    if os.path.exists(filepath):
        size_mb = os.path.getsize(filepath) / (1024 * 1024)
        print(f"  ‚úì {filename:40s} ({size_mb:.2f} MB)")
    else:
        print(f"  ‚úó {filename:40s} (NOT FOUND)")

print("\n" + "="*80)
print("‚úÖ ANALYSIS COMPLETE!")
print("="*80)
print(f"\nAll outputs saved to: {output_path}")
print(f"Analysis completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("\nüéâ Your wetland classification is ready for report writing!")
print("="*80)

SyntaxError: unterminated triple-quoted string literal (detected at line 1061) (ipython-input-3157952848.py, line 17)