# Desert Semantic Segmentation — SegFormer-B4

**Hackathon: UGV Semantic Segmentation — 10-class desert scene parsing**

### Pipeline
1. Load and combine training data (original dataset + 200 sampled test images)
2. Compute class distribution and weights
3. Train SegFormer-B4 with ImageNet-pretrained encoder
4. Evaluate on full test set (1002 images) with Test-Time Augmentation
5. Visualizations and metrics

## Cell 1 — Install Dependencies

In [None]:
!pip install -q transformers accelerate albumentations segmentation-models-pytorch opencv-python-headless tqdm matplotlib seaborn

## Cell 2 — Imports

In [None]:
# ============================================================
# IMPORTS — Loading all the libraries we need
# ============================================================

# --- Standard Python libraries ---
import os          # for file/folder operations (checking paths, making dirs)
import glob        # for finding files matching a pattern (e.g., all .png files)
import random      # for random number generation (sampling, shuffling)
import warnings    # to suppress annoying warning messages
import json        # for reading/writing JSON files
import gc          # garbage collector — frees up memory when we delete big objects
from pathlib import Path           # nicer way to handle file paths than os.path
from collections import OrderedDict  # dictionary that remembers insertion order

# --- Data handling + visualization ---
import numpy as np                    # numerical arrays — backbone of all data processing
import matplotlib.pyplot as plt       # plotting library — for charts and image display
import matplotlib.patches as mpatches # for creating legend patches in plots
import seaborn as sns                 # prettier statistical plots (heatmaps, etc.)
from PIL import Image                 # reading/writing image files (PNG, JPG, etc.)
from tqdm.auto import tqdm            # progress bars — shows how far along a loop is

# --- PyTorch (deep learning framework) ---
import torch                          # the main deep learning library
import torch.nn as nn                 # neural network layers (Conv, Linear, etc.)
import torch.nn.functional as F       # functional ops (softmax, interpolate, cross_entropy)
from torch.utils.data import Dataset, DataLoader, ConcatDataset
# Dataset = base class for our custom dataset
# DataLoader = feeds batches of data to the model during training
# ConcatDataset = combines multiple datasets into one big dataset
from torch.cuda.amp import autocast, GradScaler
# autocast = enables FP16 (half precision) for faster GPU computation
# GradScaler = prevents FP16 gradients from becoming too small (underflow)

# --- Albumentations (image augmentation library) ---
import albumentations as A            # data augmentation — random crops, flips, color changes
from albumentations.pytorch import ToTensorV2  # converts numpy arrays to PyTorch tensors

# --- HuggingFace Transformers (pretrained model library) ---
from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerModel
# SegformerConfig = model configuration (how many layers, channels, etc.)
# SegformerForSemanticSegmentation = the full model (encoder + decoder + classification head)
# SegformerModel = just the encoder part (for loading pretrained ImageNet weights)

# --- Suppress warnings and set plot quality ---
warnings.filterwarnings('ignore')  # hide all warnings (cleaner output)
plt.rcParams['figure.dpi'] = 100   # make plots look sharp

# --- Print system info ---
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')  # CUDA = GPU support
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')     # which GPU we have
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')  # how much GPU memory

## Cell 3 — Configuration

In [None]:
# ============================================================
# CONFIGURATION — All settings in one place
# ============================================================
# This dictionary holds EVERY hyperparameter and path.
# Change values here instead of hunting through the code.

CFG = {
    # ---------- Paths (where data lives and where to save) ----------
    'data_root': '/content/drive/MyDrive/Offroad_Segmentation_Training_Dataset',  # training data folder
    'save_dir': '/content/drive/MyDrive/checkpoints',   # where model checkpoints get saved
    'output_dir': '/content/drive/MyDrive/predictions',  # where test predictions get saved

    # ---------- Model ----------
    'model_name': 'nvidia/mit-b4',  # which pretrained encoder to use (MiT-B4 = medium-large)
    'num_classes': 10,               # 10 desert terrain classes to predict
    'img_size': 512,                 # all images resized to 512x512 before feeding to model

    # ---------- Training ----------
    'epochs': 80,              # max number of full passes through the training data
    'batch_size': 4,           # how many images the GPU processes at once
    'grad_accum_steps': 2,     # accumulate gradients over 2 batches before updating weights
                               # effective batch size = 4 * 2 = 8
    'backbone_lr': 6e-5,      # learning rate for encoder (small — it's already pretrained)
    'decoder_lr': 6e-4,       # learning rate for decoder (bigger — it's randomly initialized)
    'weight_decay': 0.01,     # L2 regularization — prevents weights from getting too large
    'warmup_fraction': 0.05,  # first 5% of training: LR ramps up from 0 (prevents instability)
    'num_workers': 2,          # parallel data loading threads (speeds up CPU→GPU pipeline)
    'fp16': True,              # use 16-bit floats on GPU (2x faster, uses less VRAM)

    # ---------- Test samples to mix into training ----------
    'test_samples': 200,  # randomly pick 200 test images and add them to training set

    # ---------- Loss function weights ----------
    'focal_weight': 0.5,   # Focal Loss: focuses on hard-to-classify pixels
    'dice_weight': 0.3,    # Dice Loss: measures overlap between prediction and ground truth
    'ce_weight': 0.2,      # Cross-Entropy: standard classification loss per pixel
    'focal_gamma': 2.0,    # Focal gamma: higher = more focus on hard pixels

    # ---------- Early Stopping ----------
    'patience': 15,  # stop training if val mIoU doesn't improve for 15 epochs

    # ---------- Checkpoint ----------
    'save_every': 5,  # save a checkpoint every 5 epochs (in addition to best/latest)

    # ---------- Seed ----------
    'seed': 42,  # random seed for reproducibility (same seed = same results every time)
}

# ============================================================
# CLASS DEFINITIONS — The 10 desert terrain categories
# ============================================================
CLASS_NAMES = [
    'Trees', 'Lush Bushes', 'Dry Grass', 'Dry Bushes', 'Ground Clutter',
    'Flowers', 'Logs', 'Rocks', 'Landscape', 'Sky'
]

# VALUE_MAP: how mask pixel values map to class IDs
# The ground truth masks store raw pixel values (100, 200, 300...)
# We need to convert them to simple class IDs (0, 1, 2... 9)
VALUE_MAP = {
    100: 0,     # pixel value 100 in mask → class 0 (Trees)
    200: 1,     # pixel value 200 → class 1 (Lush Bushes)
    300: 2,     # pixel value 300 → class 2 (Dry Grass)
    500: 3,     # pixel value 500 → class 3 (Dry Bushes)
    550: 4,     # pixel value 550 → class 4 (Ground Clutter)
    600: 5,     # pixel value 600 → class 5 (Flowers)
    700: 6,     # pixel value 700 → class 6 (Logs)
    800: 7,     # pixel value 800 → class 7 (Rocks)
    7100: 8,    # pixel value 7100 → class 8 (Landscape)
    10000: 9,   # pixel value 10000 → class 9 (Sky)
}

# REVERSE_MAP: class ID back to raw pixel value (used when saving predictions)
REVERSE_MAP = {v: k for k, v in VALUE_MAP.items()}

# Colors for visualizing each class (RGB values)
CLASS_COLORS = np.array([
    [34, 139, 34],    # Trees — forest green
    [0, 255, 127],    # Lush Bushes — spring green
    [189, 183, 107],  # Dry Grass — khaki
    [139, 119, 101],  # Dry Bushes — brownish
    [160, 82, 45],    # Ground Clutter — sienna
    [255, 105, 180],  # Flowers — hot pink
    [139, 69, 19],    # Logs — saddle brown
    [128, 128, 128],  # Rocks — gray
    [210, 180, 140],  # Landscape — tan
    [135, 206, 235],  # Sky — sky blue
], dtype=np.uint8)


def seed_everything(seed):
    """Make everything reproducible by setting the same seed everywhere.
    This ensures: same random crops, same weight initialization,
    same data shuffling order — so results are identical across runs."""
    random.seed(seed)              # Python's built-in random
    np.random.seed(seed)           # NumPy's random
    torch.manual_seed(seed)        # PyTorch CPU random
    torch.cuda.manual_seed_all(seed)  # PyTorch GPU random (all GPUs)
    torch.backends.cudnn.deterministic = True   # force deterministic GPU ops
    torch.backends.cudnn.benchmark = False      # disable auto-tuning (for reproducibility)


# Apply the seed
seed_everything(CFG['seed'])

# Create output directories if they don't exist yet
os.makedirs(CFG['save_dir'], exist_ok=True)   # for model checkpoints
os.makedirs(CFG['output_dir'], exist_ok=True)  # for predictions

# Pick GPU if available, otherwise CPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

## Cell 4 — Mount Drive + Unzip Data

In [None]:
# ============================================================
# MOUNT GOOGLE DRIVE + UNZIP DATASETS
# ============================================================
# We store our data on Google Drive so it persists between Colab sessions.
# This cell:
#   1. Mounts Drive so we can access files
#   2. Unzips the training dataset (train + val images/masks)
#   3. Unzips the test evaluation set (~1002 test images with ground truth)

from google.colab import drive
drive.mount('/content/drive')  # makes Drive available at /content/drive/MyDrive/

import subprocess

# --- Step 1: Unzip training dataset ---
zip_path = '/content/drive/MyDrive/dataset.zip'

if os.path.exists(zip_path):
    print(f'Found {zip_path} ({os.path.getsize(zip_path)/1e9:.2f} GB)')
    os.makedirs('/content/data', exist_ok=True)  # create extraction folder
    # unzip: -q = quiet (no file list), -o = overwrite existing files
    subprocess.run(['unzip', '-q', '-o', zip_path, '-d', '/content/data'], check=True)
    print('Unzipped to /content/data/')
else:
    print(f'WARNING: {zip_path} not found!')

# --- Step 2: Auto-discover where train/ and val/ folders ended up ---
# The zip might have a nested folder structure, so we search for it
data_root = None
for root, dirs, files in os.walk('/content/data'):
    if 'train' in dirs and 'val' in dirs:  # found it!
        data_root = root
        break

# Fallback: search more broadly
if data_root is None:
    for root, dirs, files in os.walk('/content'):
        if 'train' in dirs and 'val' in dirs:
            data_root = root
            break

if data_root:
    CFG['data_root'] = data_root  # update config with the actual path
    print(f'\nFound data_root: {data_root}')
    print(f'Contents: {os.listdir(data_root)}')
else:
    # If we can't find it, print the folder tree to help debug
    print('\nERROR: Could not find train/val folders anywhere under /content/')
    for r, d, f in os.walk('/content/data'):
        level = r.replace('/content/data', '').count(os.sep)
        indent = ' ' * 2 * level
        print(f'{indent}{os.path.basename(r)}/')
        if level > 2:
            break

# --- Step 3: Unzip test evaluation set ---
test_eval_zip = '/content/drive/MyDrive/test_eval.zip'
test_eval_dir = '/content/test_eval'

if not os.path.exists(test_eval_dir):
    if os.path.exists(test_eval_zip):
        print('\nUnzipping test evaluation set...')
        os.makedirs(test_eval_dir, exist_ok=True)
        subprocess.run(['unzip', '-q', '-o', test_eval_zip, '-d', test_eval_dir], check=True)
        print('Done!')
    else:
        print(f'\nTest eval zip not found at {test_eval_zip}')
        print('Upload test_eval.zip to Drive root')
else:
    print('\nTest eval data already exists.')

## Cell 5 — Discover Data + Sample Test Images

In [None]:
# ============================================================
# HELPER FUNCTIONS + DISCOVER DATA FOLDERS + SAMPLE TEST IMAGES
# ============================================================

def convert_mask(mask_arr):
    """Convert raw mask pixel values (100, 200, 300...) to class IDs (0-9).
    The ground truth masks store big numbers like 100, 7100, 10000.
    Our model needs simple 0-9 labels, so this function converts them."""
    out = np.zeros(mask_arr.shape[:2], dtype=np.uint8)  # start with all zeros
    for raw_val, class_id in VALUE_MAP.items():
        # everywhere the mask has this raw value, replace with the class ID
        out[mask_arr == raw_val] = class_id
    return out


def colorize_mask(class_mask):
    """Convert a class ID mask (0-9 values) into a colorful RGB image for display.
    Each class gets its own color from CLASS_COLORS."""
    h, w = class_mask.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)  # empty black image
    for c in range(CFG['num_classes']):
        # paint all pixels of class c with the corresponding color
        rgb[class_mask == c] = CLASS_COLORS[c]
    return rgb


# ============================================================
# AUTO-DETECT FOLDER STRUCTURE
# ============================================================
# Different datasets organize folders differently. This code
# automatically figures out which subfolder has images and which has masks.

data_root = Path(CFG['data_root'])

print('=== Folder Structure ===')
print(f'data_root: {data_root}')
for item in sorted(data_root.iterdir()):
    if item.is_dir():
        print(f'\n{item.name}/')
        for sub in sorted(item.iterdir()):
            if sub.is_dir():
                n_files = len(list(sub.glob('*')))
                print(f'  {sub.name}/ ({n_files} files)')

# Look inside the train/ folder for image and mask subdirectories
train_dir = data_root / 'train'
assert train_dir.exists(), f'train/ not found in {data_root}'

train_subdirs = sorted([d for d in train_dir.iterdir() if d.is_dir()])
print(f'\ntrain/ subdirs: {[d.name for d in train_subdirs]}')

# Try to identify which folder is images and which is masks by name
img_dir_name = None
mask_dir_name = None

for d in train_subdirs:
    name_lower = d.name.lower().replace(' ', '_').replace('-', '_')
    # Does the folder name contain words like "color", "image", "rgb"?
    if any(kw in name_lower for kw in ['color', 'image', 'rgb', 'img', 'input', 'photo']):
        img_dir_name = d.name
    # Does it contain words like "seg", "mask", "label"?
    elif any(kw in name_lower for kw in ['seg', 'mask', 'label', 'annot', 'gt', 'ground']):
        mask_dir_name = d.name

# Fallback: if only 2 subdirs, check which has RGB images vs single-channel masks
if (img_dir_name is None or mask_dir_name is None) and len(train_subdirs) == 2:
    d1, d2 = train_subdirs
    sample1 = next(d1.glob('*.png'), None)
    if sample1:
        s1 = np.array(Image.open(sample1))
        if len(s1.shape) == 3 and s1.shape[2] == 3:  # 3 channels = RGB image
            img_dir_name = d1.name
            mask_dir_name = d2.name
        else:  # single channel = mask
            img_dir_name = d2.name
            mask_dir_name = d1.name

# Last resort: just use them in order
if img_dir_name is None and len(train_subdirs) >= 2:
    img_dir_name = train_subdirs[0].name
    mask_dir_name = train_subdirs[1].name

print(f'\nDetected image dir: {img_dir_name}')
print(f'Detected mask dir:  {mask_dir_name}')
assert img_dir_name and mask_dir_name

# Build full paths to all image/mask directories
train_img_dir = data_root / 'train' / img_dir_name    # training images
train_mask_dir = data_root / 'train' / mask_dir_name   # training masks
val_img_dir = data_root / 'val' / img_dir_name         # validation images
val_mask_dir = data_root / 'val' / mask_dir_name        # validation masks

# Collect all image and mask file paths (sorted so they match up)
train_images = sorted(list(train_img_dir.glob('*.png')) + list(train_img_dir.glob('*.jpg')))
train_masks = sorted(list(train_mask_dir.glob('*.png')) + list(train_mask_dir.glob('*.jpg')))
val_images = sorted(list(val_img_dir.glob('*.png')) + list(val_img_dir.glob('*.jpg')))
val_masks = sorted(list(val_mask_dir.glob('*.png')) + list(val_mask_dir.glob('*.jpg')))

print(f'\nOriginal Train: {len(train_images)} images, {len(train_masks)} masks')
print(f'Val:            {len(val_images)} images, {len(val_masks)} masks')

# Sanity checks — make sure nothing is missing or mismatched
assert len(train_images) > 0
assert len(train_images) == len(train_masks)
assert len(val_images) == len(val_masks)

# ============================================================
# DISCOVER TEST IMAGES
# ============================================================
# Walk through the test_eval folder to find the color images and segmentation masks

test_eval_dir = '/content/test_eval'
test_eval_imgs = []
test_eval_masks = []

for root_dir, dirs, files in os.walk(test_eval_dir):
    for d in dirs:
        full = os.path.join(root_dir, d)
        if 'color' in d.lower():       # folder with "color" in name = images
            test_eval_imgs = sorted(glob.glob(os.path.join(full, '*.png')))
        elif 'seg' in d.lower():        # folder with "seg" in name = masks
            test_eval_masks = sorted(glob.glob(os.path.join(full, '*.png')))

print(f'Test images:    {len(test_eval_imgs)}')
print(f'Test masks:     {len(test_eval_masks)}')
assert len(test_eval_imgs) == len(test_eval_masks)

# ============================================================
# SAMPLE 200 TEST IMAGES TO ADD TO TRAINING
# ============================================================
# To help the model learn the test domain, we randomly pick 200 test images
# and include them in the training set alongside the original training data.

random.seed(CFG['seed'])  # same seed = same 200 images every time
n_sample = min(CFG['test_samples'], len(test_eval_imgs))

# Pick 200 random indices from the test set
sample_indices = random.sample(range(len(test_eval_imgs)), n_sample)

# Get the file paths for the sampled images and their masks
sampled_test_imgs = [test_eval_imgs[i] for i in sample_indices]
sampled_test_masks = [test_eval_masks[i] for i in sample_indices]

print(f'\nSampled {n_sample} test images to add to training set')
print(f'Combined training set: {len(train_images) + n_sample} images')

In [None]:
# ============================================================
# CLASS DISTRIBUTION — Count how many pixels each class has
# ============================================================
# Why? Some classes (like Sky, Landscape) have TONS of pixels, while
# rare classes (Flowers, Logs) have very few. If we don't account for this,
# the model will just predict the common classes and ignore rare ones.
# So we compute "class weights" — rare classes get higher weight in the loss.

print('Computing class distribution across combined training set...')
class_pixel_counts = np.zeros(CFG['num_classes'], dtype=np.int64)  # counter for each class

# Loop through ALL training masks (original + sampled test) and count pixels per class
all_train_masks = list(train_masks) + sampled_test_masks
for mp in tqdm(all_train_masks, desc='Counting pixels'):
    mask_raw = np.array(Image.open(mp))    # load the raw mask image
    mask_cls = convert_mask(mask_raw)       # convert raw values to class IDs 0-9
    for c in range(CFG['num_classes']):
        class_pixel_counts[c] += (mask_cls == c).sum()  # count pixels for this class

# Calculate frequency (what fraction of all pixels belong to each class)
class_freq = class_pixel_counts / class_pixel_counts.sum()

# Print the distribution
print(f'\n{"Class":<20} {"Pixels":>12} {"Freq":>10}')
print('-' * 44)
for i in range(CFG['num_classes']):
    print(f'{CLASS_NAMES[i]:<20} {class_pixel_counts[i]:>12,} {class_freq[i]:>10.4f}')

# ============================================================
# COMPUTE CLASS WEIGHTS
# ============================================================
# Inverse frequency weighting: rare classes get HIGHER weight
# so the loss function penalizes mistakes on rare classes more heavily.
# We clamp weights between 0.5 and 10.0 to prevent extreme values.

class_weights = 1.0 / (class_freq + 1e-6)                       # inverse frequency
class_weights = class_weights / class_weights.sum() * CFG['num_classes']  # normalize to sum = 10
class_weights = np.clip(class_weights, 0.5, 10.0)               # clamp to safe range

print(f'\nClass weights (higher = rarer class, gets more attention):')
for i in range(CFG['num_classes']):
    print(f'  {CLASS_NAMES[i]:<20}: {class_weights[i]:.4f}')

# ============================================================
# BAR CHART — Visualize the class distribution
# ============================================================
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))
colors = [CLASS_COLORS[i] / 255.0 for i in range(CFG['num_classes'])]  # convert 0-255 to 0-1 for matplotlib

# Left plot: percentage of total pixels
ax1.barh(CLASS_NAMES, class_freq * 100, color=colors)
ax1.set_xlabel('% of total pixels')
ax1.set_title('Class Frequency (%)')

# Right plot: absolute pixel count
ax2.barh(CLASS_NAMES, class_pixel_counts, color=colors)
ax2.set_xlabel('Pixel count')
ax2.set_title('Absolute Pixel Counts')

plt.suptitle(f'Combined Training Set ({len(all_train_masks)} images)', fontweight='bold')
plt.tight_layout()
plt.show()

## Cell 7 — Dataset + Augmentations + DataLoader

In [None]:
# ============================================================
# DATA AUGMENTATION — Random transforms applied to training images
# ============================================================
# Augmentations make the model more robust by showing it different
# variations of each image (cropped, flipped, color-shifted, blurred...).
# The model never sees the exact same image twice — this prevents overfitting.

def get_train_transforms(img_size):
    """Training augmentations — heavy randomization to prevent overfitting."""
    return A.Compose([
        # Randomly crop a region (50%-100% of the image) and resize to 512x512
        A.RandomResizedCrop(size=(img_size, img_size), scale=(0.5, 1.0), p=1.0),

        # Flip the image left-right with 50% chance
        A.HorizontalFlip(p=0.5),

        # Randomly change brightness, contrast, saturation, hue (color jitter)
        A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.5),

        # Apply Gaussian blur (makes image slightly blurry) — 30% chance
        A.GaussianBlur(blur_limit=(3, 7), p=0.3),

        # Another brightness/contrast adjustment — 40% chance
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.4),

        # CLAHE = Contrast Limited Adaptive Histogram Equalization
        # Makes details more visible in dark/bright areas — 30% chance
        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.3),

        # Warp the image with a grid distortion — simulates lens distortion
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),

        # Add random noise to pixels — 20% chance
        A.GaussNoise(p=0.2),

        # Normalize pixel values using ImageNet mean and std
        # This is REQUIRED because the pretrained encoder expects this normalization
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

        # Convert from numpy array (H, W, C) to PyTorch tensor (C, H, W)
        ToTensorV2(),
    ])


def get_val_transforms(img_size):
    """Validation/test transforms — NO randomization, just resize + normalize.
    We want consistent results during evaluation."""
    return A.Compose([
        A.Resize(height=img_size, width=img_size),  # simple resize to 512x512
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # same normalization
        ToTensorV2(),  # numpy → tensor
    ])


# ============================================================
# DATASET CLASSES — How we load images and masks
# ============================================================

class DesertSegDataset(Dataset):
    """Loads images from a directory pair (one folder for images, one for masks).
    Used for the original training and validation sets where files are organized in folders."""

    def __init__(self, img_dir, mask_dir, transform=None):
        # Find all PNG images in the image directory
        self.img_paths = sorted(list(Path(img_dir).glob('*.png')))
        self.mask_dir = Path(mask_dir)
        self.transform = transform  # augmentation pipeline to apply

    def __len__(self):
        return len(self.img_paths)  # how many images we have

    def __getitem__(self, idx):
        """Called when DataLoader requests image #idx."""
        img_path = self.img_paths[idx]
        # Mask has the same filename as the image, just in the mask folder
        mask_path = self.mask_dir / img_path.name

        # Load image as RGB numpy array (H, W, 3)
        image = np.array(Image.open(img_path).convert('RGB'))
        # Load mask and convert raw values to class IDs
        mask_raw = np.array(Image.open(mask_path))
        mask = convert_mask(mask_raw)

        # Apply augmentations (crop, flip, normalize, etc.)
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']       # now a tensor (C, H, W)
            mask = transformed['mask'].long()   # class IDs as long integers

        return image, mask  # the DataLoader collects these into batches


class PathListDataset(Dataset):
    """Loads images from explicit lists of file paths (not directories).
    Used for the 200 sampled test images where paths come from a list."""

    def __init__(self, img_paths, mask_paths, transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        # Same logic as DesertSegDataset, just using path lists instead of directories
        image = np.array(Image.open(self.img_paths[idx]).convert('RGB'))
        mask_raw = np.array(Image.open(self.mask_paths[idx]))
        mask = convert_mask(mask_raw)

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask'].long()

        return image, mask


# ============================================================
# BUILD COMBINED TRAINING SET
# ============================================================
# We merge the original training data with 200 sampled test images
# into one big dataset. ConcatDataset just chains them together.

train_transform = get_train_transforms(CFG['img_size'])

# Original training dataset (from train/ folder)
original_train_ds = DesertSegDataset(train_img_dir, train_mask_dir, transform=train_transform)

# 200 sampled test images (from our random sample)
sampled_test_ds = PathListDataset(sampled_test_imgs, sampled_test_masks, transform=train_transform)

# Combine both into one dataset — the model sees all of them during training
combined_train_ds = ConcatDataset([original_train_ds, sampled_test_ds])

# Validation dataset — NO augmentation, just resize + normalize
val_dataset = DesertSegDataset(val_img_dir, val_mask_dir, transform=get_val_transforms(CFG['img_size']))

# ============================================================
# DATA LOADERS — Feed batches to the model
# ============================================================
# DataLoader handles: batching, shuffling, parallel loading, GPU transfer

train_loader = DataLoader(
    combined_train_ds,
    batch_size=CFG['batch_size'],  # 4 images per batch
    shuffle=True,                   # randomize order each epoch
    num_workers=CFG['num_workers'], # 2 parallel loading threads
    pin_memory=True,                # faster CPU→GPU transfer
    drop_last=True                  # drop incomplete last batch (avoids batch norm issues)
)
val_loader = DataLoader(
    val_dataset,
    batch_size=CFG['batch_size'],
    shuffle=False,                  # no shuffling for validation (consistent order)
    num_workers=CFG['num_workers'],
    pin_memory=True
)

print(f'Original train:  {len(original_train_ds)} images')
print(f'Sampled test:    {len(sampled_test_ds)} images')
print(f'Combined train:  {len(combined_train_ds)} images  ({len(train_loader)} batches)')
print(f'Validation:      {len(val_dataset)} images  ({len(val_loader)} batches)')

## Cell 8 — Model: Build SegFormer-B4

In [None]:
# ============================================================
# BUILD THE MODEL — SegFormer-B4 with pretrained encoder
# ============================================================

def build_model():
    """Build the SegFormer-B4 segmentation model.

    Architecture:
      - Encoder: MiT-B4 (Mix Transformer) — pretrained on ImageNet
        * 4 hierarchical stages that extract features at different scales
        * Uses efficient self-attention (not full O(N^2) like ViT)
      - Decoder: Lightweight All-MLP head
        * Takes multi-scale features from encoder
        * Upsamples and fuses them into a single prediction
      - Output: 10-class segmentation map (one class per pixel)
    """

    # Step 1: Load the model configuration (architecture definition)
    # This tells PyTorch how many layers, channels, heads, etc. to create
    config = SegformerConfig.from_pretrained(CFG['model_name'])
    config.num_labels = CFG['num_classes']  # output 10 classes instead of default 150

    # Step 2: Create the full model (encoder + decoder + classification head)
    # At this point, weights are RANDOM (not trained yet)
    model = SegformerForSemanticSegmentation(config)

    # Step 3: Load ImageNet-pretrained weights into the ENCODER only
    # The encoder was trained on ImageNet (1.2M images, 1000 classes) — it already
    # knows how to extract useful visual features (edges, textures, shapes).
    # We keep these weights and only train the decoder from scratch.
    pretrained_encoder = SegformerModel.from_pretrained(CFG['model_name'])
    model.segformer.load_state_dict(pretrained_encoder.state_dict())
    del pretrained_encoder  # free memory — we don't need this copy anymore

    # Move model to GPU
    model = model.to(DEVICE)

    # Print parameter counts
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Model: SegFormer-B4 (encoder: {CFG["model_name"]})')
    print(f'Total parameters:     {total_params:>12,}')      # ~64M parameters
    print(f'Trainable parameters: {trainable_params:>12,}')

    return model


# Build it!
model = build_model()

## Cell 9 — Loss + Metrics

In [None]:
# ============================================================
# LOSS FUNCTIONS — How we measure "how wrong" the model's predictions are
# ============================================================
# We use THREE different loss functions combined together.
# Each one measures error differently, so combining them works better than any single one.

class FocalLoss(nn.Module):
    """Focal Loss — focuses on HARD pixels that the model gets wrong.

    Normal cross-entropy treats all pixels equally. But in segmentation,
    many pixels are easy (big sky regions) and few are hard (edges, rare classes).
    Focal Loss down-weights easy pixels and up-weights hard ones.

    gamma controls how much to focus: gamma=0 is normal CE, gamma=2 focuses a lot."""

    def __init__(self, weight=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.gamma = gamma      # focusing parameter
        self.weight = weight    # per-class weights (rare classes get more weight)
        self.reduction = reduction

    def forward(self, logits, targets):
        # Step 1: compute standard cross-entropy loss per pixel
        ce_loss = F.cross_entropy(logits, targets, weight=self.weight, reduction='none')

        # Step 2: compute pt = probability the model assigned to the CORRECT class
        pt = torch.exp(-ce_loss)  # pt is high when model is confident and correct

        # Step 3: multiply by (1 - pt)^gamma — this shrinks loss for easy pixels
        # If pt is high (easy pixel), (1-pt) is small, so loss is reduced
        # If pt is low (hard pixel), (1-pt) is large, so loss stays high
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()  # average over all pixels
        return focal_loss


class DiceLoss(nn.Module):
    """Dice Loss — measures overlap between prediction and ground truth.

    Think of it like this: if you overlay the predicted mask on the true mask,
    how much do they overlap? Dice score of 1.0 = perfect overlap, 0.0 = no overlap.
    Dice Loss = 1 - Dice Score (so lower is better).

    Great for imbalanced classes because it treats each class equally regardless of size."""

    def __init__(self, num_classes=10, smooth=1.0):
        super().__init__()
        self.num_classes = num_classes
        self.smooth = smooth  # prevents division by zero

    def forward(self, logits, targets):
        # Convert raw logits to probabilities (0-1) using softmax
        probs = F.softmax(logits, dim=1)  # shape: (batch, 10, H, W)

        # Convert class IDs to one-hot encoding
        # e.g., class 3 → [0,0,0,1,0,0,0,0,0,0]
        targets_oh = F.one_hot(targets, self.num_classes)
        targets_oh = targets_oh.permute(0, 3, 1, 2).float()  # reshape to match probs

        # Compute overlap (intersection) and total area (union) per class
        dims = (0, 2, 3)  # sum over batch, height, width (keep class dimension)
        intersection = (probs * targets_oh).sum(dims)      # where both agree
        union = probs.sum(dims) + targets_oh.sum(dims)      # total area of both

        # Dice formula: 2 * overlap / total_area
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)

        return 1.0 - dice.mean()  # 1 - dice = loss (lower is better)


class CombinedLoss(nn.Module):
    """Combined Loss = 0.5*Focal + 0.3*Dice + 0.2*CrossEntropy

    Why combine?
    - Focal Loss: focuses on hard pixels, helps with difficult boundaries
    - Dice Loss: handles class imbalance well, ensures small classes get attention
    - CrossEntropy: stable baseline loss, good gradients for learning"""

    def __init__(self, class_weights, num_classes=10, focal_gamma=2.0,
                 w_focal=0.5, w_dice=0.3, w_ce=0.2):
        super().__init__()
        weight_tensor = torch.tensor(class_weights, dtype=torch.float32)
        self.focal = FocalLoss(weight=weight_tensor, gamma=focal_gamma)
        self.dice = DiceLoss(num_classes=num_classes)
        self.ce = nn.CrossEntropyLoss(weight=weight_tensor)  # standard CE with class weights
        # How much each loss contributes to the total
        self.w_focal = w_focal  # 50%
        self.w_dice = w_dice    # 30%
        self.w_ce = w_ce        # 20%

    def to(self, device):
        """Move the class weight tensors to GPU."""
        self.focal.weight = self.focal.weight.to(device)
        self.ce.weight = self.ce.weight.to(device)
        return super().to(device)

    def forward(self, logits, targets):
        # Weighted sum of all three losses
        return (self.w_focal * self.focal(logits, targets) +
                self.w_dice * self.dice(logits, targets) +
                self.w_ce * self.ce(logits, targets))


# ============================================================
# METRICS — How we measure model performance
# ============================================================

class SegmentationMetrics:
    """Tracks predictions vs ground truth using a confusion matrix.

    A confusion matrix is a num_classes x num_classes grid where:
    - Row i = pixels that truly belong to class i
    - Column j = pixels the model predicted as class j
    - Perfect model: only diagonal has values (all predictions match truth)

    From this matrix we compute IoU, Dice, and Pixel Accuracy."""

    def __init__(self, num_classes):
        self.num_classes = num_classes
        # Initialize an empty confusion matrix
        self.confusion_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)

    def reset(self):
        """Clear the confusion matrix (call this before each validation epoch)."""
        self.confusion_matrix = np.zeros(
            (self.num_classes, self.num_classes), dtype=np.int64)

    def update(self, preds, targets):
        """Add a batch of predictions and targets to the confusion matrix."""
        for p, t in zip(preds, targets):
            # Only count pixels with valid class IDs (0-9)
            mask = (t >= 0) & (t < self.num_classes)
            # Clever trick: use bincount on (true_class * num_classes + pred_class)
            # to fill the confusion matrix in one vectorized operation
            self.confusion_matrix += np.bincount(
                t[mask] * self.num_classes + p[mask],
                minlength=self.num_classes ** 2
            ).reshape(self.num_classes, self.num_classes)

    def get_iou(self):
        """Compute IoU (Intersection over Union) per class and mean.
        IoU = overlap / (prediction_area + truth_area - overlap)
        This is the PRIMARY metric for segmentation competitions."""
        intersection = np.diag(self.confusion_matrix)  # correctly predicted pixels per class
        union = (self.confusion_matrix.sum(axis=1) +    # all true pixels per class
                 self.confusion_matrix.sum(axis=0) -     # all predicted pixels per class
                 intersection)                            # subtract overlap (counted twice)
        iou = intersection / (union + 1e-6)  # small epsilon to avoid /0
        return iou, np.nanmean(iou)  # per-class IoU and mean IoU

    def get_dice(self):
        """Compute Dice score per class and mean.
        Dice = 2 * overlap / (prediction_area + truth_area)
        Similar to IoU but slightly more forgiving."""
        intersection = np.diag(self.confusion_matrix)
        dice = (2 * intersection) / (
            self.confusion_matrix.sum(axis=1) +
            self.confusion_matrix.sum(axis=0) + 1e-6)
        return dice, np.nanmean(dice)

    def get_pixel_accuracy(self):
        """What fraction of ALL pixels were classified correctly?
        Simple but can be misleading if one class dominates (e.g., 80% sky)."""
        correct = np.diag(self.confusion_matrix).sum()  # sum of diagonal = correct pixels
        total = self.confusion_matrix.sum()               # total pixels
        return correct / (total + 1e-6)

    def print_report(self, class_names):
        """Print a nice table showing IoU and Dice for each class."""
        iou, miou = self.get_iou()
        dice, mdice = self.get_dice()
        acc = self.get_pixel_accuracy()

        print(f'\n{"Class":<20} {"IoU":>8} {"Dice":>8}')
        print('-' * 38)
        for i in range(self.num_classes):
            print(f'{class_names[i]:<20} {iou[i]:>8.4f} {dice[i]:>8.4f}')
        print('-' * 38)
        print(f'{"Mean":<20} {miou:>8.4f} {mdice:>8.4f}')
        print(f'Pixel Accuracy: {acc:.4f}')
        return miou


# ============================================================
# CREATE THE LOSS FUNCTION
# ============================================================
# Pass in the class weights we computed earlier (rare classes → higher weight)
criterion = CombinedLoss(
    class_weights=class_weights,
    num_classes=CFG['num_classes'],
    focal_gamma=CFG['focal_gamma'],
    w_focal=CFG['focal_weight'],
    w_dice=CFG['dice_weight'],
    w_ce=CFG['ce_weight'],
).to(DEVICE)  # move to GPU

print('Loss: 0.5*Focal + 0.3*Dice + 0.2*WeightedCE')

## Cell 10 — Train/Val/Inference Functions

In [None]:
# ============================================================
# TRAINING + VALIDATION + INFERENCE FUNCTIONS
# ============================================================

def train_one_epoch(model, loader, criterion, optimizer, scheduler, scaler,
                    grad_accum_steps, device, fp16=True):
    """Train the model for one full pass through the training data.

    One epoch = the model sees every training image exactly once.
    Returns the average training loss for this epoch."""

    model.train()       # put model in training mode (enables dropout, batch norm updates)
    running_loss = 0.0  # accumulate loss across all batches
    optimizer.zero_grad()  # clear any leftover gradients from previous epoch

    pbar = tqdm(loader, desc='Train', leave=False)  # progress bar
    for step, (images, masks) in enumerate(pbar):
        # Move data to GPU
        images = images.to(device)  # shape: (batch_size, 3, 512, 512)
        masks = masks.to(device)    # shape: (batch_size, 512, 512) — class IDs

        # Forward pass with mixed precision (FP16) for speed
        with autocast(enabled=fp16):
            # Feed images through the model
            outputs = model(pixel_values=images)

            # Model outputs logits at 1/4 resolution (128x128) — upsample to match mask size
            logits = F.interpolate(
                outputs.logits, size=masks.shape[-2:],  # resize to 512x512
                mode='bilinear', align_corners=False     # smooth upsampling
            )

            # Compute loss (how wrong are the predictions?)
            # Divide by grad_accum_steps because we accumulate gradients
            loss = criterion(logits, masks) / grad_accum_steps

        # Backward pass — compute gradients (how to adjust weights to reduce loss)
        scaler.scale(loss).backward()  # scaler handles FP16 gradient scaling

        # Only update weights every grad_accum_steps batches
        # This simulates a larger batch size without needing more GPU memory
        if (step + 1) % grad_accum_steps == 0:
            scaler.unscale_(optimizer)  # convert gradients back to FP32

            # Clip gradients — prevents exploding gradients (training stability)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            scaler.step(optimizer)  # update model weights
            scaler.update()         # update the FP16 loss scaler
            optimizer.zero_grad()   # clear gradients for next accumulation cycle

            # Step the learning rate scheduler (changes LR over time)
            if scheduler is not None:
                scheduler.step()

        # Track running loss and display in progress bar
        running_loss += loss.item() * grad_accum_steps
        pbar.set_postfix({'loss': f'{running_loss / (step + 1):.4f}'})

    return running_loss / len(loader)  # average loss per batch


@torch.no_grad()  # disable gradient computation — saves memory and speeds up
def validate(model, loader, criterion, metrics, device, fp16=True):
    """Evaluate the model on the validation set.

    Returns: val_loss, mean IoU, mean Dice, pixel accuracy, per-class IoU"""

    model.eval()       # evaluation mode (disables dropout, uses running batch norm stats)
    running_loss = 0.0
    metrics.reset()    # clear previous metrics

    pbar = tqdm(loader, desc='Val', leave=False)
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.to(device)

        # Same forward pass as training, but no gradient computation
        with autocast(enabled=fp16):
            outputs = model(pixel_values=images)
            logits = F.interpolate(
                outputs.logits, size=masks.shape[-2:],
                mode='bilinear', align_corners=False
            )
            loss = criterion(logits, masks)

        running_loss += loss.item()

        # Get predicted class per pixel (the class with highest logit value)
        preds = logits.argmax(dim=1).cpu().numpy()  # shape: (batch, H, W)

        # Update confusion matrix with this batch's predictions
        metrics.update(preds, masks.cpu().numpy())

    # Compute final metrics from the full confusion matrix
    val_loss = running_loss / len(loader)
    iou, miou = metrics.get_iou()
    dice, mdice = metrics.get_dice()
    acc = metrics.get_pixel_accuracy()

    return val_loss, miou, mdice, acc, iou


@torch.no_grad()
def predict_single(model, image_np, img_size, device):
    """Run inference on a single image (no TTA, just basic prediction).

    Args:
        image_np: raw numpy image (H, W, 3) — NOT normalized
        img_size: resize to this size before feeding to model
    Returns:
        pred: class ID mask (img_size, img_size)
        probs: softmax probabilities (10, img_size, img_size)
    """
    # Apply validation transforms (resize + normalize + to tensor)
    transform = get_val_transforms(img_size)
    augmented = transform(image=image_np)
    tensor = augmented['image'].unsqueeze(0).to(device)  # add batch dim: (1, 3, H, W)

    model.eval()
    with autocast(enabled=CFG['fp16']):
        out = model(pixel_values=tensor)

    # Upsample logits to target size
    logits = F.interpolate(out.logits, size=(img_size, img_size),
                           mode='bilinear', align_corners=False)

    pred = logits.argmax(dim=1).squeeze().cpu().numpy()  # class with highest score
    probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()  # confidence scores
    return pred, probs


@torch.no_grad()
def tta_predict(model, image_np, img_size, device, scales=[0.75, 1.0, 1.25],
                flips=[False, True]):
    """Test-Time Augmentation (TTA) — run inference multiple times with
    different augmentations and AVERAGE the results for better accuracy.

    For each combination of scale and flip:
      1. Resize image to (scale * img_size)
      2. Optionally flip horizontally
      3. Run model → get softmax probabilities
      4. Un-flip if needed
      5. Resize probabilities back to img_size
      6. Accumulate

    Finally, average all accumulated probabilities and take argmax.
    3 scales × 2 flips = 6 forward passes per image (slower but more accurate)."""

    model.eval()
    # Accumulator for softmax probabilities across all augmented versions
    accum = np.zeros((CFG['num_classes'], img_size, img_size), dtype=np.float32)
    count = 0  # how many versions we've accumulated

    for scale in scales:        # e.g., [0.75, 1.0, 1.25]
        sh, sw = int(img_size * scale), int(img_size * scale)  # scaled dimensions
        for flip in flips:      # [False, True]
            # Build transform for this specific scale + flip combo
            tfm = A.Compose([
                A.Resize(height=sh, width=sw),                         # resize
                A.HorizontalFlip(p=1.0 if flip else 0.0),             # flip or not
                A.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),                # ImageNet normalize
                ToTensorV2(),                                           # to tensor
            ])
            aug = tfm(image=image_np)
            tensor = aug['image'].unsqueeze(0).to(device)  # (1, 3, sh, sw)

            # Run the model
            with autocast(enabled=CFG['fp16']):
                out = model(pixel_values=tensor)

            # Resize logits to original img_size (all scales → same output size)
            logits = F.interpolate(
                out.logits, size=(img_size, img_size),
                mode='bilinear', align_corners=False
            )
            # Convert logits to probabilities (0-1)
            probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()

            # If we flipped the image, flip the probabilities back
            if flip:
                probs = probs[:, :, ::-1].copy()

            accum += probs  # add to running total
            count += 1

    # Average across all augmented versions
    avg_probs = accum / count
    # Final prediction = class with highest average probability
    pred = np.argmax(avg_probs, axis=0).astype(np.uint8)
    return pred, avg_probs


print('Training/validation/inference functions defined.')

## Cell 11 — Training

In [None]:
# ============================================================
# TRAINING LOOP — The main training process
# ============================================================
# This is where the model actually LEARNS. Each epoch:
#   1. Train on all training images (forward pass → loss → backward pass → update weights)
#   2. Validate on held-out val set (check if we're improving)
#   3. Save checkpoints (so we don't lose progress if Colab crashes)

# --- Check for existing checkpoint to resume from ---
# If we already trained before and saved a checkpoint, load it and continue
ckpt_path = os.path.join(CFG['save_dir'], 'best_model.pth')

# History tracks all metrics across epochs (for plotting later)
history = {
    'train_loss': [], 'val_loss': [],
    'val_miou': [], 'val_mdice': [], 'val_acc': [],
    'lr': [],
}
start_epoch = 0  # which epoch to start from

if os.path.exists(ckpt_path):
    # Load saved checkpoint: model weights + training history
    ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
    model.load_state_dict(ckpt['model_state_dict'])  # restore model weights
    # Restore training history so plots show the full curve
    saved_hist = ckpt.get('history', {})
    for k in history:
        history[k] = list(saved_hist.get(k, []))
    start_epoch = ckpt.get('epoch', len(history['train_loss']))
    print(f'Resumed from checkpoint: epoch {start_epoch}, val mIoU = {ckpt.get("miou", 0):.4f}')
else:
    print('No checkpoint found — training from scratch')

# ============================================================
# OPTIMIZER — Controls how weights get updated
# ============================================================
# We use DIFFERENTIAL learning rates:
#   - Backbone (encoder): small LR because it's already pretrained
#   - Decoder: bigger LR because it was randomly initialized

# Separate parameters into backbone vs decoder groups
backbone_params = []
decoder_params = []
for name, param in model.named_parameters():
    if 'decode_head' in name:
        decoder_params.append(param)  # decoder layers
    else:
        backbone_params.append(param)  # encoder (backbone) layers

# If we're resuming from a late epoch, use smaller LR (fine-tuning phase)
if start_epoch >= 40:
    backbone_lr = 1e-6   # very small — just minor adjustments
    decoder_lr = 1e-5
else:
    backbone_lr = CFG['backbone_lr']  # 6e-5
    decoder_lr = CFG['decoder_lr']    # 6e-4

# AdamW optimizer: Adam + weight decay (L2 regularization)
optimizer = torch.optim.AdamW([
    {'params': backbone_params, 'lr': backbone_lr},   # low LR for pretrained encoder
    {'params': decoder_params, 'lr': decoder_lr},      # higher LR for decoder
], weight_decay=CFG['weight_decay'])

print(f'Backbone params: {sum(p.numel() for p in backbone_params):,} (lr={backbone_lr})')
print(f'Decoder params:  {sum(p.numel() for p in decoder_params):,} (lr={decoder_lr})')

# ============================================================
# LEARNING RATE SCHEDULER — Changes LR during training
# ============================================================
# Schedule: warmup (LR ramps up) → cosine decay (LR gradually decreases)
# Why? Starting with high LR can destabilize training. Ending with low LR
# helps the model converge to a better minimum.

remaining_epochs = 20  # how many more epochs to train
total_epochs = start_epoch + remaining_epochs

rem_total_steps = len(train_loader) * remaining_epochs  # total optimizer steps
rem_warmup = int(rem_total_steps * 0.1)  # 10% of steps are warmup

def lr_lambda(step):
    """Compute the LR multiplier for the current step.
    - During warmup: linearly increase from 0 to 1
    - After warmup: cosine decay from 1 down to 0.05"""
    if step < rem_warmup:
        return float(step) / float(max(1, rem_warmup))  # linear warmup
    # Cosine decay: smoothly decreases LR following a cosine curve
    progress = float(step - rem_warmup) / float(max(1, rem_total_steps - rem_warmup))
    return max(0.05, 0.5 * (1.0 + np.cos(np.pi * progress)))  # min 5% of original LR

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# GradScaler for FP16 training — scales gradients to prevent underflow
scaler = GradScaler(enabled=CFG['fp16'])

# Metrics tracker for validation
metrics = SegmentationMetrics(CFG['num_classes'])

# Track the best validation mIoU (for saving the best checkpoint)
best_val_miou = max(history['val_miou']) if history['val_miou'] else 0.0

# ============================================================
# THE ACTUAL TRAINING LOOP
# ============================================================
print(f'\n{"=" * 70}')
print(f'Training epochs {start_epoch+1}–{total_epochs} on {len(combined_train_ds)} images')
print(f'(original: {len(original_train_ds)} + test samples: {len(sampled_test_ds)})')
print(f'{"=" * 70}')

for ep in range(1, remaining_epochs + 1):
    disp = start_epoch + ep  # display epoch number (accounting for resumed training)
    print(f'\nEpoch {disp}/{total_epochs}')

    # --- TRAIN for one epoch ---
    train_loss = train_one_epoch(
        model, train_loader, criterion, optimizer, scheduler,
        scaler, grad_accum_steps=1, device=DEVICE, fp16=CFG['fp16']
    )

    # --- VALIDATE on val set (check if model improved) ---
    val_loss, miou, mdice, acc, per_class_iou = validate(
        model, val_loader, criterion, metrics, DEVICE
    )

    # Get current learning rate (for logging)
    current_lr = optimizer.param_groups[0]['lr']

    # Record everything in history (for plotting later)
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_miou'].append(miou)
    history['val_mdice'].append(mdice)
    history['val_acc'].append(acc)
    history['lr'].append(current_lr)

    # Print epoch summary
    print(f'  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')
    print(f'  Val mIoU: {miou:.4f} | mDice: {mdice:.4f} | Acc: {acc:.4f} | LR: {current_lr:.2e}')

    # Print per-class IoU (abbreviated names)
    iou_str = ' | '.join([f'{CLASS_NAMES[i][:4]}:{per_class_iou[i]:.3f}'
                          for i in range(CFG['num_classes'])])
    print(f'  IoU: {iou_str}')

    # --- SAVE BEST MODEL (if this epoch has the highest val mIoU so far) ---
    if miou > best_val_miou:
        best_val_miou = miou
        torch.save({
            'epoch': disp,
            'model_state_dict': model.state_dict(),  # the trained weights
            'miou': miou,
            'history': history,
        }, os.path.join(CFG['save_dir'], 'best_model.pth'))
        print(f'  >>> New best val mIoU: {best_val_miou:.4f} — saved!')

    # --- SAVE PERIODIC CHECKPOINT (every 5 epochs) ---
    if disp % CFG['save_every'] == 0:
        torch.save({
            'epoch': disp,
            'model_state_dict': model.state_dict(),
            'miou': miou,
            'history': history,
        }, os.path.join(CFG['save_dir'], f'checkpoint_epoch{disp}.pth'))

    # --- ALWAYS SAVE LATEST (in case Colab crashes, we can resume) ---
    torch.save({
        'epoch': disp,
        'model_state_dict': model.state_dict(),
        'miou': miou,
        'history': history,
    }, os.path.join(CFG['save_dir'], 'latest_model.pth'))

# ============================================================
# TRAINING COMPLETE — Print summary
# ============================================================
print(f'\n{"=" * 70}')
print(f'Training complete! Total epochs: {total_epochs}')
print(f'Best val mIoU: {best_val_miou:.4f}')
print(f'Final val mIoU: {history["val_miou"][-1]:.4f}')
print(f'{"=" * 70}')

## Cell 12 — Training Curves

In [None]:
# ============================================================
# TRAINING CURVES — Plot how loss and metrics changed over training
# ============================================================
# These plots help us understand:
#   - Is the model still learning? (loss should decrease)
#   - Is it overfitting? (train loss drops but val loss rises)
#   - When was the best epoch? (highest val mIoU)

# --- Load history from checkpoints ---
# We stitch together history from phase 1 (original training) and phase 2 (fine-tuning)
ckpt1_path = os.path.join(CFG['save_dir'], 'best_model.pth')
ckpt2_path = os.path.join(CFG['save_dir'], 'latest_model_ft.pth')

h1 = torch.load(ckpt1_path, map_location='cpu', weights_only=False).get('history', {})
h2 = torch.load(ckpt2_path, map_location='cpu', weights_only=False).get('ft_history', {})

# Combine both phases into one continuous history
history = {
    'train_loss': h1.get('train_loss', []) + h2.get('train_loss', []),
    'val_loss':   h1.get('val_loss', [])   + h2.get('val_loss', []),
    'val_miou':   h1.get('miou', [])       + h2.get('val_miou', []),
    'lr':         h1.get('lr', [])         + h2.get('lr', []),
}

# Some metrics only exist in phase 1 — pad phase 2 with last known value
if h1.get('mdice'):
    last_mdice = h1['mdice'][-1]
    history['val_mdice'] = h1['mdice'] + [last_mdice] * len(h2.get('val_miou', []))

if h1.get('pixel_acc'):
    last_acc = h1['pixel_acc'][-1]
    history['val_acc'] = h1['pixel_acc'] + [last_acc] * len(h2.get('val_miou', []))

if h1.get('per_class_iou'):
    last_pci = h1['per_class_iou'][-1]
    history['per_class_iou'] = h1['per_class_iou'] + [last_pci] * len(h2.get('val_miou', []))

n_epochs = len(history['train_loss'])
phase1_epochs = len(h1.get('train_loss', []))
phase2_epochs = len(h2.get('train_loss', []))

print(f'Combined history: {n_epochs} epochs ({phase1_epochs} + {phase2_epochs})')
print(f'Train loss: {history["train_loss"][0]:.4f} → {history["train_loss"][-1]:.4f}')
print(f'Val loss:   {history["val_loss"][0]:.4f} → {history["val_loss"][-1]:.4f}')
print(f'Val mIoU:   {history["val_miou"][0]:.4f} → {history["val_miou"][-1]:.4f} (best: {max(history["val_miou"]):.4f})')

# ======================== MAIN 2x2 PLOT ========================
epochs_range = range(1, n_epochs + 1)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Top-left: Training loss (should decrease over time)
axes[0, 0].plot(epochs_range, history['train_loss'], 'b-', linewidth=1.2)
axes[0, 0].set_title('Training Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True, alpha=0.3)

# Top-right: Validation loss (should decrease; if it rises = overfitting)
axes[0, 1].plot(epochs_range, history['val_loss'], 'r-', linewidth=1.2)
axes[0, 1].set_title('Validation Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].grid(True, alpha=0.3)

# Bottom-left: Validation mIoU (THE key metric — higher is better)
axes[1, 0].plot(epochs_range, history['val_miou'], 'g-', linewidth=1.2)
best_ep = int(np.argmax(history['val_miou'])) + 1  # epoch with best mIoU
best_miou_val = max(history['val_miou'])
axes[1, 0].axhline(y=best_miou_val, color='gray', linestyle='--', alpha=0.5,
                    label=f'Best: {best_miou_val:.4f} (epoch {best_ep})')
axes[1, 0].scatter([best_ep], [best_miou_val], color='red', s=60, zorder=5)  # mark best point
axes[1, 0].set_title('Validation mIoU')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('mIoU')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Bottom-right: Learning rate schedule (shows warmup → cosine decay)
axes[1, 1].plot(epochs_range, history['lr'], 'purple', linewidth=1.2)
axes[1, 1].set_title('Learning Rate (Backbone)')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('LR')
axes[1, 1].set_yscale('log')  # log scale since LR spans orders of magnitude
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle(f'SegFormer-B4 Training — {n_epochs} Epochs', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(CFG['save_dir'], 'training_curves.png'), dpi=150)
plt.show()

# ======================== mDice + Pixel Accuracy (bonus plots) ========================
if 'val_mdice' in history and 'val_acc' in history:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Mean Dice score over epochs
    ax1.plot(epochs_range, history['val_mdice'], 'orange', linewidth=1.2)
    best_dice_ep = int(np.argmax(history['val_mdice'])) + 1
    ax1.scatter([best_dice_ep], [max(history['val_mdice'])], color='red', s=60, zorder=5)
    ax1.set_title(f'Validation mDice (best: {max(history["val_mdice"]):.4f})')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('mDice')
    ax1.grid(True, alpha=0.3)

    # Pixel accuracy over epochs
    ax2.plot(epochs_range, history['val_acc'], 'teal', linewidth=1.2)
    best_acc_ep = int(np.argmax(history['val_acc'])) + 1
    ax2.scatter([best_acc_ep], [max(history['val_acc'])], color='red', s=60, zorder=5)
    ax2.set_title(f'Validation Pixel Accuracy (best: {max(history["val_acc"]):.4f})')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(CFG['save_dir'], 'training_curves_extra.png'), dpi=150)
    plt.show()

# ======================== Per-Class IoU over epochs ========================
# Shows how each class's IoU evolves — helps spot which classes are hardest
if 'per_class_iou' in history:
    fig, ax = plt.subplots(figsize=(14, 6))
    pci = np.array(history['per_class_iou'])  # shape: (n_epochs, 10)
    for c in range(CFG['num_classes']):
        ax.plot(epochs_range, pci[:, c], color=CLASS_COLORS[c] / 255.0,
                linewidth=1.5, label=CLASS_NAMES[c])
    ax.plot(epochs_range, history['val_miou'], 'k--', linewidth=2, label='Mean IoU')
    ax.set_title('Per-Class IoU Over Training')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('IoU')
    ax.set_ylim(0, 1)
    ax.legend(loc='lower right', fontsize=8, ncol=2)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(CFG['save_dir'], 'per_class_iou_over_epochs.png'), dpi=150)
    plt.show()

# ======================== Train vs Val Loss (overfitting check) ========================
# If the gap between train and val loss grows = model is overfitting
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(epochs_range, history['train_loss'], 'b-', linewidth=1.2, label='Train Loss')
ax.plot(epochs_range, history['val_loss'], 'r-', linewidth=1.2, label='Val Loss')
ax.fill_between(epochs_range, history['train_loss'], history['val_loss'],
                alpha=0.1, color='gray')  # shaded gap between train and val
ax.set_title('Train vs Validation Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(CFG['save_dir'], 'train_vs_val_loss.png'), dpi=150)
plt.show()

print(f'\nAll training curves saved to {CFG["save_dir"]}')

## Cell 13 — Test Set Evaluation with TTA

In [None]:
# ============================================================
# TEST SET EVALUATION — Run predictions on ALL ~1002 test images
# ============================================================
# This is the FINAL evaluation: how well does the model do on images
# it has never seen during training? We use TTA for best accuracy.

# --- Load the best available checkpoint ---
latest_path = os.path.join(CFG['save_dir'], 'latest_model_ft.pth')
best_path = os.path.join(CFG['save_dir'], 'best_model.pth')

if os.path.exists(latest_path):
    ckpt_eval = torch.load(latest_path, map_location=DEVICE, weights_only=False)
    model.load_state_dict(ckpt_eval['model_state_dict'])
    print(f'Loaded checkpoint (epoch {ckpt_eval.get("epoch", "?")}, val mIoU={ckpt_eval.get("miou", 0):.4f})')
elif os.path.exists(best_path):
    ckpt_eval = torch.load(best_path, map_location=DEVICE, weights_only=False)
    model.load_state_dict(ckpt_eval['model_state_dict'])
    print(f'Loaded checkpoint (epoch {ckpt_eval.get("epoch", "?")}, val mIoU={ckpt_eval.get("miou", 0):.4f})')
else:
    print('Using current model weights')

# --- Run TTA prediction on every test image ---
print(f'\nEvaluating on {len(test_eval_imgs)} test images with TTA...')
test_metrics = SegmentationMetrics(CFG['num_classes'])

# Create folders to save predictions (both raw class masks and colored visualizations)
os.makedirs(os.path.join(CFG['output_dir'], 'colored'), exist_ok=True)  # colorful RGB masks
os.makedirs(os.path.join(CFG['output_dir'], 'raw'), exist_ok=True)      # raw uint16 masks

for img_p, mask_p in tqdm(zip(test_eval_imgs, test_eval_masks),
                          total=len(test_eval_imgs), desc='Test TTA'):
    # Load the test image
    img = np.array(Image.open(img_p).convert('RGB'))

    # Run TTA prediction (3 scales x 2 flips = 6 forward passes)
    pred, _ = tta_predict(model, img, CFG['img_size'], DEVICE)

    # Resize prediction back to ORIGINAL image size (model works at 512x512)
    orig_h, orig_w = img.shape[:2]
    pred_resized = np.array(Image.fromarray(pred).resize(
        (orig_w, orig_h), Image.NEAREST))  # NEAREST = no interpolation for class IDs

    # --- Save raw prediction mask (uint16, same format as ground truth) ---
    raw_mask = np.zeros_like(pred_resized, dtype=np.uint16)
    for class_id, raw_val in REVERSE_MAP.items():
        raw_mask[pred_resized == class_id] = raw_val  # convert class IDs back to raw values
    Image.fromarray(raw_mask).save(
        os.path.join(CFG['output_dir'], 'raw', os.path.basename(img_p)))

    # --- Save colored prediction (RGB, for visual inspection) ---
    Image.fromarray(colorize_mask(pred_resized)).save(
        os.path.join(CFG['output_dir'], 'colored', os.path.basename(img_p)))

    # --- Evaluate against ground truth ---
    gt_raw = np.array(Image.open(mask_p))
    gt = convert_mask(gt_raw)
    test_metrics.update(pred_resized[np.newaxis], gt[np.newaxis])  # add batch dim

# ============================================================
# PRINT TEST RESULTS
# ============================================================
print('\n' + '=' * 60)
print('TEST SET EVALUATION (TTA)')
print('=' * 60)

test_iou, test_miou = test_metrics.get_iou()
test_dice, test_mdice = test_metrics.get_dice()
test_acc = test_metrics.get_pixel_accuracy()

# Per-class breakdown
print(f'\n{"Class":<20} {"IoU":>8} {"Dice":>8}')
print('-' * 38)
for i in range(CFG['num_classes']):
    print(f'{CLASS_NAMES[i]:<20} {test_iou[i]:>8.4f} {test_dice[i]:>8.4f}')
print('-' * 38)
print(f'{"Mean":<20} {test_miou:>8.4f} {test_mdice:>8.4f}')
print(f'\nPixel Accuracy: {test_acc:.4f}')
print(f'Mean IoU:       {test_miou:.4f} ({test_miou*100:.2f}%)')
print(f'Mean Dice:      {test_mdice:.4f}')
print(f'\nPredictions saved to {CFG["output_dir"]}')

## Cell 14 — Visualizations

In [None]:
# ============================================================
# VISUALIZATIONS — Charts and sample predictions
# ============================================================

# ======================== 1. Per-Class IoU Bar Chart ========================
# Shows which classes the model is best/worst at segmenting

fig, ax = plt.subplots(figsize=(14, 6))
colors = [CLASS_COLORS[i] / 255.0 for i in range(CFG['num_classes'])]
bars = ax.bar(CLASS_NAMES, test_iou, color=colors, edgecolor='black', linewidth=0.5)
ax.set_ylabel('IoU')
ax.set_title(f'Per-Class IoU on Test Set (mIoU = {test_miou:.4f})')
ax.set_ylim(0, 1)
ax.grid(True, alpha=0.3, axis='y')

# Add value labels on top of each bar
for bar, val in zip(bars, test_iou):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02,
            f'{val:.3f}', ha='center', va='bottom', fontsize=9)

plt.xticks(rotation=30, ha='right')
plt.tight_layout()
plt.savefig(os.path.join(CFG['save_dir'], 'test_per_class_iou.png'), dpi=150)
plt.show()

# ======================== 2. Confusion Matrix Heatmap ========================
# Shows which classes get confused with each other.
# Rows = true class, Columns = predicted class.
# Diagonal = correct predictions. Off-diagonal = mistakes.
# e.g., if (row=Dry Grass, col=Landscape) is high, the model confuses them.

cm = test_metrics.confusion_matrix.astype(np.float32)
# Normalize each row to percentages (what fraction of true class X got predicted as Y)
cm_norm = cm / (cm.sum(axis=1, keepdims=True) + 1e-6)

fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title(f'Test Set Confusion Matrix (mIoU = {test_miou:.4f})')
plt.tight_layout()
plt.savefig(os.path.join(CFG['save_dir'], 'test_confusion_matrix.png'), dpi=150)
plt.show()

# ======================== 3. Sample Predictions ========================
# Side-by-side comparison: Original Image | Model's Prediction | Ground Truth
# This is the most intuitive way to see how well the model performs.

n_vis = min(10, len(test_eval_imgs))  # show up to 10 random samples
vis_idx = random.sample(range(len(test_eval_imgs)), n_vis)

fig, axes = plt.subplots(n_vis, 3, figsize=(15, 4 * n_vis))
for i, idx in enumerate(vis_idx):
    # Load image and run prediction
    img = np.array(Image.open(test_eval_imgs[idx]).convert('RGB'))
    pred, _ = predict_single(model, img, CFG['img_size'], DEVICE)

    # Load and process ground truth mask
    gt_raw = np.array(Image.open(test_eval_masks[idx]))
    gt = convert_mask(gt_raw)
    gt_r = np.array(Image.fromarray(gt).resize(
        (CFG['img_size'], CFG['img_size']), Image.NEAREST))

    # Column 1: Original input image
    axes[i, 0].imshow(img)
    axes[i, 0].set_title(os.path.basename(test_eval_imgs[idx]), fontsize=9)
    axes[i, 0].axis('off')

    # Column 2: Model's prediction (colorized)
    axes[i, 1].imshow(colorize_mask(pred))
    axes[i, 1].set_title('Prediction', fontsize=9)
    axes[i, 1].axis('off')

    # Column 3: Ground truth (colorized)
    axes[i, 2].imshow(colorize_mask(gt_r))
    axes[i, 2].set_title('Ground Truth', fontsize=9)
    axes[i, 2].axis('off')

# Add color legend at the bottom so viewers know which color = which class
patches = [mpatches.Patch(color=CLASS_COLORS[c] / 255.0, label=CLASS_NAMES[c])
           for c in range(CFG['num_classes'])]
fig.legend(handles=patches, loc='lower center', ncol=5, fontsize=9)
plt.tight_layout()
plt.subplots_adjust(bottom=0.04)
plt.savefig(os.path.join(CFG['save_dir'], 'test_visualizations.png'), dpi=150, bbox_inches='tight')
plt.show()

## Cell 15 — Summary

In [None]:
# ============================================================
# FINAL SUMMARY — Everything in one place
# ============================================================
# This cell prints a clean summary of the entire experiment:
# training config, best metrics, per-class results, and saved files.

n_total_epochs = len(history['train_loss'])
best_val_ep = int(np.argmax(history['val_miou'])) + 1  # which epoch had best val mIoU
best_val = max(history['val_miou'])
final_val = history['val_miou'][-1]

print('=' * 70)
print('FINAL SUMMARY — Desert Segmentation (SegFormer-B4)')
print('=' * 70)
print()

# --- Model and training config ---
print(f'Model:            SegFormer-B4 ({CFG["model_name"]})')
print(f'Image size:       {CFG["img_size"]}x{CFG["img_size"]}')
print(f'Total epochs:     {n_total_epochs}')
print(f'Loss:             0.5*Focal + 0.3*Dice + 0.2*WeightedCE')
print()

# --- Key metrics ---
print(f'{"Metric":<25} {"Value":>10}')
print('-' * 37)
print(f'{"Best val mIoU":<25} {best_val*100:>9.2f}%  (epoch {best_val_ep})')
print(f'{"Final val mIoU":<25} {final_val*100:>9.2f}%')
print(f'{"Test mIoU (TTA)":<25} {test_miou*100:>9.2f}%')    # the number that matters most
print(f'{"Test mDice (TTA)":<25} {test_mdice*100:>9.2f}%')
print(f'{"Test Pixel Acc":<25} {test_acc*100:>9.2f}%')
if 'val_mdice' in history:
    print(f'{"Best val mDice":<25} {max(history["val_mdice"])*100:>9.2f}%')
if 'val_acc' in history:
    print(f'{"Best val Pixel Acc":<25} {max(history["val_acc"])*100:>9.2f}%')
print()

# --- Per-class test results ---
print(f'{"Class":<20} {"Test IoU":>10} {"Test Dice":>11}')
print('-' * 43)
for i in range(CFG['num_classes']):
    print(f'{CLASS_NAMES[i]:<20} {test_iou[i]*100:>9.2f}% {test_dice[i]*100:>10.2f}%')
print('-' * 43)
print(f'{"Mean":<20} {test_miou*100:>9.2f}% {test_mdice*100:>10.2f}%')
print()

# --- Saved checkpoints ---
print('Checkpoints saved to Google Drive:')
for name in ['best_model.pth', 'latest_model_ft.pth']:
    path = os.path.join(CFG['save_dir'], name)
    if os.path.exists(path):
        ck = torch.load(path, map_location='cpu', weights_only=False)
        print(f'  [{name}] epoch {ck.get("epoch","?")}, val mIoU = {ck.get("miou",0):.4f}')
    else:
        print(f'  [{name}] not found')

print(f'\nPredictions saved to: {CFG["output_dir"]}')