In [None]:
"""
--------------------------------
Disaster segmentation (FloodNet)

Objectives:
1. Load preprocessed data
2. Build U-Net with pretrained encoder
3. Train with Focal Loss (handle class imbalance)
4. Evaluate with IoU & Dice
5. Visualize predictions
"""

# ============================================================
# 1. Setup & Imports
# ------------------------------------------------------------
# Core, numerical, and plotting utilities + progress bars.
# ============================================================
import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast

# Segmentation Models
import segmentation_models_pytorch as smp

# Add src to path
sys.path.append('../src')

# Set seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

# ============================================================
# Device selection
# ------------------------------------------------------------
# Choose GPU if available, otherwise CPU; print device info.
# ============================================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Device: {device}")

if torch.cuda.is_available():
    # Print basic GPU specs for debugging/resource planning
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

 Device: cuda
   GPU: NVIDIA GeForce RTX 4050 Laptop GPU
   Memory: 6.4 GB


In [2]:
# ============================================================
# 2. Configuration
# ------------------------------------------------------------
# Project paths, model/training params, and dataset class names.
# Short, clear comments (mix of block, inline, and docstring styles).
# ============================================================

class Config:

    # Paths
    BASE_DIR = Path(r"D:\Projects\Image Segmentation for Disaster Resilience\Disaster-segmentation")  # project root
    DATA_DIR = BASE_DIR / "data" / "raw" / "FloodNet"  # raw FloodNet data

    TRAIN_IMAGES = DATA_DIR / "train" / "train-org-img"    # train RGBs
    TRAIN_MASKS  = DATA_DIR / "train" / "train-label-img"  # train masks
    VAL_IMAGES   = DATA_DIR / "val" / "val-org-img"        # val RGBs
    VAL_MASKS    = DATA_DIR / "val" / "val-label-img"      # val masks

    CHECKPOINT_DIR = BASE_DIR / "models" / "checkpoints"  # where checkpoints are saved
    LOG_DIR        = BASE_DIR / "logs"                    # training logs
    RESULTS_DIR    = BASE_DIR / "results"                 # visualizations & metrics

    # Model
    ENCODER_NAME    = "resnet34"     # encoder backbone for U-Net
    ENCODER_WEIGHTS = "imagenet"     # pretrained weights
    NUM_CLASSES     = 10             # segmentation classes

    # Training
    IMG_SIZE = (256, 256) # (H, W) input size for the network
    BATCH_SIZE = 8
    NUM_WORKERS = 4
    EPOCHS = 50
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    
    # Early stopping
    PATIENCE = 10
    
    # Class names
    CLASS_NAMES = [
        "Background", "Building-flooded", "Building-non-flooded",
        "Road-flooded", "Road-non-flooded", "Water",
        "Tree", "Vehicle", "Pool", "Grass"
    ]

# Create directories
Config.CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
Config.LOG_DIR.mkdir(parents=True, exist_ok=True)

print(f"Base Directory: {Config.BASE_DIR}")
print(f"Image Size: {Config.IMG_SIZE}")
print(f"Batch Size: {Config.BATCH_SIZE}")
print(f"Epochs: {Config.EPOCHS}")


Base Directory: D:\Projects\Image Segmentation for Disaster Resilience\Disaster-segmentation
Image Size: (256, 256)
Batch Size: 8
Epochs: 50


In [None]:
# ============================================================
# 3. Dataset and DataLoaders
# ------------------------------------------------------------
# Load dataset module if available; otherwise define an inline
# FloodNetDataset + simple augmentations and DataLoader setup.
# ============================================================

# Import from our modules (or define inline if not found)
try:
    from data.dataset import FloodNetDataset, get_train_transform, get_val_transform
    print("Dataset module loaded")
except ImportError:
    print("Using inline dataset definition")

    # Minimal inline dependencies
    import cv2
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    from torch.utils.data import Dataset
    
    class FloodNetDataset(Dataset):
        """
        Simple Dataset for FloodNet.
        - Loads RGB images and single-channel mask (class ids).
        - Resizes to img_size and applies albumentations transforms if provided.
        """
        def __init__(self, image_dir, mask_dir, transform=None, img_size=(256, 256)):
            self.image_dir = Path(image_dir)
            self.mask_dir = Path(mask_dir)
            self.transform = transform
            self.img_size = img_size
            self.images = sorted([f for f in self.image_dir.glob("*") 
                                  if f.suffix.lower() in ['.jpg', '.jpeg', '.png']])
            print(f"Loaded {len(self.images)} images")
        
        def __len__(self):
            return len(self.images)
        
        def _get_mask_path(self, image_path):
            """
            Heuristic mask lookup:
            - Try same stem with .png
            - Try stem + '_lab'.png
            Returns Path or None.
            """
            img_stem = image_path.stem
            for pattern in [f"{img_stem}.png", f"{img_stem}_lab.png"]:
                mask_path = self.mask_dir / pattern
                if mask_path.exists():
                    return mask_path
            return None
        
        def __getitem__(self, idx):
            # Load image (BGR -> RGB)
            img_path = self.images[idx]
            image = cv2.imread(str(img_path))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # Load mask (fallback: zeros if missing)
            mask_path = self._get_mask_path(img_path)
            mask = (cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
                    if mask_path else np.zeros(image.shape[:2], dtype=np.uint8))

            # Resize to target size (nearest for masks)
            image = cv2.resize(image, self.img_size)
            mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_NEAREST)

            # Apply augmentations or convert to tensors if none provided
            if self.transform:
                augmented = self.transform(image=image, mask=mask)
                image, mask = augmented['image'], augmented['mask']
            else:
                image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
                mask = torch.from_numpy(mask).long()

            return image, mask

    # -------------------------
    # Augmentations (train / val)
    # -------------------------
    def get_train_transform(img_size=(256, 256)):
        """Standard augmentation pipeline for training."""
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.3),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=30, p=0.5),
            A.RandomBrightnessContrast(p=0.3),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])

    def get_val_transform(img_size=(256, 256)):
        """Minimal preprocessing for validation (no augmentations)."""
        return A.Compose([
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])

# -------------------------
# Instantiate datasets
# -------------------------
train_dataset = FloodNetDataset(
    Config.TRAIN_IMAGES, Config.TRAIN_MASKS,
    transform=get_train_transform(Config.IMG_SIZE),
    img_size=Config.IMG_SIZE
)

val_dataset = FloodNetDataset(
    Config.VAL_IMAGES, Config.VAL_MASKS,
    transform=get_val_transform(Config.IMG_SIZE),
    img_size=Config.IMG_SIZE
)

# -------------------------
# DataLoaders
# -------------------------
train_loader = DataLoader(
    train_dataset, batch_size=Config.BATCH_SIZE,
    shuffle=True, num_workers=Config.NUM_WORKERS, pin_memory=True
)

val_loader = DataLoader(
    val_dataset, batch_size=Config.BATCH_SIZE,
    shuffle=False, num_workers=Config.NUM_WORKERS, pin_memory=True
)

# Quick summary
print(f"\nDataset Summary:")
print(f"   Training: {len(train_dataset)} images ({len(train_loader)} batches)")
print(f"   Validation: {len(val_dataset)} images ({len(val_loader)} batches)")

# -------------------------
# Sanity check: sample batch
# -------------------------
sample_img, sample_mask = next(iter(train_loader))
print(f"\nSample batch:")
print(f"   Images: {sample_img.shape}, dtype: {sample_img.dtype}")
print(f"   Masks: {sample_mask.shape}, dtype: {sample_mask.dtype}")
print(f"   Mask classes: {torch.unique(sample_mask).tolist()}")

Dataset module loaded
Loaded 1445 images from D:\Projects\Image Segmentation for Disaster Resilience\Disaster-segmentation\data\raw\FloodNet\train\train-org-img
Loaded 450 images from D:\Projects\Image Segmentation for Disaster Resilience\Disaster-segmentation\data\raw\FloodNet\val\val-org-img

Dataset Summary:
   Training: 1445 images (181 batches)
   Validation: 450 images (57 batches)
