# Model Inference and Results Visualization 

Welcome to the **inference and visualization** part of Tutorial 3! Here you'll see your fine-tuned Clay model in action, making predictions on real satellite imagery.

## What You'll Learn
- How to load a trained segmentation model for inference
- Techniques for visualizing model predictions
- How to interpret land cover segmentation results
- Methods for comparing predictions with ground truth
- Best practices for model evaluation

## What We'll Do
1. **Load the trained model** from the previous notebook
2. **Prepare validation data** for testing
3. **Run inference** to generate predictions  
4. **Visualize results** with color-coded land cover maps
5. **Compare predictions** with ground truth labels

## Key Concepts

### For GIS Professionals 📍
- **Inference**: Using your trained model to classify new imagery
- **Visualization**: Creating interpretable land cover maps from model outputs
- Think of this as automated feature extraction and classification
- Results can be exported as GeoTIFF files for use in GIS software

### For Data Analysts 📊
- **Model evaluation**: Assessing how well our model performs
- **Visual validation**: Checking predictions against known ground truth
- **Pattern recognition**: Understanding what the model learned vs. missed
- **Quality assessment**: Identifying areas for model improvement

### For ML Engineers 🤖
- **Inference pipeline**: Loading checkpoints and running forward passes
- **Post-processing**: Converting logits to class predictions
- **Batch processing**: Efficient handling of multiple images
- **Model interpretation**: Understanding model behavior through visualization

## 📂 Environment Setup

**Important**: This notebook assumes you've completed the training notebook first. If you haven't, please run `tut3_EOFM_finetune.ipynb` before proceeding.

Let's make sure we're in the correct directory and have access to our trained model:

In [None]:
# Navigate to the model directory (if not already there)
%cd model/

# Verify we have the necessary files
print("📁 Current directory contents:")
!ls -la

print("\n🔍 Checking for trained model...")
!ls -la checkpoints/segment/lightning_logs/*/checkpoints/ 2>/dev/null || echo "❌ No trained model found - please run the training notebook first!"

## 📚 Import Required Libraries

Let's import all the tools we need for inference and visualization:

In [None]:
# Core Python libraries
import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# PyTorch for deep learning
import torch
import torch.nn.functional as F

# Additional utilities  
from einops import rearrange

# Add claymodel to path and import our custom modules
sys.path.append("./claymodel")
from claymodel.finetune.segment.chesapeake_datamodule import ChesapeakeDataModule
from claymodel.finetune.segment.chesapeake_model import ChesapeakeSegmentor

print("✅ All libraries imported successfully!")
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"🎮 GPU available: {'Yes' if torch.cuda.is_available() else 'No (using CPU)'}")

## ⚙️ Configuration and File Paths

Let's define all the paths and parameters we'll need. These should match what you used in the training notebook:

In [None]:
# File paths and configuration
CHESAPEAKE_CHECKPOINT_PATH = "checkpoints/segment/lightning_logs/version_0/checkpoints/epoch=0-step=63.ckpt"
CLAY_CHECKPOINT_PATH = "checkpoints/clay-v1.5.ckpt"
METADATA_PATH = "configs/metadata.yaml"

# Data directories
TRAIN_CHIP_DIR = "data/cvpr/ny/train/chips/"
TRAIN_LABEL_DIR = "data/cvpr/ny/train/labels/"
VAL_CHIP_DIR = "data/cvpr/ny/val/chips/"
VAL_LABEL_DIR = "data/cvpr/ny/val/labels/"

# Data loading parameters
BATCH_SIZE = 32          # Process 32 images at once (larger batch for inference)
NUM_WORKERS = 1          # Single worker to avoid issues in Colab
PLATFORM = "naip"        # NAIP aerial imagery platform

print("📋 Configuration loaded:")
print(f"   🎯 Model checkpoint: {CHESAPEAKE_CHECKPOINT_PATH}")
print(f"   🧠 Clay model: {CLAY_CHECKPOINT_PATH}")
print(f"   📊 Batch size: {BATCH_SIZE}")
print(f"   📷 Platform: {PLATFORM}")

## 🔧 Helper Functions

Let's define functions to handle model loading, data preparation, inference, and visualization. Breaking these into functions makes the code more organized and reusable:

In [None]:
def get_model(chesapeake_checkpoint_path, clay_checkpoint_path, metadata_path):
    """
    Load the trained segmentation model from checkpoint.
    
    Args:
        chesapeake_checkpoint_path: Path to our trained model
        clay_checkpoint_path: Path to the Clay foundation model  
        metadata_path: Path to data normalization metadata
        
    Returns:
        model: Loaded model in evaluation mode
    """
    print("🤖 Loading trained model...")
    
    model = ChesapeakeSegmentor.load_from_checkpoint(
        checkpoint_path=chesapeake_checkpoint_path,
        metadata_path=metadata_path,
        ckpt_path=clay_checkpoint_path,
    )
    
    # Set to evaluation mode (disables dropout, batch norm training mode, etc.)
    model.eval()
    
    print("✅ Model loaded successfully!")
    return model

In [None]:
def get_data(train_chip_dir, train_label_dir, val_chip_dir, val_label_dir, 
             metadata_path, batch_size, num_workers, platform):
    """
    Set up data loading for inference.
    
    Args:
        Various paths and parameters for data loading
        
    Returns:
        batch: A batch of validation data
        metadata: Data normalization and class information
    """
    print("📊 Setting up data loader...")
    
    # Create data module (same as training, but we only need validation data)
    dm = ChesapeakeDataModule(
        train_chip_dir=train_chip_dir,
        train_label_dir=train_label_dir,  
        val_chip_dir=val_chip_dir,
        val_label_dir=val_label_dir,
        metadata_path=metadata_path,
        batch_size=batch_size,
        num_workers=num_workers,
        platform=platform,
    )
    
    # Setup the data module
    dm.setup(stage="fit")
    
    # Get one batch of validation data for visualization
    val_dl = iter(dm.val_dataloader())
    batch = next(val_dl)
    
    print(f"✅ Data loaded - batch contains {batch['pixels'].shape[0]} images")
    print(f"📏 Image shape: {list(batch['pixels'].shape[1:])}")
    
    return batch, dm.metadata

In [None]:
def run_prediction(model, batch):
    """
    Run inference on a batch of images.
    
    Args:
        model: Trained segmentation model
        batch: Batch of input images
        
    Returns:
        outputs: Model predictions (probabilities for each class)
    """
    print("🔮 Running inference...")
    
    # Disable gradient computation for faster inference
    with torch.no_grad():
        # Forward pass through the model
        outputs = model(batch)
    
    # Upsample predictions to match original image size (256x256)
    # The model outputs smaller feature maps that need to be upsampled
    outputs = F.interpolate(
        outputs, 
        size=(256, 256),           # Target size
        mode="bilinear",           # Smooth interpolation
        align_corners=False        # PyTorch default
    )
    
    print(f"✅ Inference complete - predictions shape: {list(outputs.shape)}")
    return outputs

In [None]:
def denormalize_images(normalized_images, means, stds):
    """
    Convert normalized images back to viewable format.
    
    During training, images are normalized (mean=0, std=1) for better model performance.  
    For visualization, we need to reverse this normalization.
    
    Args:
        normalized_images: Normalized image tensors
        means: Mean values used for normalization
        stds: Standard deviation values used for normalization
        
    Returns:
        denormalized_images: Images in 0-255 range for display
    """
    means = np.array(means).reshape(1, -1, 1, 1)
    stds = np.array(stds).reshape(1, -1, 1, 1)
    
    # Reverse normalization: multiply by std, then add mean
    denormalized_images = normalized_images * stds + means
    
    # Convert to 0-255 range for display
    return denormalized_images.astype(np.uint8)


def post_process(batch, outputs, metadata):
    """
    Convert model outputs and inputs into visualization-ready format.
    
    Args:
        batch: Original batch of data
        outputs: Model prediction probabilities
        metadata: Data normalization info
        
    Returns:
        images: RGB images ready for display
        labels: Ground truth segmentation maps
        preds: Predicted segmentation maps
    """
    print("🔄 Post-processing results...")
    
    # Convert prediction probabilities to class predictions
    # argmax selects the class with highest probability for each pixel
    preds = torch.argmax(outputs, dim=1).detach().cpu().numpy()
    
    # Extract ground truth labels
    labels = batch["label"].detach().cpu().numpy()
    
    # Extract normalized pixel values
    pixels = batch["pixels"].detach().cpu().numpy()
    
    # Get normalization parameters for this platform (NAIP)
    means = list(metadata["naip"].bands.mean.values())
    stds = list(metadata["naip"].bands.std.values())
    
    # Denormalize images for display
    norm_pixels = denormalize_images(pixels, means, stds)
    
    # Rearrange from (batch, channels, height, width) to (batch, height, width, channels)
    # This is the format matplotlib expects for RGB images
    images = rearrange(norm_pixels[:, :3, :, :], "b c h w -> b h w c")
    
    print(f"✅ Post-processing complete")
    print(f"📊 Processed {len(images)} images")
    
    return images, labels, preds

In [None]:
def plot_predictions(images, labels, preds):
    """
    Create a comprehensive visualization of results.
    
    Shows original images, ground truth labels, and model predictions
    in an easy-to-compare grid format.
    
    Args:
        images: RGB aerial images
        labels: Ground truth segmentation maps  
        preds: Model predicted segmentation maps
    """
    print("🎨 Creating visualization...")
    
    # Define colors for each land cover class
    # These colors are chosen to be intuitive and visually distinct
    colors = [
        (0/255, 0/255, 255/255, 1),         # Deep Blue for water 💧
        (34/255, 139/255, 34/255, 1),       # Forest Green for tree canopy 🌳
        (154/255, 205/255, 50/255, 1),      # Yellow Green for low vegetation 🌱
        (210/255, 180/255, 140/255, 1),     # Tan for barren land 🏔️
        (169/255, 169/255, 169/255, 1),     # Dark Gray for impervious (other) 🏢
        (105/255, 105/255, 105/255, 1),     # Dim Gray for impervious (road) 🛣️
        (255/255, 255/255, 255/255, 1),     # White for no data ⬜
    ]
    cmap = ListedColormap(colors)
    
    # Create a large figure to show all comparisons
    fig, axes = plt.subplots(12, 8, figsize=(16, 24))
    fig.suptitle('🌍 Land Cover Segmentation Results', fontsize=16, fontweight='bold')
    
    # Plot in three rows: Images, Ground Truth, Predictions
    plot_data(axes, images, row_offset=0, title="📷 Original Image")
    plot_data(axes, labels, row_offset=1, title="🎯 Ground Truth", cmap=cmap, vmin=0, vmax=6)
    plot_data(axes, preds, row_offset=2, title="🤖 Model Prediction", cmap=cmap, vmin=0, vmax=6)
    
    # Add a legend explaining the color scheme
    add_legend(fig, cmap)
    
    plt.tight_layout()
    plt.show()
    
    print("✅ Visualization complete!")


def plot_data(axes, data, row_offset, title=None, cmap=None, vmin=None, vmax=None):
    """Helper function to plot a row of data in the grid."""
    for i, item in enumerate(data):
        if i >= 24:  # Only show first 24 images (3 rows of 8)
            break
            
        row = row_offset + (i // 8) * 3
        col = i % 8
        
        axes[row, col].imshow(item, cmap=cmap, vmin=vmin, vmax=vmax)
        axes[row, col].axis("off")
        
        # Add row titles
        if title and col == 0:
            axes[row, col].set_ylabel(title, rotation=0, fontsize=12, 
                                    fontweight='bold', ha='right', va='center')


def add_legend(fig, cmap):
    """Add a color legend explaining the land cover classes."""
    class_names = [
        "💧 Water",
        "🌳 Tree Canopy", 
        "🌱 Low Vegetation",
        "🏔️ Barren Land",
        "🏢 Impervious (Other)",
        "🛣️ Impervious (Roads)", 
        "⬜ No Data"
    ]
    
    # Create legend patches
    import matplotlib.patches as mpatches
    patches = [mpatches.Patch(color=cmap.colors[i], label=class_names[i]) 
               for i in range(len(class_names))]
    
    # Add legend to the figure
    fig.legend(handles=patches, loc='center', bbox_to_anchor=(0.5, 0.02), 
               ncol=4, fontsize=10)

## 🚀 Run the Complete Inference Pipeline

Now let's put it all together! We'll load the model, prepare data, run inference, and visualize results:

In [None]:
# Load the trained model
model = get_model(CHESAPEAKE_CHECKPOINT_PATH, CLAY_CHECKPOINT_PATH, METADATA_PATH)

In [None]:
# Get validation data for testing
batch, metadata = get_data(
    TRAIN_CHIP_DIR,
    TRAIN_LABEL_DIR,
    VAL_CHIP_DIR,
    VAL_LABEL_DIR,
    METADATA_PATH,
    BATCH_SIZE,
    NUM_WORKERS,
    PLATFORM,
)

# Move data to GPU if available (same device as model)
device = next(model.parameters()).device
batch = {k: v.to(device) for k, v in batch.items()}
print(f"📱 Using device: {device}")

In [None]:
# Run inference on the batch
outputs = run_prediction(model, batch)

In [None]:
# Post-process results for visualization
images, labels, preds = post_process(batch, outputs, metadata)

In [None]:
# Create the final visualization
plot_predictions(images, labels, preds)

print("\n🎉 Inference and visualization complete!")
print("\n🔍 What to Look For:")
print("   • How well does the model identify water bodies?")
print("   • Are forest areas correctly classified?") 
print("   • Does the model distinguish between different types of impervious surfaces?")
print("   • Where does the model struggle or make mistakes?")
print("\n💡 Next Steps:")
print("   • Try running on more batches to see consistency")
print("   • Consider additional training epochs for better performance")
print("   • Experiment with different learning rates or data augmentation")