# 01 – Green Roof Detection (U-Net Semantic Segmentation)

This notebook walks through:
1. Dataset preparation (GeoTIFF patches + binary labels)
2. Model building (U-Net with pre-trained ResNet34 encoder)
3. Training with cross-entropy loss and cosine-annealing LR schedule
4. Inference on a full orthophoto tile
5. Visual inspection of predictions

In [None]:
import sys
sys.path.insert(0, '..')

import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch

from src.detection.dataset import GreenRoofDataset
from src.detection.model import build_unet
from src.detection.train import train_model
from src.detection.predict import predict_raster

## 1  Paths – adapt to your local data layout

In [None]:
IMAGE_DIR   = '../data/raw/images'     # GeoTIFF patches (RGB or RGB+NIR)
LABEL_DIR   = '../data/labels'         # Binary label GeoTIFFs (same filenames)
OUTPUT_DIR  = '../data/results/models' # Where checkpoints are saved
ORTHO_PATH  = '../data/raw/berlin_ortho_2025.tif'  # Full orthophoto for inference
PRED_PATH   = '../data/results/prediction.tif'     # Output prediction raster

## 2  Explore the dataset

In [None]:
import os

# Only run when data exists
if os.path.isdir(IMAGE_DIR) and os.listdir(IMAGE_DIR):
    dataset = GreenRoofDataset(
        image_dir=IMAGE_DIR,
        label_dir=LABEL_DIR,
        image_size=(256, 256),
        bands=[0, 1, 2],  # RGB
    )
    print(f'Dataset size: {len(dataset)} patches')

    # Visualise a few samples
    fig, axes = plt.subplots(3, 2, figsize=(8, 12))
    for i, ax_row in enumerate(axes):
        img, mask = dataset[i]
        ax_row[0].imshow(img.permute(1, 2, 0).numpy())
        ax_row[0].set_title(f'Image patch {i}')
        ax_row[0].axis('off')
        ax_row[1].imshow(mask.numpy(), cmap='Greens', vmin=0, vmax=1)
        ax_row[1].set_title(f'Label mask {i}')
        ax_row[1].axis('off')
    plt.tight_layout()
    plt.show()
else:
    print('Data not yet available – skipping visualisation.')

## 3  Build the U-Net model

In [None]:
model = build_unet(
    encoder_name='resnet34',
    encoder_weights='imagenet',
    in_channels=3,
    num_classes=2,
)
print(model)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'\nTrainable parameters: {n_params:,}')

## 4  Train the model

In [None]:
if os.path.isdir(IMAGE_DIR) and os.listdir(IMAGE_DIR):
    history = train_model(
        image_dir=IMAGE_DIR,
        label_dir=LABEL_DIR,
        output_dir=OUTPUT_DIR,
        encoder_name='resnet34',
        encoder_weights='imagenet',
        in_channels=3,
        num_classes=2,
        image_size=(256, 256),
        batch_size=8,
        num_epochs=30,
        learning_rate=1e-4,
        val_split=0.15,
    )

    # Plot training curves
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'],   label='Validation')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Cross-Entropy Loss')
    axes[0].set_title('Loss curves')
    axes[0].legend()

    axes[1].plot(history['val_iou'], color='green')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Mean IoU')
    axes[1].set_title('Validation IoU')

    plt.tight_layout()
    plt.show()
else:
    print('Data not available – skipping training.')

## 5  Run inference on the full orthophoto

In [None]:
CHECKPOINT = f'{OUTPUT_DIR}/best_model.pth'

if os.path.isfile(ORTHO_PATH) and os.path.isfile(CHECKPOINT):
    pred_path = predict_raster(
        image_path=ORTHO_PATH,
        checkpoint_path=CHECKPOINT,
        output_path=PRED_PATH,
        encoder_name='resnet34',
        in_channels=3,
        num_classes=2,
        patch_size=256,
        stride=128,
        batch_size=4,
    )
    print(f'Prediction saved to: {pred_path}')

    # Visualise prediction
    with rasterio.open(pred_path) as src:
        pred_data = src.read(1)

    with rasterio.open(ORTHO_PATH) as src:
        rgb = src.read([1, 2, 3]).transpose(1, 2, 0)
        rgb = (rgb / rgb.max() * 255).astype('uint8')

    fig, axes = plt.subplots(1, 2, figsize=(14, 7))
    axes[0].imshow(rgb)
    axes[0].set_title('RGB Orthophoto')
    axes[0].axis('off')
    axes[1].imshow(pred_data, cmap='Greens', vmin=0, vmax=1)
    axes[1].set_title('Green Roof Prediction')
    axes[1].axis('off')
    plt.tight_layout()
    plt.show()
else:
    print('Orthophoto or checkpoint not found – skipping inference.')