# SageMaker Semantic Segmentation Exercise

This notebook demonstrates Amazon SageMaker's **Semantic Segmentation** algorithm for pixel-level image classification.

## What You'll Learn
1. How to prepare pixel-level annotations
2. How to configure and understand semantic segmentation hyperparameters
3. How to train a segmentation model
4. How to interpret and evaluate segmentation masks

## What is Semantic Segmentation?

Semantic Segmentation classifies **every pixel** in an image, assigning each pixel to a class. Unlike object detection which provides bounding boxes, segmentation provides exact object shapes and boundaries.

**Key Difference from Other CV Tasks:**

| Task | Output | Granularity | Example Output |
|------|--------|-------------|----------------|
| Classification | Single label | Image-level | "cat" |
| Object Detection | Bounding boxes + labels | Object-level | Box around cat |
| **Semantic Segmentation** | Pixel mask | Pixel-level | Every cat pixel labeled |
| Instance Segmentation | Pixel mask per instance | Pixel + Instance | Cat1 pixels, Cat2 pixels |

**Note on Instance vs Semantic:**
- **Semantic**: All cats labeled as "cat" (same color)
- **Instance**: Each cat has unique ID (different colors per cat)

## Use Cases

| Industry | Application |
|----------|-------------|
| Autonomous Driving | Road/lane detection, pedestrian segmentation, obstacle identification |
| Medical Imaging | Tumor segmentation, organ identification, cell counting |
| Satellite/Aerial Imagery | Land use classification, building detection, flood mapping |
| Robotics | Scene understanding, navigation, manipulation |
| Fashion/Retail | Clothing segmentation, virtual try-on, background removal |
| Agriculture | Crop health analysis, weed detection, yield estimation |

---

## ⚠️ Important: Training Cost Warning

<div style="background-color: #090907ff; border: 1px solid #ffc107; border-radius: 5px; padding: 15px; margin: 10px 0;">

### GPU Requirements and Costs

**Semantic Segmentation training requires GPU instances.** This algorithm processes entire images at pixel-level resolution, making it computationally intensive.

| Instance Type | GPU | Memory | On-Demand Price* |
|---------------|-----|--------|------------------|
| ml.p2.xlarge | 1x K80 | 12 GB | ~$1.26/hour |
| ml.p3.2xlarge | 1x V100 | 16 GB | ~$3.83/hour |
| ml.p3.8xlarge | 4x V100 | 64 GB | ~$14.69/hour |
| ml.g4dn.xlarge | 1x T4 | 16 GB | ~$0.74/hour |
| ml.g4dn.2xlarge | 1x T4 | 32 GB | ~$1.05/hour |
| ml.g5.xlarge | 1x A10G | 24 GB | ~$1.41/hour |

*Prices are approximate for us-west-2 and subject to change. Check [AWS SageMaker Pricing](https://aws.amazon.com/sagemaker/pricing/) for current rates.

### Cost Estimation Example

Training a typical semantic segmentation model:
- **30 epochs** with **5,000 images** (512x512): ~3-5 hours on ml.p3.2xlarge
- **Estimated cost**: $11.49 - $19.15 for training
- Semantic segmentation is **more expensive** than image classification due to pixel-level processing

### Cost-Saving Recommendations

1. **Use Spot Instances**: Can save up to 70% - add `use_spot_instances=True` to Estimator
2. **Start with ml.g4dn.xlarge**: Most cost-effective GPU option (~$0.74/hour)
3. **Use transfer learning**: Set `use_pretrained_model=True` - critical for this task
4. **Reduce crop_size**: Start with `crop_size=240` instead of 480 for faster iteration
5. **Use FCN over DeepLab**: FCN is faster, use DeepLab only when you need maximum accuracy
6. **Start with fewer epochs**: Use 10 epochs to validate setup before full training
7. **Use ResNet-50 backbone**: Faster than ResNet-101, often sufficient accuracy

</div>

## Step 1: Setup and Imports

In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.image_uris import retrieve
from sagemaker.estimator import Estimator
import numpy as np
import json
import os
from datetime import datetime
from dotenv import load_dotenv
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from collections import defaultdict

# Load environment variables from .env file
load_dotenv()

# Configure AWS session from environment variables
aws_profile = os.getenv('AWS_PROFILE')
aws_region = os.getenv('AWS_REGION', 'us-west-2')
sagemaker_role = os.getenv('SAGEMAKER_ROLE_ARN')

if aws_profile:
    boto3.setup_default_session(profile_name=aws_profile, region_name=aws_region)
else:
    boto3.setup_default_session(region_name=aws_region)

# SageMaker session and role
sagemaker_session = sagemaker.Session()

if sagemaker_role:
    role = sagemaker_role
else:
    role = get_execution_role()

region = sagemaker_session.boto_region_name

print(f"AWS Profile: {aws_profile or 'default'}")
print(f"SageMaker Role: {role}")
print(f"Region: {region}")
print(f"SageMaker SDK Version: {sagemaker.__version__}")

In [None]:
# Configuration
BUCKET_NAME = sagemaker_session.default_bucket()
PREFIX = "semantic-segmentation"

print(f"S3 Bucket: {BUCKET_NAME}")
print(f"S3 Prefix: {PREFIX}")

## Step 2: Understand Data Format

SageMaker Semantic Segmentation requires a specific data format with paired images and annotation masks.

### Directory Structure
```
train/
  image001.jpg
  image002.jpg
train_annotation/
  image001.png  # Grayscale mask, pixel values = class IDs
  image002.png
validation/
  image001.jpg
validation_annotation/
  image001.png
```

### Critical Requirements

1. **Matching filenames**: `train/image001.jpg` must have `train_annotation/image001.png`
2. **Same dimensions**: Annotation mask must be EXACT same size as input image
3. **PNG format for annotations**: Must be PNG (lossless compression preserves class IDs)
4. **Grayscale annotations**: Single channel where pixel value = class ID
5. **Class ID range**: 0 to (num_classes - 1)
6. **No compression artifacts**: Use PNG or uncompressed formats for annotations

### Annotation Format Details

| Aspect | Requirement |
|--------|-------------|
| Format | PNG (8-bit grayscale) |
| Channels | 1 (grayscale) |
| Pixel values | 0 to num_classes-1 |
| Background | Typically class 0 |
| Dimensions | Must match input image exactly |

### Annotation Example

For a scene with 5 classes:

```
Class 0: Background (pixel value = 0)
Class 1: Road (pixel value = 1)
Class 2: Building (pixel value = 2)
Class 3: Vegetation (pixel value = 3)
Class 4: Sky (pixel value = 4)
```

The annotation PNG looks grayscale to humans (very dark) because values 0-4 are nearly black. When visualized with a colormap, each class gets a distinct color.

**Common Mistake**: Using RGB colors in annotations. The algorithm expects grayscale class IDs, not RGB colors!

## Step 3: Synthetic Data - Limitations and Purpose

<div style="background-color: #030405ff; border: 1px solid #0c5460; border-radius: 5px; padding: 15px; margin: 10px 0;">

### ⚠️ Important: Why We Can't Truly Simulate Semantic Segmentation

Like other deep learning vision tasks, Semantic Segmentation requires **real images** with actual visual features.

**Why synthetic data doesn't work for training:**
1. **Pixel-level features matter**: The model learns to recognize boundaries, textures, and context at every pixel
2. **Random shapes don't generalize**: A model trained on synthetic rectangles won't segment real buildings
3. **Context is crucial**: Real scenes have natural spatial relationships (sky above buildings, road at bottom)

**What we CAN demonstrate:**
- ✅ Annotation format (grayscale PNG with class IDs)
- ✅ Data structure and pairing
- ✅ Evaluation metrics (IoU, Dice, pixel accuracy)
- ✅ Output visualization and interpretation
- ✅ Hyperparameter configuration

**For actual training, you need:**
- Real images with pixel-level annotations
- Public datasets: Cityscapes, PASCAL VOC, ADE20K, COCO-Stuff
- SageMaker Ground Truth for custom labeling (supports semantic segmentation)

</div>

In [None]:
def generate_synthetic_segmentation_mask(height=256, width=256, num_classes=5, seed=None):
    """
    Generate a synthetic segmentation mask with random shapes.
    
    NOTE: This is for FORMAT DEMONSTRATION ONLY.
    Real training requires actual images with corresponding annotations.
    
    Args:
        height: Image height in pixels
        width: Image width in pixels
        num_classes: Number of segmentation classes
        seed: Random seed for reproducibility
    
    Returns:
        mask: 2D numpy array with class IDs (dtype=uint8)
        class_names: List of class names
    """
    if seed is not None:
        np.random.seed(seed)
    
    class_names = ['background', 'road', 'building', 'vegetation', 'sky']
    
    # Start with background
    mask = np.zeros((height, width), dtype=np.uint8)
    
    # Add sky (top portion)
    sky_height = np.random.randint(height // 4, height // 2)
    mask[:sky_height, :] = 4  # sky class
    
    # Add road (bottom center)
    road_width = width // 3
    road_start = (width - road_width) // 2
    mask[height * 2 // 3:, road_start:road_start + road_width] = 1  # road class
    
    # Add random buildings
    num_buildings = np.random.randint(2, 5)
    for _ in range(num_buildings):
        bw = np.random.randint(30, 80)
        bh = np.random.randint(50, 120)
        bx = np.random.randint(0, width - bw)
        by = np.random.randint(sky_height, height - bh - 20)
        mask[by:by + bh, bx:bx + bw] = 2  # building class
    
    # Add vegetation patches
    num_veg = np.random.randint(3, 7)
    for _ in range(num_veg):
        vw = np.random.randint(20, 60)
        vh = np.random.randint(20, 50)
        vx = np.random.randint(0, width - vw)
        vy = np.random.randint(sky_height, height - vh)
        mask[vy:vy + vh, vx:vx + vw] = 3  # vegetation class
    
    return mask, class_names

# Generate sample mask
sample_mask, class_names = generate_synthetic_segmentation_mask(seed=42)

print(f"Mask shape: {sample_mask.shape}")
print(f"Mask dtype: {sample_mask.dtype} (must be uint8 for PNG)")
print(f"Classes: {class_names}")
print(f"Unique values in mask: {np.unique(sample_mask)}")
print(f"Value range: {sample_mask.min()} to {sample_mask.max()}")

In [None]:
def visualize_segmentation(mask, class_names, title="Segmentation Mask"):
    """
    Visualize segmentation mask with colored classes.
    
    Shows both the raw annotation format and a human-readable colored version.
    """
    num_classes = len(class_names)
    colors = plt.cm.tab10(np.linspace(0, 1, num_classes))
    
    # Create colored mask for visualization
    colored_mask = np.zeros((*mask.shape, 3))
    for class_id in range(num_classes):
        colored_mask[mask == class_id] = colors[class_id][:3]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Raw annotation (grayscale - what's actually saved)
    im1 = axes[0].imshow(mask, cmap='gray', vmin=0, vmax=num_classes-1)
    axes[0].set_title("Raw Annotation (Grayscale PNG)\nPixel values = Class IDs")
    axes[0].axis('off')
    plt.colorbar(im1, ax=axes[0], label='Class ID', ticks=range(num_classes))
    
    # Colored visualization (for human understanding)
    axes[1].imshow(colored_mask)
    axes[1].set_title(f"{title}\n(Colored for visualization)")
    axes[1].axis('off')
    
    # Legend
    patches = [mpatches.Patch(color=colors[i], label=f"{i}: {class_names[i]}") 
               for i in range(num_classes)]
    axes[1].legend(handles=patches, loc='upper right', fontsize=10)
    
    plt.tight_layout()
    plt.show()

visualize_segmentation(sample_mask, class_names, "Synthetic Urban Scene Segmentation")

In [None]:
def analyze_class_distribution(mask, class_names):
    """
    Analyze pixel distribution across classes.
    
    Important for understanding class imbalance - common issue in segmentation.
    """
    total_pixels = mask.size
    class_pixels = {}
    
    for i, name in enumerate(class_names):
        count = (mask == i).sum()
        class_pixels[name] = {
            'count': count,
            'percentage': count / total_pixels * 100
        }
    
    return class_pixels

# Analyze distribution
distribution = analyze_class_distribution(sample_mask, class_names)

print("Class Distribution (Pixel Count):")
print("=" * 50)
for name, data in distribution.items():
    bar = '#' * int(data['percentage'] / 2)
    print(f"{name:12s}: {data['count']:6d} pixels ({data['percentage']:5.1f}%) {bar}")

# Check for class imbalance
percentages = [d['percentage'] for d in distribution.values()]
imbalance_ratio = max(percentages) / max(min(percentages), 0.1)
print(f"\nClass imbalance ratio: {imbalance_ratio:.1f}x")
if imbalance_ratio > 10:
    print("⚠️  Significant class imbalance! Consider using weighted loss or class balancing.")

In [None]:
# Visualize class distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = plt.cm.tab10(np.linspace(0, 1, len(class_names)))

# Bar chart
percentages = [distribution[name]['percentage'] for name in class_names]
bars = axes[0].bar(class_names, percentages, color=colors)
axes[0].set_ylabel('Percentage of Image')
axes[0].set_title('Class Distribution (Pixel Percentage)')
axes[0].tick_params(axis='x', rotation=45)

# Add percentage labels
for bar, pct in zip(bars, percentages):
    axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{pct:.1f}%', ha='center', va='bottom', fontsize=10)

# Pie chart
axes[1].pie(percentages, labels=class_names, colors=colors, autopct='%1.1f%%',
           startangle=90)
axes[1].set_title('Class Distribution (Pie Chart)')

plt.tight_layout()
plt.show()

---

## Step 4: Training Configuration and Hyperparameters

### Understanding Semantic Segmentation Hyperparameters

SageMaker's Semantic Segmentation algorithm has specific hyperparameters for architecture selection and training configuration.

### Core Required Parameters

**num_classes** (Required)
- Total number of segmentation classes INCLUDING background
- Must match the maximum class ID + 1 in your annotations
- Example: Classes 0-4 → `num_classes=5`

**num_training_samples** (Required)
- Total number of training images
- Used for learning rate scheduling
- Must match your actual dataset size

### Architecture Parameters

**algorithm**
- The segmentation architecture to use
- Options: `fcn`, `psp`, `deeplab`

| Algorithm | Full Name | Description | Best For |
|-----------|-----------|-------------|----------|
| `fcn` | Fully Convolutional Network | Simple encoder-decoder | Fast inference, basic tasks |
| `psp` | Pyramid Scene Parsing | Multi-scale pooling | Scene parsing, global context |
| `deeplab` | DeepLabV3 | Atrous/dilated convolutions | **Maximum accuracy**, fine boundaries |

- Default: `fcn`
- Recommendation: Use `deeplab` for best accuracy, `fcn` for speed

**backbone**
- The encoder network that extracts features
- Options: `resnet-50`, `resnet-101`

| Backbone | Layers | Speed | Accuracy | Memory |
|----------|--------|-------|----------|--------|
| `resnet-50` | 50 | Faster | Good | ~4GB |
| `resnet-101` | 101 | Slower | Better | ~6GB |

- Default: `resnet-50`
- Recommendation: Start with `resnet-50`, upgrade if needed

**use_pretrained_model**
- Whether to initialize backbone with ImageNet pretrained weights
- `True`: **Highly recommended** - critical for segmentation
- `False`: Train from scratch (needs much more data)
- Default: `True`
- Note: Pretrained weights help the model understand basic visual features

### Training Parameters

**epochs**
- Number of complete passes through training data
- Segmentation often needs more epochs than classification
- Typical range: 30-100 depending on dataset size
- Default: `30`

**mini_batch_size**
- Number of images per batch
- **Memory intensive**: Segmentation processes full images
- Reduce if you get OOM (Out of Memory) errors
- Typical range: 4-16 depending on crop_size and GPU
- Default: `16`
- Rule of thumb: crop_size=480 → batch_size 4-8; crop_size=240 → batch_size 8-16

**learning_rate**
- Initial learning rate
- Lower for fine-tuning pretrained models
- Typical range: 0.001 - 0.01
- Default: `0.001`

**lr_scheduler**
- Learning rate decay strategy
- Options: `poly`, `step`, `cosine`

| Scheduler | Description | Formula |
|-----------|-------------|--------|
| `poly` | Polynomial decay | LR × (1 - iter/max_iter)^power |
| `step` | Step decay at specific epochs | LR × factor at each step |
| `cosine` | Cosine annealing | Smooth decay following cosine |

- Default: `poly` (recommended for segmentation)

**lr_scheduler_step**
- For `step` scheduler: epochs at which to reduce LR
- Format: comma-separated (e.g., `"10,20,30"`)

**crop_size**
- Size of random crops during training
- Images are randomly cropped to this size for training
- Larger = more context, but slower and more memory
- Must be smaller than smallest image dimension
- Typical values: 240, 320, 480, 512
- Default: `240`
- **Important**: Larger crop_size requires reducing batch_size

### Optimizer Parameters

**optimizer**
- Optimization algorithm
- Options: `sgd`, `adam`, `rmsprop`
- Default: `sgd` (recommended with momentum)

**momentum**
- Momentum for SGD optimizer
- Typical value: 0.9
- Default: `0.9`

**weight_decay**
- L2 regularization
- Helps prevent overfitting
- Typical range: 0.0001 - 0.001
- Default: `0.0001`

### Data Augmentation Parameters

**sync_bn**
- Synchronized Batch Normalization across GPUs
- Useful for multi-GPU training with small batch sizes
- Default: `False`

**validation_crop_size**
- Crop size for validation
- Often same as crop_size or larger
- Default: Same as crop_size

In [None]:
# Get Semantic Segmentation container image
semantic_segmentation_image = retrieve(
    framework='semantic-segmentation',
    region=region,
    version='1'
)

print(f"Semantic Segmentation Image URI: {semantic_segmentation_image}")

In [None]:
# Complete hyperparameter configuration with explanations
hyperparameters = {
    # === REQUIRED PARAMETERS ===
    "num_classes": 5,                    # Number of segmentation classes (including background)
    "num_training_samples": 1000,        # Total training images
    
    # === ARCHITECTURE ===
    "algorithm": "deeplab",              # fcn, psp, or deeplab (deeplab for best accuracy)
    "backbone": "resnet-50",             # Feature extractor backbone
    "use_pretrained_model": "True",      # Transfer learning (critical for segmentation)
    
    # === TRAINING PARAMETERS ===
    "epochs": 30,                        # Training epochs
    "mini_batch_size": 8,                # Batch size (reduce if OOM)
    "learning_rate": 0.001,              # Initial learning rate
    "lr_scheduler": "poly",              # Learning rate scheduler
    
    # === OPTIMIZER ===
    "optimizer": "sgd",                  # Optimizer algorithm
    "momentum": 0.9,                     # SGD momentum
    "weight_decay": 0.0001,              # L2 regularization
    
    # === CROP SIZE ===
    "crop_size": 480,                    # Training crop size (pixels)
}

print("Semantic Segmentation Hyperparameters:")
print("=" * 55)
for key, value in hyperparameters.items():
    print(f"  {key}: {value}")

In [None]:
# Example Estimator Configuration
# NOTE: Do NOT run training without actual image data!

print("""
═══════════════════════════════════════════════════════════════════════════════
                    EXAMPLE ESTIMATOR CONFIGURATION
═══════════════════════════════════════════════════════════════════════════════

⚠️  WARNING: Running this training job will incur GPU costs!
    Semantic segmentation is compute-intensive.
    Estimated cost: $10-30 depending on dataset and epochs.

# Standard training (On-Demand)
semantic_segmentation_estimator = Estimator(
    image_uri=semantic_segmentation_image,
    role=role,
    instance_count=1,
    instance_type='ml.p3.2xlarge',  # GPU required!
    output_path=f's3://{BUCKET_NAME}/{PREFIX}/output',
    sagemaker_session=sagemaker_session,
    base_job_name='semantic-segmentation',
    max_run=3600 * 6,  # 6 hour max runtime
)

# Cost-saving alternative with Spot Instances (up to 70% savings)
semantic_segmentation_estimator_spot = Estimator(
    image_uri=semantic_segmentation_image,
    role=role,
    instance_count=1,
    instance_type='ml.g4dn.xlarge',  # Most cost-effective GPU
    output_path=f's3://{BUCKET_NAME}/{PREFIX}/output',
    sagemaker_session=sagemaker_session,
    base_job_name='semantic-segmentation-spot',
    use_spot_instances=True,         # Enable Spot pricing
    max_wait=3600 * 8,               # Max time to wait for spot capacity
    max_run=3600 * 6,                # Max training time
)

# Set hyperparameters
semantic_segmentation_estimator.set_hyperparameters(**hyperparameters)

# Data channels configuration
# train: s3://bucket/prefix/train/  (JPG images)
# train_annotation: s3://bucket/prefix/train_annotation/  (PNG masks)
# validation: s3://bucket/prefix/validation/  (JPG images)
# validation_annotation: s3://bucket/prefix/validation_annotation/  (PNG masks)

""")

---

## Step 5: Understanding Model Output

The model outputs a **grayscale PNG image** where each pixel value is the predicted class ID.

**Output Format:**
- Same dimensions as input image
- 8-bit grayscale PNG
- Pixel values range from 0 to (num_classes - 1)
- Each pixel independently classified

In [None]:
def simulate_prediction(ground_truth, class_names, noise_level=0.05):
    """
    Simulate model prediction with controlled noise.
    
    In reality, errors tend to occur at boundaries between classes.
    
    Args:
        ground_truth: Ground truth mask
        class_names: List of class names
        noise_level: Fraction of pixels to randomly misclassify
    
    Returns:
        Simulated prediction mask
    """
    prediction = ground_truth.copy()
    
    # Add random misclassifications
    noise_mask = np.random.random(ground_truth.shape) < noise_level
    random_classes = np.random.randint(0, len(class_names), ground_truth.shape)
    prediction[noise_mask] = random_classes[noise_mask]
    
    # Add boundary noise (more realistic - errors at edges)
    from scipy import ndimage
    edges = ndimage.sobel(ground_truth.astype(float)) != 0
    edge_noise = np.random.random(ground_truth.shape) < 0.15  # Higher error at edges
    boundary_noise = edges & edge_noise
    prediction[boundary_noise] = random_classes[boundary_noise]
    
    return prediction

# Generate prediction
try:
    prediction = simulate_prediction(sample_mask, class_names)
except ImportError:
    # Fallback if scipy not available
    prediction = sample_mask.copy()
    noise_mask = np.random.random(sample_mask.shape) < 0.05
    random_classes = np.random.randint(0, len(class_names), sample_mask.shape)
    prediction[noise_mask] = random_classes[noise_mask]

print(f"Ground truth shape: {sample_mask.shape}")
print(f"Prediction shape: {prediction.shape}")
print(f"Matching pixels: {(sample_mask == prediction).sum()} / {sample_mask.size}")

In [None]:
def visualize_prediction_comparison(ground_truth, prediction, class_names):
    """
    Compare ground truth and prediction with error visualization.
    """
    num_classes = len(class_names)
    colors = plt.cm.tab10(np.linspace(0, 1, num_classes))
    
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    # Ground truth (colored)
    colored_gt = np.zeros((*ground_truth.shape, 3))
    for class_id in range(num_classes):
        colored_gt[ground_truth == class_id] = colors[class_id][:3]
    axes[0].imshow(colored_gt)
    axes[0].set_title("Ground Truth")
    axes[0].axis('off')
    
    # Prediction (colored)
    colored_pred = np.zeros((*prediction.shape, 3))
    for class_id in range(num_classes):
        colored_pred[prediction == class_id] = colors[class_id][:3]
    axes[1].imshow(colored_pred)
    axes[1].set_title("Prediction")
    axes[1].axis('off')
    
    # Error map (binary)
    error_map = (ground_truth != prediction).astype(float)
    axes[2].imshow(error_map, cmap='Reds')
    error_count = error_map.sum()
    error_pct = error_count / error_map.size * 100
    axes[2].set_title(f"Errors (Red)\n{error_count:.0f} pixels ({error_pct:.1f}%)")
    axes[2].axis('off')
    
    # Correct pixels (green) vs errors (red)
    comparison = np.zeros((*ground_truth.shape, 3))
    comparison[ground_truth == prediction] = [0, 0.7, 0]  # Green for correct
    comparison[ground_truth != prediction] = [0.9, 0, 0]  # Red for errors
    axes[3].imshow(comparison)
    axes[3].set_title(f"Accuracy Map\nGreen=Correct, Red=Error")
    axes[3].axis('off')
    
    # Legend
    patches = [mpatches.Patch(color=colors[i], label=class_names[i]) 
               for i in range(num_classes)]
    fig.legend(handles=patches, loc='lower center', ncol=num_classes, fontsize=10)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    plt.show()

visualize_prediction_comparison(sample_mask, prediction, class_names)

---

## Step 6: Evaluation Metrics Deep Dive

Semantic segmentation uses pixel-level metrics. The primary metric is **Mean Intersection over Union (mIoU)**.

### Key Metrics

| Metric | Description | Formula | Range |
|--------|-------------|---------|-------|
| **Pixel Accuracy** | % pixels correctly classified | Correct / Total | 0-1 |
| **Mean Accuracy** | Average per-class accuracy | Mean(class accuracies) | 0-1 |
| **IoU (per class)** | Intersection / Union | TP / (TP+FP+FN) | 0-1 |
| **mIoU** | Mean IoU across classes | Mean(class IoUs) | 0-1 |
| **Dice Score** | Similar to IoU, used in medical | 2×TP / (2×TP+FP+FN) | 0-1 |
| **Frequency Weighted IoU** | IoU weighted by class frequency | Weighted mean | 0-1 |

### Intersection over Union (IoU) Explained

IoU measures the overlap between prediction and ground truth for a single class:

```
IoU = Area of Overlap / Area of Union
    = True Positives / (True Positives + False Positives + False Negatives)
```

**Interpretation:**
- IoU = 1.0: Perfect segmentation
- IoU = 0.5: Acceptable for many applications
- IoU = 0.0: Complete miss (no overlap)

**Why IoU over Pixel Accuracy?**
- Pixel accuracy can be misleading with class imbalance
- Example: 90% background → predicting all background gives 90% accuracy but 0 IoU for other classes

In [None]:
def calculate_iou(ground_truth, prediction, class_id):
    """
    Calculate Intersection over Union for a single class.
    
    Args:
        ground_truth: Ground truth mask
        prediction: Predicted mask
        class_id: Class ID to calculate IoU for
    
    Returns:
        IoU score (0-1)
    """
    gt_mask = ground_truth == class_id
    pred_mask = prediction == class_id
    
    intersection = (gt_mask & pred_mask).sum()
    union = (gt_mask | pred_mask).sum()
    
    if union == 0:
        return float('nan')  # Class not present in either
    
    return intersection / union


def calculate_dice(ground_truth, prediction, class_id):
    """
    Calculate Dice coefficient (F1 score) for a single class.
    
    Dice = 2 * |A ∩ B| / (|A| + |B|)
    
    Commonly used in medical image segmentation.
    """
    gt_mask = ground_truth == class_id
    pred_mask = prediction == class_id
    
    intersection = (gt_mask & pred_mask).sum()
    total = gt_mask.sum() + pred_mask.sum()
    
    if total == 0:
        return float('nan')
    
    return 2 * intersection / total


def calculate_segmentation_metrics(ground_truth, prediction, num_classes, class_names):
    """
    Calculate comprehensive segmentation metrics.
    
    Returns:
        Dictionary with all metrics
    """
    # Pixel accuracy
    correct = (ground_truth == prediction).sum()
    total = ground_truth.size
    pixel_accuracy = correct / total
    
    # Per-class metrics
    iou_per_class = []
    dice_per_class = []
    accuracy_per_class = []
    class_pixels = []
    
    for class_id in range(num_classes):
        iou = calculate_iou(ground_truth, prediction, class_id)
        dice = calculate_dice(ground_truth, prediction, class_id)
        
        # Per-class accuracy (recall)
        gt_mask = ground_truth == class_id
        if gt_mask.sum() > 0:
            class_acc = ((prediction == class_id) & gt_mask).sum() / gt_mask.sum()
        else:
            class_acc = float('nan')
        
        iou_per_class.append(iou)
        dice_per_class.append(dice)
        accuracy_per_class.append(class_acc)
        class_pixels.append(gt_mask.sum())
    
    # Mean metrics (excluding NaN classes)
    valid_ious = [x for x in iou_per_class if not np.isnan(x)]
    valid_dices = [x for x in dice_per_class if not np.isnan(x)]
    valid_accs = [x for x in accuracy_per_class if not np.isnan(x)]
    
    mean_iou = np.mean(valid_ious) if valid_ious else 0
    mean_dice = np.mean(valid_dices) if valid_dices else 0
    mean_accuracy = np.mean(valid_accs) if valid_accs else 0
    
    # Frequency weighted IoU
    total_pixels = sum(class_pixels)
    freq_weighted_iou = sum(
        (pixels / total_pixels) * iou 
        for pixels, iou in zip(class_pixels, iou_per_class) 
        if not np.isnan(iou) and pixels > 0
    )
    
    return {
        'pixel_accuracy': pixel_accuracy,
        'mean_accuracy': mean_accuracy,
        'mean_iou': mean_iou,
        'mean_dice': mean_dice,
        'freq_weighted_iou': freq_weighted_iou,
        'iou_per_class': dict(zip(class_names, iou_per_class)),
        'dice_per_class': dict(zip(class_names, dice_per_class)),
        'accuracy_per_class': dict(zip(class_names, accuracy_per_class)),
    }

# Calculate metrics
metrics = calculate_segmentation_metrics(sample_mask, prediction, len(class_names), class_names)

print("Segmentation Evaluation Metrics:")
print("=" * 55)
print(f"\n  Overall Metrics:")
print(f"    Pixel Accuracy:      {metrics['pixel_accuracy']:.4f} ({metrics['pixel_accuracy']*100:.1f}%)")
print(f"    Mean Accuracy:       {metrics['mean_accuracy']:.4f}")
print(f"    Mean IoU (mIoU):     {metrics['mean_iou']:.4f}")
print(f"    Mean Dice:           {metrics['mean_dice']:.4f}")
print(f"    Freq. Weighted IoU:  {metrics['freq_weighted_iou']:.4f}")

In [None]:
# Display per-class metrics in a formatted table
print("\nPer-Class Metrics:")
print("=" * 65)
print(f"{'Class':>15s} {'IoU':>10s} {'Dice':>10s} {'Accuracy':>10s}")
print("-" * 65)

for name in class_names:
    iou = metrics['iou_per_class'][name]
    dice = metrics['dice_per_class'][name]
    acc = metrics['accuracy_per_class'][name]
    
    iou_str = f"{iou:.4f}" if not np.isnan(iou) else "N/A"
    dice_str = f"{dice:.4f}" if not np.isnan(dice) else "N/A"
    acc_str = f"{acc:.4f}" if not np.isnan(acc) else "N/A"
    
    print(f"{name:>15s} {iou_str:>10s} {dice_str:>10s} {acc_str:>10s}")

print("-" * 65)
print(f"{'Mean':>15s} {metrics['mean_iou']:>10.4f} {metrics['mean_dice']:>10.4f} {metrics['mean_accuracy']:>10.4f}")

In [None]:
# Visualize per-class IoU
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = plt.cm.tab10(np.linspace(0, 1, len(class_names)))

# IoU bar chart
ious = [metrics['iou_per_class'][name] for name in class_names]
ious_clean = [x if not np.isnan(x) else 0 for x in ious]
bars = axes[0].barh(class_names, ious_clean, color=colors)
axes[0].set_xlabel('IoU Score')
axes[0].set_title(f'Per-Class IoU (mIoU = {metrics["mean_iou"]:.4f})')
axes[0].set_xlim(0, 1)
axes[0].axvline(x=metrics['mean_iou'], color='red', linestyle='--', linewidth=2, label=f'mIoU')
axes[0].legend()

# Add value labels
for bar, iou in zip(bars, ious):
    label = f'{iou:.3f}' if not np.isnan(iou) else 'N/A'
    axes[0].text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2,
                label, va='center')

# Dice bar chart
dices = [metrics['dice_per_class'][name] for name in class_names]
dices_clean = [x if not np.isnan(x) else 0 for x in dices]
bars2 = axes[1].barh(class_names, dices_clean, color=colors)
axes[1].set_xlabel('Dice Score')
axes[1].set_title(f'Per-Class Dice (Mean = {metrics["mean_dice"]:.4f})')
axes[1].set_xlim(0, 1)
axes[1].axvline(x=metrics['mean_dice'], color='red', linestyle='--', linewidth=2, label='Mean Dice')
axes[1].legend()

for bar, dice in zip(bars2, dices):
    label = f'{dice:.3f}' if not np.isnan(dice) else 'N/A'
    axes[1].text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2,
                label, va='center')

plt.tight_layout()
plt.show()

### Confusion Matrix for Segmentation

A pixel-level confusion matrix shows which classes get confused with each other.

In [None]:
def compute_confusion_matrix(ground_truth, prediction, num_classes):
    """
    Compute pixel-level confusion matrix.
    
    Row = Ground Truth class
    Column = Predicted class
    Values = Number of pixels
    """
    cm = np.zeros((num_classes, num_classes), dtype=np.int64)
    
    for gt_class in range(num_classes):
        for pred_class in range(num_classes):
            cm[gt_class, pred_class] = ((ground_truth == gt_class) & 
                                        (prediction == pred_class)).sum()
    
    return cm


def plot_confusion_matrix(cm, class_names, normalize=False, title="Confusion Matrix"):
    """
    Plot confusion matrix as heatmap.
    """
    if normalize:
        # Normalize by row (ground truth)
        row_sums = cm.sum(axis=1, keepdims=True)
        cm_display = np.divide(cm, row_sums, where=row_sums!=0)
        fmt = '.2f'
    else:
        cm_display = cm
        fmt = 'd'
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    im = ax.imshow(cm_display, interpolation='nearest', cmap='Blues')
    ax.figure.colorbar(im, ax=ax)
    
    ax.set(xticks=np.arange(len(class_names)),
           yticks=np.arange(len(class_names)),
           xticklabels=class_names,
           yticklabels=class_names,
           xlabel='Predicted Class',
           ylabel='Ground Truth Class',
           title=title)
    
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')
    
    # Add text annotations
    thresh = cm_display.max() / 2
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            value = cm_display[i, j]
            if normalize:
                text = f'{value:.2f}'
            else:
                if value > 1000:
                    text = f'{value/1000:.1f}K'
                else:
                    text = f'{value}'
            ax.text(j, i, text,
                   ha='center', va='center',
                   color='white' if cm_display[i, j] > thresh else 'black',
                   fontsize=10)
    
    plt.tight_layout()
    plt.show()

# Compute and plot confusion matrix
cm = compute_confusion_matrix(sample_mask, prediction, len(class_names))

print("Pixel-Level Confusion Matrix:")
plot_confusion_matrix(cm, class_names, normalize=False, 
                     title="Confusion Matrix (Pixel Counts)")

In [None]:
# Normalized confusion matrix (shows per-class recall)
print("Normalized Confusion Matrix (Per-class Recall):")
plot_confusion_matrix(cm, class_names, normalize=True,
                     title="Confusion Matrix (Normalized by Ground Truth)")

---

## Step 7: CloudWatch Training Metrics

During training, SageMaker Semantic Segmentation emits these metrics to CloudWatch:

| Metric | Description | Good Values |
|--------|-------------|-------------|
| `train:loss` | Training loss | Decreasing |
| `validation:loss` | Validation loss | Decreasing |
| `train:mIoU` | Training mean IoU | Increasing |
| `validation:mIoU` | Validation mean IoU | **Primary metric** - Increasing |
| `train:pixacc` | Training pixel accuracy | Increasing |
| `validation:pixacc` | Validation pixel accuracy | Increasing |

### What to Watch For

**Healthy Training:**
- Loss decreasing steadily
- mIoU increasing on both training and validation
- Small gap between training and validation metrics

**Overfitting Signs:**
- Training mIoU keeps improving, validation mIoU plateaus/decreases
- Large gap between training and validation metrics

**Underfitting Signs:**
- Both metrics are poor and improve very slowly
- Consider: larger model, more data, longer training

In [None]:
# Simulate training metrics over epochs
np.random.seed(42)
epochs = 30

# Simulate healthy training curves for segmentation
# Segmentation typically has slower convergence than classification

# Loss curves
base_loss = 1.5
train_loss = [max(0.1, base_loss * np.exp(-0.08 * e) + np.random.normal(0, 0.02)) for e in range(epochs)]
val_loss = [max(0.15, base_loss * np.exp(-0.06 * e) + 0.05 + np.random.normal(0, 0.03)) for e in range(epochs)]

# mIoU curves
base_miou = 0.2
train_miou = [min(0.85, base_miou + 0.02 * e + np.random.normal(0, 0.01)) for e in range(epochs)]
val_miou = [min(0.75, base_miou - 0.02 + 0.018 * e + np.random.normal(0, 0.015)) for e in range(epochs)]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
axes[0].plot(range(1, epochs + 1), train_loss, 'b-', label='Training Loss', linewidth=2)
axes[0].plot(range(1, epochs + 1), val_loss, 'r--', label='Validation Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Progress: Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# mIoU plot
axes[1].plot(range(1, epochs + 1), train_miou, 'b-', label='Training mIoU', linewidth=2)
axes[1].plot(range(1, epochs + 1), val_miou, 'r--', label='Validation mIoU', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Mean IoU')
axes[1].set_title('Training Progress: Mean IoU')
axes[1].set_ylim(0, 1)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final Training mIoU: {train_miou[-1]:.4f}")
print(f"Final Validation mIoU: {val_miou[-1]:.4f}")
print(f"Final Training Loss: {train_loss[-1]:.4f}")
print(f"Final Validation Loss: {val_loss[-1]:.4f}")

---

## Summary

In this exercise, you learned:

### 1. Data Format
- **Images**: JPG/PNG in `train/`, `validation/` folders
- **Annotations**: Grayscale PNG where pixel value = class ID
- **Matching**: Same filename for image and annotation
- **Dimensions**: Annotation must match image size exactly

### 2. Architecture Options

| Algorithm | Description | Speed | Accuracy |
|-----------|-------------|-------|----------|
| FCN | Fully Convolutional | Fast | Good |
| PSP | Pyramid Scene Parsing | Medium | Better |
| DeepLab | Atrous Convolutions | Slower | Best |

### 3. Key Hyperparameters

| Category | Parameters |
|----------|------------|
| Architecture | `algorithm`, `backbone`, `use_pretrained_model` |
| Training | `epochs`, `mini_batch_size`, `learning_rate`, `crop_size` |
| Optimizer | `optimizer`, `momentum`, `weight_decay` |
| Scheduler | `lr_scheduler` |

### 4. Output Format
- Grayscale PNG mask
- Same dimensions as input image
- Pixel values = predicted class IDs (0 to num_classes-1)

### 5. Evaluation Metrics

| Metric | Description | Primary? |
|--------|-------------|----------|
| mIoU | Mean Intersection over Union | **Yes** |
| Pixel Accuracy | % pixels correct | No (misleading with imbalance) |
| Dice Score | 2×TP/(2×TP+FP+FN) | Medical imaging |
| Per-class IoU | IoU for each class | Detailed analysis |

### Instance Requirements

| Task | Instance Types | Notes |
|------|----------------|-------|
| Training | ml.g4dn.xlarge, ml.p3.2xlarge | **GPU required**, memory intensive |
| Inference | ml.c5.xlarge (CPU), ml.g4dn.xlarge (GPU) | GPU for real-time |

### Cost Considerations
- Training costs: $10-30+ depending on dataset and settings
- **More expensive than classification** due to pixel-level processing
- Use Spot Instances for up to 70% savings
- Reduce `crop_size` and use `fcn` algorithm for faster iteration
- Use ResNet-50 backbone unless you need maximum accuracy

### Next Steps
1. Obtain real pixel-annotated data (Cityscapes, PASCAL VOC, ADE20K)
2. Use SageMaker Ground Truth for custom segmentation labeling
3. Experiment with different algorithms (FCN vs PSP vs DeepLab)
4. Monitor mIoU during training - it's the key metric
5. Address class imbalance if needed (weighted loss)

## Resources

- [SageMaker Semantic Segmentation Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/semantic-segmentation.html)
- [Semantic Segmentation Hyperparameters](https://docs.aws.amazon.com/sagemaker/latest/dg/segmentation-hyperparameters.html)
- [Cityscapes Dataset](https://www.cityscapes-dataset.com/) - Urban scene segmentation
- [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/) - Multi-class segmentation
- [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/) - Scene parsing dataset
- [SageMaker Ground Truth](https://docs.aws.amazon.com/sagemaker/latest/dg/sms.html) - Semantic segmentation labeling
- [AWS Pricing Calculator](https://calculator.aws/) - Estimate training costs