# Fine-tuning Clay Foundation Model for Land Cover Segmentation

Welcome to Tutorial 3! In this hands-on session, you'll learn how to fine-tune the Clay foundation model for land cover segmentation using the Chesapeake Bay dataset.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/developmentseed/igarss25tutorial/blob/main/tut3a_EOFM_finetune.ipynb)

## Learning Objectives
By the end of this tutorial, you will:
- Understand what foundation models are and why they're powerful for Earth observation
- Learn how to fine-tune a pre-trained model for semantic segmentation
- Apply transfer learning techniques to land cover classification
- Work with real satellite imagery and ground truth labels
- Evaluate model performance on geospatial data

## What You'll Build
You'll create a land cover segmentation model that can classify different types of land use (water, forest, urban areas, etc.) from satellite imagery.

## Background for Different Audiences

### For GIS Professionals 📍
- **Foundation models** are like having a universal "base map" that understands Earth's features
- **Segmentation** is similar to creating detailed land use polygons, but at the pixel level
- Think of this as automated land cover classification that can replace manual digitization
- The output is similar to creating a detailed land use/land cover (LULC) raster

### For Data Analysts 📊  
- We're using **transfer learning** - starting with a model already trained on lots of Earth imagery
- **Fine-tuning** means adapting this pre-trained model to our specific classification task
- This is like taking a general-purpose tool and customizing it for your specific needs
- The model learns patterns in pixel values to predict land cover categories

### For ML Engineers 🤖
- Clay is a **Vision Transformer (ViT)** trained on massive Earth observation datasets
- We're doing **semantic segmentation** - predicting a class for every pixel
- The architecture uses a **frozen encoder** (Clay) + **trainable segmentation head**
- We'll use PyTorch Lightning for training orchestration

## Dataset Overview
The **Chesapeake Bay Land Cover dataset** contains:
- **High-resolution aerial imagery** (NAIP - National Agriculture Imagery Program)
- **7 land cover classes**: Water, Tree Canopy, Low Vegetation, Barren, Impervious (Roads), Impervious (Other), No Data
- **Pixel-level annotations** for supervised learning
- **Real-world complexity** with mixed land uses and seasonal variations

## How the Clay Segmentation Architecture Works

The `Segmentor` class combines two key components:

### 1. **Frozen Clay Encoder** 🧊
- Pre-trained on millions of Earth observation images
- Extracts rich feature representations from input imagery  
- **Frozen** = weights don't change during fine-tuning (saves compute!)
- Acts like a "universal feature extractor" for Earth imagery

### 2. **Trainable Segmentation Head** 🎯  
- Takes Clay's feature maps and upsamples them to original image size
- Uses **convolution + pixel shuffle** operations for efficient upsampling
- **Only this part gets trained** - much faster than training from scratch!

**Key Parameters:**
- `num_classes (int)`: Number of land cover classes to predict (7 for Chesapeake)
- `ckpt_path (str)`: Path to the pre-trained Clay model weights

**Why This Approach Works:**
- ✅ **Faster training**: Only train the small segmentation head
- ✅ **Less data needed**: Clay already understands Earth imagery patterns  
- ✅ **Better performance**: Foundation model knowledge transfers well
- ✅ **Cost effective**: Requires fewer computational resources

## About the Chesapeake Bay Dataset 🦀

We'll use the **Chesapeake Bay Land Cover dataset** - a high-quality dataset perfect for learning land cover segmentation.

### Dataset Citation
If you use this dataset in your work, please cite:
> Robinson C, Hou L, Malkin K, Soobitsky R, Czawlytko J, Dilkina B, Jojic N.  
> Large Scale High-Resolution Land Cover Mapping with Multi-Resolution Data.  
> Proceedings of the 2019 Conference on Computer Vision and Pattern Recognition (CVPR 2019).

### Why This Dataset is Great for Learning:
- **High Resolution**: 1-meter pixel resolution aerial imagery
- **Multiple Regions**: Covers diverse landscapes in the Chesapeake Bay area
- **Expert Annotations**: Ground truth labels created by domain experts
- **Real-world Complexity**: Mixed land uses, seasonal variations, and edge cases
- **Well-documented**: Extensively used in research with known baselines

### Land Cover Classes (7 total):
1. **Water** 💧 - Rivers, lakes, bays, coastal areas
2. **Tree Canopy/Forest** 🌳 - Dense forest areas, large trees
3. **Low Vegetation/Fields** 🌱 - Grass, crops, shrubs, sparse vegetation  
4. **Barren Land** 🏔️ - Exposed soil, construction sites, beaches
5. **Impervious (Roads)** 🛣️ - Paved roads, highways, parking lots
6. **Impervious (Other)** 🏢 - Buildings, rooftops, other built structures
7. **No Data** ⬜ - Areas with missing or invalid data

More information: [Chesapeake Bay Dataset](https://lila.science/datasets/chesapeakelandcover)

## 🚀 Setup and Installation

We'll install all required packages for fine-tuning the Clay model. This notebook is optimized for **Google Colab** but works in any Jupyter environment.

### What Each Package Does:
- **torch**: PyTorch deep learning framework
- **lightning**: PyTorch Lightning for training orchestration  
- **segmentation_models_pytorch**: Pre-built segmentation architectures
- **rasterio**: Reading/writing geospatial raster data (GeoTIFF files)
- **s5cmd**: Fast, parallel S3 data transfers

### Installation Options

**Option 1: All at once (recommended for Colab)**
```bash
pip install torch lightning segmentation_models_pytorch rasterio s5cmd
```

**Option 2: Individual packages (if you encounter conflicts)**
```bash
pip install torch
pip install lightning  
pip install segmentation_models_pytorch
pip install rasterio
pip install s5cmd
```

**For Conda users (local environments):**
```bash
mamba env create --file environment.yml
mamba activate claymodel
```

Let's install everything we need:

### 📦 Install Required Packages

Run this cell to install all dependencies. This may take 2-3 minutes in Colab.

In [None]:
# Install packages (this may take a few minutes)
!pip install torch lightning segmentation_models_pytorch rasterio s5cmd -q

### 📂 Clone the Clay Model Repository

We need the Clay model code for training. This downloads the latest version:

In [None]:
# Clone the Clay model repository
!git clone --depth=1 https://github.com/clay-foundation/model.git

In [None]:
# Navigate to the model directory and check contents
%cd model
!ls -la

### 🐍 Add Clay Model to Python Path

This makes the Clay model modules available for import:

In [None]:
# Add the claymodel directory to Python path so we can import modules
import sys
sys.path.append("./claymodel")

# Import key modules we'll use for training
from claymodel.finetune.segment.chesapeake_datamodule import ChesapeakeDataModule
from claymodel.finetune.segment.chesapeake_model import ChesapeakeSegmentor

print("✅ Clay model modules imported successfully!")

## 📥 Download Training Data

We'll download a subset of the Chesapeake Bay dataset for training. The full dataset is ~100GB, so we're using a small sample for this tutorial.

### What We're Downloading:
- **`*_lc.tif`**: Land cover label images (ground truth)
- **`*_naip-new.tif`**: NAIP aerial imagery (input images)
- **Training data**: From New York region, 2013
- **Validation data**: Separate set for evaluating model performance

### About s5cmd:
`s5cmd` is a high-performance tool for transferring data from AWS S3. It's much faster than standard AWS CLI for large datasets.

**Note**: Download may take 5-10 minutes depending on your internet connection.

In [None]:
# Create directory structure for our data
!mkdir -p data/cvpr/files/train data/cvpr/files/val

# Download training data (subset from NY region)
print("📥 Downloading training data...")
!s5cmd \
    --no-sign-request \
    cp \
    --include "m_42076*_lc.tif" \
    --include "m_42076*_naip-new.tif" \
    "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-train_tiles/*" \
    data/cvpr/files/train/

print("✅ Training data downloaded!")

In [None]:
# Download validation data (complete validation set)
print("📥 Downloading validation data...")
!s5cmd \
    --no-sign-request \
    cp \
    --include "*_lc.tif" \
    --include "*_naip-new.tif" \
    "s3://us-west-2.opendata.source.coop/agentmorris/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover/ny_1m_2013_extended-debuffered-val_tiles/*" \
    data/cvpr/files/val/

print("✅ Validation data downloaded!")

### ✅ Verify Downloaded Data

Let's check what we downloaded:

In [None]:
# Check what files we downloaded
print("📊 Validation data files:")
!ls data/cvpr/files/val | head -10

print(f"\n📈 Total files in validation: {len(!ls data/cvpr/files/val)} files")
print(f"📈 Total files in training: {len(!ls data/cvpr/files/train)} files")

## 🔄 Data Preprocessing

The downloaded GeoTIFF files are large (typically 1000x1000 pixels or more). For efficient training, we need to:

1. **Split into smaller chips**: Break large images into 224x224 pixel tiles
2. **Organize directory structure**: Separate images and labels into proper folders  
3. **Create train/val splits**: Ensure no data leakage between training and validation

### Why 224x224 chips?
- **Memory efficiency**: Fits in GPU memory for training
- **Standard size**: Common input size for vision models
- **Balanced coverage**: Good trade-off between context and computational efficiency

### What the preprocessing script does:
- Reads large GeoTIFF files 
- Splits them into 224x224 pixel chips
- Saves chips as individual image files
- Maintains spatial alignment between imagery and labels
- Creates proper directory structure for PyTorch Lightning

**Note**: This step may take 5-10 minutes to process all the data.

In [None]:
# Clean up any existing processed data to ensure fresh start
!rm -rf data/cvpr/ny/
print("🧹 Cleaned up existing processed data")

### 🔧 Run Data Preprocessing

This converts the large GeoTIFF files into training-ready 224x224 image chips:

In [None]:
# Run the preprocessing script
# Args: input_dir output_dir chip_size
print("🔄 Processing data into 224x224 chips...")
!python claymodel/finetune/segment/preprocess_data.py data/cvpr/files data/cvpr/ny 224
print("✅ Data preprocessing complete!")

### 📊 Check Processed Data

Let's verify our preprocessing worked correctly:

In [None]:
# Check the directory structure and count files
!echo "📁 Directory structure:"
!ls -la data/cvpr/ny/

!echo -e "\n📊 Data counts:"
!echo "Validation labels: $(ls data/cvpr/ny/val/labels | wc -l) files"  
!echo "Validation chips: $(ls data/cvpr/ny/val/chips | wc -l) files"
!echo "Training labels: $(ls data/cvpr/ny/train/labels | wc -l) files"
!echo "Training chips: $(ls data/cvpr/ny/train/chips | wc -l) files"

## 🏗️ Download Pre-trained Clay Model

Now we need the **pre-trained Clay foundation model**. Think of this as downloading a "universal Earth imagery expert" that already understands features like vegetation, water, and built structures.

### About the Clay Model:
- **Version 1.5**: Latest stable version  
- **Size**: ~400MB (this is normal for foundation models!)
- **Training**: Trained on millions of satellite/aerial images
- **Format**: PyTorch Lightning checkpoint (.ckpt file)

### What Makes Clay Special:
- 🌍 **Global coverage**: Trained on imagery from around the world
- 🛰️ **Multi-sensor**: Works with different satellite/aerial platforms  
- 🎯 **Transfer learning ready**: Designed to be fine-tuned for specific tasks
- ⚡ **Efficient**: Optimized for both training and inference

In [None]:
# Create checkpoints directory and download Clay model
!mkdir -p checkpoints

print("⬇️ Downloading Clay v1.5 model (this may take a few minutes)...")
!wget -O checkpoints/clay-v1.5.ckpt https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt

print("✅ Clay model downloaded successfully!")

### ✅ Verify Model Download

Let's check the downloaded model file:

In [None]:
# Verify the model was downloaded correctly
!ls -lh checkpoints/
print(f"✅ Clay model size: {!du -h checkpoints/clay-v1.5.ckpt | !cut -f1} - looks good!")

## ⚙️ Training Configuration

Before training, let's examine the configuration file that controls all the training parameters. Understanding these settings helps you adapt the model for your own projects.

### What's in the Config File:
- **Data paths**: Where to find training/validation data
- **Model settings**: Architecture choices and hyperparameters  
- **Training params**: Learning rate, batch size, number of epochs
- **Hardware settings**: GPU usage, mixed precision training
- **Logging**: Where to save results and checkpoints

In [None]:
# Let's look at the training configuration
print("📋 Training Configuration:")
!cat configs/segment_chesapeake.yaml

print("\n💡 Key Settings Explained:")
print("- lr: 1e-5 (learning rate - how fast the model learns)")  
print("- batch_size: 16 (number of images processed together)")
print("- max_epochs: 50 (maximum training iterations)")
print("- precision: bf16-mixed (faster training with minimal accuracy loss)")

### 📝 Understanding the Configuration

The config file uses YAML format - a human-readable way to specify settings. Here's what each section does:

- **data**: Paths to training and validation data
- **model**: Architecture and learning parameters  
- **trainer**: Hardware settings and training duration
- **callbacks**: When to save models and how to monitor progress
- **logger**: Where to save training logs and metrics

## 🚀 Model Training Setup

Now we'll set up the training pipeline using PyTorch Lightning. This approach separates data handling from model training, making the code cleaner and more maintainable.

### Training Components:

1. **DataModule**: Handles loading and preprocessing of images
2. **Model**: The Clay encoder + segmentation head  
3. **Trainer**: Orchestrates the training process

### Key Benefits of This Approach:
- ✅ **Reproducible**: Same setup works across different environments
- ✅ **Scalable**: Easy to train on single GPU or multiple GPUs  
- ✅ **Maintainable**: Clean separation of concerns
- ✅ **Flexible**: Easy to modify individual components

### 📊 Initialize Data Module

The DataModule handles all data operations - loading images, applying transforms, creating batches:

In [None]:
# Initialize the data module with our processed data
print("📊 Setting up data module...")

dm = ChesapeakeDataModule(
    train_chip_dir="data/cvpr/ny/train/chips/",      # Training images
    train_label_dir="data/cvpr/ny/train/labels/",    # Training labels  
    val_chip_dir="data/cvpr/ny/val/chips/",          # Validation images
    val_label_dir="data/cvpr/ny/val/labels/",        # Validation labels
    metadata_path="configs/metadata.yaml",           # Data normalization info
    batch_size=16,                                   # Images per training batch
    num_workers=8,                                   # Parallel data loading processes  
    platform="naip",                                 # Image type (NAIP aerial imagery)
)

# Prepare the data loaders
dm.setup()
print("✅ Data module ready!")
print(f"📈 Training batches: {len(dm.train_dataloader())}")
print(f"📊 Validation batches: {len(dm.val_dataloader())}")

### 🤖 Initialize the Model

Now we create our segmentation model - Clay encoder + segmentation head:

In [None]:
# Initialize the segmentation model
print("🤖 Setting up segmentation model...")

model = ChesapeakeSegmentor(
    num_classes=7,                              # 7 land cover classes
    ckpt_path="checkpoints/clay-v1.5.ckpt",    # Pre-trained Clay model
    lr=1e-5,                                    # Learning rate (conservative for fine-tuning)
    wd=0.05,                                    # Weight decay (regularization)
    b1=0.9,                                     # Adam optimizer beta1  
    b2=0.95,                                    # Adam optimizer beta2
)

print("✅ Model initialized!")
print(f"🧊 Clay encoder: FROZEN (saves compute)")
print(f"🎯 Segmentation head: TRAINABLE (learns land cover patterns)")

### ⚡ Setup the Trainer

The Trainer handles the training loop, GPU usage, and checkpointing:

In [None]:
# Import the Trainer
from lightning import Trainer

print("⚡ Setting up trainer...")

In [None]:
# Configure the trainer for our training session
trainer = Trainer(
    accelerator="auto",                    # Automatically detect GPU/CPU
    devices=1,                            # Use 1 device (GPU if available)  
    num_nodes=1,                          # Single machine training
    precision="bf16-mixed",               # Mixed precision (faster training)
    log_every_n_steps=5,                  # Log metrics every 5 training steps
    max_epochs=1,                         # Train for 1 epoch (demo purposes)
    accumulate_grad_batches=1,            # No gradient accumulation
    default_root_dir="checkpoints/segment", # Where to save checkpoints
    fast_dev_run=False,                   # Full training (not debugging mode)
    num_sanity_val_steps=0,               # Skip validation sanity check
)

print("✅ Trainer configured!")
print(f"🎯 Will train for {trainer.max_epochs} epoch(s)")
print(f"💾 Checkpoints saved to: {trainer.default_root_dir}")

### 🏁 Start Training!

Everything is set up - let's train the model! This will:

1. **Load batches** of images and labels
2. **Forward pass**: Run images through Clay encoder + segmentation head  
3. **Compute loss**: Compare predictions to ground truth labels
4. **Backward pass**: Calculate gradients for the segmentation head
5. **Update weights**: Improve the segmentation head parameters
6. **Validate**: Test performance on validation data
7. **Save checkpoint**: Store the trained model

**Expected time**: ~5-10 minutes for 1 epoch (depending on hardware)

**What to watch for**:
- Training loss should decrease over time
- Validation metrics should improve
- No out-of-memory errors

In [None]:
# Start the training process!
print("🚀 Starting training...")
print("📊 Watch the progress below:")

trainer.fit(model, dm)

print("\n🎉 Training complete!")
print("📁 Check the checkpoints directory for your trained model")
print("➡️ Next: Run the inference notebook to see predictions!")