# Wetland Mapping with Foundation Models

[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/wetland_foundation_model.ipynb)

This notebook demonstrates how to train and deploy wetland mapping models using Satlas Aerial foundation models, NAIP imagery, and National Wetlands Inventory (NWI) data.

## Key Features

- **Foundation Model Backbone**: Uses Allen AI's Satlas Aerial (Swin-v2-Base pre-trained on NAIP aerial imagery)
- **Automated Data Pipeline**: Downloads NAIP imagery and NWI data via existing GeoAI infrastructure
- **Multi-Class Wetland Detection**: Classifies wetlands into 6 categories
- **Scalable Training**: PyTorch Lightning with progressive unfreezing
- **Large Image Inference**: Continental-scale mapping capability

## Install packages

In [1]:
# Uncomment to install required packages
# %pip install geoai-py lightning leafmap satlaspretrain-models

## Import libraries

In [2]:
import os
import geoai
import leafmap
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

print(f"GeoAI version: {geoai.__version__}")

GeoAI version: 0.27.0


## Step 1: Explore Available Foundation Models

First, let's explore the Satlas Aerial foundation model used as the default backbone:

In [3]:
# Check Satlas Aerial backbone availability
import satlaspretrain_models

print("Satlas Aerial backbone: Swin-v2-Base pre-trained on NAIP imagery")
print("Source: https://huggingface.co/allenai/satlas-pretrain")
print()
print("Model Details:")
print("  - Architecture: Swin-v2-Base transformer")
print("  - Pre-training: 26.5 million NAIP aerial images")
print("  - Developed by: Allen AI Institute")
print("  - Input channels: RGB (3 channels)")
print("  - Spatial resolution: 0.6m per pixel")
print(
    "\nNote: Prithvi is still available as an alternative backbone with backbone='prithvi'"
)

Satlas Aerial backbone: Swin-v2-Base pre-trained on NAIP imagery
Source: https://huggingface.co/allenai/satlas-pretrain

Model Details:
  - Architecture: Swin-v2-Base transformer
  - Pre-training: 26.5 million NAIP aerial images
  - Developed by: Allen AI Institute
  - Input channels: RGB (3 channels)
  - Spatial resolution: 0.6m per pixel

Note: Prithvi is still available as an alternative backbone with backbone='prithvi'


## Step 2: Define Study Region

Let's select a region with diverse wetland types. We'll use North Dakota's prairie pothole region:

In [4]:
# Define study region - North Dakota prairie pothole area
study_bbox = (-99.3, 46.8, -99.0, 47.1)  # (min_lon, min_lat, max_lon, max_lat)
year = 2020

print(f"Study region: {study_bbox}")
print(f"NAIP year: {year}")

# Create interactive map to visualize the region
m = leafmap.Map(center=[46.95, -99.15], zoom=10)
m.add_basemap("Esri.WorldImagery")

# Add study area boundary
bbox_gdf = leafmap.bbox_to_gdf(study_bbox)
m.add_gdf(
    bbox_gdf,
    layer_name="Study Area",
    style={"color": "red", "weight": 3, "fillOpacity": 0},
)

# Add NWI layer for context
m.add_basemap("FWS NWI Wetlands", opacity=0.6)

print("\nInteractive map showing study region:")
print("- Red outline: Study area boundary")
print("- Blue areas: Existing NWI wetland data")
m

Study region: (-99.3, 46.8, -99.0, 47.1)
NAIP year: 2020

Interactive map showing study region:
- Red outline: Study area boundary
- Blue areas: Existing NWI wetland data


Map(center=[46.95, -99.15], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'zoom_‚Ä¶

## Step 3: Explore Wetland Classes

Our foundation model will classify wetlands into 6 main categories:

In [5]:
# Get wetland class definitions
wetland_classes = geoai.get_wetland_classes()

print("Wetland Classification System:")
print("=============================")
for class_name, class_id in wetland_classes.items():
    descriptions = {
        'background': 'Non-wetland areas (uplands, agriculture, urban)',
        'freshwater_emergent': 'Palustrine emergent wetlands (cattails, sedges)',
        'freshwater_forested': 'Palustrine forested wetlands (swamps, wet forests)',
        'freshwater_pond': 'Open water and pond habitats',
        'estuarine': 'Estuarine and marine wetlands (salt marshes)',
        'other_wetland': 'Other wetland types (scrub-shrub, etc.)',
    }

    desc = descriptions.get(class_name, 'Other wetland type')
    print(f"  {class_id}: {class_name.replace('_', ' ').title()}")
    print(f"     {desc}")
    print()

print(f"Total classes: {len(wetland_classes)}")

Wetland Classification System:
  0: Background
     Non-wetland areas (uplands, agriculture, urban)

  1: Freshwater Emergent
     Palustrine emergent wetlands (cattails, sedges)

  2: Freshwater Forested
     Palustrine forested wetlands (swamps, wet forests)

  3: Freshwater Pond
     Open water and pond habitats

  4: Estuarine
     Estuarine and marine wetlands (salt marshes)

  5: Other Wetland
     Other wetland types (scrub-shrub, etc.)

Total classes: 6


## Step 4: Create Training Dataset

Now let's create a complete wetland training dataset by downloading NAIP imagery and NWI data:

In [6]:
# Create training dataset using the convenience function
print("Creating wetland training dataset...")
print("This will:")
print("1. Download NAIP imagery from Planetary Computer")
print("2. Fetch NWI wetland data")
print("3. Create training tiles with wetland masks")
print()

try:
    dataset_stats = geoai.create_wetland_dataset(
        bbox=study_bbox,
        output_dir="wetland_training_data",
        year=year,
        max_naip_items=5,  # Download up to 15 NAIP tiles
        tile_size=512,  # 512x512 pixel training tiles
        min_wetland_pixels=50,  # Minimum wetland pixels to include tile
    )

    print("\nDataset Creation Results:")
    print("========================")
    print(f"NAIP files downloaded: {dataset_stats['naip_files_downloaded']}")
    print(f"Wetland features found: {dataset_stats['wetland_features_found']}")
    print(f"Total training tiles: {dataset_stats['total_tiles']}")
    print(f"Wetland tiles: {dataset_stats['wetland_tiles']}")
    print(f"Files processed: {dataset_stats['files_processed']}")
    print(f"Dataset directory: {dataset_stats['dataset_dir']}")

    dataset_created = dataset_stats['wetland_tiles'] > 0

except Exception as e:
    print(f"Error creating dataset: {e}")
    dataset_created = False
    dataset_stats = {}

Creating wetland training dataset...
This will:
1. Download NAIP imagery from Planetary Computer
2. Fetch NWI wetland data
3. Create training tiles with wetland masks



Found 5 NAIP items.


Skipping existing file: wetland_data_cache/naip/m_4709964_sw_14_060_20200831.tif
Skipping existing file: wetland_data_cache/naip/m_4709964_se_14_060_20200831.tif
Skipping existing file: wetland_data_cache/naip/m_4709964_nw_14_060_20200831.tif
Skipping existing file: wetland_data_cache/naip/m_4709964_ne_14_060_20200831.tif
Skipping existing file: wetland_data_cache/naip/m_4709963_se_14_060_20200831.tif



Dataset Creation Results:
NAIP files downloaded: 5
Wetland features found: 2000
Total training tiles: 2465
Wetland tiles: 1923
Files processed: 5
Dataset directory: wetland_training_data


## Step 5: Visualize Training Data

Let's examine some of the training tiles we created:

In [7]:
if dataset_created:
    # List some training files
    images_dir = Path("wetland_training_data/images")
    masks_dir = Path("wetland_training_data/masks")

    image_files = list(images_dir.glob("*.tif"))[:5]  # First 5 files

    if image_files:
        print(
            f"Sample training files (showing first 5 of {len(list(images_dir.glob('*.tif')))}):"
        )
        for img_file in image_files:
            mask_file = masks_dir / img_file.name
            print(f"  Image: {img_file.name}")
            print(f"  Mask:  {mask_file.name}")
            print()

        # Show file sizes
        total_size_mb = sum(f.stat().st_size for f in images_dir.glob("*.tif")) / (
            1024**2
        )
        print(f"Total dataset size: {total_size_mb:.1f} MB")
    else:
        print("No training files found.")
else:
    print("Skipping visualization - no training data was created.")
    print("Try a different region with more wetland coverage.")

Sample training files (showing first 5 of 3765):
  Image: m_4709964_sw_14_060_20200831_0_7680.tif
  Mask:  m_4709964_sw_14_060_20200831_0_7680.tif

  Image: m_4709964_sw_14_060_20200831_256_6400.tif
  Mask:  m_4709964_sw_14_060_20200831_256_6400.tif

  Image: m_4709964_sw_14_060_20200831_256_6912.tif
  Mask:  m_4709964_sw_14_060_20200831_256_6912.tif

  Image: m_4709964_sw_14_060_20200831_512_3840.tif
  Mask:  m_4709964_sw_14_060_20200831_512_3840.tif

  Image: m_4709964_sw_14_060_20200831_512_5632.tif
  Mask:  m_4709964_sw_14_060_20200831_512_5632.tif

Total dataset size: 3656.5 MB


## Step 6: Examine Foundation Model Architecture

Let's understand the wetland foundation model using Satlas Aerial backbone that we're about to train:

In [None]:
# Create a sample model to examine architecture
print("=== Wetland Foundation Model Architecture ===")
print()

try:
    sample_model = geoai.WetlandSatlasModel(
        num_wetland_classes=6, freeze_backbone_epochs=2
    )

    # Show model info
    print(f"Foundation Backbone: Satlas Aerial (Swin-v2-Base)")
    print(f"  - Pre-trained on 26.5M NAIP aerial images")
    print(f"  - Architecture: Hierarchical Vision Transformer")
    print(f"  - Input: RGB imagery (3 channels)")
    print(f"  - Spatial understanding of aerial imagery patterns")

    print(f"\nTask-Specific Components:")
    print(f"  - Feature Pyramid Network (FPN): Multi-scale feature fusion")
    print(f"  - Segmentation head: Convolutional layers for pixel-wise classification")
    print(f"  - Output classes: {sample_model.num_classes}")
    print(f"  - Freeze backbone: First {sample_model.freeze_backbone_epochs} epochs")

    # Count parameters
    total_params = sum(p.numel() for p in sample_model.parameters())
    trainable_params = sum(
        p.numel() for p in sample_model.parameters() if p.requires_grad
    )

    print(f"\nModel Parameters:")
    print(f"  Total: {total_params:,}")
    print(f"  Initially trainable: {trainable_params:,}")
    print(f"  Initially frozen: {total_params - trainable_params:,}")

    print(f"\nTraining Strategy:")
    print(f"  1. Freeze Satlas backbone (preserve foundation knowledge)")
    print(f"  2. Train only FPN and segmentation head")
    print(f"  3. Progressively unfreeze backbone for fine-tuning")

    model_architecture_ready = True

except Exception as e:
    print(f"Error examining model: {e}")
    print("This may be due to missing dependencies or model files.")
    model_architecture_ready = False

## Step 7: Train the Wetland Foundation Model

Now let's train our model (if we have sufficient training data):

In [None]:
# Train model if we have sufficient data
min_required_tiles = 20  # Minimum for demo

if dataset_created and dataset_stats.get('wetland_tiles', 0) >= min_required_tiles:
    print("=== Training Wetland Foundation Model ===")
    print(f"Training with {dataset_stats['wetland_tiles']} wetland tiles")
    print()

    try:
        training_results = geoai.train_wetland_model(
            dataset_dir="wetland_training_data",
            output_dir="wetland_model_output",
            backbone="satlas",
            batch_size=4,
            max_epochs=30,
            learning_rate=1e-4,
            val_split=0.2,
            freeze_backbone_epochs=2,
        )

        print("\nüéâ Training Completed Successfully!")
        print(f"Best model: {training_results['best_model_path']}")
        print(f"Checkpoint: {training_results['checkpoint_path']}")
        print(f"Output directory: {training_results['output_dir']}")

        model_trained = True

    except Exception as e:
        print(f"Training failed: {e}")
        print(
            "This may be due to insufficient GPU memory, missing dependencies, or data issues."
        )
        model_trained = False
        training_results = {}

else:
    print(f"Insufficient training data for model training.")
    print(f"Found: {dataset_stats.get('wetland_tiles', 0)} tiles")
    print(f"Required: {min_required_tiles} tiles")
    print()
    print("For production training, you would:")
    print("1. Use a larger study region or multiple regions")
    print("2. Download 50+ NAIP tiles")
    print("3. Train for 50-100 epochs")
    print("4. Use larger batch sizes on GPU")

    model_trained = False
    training_results = {}

## Step 8: Resume Training (Optional)

If training was interrupted or you want to continue for more epochs, you can resume from the last checkpoint without starting over:

In [None]:
# Resume training from last checkpoint (increase max_epochs to train longer)
# Uncomment below to resume:

training_results = geoai.train_wetland_model(
    dataset_dir="wetland_training_data",
    output_dir="wetland_model_output",
    backbone="satlas",
    batch_size=4,
    max_epochs=30,  # Increase to train further
    learning_rate=1e-4,
    val_split=0.2,
    freeze_backbone_epochs=2,
    resume_from="last",  # Auto-detect last checkpoint
)

model_trained = True
# You can also resume from a specific checkpoint:
# training_results = geoai.train_wetland_model(
#     ...,
#     resume_from="wetland_model_output/checkpoints/wetland-satlas-epoch=03-val_loss=0.373.ckpt",
# )

## Step 9: Model Inference Demo

If training was successful, let's run inference on a new image:

In [None]:
if model_trained and dataset_stats.get('naip_files_downloaded', 0) > 0:
    print("=== Running Wetland Prediction ===")

    # Get list of downloaded NAIP files for testing
    naip_cache_dir = Path("wetland_data_cache/naip")
    naip_files = list(naip_cache_dir.glob("*.tif"))

    if naip_files:
        # Use first NAIP file as test case
        test_raster = str(naip_files[0])
        output_prediction = "wetland_prediction_demo.tif"

        print(f"Test raster: {test_raster}")
        print(f"Output: {output_prediction}")

        try:
            # Run inference
            prediction_result = geoai.predict_wetlands_large_image(
                model_path=training_results['best_model_path'],
                input_raster=test_raster,
                output_path=output_prediction,
                tile_size=512,
                overlap=64,
            )

            print(f"\n‚úÖ Prediction completed: {prediction_result}")
            inference_successful = True

        except Exception as e:
            print(f"Inference failed: {e}")
            inference_successful = False
    else:
        print("No NAIP files found for testing")
        inference_successful = False

else:
    print("Skipping inference - model not trained or no test data available")
    inference_successful = False

## Step 10: Visualize Results

Let's create an interactive map showing the wetland predictions:

In [None]:
if inference_successful and Path("wetland_prediction_demo.tif").exists():
    print("=== Visualizing Wetland Predictions ===")

    try:
        # Create visualization map
        naip_file = str(list(Path("wetland_data_cache/naip").glob("*.tif"))[0])

        viz_map = geoai.visualize_wetland_predictions(
            prediction_path="wetland_prediction_demo.tif",
            naip_path=naip_file,
            center=[46.95, -99.15],
        )

        print("Interactive map with wetland predictions:")
        print("- Background: NAIP imagery (NIR-Red-Green false color)")
        print("- Overlay: Wetland classification predictions")
        print("- Legend: Color-coded wetland classes")

        viz_map

    except Exception as e:
        print(f"Visualization failed: {e}")
        print("This may be due to missing leafmap or file access issues.")

else:
    print("Skipping visualization - no prediction results available")
    print()
    print("In a complete workflow, you would see:")
    print("üó∫Ô∏è  Interactive map with NAIP imagery background")
    print("üé®  Color-coded wetland classifications overlay")
    print("üìä  Legend showing the 6 wetland classes")
    print("üîç  Ability to zoom and explore predictions")

## Step 11: Production Scaling Tips

For deploying this as a production wetland mapping system:

In [None]:
print("=== Production Deployment Guide ===")
print()

print("üî¨ **Research & Development:**")
print("  ‚Ä¢ Multi-region training: 5,000-10,000+ tiles across diverse ecoregions")
print("  ‚Ä¢ Temporal analysis: Include multiple seasons/years of NAIP imagery")
print("  ‚Ä¢ Cross-validation: Geographic holdout regions for robust evaluation")
print("  ‚Ä¢ Field validation: GPS ground truth data for accuracy assessment")
print()

print("‚ö° **Model Optimization:**")
print("  ‚Ä¢ Backbone options: Satlas Aerial (default) or Prithvi for comparison")
print("  ‚Ä¢ Architecture improvements: Multi-scale fusion, attention mechanisms")
print("  ‚Ä¢ Training optimization: Mixed precision, gradient accumulation, DDP")
print("  ‚Ä¢ Ensemble methods: Combine multiple models for robustness")
print()

print("üåç **Scale & Deployment:**")
print("  ‚Ä¢ Cloud infrastructure: AWS/GCP with GPU clusters")
print("  ‚Ä¢ Model serving: TensorRT optimization, ONNX conversion")
print("  ‚Ä¢ API deployment: FastAPI + Docker for scalable inference")
print("  ‚Ä¢ Integration: Google Earth Engine for global monitoring")
print()

print("üìä **Monitoring & Validation:**")
print("  ‚Ä¢ Continuous validation: Compare with new field surveys")
print("  ‚Ä¢ Change detection: Monitor wetland loss/gain over time")
print("  ‚Ä¢ Model drift: Retrain periodically with new data")
print("  ‚Ä¢ Uncertainty quantification: Provide confidence scores")
print()

print("üîó **Ecosystem Integration:**")
print("  ‚Ä¢ GeoAI ecosystem: Seamless integration with geemap, leafmap")
print("  ‚Ä¢ Open science: Model sharing via HuggingFace Hub")
print("  ‚Ä¢ Standards: STAC-compliant metadata and outputs")
print("  ‚Ä¢ Community: Collaborate with wetland scientists and managers")

# Show example production code
print("\nüìù **Example Production Code:**")
print("```python")
print("# Continental-scale wetland mapping")
print("import geoai")
print("")
print("# Train on multiple regions")
print("regions = [")
print("    (-99, 46, -97, 48),  # North Dakota")
print("    (-84, 25, -80, 27),  # Florida Everglades")
print("    (-73, 40, -70, 42),  # Northeast coast")
print("]")
print("")
print("for region in regions:")
print("    geoai.create_wetland_dataset(region, f'training_{i}')")
print("")
print("# Train production model")
print("geoai.train_wetland_model(")
print("    'combined_training_data',")
print("    backbone='satlas',")
print("    max_epochs=100,")
print("    batch_size=16")
print(")")
print("```")

## Summary

üéâ **What we've demonstrated:**

1. **Foundation Model Pipeline**: Complete wetland mapping system using Satlas Aerial
2. **Automated Data Access**: NAIP imagery + NWI data via existing GeoAI infrastructure
3. **Multi-Class Classification**: 6 wetland categories with class-aware training
4. **Transfer Learning**: Leverage 26.5M NAIP aerial imagery pre-training
5. **Scalable Architecture**: PyTorch Lightning with progressive unfreezing
6. **Production Ready**: Large-scale inference and visualization capabilities

üöÄ **Key advantages over traditional approaches:**

- **Rich Representation**: Foundation models understand complex Earth patterns
- **Data Efficiency**: Less training data needed due to pre-training
- **Generalization**: Better performance across different regions/seasons
- **Scalability**: Cloud-native architecture for continental mapping
- **Integration**: Seamless with existing GeoAI ecosystem

This represents the **future of environmental monitoring** - combining the power of foundation models with domain expertise for accurate, scalable wetland mapping.

## Next Steps

1. **Expand training data**: Collect samples from multiple ecoregions
2. **Optimize model**: Try different backbone options and architectures
3. **Field validation**: Compare predictions with ground truth surveys
4. **Scale deployment**: Set up cloud infrastructure for production use
5. **Contribute**: Share trained models and improvements with the community

## Learn More

- **GeoAI Documentation**: [https://opengeoai.org](https://opengeos.org)
- **Satlas Aerial Foundation Model**: [https://huggingface.co/allenai/satlas-pretrain](https://huggingface.co/allenai/satlas-pretrain)
- **Prithvi Foundation Model** (alternative): [https://huggingface.co/ibm-nasa-geospatial](https://huggingface.co/ibm-nasa-geospatial)
- **NAIP Imagery**: [https://planetarycomputer.microsoft.com/dataset/naip](https://planetarycomputer.microsoft.com/dataset/naip)
- **National Wetlands Inventory**: [https://www.fws.gov/program/national-wetlands-inventory](https://www.fws.gov/program/national-wetlands-inventory)