# SageMaker Semantic Segmentation Exercise

This notebook demonstrates Amazon SageMaker's **Semantic Segmentation** algorithm for pixel-level image classification.

## What You'll Learn
1. How to prepare pixel-level annotations
2. How to train a segmentation model
3. How to interpret segmentation masks

## What is Semantic Segmentation?

Semantic Segmentation classifies **every pixel** in an image, assigning each pixel to a class. Unlike object detection which provides bounding boxes, segmentation provides exact object shapes.

**Key Difference from Other CV Tasks:**

| Task | Output | Granularity |
|------|--------|-------------|
| Classification | Single label | Image-level |
| Object Detection | Bounding boxes + labels | Object-level |
| Semantic Segmentation | Pixel mask | Pixel-level |
| Instance Segmentation | Pixel mask per object | Pixel + Instance level |

## Use Cases

| Industry | Application |
|----------|-------------|
| Autonomous Driving | Road/lane detection, obstacle segmentation |
| Medical Imaging | Tumor segmentation, organ identification |
| Satellite Imagery | Land use classification, building detection |
| Robotics | Scene understanding, navigation |
| Fashion | Clothing segmentation, virtual try-on |

---

## Step 1: Setup and Imports

In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.image_uris import retrieve
from sagemaker.estimator import Estimator
import numpy as np
import json
import os
from datetime import datetime
from dotenv import load_dotenv
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Load environment variables from .env file
load_dotenv()

# Configure AWS session from environment variables
aws_profile = os.getenv('AWS_PROFILE')
aws_region = os.getenv('AWS_REGION', 'us-west-2')
sagemaker_role = os.getenv('SAGEMAKER_ROLE_ARN')

if aws_profile:
    boto3.setup_default_session(profile_name=aws_profile, region_name=aws_region)
else:
    boto3.setup_default_session(region_name=aws_region)

# SageMaker session and role
sagemaker_session = sagemaker.Session()

if sagemaker_role:
    role = sagemaker_role
else:
    role = get_execution_role()

region = sagemaker_session.boto_region_name

print(f"AWS Profile: {aws_profile or 'default'}")
print(f"SageMaker Role: {role}")
print(f"Region: {region}")
print(f"SageMaker SDK Version: {sagemaker.__version__}")

In [None]:
# Configuration
BUCKET_NAME = sagemaker_session.default_bucket()
PREFIX = "semantic-segmentation"

print(f"S3 Bucket: {BUCKET_NAME}")
print(f"S3 Prefix: {PREFIX}")

## Step 2: Understand Data Format

SageMaker Semantic Segmentation requires:

### Directory Structure
```
train/
  image001.jpg
  image002.jpg
train_annotation/
  image001.png  # Grayscale mask, pixel values = class IDs
  image002.png
validation/
validation_annotation/
```

### Annotation Format
- **PNG images** (uncompressed)
- **Grayscale** where each pixel value is the class ID
- Same dimensions as input image
- Class IDs from 0 to (num_classes - 1)
- Background is typically class 0

In [None]:
def generate_synthetic_segmentation_mask(height=256, width=256, num_classes=5, seed=None):
    """
    Generate a synthetic segmentation mask with random shapes.
    
    Returns:
        mask: 2D numpy array with class IDs
        class_names: List of class names
    """
    if seed is not None:
        np.random.seed(seed)
    
    class_names = ['background', 'road', 'building', 'vegetation', 'sky']
    
    # Start with background
    mask = np.zeros((height, width), dtype=np.uint8)
    
    # Add sky (top portion)
    sky_height = np.random.randint(height // 4, height // 2)
    mask[:sky_height, :] = 4  # sky class
    
    # Add road (bottom center)
    road_width = width // 3
    road_start = (width - road_width) // 2
    mask[height * 2 // 3:, road_start:road_start + road_width] = 1  # road class
    
    # Add random buildings
    num_buildings = np.random.randint(2, 5)
    for _ in range(num_buildings):
        bw = np.random.randint(30, 80)
        bh = np.random.randint(50, 120)
        bx = np.random.randint(0, width - bw)
        by = np.random.randint(sky_height, height - bh - 20)
        mask[by:by + bh, bx:bx + bw] = 2  # building class
    
    # Add vegetation patches
    num_veg = np.random.randint(3, 7)
    for _ in range(num_veg):
        vw = np.random.randint(20, 60)
        vh = np.random.randint(20, 50)
        vx = np.random.randint(0, width - vw)
        vy = np.random.randint(sky_height, height - vh)
        mask[vy:vy + vh, vx:vx + vw] = 3  # vegetation class
    
    return mask, class_names

# Generate sample mask
sample_mask, class_names = generate_synthetic_segmentation_mask(seed=42)

print(f"Mask shape: {sample_mask.shape}")
print(f"Classes: {class_names}")
print(f"Unique values in mask: {np.unique(sample_mask)}")

In [None]:
def visualize_segmentation(mask, class_names, title="Segmentation Mask"):
    """
    Visualize segmentation mask with colored classes.
    """
    num_classes = len(class_names)
    colors = plt.cm.tab10(np.linspace(0, 1, num_classes))
    
    # Create colored mask
    colored_mask = np.zeros((*mask.shape, 3))
    for class_id in range(num_classes):
        colored_mask[mask == class_id] = colors[class_id][:3]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Grayscale mask (actual annotation format)
    axes[0].imshow(mask, cmap='gray', vmin=0, vmax=num_classes-1)
    axes[0].set_title("Annotation (Grayscale - pixel values = class IDs)")
    axes[0].axis('off')
    
    # Colored visualization
    axes[1].imshow(colored_mask)
    axes[1].set_title(title)
    axes[1].axis('off')
    
    # Legend
    patches = [mpatches.Patch(color=colors[i], label=class_names[i]) 
               for i in range(num_classes)]
    axes[1].legend(handles=patches, loc='upper right', fontsize=10)
    
    plt.tight_layout()
    plt.show()

visualize_segmentation(sample_mask, class_names, "Synthetic Scene Segmentation")

## Step 3: Training Configuration

### Architecture Options

| Algorithm | Description | Best For |
|-----------|-------------|----------|
| FCN | Fully Convolutional Network | Fast inference |
| PSP | Pyramid Scene Parsing | Multi-scale features |
| DeepLabV3 | Atrous convolutions | High accuracy |

### Backbone Networks

| Backbone | Depth | Speed | Accuracy |
|----------|-------|-------|----------|
| ResNet-50 | 50 layers | Faster | Good |
| ResNet-101 | 101 layers | Slower | Better |

### Key Hyperparameters

| Parameter | Description | Default |
|-----------|-------------|---------|
| `num_classes` | Number of segmentation classes | Required |
| `algorithm` | fcn, psp, deeplab | fcn |
| `backbone` | resnet-50, resnet-101 | resnet-50 |
| `use_pretrained_model` | Use pretrained backbone | True |
| `epochs` | Training epochs | 30 |
| `learning_rate` | Initial learning rate | 0.001 |
| `mini_batch_size` | Batch size | 16 |
| `crop_size` | Training crop size | 240 |

In [None]:
# Get Semantic Segmentation container image
semantic_segmentation_image = retrieve(
    framework='semantic-segmentation',
    region=region,
    version='1'
)

print(f"Semantic Segmentation Image URI: {semantic_segmentation_image}")

In [None]:
# Example estimator configuration (for reference)
print("""
Semantic Segmentation Estimator Configuration:
===============================================

semantic_segmentation_estimator = Estimator(
    image_uri=semantic_segmentation_image,
    role=role,
    instance_count=1,
    instance_type='ml.p3.2xlarge',  # GPU required
    output_path=f's3://{BUCKET_NAME}/{PREFIX}/output',
    sagemaker_session=sagemaker_session,
    base_job_name='semantic-segmentation'
)

hyperparameters = {
    "num_classes": 5,
    "algorithm": "deeplab",         # fcn, psp, or deeplab
    "backbone": "resnet-50",
    "use_pretrained_model": "True",
    "epochs": 30,
    "learning_rate": 0.001,
    "lr_scheduler": "poly",
    "mini_batch_size": 16,
    "optimizer": "sgd",
    "momentum": 0.9,
    "weight_decay": 0.0001,
    "crop_size": 480,
    "num_training_samples": 1000,
}

Data channels:
- train: Training images (JPG)
- train_annotation: Training masks (PNG, grayscale)
- validation: Validation images
- validation_annotation: Validation masks
""")

## Step 4: Understanding Model Output

The model outputs a grayscale PNG image where each pixel value is the predicted class ID.

In [None]:
def simulate_prediction_comparison(ground_truth, class_names):
    """
    Simulate model prediction and compare to ground truth.
    """
    # Simulate prediction with some noise
    prediction = ground_truth.copy()
    
    # Add some random misclassifications (simulate model errors)
    noise_mask = np.random.random(ground_truth.shape) < 0.05  # 5% noise
    random_classes = np.random.randint(0, len(class_names), ground_truth.shape)
    prediction[noise_mask] = random_classes[noise_mask]
    
    return prediction

# Generate prediction
prediction = simulate_prediction_comparison(sample_mask, class_names)

# Visualize comparison
num_classes = len(class_names)
colors = plt.cm.tab10(np.linspace(0, 1, num_classes))

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Ground truth
colored_gt = np.zeros((*sample_mask.shape, 3))
for class_id in range(num_classes):
    colored_gt[sample_mask == class_id] = colors[class_id][:3]
axes[0].imshow(colored_gt)
axes[0].set_title("Ground Truth")
axes[0].axis('off')

# Prediction
colored_pred = np.zeros((*prediction.shape, 3))
for class_id in range(num_classes):
    colored_pred[prediction == class_id] = colors[class_id][:3]
axes[1].imshow(colored_pred)
axes[1].set_title("Prediction")
axes[1].axis('off')

# Error map
error_map = (sample_mask != prediction).astype(float)
axes[2].imshow(error_map, cmap='Reds')
axes[2].set_title(f"Errors (red pixels: {error_map.sum():.0f})")
axes[2].axis('off')

# Legend
patches = [mpatches.Patch(color=colors[i], label=class_names[i]) 
           for i in range(num_classes)]
fig.legend(handles=patches, loc='lower center', ncol=num_classes, fontsize=10)

plt.tight_layout()
plt.subplots_adjust(bottom=0.15)
plt.show()

## Step 5: Evaluation Metrics

### Mean Intersection over Union (mIoU)
Primary metric for semantic segmentation:
```
IoU(class) = True Positives / (True Positives + False Positives + False Negatives)
mIoU = Average IoU across all classes
```

### Pixel Accuracy
```
Pixel Accuracy = Correctly classified pixels / Total pixels
```

In [None]:
def calculate_segmentation_metrics(ground_truth, prediction, num_classes):
    """
    Calculate segmentation metrics.
    
    Returns:
        dict with pixel_accuracy, mean_iou, and per-class IoU
    """
    # Pixel accuracy
    correct = (ground_truth == prediction).sum()
    total = ground_truth.size
    pixel_accuracy = correct / total
    
    # Per-class IoU
    iou_per_class = []
    for class_id in range(num_classes):
        gt_mask = ground_truth == class_id
        pred_mask = prediction == class_id
        
        intersection = (gt_mask & pred_mask).sum()
        union = (gt_mask | pred_mask).sum()
        
        if union > 0:
            iou = intersection / union
        else:
            iou = 0.0
        
        iou_per_class.append(iou)
    
    mean_iou = np.mean(iou_per_class)
    
    return {
        'pixel_accuracy': pixel_accuracy,
        'mean_iou': mean_iou,
        'iou_per_class': iou_per_class
    }

# Calculate metrics
metrics = calculate_segmentation_metrics(sample_mask, prediction, len(class_names))

print("Segmentation Metrics:")
print("=" * 40)
print(f"Pixel Accuracy: {metrics['pixel_accuracy']:.4f}")
print(f"Mean IoU (mIoU): {metrics['mean_iou']:.4f}")
print(f"\nPer-class IoU:")
for i, (name, iou) in enumerate(zip(class_names, metrics['iou_per_class'])):
    print(f"  {name}: {iou:.4f}")

In [None]:
# Visualize per-class IoU
fig, ax = plt.subplots(figsize=(10, 5))

colors_bar = plt.cm.tab10(np.linspace(0, 1, len(class_names)))
bars = ax.barh(class_names, metrics['iou_per_class'], color=colors_bar)

ax.set_xlabel('IoU Score')
ax.set_title(f'Per-Class IoU (mIoU = {metrics["mean_iou"]:.4f})')
ax.set_xlim(0, 1)

# Add value labels
for bar, iou in zip(bars, metrics['iou_per_class']):
    ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2,
           f'{iou:.4f}', va='center')

plt.tight_layout()
plt.show()

---

## Summary

In this exercise, you learned:

1. **Data Format**:
   - Images: JPG/PNG in train/, validation/ folders
   - Annotations: Grayscale PNG (pixel value = class ID)
   - Same filename for image and annotation

2. **Architecture Options**:
   - FCN: Fast, basic segmentation
   - PSP: Pyramid pooling for multi-scale
   - DeepLabV3: Atrous convolutions for accuracy

3. **Key Hyperparameters**:
   - `algorithm`: fcn, psp, deeplab
   - `backbone`: resnet-50, resnet-101
   - `crop_size`: Training patch size

4. **Output Format**:
   - Grayscale PNG mask
   - Same dimensions as input
   - Pixel values are class IDs

5. **Evaluation Metrics**:
   - Mean IoU (mIoU): Primary metric
   - Pixel Accuracy: Overall accuracy
   - Per-class IoU: Class-specific performance

### Instance Requirements

| Task | Instance Types |
|------|----------------|
| Training | ml.p2.xlarge, ml.p3.2xlarge, ml.g4dn.xlarge |
| Inference | ml.c5.xlarge (CPU) or GPU instances |

### Next Steps

- Prepare real image data with pixel annotations
- Use SageMaker Ground Truth for semantic segmentation labeling
- Experiment with different algorithms (FCN vs PSP vs DeepLab)
- Try different crop sizes for your image resolution