# Week 3 Lab: CNN Land-Cover Classification

**Objective**: Build, train, and interpret a convolutional neural network for land-cover classification using Sentinel-2 imagery.

**Learning Goals**:
- Implement a CNN architecture in PyTorch
- Prepare geospatial data for deep learning
- Train a model with proper monitoring
- Evaluate performance with appropriate metrics
- Interpret learned features through visualization

---

## Part 1: Setup and Imports

In [None]:
# Standard library imports
import os
import json
from datetime import datetime
from pathlib import Path

# Data handling
import numpy as np
import pandas as pd
import rasterio
from rasterio.features import rasterize
import geopandas as gpd

# Deep learning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap

# Metrics
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, classification_report
)

# Google Earth Engine
import ee
import geemap

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
# Initialize Earth Engine
try:
    ee.Initialize()
    print("Earth Engine initialized successfully")
except Exception as e:
    print(f"Error initializing Earth Engine: {e}")
    print("Run: earthengine authenticate")

In [None]:
# Create directory structure
dirs = [
    'data/raw',
    'data/processed',
    'data/labels',
    'models',
    'figures',
    'reports',
    'config'
]

for dir_path in dirs:
    Path(dir_path).mkdir(parents=True, exist_ok=True)

print("Directory structure created")

## Part 2: Configuration

Define all hyperparameters and settings in one place for easy modification.

In [None]:
# Configuration dictionary
config = {
    # Data parameters
    'aoi': {  # Replace with your Area of Interest
        'name': 'Los_Lagos',  # or 'Central_Chile' or 'Lake_Llanquihue'
        'bounds': [-73.5, -42.0, -72.5, -41.0],  # [west, south, east, north]
    },
    'date_range': ['2023-06-01', '2023-08-31'],  # Summer months
    'cloud_threshold': 20,  # Maximum cloud cover percentage
    
    # Class definitions
    'classes': {
        0: 'Forest',
        1: 'Agriculture',
        2: 'Parcels',
        3: 'Water'
    },
    'num_classes': 4,
    
    # Sentinel-2 bands to use
    'bands': ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12'],
    'num_bands': 10,
    
    # Patch extraction
    'patch_size': 128,  # Pixels
    'stride': 64,  # Overlap for more training examples
    
    # Data splitting
    'train_split': 0.7,
    'val_split': 0.3,
    
    # Model architecture
    'model': {
        'name': 'SimpleCNN',
        'channels': [32, 64, 128],  # Channels in each conv block
        'dropout': 0.5
    },
    
    # Training parameters
    'training': {
        'batch_size': 32,
        'num_epochs': 50,
        'learning_rate': 0.001,
        'optimizer': 'Adam',
        'weight_decay': 0.01,
        'scheduler': 'ReduceLROnPlateau',
        'scheduler_patience': 5,
        'early_stopping_patience': 10
    },
    
    # Device
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Paths
    'paths': {
        'labels': 'data/labels/training_polygons.geojson',
        'processed_data': 'data/processed',
        'model_save': 'models/week3_best_model.pth',
        'config_save': 'config/week3_config.json'
    }
}

# Save configuration
with open(config['paths']['config_save'], 'w') as f:
    json.dump(config, f, indent=2)

print("Configuration:")
print(json.dumps(config, indent=2))

## Part 3: Data Collection and Preparation

### Step 3.1: Define Area of Interest

In [None]:
# Define AOI as Earth Engine geometry
aoi_bounds = config['aoi']['bounds']
aoi = ee.Geometry.Rectangle(aoi_bounds)

# Visualize AOI
Map = geemap.Map()
Map.centerObject(aoi, zoom=10)
Map.addLayer(aoi, {'color': 'red'}, 'AOI')
Map

### Step 3.2: Load Training Labels

**Note**: You need to create training labels first using QGIS. See the Detailed Study Guide for instructions.

Your GeoJSON should have a 'class' field with values 0-3 corresponding to your land-cover classes.

In [None]:
# Load training polygons
labels_path = config['paths']['labels']

if os.path.exists(labels_path):
    training_labels = gpd.read_file(labels_path)
    print(f"Loaded {len(training_labels)} training polygons")
    print(f"\nClass distribution:")
    print(training_labels['class'].value_counts().sort_index())
    
    # Visualize on map
    Map.addLayer(ee.FeatureCollection(labels_path), {'color': 'blue'}, 'Training Labels')
else:
    print(f"Training labels not found at {labels_path}")
    print("Please create training labels in QGIS first.")
    print("See Week 3 Detailed Study Guide, Section 'Collecting Training Data'")

### Step 3.3: Collect Sentinel-2 Imagery

In [None]:
def get_sentinel2_composite(aoi, date_range, cloud_threshold, bands):
    """
    Get cloud-free Sentinel-2 composite for AOI.
    
    Args:
        aoi: Earth Engine geometry
        date_range: [start_date, end_date] as strings
        cloud_threshold: Maximum cloud cover percentage
        bands: List of band names to include
    
    Returns:
        Earth Engine image
    """
    # Load Sentinel-2 Surface Reflectance collection
    s2 = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \
        .filterBounds(aoi) \
        .filterDate(date_range[0], date_range[1]) \
        .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', cloud_threshold))
    
    # Cloud masking function
    def mask_clouds(image):
        qa = image.select('QA60')
        cloud_mask = qa.bitwiseAnd(1 << 10).eq(0).And(
                     qa.bitwiseAnd(1 << 11).eq(0))
        return image.updateMask(cloud_mask).select(bands)
    
    # Apply cloud masking and create median composite
    composite = s2.map(mask_clouds).median().clip(aoi)
    
    return composite

# Get composite
s2_composite = get_sentinel2_composite(
    aoi,
    config['date_range'],
    config['cloud_threshold'],
    config['bands']
)

# Visualize
vis_params = {
    'bands': ['B4', 'B3', 'B2'],
    'min': 0,
    'max': 3000,
    'gamma': 1.4
}
Map.addLayer(s2_composite, vis_params, 'Sentinel-2 Composite')
Map

### Step 3.4: Extract Patches and Labels

This is a simplified version. For production use, consider using TorchGeo or custom data loaders.

In [None]:
def extract_patches_from_labels(image, labels_gdf, patch_size, class_field='class'):
    """
    Extract image patches centered on training label polygons.
    
    This is a simplified approach. For each polygon:
    1. Get centroid
    2. Extract patch_size x patch_size region
    3. Assign label based on polygon class
    
    Args:
        image: Earth Engine image
        labels_gdf: GeoDataFrame with training polygons
        patch_size: Size of patches in pixels
        class_field: Field name containing class labels
    
    Returns:
        patches: List of numpy arrays (num_bands, patch_size, patch_size)
        labels: List of class labels
    """
    patches = []
    labels = []
    
    # Get image projection and scale
    projection = image.select(0).projection()
    scale = projection.nominalScale().getInfo()
    
    print(f"Extracting patches at {scale}m resolution...")
    
    for idx, row in labels_gdf.iterrows():
        # Get polygon centroid
        centroid = row.geometry.centroid
        lon, lat = centroid.x, centroid.y
        
        # Define patch region (patch_size pixels around centroid)
        half_size = (patch_size * scale) / 2
        region = ee.Geometry.Rectangle(
            [lon - half_size, lat - half_size, lon + half_size, lat + half_size]
        )
        
        try:
            # Extract patch
            patch_array = geemap.ee_to_numpy(
                image,
                region=region,
                scale=scale,
                bands=config['bands']
            )
            
            # Check if patch has correct size
            if patch_array.shape[1:] == (patch_size, patch_size):
                patches.append(patch_array)
                labels.append(row[class_field])
            
            if (idx + 1) % 10 == 0:
                print(f"Extracted {idx + 1}/{len(labels_gdf)} patches")
                
        except Exception as e:
            print(f"Error extracting patch {idx}: {e}")
            continue
    
    print(f"\nSuccessfully extracted {len(patches)} patches")
    return np.array(patches), np.array(labels)

# Extract patches
if os.path.exists(labels_path):
    patches, patch_labels = extract_patches_from_labels(
        s2_composite,
        training_labels,
        config['patch_size']
    )
    
    print(f"\nPatches shape: {patches.shape}")
    print(f"Labels shape: {patch_labels.shape}")
    print(f"\nClass distribution in patches:")
    unique, counts = np.unique(patch_labels, return_counts=True)
    for cls, count in zip(unique, counts):
        print(f"  {config['classes'][cls]}: {count}")
else:
    print("Skipping patch extraction - no training labels found")

### Step 3.5: Normalize Data

In [None]:
def normalize_patches(patches):
    """
    Normalize patches using per-band standardization.
    
    Args:
        patches: (N, C, H, W) array
    
    Returns:
        normalized: (N, C, H, W) array
        stats: Dictionary with mean and std for each band
    """
    # Compute per-band statistics
    mean = patches.mean(axis=(0, 2, 3))  # Average over samples and spatial dims
    std = patches.std(axis=(0, 2, 3))
    
    # Normalize
    normalized = (patches - mean[None, :, None, None]) / (std[None, :, None, None] + 1e-8)
    
    stats = {
        'mean': mean.tolist(),
        'std': std.tolist()
    }
    
    return normalized, stats

if 'patches' in locals():
    patches_normalized, normalization_stats = normalize_patches(patches)
    
    print("Normalization statistics:")
    for i, band in enumerate(config['bands']):
        print(f"  {band}: mean={normalization_stats['mean'][i]:.2f}, std={normalization_stats['std'][i]:.2f}")
    
    # Save normalization stats
    with open('config/normalization_stats.json', 'w') as f:
        json.dump(normalization_stats, f, indent=2)

### Step 3.6: Spatial Train/Validation Split

In [None]:
def spatial_train_test_split(patches, labels, train_ratio=0.7, random_state=42):
    """
    Split data spatially using random assignment.
    
    For a more sophisticated spatial split, use spatial clustering.
    
    Args:
        patches: (N, C, H, W) array
        labels: (N,) array
        train_ratio: Fraction for training
        random_state: Random seed
    
    Returns:
        train_patches, train_labels, val_patches, val_labels
    """
    np.random.seed(random_state)
    
    # Random permutation
    indices = np.random.permutation(len(patches))
    split_idx = int(len(patches) * train_ratio)
    
    train_idx = indices[:split_idx]
    val_idx = indices[split_idx:]
    
    return (
        patches[train_idx], labels[train_idx],
        patches[val_idx], labels[val_idx]
    )

if 'patches_normalized' in locals():
    train_patches, train_labels, val_patches, val_labels = spatial_train_test_split(
        patches_normalized,
        patch_labels,
        train_ratio=config['train_split'],
        random_state=SEED
    )
    
    print(f"Training set: {len(train_patches)} patches")
    print(f"Validation set: {len(val_patches)} patches")
    
    print(f"\nTraining class distribution:")
    unique, counts = np.unique(train_labels, return_counts=True)
    for cls, count in zip(unique, counts):
        print(f"  {config['classes'][cls]}: {count} ({100*count/len(train_labels):.1f}%)")
    
    print(f"\nValidation class distribution:")
    unique, counts = np.unique(val_labels, return_counts=True)
    for cls, count in zip(unique, counts):
        print(f"  {config['classes'][cls]}: {count} ({100*count/len(val_labels):.1f}%)")
    
    # Save processed data
    np.save('data/processed/train_patches.npy', train_patches)
    np.save('data/processed/train_labels.npy', train_labels)
    np.save('data/processed/val_patches.npy', val_patches)
    np.save('data/processed/val_labels.npy', val_labels)
    
    print("\nProcessed data saved to data/processed/")

### Step 3.7: Visualize Example Patches

In [None]:
def visualize_patches(patches, labels, class_names, num_examples=4):
    """
    Visualize example patches for each class.
    """
    num_classes = len(class_names)
    fig, axes = plt.subplots(num_classes, num_examples, figsize=(12, 3*num_classes))
    
    for class_id in range(num_classes):
        # Get indices for this class
        class_indices = np.where(labels == class_id)[0]
        
        # Select random examples
        if len(class_indices) >= num_examples:
            selected = np.random.choice(class_indices, num_examples, replace=False)
        else:
            selected = class_indices
        
        for i, idx in enumerate(selected):
            patch = patches[idx]
            
            # Create RGB visualization (bands 2, 1, 0 = R, G, B)
            rgb = np.stack([patch[2], patch[1], patch[0]], axis=-1)
            
            # Normalize to [0, 1] for display
            rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)
            
            axes[class_id, i].imshow(rgb)
            axes[class_id, i].set_title(f"{class_names[class_id]}")
            axes[class_id, i].axis('off')
    
    plt.tight_layout()
    plt.savefig('figures/week3_example_patches.png', dpi=150, bbox_inches='tight')
    plt.show()

if 'train_patches' in locals():
    visualize_patches(train_patches, train_labels, config['classes'], num_examples=4)

## Part 4: Model Architecture

### Step 4.1: Define SimpleCNN Architecture

In [None]:
class SimpleCNN(nn.Module):
    """
    Simple CNN for land-cover classification.
    
    Architecture:
    - 3 convolutional blocks (conv → batch norm → ReLU → max pool)
    - Global average pooling
    - Dropout
    - Final classification layer
    """
    def __init__(self, num_bands=10, num_classes=4, channels=[32, 64, 128], dropout=0.5):
        super(SimpleCNN, self).__init__()
        
        # Block 1
        self.conv1 = nn.Conv2d(num_bands, channels[0], kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels[0])
        self.relu1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Block 2
        self.conv2 = nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels[1])
        self.relu2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Block 3
        self.conv3 = nn.Conv2d(channels[1], channels[2], kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(channels[2])
        self.relu3 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Classifier
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(channels[2], num_classes)
        
        # For storing activations (used in Grad-CAM)
        self.gradients = None
        self.activations = None
    
    def forward(self, x):
        # Block 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        
        # Block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        # Block 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        
        # Store activations for Grad-CAM
        if x.requires_grad:
            x.register_hook(self.save_gradient)
        self.activations = x
        
        x = self.pool3(x)
        
        # Global pooling and classification
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)
        
        return x
    
    def save_gradient(self, grad):
        """Hook to save gradients for Grad-CAM."""
        self.gradients = grad
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self):
        return self.activations

# Instantiate model
model = SimpleCNN(
    num_bands=config['num_bands'],
    num_classes=config['num_classes'],
    channels=config['model']['channels'],
    dropout=config['model']['dropout']
)

# Move to device
device = torch.device(config['device'])
model = model.to(device)

print(model)
print(f"\nModel moved to: {device}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

### Step 4.2: Test Model with Dummy Input

In [None]:
# Test forward pass
dummy_input = torch.randn(2, config['num_bands'], config['patch_size'], config['patch_size']).to(device)
dummy_output = model(dummy_input)

print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {dummy_output.shape}")
print(f"Output (logits): {dummy_output}")

## Part 5: Training Setup

### Step 5.1: Create PyTorch Datasets and DataLoaders

In [None]:
class LandCoverDataset(Dataset):
    """PyTorch dataset for land-cover patches."""
    
    def __init__(self, patches, labels):
        """
        Args:
            patches: (N, C, H, W) numpy array
            labels: (N,) numpy array
        """
        self.patches = torch.FloatTensor(patches)
        self.labels = torch.LongTensor(labels)
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, idx):
        return self.patches[idx], self.labels[idx]

# Create datasets
if 'train_patches' in locals():
    train_dataset = LandCoverDataset(train_patches, train_labels)
    val_dataset = LandCoverDataset(val_patches, val_labels)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['training']['batch_size'],
        shuffle=True,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['training']['batch_size'],
        shuffle=False,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
    # Test data loader
    sample_batch = next(iter(train_loader))
    print(f"\nSample batch shapes:")
    print(f"  Images: {sample_batch[0].shape}")
    print(f"  Labels: {sample_batch[1].shape}")

### Step 5.2: Define Loss Function and Optimizer

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
if config['training']['optimizer'] == 'Adam':
    optimizer = optim.Adam(
        model.parameters(),
        lr=config['training']['learning_rate'],
        weight_decay=config['training']['weight_decay']
    )
elif config['training']['optimizer'] == 'AdamW':
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['training']['learning_rate'],
        weight_decay=config['training']['weight_decay']
    )
else:
    optimizer = optim.SGD(
        model.parameters(),
        lr=config['training']['learning_rate'],
        momentum=0.9,
        weight_decay=config['training']['weight_decay']
    )

# Learning rate scheduler
if config['training']['scheduler'] == 'ReduceLROnPlateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=config['training']['scheduler_patience'],
        verbose=True
    )
else:
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=10,
        gamma=0.1
    )

print(f"Loss function: {criterion}")
print(f"Optimizer: {optimizer}")
print(f"Scheduler: {scheduler}")

## Part 6: Training Loop

### Step 6.1: Define Training and Validation Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track metrics
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def validate(model, val_loader, criterion, device):
    """Validate model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, np.array(all_preds), np.array(all_labels)

print("Training and validation functions defined")

### Step 6.2: Run Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

# Early stopping
best_val_loss = float('inf')
best_val_acc = 0.0
epochs_without_improvement = 0
early_stopping_patience = config['training']['early_stopping_patience']

# Training loop
num_epochs = config['training']['num_epochs']

print(f"Starting training for {num_epochs} epochs...\n")
print("=" * 80)

if 'train_loader' in locals():
    for epoch in range(num_epochs):
        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        
        # Validate
        val_loss, val_acc, val_preds, val_true = validate(
            model, val_loader, criterion, device
        )
        
        # Update scheduler
        if config['training']['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_loss)
        else:
            scheduler.step()
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Print progress
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc
            epochs_without_improvement = 0
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc,
                'config': config
            }, config['paths']['model_save'])
            
            print(f"  → New best model saved! (Val Loss: {val_loss:.4f})")
        else:
            epochs_without_improvement += 1
        
        # Early stopping
        if epochs_without_improvement >= early_stopping_patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            print(f"No improvement for {early_stopping_patience} epochs")
            break
        
        print("=" * 80)
    
    print(f"\nTraining complete!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    
    # Save training history
    with open('reports/week3_training_history.json', 'w') as f:
        json.dump(history, f, indent=2)
else:
    print("Skipping training - no data loaded")

### Step 6.3: Visualize Training Curves

In [None]:
if 'history' in locals() and len(history['train_loss']) > 0:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Loss curves
    ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
    ax1.plot(history['val_loss'], label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Accuracy curves
    ax2.plot(history['train_acc'], label='Train Accuracy', linewidth=2)
    ax2.plot(history['val_acc'], label='Val Accuracy', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('figures/week3_training_curves.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Training curves saved to figures/week3_training_curves.png")

## Part 7: Model Evaluation

### Step 7.1: Load Best Model

In [None]:
# Load best model checkpoint
if os.path.exists(config['paths']['model_save']):
    checkpoint = torch.load(config['paths']['model_save'])
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
    print(f"Validation loss: {checkpoint['val_loss']:.4f}")
    print(f"Validation accuracy: {checkpoint['val_acc']:.2f}%")
else:
    print("No saved model found")

### Step 7.2: Compute Detailed Metrics

In [None]:
if 'val_loader' in locals():
    # Get predictions on validation set
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    # Compute metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, support = precision_recall_fscore_support(
        all_labels, all_preds, average=None
    )
    
    # Create metrics dataframe
    metrics_df = pd.DataFrame({
        'Class': [config['classes'][i] for i in range(config['num_classes'])],
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1,
        'Support': support
    })
    
    print("\nPer-Class Metrics:")
    print(metrics_df.to_string(index=False))
    
    # Save metrics
    metrics_df.to_csv('reports/week3_metrics.csv', index=False)
    print("\nMetrics saved to reports/week3_metrics.csv")
    
    # Overall metrics
    print(f"\nOverall Accuracy: {accuracy*100:.2f}%")
    print(f"Macro-averaged F1: {f1.mean():.3f}")
    print(f"Weighted-averaged F1: {np.average(f1, weights=support):.3f}")

### Step 7.3: Confusion Matrix

In [None]:
if 'all_preds' in locals():
    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Normalize by row (true labels)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Absolute counts
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax1,
                xticklabels=[config['classes'][i] for i in range(config['num_classes'])],
                yticklabels=[config['classes'][i] for i in range(config['num_classes'])])
    ax1.set_xlabel('Predicted', fontsize=12)
    ax1.set_ylabel('Actual', fontsize=12)
    ax1.set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
    
    # Normalized
    sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', ax=ax2,
                xticklabels=[config['classes'][i] for i in range(config['num_classes'])],
                yticklabels=[config['classes'][i] for i in range(config['num_classes'])])
    ax2.set_xlabel('Predicted', fontsize=12)
    ax2.set_ylabel('Actual', fontsize=12)
    ax2.set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('figures/week3_confusion_matrix.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Confusion matrix saved to figures/week3_confusion_matrix.png")

### Step 7.4: Visualize Predictions

In [None]:
def visualize_predictions(model, val_dataset, device, num_examples=12):
    """
    Visualize predictions on validation examples.
    """
    model.eval()
    
    # Select random examples
    indices = np.random.choice(len(val_dataset), num_examples, replace=False)
    
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    axes = axes.flatten()
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, true_label = val_dataset[idx]
            
            # Predict
            output = model(image.unsqueeze(0).to(device))
            prob = torch.softmax(output, dim=1)
            pred_label = output.argmax(1).item()
            confidence = prob[0, pred_label].item()
            
            # Create RGB visualization
            rgb = np.stack([image[2], image[1], image[0]], axis=-1).numpy()
            rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)
            
            # Plot
            axes[i].imshow(rgb)
            
            # Title with true and predicted labels
            true_class = config['classes'][true_label]
            pred_class = config['classes'][pred_label]
            
            if true_label == pred_label:
                color = 'green'
                title = f"✓ {pred_class}\n({confidence:.2f})"
            else:
                color = 'red'
                title = f"✗ True: {true_class}\nPred: {pred_class} ({confidence:.2f})"
            
            axes[i].set_title(title, fontsize=10, color=color, fontweight='bold')
            axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig('figures/week3_predictions.png', dpi=150, bbox_inches='tight')
    plt.show()

if 'val_dataset' in locals():
    visualize_predictions(model, val_dataset, device, num_examples=12)
    print("Prediction visualizations saved to figures/week3_predictions.png")

## Part 8: Model Interpretation

### Step 8.1: Visualize Learned Filters

In [None]:
def visualize_conv_filters(model, layer_name='conv1', num_filters=16):
    """
    Visualize convolutional filters from a layer.
    """
    # Get layer
    layer = getattr(model, layer_name)
    
    # Get filter weights
    filters = layer.weight.data.cpu().numpy()
    
    # Plot
    fig, axes = plt.subplots(4, 4, figsize=(10, 10))
    axes = axes.flatten()
    
    for i in range(min(num_filters, filters.shape[0])):
        # Average across input channels
        filter_img = filters[i].mean(axis=0)
        
        axes[i].imshow(filter_img, cmap='gray')
        axes[i].set_title(f'Filter {i}', fontsize=10)
        axes[i].axis('off')
    
    plt.suptitle(f'Learned Filters from {layer_name}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(f'figures/week3_filters_{layer_name}.png', dpi=150, bbox_inches='tight')
    plt.show()

# Visualize filters from first layer
visualize_conv_filters(model, 'conv1', num_filters=16)
print("Filter visualizations saved to figures/week3_filters_conv1.png")

### Step 8.2: Grad-CAM Activation Maps

In [None]:
def generate_gradcam(model, image, target_class, device):
    """
    Generate Grad-CAM activation map.
    
    Args:
        model: CNN model with hooks
        image: Input image tensor (1, C, H, W)
        target_class: Class to visualize
        device: Device
    
    Returns:
        cam: Activation map (H, W)
    """
    model.eval()
    image = image.unsqueeze(0).to(device)
    image.requires_grad = True
    
    # Forward pass
    output = model(image)
    
    # Zero gradients
    model.zero_grad()
    
    # Backward pass for target class
    output[0, target_class].backward()
    
    # Get gradients and activations
    gradients = model.get_activations_gradient()
    activations = model.get_activations()
    
    # Weight activations by gradients
    weights = gradients.mean(dim=(2, 3), keepdim=True)
    cam = (weights * activations).sum(dim=1, keepdim=True)
    
    # Apply ReLU and normalize
    cam = torch.relu(cam)
    cam = cam / (cam.max() + 1e-8)
    
    # Resize to input size
    cam = torch.nn.functional.interpolate(
        cam,
        size=(image.shape[2], image.shape[3]),
        mode='bilinear',
        align_corners=False
    )
    
    return cam[0, 0].detach().cpu().numpy()

def visualize_gradcam_examples(model, val_dataset, device, num_examples=12):
    """
    Visualize Grad-CAM for multiple examples.
    """
    # Select examples (3 per class)
    indices_per_class = {}
    for cls in range(config['num_classes']):
        class_indices = np.where(val_dataset.labels.numpy() == cls)[0]
        if len(class_indices) >= 3:
            indices_per_class[cls] = np.random.choice(class_indices, 3, replace=False)
    
    fig, axes = plt.subplots(4, 6, figsize=(18, 12))
    
    for cls in range(config['num_classes']):
        if cls not in indices_per_class:
            continue
        
        for i, idx in enumerate(indices_per_class[cls]):
            image, label = val_dataset[idx]
            
            # Generate Grad-CAM
            cam = generate_gradcam(model, image, label, device)
            
            # Create RGB visualization
            rgb = np.stack([image[2], image[1], image[0]], axis=-1).numpy()
            rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)
            
            # Plot original
            axes[cls, i*2].imshow(rgb)
            axes[cls, i*2].set_title(f"{config['classes'][cls]}", fontsize=10)
            axes[cls, i*2].axis('off')
            
            # Plot with Grad-CAM overlay
            axes[cls, i*2+1].imshow(rgb)
            axes[cls, i*2+1].imshow(cam, cmap='jet', alpha=0.5)
            axes[cls, i*2+1].set_title('Grad-CAM', fontsize=10)
            axes[cls, i*2+1].axis('off')
    
    plt.tight_layout()
    plt.savefig('figures/week3_gradcam_examples.png', dpi=150, bbox_inches='tight')
    plt.show()

if 'val_dataset' in locals():
    visualize_gradcam_examples(model, val_dataset, device, num_examples=12)
    print("Grad-CAM visualizations saved to figures/week3_gradcam_examples.png")

## Part 9: Summary and Next Steps

### Summary of Results

In [None]:
print("="*80)
print("WEEK 3 LAB SUMMARY")
print("="*80)

if 'best_val_acc' in locals():
    print(f"\n✓ Model Training Complete")
    print(f"  - Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"  - Best Validation Loss: {best_val_loss:.4f}")
    print(f"  - Total Parameters: {total_params:,}")

if 'metrics_df' in locals():
    print(f"\n✓ Evaluation Metrics Computed")
    print(f"  - Overall Accuracy: {accuracy*100:.2f}%")
    print(f"  - Macro F1-Score: {f1.mean():.3f}")
    print(f"  - Per-class metrics saved to reports/week3_metrics.csv")

print(f"\n✓ Visualizations Generated")
print(f"  - Training curves: figures/week3_training_curves.png")
print(f"  - Confusion matrix: figures/week3_confusion_matrix.png")
print(f"  - Predictions: figures/week3_predictions.png")
print(f"  - Learned filters: figures/week3_filters_conv1.png")
print(f"  - Grad-CAM: figures/week3_gradcam_examples.png")

print(f"\n✓ Model and Configuration Saved")
print(f"  - Model checkpoint: {config['paths']['model_save']}")
print(f"  - Configuration: {config['paths']['config_save']}")

print("\n" + "="*80)
print("Next Steps:")
print("="*80)
print("1. Write model interpretation memo (reports/Week3_Model_Interpretation.md)")
print("2. Complete Ethics Thread reflection (reports/Week3_Ethics_Thread.md)")
print("3. Write weekly reflection (reports/Week3_Reflection.md)")
print("4. Update GitHub repository README")
print("5. Commit all materials to GitHub")
print("="*80)

---

## Congratulations!

You've successfully built, trained, and interpreted your first CNN for land-cover classification. This forms the foundation for Week 4's transfer learning and data fusion experiments.

**Key Achievements:**
- ✅ Prepared geospatial training data with spatial splitting
- ✅ Implemented a CNN architecture in PyTorch
- ✅ Trained a model with proper monitoring and early stopping
- ✅ Evaluated performance with multiple metrics
- ✅ Interpreted learned features through visualization
- ✅ Documented a reproducible training pipeline

**Remember to complete:**
1. Model interpretation memo (500-750 words)
2. Ethics Thread reflection (450-600 words)
3. Weekly reflection (300-400 words)
4. GitHub repository update

See you in Week 4 for transfer learning and multi-source fusion!