# HERMES 0.3

In [1]:
#import signal
#os.kill(os.getpid(), signal.SIGKILL)
#if "COLAB_GPU" in os.environ or "COLAB_BACKEND_VERSION" in os.environ:
#    os.system("rm -rf /content/*")

In [2]:
!pip install albumentations tqdm segmentation-models-pytorch torchvision
!pip install --upgrade "docutils>=0.20,<0.22"
!pip install awscli -q
!pip install rasterio

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl (154 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: segmentation-models-pytorch
Successfully installed segmentation-models-pytorch-0.5.0
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m129.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m149.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m570.5/570.5 kB[0m [31m44.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.7/85.7 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all th

## IMPORTS AND DEPENDENCIES

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DataParallel

#Computer Vision and image processing
import cv2
import numpy as np
from PIL import Image

#Torchvision for pretrained models and transforms
import torchvision
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3_resnet101
#Data augmentation library
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

#Utilities
import gc
import os
import glob
import time
import zipfile
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import random
import seaborn as sns
from datetime import datetime
import json
import pickle
import pandas as pd
import geopandas as gpd
import rasterio as rio
from tqdm.auto import tqdm
from itertools import cycle
from functools import lru_cache
from collections import defaultdict
from datetime import timedelta
import psutil

os.environ.update({
    "GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR",
    "GDAL_CACHEMAX": "2048",
    "CPL_VSIL_CURL_ALLOWED_EXTENSIONS": ".tif,.tiff"
})

#Optional: Weights & Biases for experiment tracking
#try:
  #import wandb
  #WANDB_AVAILABLE = True
#except ImportError:
  #WANDB_AVAILABLE = False
  #print("Warning: Weights & Biases not installed. Please install with `pip install wandb`")

## Config


In [4]:
#Set random seeds
def set_seed(seed=42):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

  torch.backends.cuda.matmul.allow_tf32 = True
  torch.backends.cudnn.allow_tf32 = True
  torch.backends.cudnn.benchmark = True

def prepare_model(model, config=None):
    """Prepare model for optimized training."""
    model = model.to(memory_format=torch.channels_last)

    if config and getattr(config, "compile_model", False) and hasattr(torch, "compile"):
        try:
            model = torch.compile(model)
            print("✓ Model compiled via torch.compile()")
        except Exception as e:
            print(f"⚠️ Skipping compile due to: {e}")
    else:
        print("ℹ️ Skipping torch.compile (disabled or unsupported).")

    return model

set_seed(42)

#Configuration and Parameters

#Optimized for A100
class Config:
  #model architecture
  backbone = 'resnet101' # May later upgrade to resnet152 or Vision Transformer
  num_classes_flood = 2
  num_classes_damage = 4

  mixed_precision = True # Essential for A100

  #training parameters
  batch_size = 32
  accumulation_steps = 4
  effective_batch_size = batch_size * accumulation_steps
  num_epochs = 50
  learning_rate = 1e-4
  weight_decay = 1e-4
  gradient_clip = 1.0
  warmup_epochs = 2


  # loss weights
  ce_weight = 0.3
  dice_weight = 0.4
  focal_weight = 0.3

  # task weights for multi-tasking learning
  flood_task_weight = 0.6
  damage_task_weight = 0.4

  #Memory Management
  gradient_checkpointing = False
  empty_cache_freq = 50
  compile_model = True

  #Multi-Scale Training
  multi_scale_training = True
  scale_range = (0.75, 1.25)


  # Create directories
  def __init__(self, base_dir="/content", use_tf32=True):
    #Core device and optimization setting
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    self.use_tf32 = use_tf32

    #data loading
    self.gradient_checkpoint = True
    self.num_workers = 0
    self.pin_memory = False
    self.prefetch_factor = 4
    self.persistent_workers = True
    self.spacenet_cache_size=50
    self.preload_count = 10

    # image dimensions
    self.tile_size = 512

    #Directory structure
    base = Path(base_dir)
    self.data_root = base / "data"
    self.checkpoint_dir = base / "checkpoints"
    self.cache_dir = base / "cache"
    self.results_dir = base / "results"


    #Directory creation
    for d in [self.data_root, self.checkpoint_dir, self.cache_dir, self.results_dir]:
        d.mkdir(exist_ok=True, parents=True)



    # Enable A100-specific optimizations
    if torch.cuda.is_available():
      # Enable TF32 for A100
      torch.backends.cuda.matmul.allow_tf32 = self.use_tf32
      torch.backends.cudnn.allow_tf32 = self.use_tf32

      # Set cudnn benchmarking for A100
      torch.backends.cudnn.benchmark = True
      torch.backends.cudnn.enabled = True

      # Check if we actually have an A100
      gpu_name = torch.cuda.get_device_name(0)
      if 'A100' in gpu_name:
          print(f"✓ A100 detected: {gpu_name}")
          print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
      else:
          print(f"⚠ Warning: GPU is {gpu_name}, not A100. Adjusting config...")
          self.batch_size = 8
          self.tile_size = 512
    else:
      print("⚠ Warning: No GPU detected. Adjusting config...")

    print(f"Directories initialized in {base_dir}")
    print(f"Device: {self.device}")

In [5]:
#Mapping Datasets

# Create destination directories
!mkdir -p /content/data/FloodNet
!mkdir -p /content/data/RescueNet
!mkdir -p /content/data/SpaceNet

# Copy datasets from Google Drive to Colab local SSD (with progress bar)
!rsync -ah --info=progress2 "/content/drive/MyDrive/G.E.M.S./FloodNet/" "/content/data/FloodNet/"
!rsync -ah --info=progress2 "/content/drive/MyDrive/G.E.M.S./RescueNet/" "/content/data/RescueNet/"
!rsync -ah --info=progress2 "/content/drive/MyDrive/G.E.M.S./sn8/" "/content/data/SpaceNet/"

         13.06G 100%    2.43MB/s    1:25:30 (xfr#7031, to-chk=0/7046)
         23.73G 100%    3.47MB/s    1:48:37 (xfr#8989, to-chk=0/8999)
          1.68G 100%    6.20MB/s    0:04:19 (xfr#610, to-chk=0/618)


In [6]:
#Pulling FloodNet dataset from GitHub DropBox due to file size at ~12GB

In [7]:
#!wget "https://www.dropbox.com/scl/fo/k33qdif15ns2qv2jdxvhx/ANGaa8iPRhvlrvcKXjnmNRc?rlkey=ao2493wzl1cltonowjdbrnp7f&e=3&st=6lg4ncwc&dl=1"

In [8]:
# import zipfile

# zip_path = '/content/ANGaa8iPRhvlrvcKXjnmNRc?rlkey=ao2493wzl1cltonowjdbrnp7f&e=3&st=6lg4ncwc&dl=1'
# extract_path = '/content/FloodNet'

# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#   zip_ref.extractall(extract_path)

#   print("Extraction complete.")
#   print(os.listdir('/content/FloodNet'))

In [9]:
#Defining dataset pathways
FloodNet_train_img_dir = '/content/data/FloodNet/FloodNet-Supervised_v1.0/train/train-org-img'
FloodNet_train_mask_dir = '/content/data/FloodNet/FloodNet-Supervised_v1.0/train/train-label-img'


FloodNet_val_img_dir ='/content/data/FloodNet/FloodNet-Supervised_v1.0/val/val-org-img'
FloodNet_val_mask_dir = '/content/data/FloodNet/FloodNet-Supervised_v1.0/val/val-label-img'

FloodNet_test_img_dir = '/content/data/FloodNet/FloodNet-Supervised_v1.0/test/test-org-img'
FloodNet_test_mask_dir = '/content/data/FloodNet/FloodNet-Supervised_v1.0/test/test-label-img'

In [10]:
# os.remove("/content/ANGaa8iPRhvlrvcKXjnmNRc?rlkey=ao2493wzl1cltonowjdbrnp7f&e=3&st=6lg4ncwc&dl=1")
# print("Zip File Deleted")

In [11]:
#with zipfile.ZipFile("/content/drive/MyDrive/G.E.M.S./.zip", "r") as zip_ref:
  #zip_ref.extractall("/content/SN8")

#print("Extraction complete.")

In [12]:
SpaceNet8_train_img_dir = '/content/data/SpaceNet/images/train'
SpaceNet8_train_mask_dir = '/content/data/SpaceNet/masks/train'

SpaceNet8_val_img_dir = '/content/data/SpaceNet/images/val'
SpaceNet8_val_mask_dir = '/content/data/SpaceNet/masks/val'

In [13]:
#!pip install gdown
# import gdown

In [14]:
#!gdown --fuzzy "https://drive.google.com/file/d/1iRkEX9LQ8Hi-38QMyaReFJ8wXDDyYJAg/view?usp=sharing" -O RescueNet.zip

In [15]:
# import zipfile

# with zipfile.ZipFile("/content/drive/MyDrive/G.E.M.S./RescueNet.zip", "r") as zip_ref:
#      zip_ref.extractall("/content/RescueNet")
# \
# print("Extraction complete.")

In [16]:
print(os.listdir('/content/data'))
#print(os.listdir('/content/RescueNet'))

['FloodNet', 'RescueNet', 'SpaceNet']


In [17]:
RescueNet_train_img_dir = '/content/data/RescueNet/train/train-org-img'
RescueNet_train_mask_dir = '/content/data/RescueNet/train/train-label-img'

RescueNet_val_img_dir = '/content/data/RescueNet/val/val-org-img'
RescueNet_val_mask_dir = '/content/data/RescueNet/val/val-label-img'

RescueNet_test_img_dir = '/content/data/RescueNet/test/test-org-img'
RescueNet_test_mask_dir = '/content/data/RescueNet/test/test-label-img'

In [18]:
#os.remove("/content/RescueNet.zip")
#print("Files Deleted")

In [19]:
sn8_path = "/content/data/SpaceNet"
sn8_train_manifest = "/content/data/SpaceNet/manifests/train_manifest.csv"
sn8_val_manifest = "/content/data/SpaceNet/manifests/val_manifest.csv"

## Loading Datasets

In [20]:
class FloodNetDataset(Dataset):
    """
    Dataset class for FloodNet flood detection data

    Expected file structure:
    - Images: original disaster images
    - Masks: binary masks (0: non-flooded, 1: flooded)
    """
    FLOODNET_COLORS = {
        (0,0,0):0, #Unlabeled
        (128,0,0):1, #Building-flooded
        (0,128,0):2, # Buildings-non-flooded
        (128,128,0):3, #Road-flooded
        (0,0,128):4, #Road-non-flooded
        (128,0,128):5, #Water
        (0,128,128):6, #Tree
        (128,128,128):7,#Vehicle
        (64,0,0):8, #Pool
        (192,0,0):9, #Grass

    }


    def __init__(self, img_dir, mask_dir, transform=None, binary_flood=True):
        """
        Args:
            img_dir: Directory containing input images
            mask_dir: Directory containing segmentation masks
            transform: Albumentations transform pipeline
        """
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_names = sorted(os.listdir(img_dir))
        self.transform = transform
        self.binary_flood= binary_flood

        # Filter images that have corresponding masks
        all_imgs = sorted(os.listdir(img_dir))
        self.img_names = []

        for img_name in all_imgs:
          base_name = img_name.rsplit(".", 1)[0]
          pattern = os.path.join(mask_dir, base_name + "_*.png")
          if glob.glob(pattern):
            self.img_names.append(img_name)
          else:
            print(f"Warning: No mask found for {img_name}, excluding from dataset")

        print(f"FloodNet dataset: {len(all_imgs)} images, {len(self.img_names)} with masks")


    def rgb_to_class(self, mask_rgb):
        """Convert RGB mask to class indices"""
        h, w = mask_rgb.shape[:2]
        mask_class = np.zeros((h,w), dtype=np.uint8)

        for rgb, class_id in self.FLOODNET_COLORS.items():
            #Find pixels matching this RGB value
            matches = np.all(mask_rgb == rgb, axis=-1)
            mask_class[matches] = class_id

        return mask_class

    def create_flood_mask(self, class_mask):
      """Create binary flood/non-_flood mask from classes"""
      flood_mask = np.zeros_like(class_mask, dtype=np.uint8)
      flooded_classes = [1,3,5] # Building-flooded


      for class_id in flooded_classes:
        flood_mask[class_mask == class_id] = 1

      return flood_mask

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        """Load image and mask pair with error handling"""
        img_name = self.img_names[idx]

        # Extract base filename without extension
        base_name = img_name.rsplit(".", 1)[0]

        # Find matching mask (FloodNet naming pattern)
        pattern = os.path.join(self.mask_dir, base_name + "_*.png")
        matching_masks = glob.glob(pattern)

        if len(matching_masks) == 0:
            raise FileNotFoundError(f"No matching mask for: {img_name}")

        mask_path = matching_masks[0]
        img_path = os.path.join(self.img_dir, img_name)

        # Load image (OpenCV loads as BGR, convert to RGB)
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Failed to load image: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Load mask as grayscale
        mask_rgb = cv2.imread(mask_path)
        if mask_rgb is None:
            raise FileNotFoundError(f"Failed to load mask: {mask_path}")
        mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_BGR2RGB)

        # Convert RGB mask to class indices
        class_mask = self.rgb_to_class(mask_rgb)

        # Create appropriate output mask
        if self.binary_flood:
          mask = self.create_flood_mask(class_mask)
        else:
          mask = class_mask


        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        return image, mask.long()

In [21]:
class RescueNetDataset(Dataset):
    """
    Dataset class for RescueNet damage assessment data

    Expected file structure:
    - Images: original disaster images
    - Masks: multi-class masks (0-3 for damage levels)
    """
    def __init__(self, img_dir, mask_dir, transform=None):
        """
        Args:
            img_dir: Directory containing input images
            mask_dir: Directory containing segmentation masks
            transform: Albumentations transform pipeline
        """
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_names = sorted(os.listdir(img_dir))
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        """Load image and mask pair with error handling"""
        img_name = self.img_names[idx]

        # Extract base filename without extension
        base_name = img_name.rsplit(".", 1)[0]

        # Find matching mask (RescueNet naming pattern)
        pattern = os.path.join(self.mask_dir, base_name + "_lab*.png")
        matching_masks = glob.glob(pattern)

        if len(matching_masks) == 0:
            raise FileNotFoundError(f"No matching mask for: {img_name}")

        mask_path = matching_masks[0]
        img_path = os.path.join(self.img_dir, img_name)

        # Load image
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Failed to load image: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Load mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Failed to load mask: {mask_path}")

        # Ensure mask values are in valid range (0-3)
        mask = np.clip(mask, 0, 3).astype(np.uint8)

        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        return image, mask.long()

In [22]:
class SpaceNet8Dataset(Dataset):
    def __init__(self, manifest_csv, augment=None,
                 cache_size=50, disaster_focus=True,
                 preload_critical_samples=True, config=None):
        self.df = pd.read_csv(manifest_csv)
        self.augment = augment
        self.disaster_focus = disaster_focus
        self.preload_critical_samples = preload_critical_samples
        self.config = config

        # Quick LRU cache for frequently accessed files
        from functools import lru_cache
        self._load_file = lru_cache(maxsize=cache_size)(self._load_file_uncached)

        if disaster_focus:
            self._prioritize_disaster_samples()  # Added underscore

        self.gpu_preloaded = {}
        if preload_critical_samples and torch.cuda.is_available():
            self._preload_disaster_samples(self.config.preload_count)  # Added underscore

        print(f"Disaster-focused SpaceNet8: {len(self.df)} samples")
        print(f"Cache size: {cache_size}, GPU preloaded: {len(self.gpu_preloaded)}")

    def _prioritize_disaster_samples(self):
        """Prioritize samples likely to contain flood/damage for training efficiency"""
        file_sizes = []
        for _, row in self.df.iterrows():
            try:
                mask_size = os.path.getsize(row["mask_path"]) if os.path.exists(row["mask_path"]) else 0
                file_sizes.append(mask_size)
            except:
                file_sizes.append(0)

        self.df["mask_size"] = file_sizes
        self.df = self.df.sort_values(by='mask_size', ascending=False).reset_index(drop=True)
        print(f"Prioritized {len(self.df)} samples by disaster content likelihood")

    def _preload_disaster_samples(self, preload_count):
      """Preload disaster-rich samples to GPU for faster training"""
      print(f"Preloading top {preload_count} disaster samples to A100 GPU memory...")

      # ADDED: Check available GPU memory
      if torch.cuda.is_available():
        available_mem = torch.cuda.get_device_properties(0).total_memory * 0.3  # Use max 30% for preload
        estimated_sample_size = self.config.tile_size * self.config.tile_size * 3 * 4  # RGB float32
        max_preload = int(available_mem / estimated_sample_size)
        preload_count = min(preload_count, max_preload, len(self.df))
        print(f"  Adjusted preload count to {preload_count} based on available memory")

      for idx in range(preload_count):
        row = self.df.iloc[idx]
        try:
            img, mask = self._load_file_uncached(row["post_path"], row["mask_path"])

            img_tensor = torch.from_numpy(img).float() / 255.0
            mask_tensor = torch.from_numpy(mask).long()

            self.gpu_preloaded[idx] = (
                img_tensor.cuda(non_blocking=True),
                mask_tensor.cuda(non_blocking=True)
            )
        except RuntimeError as e:  # ADDED: Handle OOM during preload
            if "out of memory" in str(e).lower():
                print(f"  OOM during preload at sample {idx}, stopping preload")
                torch.cuda.empty_cache()
                break
            raise e
        except Exception as e:
            print(f"Failed to preload sample {idx}: {e}")

      print(f"Successfully preloaded {len(self.gpu_preloaded)} disaster-rich samples to GPU")

    def _load_file_uncached(self, post_path, mask_path):
        """Load and cache raw file for optimization"""
        try:
            with rio.Env(
                GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
                GDAL_CACHEMAX=2048,
                CPL_VSIL_CURL_ALLOWED_EXTENSIONS='.tif,.tiff',  # Fixed comma
                GDAL_NUM_THREADS='ALL_CPUS'
            ):
                with rio.open(post_path) as src:
                    # (C,H,W) order - FIXED INDENTATION
                    img = src.read([1, 2, 3])
                with rio.open(mask_path) as src:
                    mask = src.read(1)
            return img.astype(np.uint8), mask.astype(np.int64)  # Fixed dtype

        except Exception as e:
            print(f"Failed to load file {os.path.basename(post_path)}: {e}")
            return (np.zeros((3, self.config.tile_size, self.config.tile_size), dtype=np.uint8),
                    np.zeros((self.config.tile_size, self.config.tile_size), dtype=np.int64))  # Fixed dtype

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Check GPU preloaded samples first (fastest)
        if idx in self.gpu_preloaded:
            img_tensor, mask_tensor = self.gpu_preloaded[idx]

            if self.augment:
                img_np = img_tensor.cpu().numpy().transpose(1, 2, 0)  # Fixed comma
                mask_np = mask_tensor.cpu().numpy()
                augmented = self.augment(image=img_np, mask=mask_np)
                return augmented["image"], augmented["mask"]

            return img_tensor, mask_tensor

        # If not in cache, load from disk
        row = self.df.iloc[idx]
        img, mask = self._load_file(row["post_path"], row["mask_path"])

        # Convert to HWC for Albumentations
        img = np.moveaxis(img, 0, -1)
        binary_flood_mask = (mask > 0).astype(np.int64)

        # Apply augmentation if given
        if self.augment:
            augmented = self.augment(image=img, mask=binary_flood_mask)  # Fixed mask variable
            return augmented["image"], augmented["mask"]

        # Back to CHW + normalize
        else:
            img_tensor = torch.from_numpy(np.moveaxis(img, -1, 0)).float() / 255.0
            mask_tensor = torch.from_numpy(binary_flood_mask).long()
            return img_tensor, mask_tensor

    def get_raw_mask(self, idx):
        """Get raw mask without augmentation for weight computation"""
        row = self.df.iloc[idx]
        try:
            with rio.open(row["mask_path"]) as src:
                mask = src.read(1)
            return mask.astype(np.int64)
        except Exception as e:
            print(f"Failed to load mask {idx}: {e}")
            return np.zeros((self.config.tile_size, self.config.tile_size), dtype=np.int64)

class SpaceNet8CompatWrapper(Dataset):
    """Wrapper to make SpaceNet8 compatible with FloodNet"""
    def __init__(self, spacenet_dataset):
        self.spacenet_dataset = spacenet_dataset

    def __len__(self):
        return len(self.spacenet_dataset)

    def __getitem__(self, idx):
        img_tensor, mask_tensor = self.spacenet_dataset[idx]
        # Ensure mask is binary for flood detection (0=no flood, 1=flood)
        # SpaceNet8 uses: 0=no building, 1=no damage, 2=minor, 3=major, 4=destroyed
        # Convert to binary: 0=no flood, 1=flood (any building damage indicates potential flooding)
        binary_mask = (mask_tensor > 0).long()

        return img_tensor, binary_mask



In [23]:
def compute_class_weights(dataset, num_classes):
    """Fast vectorized class weight computation - bypasses augmentation"""
    import rasterio as rio

    class_counts = torch.zeros(num_classes, dtype=torch.float64)

    print("Fast class weight computation (bypassing augmentation)...")

    # Direct mask access without augmentation
    if hasattr(dataset, 'df'):  # For SpaceNet8Dataset
        for idx in tqdm(range(len(dataset)), desc="Analyzing masks"):
            row = dataset.df.iloc[idx]
            try:
                # Read mask directly - skip augmentation pipeline
                with rio.open(row["mask_path"]) as src:
                    mask = src.read(1)
                # Vectorized histogram
                mask_flat = torch.from_numpy(mask.flatten()).long()
                counts = torch.bincount(mask_flat, minlength=num_classes)
                class_counts += counts
            except Exception as e:
                continue
    else:  # For other dataset types
        # Fallback to original but optimized
        for idx in tqdm(range(len(dataset)), desc="Analyzing masks"):
            try:
                # Get raw mask without augmentation
                if hasattr(dataset, 'get_raw_mask'):
                    mask = dataset.get_raw_mask(idx)
                else:
                    # Temporarily disable augmentation
                    orig_augment = dataset.augment
                    dataset.augment = None
                    _, mask = dataset[idx]
                    dataset.augment = orig_augment

                if isinstance(mask, torch.Tensor):
                    mask_flat = mask.flatten()
                else:
                    mask_flat = torch.from_numpy(mask.flatten()).long()

                counts = torch.bincount(mask_flat, minlength=num_classes)
                class_counts += counts
            except Exception as e:
                continue

    # Compute weights
    total = class_counts.sum()
    weights = total / (num_classes * class_counts.clamp(min=1))
    weights = weights / weights.sum() * num_classes  # Normalize

    return weights.float()

## Data Augmentation Pipeline

In [24]:
def get_training_augmentation(config, training=True):
    """
    Training augmentation pipeline with various geometric and color transforms
    """
    return A.Compose([
        # Resize to target dimensions
        A.Resize(config.tile_size, config.tile_size),

        # Geometric transforms
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),  # Less common but useful for aerial images
        A.RandomRotate90(p=0.5),

        # Slight rotations and shifts
        A.ShiftScaleRotate(
            shift_limit=0.1,
            scale_limit=0.1,
            rotate_limit=15,
            border_mode=cv2.BORDER_REFLECT,
            p=0.5
        ),

        # Color augmentations
        A.RandomBrightnessContrast(
            brightness_limit=0.2,
            contrast_limit=0.2,
            p=0.5
        ),
        A.ColorJitter(
            brightness=0.1,
            contrast=0.1,
            saturation=0.1,
            hue=0.05,
            p=0.3
        ),

        # Weather effects (useful for disaster scenarios)
        A.RandomRain(p=0.1),
        A.RandomFog(p=0.1),

        # Noise and blur
        A.GaussNoise(per_channel=True, p=0.2),
        A.GaussianBlur(blur_limit=(3, 7), p=0.2),

        # Normalize with ImageNet statistics
        A.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225)
        ),

        # Convert to PyTorch tensor
        ToTensorV2()
    ])

In [25]:
def get_validation_augmentation(config):
    """
    Validation augmentation pipeline (only essential transforms)
    """
    return A.Compose([
        A.Resize(config.tile_size, config.tile_size),
        A.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225)
        ),
        ToTensorV2()
    ])

In [26]:
def get_test_augmentation():
    """
    Test augmentation pipeline (same as validation)
    """
    return get_validation_augmentation()

## Attention models

In [27]:
class ChannelAttention(nn.Module):
    """
    Channel attention module to focus on important feature channels
    Squeeze-and-Excitation style attention mechanism
    """
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # Shared MLP
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()

        # Average pooling path
        avg_out = self.fc(self.avg_pool(x).view(b, c))

        # Max pooling path
        max_out = self.fc(self.max_pool(x).view(b, c))

        # Combine and apply attention
        attention = avg_out + max_out
        return x * attention.view(b, c, 1, 1)

In [28]:
class SpatialAttention(nn.Module):
    """
    Spatial attention module to focus on important spatial regions
    """
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Channel-wise average and max
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)

        # Concatenate and convolve
        x_cat = torch.cat([avg_out, max_out], dim=1)
        attention = self.sigmoid(self.conv(x_cat))

        return x * attention

# Multi-Task Model

In [29]:
class EnhancedDisasterModel(nn.Module):
    """
    Enhanced multi-task segmentation model for disaster assessment

    Features:
    - Shared backbone for feature extraction
    - Task-specific heads with attention mechanisms
    - Feature fusion for cross-task learning
    - Deep supervision options
    """
    def __init__(self, num_classes_flood=2, num_classes_damage=4, backbone='resnet101', config=None):
        super().__init__()
        self.config = config

        self.memory_format = torch.channels_last if torch.cuda.is_available() else torch.contiguous_format

        # Initialize backbone based on configuration
        if backbone == 'resnet101':
            base_model = deeplabv3_resnet101(pretrained=True)
            print("Using ResNet101 backbone")
        else:
            base_model = deeplabv3_resnet50(pretrained=True)
            print("Using ResNet50 backbone")

        # Extract backbone and ASPP module
        self.backbone = base_model.backbone
        self.aspp = base_model.classifier[0]

        # Attention mechanisms for feature refinement
        self.channel_attention = ChannelAttention(256)
        self.spatial_attention = SpatialAttention()

        # Task-specific feature extraction branches
        # Flood detection branch
        self.flood_branch = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.3),

            nn.Conv2d(512, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.2),

            nn.Conv2d(256, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        # Damage assessment branch
        self.damage_branch = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.3),

            nn.Conv2d(512, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.2),

            nn.Conv2d(256, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        # Classification heads
        self.flood_classifier = nn.Conv2d(128, num_classes_flood, 1)
        self.damage_classifier = nn.Conv2d(128, num_classes_damage, 1)

        # Optional: Cross-task feature fusion
        self.enable_fusion = True
        if self.enable_fusion:
            self.fusion_conv = nn.Sequential(
                nn.Conv2d(256, 128, 1, bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, 256, 1, bias=False),
                nn.BatchNorm2d(256),
                nn.Sigmoid()
            )
        if hasattr(config, 'gradient_checkpoint') and config.gradient_checkpoint:
            if hasattr(self.backbone, 'gradient_checkpointing_enable'):
                self.backbone.gradient_checkpointing_enable()
                print("Gradient checkpointing enabled - trading compute for memory")


    def forward(self, x, return_features=False):
        """
        Forward pass through the network

        Args:
            x: Input tensor [B, 3, H, W]
            return_features: Whether to return intermediate features

        Returns:
            flood_out: Flood segmentation output [B, 2, H, W]
            damage_out: Damage segmentation output [B, 4, H, W]
            features (optional): Intermediate features for visualization
        """
        # Store input shape for upsampling
        input_shape = x.shape[-2:]

        # Extract multi-level features from backbone
        features = self.backbone(x)

        # Get high-level features
        x = features['out']

        # Apply ASPP for multi-scale context
        x = self.aspp(x)

        # Apply attention mechanisms
        x = self.channel_attention(x)
        x = self.spatial_attention(x)

        # Task-specific processing
        flood_features = self.flood_branch(x)
        damage_features = self.damage_branch(x)

        # Optional cross-task feature fusion
        if self.enable_fusion:
            # Concatenate task features
            combined = torch.cat([flood_features, damage_features], dim=1)

            # Generate fusion weights
            fusion_weights = self.fusion_conv(combined)

            flood_channels = flood_features.shape[1]
            damage_channels = damage_features.shape[1]

            # Apply fusion
            flood_features = flood_features + fusion_weights[:, :flood_channels] * damage_features
            damage_features = damage_features + fusion_weights[:, flood_channels:flood_channels+damage_channels] * flood_features

        # Generate predictions
        flood_out = self.flood_classifier(flood_features)
        damage_out = self.damage_classifier(damage_features)

        # Upsample to original resolution
        flood_out = F.interpolate(
            flood_out, size=input_shape,
            mode='bilinear', align_corners=False
        )
        damage_out = F.interpolate(
            damage_out, size=input_shape,
            mode='bilinear', align_corners=False
        )

        if return_features:
            return flood_out, damage_out, flood_features, damage_features

        return flood_out, damage_out

## Custom Loss Functions

In [30]:
class DisasterFocusedLoss(nn.Module):
    """Loss function optimized for disaster mapping class imbalances"""

    def __init__(self, ce_weight=0.3, dice_weight=0.4, focal_weight=0.3,
                 class_weights=None, focal_alpha=0.25, focal_gamma=2.0):
        super().__init__()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        self.focal_alpha = focal_alpha
        self.focal_gamma = focal_gamma

        self.ce_loss = nn.CrossEntropyLoss(weight=class_weights)

    def forward(self, pred, target):
        # Cross-entropy with disaster class weighting
        ce = self.ce_loss(pred, target)

        # Dice loss - critical for disaster boundary accuracy
        dice = self._dice_loss_disaster(pred, target)

        # Focal loss - handle severe class imbalance in disaster data
        focal = self._focal_loss_disaster(pred, target)

        total_loss = (
            self.ce_weight * ce +
            self.dice_weight * dice +
            self.focal_weight * focal
        )

        return total_loss, {
            'ce': ce.item(),
            'dice': dice.item(),
            'focal': focal.item()
        }

    def _dice_loss_disaster(self, pred, target):
        """Dice loss optimized for disaster mapping accuracy"""
        pred_soft = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, pred.shape[1]).permute(0, 3, 1, 2).float()

        # Flatten for efficient computation
        pred_flat = pred_soft.reshape(pred_soft.shape[0], pred_soft.shape[1], -1)
        target_flat = target_one_hot.reshape(target_one_hot.shape[0], target_one_hot.shape[1], -1)

        intersection = (pred_flat * target_flat).sum(dim=2)
        union = pred_flat.sum(dim=2) + target_flat.sum(dim=2)

        # Smooth dice with small epsilon for disaster data stability
        dice = (2 * intersection + 1e-6) / (union + 1e-6)

        return 1 - dice.mean()

    def _focal_loss_disaster(self, pred, target):
        """Focal loss with disaster-specific parameters"""
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)

        # Disaster-tuned focal loss
        focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * ce_loss

        return focal_loss.mean()

dice_loss_flood = smp.losses.DiceLoss(mode='binary')

bce_loss_flood = nn.BCEWithLogitsLoss()

def criterion_flood(pred, target):
  return dice_loss_flood(pred, target) + bce_loss_flood(pred, target)

criterion_damage = DisasterFocusedLoss(
    ce_weight=0.3,
    dice_weight=0.4,
    focal_weight=0.3,
    focal_alpha=0.25,
    focal_gamma=2.0
)

def joint_loss(preds, targets):
  """
  preds: dict with {"flood": flood_out, "damage": damage_out}
  targets: dict with {"flood": flood_mask, "damage": damage_mask}
  """
  # Flood: binary dice + BCE
  loss_flood = criterion_flood(preds["flood"], targets["flood"])

  # Damage: multi-class dice + focal loss
  loss_damage, components = criterion_damage(preds["damage"], targets["damage"])

  total_loss = 0.5 * loss_flood + 0.5 * loss_damage
  return total_loss, {"flood": loss_flood.item(), **components}



## Evaluation Metrics


In [31]:
class MetricCalculator:
    """
    Optimized vectorized segmentation metrics calculator
    """
    @staticmethod
    def calculate_iou(pred, target, num_classes):
        """
        Vectorized IoU calculation - much faster
        """
        if pred.dim() > 1:
            pred = pred.view(-1)
            target = target.view(-1)

        # Vectorized computation for all classes at once
        ious = []
        for cls in range(num_classes):
            pred_mask = (pred == cls)
            target_mask = (target == cls)

            intersection = (pred_mask & target_mask).sum().float()
            union = (pred_mask | target_mask).sum().float()

            if union == 0:
                ious.append(float('nan'))
            else:
                ious.append((intersection / union).item())

        return ious

    @staticmethod
    def calculate_metrics_batch(preds, targets, num_classes):
        """
        Calculate IoU and Dice for entire batch at once - NEW METHOD
        """
        batch_size = preds.shape[0]
        device = preds.device

        # Flatten spatial dimensions
        preds_flat = preds.view(batch_size, -1)
        targets_flat = targets.view(batch_size, -1)

        # Preallocate results
        ious = torch.zeros(batch_size, num_classes, device=device)
        dices = torch.zeros(batch_size, num_classes, device=device)

        for cls in range(num_classes):
            pred_mask = (preds_flat == cls).float()
            target_mask = (targets_flat == cls).float()

            intersection = (pred_mask * target_mask).sum(dim=1)
            union = ((pred_mask + target_mask) > 0).float().sum(dim=1)
            pred_sum = pred_mask.sum(dim=1)
            target_sum = target_mask.sum(dim=1)

            # IoU
            valid_union = union > 0
            ious[valid_union, cls] = intersection[valid_union] / union[valid_union]
            ious[~valid_union, cls] = float('nan')

            # Dice
            dice_denom = pred_sum + target_sum
            valid_dice = dice_denom > 0
            dices[valid_dice, cls] = 2 * intersection[valid_dice] / dice_denom[valid_dice]
            dices[~valid_dice, cls] = float('nan')

        return ious, dices

    @staticmethod
    def calculate_dice(pred, target, num_classes):
        """
        Vectorized Dice calculation
        """
        if pred.dim() > 1:
            pred = pred.view(-1)
            target = target.view(-1)

        dices = []
        for cls in range(num_classes):
            pred_mask = (pred == cls).float()
            target_mask = (target == cls).float()

            intersection = (pred_mask * target_mask).sum()
            dice_denom = pred_mask.sum() + target_mask.sum()

            if dice_denom == 0:
                dices.append(float('nan'))
            else:
                dices.append((2 * intersection / dice_denom).item())

        return dices

    @staticmethod
    def calculate_pixel_accuracy(pred, target):
        """
        Calculate overall pixel accuracy - optimized
        """
        return (pred == target).float().mean().item()

## Pretraining Flood model with SpaceNet

In [32]:
# Create a separate pretraining function
def pretrain_flood_on_spacenet(model, spacenet_loader, num_epochs=10):
    print("Pretraining flood detection on SpaceNet8...")

    # Only optimize flood-related parameters
    pretrain_optimizer = torch.optim.Adam([
        {'params': model.backbone.parameters(), 'lr': 1e-5},
        {'params': model.flood_branch.parameters(), 'lr': 5e-5},
        {'params': model.flood_classifier.parameters(), 'lr': 5e-5}
    ])

    for epoch in range(num_epochs):
      for images, masks in tqdm(spacenet_loader):
        flood_out, _ = model(images)
        loss = F.cross_entropy(flood_out, masks)

        pretrain_optimizer.zero_grad()
        loss.backward()
        pretrain_optimizer.step()


In [None]:
# This should be BEFORE creating the trainer, not during training
if not class_weights_computed:  # Add a flag to ensure it's only done once
    print("Computing class weights (this may take a minute)...")
    flood_class_weights = compute_class_weights(flood_train_dataset, num_classes=2)
    damage_class_weights = compute_class_weights(damage_train_dataset, num_classes=4)
    class_weights_computed = True
    print("Class weights computed!")

## Optimized Trainer

In [33]:
class OptimizedTrainer:
    def __init__(self, model, config, device='cuda'):
        self.model = model
        self.config = config
        self.device = config.device

        # A100 optimizations
        self._enable_optimizations()

        # Mixed precision training
        self.scaler = torch.amp.GradScaler('cuda', init_scale=2**10, growth_interval=100)

        # Optimizer with task-specific learning rates
        self.optimizer = self._create_disaster_optimizer()

        # Loss functions
        self.flood_loss_fn = self._create_disaster_loss('flood')
        self.damage_loss_fn = self._create_disaster_loss('damage')

        # Metrics tracking
        self.metric_calculator = MetricCalculator()
        self.disaster_metrics = defaultdict(list)
        self.best_disaster_iou = {'flood': 0.0, 'damage': 0.0}

        # Training history tracking
        self.train_losses = {'flood': [], 'damage': [], 'total': []}
        self.val_metrics = {'flood_iou': [], 'damage_iou': []}
        self.best_flood_iou = 0.0
        self.best_damage_iou = 0.0
        self.best_combined_score = 0.0

        print("Disaster mapping trainer initialized for A100")

    def _enable_optimizations(self):
        """Enable A100-specific optimizations"""
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_per_process_memory_fraction(0.8)
        print("A100 optimizations enabled")

    def _create_disaster_optimizer(self):
        """Create optimizer with disaster-focused learning rates"""
        param_groups = [
            {
                'params': [p for n, p in self.model.backbone.named_parameters()],
                'lr': self.config.learning_rate * 0.1,
                'weight_decay': self.config.weight_decay
            },
            {
                'params': [p for n, p in self.model.aspp.named_parameters()],
                'lr': self.config.learning_rate * 0.3,
                'weight_decay': self.config.weight_decay * 0.5
            },
            {
                'params': list(self.model.flood_branch.parameters()) +
                          list(self.model.flood_classifier.parameters()),
                'lr': self.config.learning_rate * 1.2,
                'weight_decay': self.config.weight_decay * 0.8
            },
            {
                'params': list(self.model.damage_branch.parameters()) +
                          list(self.model.damage_classifier.parameters()),
                'lr': self.config.learning_rate * 1.5,
                'weight_decay': self.config.weight_decay * 0.6
            }
        ]

        return torch.optim.AdamW(
            param_groups,
            eps=1e-8,
            betas=(0.9, 0.999),
            amsgrad=True
        )

    def _create_disaster_loss(self, task):
        """Create disaster-focused loss function"""
        if task == 'flood':
            class_weights = torch.tensor([0.2, 3.5]).cuda()
            alpha = 0.8
        else:
            class_weights = torch.tensor([0.3, 1.5, 2.2, 2.8]).cuda()
            alpha = 0.6

        return DisasterFocusedLoss(
            ce_weight=0.3,
            dice_weight=0.4,
            focal_weight=0.3,
            class_weights=class_weights,
            focal_alpha=alpha
        )

    def train_epoch(self, flood_loader, damage_loader, epoch):
      """A100-optimized training epoch for disaster mapping"""
      import time
      self.model.train()

      flood_cycle = cycle(flood_loader)
      damage_cycle = cycle(damage_loader)

      num_iterations = max(len(flood_loader), len(damage_loader))
      epoch_metrics = defaultdict(float)

      print(f"Disaster mapping training epoch {epoch}: {num_iterations} iterations")

      # Initialize gradient accumulation
      self.optimizer.zero_grad(set_to_none=True)
      pbar = tqdm(range(num_iterations), desc=f"Epoch {epoch+1}", leave=False)

      times = {'data_load': [], 'forward':[], 'total':[]}

      for i in pbar:
          print(f"DEBUG: Starting iteration {i}")
          iter_start = time.time()
          try:
              print(f"DEBUG: Loading flood batch...")
              data_start = time.time()
              flood_batch = next(flood_cycle)
              print(f"DEBUG: Flood batch loaded in {time.time()-data_start:.2f}s")

              print(f"DEBUG: Loading damage batch...")
              damage_batch = next(damage_cycle)
              print(f"DEBUG: Damage batch loaded")

              times['data_load'].append(time.time() - data_start)

              forward_start = time.time()
              metrics = self._simultaneous_disaster_step(flood_batch, damage_batch, i)

              times['forward'].append(time.time() - forward_start)

              # Accumulate metrics
              for key, value in metrics.items():
                  epoch_metrics[key] += value

              # Accumulation boundary -> optimizer step
              if (i + 1) % self.config.accumulation_steps == 0:
                  self._disaster_gradient_update()

              times['total'].append(time.time() - iter_start)

            # Progress reporting
              if (i + 1) % 20 == 0:
                  if len(times['total']) >= 20:
                      avg_data = sum(times['data_load'][-20:]) / 20
                      avg_forward = sum(times['forward'][-20:]) / 20
                      avg_total = sum(times['total'][-20:]) / 20

                      print(f"\n⏱️  Timing (last 20 iters):")
                      print(f"   Data load: {avg_data:.3f}s ({avg_data/avg_total*100:.1f}%)")
                      print(f"   Forward+backward: {avg_forward:.3f}s ({avg_forward/avg_total*100:.1f}%)")
                      print(f"   Total: {avg_total:.3f}s/iter")

                  flood_loss = epoch_metrics['flood_loss'] / (i + 1)
                  damage_loss = epoch_metrics['damage_loss'] / (i + 1)
                  mem_gb = (torch.cuda.max_memory_allocated() / 1e9) if torch.cuda.is_available() else 0
                  pbar.set_postfix({
                  'flood_loss': f"{flood_loss:.4f}",
                  'damage_loss': f"{damage_loss:.4f}",
                  'mem': f"{mem_gb:.1f}GB"
                  })

                  if (i + 1) % 50 == 0:
                      torch.cuda.empty_cache()

          except RuntimeError as e:
            # Robust OOM handling (skip batch, keep training)
            if "out of memory" in str(e).lower():
                print(f"⚠️ GPU OOM at iter {i}, skipping...")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                # Clear any partial grads from the failed step
                self.optimizer.zero_grad(set_to_none=True)
                continue
            # Re-raise non-OOM errors
            raise

      # Flush remaining accumulated grads if loop ended mid-accumulation
      if num_iterations % self.config.accumulation_steps != 0:
          self._disaster_gradient_update()

      # Epoch averages
      epoch_flood_loss = epoch_metrics['flood_loss'] / max(1, num_iterations)
      epoch_damage_loss = epoch_metrics['damage_loss'] / max(1, num_iterations)
      epoch_total_loss = epoch_flood_loss + epoch_damage_loss

      # Record history
      self.train_losses['flood'].append(epoch_flood_loss)
      self.train_losses['damage'].append(epoch_damage_loss)
      self.train_losses['total'].append(epoch_total_loss)

      return epoch_flood_loss, epoch_damage_loss, epoch_total_loss

    def _simultaneous_disaster_step(self, flood_batch, damage_batch, step_idx):
        """Process both disaster tasks simultaneously on A100"""
        flood_imgs, flood_masks = flood_batch
        damage_imgs, damage_masks = damage_batch

        # Move to GPU with non-blocking transfer
        flood_imgs = flood_imgs.to(self.device, non_blocking=True)
        flood_masks = flood_masks.to(self.device, non_blocking=True)
        damage_imgs = damage_imgs.to(self.device, non_blocking=True)
        damage_masks = damage_masks.to(self.device, non_blocking=True)

        metrics = {}

        with torch.amp.autocast('cuda', enabled=True):
            # Process flood task
            flood_out_1, _ = self.model(flood_imgs)
            flood_loss, flood_components = self.flood_loss_fn(flood_out_1, flood_masks)

            # Process damage task
            _, damage_out_2 = self.model(damage_imgs)
            damage_loss, damage_components = self.damage_loss_fn(damage_out_2, damage_masks)

            # Combined loss with task weighting
            total_loss = (
                self.config.flood_task_weight * flood_loss +
                self.config.damage_task_weight * damage_loss
            ) / self.config.accumulation_steps

        # Backward pass
        self.scaler.scale(total_loss).backward()

        # Collect metrics
        metrics.update({
            'flood_loss': flood_loss.item(),
            'damage_loss': damage_loss.item(),
            'total_loss': total_loss.item() * self.config.accumulation_steps,
            'flood_dice': flood_components.get('dice', 0),
            'damage_dice': damage_components.get('dice', 0),
        })

        return metrics

    def _disaster_gradient_update(self):
        """Gradient update optimized for disaster mapping"""
        # Unscale gradients
        self.scaler.unscale_(self.optimizer)

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)

        # Step optimizer
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()

    def validate(self, loader, task, num_classes):
        """Optimized validation with batch-level metric computation"""
        self.model.eval()

        total_loss = 0
        all_ious = []
        all_dice = []
        all_accuracies = []

        with torch.no_grad():
            for images, masks in tqdm(loader, desc=f'Validating {task}'):
                images = images.to(self.device, non_blocking=True)
                masks = masks.to(self.device, non_blocking=True)

                with torch.amp.autocast('cuda', enabled=True):
                    flood_out, damage_out = self.model(images)

                    if task == 'flood':
                        output = flood_out
                        loss, _ = self.flood_loss_fn(output, masks)
                    else:
                        output = damage_out
                        loss, _ = self.damage_loss_fn(output, masks)

            total_loss += loss.item() * images.size(0)  # Weight by batch size

            # Get predictions
            preds = torch.argmax(output, dim=1)

            # BATCH-LEVEL METRICS - Much faster!
            batch_ious, batch_dice = self.metric_calculator.calculate_metrics_batch(
            preds, masks, num_classes
            )

            #Pixel accuracy for batch
            batch_acc = (preds == masks).float().mean(dim=(1, 2))  # Per-sample accuracy

            # Store results (keep on GPU until end)
            all_ious.append(batch_ious)
            all_dice.append(batch_dice)
            all_accuracies.append(batch_acc)

        # Aggregate all metrics at once (more efficient)
        all_ious = torch.cat(all_ious, dim=0)
        all_dice = torch.cat(all_dice, dim=0)
        all_accuracies = torch.cat(all_accuracies, dim=0)

        # Move to CPU and compute statistics
        all_ious_cpu = all_ious.cpu().numpy()
        all_dice_cpu = all_dice.cpu().numpy()

          # Calculate averages
        avg_loss = total_loss / len(loader.dataset)
        class_ious = np.nanmean(all_ious_cpu, axis=0)
        class_dice = np.nanmean(all_dice_cpu, axis=0)
        mean_iou = np.nanmean(class_ious)
        mean_dice = np.nanmean(class_dice)
        mean_accuracy = all_accuracies.mean().item()

        # Print validation results
        print(f"\n{task.upper()} Validation Results:")
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"Mean IoU: {mean_iou:.4f}")
        print(f"Mean Dice: {mean_dice:.4f}")
        print(f"Pixel Accuracy: {mean_accuracy:.4f}")

        # Per-class metrics
        for i in range(num_classes):
            class_name = self._get_class_name(task, i)
        print(f"  {class_name} - IoU: {class_ious[i]:.4f}, Dice: {class_dice[i]:.4f}")

        # Update validation metrics history
        if task == 'flood':
            self.val_metrics['flood_iou'].append(mean_iou)
        else:
            self.val_metrics['damage_iou'].append(mean_iou)

        return avg_loss, mean_iou, class_ious

    def _get_class_name(self, task, class_idx):
        """Get human-readable class names for disaster mapping"""
        if task == 'flood':
            return ['Non-Flooded', 'Flooded'][class_idx]
        else:
            return ['No Damage', 'Minor', 'Major', 'Destroyed'][class_idx]

    def test(self, loader, task, num_classes):
        """Test model - same as validate for disaster mapping"""
        return self.validate(loader, task, num_classes)

class CheckpointManager:
  "Save and Resume checkpoints"

  def __init__(self, model, config, checkpoint_dir):
    self.model = model
    self.config = config
    self.device = config.device
    self.checkpoint_dir = Path(checkpoint_dir)
    self.checkpoint_dir.mkdir(parents=True, exist_ok=True)

    self.train_losses = {'flood': [], 'damage': [], 'total': []}
    self.val_metrics = {'flood_iou': [], 'damage_iou': []}
    self.best_flood_iou = 0.0
    self.best_damage_iou = 0.0
    self.optimizer = None


  def save_checkpoint(
      self,model, optimizer, scaler, epoch,
      flood_metrics, damage_metrics, is_best=False,
      train_losses=None, val_metrics=None,
      best_flood_iou=None, best_damage_iou=None):

    if train_losses is not None:
      self.train_losses = train_losses
    if val_metrics is not None:
      self.val_metrics = val_metrics
    if best_flood_iou is not None:
      self.best_flood_iou = best_flood_iou
    if best_damage_iou is not None:
      self.best_damage_iou = best_damage_iou
    self.optimizer = optimizer  # Store for plotting

    checkpoint = {
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'scaler_state_dict': scaler.state_dict(),
      'train_losses': self.train_losses,
      'val_metrics': self.val_metrics,
      'flood_metrics': flood_metrics,
      'damage_metrics': damage_metrics,
      'best_flood_iou': self.best_flood_iou,
      'best_damage_iou': self.best_damage_iou,
      'config': self.config.__dict__
    }

    # Save regular checkpoint
    checkpoint_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt'
    torch.save(checkpoint, checkpoint_path)
    print(f"Saved checkpoint: {checkpoint_path}")

    # Save best model
    if is_best:
      best_path = self.checkpoint_dir / f'best_model.pt'
      torch.save(checkpoint, best_path)
      print(f"Saved best model: {best_path}")

  def load_checkpoint(self, model, optimizer, scaler, checkpoint_name='best_model.pt'):
    """Load checkpoint and restore model, optimizer, and scaler state."""
    checkpoint_path = self.checkpoint_dir / checkpoint_name
    print(f"Loading checkpoint from {checkpoint_path}")

    # Ensure proper device type
    device = torch.device(self.device if isinstance(self.device, str) else self.device)
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Restore weights
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])

    # Restore training history
    self.train_losses = checkpoint.get('train_losses', {'flood': [], 'damage': [], 'total': []})
    self.val_metrics = checkpoint.get('val_metrics', {'flood_iou': [], 'damage_iou': []})
    self.best_flood_iou = checkpoint.get('best_flood_iou', 0.0)
    self.best_damage_iou = checkpoint.get('best_damage_iou', 0.0)

    epoch = checkpoint.get('epoch', 0)
    print(f"Resumed from epoch {epoch}")
    print(f"Best Flood IoU: {self.best_flood_iou:.4f}")
    print(f"Best Damage IoU: {self.best_damage_iou:.4f}")

    return epoch

  def _cleanup_old_checkpoints(self):
      """Keep only recent checkpoints"""
      checkpoints = sorted(self.checkpoint_dir.glob('checkpoint_epoch_*.pt'))
      for old_checkpoint in checkpoints[:-3]:
          old_checkpoint.unlink()

  def plot_training_history(self):
      """Generate training history plots"""
      fig, axes = plt.subplots(2, 2, figsize=(15, 12))

      # Loss plot
      ax = axes[0, 0]
      epochs = range(1, len(self.train_losses['total']) + 1)
      ax.plot(epochs, self.train_losses['flood'], 'b-', label='Flood Loss')
      ax.plot(epochs, self.train_losses['damage'], 'r-', label='Damage Loss')
      ax.plot(epochs, self.train_losses['total'], 'g--', label='Total Loss')
      ax.set_xlabel('Epoch')
      ax.set_ylabel('Loss')
      ax.set_title('Training Loss History')
      ax.legend()
      ax.grid(True)

      # IoU plot
      ax = axes[0, 1]
      if self.val_metrics['flood_iou']:
          epochs_val = range(1, len(self.val_metrics['flood_iou']) + 1)
          ax.plot(epochs_val, self.val_metrics['flood_iou'], 'b-o', label='Flood IoU')
          ax.plot(epochs_val, self.val_metrics['damage_iou'], 'r-o', label='Damage IoU')
          ax.set_xlabel('Epoch')
          ax.set_ylabel('IoU')
          ax.set_title('Validation IoU History')
          ax.legend()
          ax.grid(True)

      # Learning rate plot
      ax = axes[1, 0]
      lrs = [group['lr'] for group in self.optimizer.param_groups]
      ax.plot([lrs[0]] * len(epochs), 'b-', label='Backbone LR')
      ax.plot([lrs[-1]] * len(epochs), 'r-', label='Head LR')
      ax.set_xlabel('Epoch')
      ax.set_ylabel('Learning Rate')
      ax.set_title('Learning Rate Schedule')
      ax.legend()
      ax.set_yscale('log')
      ax.grid(True)

      # Summary text
      ax = axes[1, 1]
      ax.axis('off')

      # Fix the string formatting
      final_flood_iou = f"{self.val_metrics['flood_iou'][-1]:.4f}" if self.val_metrics['flood_iou'] else "N/A"
      final_damage_iou = f"{self.val_metrics['damage_iou'][-1]:.4f}" if self.val_metrics['damage_iou'] else "N/A"
      final_flood_loss = f"{self.train_losses['flood'][-1]:.4f}" if self.train_losses['flood'] else "N/A"
      final_damage_loss = f"{self.train_losses['damage'][-1]:.4f}" if self.train_losses['damage'] else "N/A"

      summary_text = f"""
      Training Summary:

      Total Epochs: {len(self.train_losses['total'])}
      Best Flood IoU: {self.best_flood_iou:.4f}
      Best Damage IoU: {self.best_damage_iou:.4f}

      Final Training Loss:
      - Flood: {final_flood_loss}
      - Damage: {final_damage_loss}

      Final Validation IoU:
      - Flood: {final_flood_iou}
      - Damage: {final_damage_iou}
      """
      ax.text(0.1, 0.5, summary_text, fontsize=12,
              verticalalignment='center', fontfamily='monospace')

      plt.tight_layout()

      # Save plot
      plot_path = os.path.join(self.config.results_dir, 'training_history.png')
      plt.savefig(plot_path, dpi=150, bbox_inches='tight')
      plt.close()

      print(f"Training history plot saved to {plot_path}")

## Main Training Loop

In [34]:
def make_safe_dataloader(
    dataset,
    config,
    batch_size,
    shuffle
):
  """Create Data Loader that handles 0 num_workers without crashing"""
  num_workers = int(config.num_workers)
  kwargs = dict(
      batch_size=batch_size,
      shuffle=shuffle,
      num_workers= num_workers,
      pin_memory=config.pin_memory
  )
  if num_workers > 0:
      kwargs.update({
      "prefetch_factor" : config.prefetch_factor,
      "persistent_workers" : config.persistent_workers
      })
  return DataLoader(dataset, **kwargs)

@torch.no_grad()
def validate_config(config):
    """Validate configuration before training"""
    print("\nValidating configuration...")

    # Check CUDA
    assert torch.cuda.is_available(), "CUDA not available - A100 required"

    # Check directories
    assert config.data_root.exists(), f"Data root not found: {config.data_root}"

    # Check dataset paths
    paths_to_check = [
        ("/content/data/FloodNet", "FloodNet"),
        ("/content/data/RescueNet", "RescueNet"),
        ("/content/data/SpaceNet", "SpaceNet8"),
    ]

    for path, name in paths_to_check:
        if not os.path.exists(path):
            print(f"⚠ Warning: {name} not found at {path}")

    # Check manifest files
    if not os.path.exists(sn8_train_manifest):
        print(f"⚠ Warning: SpaceNet8 train manifest not found at {sn8_train_manifest}")
    if not os.path.exists(sn8_val_manifest):
        print(f"⚠ Warning: SpaceNet8 val manifest not found at {sn8_val_manifest}")

    # Check GPU memory
    if torch.cuda.is_available():
        total_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
        required_mem = 32  # Minimum GB for training
        assert total_mem >= required_mem, f"Need {required_mem}GB GPU, have {total_mem:.1f}GB"

    print("✓ Configuration validated")

def compute_or_load_class_weights(dataset, cache_dir, dataset_name):
    """Compute class weights or load from cache to save 2 hours"""
    cache_file = Path(cache_dir) / f"{dataset_name}_class_weights.pkl"

    # Try to load from cache
    if cache_file.exists():
        try:
            with open(cache_file, 'rb') as f:
                data = pickle.load(f)
            print(f"✅ Loaded cached {dataset_name} weights: {data['weights']}")
            return data['weights']
        except Exception as e:
            print(f"⚠️  Cache load failed ({e}), recomputing...")

    # Get num_classes from dataset
    # Try to infer from dataset attributes
    if hasattr(dataset, 'num_classes'):
        num_classes = dataset.num_classes
    elif hasattr(dataset, 'datasets'):  # For ConcatDataset
        num_classes = dataset.datasets[0].num_classes if hasattr(dataset.datasets[0], 'num_classes') else 2
    else:
        # Default guess based on name
        num_classes = 2 if 'flood' in dataset_name.lower() else 4

    # Compute fresh (takes ~1-2 hours)
    print(f"🔄 Computing {dataset_name} class weights (this will take a while)...")
    weights = compute_class_weights(dataset, num_classes)

    # Save to cache
    try:
        cache_file.parent.mkdir(parents=True, exist_ok=True)
        with open(cache_file, 'wb') as f:
            pickle.dump({
                'weights': weights,
                'dataset_name': dataset_name,
                'num_classes': num_classes,
                'num_samples': len(dataset),
                'computed_at': str(datetime.now())
            }, f)
        print(f"✅ Cached to {cache_file}")
    except Exception as e:
        print(f"⚠️  Failed to cache: {e}")

    return weights

def compute_or_load_class_weights(dataset, cache_dir, dataset_name):
    """Compute class weights or load from cache to save 2 hours"""
    cache_file = Path(cache_dir) / f"{dataset_name}_class_weights.pkl"

    # Try to load from cache
    if cache_file.exists():
        try:
            with open(cache_file, 'rb') as f:
                data = pickle.load(f)
            print(f"✅ Loaded cached {dataset_name} weights: {data['weights']}")
            return data['weights']
        except Exception as e:
            print(f"⚠️  Cache load failed ({e}), recomputing...")

    # Get num_classes from dataset
    # Try to infer from dataset attributes
    if hasattr(dataset, 'num_classes'):
        num_classes = dataset.num_classes
    elif hasattr(dataset, 'datasets'):  # For ConcatDataset
        num_classes = dataset.datasets[0].num_classes if hasattr(dataset.datasets[0], 'num_classes') else 2
    else:
        # Default guess based on name
        num_classes = 2 if 'flood' in dataset_name.lower() else 4

    # Compute fresh (takes ~1-2 hours)
    print(f"🔄 Computing {dataset_name} class weights (this will take a while)...")
    weights = compute_class_weights(dataset, num_classes)

    # Save to cache
    try:
        cache_file.parent.mkdir(parents=True, exist_ok=True)
        with open(cache_file, 'wb') as f:
            pickle.dump({
                'weights': weights,
                'dataset_name': dataset_name,
                'num_classes': num_classes,
                'num_samples': len(dataset),
                'computed_at': str(datetime.now())
            }, f)
        print(f"✅ Cached to {cache_file}")
    except Exception as e:
        print(f"⚠️  Failed to cache: {e}")

    return weights


def train_complete_model(config):
    """
    Complete training pipeline
    """
    validate_config(config)
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    if device.type == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    # Create datasets
    print("\nCreating datasets...")


    # Create augmentation pipelines
    train_transform = get_training_augmentation(config)
    val_transform = get_validation_augmentation(config)

    # Create datasets
    flood_train_dataset = FloodNetDataset(
        FloodNet_train_img_dir, FloodNet_train_mask_dir, train_transform, binary_flood=True
    )

    flood_val_dataset = FloodNetDataset(
        FloodNet_val_img_dir, FloodNet_val_mask_dir, val_transform, binary_flood=True
    )

    flood_test_dataset = FloodNetDataset(
        FloodNet_test_img_dir, FloodNet_test_mask_dir, val_transform, binary_flood=True
    )

    damage_train_dataset = RescueNetDataset(
        RescueNet_train_img_dir, RescueNet_train_mask_dir, train_transform
    )

    damage_val_dataset = RescueNetDataset(
        RescueNet_val_img_dir, RescueNet_val_mask_dir, val_transform
    )

    damage_test_dataset = RescueNetDataset(
        RescueNet_test_img_dir, RescueNet_test_mask_dir, val_transform
    )
    # Compute class weights for balanced training
    print("\nComputing class weights...")
    flood_class_weights = compute_or_load_class_weights(flood_train_dataset, config.cache_dir, 'floodnet_train')
    damage_class_weights = compute_or_load_class_weights(damage_train_dataset, config.cache_dir, 'rescuenet_train')
    print(f"Flood class weights: {flood_class_weights}")
    print(f"Damage class weights: {damage_class_weights}")

    # Move weights to GPU
    if torch.cuda.is_available():
        flood_class_weights = flood_class_weights.cuda()
        damage_class_weights = damage_class_weights.cuda()

    space_train_dataset = SpaceNet8Dataset(
        sn8_train_manifest, augment=train_transform,
        cache_size=config.spacenet_cache_size, config=config
    )

    space_val_dataset = SpaceNet8Dataset(
        sn8_val_manifest, augment=val_transform,
        cache_size=config.spacenet_cache_size, config=config
    )

    space_train_dataset = SpaceNet8CompatWrapper(space_train_dataset)

    space_val_dataset = SpaceNet8CompatWrapper(space_val_dataset)


    # Concatenate FloodNet & SpaceNet 8  datasets
    flood_train_combined = ConcatDataset([flood_train_dataset, space_train_dataset])
    flood_val_combined = ConcatDataset([flood_val_dataset, space_val_dataset])


    # Create dataloaders
    print("Creating dataloaders...")

    flood_train_loader = make_safe_dataloader(
        flood_train_combined,
        config,
        config.batch_size,
        True
    )

    flood_val_loader = make_safe_dataloader(
        flood_val_combined,
        config,
        config.batch_size * 2,
        False
    )

    flood_test_loader = make_safe_dataloader(
        flood_test_dataset,
        config,
        config.batch_size * 2,
        False
    )

    damage_train_loader = make_safe_dataloader(
        damage_train_dataset,
        config,
        config.batch_size,
        True
    )

    damage_val_loader = make_safe_dataloader(
        damage_val_dataset,
        config,
        config.batch_size * 2,
        False
    )

    damage_test_loader = make_safe_dataloader(
        damage_test_dataset,
        config,
        config.batch_size * 2,
        False
    )

    val_loader = {
        "flood": flood_val_loader,
        "damage": damage_val_loader
    }

    test_loader = {
        "flood": flood_test_loader,
        "damage": damage_test_loader
    }

    # Initialize model
    print("\nInitializing model...")
    model = EnhancedDisasterModel(
        num_classes_flood=config.num_classes_flood,
        num_classes_damage=config.num_classes_damage,
        backbone=config.backbone,
        config=config
    )
    model = model.to(device)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    # Initialize trainer
    trainer = OptimizedTrainer(model, config, device)

    ckpt_manager = CheckpointManager(model, config, config.checkpoint_dir)

    # Optional: Resume from checkpoint
    best_combined_score = 0.0
    ckpt_path = os.path.join(config.checkpoint_dir, 'best_model.pt')
    start_epoch = 0

    if os.path.exists(ckpt_path):
        if getattr(config, "auto_resume", False):
            # Auto-resume without prompting
            print(f"Found checkpoint at {ckpt_path}. Auto-resuming...")
            start_epoch = ckpt_manager.load_checkpoint(trainer.model, trainer.optimizer,  trainer.scaler, 'best_model.pt')
            if getattr(trainer, "best_flood_iou", None) is not None:
                best_combined_score = (trainer.best_flood_iou + trainer.best_damage_iou) / 2.0
        else:
        # Manual prompt
            response = input("Found existing checkpoint. Resume training? (y/n): ")
            if response.lower() == 'y':
                start_epoch = ckpt_manager.load_checkpoint(trainer.model, trainer.optimizer, trainer.scaler, 'best_model.pt')
                if getattr(trainer, "best_flood_iou", None) is not None:
                    best_combined_score = (trainer.best_flood_iou + trainer.best_damage_iou) / 2.0
                print(f"Resuming from epoch {start_epoch}")
            else:
                print("Starting from scratch.")

    # Training loop
    print(f"\nStarting training from epoch {start_epoch}...")
    print(f"Total epochs: {config.num_epochs}")
    print(f"Batch size: {config.batch_size} (effective {config.batch_size * max(1, getattr(config, 'accumulation_steps', 1))})")


    for epoch in range(start_epoch, config.num_epochs):
      # Train with bars (Trainer shows per-batch bar)
      flood_loss, damage_loss, total_loss = trainer.train_epoch(flood_train_loader, damage_train_loader, epoch)

      torch.cuda.empty_cache()

      # Validate (you already use tqdm inside validate)
      flood_val_loss, flood_iou, flood_class_ious = trainer.validate(flood_val_loader, 'flood',  config.num_classes_flood)
      damage_val_loss, damage_iou, damage_class_ious = trainer.validate(damage_val_loader, 'damage', config.num_classes_damage)

      torch.cuda.empty_cache()

      trainer.val_metrics['flood_iou'].append(flood_iou)
      trainer.val_metrics['damage_iou'].append(damage_iou)

      combined_score = (flood_iou + damage_iou) / 2.0
      is_best = combined_score > best_combined_score
      if is_best:
        best_combined_score = combined_score
        trainer.best_flood_iou  = flood_iou
        trainer.best_damage_iou = damage_iou

      flood_metrics  = {'iou': flood_iou,  'loss': flood_val_loss,  'class_ious': getattr(flood_class_ious, "tolist", lambda: flood_class_ious)()}
      damage_metrics = {'iou': damage_iou, 'loss': damage_val_loss, 'class_ious': getattr(damage_class_ious, "tolist", lambda: damage_class_ious)()}

      ckpt_manager.save_checkpoint(
          model=trainer.model,
          optimizer=trainer.optimizer,
          scaler=trainer.scaler,
          epoch=epoch,
          flood_metrics=flood_metrics,
          damage_metrics=damage_metrics,
          is_best=is_best,
          train_losses=trainer.train_losses,
          val_metrics=trainer.val_metrics,
          best_flood_iou=trainer.best_flood_iou,
          best_damage_iou=trainer.best_damage_iou
      )

      ckpt_manager._cleanup_old_checkpoints()

      if (epoch + 1) % 5 == 0:
        ckpt_manager.plot_training_history()

    # Final evaluation on test set
    print("\n" + "="*60)
    print("FINAL EVALUATION ON TEST SET")
    print("="*60)

    # Test flood detection
    print("\nFlood Detection Test Results:")
    flood_test_loss, flood_test_iou, flood_test_class_ious = trainer.test(
        flood_test_loader, 'flood', config.num_classes_flood
    )

    # Test damage assessment
    print("\nDamage Assessment Test Results:")
    damage_test_loss, damage_test_iou, damage_test_class_ious = trainer.test(
        damage_test_loader, 'damage', config.num_classes_damage
    )

    # Save final results
    final_results = {
        'flood_test': {
            'mean_iou': flood_test_iou,
            'class_ious': flood_test_class_ious.tolist(),
            'loss': flood_test_loss
        },
        'damage_test': {
            'mean_iou': damage_test_iou,
            'class_ious': damage_test_class_ious.tolist(),
            'loss': damage_test_loss
        },
        'training_config': config.__dict__
    }

    import json
    with open(os.path.join(config.results_dir, 'final_results.json'), 'w') as f:
        json.dump(final_results, f, indent=2)

    print("\nTraining complete!")
    print(f"Best Flood IoU: {trainer.best_flood_iou:.4f}")
    print(f"Best Damage IoU: {trainer.best_damage_iou:.4f}")
    print(f"Results saved to {config.results_dir}")

    return model, trainer


Usage Monitor

In [35]:
class GPUMonitor:
  """Monitor GPU memory usage and prevent OOM"""

  @staticmethod
  def get_memory_info():
    """Get current GPU memory usage"""
    if torch.cuda.is_available():
      allocated = torch.cuda.memory_allocated()/ 1e9
      cached = torch.cuda.memory_reserved() / 1e9
      total = torch.cuda.get_device_properties(0).total_memory
      return {
          'allocated' : allocated,
          "cached" : cached,
          "total" : total,
          "free" : total - allocated,
          "usage_percent" : (allocated / total ) * 100
      }
    return None

  @staticmethod
  def print_memory_stats():
    """Print current GPU memory usage"""
    info = GPUMonitor.get_memory_info()
    if info:
      print(f"GPU Memory Usage: {info['allocated']:.1f}/{info['total']:.1f} GB"
      f"({info['usage_percent']:.1f}%) - Free: {info['free']:.1f} GB")

  @staticmethod
  def clear_cache_if_needed(threshold=.9):
    """Clear cache if memory usage exceeds threshold"""
    info = GPUMonitor.get_memory_info()
    if info and info ['usage_percent'] > threshold * 100:
      torch.cuda.empty_cache()
      gc.collect()
      print(f"Cleared GPU cache at {info['usage_percent']:.1f}% usage")


## Diagnostic tools

In [36]:
"""
import os
import subprocess

print("="*70)
print("DATA LOCATION CHECK")
print("="*70)

# Check if data directories exist
data_paths = [
    "/content/data/FloodNet",
    "/content/data/RescueNet",
    "/content/data/SpaceNet"
]

for path in data_paths:
    if os.path.exists(path):
        # Check if it's a symlink (bad - points to Drive)
        if os.path.islink(path):
            print(f"⚠️  {path}")
            print(f"    -> SYMLINK to {os.readlink(path)} (SLOW!)")
        else:
            # Check size and file count
            result = subprocess.run(['du', '-sh', path], capture_output=True, text=True)
            size = result.stdout.split()[0] if result.returncode == 0 else "unknown"

            file_count = subprocess.run(['find', path, '-type', 'f', '|', 'wc', '-l'],
                                       shell=True, capture_output=True, text=True)
            files = file_count.stdout.strip() if file_count.returncode == 0 else "unknown"

            print(f"✅ {path}")
            print(f"    Size: {size}, Files: {files}")
            print(f"    -> LOCAL (FAST)")
    else:
        print(f"❌ {path} - NOT FOUND!")

# Check where /content/data is mounted
print("\n" + "="*70)
print("FILESYSTEM CHECK")
print("="*70)
result = subprocess.run(['df', '-h', '/content'], capture_output=True, text=True)
print(result.stdout)

# If you see 'drive' or 'fuse' in filesystem, data is on Google Drive (slow!)
# If you see 'overlay' or '/dev/sda', data is on local disk (fast!)
"""

'\nimport os\nimport subprocess\n\nprint("="*70)\nprint("DATA LOCATION CHECK")\nprint("="*70)\n\n# Check if data directories exist\ndata_paths = [\n    "/content/data/FloodNet",\n    "/content/data/RescueNet",\n    "/content/data/SpaceNet"\n]\n\nfor path in data_paths:\n    if os.path.exists(path):\n        # Check if it\'s a symlink (bad - points to Drive)\n        if os.path.islink(path):\n            print(f"⚠️  {path}")\n            print(f"    -> SYMLINK to {os.readlink(path)} (SLOW!)")\n        else:\n            # Check size and file count\n            result = subprocess.run([\'du\', \'-sh\', path], capture_output=True, text=True)\n            size = result.stdout.split()[0] if result.returncode == 0 else "unknown"\n\n            file_count = subprocess.run([\'find\', path, \'-type\', \'f\', \'|\', \'wc\', \'-l\'],\n                                       shell=True, capture_output=True, text=True)\n            files = file_count.stdout.strip() if file_count.returncode == 0 else

In [37]:
#"""
#import os
#import subprocess

# Quick check
#print("Does data exist?")
#print(f"FloodNet: {os.path.exists('/content/data/FloodNet')}")
#print(f"RescueNet: {os.path.exists('/content/data/RescueNet')}")
#print(f"SpaceNet: {os.path.exists('/content/data/SpaceNet')}")

# Check if symlinks
#print("\nAre they symlinks?")
#print(f"FloodNet: {os.path.islink('/content/data/FloodNet')}")

# Check filesystem
#print("\nFilesystem:")
#result = subprocess.run(['df', '-h', '/content/data'], capture_output=True, text=True)
#print(result.stdout)
#"""

## Initiate Model

In [38]:
if __name__ == '__main__':
  config = Config()
  model, trainer = train_complete_model(config)

✓ A100 detected: NVIDIA A100-SXM4-80GB
  Memory: 85.2 GB
Directories initialized in /content
Device: cuda

Validating configuration...
✓ Configuration validated
Using device: cuda
GPU: NVIDIA A100-SXM4-80GB
Memory: 85.2 GB

Creating datasets...


  original_init(self, **validated_kwargs)


FloodNet dataset: 1445 images, 1445 with masks
FloodNet dataset: 450 images, 450 with masks
FloodNet dataset: 448 images, 448 with masks

Computing class weights...
🔄 Computing floodnet_train class weights (this will take a while)...
Fast class weight computation (bypassing augmentation)...


Analyzing masks:   0%|          | 0/1445 [00:00<?, ?it/s]

✅ Cached to /content/cache/floodnet_train_class_weights.pkl
🔄 Computing rescuenet_train class weights (this will take a while)...
Fast class weight computation (bypassing augmentation)...


Analyzing masks:   0%|          | 0/3595 [00:00<?, ?it/s]

✅ Cached to /content/cache/rescuenet_train_class_weights.pkl
Flood class weights: tensor([nan, nan])
Damage class weights: tensor([nan, nan, nan, nan])
Prioritized 158 samples by disaster content likelihood
Preloading top 10 disaster samples to A100 GPU memory...
  Adjusted preload count to 10 based on available memory
Successfully preloaded 10 disaster-rich samples to GPU
Disaster-focused SpaceNet8: 158 samples
Cache size: 50, GPU preloaded: 10
Prioritized 44 samples by disaster content likelihood
Preloading top 10 disaster samples to A100 GPU memory...
  Adjusted preload count to 10 based on available memory
Successfully preloaded 10 disaster-rich samples to GPU
Disaster-focused SpaceNet8: 44 samples
Cache size: 50, GPU preloaded: 10
Creating dataloaders...

Initializing model...




Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth


100%|██████████| 233M/233M [00:01<00:00, 198MB/s]


Using ResNet101 backbone
Total parameters: 63,422,632
Trainable parameters: 63,422,632
A100 optimizations enabled
Disaster mapping trainer initialized for A100

Starting training from epoch 0...
Total epochs: 50
Batch size: 32 (effective 128)
Disaster mapping training epoch 0: 113 iterations


Epoch 1:   0%|          | 0/113 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
!nvidia-smi --query-gpu=utilization.gpu,utilization.memory,memory.used,memory.total --format=csv

##Model Demonstration

In [None]:
# def test_on_new_image_with_tiff(model_path, image_path):
#     import torch
#     import cv2
#     import numpy as np
#     import matplotlib.pyplot as plt
#     from PIL import Image
#     import os

#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#     try:
#         # Load model
#         checkpoint = torch.load(model_path, map_location=device, weights_only=False)

#         model = EnhancedDisasterModel(
#             num_classes_flood=2,
#             num_classes_damage=4,
#             backbone='resnet101'
#         )
#         model.load_state_dict(checkpoint['model_state_dict'])
#         model.to(device)
#         model.eval()

#         # Check file extension and load accordingly
#         file_ext = os.path.splitext(image_path)[1].lower()
#         print(f"  → Processing {file_ext} file: {os.path.basename(image_path)}")

#         if file_ext in ['.tif', '.tiff']:
#             # Use PIL for TIFF files (better TIFF support)
#             print("  → Using PIL for TIFF loading...")
#             try:
#                 image_pil = Image.open(image_path)

#                 # Handle different TIFF modes
#                 if image_pil.mode == 'RGB':
#                     image = np.array(image_pil)
#                 elif image_pil.mode == 'RGBA':
#                     # Convert RGBA to RGB
#                     image = np.array(image_pil.convert('RGB'))
#                 elif image_pil.mode in ['L', 'P']:
#                     # Convert grayscale or palette to RGB
#                     image = np.array(image_pil.convert('RGB'))
#                 elif image_pil.mode == 'I' or image_pil.mode == 'F':
#                     # Handle 32-bit integer or float images
#                     arr = np.array(image_pil)
#                     # Normalize to 0-255 range
#                     arr = ((arr - arr.min()) / (arr.max() - arr.min()) * 255).astype(np.uint8)
#                     image = np.stack([arr, arr, arr], axis=-1)  # Convert to RGB
#                 else:
#                     print(f"  → Converting from mode {image_pil.mode} to RGB")
#                     image = np.array(image_pil.convert('RGB'))

#                 print(f"  → TIFF image loaded: {image.shape}, dtype: {image.dtype}")

#             except Exception as e:
#                 print(f"  ✗ PIL failed, trying OpenCV for TIFF: {e}")
#                 # Fallback to OpenCV
#                 image = cv2.imread(image_path, cv2.IMREAD_COLOR)
#                 if image is None:
#                     raise ValueError(f"Could not load TIFF image from {image_path}")
#                 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

#         else:
#             # Use OpenCV for standard formats (PNG, JPG, JPEG)
#             print("  → Using OpenCV for standard image loading...")
#             image = cv2.imread(image_path)
#             if image is None:
#                 raise ValueError(f"Could not load image from {image_path}")
#             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

#         print(f"  → Final image shape: {image.shape}, dtype: {image.dtype}")

#         # Ensure image is in correct format (0-255, uint8)
#         if image.dtype != np.uint8:
#             if image.max() <= 1.0:
#                 # Image is in 0-1 range, convert to 0-255
#                 image = (image * 255).astype(np.uint8)
#             else:
#                 # Image might be in different range, normalize
#                 image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)

#         # Apply transforms
#         transform = get_validation_augmentation()
#         input_tensor = transform(image=image)['image']
#         input_tensor = input_tensor.unsqueeze(0).to(device)

#         # Make predictions
#         with torch.no_grad():
#             flood_out, damage_out = model(input_tensor)

#         flood_pred = torch.argmax(flood_out, dim=1)[0].cpu().numpy()
#         damage_pred = torch.argmax(damage_out, dim=1)[0].cpu().numpy()

#         # Visualize results
#         fig, axes = plt.subplots(1, 3, figsize=(15, 5))

#         axes[0].imshow(image)
#         axes[0].set_title(f'Original Image ({file_ext})')
#         axes[0].axis('off')

#         axes[1].imshow(flood_pred, cmap='Blues')
#         axes[1].set_title('Flood Prediction')
#         axes[1].axis('off')

#         axes[2].imshow(damage_pred, cmap='Reds')
#         axes[2].set_title('Damage Prediction')
#         axes[2].axis('off')

#         plt.tight_layout()
#         plt.show()

#         return flood_pred, damage_pred

#     except Exception as e:
#         print(f"  ✗ ERROR processing {image_path}: {str(e)}")
#         return None, None


# import os
# import random

# def test_on_images_with_tiff(model_path, folder_path, num_img=10):
#     """
#     Test the disaster model on multiple images including TIFF files.

#     Supported formats: PNG, JPG, JPEG, TIF, TIFF
#     """
#     try:
#         # Updated file extension list to include TIFF
#         image_extensions = (".png", ".jpg", ".jpeg", ".tif", ".tiff")
#         image_files = [f for f in os.listdir(folder_path)
#                       if f.lower().endswith(image_extensions)]

#         if not image_files:
#             print(f"No supported image files found in {folder_path}")
#             print(f"Looking for: {image_extensions}")
#             all_files = os.listdir(folder_path)[:10]  # Show first 10 files
#             print(f"Files in folder: {all_files}")
#             return []

#         print(f"Found {len(image_files)} supported images")

#         # Group by file type for info
#         file_types = {}
#         for f in image_files:
#             ext = os.path.splitext(f)[1].lower()
#             file_types[ext] = file_types.get(ext, 0) + 1

#         print(f"File type breakdown: {file_types}")

#         # Sample random images
#         sample_files = random.sample(image_files, min(num_img, len(image_files)))
#         print(f"Processing {len(sample_files)} images...")

#         results = []
#         successful_predictions = 0

#         for i, file_name in enumerate(sample_files, 1):
#             file_path = os.path.join(folder_path, file_name)
#             print(f"\n[{i}/{len(sample_files)}] Processing: {file_name}")

#             flood_pred, damage_pred = test_on_new_image_with_tiff(model_path, file_path)

#             if flood_pred is not None and damage_pred is not None:
#                 results.append({
#                     "file": file_name,
#                     "flood_mask": flood_pred,
#                     "damage_mask": damage_pred
#                 })
#                 successful_predictions += 1
#                 print(f"  ✓ Success!")
#             else:
#                 print(f"  ✗ Failed")

#         print(f"\nCompleted: {successful_predictions}/{len(sample_files)} images processed successfully")
#         return results

#     except Exception as e:
#         print(f"Error in test_on_images_with_tiff: {str(e)}")
#         return []


# # Quick check function to see what file types are in your folder
# def check_file_types(folder_path):
#     """Check what file types are present in the folder"""
#     print(f"=== File Analysis for: {folder_path} ===")

#     if not os.path.exists(folder_path):
#         print(f"Folder does not exist!")
#         return

#     all_files = os.listdir(folder_path)
#     print(f"Total files: {len(all_files)}")

#     # Group by extension
#     extensions = {}
#     for f in all_files:
#         ext = os.path.splitext(f)[1].lower()
#         if not ext:
#             ext = "(no extension)"
#         extensions[ext] = extensions.get(ext, 0) + 1

#     print(f"\nFile extensions found:")
#     for ext, count in sorted(extensions.items()):
#         print(f"  {ext}: {count} files")

#     # Check for TIFF specifically
#     tiff_files = [f for f in all_files if f.lower().endswith(('.tif', '.tiff'))]
#     if tiff_files:
#         print(f"\nTIFF files found: {len(tiff_files)}")
#         print(f"First few TIFF files: {tiff_files[:5]}")
#     else:
#         print(f"\nNo TIFF files found.")


# # Usage example
# if __name__ == "__main__":
#     model_path = "/content/checkpoints/best_model.pth"
#     folder_path = "/content/drive/MyDrive/G.E.M.S./FloodNet/FloodNet-Supervised_v1.0/test/test-label-img/"

#     # First check what file types you have
#     check_file_types(folder_path)

#     # Then run with TIFF support
#     if os.path.exists(model_path) and os.path.exists(folder_path):
#         results = test_on_images_with_tiff(model_path, folder_path, num_img=10)
#         print(f"Processing complete. Got results for {len(results)} images.")
#     else:
#         print("Missing model or folder path")

In [None]:
#!pip install nbconvert
#!jupyter nbconvert --to html Hermes.0.3.6.ipynb

In [None]:
# print(f"Flood pixels: {(flood_mask == 1).sum()}")
# print(f"Damage distribution: {np.bincount(damage_mask.flatten())}")