# HERMES 0.3

In [None]:
#import signal
#os.kill(os.getpid(), signal.SIGKILL)
#!rm -rf /content/*

In [None]:
!pip install albumentations tqdm segmentation-models-pytorch torchvision
!pip install --upgrade "docutils>=0.20,<0.22"
!pip install awscli -q
!pip install rasterio

## IMPORTS AND DEPENDENCIES

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.amp import GradScaler, autocast
from torch.nn.parallel import DataParallel

#Computer Vision and image processing
import cv2
import numpy as np
from PIL import Image

#Torchvision for pretrained models and transforms
import torchvision
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3_resnet101

#Data augmentation library
import albumentations as A
from albumentations.pytorch import ToTensorV2

#Utilities
import os
import glob
import time
import zipfile
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import json
import pickle
import pandas as pd
import geopandas as gpd
import rasterio as rio
from tqdm import tqdm
from itertools import cycle
from functools import lru_cache
from collections import defaultdict
from datetime import timedelta
import psutil

#Optional: Weights & Biases for experiment tracking
#try:
  #import wandb
  #WANDB_AVAILABLE = True
#except ImportError:
  #WANDB_AVAILABLE = False
  #print("Warning: Weights & Biases not installed. Please install with `pip install wandb`")

## Config


In [None]:
#Set random seeds for reproducbility
def set_seed(seed=42):
  """Set random seeds for reproducibility"""
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

set_seed(42)

#Configuration and Parameters

class Config:
  #model architecture
  backbone = 'resnet101'
  num_classes_flood = 2
  num_classes_damage = 4
  #training parameters
  batch_size= 32
  accumulation_steps = 6
  num_epochs = 20
  learning_rate = 8e-5
  weight_decay = 1e-4
  #data loading
  num_workers = 12
  pin_memory = True
  prefetch_factor = 6
  persistent_workers = True
  spacenet_cache_size=500
  # image dimensions
  img_height = 512
  img_width = 512
  # loss weighs
  ce_weight = 0.3
  dice_weight = 0.4
  focal_weight = 0.3
  # task weights for mulit-tasking learning
  flood_task_weight = 0.6
  damage_task_weight = 0.4

  # Paths (update these based on your setup will need to edit)
  Data_Root = '/content'
  Checkpoint_Dir = '/content/checkpoints'
  Results_Dir = '/content/results'

  # Create directories
  os.makedirs(Checkpoint_Dir, exist_ok=True)
  os.makedirs(Results_Dir, exist_ok=True)

In [None]:
#Pulling FloodNet dataset from GitHub DropBox due to file size at ~12GB

In [None]:
#!wget "https://www.dropbox.com/scl/fo/k33qdif15ns2qv2jdxvhx/ANGaa8iPRhvlrvcKXjnmNRc?rlkey=ao2493wzl1cltonowjdbrnp7f&e=3&st=6lg4ncwc&dl=1"

In [None]:
# import zipfile

# zip_path = '/content/ANGaa8iPRhvlrvcKXjnmNRc?rlkey=ao2493wzl1cltonowjdbrnp7f&e=3&st=6lg4ncwc&dl=1'
# extract_path = '/content/FloodNet'

# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#   zip_ref.extractall(extract_path)

#   print("Extraction complete.")
#   print(os.listdir('/content/FloodNet'))

In [None]:
#Defining dataset pathways
FloodNet_train_img_dir = '/content/drive/MyDrive/G.E.M.S./FloodNet/FloodNet-Supervised_v1.0/train/train-org-img'
FloodNet_train_mask_dir = '/content/drive/MyDrive/G.E.M.S./FloodNet/FloodNet-Supervised_v1.0/train/train-label-img'


FloodNet_val_img_dir ='/content/drive/MyDrive/G.E.M.S./FloodNet/FloodNet-Supervised_v1.0/val/val-org-img'
FloodNet_val_mask_dir = '/content/drive/MyDrive/G.E.M.S./FloodNet/FloodNet-Supervised_v1.0/val/val-label-img'

FloodNet_test_img_dir = '/content/drive/MyDrive/G.E.M.S./FloodNet/FloodNet-Supervised_v1.0/test/test-org-img'
FloodNet_test_mask_dir = '/content/drive/MyDrive/G.E.M.S./FloodNet/FloodNet-Supervised_v1.0/test/test-label-img'

In [None]:
# os.remove("/content/ANGaa8iPRhvlrvcKXjnmNRc?rlkey=ao2493wzl1cltonowjdbrnp7f&e=3&st=6lg4ncwc&dl=1")
# print("Zip File Deleted")

In [None]:
#with zipfile.ZipFile("/content/drive/MyDrive/G.E.M.S./.zip", "r") as zip_ref:
  #zip_ref.extractall("/content/SN8")

#print("Extraction complete.")

In [None]:
SpaceNet8_train_img_dir = '/drive/MyDrive/G.E.M.S./sn8/images/train'
SpaceNet8_train_mask_dir = '/drive/MyDrive/G.E.M.S./sn8/masks/train'

SpaceNet8_val_img_dir = '/drive/MyDrive/G.E.M.S./sn8/images/val'
SpaceNet8_val_mask_dir = '/drive/MyDrive/G.E.M.S./sn8/masks/val'

In [None]:
#!pip install gdown
# import gdown

In [None]:
#!gdown --fuzzy "https://drive.google.com/file/d/1iRkEX9LQ8Hi-38QMyaReFJ8wXDDyYJAg/view?usp=sharing" -O RescueNet.zip

In [None]:
# import zipfile

# with zipfile.ZipFile("/content/drive/MyDrive/G.E.M.S./RescueNet.zip", "r") as zip_ref:
#      zip_ref.extractall("/content/RescueNet")
# \
# print("Extraction complete.")

In [None]:
print(os.listdir('/content/drive/MyDrive/G.E.M.S.'))
#print(os.listdir('/content/RescueNet'))

In [None]:
RescueNet_train_img_dir = '/content/drive/MyDrive/G.E.M.S./RescueNet/train/train-org-img'
RescueNet_train_mask_dir = '/content/drive/MyDrive/G.E.M.S./RescueNet/train/train-label-img'

RescueNet_val_img_dir = '/content/drive/MyDrive/G.E.M.S./RescueNet/val/val-org-img'
RescueNet_val_mask_dir = '/content/drive/MyDrive/G.E.M.S./RescueNet/val/val-label-img'

RescueNet_test_img_dir = '/content/drive/MyDrive/G.E.M.S./RescueNet/test/test-org-img'
RescueNet_test_mask_dir = '/content/drive/MyDrive/G.E.M.S./RescueNet/test/test-label-img'

In [None]:
#os.remove("/content/RescueNet.zip")
#print("Files Deleted")

In [None]:
sn8_path = "/content/drive/MyDrive/G.E.M.S./sn8"
sn8_train_manifest = "/content/drive/MyDrive/G.E.M.S./sn8/manifests/train_manifest.csv"
sn8_val_manifest = "/content/drive/MyDrive/G.E.M.S./sn8/manifests/val_manifest.csv"

## Loading Datasets

In [None]:
class FloodNetDataset(Dataset):
    """
    Dataset class for FloodNet flood detection data

    Expected file structure:
    - Images: original disaster images
    - Masks: binary masks (0: non-flooded, 1: flooded)
    """
    FLOODNET_COLORS = {
        (0,0,0):0, #Unlabeled
        (128,0,0):1, #Building-flooded
        (0,128,0):2, # Buildings-non-flooded
        (128,128,0):3, #Road-flooded
        (0,0,128):4, #Road-non-flooded
        (128,0,128):5, #Water
        (0,128,128):6, #Tree
        (128,128,128):7,#Vehicle
        (64,0,0):8, #Pool
        (192,0,0):9, #Grass

    }


    def __init__(self, img_dir, mask_dir, transform=None, binary_flood=True):
        """
        Args:
            img_dir: Directory containing input images
            mask_dir: Directory containing segmentation masks
            transform: Albumentations transform pipeline
        """
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_names = sorted(os.listdir(img_dir))
        self.transform = transform
        self.binary_flood= binary_flood

        # Filter images that have corresponding masks
        all_imgs = sorted(os.listdir(img_dir))
        self.img_names = []

        for img_name in all_imgs:
          base_name = img_name.rsplit(".", 1)[0]
          pattern = os.path.join(mask_dir, base_name + "_*.png")
          if glob.glob(pattern):
            self.img_names.append(img_name)
          else:
            print(f"Warning: No mask found for {img_name}, excluding from dataset")

        print(f"FloodNet dataset: {len(all_imgs)} images, {len(self.img_names)} with masks")


    def rgb_to_class(self, mask_rgb):
        """Convert RGB mask to class indices"""
        h, w = mask_rgb.shape[:2]
        mask_class = np.zeros((h,w), dtype=np.uint8)

        for rgb, class_id in self.FLOODNET_COLORS.items():
            #Find pixels matching this RGB value
            matches = np.all(mask_rgb == rgb, axis=-1)
            mask_class[matches] = class_id

        return mask_class

    def create_flood_mask(self, class_mask):
      """Create binary flood/non-_flood mask from classes"""
      flood_mask = np.zeros_like(class_mask, dtype=np.uint8)
      flooded_classes = [1,3,5] # Building-flooded


      for class_id in flooded_classes:
        flood_mask[class_mask == class_id] = 1

      return flood_mask

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        """Load image and mask pair with error handling"""
        img_name = self.img_names[idx]

        # Extract base filename without extension
        base_name = img_name.rsplit(".", 1)[0]

        # Find matching mask (FloodNet naming pattern)
        pattern = os.path.join(self.mask_dir, base_name + "_*.png")
        matching_masks = glob.glob(pattern)

        if len(matching_masks) == 0:
            raise FileNotFoundError(f"No matching mask for: {img_name}")

        mask_path = matching_masks[0]
        img_path = os.path.join(self.img_dir, img_name)

        # Load image (OpenCV loads as BGR, convert to RGB)
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Failed to load image: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Load mask as grayscale
        mask_rgb = cv2.imread(mask_path)
        if mask_rgb is None:
            raise FileNotFoundError(f"Failed to load mask: {mask_path}")
        mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_BGR2RGB)

        # Convert RGB mask to class indices
        class_mask = self.rgb_to_class(mask_rgb)

        # Create appropriate output mask
        if self.binary_flood:
          mask = self.create_flood_mask(class_mask)
        else:
          mask = class_mask


        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        return image, mask.long()

In [None]:
class RescueNetDataset(Dataset):
    """
    Dataset class for RescueNet damage assessment data

    Expected file structure:
    - Images: original disaster images
    - Masks: multi-class masks (0-3 for damage levels)
    """
    def __init__(self, img_dir, mask_dir, transform=None):
        """
        Args:
            img_dir: Directory containing input images
            mask_dir: Directory containing segmentation masks
            transform: Albumentations transform pipeline
        """
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_names = sorted(os.listdir(img_dir))
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        """Load image and mask pair with error handling"""
        img_name = self.img_names[idx]

        # Extract base filename without extension
        base_name = img_name.rsplit(".", 1)[0]

        # Find matching mask (RescueNet naming pattern)
        pattern = os.path.join(self.mask_dir, base_name + "_lab*.png")
        matching_masks = glob.glob(pattern)

        if len(matching_masks) == 0:
            raise FileNotFoundError(f"No matching mask for: {img_name}")

        mask_path = matching_masks[0]
        img_path = os.path.join(self.img_dir, img_name)

        # Load image
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Failed to load image: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Load mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Failed to load mask: {mask_path}")

        # Ensure mask values are in valid range (0-3)
        mask = np.clip(mask, 0, 3).astype(np.uint8)

        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        return image, mask.long()

In [None]:
class SpaceNet8Dataset(Dataset):
    def __init__(self, manifest_csv, augment=None,
                 cache_size=500, disaster_focus=True, preload_critical_samples=True):
        self.df = pd.read_csv(manifest_csv)
        self.augment = augment
        self.disaster_focus = disaster_focus
        self.preload_critical_samples = preload_critical_samples

        # Quick LRU cache for frequently accessed files
        from functools import lru_cache
        self._load_file = lru_cache(maxsize=cache_size)(self._load_file_uncached)

        if disaster_focus:
            self._prioritize_disaster_samples()  # Added underscore

        self.gpu_preloaded = {}
        if preload_critical_samples and torch.cuda.is_available():
            self._preload_disaster_samples()  # Added underscore

        print(f"Disaster-focused SpaceNet8: {len(self.df)} samples")
        print(f"Cache size: {cache_size}, GPU preloaded: {len(self.gpu_preloaded)}")

    def _prioritize_disaster_samples(self):
        """Prioritize samples likely to contain flood/damage for training efficiency"""
        file_sizes = []
        for _, row in self.df.iterrows():
            try:
                mask_size = os.path.getsize(row["mask_path"]) if os.path.exists(row["mask_path"]) else 0
                file_sizes.append(mask_size)
            except:
                file_sizes.append(0)

        self.df["mask_size"] = file_sizes
        self.df = self.df.sort_values(by='mask_size', ascending=False).reset_index(drop=True)
        print(f"Prioritized {len(self.df)} samples by disaster content likelihood")

    def _preload_disaster_samples(self, preload_count=100):
        """Preload disaster-rich samples to GPU for faster training"""
        print(f"Preloading top {preload_count} disaster samples to A100 GPU memory...")

        for idx in range(min(preload_count, len(self.df))):
            row = self.df.iloc[idx]
            try:
                img, mask = self._load_file_uncached(row["post_path"], row["mask_path"])

                img_tensor = torch.from_numpy(np.moveaxis(img, 0, -1)).float() / 255.0
                mask_tensor = torch.from_numpy(mask).long()

                self.gpu_preloaded[idx] = (
                    img_tensor.cuda(non_blocking=True),
                    mask_tensor.cuda(non_blocking=True)
                )
            except Exception as e:
                print(f"Failed to preload sample {idx}: {e}")

        print(f"Preloaded {len(self.gpu_preloaded)} disaster-rich samples to GPU")

    def _load_file_uncached(self, post_path, mask_path):
        """Load and cache raw file for optimization"""
        try:
            with rio.Env(
                GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR",
                GDAL_CACHEMAX=2048,
                CPL_VSIL_CURL_ALLOWED_EXTENSIONS='.tif,.tiff',  # Fixed comma
                GDAL_NUM_THREADS='ALL_CPUS'
            ):
                with rio.open(post_path) as src:
                    # (C,H,W) order - FIXED INDENTATION
                    img = src.read([1, 2, 3])
                with rio.open(mask_path) as src:
                    mask = src.read(1)
            return img.astype(np.uint8), mask.astype(np.int64)  # Fixed dtype

        except Exception as e:
            print(f"Failed to load file {os.path.basename(post_path)}: {e}")
            return (np.zeros((3, 512, 512), dtype=np.uint8),
                    np.zeros((512, 512), dtype=np.int64))  # Fixed dtype

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Check GPU preloaded samples first (fastest)
        if idx in self.gpu_preloaded:
            img_tensor, mask_tensor = self.gpu_preloaded[idx]

            if self.augment:
                img_np = img_tensor.cpu().numpy().transpose(1, 2, 0)  # Fixed comma
                mask_np = mask_tensor.cpu().numpy()
                augmented = self.augment(image=img_np, mask=mask_np)
                return augmented["image"], augmented["mask"]

            return img_tensor, mask_tensor

        # If not in cache, load from disk
        row = self.df.iloc[idx]
        img, mask = self._load_file(row["post_path"], row["mask_path"])

        # Convert to HWC for Albumentations
        img = np.moveaxis(img, 0, -1)
        binary_flood_mask = (mask > 0).astype(np.int64)

        # Apply augmentation if given
        if self.augment:
            augmented = self.augment(image=img, mask=binary_flood_mask)  # Fixed mask variable
            return augmented["image"], augmented["mask"]

        # Back to CHW + normalize
        else:
            img_tensor = torch.from_numpy(np.moveaxis(img, -1, 0)).float() / 255.0
            mask_tensor = torch.from_numpy(binary_flood_mask).long()
            return img_tensor, mask_tensor


class SpaceNet8CompatWrapper(Dataset):
    """Wrapper to make SpaceNet8 compatible with FloodNet"""
    def __init__(self, spacenet_dataset):
        self.spacenet_dataset = spacenet_dataset

    def __len__(self):
        return len(self.spacenet_dataset)

    def __getitem__(self, idx):
        img_tensor, mask_tensor = self.spacenet_dataset[idx]
        # Ensure mask is binary for flood detection (0=no flood, 1=flood)
        # SpaceNet8 uses: 0=no building, 1=no damage, 2=minor, 3=major, 4=destroyed
        # Convert to binary: 0=no flood, 1=flood (any building damage indicates potential flooding)
        binary_mask = (mask_tensor > 0).long()

        return img_tensor, binary_mask

## Data Augmentation Pipeline

In [None]:
def get_training_augmentation():
    """
    Training augmentation pipeline with various geometric and color transforms
    """
    return A.Compose([
        # Resize to target dimensions
        A.Resize(Config.img_height, Config.img_width),

        # Geometric transforms
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),  # Less common but useful for aerial images
        A.RandomRotate90(p=0.5),

        # Slight rotations and shifts
        A.ShiftScaleRotate(
            shift_limit=0.1,
            scale_limit=0.1,
            rotate_limit=15,
            border_mode=cv2.BORDER_REFLECT,
            p=0.5
        ),

        # Color augmentations
        A.RandomBrightnessContrast(
            brightness_limit=0.2,
            contrast_limit=0.2,
            p=0.5
        ),
        A.ColorJitter(
            brightness=0.1,
            contrast=0.1,
            saturation=0.1,
            hue=0.05,
            p=0.3
        ),

        # Weather effects (useful for disaster scenarios)
        A.RandomRain(p=0.1),
        A.RandomFog(p=0.1),

        # Noise and blur
        A.GaussNoise(noise_scale_factor=0.1, p=0.2),
        A.GaussianBlur(blur_limit=(3, 7), p=0.2),

        # Normalize with ImageNet statistics
        A.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225)
        ),

        # Convert to PyTorch tensor
        ToTensorV2()
    ])

In [None]:
def get_validation_augmentation():
    """
    Validation augmentation pipeline (only essential transforms)
    """
    return A.Compose([
        A.Resize(Config.img_height, Config.img_width),
        A.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225)
        ),
        ToTensorV2()
    ])

In [None]:
def get_test_augmentation():
    """
    Test augmentation pipeline (same as validation)
    """
    return get_validation_augmentation()

## Attention models

In [None]:
class ChannelAttention(nn.Module):
    """
    Channel attention module to focus on important feature channels
    Squeeze-and-Excitation style attention mechanism
    """
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # Shared MLP
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()

        # Average pooling path
        avg_out = self.fc(self.avg_pool(x).view(b, c))

        # Max pooling path
        max_out = self.fc(self.max_pool(x).view(b, c))

        # Combine and apply attention
        attention = avg_out + max_out
        return x * attention.view(b, c, 1, 1)

In [None]:
class SpatialAttention(nn.Module):
    """
    Spatial attention module to focus on important spatial regions
    """
    def __init__(self, kernel_size=7):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Channel-wise average and max
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)

        # Concatenate and convolve
        x_cat = torch.cat([avg_out, max_out], dim=1)
        attention = self.sigmoid(self.conv(x_cat))

        return x * attention

# Multi-Task Model

In [None]:
class EnhancedDisasterModel(nn.Module):
    """
    Enhanced multi-task segmentation model for disaster assessment

    Features:
    - Shared backbone for feature extraction
    - Task-specific heads with attention mechanisms
    - Feature fusion for cross-task learning
    - Deep supervision options
    """
    def __init__(self, num_classes_flood=2, num_classes_damage=4, backbone='resnet101'):
        super().__init__()

        # Initialize backbone based on configuration
        if backbone == 'resnet101':
            base_model = deeplabv3_resnet101(pretrained=True)
            print("Using ResNet101 backbone")
        else:
            base_model = deeplabv3_resnet50(pretrained=True)
            print("Using ResNet50 backbone")

        # Extract backbone and ASPP module
        self.backbone = base_model.backbone
        self.aspp = base_model.classifier[0]

        # Attention mechanisms for feature refinement
        self.channel_attention = ChannelAttention(256)
        self.spatial_attention = SpatialAttention()

        # Task-specific feature extraction branches
        # Flood detection branch
        self.flood_branch = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.3),

            nn.Conv2d(512, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.2),

            nn.Conv2d(256, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        # Damage assessment branch
        self.damage_branch = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.3),

            nn.Conv2d(512, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.2),

            nn.Conv2d(256, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        # Classification heads
        self.flood_classifier = nn.Conv2d(128, num_classes_flood, 1)
        self.damage_classifier = nn.Conv2d(128, num_classes_damage, 1)

        # Optional: Cross-task feature fusion
        self.enable_fusion = True
        if self.enable_fusion:
            self.fusion_conv = nn.Sequential(
                nn.Conv2d(256, 128, 1, bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, 256, 1, bias=False),
                nn.BatchNorm2d(256),
                nn.Sigmoid()
            )

    def forward(self, x, return_features=False):
        """
        Forward pass through the network

        Args:
            x: Input tensor [B, 3, H, W]
            return_features: Whether to return intermediate features

        Returns:
            flood_out: Flood segmentation output [B, 2, H, W]
            damage_out: Damage segmentation output [B, 4, H, W]
            features (optional): Intermediate features for visualization
        """
        # Store input shape for upsampling
        input_shape = x.shape[-2:]

        # Extract multi-level features from backbone
        features = self.backbone(x)

        # Get high-level features
        x = features['out']

        # Apply ASPP for multi-scale context
        x = self.aspp(x)

        # Apply attention mechanisms
        x = self.channel_attention(x)
        x = self.spatial_attention(x)

        # Task-specific processing
        flood_features = self.flood_branch(x)
        damage_features = self.damage_branch(x)

        # Optional cross-task feature fusion
        if self.enable_fusion:
            # Concatenate task features
            combined = torch.cat([flood_features, damage_features], dim=1)

            # Generate fusion weights
            fusion_weights = self.fusion_conv(combined)

            # Apply fusion
            flood_features = flood_features + fusion_weights[:, :128] * damage_features
            damage_features = damage_features + fusion_weights[:, 128:] * flood_features

        # Generate predictions
        flood_out = self.flood_classifier(flood_features)
        damage_out = self.damage_classifier(damage_features)

        # Upsample to original resolution
        flood_out = F.interpolate(
            flood_out, size=input_shape,
            mode='bilinear', align_corners=False
        )
        damage_out = F.interpolate(
            damage_out, size=input_shape,
            mode='bilinear', align_corners=False
        )

        if return_features:
            return flood_out, damage_out, flood_features, damage_features

        return flood_out, damage_out

## Custom Loss Functions

In [None]:
class DisasterFocusedLoss(nn.Module):
    """Loss function optimized for disaster mapping class imbalances"""

    def __init__(self, ce_weight=0.3, dice_weight=0.4, focal_weight=0.3,
                 class_weights=None, focal_alpha=0.25, focal_gamma=2.0):
        super().__init__()
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        self.focal_alpha = focal_alpha
        self.focal_gamma = focal_gamma

        self.ce_loss = nn.CrossEntropyLoss(weight=class_weights)

    def forward(self, pred, target):
        # Cross-entropy with disaster class weighting
        ce = self.ce_loss(pred, target)

        # Dice loss - critical for disaster boundary accuracy
        dice = self._dice_loss_disaster(pred, target)

        # Focal loss - handle severe class imbalance in disaster data
        focal = self._focal_loss_disaster(pred, target)

        total_loss = (
            self.ce_weight * ce +
            self.dice_weight * dice +
            self.focal_weight * focal
        )

        return total_loss, {
            'ce': ce.item(),
            'dice': dice.item(),
            'focal': focal.item()
        }

    def _dice_loss_disaster(self, pred, target):
        """Dice loss optimized for disaster mapping accuracy"""
        pred_soft = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, pred.shape[1]).permute(0, 3, 1, 2).float()

        # Flatten for efficient computation
        pred_flat = pred_soft.reshape(pred_soft.shape[0], pred_soft.shape[1], -1)
        target_flat = target_one_hot.reshape(target_one_hot.shape[0], target_one_hot.shape[1], -1)

        intersection = (pred_flat * target_flat).sum(dim=2)
        union = pred_flat.sum(dim=2) + target_flat.sum(dim=2)

        # Smooth dice with small epsilon for disaster data stability
        dice = (2 * intersection + 1e-6) / (union + 1e-6)

        return 1 - dice.mean()

    def _focal_loss_disaster(self, pred, target):
        """Focal loss with disaster-specific parameters"""
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)

        # Disaster-tuned focal loss
        focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * ce_loss

        return focal_loss.mean()

## Evaluation Metrics

In [None]:
class MetricCalculator:
    """
    Calculate various segmentation metrics
    """
    @staticmethod
    def calculate_iou(pred, target, num_classes):
        """
        Calculate Intersection over Union per class
        """
        ious = []
        pred = pred.view(-1)
        target = target.view(-1)

        for cls in range(num_classes):
            pred_inds = pred == cls
            target_inds = target == cls

            intersection = (pred_inds & target_inds).sum().item()
            union = (pred_inds | target_inds).sum().item()

            if union == 0:
                ious.append(float('nan'))
            else:
                ious.append(intersection / union)

        return ious

    @staticmethod
    def calculate_dice(pred, target, num_classes):
        """
        Calculate Dice coefficient per class
        """
        dices = []
        pred = pred.view(-1)
        target = target.view(-1)

        for cls in range(num_classes):
            pred_inds = pred == cls
            target_inds = target == cls

            intersection = (pred_inds & target_inds).sum().item()
            pred_sum = pred_inds.sum().item()
            target_sum = target_inds.sum().item()

            if pred_sum + target_sum == 0:
                dices.append(float('nan'))
            else:
                dices.append(2 * intersection / (pred_sum + target_sum))

        return dices

    @staticmethod
    def calculate_pixel_accuracy(pred, target):
        """
        Calculate overall pixel accuracy
        """
        correct = (pred == target).sum().item()
        total = target.numel()
        return correct / total


## Pretraining Flood model with SpaceNet

In [None]:
# Create a separate pretraining function
def pretrain_flood_on_spacenet(model, spacenet_loader, num_epochs=10):
    print("Pretraining flood detection on SpaceNet8...")

    # Only optimize flood-related parameters
    pretrain_optimizer = torch.optim.Adam([
        {'params': model.backbone.parameters(), 'lr': 1e-5},
        {'params': model.flood_branch.parameters(), 'lr': 5e-5},
        {'params': model.flood_classifier.parameters(), 'lr': 5e-5}
    ])

    for epoch in range(num_epochs):
      for images, masks in tqdm(spacenet_loader):
        flood_out, _ = model(images)
        loss = F.cross_entropy(flood_out, masks)

        pretrain_optimizer.zero_grad()
        loss.backward()
        pretrain_optimizer.step()


## Optimized Trainer

In [None]:
class OptimizedTrainer:
    def __init__(self, model, config, device='cuda'):
        self.model = model
        self.config = config
        self.device = device

        # A100 optimizations
        self._enable_optimizations()

        # Mixed precision training
        self.scaler = GradScaler(init_scale=2**16, growth_interval=100)

        # Optimizer with task-specific learning rates
        self.optimizer = self._create_disaster_optimizer()

        # Loss functions
        self.flood_loss_fn = self._create_disaster_loss('flood')
        self.damage_loss_fn = self._create_disaster_loss('damage')

        # Metrics tracking
        self.metric_calculator = MetricCalculator()
        self.disaster_metrics = defaultdict(list)
        self.best_disaster_iou = {'flood': 0.0, 'damage': 0.0}

        # Training history tracking
        self.train_losses = {'flood': [], 'damage': [], 'total': []}
        self.val_metrics = {'flood_iou': [], 'damage_iou': []}
        self.best_flood_iou = 0.0
        self.best_damage_iou = 0.0
        self.best_combined_score = 0.0

        print("Disaster mapping trainer initialized for A100")

    def _enable_optimizations(self):
        """Enable A100-specific optimizations"""
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.cuda.set_per_process_memory_fraction(0.9)
        print("A100 optimizations enabled")

    def _create_disaster_optimizer(self):
        """Create optimizer with disaster-focused learning rates"""
        param_groups = [
            {
                'params': [p for n, p in self.model.backbone.named_parameters()],
                'lr': self.config.learning_rate * 0.1,
                'weight_decay': self.config.weight_decay
            },
            {
                'params': [p for n, p in self.model.aspp.named_parameters()],
                'lr': self.config.learning_rate * 0.3,
                'weight_decay': self.config.weight_decay * 0.5
            },
            {
                'params': list(self.model.flood_branch.parameters()) +
                          list(self.model.flood_classifier.parameters()),
                'lr': self.config.learning_rate * 1.2,
                'weight_decay': self.config.weight_decay * 0.8
            },
            {
                'params': list(self.model.damage_branch.parameters()) +
                          list(self.model.damage_classifier.parameters()),
                'lr': self.config.learning_rate * 1.5,
                'weight_decay': self.config.weight_decay * 0.6
            }
        ]

        return torch.optim.AdamW(
            param_groups,
            eps=1e-8,
            betas=(0.9, 0.999),
            amsgrad=True
        )

    def _create_disaster_loss(self, task):
        """Create disaster-focused loss function"""
        if task == 'flood':
            class_weights = torch.tensor([0.2, 3.5]).cuda()
            alpha = 0.8
        else:
            class_weights = torch.tensor([0.3, 1.5, 2.2, 2.8]).cuda()
            alpha = 0.6

        return DisasterFocusedLoss(
            ce_weight=0.3,
            dice_weight=0.4,
            focal_weight=0.3,
            class_weights=class_weights,
            focal_alpha=alpha
        )

    def train_epoch(self, flood_loader, damage_loader, epoch):
        """A100-optimized training epoch for disaster mapping"""
        self.model.train()

        # Use cycle iterators for continuous iteration
        flood_cycle = cycle(flood_loader)
        damage_cycle = cycle(damage_loader)

        # Determine number of iterations
        num_iterations = max(len(flood_loader), len(damage_loader))
        epoch_metrics = defaultdict(float)

        print(f"Disaster mapping training epoch {epoch}: {num_iterations} iterations")

        # Initialize gradient accumulation
        self.optimizer.zero_grad()

        for i in range(num_iterations):
            # Process both tasks simultaneously
            metrics = self._simultaneous_disaster_step(
                next(flood_cycle), next(damage_cycle), i
            )

            # Accumulate metrics
            for key, value in metrics.items():
                epoch_metrics[key] += value

            # Gradient update
            if (i + 1) % self.config.accumulation_steps == 0:
                self._disaster_gradient_update()

            # Progress reporting
            if i % 100 == 0 and i > 0:
                flood_loss = epoch_metrics['flood_loss'] / (i + 1)
                damage_loss = epoch_metrics['damage_loss'] / (i + 1)
                memory_gb = torch.cuda.max_memory_allocated() / 1e9

                print(f"Step {i}: Flood Loss: {flood_loss:.4f}, "
                      f"Damage Loss: {damage_loss:.4f}, GPU: {memory_gb:.1f}GB")

        # Final gradient update if needed
        if num_iterations % self.config.accumulation_steps != 0:
            self._disaster_gradient_update()

        # Calculate epoch averages
        epoch_flood_loss = epoch_metrics['flood_loss'] / num_iterations
        epoch_damage_loss = epoch_metrics['damage_loss'] / num_iterations
        epoch_total_loss = epoch_flood_loss + epoch_damage_loss

        # Store epoch statistics
        self.train_losses['flood'].append(epoch_flood_loss)
        self.train_losses['damage'].append(epoch_damage_loss)
        self.train_losses['total'].append(epoch_total_loss)

        return epoch_flood_loss, epoch_damage_loss

    def _simultaneous_disaster_step(self, flood_batch, damage_batch, step_idx):
        """Process both disaster tasks simultaneously on A100"""
        flood_imgs, flood_masks = flood_batch
        damage_imgs, damage_masks = damage_batch

        # Move to GPU with non-blocking transfer
        flood_imgs = flood_imgs.to(self.device, non_blocking=True)
        flood_masks = flood_masks.to(self.device, non_blocking=True)
        damage_imgs = damage_imgs.to(self.device, non_blocking=True)
        damage_masks = damage_masks.to(self.device, non_blocking=True)

        metrics = {}

        with autocast(device_type=self.device.type):
            # Process flood task
            flood_out_1, _ = self.model(flood_imgs)
            flood_loss, flood_components = self.flood_loss_fn(flood_out_1, flood_masks)

            # Process damage task
            _, damage_out_2 = self.model(damage_imgs)
            damage_loss, damage_components = self.damage_loss_fn(damage_out_2, damage_masks)

            # Combined loss with task weighting
            total_loss = (
                self.config.flood_task_weight * flood_loss +
                self.config.damage_task_weight * damage_loss
            ) / self.config.accumulation_steps

        # Backward pass
        self.scaler.scale(total_loss).backward()

        # Collect metrics
        metrics.update({
            'flood_loss': flood_loss.item(),
            'damage_loss': damage_loss.item(),
            'total_loss': total_loss.item() * self.config.accumulation_steps,
            'flood_dice': flood_components.get('dice', 0),
            'damage_dice': damage_components.get('dice', 0),
        })

        return metrics

    def _disaster_gradient_update(self):
        """Gradient update optimized for disaster mapping"""
        # Unscale gradients
        self.scaler.unscale_(self.optimizer)

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.5)

        # Step optimizer
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()

    def validate(self, loader, task, num_classes):
        """Validate model on specific disaster task"""
        self.model.eval()

        total_loss = 0
        all_ious = []
        all_dice = []
        all_accuracy = []

        # Disable gradient computation for efficiency
        with torch.no_grad():
            for images, masks in tqdm(loader, desc=f'Validating {task}'):
                images = images.to(self.device, non_blocking=True)
                masks = masks.to(self.device, non_blocking=True)

                # Forward pass with mixed precision
                with autocast(device_type=self.device.type):
                    flood_out, damage_out = self.model(images)

                    # Select appropriate output
                    if task == 'flood':
                        output = flood_out
                        loss, _ = self.flood_loss_fn(output, masks)
                    else:
                        output = damage_out
                        loss, _ = self.damage_loss_fn(output, masks)

                total_loss += loss.item()

                # Get predictions
                preds = torch.argmax(output, dim=1)

                # Calculate metrics for each sample in batch
                for i in range(preds.shape[0]):
                    # IoU per class
                    ious = self.metric_calculator.calculate_iou(
                        preds[i], masks[i], num_classes
                    )
                    all_ious.append(ious)

                    # Dice coefficient
                    dice = self.metric_calculator.calculate_dice(
                        preds[i], masks[i], num_classes
                    )
                    all_dice.append(dice)

                    # Pixel accuracy
                    accuracy = self.metric_calculator.calculate_pixel_accuracy(
                        preds[i], masks[i]
                    )
                    all_accuracy.append(accuracy)

        # Calculate average metrics
        avg_loss = total_loss / len(loader)

        # Calculate per-class and mean IoU
        all_ious = np.array(all_ious)
        class_ious = np.nanmean(all_ious, axis=0)
        mean_iou = np.nanmean(class_ious)

        # Calculate mean dice and accuracy
        mean_dice = np.nanmean(all_dice)
        mean_accuracy = np.mean(all_accuracy)

        # Print validation results
        print(f"\n{task.upper()} Validation Results:")
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"Mean IoU: {mean_iou:.4f}")
        print(f"Mean Dice: {mean_dice:.4f}")
        print(f"Pixel Accuracy: {mean_accuracy:.4f}")

        # Per-class IoU
        for i, iou in enumerate(class_ious):
            class_name = self._get_class_name(task, i)
            print(f"  {class_name} IoU: {iou:.4f}")

        return avg_loss, mean_iou, class_ious

    def _get_class_name(self, task, class_idx):
        """Get human-readable class names for disaster mapping"""
        if task == 'flood':
            return ['Non-Flooded', 'Flooded'][class_idx]
        else:
            return ['No Damage', 'Minor', 'Major', 'Destroyed'][class_idx]

    def test(self, loader, task, num_classes):
        """Test model - same as validate for disaster mapping"""
        return self.validate(loader, task, num_classes)

    def save_checkpoint(self, epoch, flood_metrics, damage_metrics, is_best=False):
        """Save training checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scaler_state_dict': self.scaler.state_dict(),
            'train_losses': self.train_losses,
            'val_metrics': self.val_metrics,
            'flood_metrics': flood_metrics,
            'damage_metrics': damage_metrics,
            'best_flood_iou': self.best_flood_iou,
            'best_damage_iou': self.best_damage_iou,
            'config': self.config.__dict__
        }

        # Save regular checkpoint
        checkpoint_path = os.path.join(
            self.config.Checkpoint_Dir,
            f'checkpoint_epoch_{epoch:03d}.pth'
        )
        torch.save(checkpoint, checkpoint_path)
        print(f"Saved checkpoint: {checkpoint_path}")

        # Save best model
        if is_best:
            best_path = os.path.join(
                self.config.Checkpoint_Dir,
                'best_model.pth'
            )
            torch.save(checkpoint, best_path)
            print(f"Saved best model: {best_path}")

        # Keep only last 5 checkpoints to save space
        self._cleanup_old_checkpoints(keep_last=5)

    def load_checkpoint(self, checkpoint_path):
        """Load checkpoint for resuming training"""
        print(f"Loading checkpoint from {checkpoint_path}")

        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scaler.load_state_dict(checkpoint['scaler_state_dict'])

        self.train_losses = checkpoint.get('train_losses', {'flood': [], 'damage': [], 'total': []})
        self.val_metrics = checkpoint.get('val_metrics', {'flood_iou': [], 'damage_iou': []})
        self.best_flood_iou = checkpoint.get('best_flood_iou', 0.0)
        self.best_damage_iou = checkpoint.get('best_damage_iou', 0.0)

        print(f"Resumed from epoch {checkpoint['epoch']}")
        print(f"Best Flood IoU: {self.best_flood_iou:.4f}")
        print(f"Best Damage IoU: {self.best_damage_iou:.4f}")

        return checkpoint['epoch']

    def _cleanup_old_checkpoints(self, keep_last=5):
        """Remove old checkpoints to save disk space"""
        checkpoint_files = sorted([
            f for f in os.listdir(self.config.Checkpoint_Dir)
            if f.startswith('checkpoint_epoch_') and f.endswith('.pth')
        ])

        if len(checkpoint_files) > keep_last:
            for f in checkpoint_files[:-keep_last]:
                os.remove(os.path.join(self.config.Checkpoint_Dir, f))

    def plot_training_history(self):
        """Generate training history plots"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))

        # Loss plot
        ax = axes[0, 0]
        epochs = range(1, len(self.train_losses['total']) + 1)
        ax.plot(epochs, self.train_losses['flood'], 'b-', label='Flood Loss')
        ax.plot(epochs, self.train_losses['damage'], 'r-', label='Damage Loss')
        ax.plot(epochs, self.train_losses['total'], 'g--', label='Total Loss')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.set_title('Training Loss History')
        ax.legend()
        ax.grid(True)

        # IoU plot
        ax = axes[0, 1]
        if self.val_metrics['flood_iou']:
            epochs_val = range(1, len(self.val_metrics['flood_iou']) + 1)
            ax.plot(epochs_val, self.val_metrics['flood_iou'], 'b-o', label='Flood IoU')
            ax.plot(epochs_val, self.val_metrics['damage_iou'], 'r-o', label='Damage IoU')
            ax.set_xlabel('Epoch')
            ax.set_ylabel('IoU')
            ax.set_title('Validation IoU History')
            ax.legend()
            ax.grid(True)

        # Learning rate plot
        ax = axes[1, 0]
        lrs = [group['lr'] for group in self.optimizer.param_groups]
        ax.plot([lrs[0]] * len(epochs), 'b-', label='Backbone LR')
        ax.plot([lrs[-1]] * len(epochs), 'r-', label='Head LR')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Learning Rate')
        ax.set_title('Learning Rate Schedule')
        ax.legend()
        ax.set_yscale('log')
        ax.grid(True)

        # Summary text
        ax = axes[1, 1]
        ax.axis('off')

        # Fix the string formatting
        final_flood_iou = self.val_metrics['flood_iou'][-1] if self.val_metrics['flood_iou'] else 'N/A'
        final_damage_iou = self.val_metrics['damage_iou'][-1] if self.val_metrics['damage_iou'] else 'N/A'

        summary_text = f"""
        Training Summary:

        Total Epochs: {len(self.train_losses['total'])}
        Best Flood IoU: {self.best_flood_iou:.4f}
        Best Damage IoU: {self.best_damage_iou:.4f}

        Final Training Loss:
        - Flood: {self.train_losses['flood'][-1]:.4f if self.train_losses['flood'] else 'N/A'}
        - Damage: {self.train_losses['damage'][-1]:.4f if self.train_losses['damage'] else 'N/A'}

        Final Validation IoU:
        - Flood: {final_flood_iou}
        - Damage: {final_damage_iou}
        """
        ax.text(0.1, 0.5, summary_text, fontsize=12,
              verticalalignment='center', fontfamily='monospace')

        plt.tight_layout()

        # Save plot
        plot_path = os.path.join(self.config.Results_Dir, 'training_history.png')
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"Training history plot saved to {plot_path}")

## Main Training Loop

In [None]:
def train_complete_model(config):
    """
    Complete training pipeline
    """
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    if device.type == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    # Create datasets
    print("\nCreating datasets...")


    # Create augmentation pipelines
    train_transform = get_training_augmentation()
    val_transform = get_validation_augmentation()

    # Create datasets
    flood_train_dataset = FloodNetDataset(FloodNet_train_img_dir, FloodNet_train_mask_dir, train_transform, binary_flood=True)
    flood_val_dataset = FloodNetDataset(FloodNet_val_img_dir, FloodNet_val_mask_dir, val_transform, binary_flood=True)
    flood_test_dataset = FloodNetDataset(FloodNet_test_img_dir, FloodNet_test_mask_dir, val_transform, binary_flood=True)

    damage_train_dataset = RescueNetDataset(RescueNet_train_img_dir, RescueNet_train_mask_dir, train_transform)
    damage_val_dataset = RescueNetDataset(RescueNet_val_img_dir, RescueNet_val_mask_dir, val_transform)
    damage_test_dataset = RescueNetDataset(RescueNet_test_img_dir, RescueNet_test_mask_dir, val_transform)

    space_train_dataset = SpaceNet8Dataset(sn8_train_manifest, augment=train_transform, cache_size=config.spacenet_cache_size)
    space_val_dataset = SpaceNet8Dataset(sn8_val_manifest, augment=val_transform, cache_size=config.spacenet_cache_size)

    space_train_dataset = SpaceNet8CompatWrapper(space_train_dataset)
    space_val_dataset = SpaceNet8CompatWrapper(space_val_dataset)


    # Concatenate FloodNet & SpaceNet 8  datasets
    flood_train_combined = ConcatDataset([flood_train_dataset, space_train_dataset])
    flood_val_combined = ConcatDataset([flood_val_dataset, space_val_dataset])


    # Create dataloaders
    print("Creating dataloaders...")

    flood_train_loader = DataLoader(
        flood_train_combined,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        prefetch_factor=config.prefetch_factor,
        persistent_workers=config.persistent_workers
    )

    flood_val_loader = DataLoader(
        flood_val_combined,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        prefetch_factor=config.prefetch_factor,
        persistent_workers=config.persistent_workers
    )

    flood_test_loader = DataLoader(
        flood_test_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        prefetch_factor=config.prefetch_factor,
        persistent_workers=config.persistent_workers
    )

    damage_train_loader = DataLoader(
        damage_train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        prefetch_factor=config.prefetch_factor,
        persistent_workers=config.persistent_workers
    )

    damage_val_loader = DataLoader(
        damage_val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        prefetch_factor=config.prefetch_factor,
        persistent_workers=config.persistent_workers
    )

    damage_test_loader = DataLoader(
        damage_test_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        prefetch_factor=config.prefetch_factor,
        persistent_workers=config.persistent_workers
    )

    # Initialize model
    print("\nInitializing model...")
    model = EnhancedDisasterModel(
        num_classes_flood=config.num_classes_flood,
        num_classes_damage=config.num_classes_damage,
        backbone=config.backbone
    )
    model = model.to(device)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    # Initialize trainer
    trainer = OptimizedTrainer(model, config, device)

    # Optional: Resume from checkpoint
    start_epoch = 0
    if os.path.exists(os.path.join(config.Checkpoint_Dir, 'best_model.pth')):
        response = input("Found existing checkpoint. Resume training? (y/n): ")
        if response.lower() == 'y':
            start_epoch = trainer.load_checkpoint(
                os.path.join(config.Checkpoint_Dir, 'best_model.pth')
            )

    # Training loop
    print(f"\nStarting training from epoch {start_epoch + 1}...")
    print(f"Total epochs: {config.num_epochs}")
    print(f"Batch size: {config.batch_size}")
    print(f"Effective batch size: {config.batch_size * config.accumulation_steps}")

    best_combined_score = 0.0

    for epoch in range(start_epoch + 1, config.num_epochs + 1):
        print(f"\n{'='*60}")
        print(f"EPOCH {epoch}/{config.num_epochs}")
        print('='*60)

        # Training
        flood_loss, damage_loss = trainer.train_epoch(
            flood_train_loader, damage_train_loader, epoch
        )

        # Validation (every 2 epochs to save time)
        if epoch % 2 == 0 or epoch == config.num_epochs:
            print("\nRunning validation...")

            # Validate flood task
            flood_val_loss, flood_iou, flood_class_ious = trainer.validate(
                flood_val_loader, 'flood', config.num_classes_flood
            )

            # Validate damage task
            damage_val_loss, damage_iou, damage_class_ious = trainer.validate(
                damage_val_loader, 'damage', config.num_classes_damage
            )

            # Store metrics
            trainer.val_metrics['flood_iou'].append(flood_iou)
            trainer.val_metrics['damage_iou'].append(damage_iou)

            # Check if best model
            combined_score = (flood_iou + damage_iou) / 2
            is_best = combined_score > best_combined_score

            if is_best:
                best_combined_score = combined_score
                trainer.best_flood_iou = flood_iou
                trainer.best_damage_iou = damage_iou

            # Save checkpoint
            flood_metrics = {'iou': flood_iou, 'loss': flood_val_loss, 'class_ious': flood_class_ious}
            damage_metrics = {'iou': damage_iou, 'loss': damage_val_loss, 'class_ious': damage_class_ious}

            trainer.save_checkpoint(epoch, flood_metrics, damage_metrics, is_best)

        # Plot training history
        if epoch % 5 == 0:
            trainer.plot_training_history()

    # Final evaluation on test set
    print("\n" + "="*60)
    print("FINAL EVALUATION ON TEST SET")
    print("="*60)

    # Test flood detection
    print("\nFlood Detection Test Results:")
    flood_test_loss, flood_test_iou, flood_test_class_ious = trainer.test(
        flood_test_loader, 'flood', config.num_classes_flood
    )

    # Test damage assessment
    print("\nDamage Assessment Test Results:")
    damage_test_loss, damage_test_iou, damage_test_class_ious = trainer.test(
        damage_test_loader, 'damage', config.num_classes_damage
    )

    # Save final results
    final_results = {
        'flood_test': {
            'mean_iou': flood_test_iou,
            'class_ious': flood_test_class_ious.tolist(),
            'loss': flood_test_loss
        },
        'damage_test': {
            'mean_iou': damage_test_iou,
            'class_ious': damage_test_class_ious.tolist(),
            'loss': damage_test_loss
        },
        'training_config': config.__dict__
    }

    import json
    with open(os.path.join(config.Results_Dir, 'final_results.json'), 'w') as f:
        json.dump(final_results, f, indent=2)

    print("\nTraining complete!")
    print(f"Best Flood IoU: {trainer.best_flood_iou:.4f}")
    print(f"Best Damage IoU: {trainer.best_damage_iou:.4f}")
    print(f"Results saved to {config.Results_Dir}")

    return model, trainer


Duration estimation

In [None]:
# class TrainingTimeEstimator:
#     """Estimates training time for disaster mapping model on A100"""

#     def __init__(self, config):
#         self.config = config
#         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#         # A100 performance benchmarks (empirically derived)
#         self.a100_benchmarks = {
#             'resnet50_fps': 85,   # Images per second for ResNet50 backbone
#             'resnet101_fps': 65,  # Images per second for ResNet101 backbone
#             'memory_overhead': 0.3,  # 30% overhead for caching/misc
#             'validation_factor': 0.2,  # Validation takes 20% of training time
#             'io_overhead': 0.15,     # 15% overhead for data loading
#         }

#     def estimate_training_time(self, dataset_sizes, sample_batch=True):
#         """
#         Estimate total training time for disaster mapping pipeline

#         Args:
#             dataset_sizes: Dict with 'flood_train', 'flood_val', 'damage_train', 'damage_val'
#             sample_batch: Whether to run a sample batch for accurate timing

#         Returns:
#             Dict with detailed time estimates
#         """
#         print("=" * 60)
#         print("HERMES 0.3 - A100 Training Time Estimation")
#         print("=" * 60)

#         # System checks
#         self._check_system_resources()

#         # Dataset analysis
#         total_samples = sum(dataset_sizes.values())
#         train_samples = dataset_sizes.get('flood_train', 0) + dataset_sizes.get('damage_train', 0)
#         val_samples = dataset_sizes.get('flood_val', 0) + dataset_sizes.get('damage_val', 0)

#         print(f"\nDataset Analysis:")
#         print(f"  Total training samples: {train_samples:,}")
#         print(f"  Total validation samples: {val_samples:,}")
#         print(f"  Total samples: {total_samples:,}")

#         # Model complexity analysis
#         model_complexity = self._analyze_model_complexity()

#         # Batch processing estimates
#         batch_estimates = self._estimate_batch_processing()

#         # Optional: Run sample batch for accuracy
#         if sample_batch and torch.cuda.is_available():
#             print(f"\nRunning sample batch for accurate timing...")
#             sample_timing = self._run_sample_batch()
#             # Use sample timing to calibrate estimates
#             batch_estimates['time_per_batch'] = sample_timing['time_per_batch']
#             batch_estimates['memory_usage'] = sample_timing['peak_memory']

#         # Calculate epoch timing
#         epoch_timing = self._calculate_epoch_timing(dataset_sizes, batch_estimates)

#         # Calculate total training time
#         total_timing = self._calculate_total_timing(epoch_timing)

#         # Generate detailed report
#         report = self._generate_timing_report(
#             dataset_sizes, model_complexity, batch_estimates,
#             epoch_timing, total_timing
#         )

#         return report

#     def _check_system_resources(self):
#         """Check system resources and warn about potential issues"""
#         print(f"\nSystem Resource Check:")

#         # GPU check
#         if torch.cuda.is_available():
#             gpu_name = torch.cuda.get_device_name(0)
#             gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
#             print(f"  GPU: {gpu_name}")
#             print(f"  GPU Memory: {gpu_memory:.1f} GB")

#             if 'A100' not in gpu_name:
#                 print(f"  ⚠️  WARNING: Not using A100, estimates may be inaccurate")

#             if gpu_memory < 70:
#                 print(f"  ⚠️  WARNING: Limited GPU memory, may need smaller batch sizes")

#         else:
#             print(f"  ❌ ERROR: No CUDA GPU available")
#             return False

#         # CPU and RAM check
#         cpu_count = psutil.cpu_count()
#         ram_gb = psutil.virtual_memory().total / 1e9
#         print(f"  CPU Cores: {cpu_count}")
#         print(f"  RAM: {ram_gb:.1f} GB")

#         if cpu_count < 8:
#             print(f"  ⚠️  WARNING: Limited CPU cores, data loading may be slow")

#         if ram_gb < 32:
#             print(f"  ⚠️  WARNING: Limited RAM, may impact data caching")

#         return True

#     def _analyze_model_complexity(self):
#         """Analyze model complexity for timing estimates"""
#         print(f"\nModel Complexity Analysis:")

#         # Create a minimal model to count parameters
#         try:
#             model = EnhancedDisasterModel(
#                 num_classes_flood=self.config.num_classes_flood,
#                 num_classes_damage=self.config.num_classes_damage,
#                 backbone=self.config.backbone
#             )

#             total_params = sum(p.numel() for p in model.parameters())
#             trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

#             print(f"  Backbone: {self.config.backbone}")
#             print(f"  Total parameters: {total_params:,}")
#             print(f"  Trainable parameters: {trainable_params:,}")

#             # Memory estimate
#             param_memory = total_params * 4 / 1e9  # 4 bytes per float32
#             activation_memory = self._estimate_activation_memory()
#             total_memory = param_memory + activation_memory

#             print(f"  Parameter memory: {param_memory:.2f} GB")
#             print(f"  Activation memory: {activation_memory:.2f} GB")
#             print(f"  Total model memory: {total_memory:.2f} GB")

#             del model  # Free memory
#             torch.cuda.empty_cache()

#             return {
#                 'total_params': total_params,
#                 'backbone': self.config.backbone,
#                 'memory_gb': total_memory,
#                 'complexity_factor': self._get_complexity_factor()
#             }

#         except Exception as e:
#             print(f"  ⚠️  Could not analyze model: {e}")
#             return {
#                 'total_params': 60_000_000,  # Rough estimate
#                 'backbone': self.config.backbone,
#                 'memory_gb': 8.0,
#                 'complexity_factor': 1.2
#             }

#     def _estimate_activation_memory(self):
#         """Estimate activation memory based on config"""
#         batch_size = self.config.batch_size
#         img_size = self.config.img_height * self.config.img_width

#         # Rough estimate based on typical CNN activations
#         activation_memory = (
#             batch_size * img_size * 3 * 4 +  # Input
#             batch_size * img_size * 256 * 4 +  # Feature maps
#             batch_size * img_size * 64 * 4     # Output layers
#         ) / 1e9  # Convert to GB

#         return activation_memory

#     def _get_complexity_factor(self):
#         """Get complexity factor based on backbone"""
#         factors = {
#             'resnet50': 1.0,
#             'resnet101': 1.6,
#             'resnet152': 2.2
#         }
#         return factors.get(self.config.backbone, 1.2)

#     def _estimate_batch_processing(self):
#         """Estimate batch processing time"""
#         print(f"\nBatch Processing Analysis:")

#         # Base throughput from A100 benchmarks
#         if 'resnet101' in self.config.backbone:
#             base_fps = self.a100_benchmarks['resnet101_fps']
#         else:
#             base_fps = self.a100_benchmarks['resnet50_fps']

#         # Adjust for multi-task learning
#         multi_task_factor = 1.8  # Processing both flood and damage tasks
#         effective_fps = base_fps / multi_task_factor

#         # Adjust for image size
#         size_factor = (self.config.img_height * self.config.img_width) / (512 * 512)
#         effective_fps = effective_fps / size_factor

#         # Adjust for batch size
#         batch_efficiency = min(1.0, self.config.batch_size / 16)  # Optimal around batch size 16
#         effective_fps = effective_fps * batch_efficiency

#         time_per_batch = self.config.batch_size / effective_fps

#         print(f"  Base FPS ({self.config.backbone}): {base_fps}")
#         print(f"  Multi-task factor: {multi_task_factor:.1f}x")
#         print(f"  Image size factor: {size_factor:.1f}x")
#         print(f"  Effective FPS: {effective_fps:.1f}")
#         print(f"  Time per batch: {time_per_batch:.3f} seconds")

#         return {
#             'effective_fps': effective_fps,
#             'time_per_batch': time_per_batch,
#             'estimated_memory': self._estimate_batch_memory()
#         }

#     def _estimate_batch_memory(self):
#         """Estimate memory usage per batch"""
#         batch_size = self.config.batch_size
#         img_pixels = self.config.img_height * self.config.img_width * 3

#         # Input batch + gradients + activations
#         batch_memory = (
#             batch_size * img_pixels * 4 * 3  # Input, gradients, activations
#         ) / 1e9

#         return batch_memory

#     def _run_sample_batch(self):
#         """Run a sample batch to get accurate timing"""
#         try:
#             # Create minimal model
#             model = EnhancedDisasterModel(
#                 num_classes_flood=self.config.num_classes_flood,
#                 num_classes_damage=self.config.num_classes_damage,
#                 backbone=self.config.backbone
#             ).cuda()

#             model.train()
#             optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
#             scaler = torch.cuda.amp.GradScaler()

#             # Create sample data
#             batch_size = self.config.batch_size
#             sample_images = torch.randn(
#                 batch_size, 3, self.config.img_height, self.config.img_width
#             ).cuda()
#             sample_masks = torch.randint(
#                 0, 2, (batch_size, self.config.img_height, self.config.img_width)
#             ).cuda()

#             # Warm up
#             for _ in range(3):
#                 with torch.cuda.amp.autocast():
#                     flood_out, damage_out = model(sample_images)
#                     loss = torch.nn.functional.cross_entropy(flood_out, sample_masks)

#                 scaler.scale(loss).backward()
#                 scaler.step(optimizer)
#                 scaler.update()
#                 optimizer.zero_grad()

#             torch.cuda.synchronize()

#             # Actual timing
#             num_runs = 10
#             start_time = time.time()
#             start_memory = torch.cuda.memory_allocated()

#             for _ in range(num_runs):
#                 with torch.cuda.amp.autocast():
#                     flood_out, damage_out = model(sample_images)
#                     flood_loss = torch.nn.functional.cross_entropy(flood_out, sample_masks)
#                     damage_loss = torch.nn.functional.cross_entropy(damage_out, sample_masks)
#                     total_loss = flood_loss + damage_loss

#                 scaler.scale(total_loss).backward()
#                 scaler.step(optimizer)
#                 scaler.update()
#                 optimizer.zero_grad()

#             torch.cuda.synchronize()
#             end_time = time.time()
#             peak_memory = torch.cuda.max_memory_allocated()

#             # Calculate metrics
#             total_time = end_time - start_time
#             time_per_batch = total_time / num_runs
#             peak_memory_gb = peak_memory / 1e9

#             print(f"    Sample batch timing: {time_per_batch:.3f} seconds")
#             print(f"    Peak GPU memory: {peak_memory_gb:.2f} GB")

#             # Cleanup
#             del model, optimizer, scaler, sample_images, sample_masks
#             torch.cuda.empty_cache()

#             return {
#                 'time_per_batch': time_per_batch,
#                 'peak_memory': peak_memory_gb
#             }

#         except Exception as e:
#             print(f"    ⚠️ Sample batch failed: {e}")
#             return {
#                 'time_per_batch': 0.5,  # Fallback estimate
#                 'peak_memory': 12.0
#             }

#     def _calculate_epoch_timing(self, dataset_sizes, batch_estimates):
#         """Calculate timing for one epoch"""
#         print(f"\nEpoch Timing Calculation:")

#         # Training steps per epoch
#         flood_batches = np.ceil(dataset_sizes.get('flood_train', 0) / self.config.batch_size)
#         damage_batches = np.ceil(dataset_sizes.get('damage_train', 0) / self.config.batch_size)

#         # Use simultaneous processing (both tasks per iteration)
#         steps_per_epoch = max(flood_batches, damage_batches)

#         # Training time
#         train_time_per_epoch = steps_per_epoch * batch_estimates['time_per_batch']

#         # Add gradient accumulation overhead
#         if hasattr(self.config, 'accumulation_steps') and self.config.accumulation_steps > 1:
#             accumulation_overhead = train_time_per_epoch * 0.1  # 10% overhead
#             train_time_per_epoch += accumulation_overhead

#         # Validation time (every few epochs)
#         val_flood_batches = np.ceil(dataset_sizes.get('flood_val', 0) / (self.config.batch_size * 2))
#         val_damage_batches = np.ceil(dataset_sizes.get('damage_val', 0) / (self.config.batch_size * 2))
#         val_steps = val_flood_batches + val_damage_batches
#         val_time = val_steps * batch_estimates['time_per_batch'] * 0.5  # Validation is faster

#         print(f"  Training steps per epoch: {int(steps_per_epoch):,}")
#         print(f"  Training time per epoch: {timedelta(seconds=int(train_time_per_epoch))}")
#         print(f"  Validation steps: {int(val_steps):,}")
#         print(f"  Validation time: {timedelta(seconds=int(val_time))}")

#         return {
#             'train_time': train_time_per_epoch,
#             'val_time': val_time,
#             'steps_per_epoch': steps_per_epoch,
#             'total_epoch_time': train_time_per_epoch + (val_time / 2)  # Val every 2 epochs
#         }

#     def _calculate_total_timing(self, epoch_timing):
#         """Calculate total training time"""
#         print(f"\nTotal Training Time Calculation:")

#         num_epochs = self.config.num_epochs
#         epoch_time = epoch_timing['total_epoch_time']

#         # Base training time
#         base_training_time = num_epochs * epoch_time

#         # Add overhead factors
#         io_overhead = base_training_time * self.a100_benchmarks['io_overhead']
#         setup_overhead = 300  # 5 minutes for setup/initialization
#         checkpoint_overhead = num_epochs * 30  # 30 seconds per checkpoint

#         total_time = base_training_time + io_overhead + setup_overhead + checkpoint_overhead

#         print(f"  Base training time: {timedelta(seconds=int(base_training_time))}")
#         print(f"  I/O overhead: {timedelta(seconds=int(io_overhead))}")
#         print(f"  Setup/checkpoint overhead: {timedelta(seconds=int(setup_overhead + checkpoint_overhead))}")
#         print(f"  Total estimated time: {timedelta(seconds=int(total_time))}")

#         # Cost estimation (rough)
#         hours = total_time / 3600
#         estimated_cost = hours * 2.5  # ~$2.5/hour for Colab Pro A100

#         print(f"  Estimated Colab Pro cost: ${estimated_cost:.2f}")

#         return {
#             'total_seconds': total_time,
#             'total_hours': hours,
#             'estimated_cost': estimated_cost,
#             'breakdown': {
#                 'training': base_training_time,
#                 'io_overhead': io_overhead,
#                 'setup_checkpoint': setup_overhead + checkpoint_overhead
#             }
#         }

#     def _generate_timing_report(self, dataset_sizes, model_complexity,
#                                batch_estimates, epoch_timing, total_timing):
#         """Generate comprehensive timing report"""

#         report = {
#             'summary': {
#                 'total_time_hours': total_timing['total_hours'],
#                 'total_time_formatted': str(timedelta(seconds=int(total_timing['total_seconds']))),
#                 'estimated_cost_usd': total_timing['estimated_cost'],
#                 'epochs': self.config.num_epochs,
#                 'batch_size': self.config.batch_size
#             },
#             'dataset_info': dataset_sizes,
#             'model_info': model_complexity,
#             'performance': {
#                 'time_per_batch_seconds': batch_estimates['time_per_batch'],
#                 'effective_fps': batch_estimates['effective_fps'],
#                 'steps_per_epoch': epoch_timing['steps_per_epoch'],
#                 'epoch_time_minutes': epoch_timing['total_epoch_time'] / 60
#             },
#             'resources': {
#                 'estimated_gpu_memory_gb': batch_estimates.get('estimated_memory', 0) + model_complexity['memory_gb'],
#                 'gpu_utilization_percent': 85  # Estimated
#             },
#             'time_breakdown': total_timing['breakdown'],
#             'recommendations': self._generate_recommendations(total_timing, batch_estimates)
#         }

#         return report

#     def _generate_recommendations(self, total_timing, batch_estimates):
#         """Generate recommendations based on estimates"""
#         recommendations = []

#         if total_timing['total_hours'] > 8:
#             recommendations.append("⚠️ Training will take >8 hours. Consider reducing epochs or using larger batch size.")

#         if total_timing['estimated_cost'] > 50:
#             recommendations.append("💰 High cost estimate (>${:.0f}). Consider optimizing hyperparameters.".format(total_timing['estimated_cost']))

#         if batch_estimates['time_per_batch'] > 1.0:
#             recommendations.append("🐌 Slow batch processing. Consider reducing image size or model complexity.")

#         if self.config.batch_size < 16:
#             recommendations.append("📈 Small batch size. A100 can handle larger batches for better efficiency.")

#         if not hasattr(self.config, 'spacenet_cache_size') or self.config.spacenet_cache_size < 200:
#             recommendations.append("🔄 Increase cache size to 200-500 for A100 to reduce I/O overhead.")

#         if not recommendations:
#             recommendations.append("✅ Configuration looks well-optimized for A100!")

#         return recommendations


# def estimate_training_time_before_run(config):
#     """
#     Main function to estimate training time before starting training
#     Place this before your main training loop
#     """

#     # Initialize estimator
#     estimator = TrainingTimeEstimator(config)

#     # Define dataset sizes (you'll need to update these with actual sizes)
#     # You can get these from your dataset objects
#     dataset_sizes = {
#         'flood_train': 1445,  # Update with actual FloodNet training size
#         'flood_val': 450,     # Update with actual FloodNet validation size
#         'damage_train': 4500, # Update with actual RescueNet training size
#         'damage_val': 1000,   # Update with actual RescueNet validation size
#     }

#     # Get time estimates
#     report = estimator.estimate_training_time(
#         dataset_sizes=dataset_sizes,
#         sample_batch=True  # Set to False to skip sample batch (faster but less accurate)
#     )

#     # Display summary
#     print(f"\n" + "="*60)
#     print("TRAINING TIME ESTIMATE SUMMARY")
#     print("="*60)
#     print(f"Estimated Training Time: {report['summary']['total_time_formatted']}")
#     print(f"Estimated Cost: ${report['summary']['estimated_cost_usd']:.2f}")
#     print(f"Time per Epoch: {report['performance']['epoch_time_minutes']:.1f} minutes")
#     print(f"GPU Memory Usage: {report['resources']['estimated_gpu_memory_gb']:.1f} GB")

#     print(f"\n📋 RECOMMENDATIONS:")
#     for rec in report['recommendations']:
#         print(f"  {rec}")

#     # Ask user for confirmation
#     print(f"\n" + "="*60)
#     response = input("Continue with training? (y/n): ").lower().strip()

#     if response != 'y':
#         print("Training cancelled by user.")
#         return False, report

#     print("Proceeding with training...")
#     return True, report

In [None]:
# if __name__ == "__main__":
#     config = Config()

#     # NEW: Estimate training time first
#     should_continue, time_report = estimate_training_time_before_run(config)

#     if should_continue:
#         model, trainer = train_complete_model(config)  # Your existing code
#     else:
#         print("Training cancelled. Adjust config and try again.")

## Initiate Model

In [None]:
if __name__ == '__main__':
  config = Config()
  model, trainer = train_complete_model(config)

Using device: cuda
GPU: NVIDIA A100-SXM4-80GB
Memory: 85.2 GB

Creating datasets...
FloodNet dataset: 1445 images, 1445 with masks
FloodNet dataset: 450 images, 450 with masks
FloodNet dataset: 448 images, 448 with masks
Prioritized 158 samples by disaster content likelihood
Preloading top 100 disaster samples to A100 GPU memory...


##Model Demonstration

In [None]:
# def test_on_new_image_with_tiff(model_path, image_path):
#     import torch
#     import cv2
#     import numpy as np
#     import matplotlib.pyplot as plt
#     from PIL import Image
#     import os

#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#     try:
#         # Load model
#         checkpoint = torch.load(model_path, map_location=device, weights_only=False)

#         model = EnhancedDisasterModel(
#             num_classes_flood=2,
#             num_classes_damage=4,
#             backbone='resnet101'
#         )
#         model.load_state_dict(checkpoint['model_state_dict'])
#         model.to(device)
#         model.eval()

#         # Check file extension and load accordingly
#         file_ext = os.path.splitext(image_path)[1].lower()
#         print(f"  → Processing {file_ext} file: {os.path.basename(image_path)}")

#         if file_ext in ['.tif', '.tiff']:
#             # Use PIL for TIFF files (better TIFF support)
#             print("  → Using PIL for TIFF loading...")
#             try:
#                 image_pil = Image.open(image_path)

#                 # Handle different TIFF modes
#                 if image_pil.mode == 'RGB':
#                     image = np.array(image_pil)
#                 elif image_pil.mode == 'RGBA':
#                     # Convert RGBA to RGB
#                     image = np.array(image_pil.convert('RGB'))
#                 elif image_pil.mode in ['L', 'P']:
#                     # Convert grayscale or palette to RGB
#                     image = np.array(image_pil.convert('RGB'))
#                 elif image_pil.mode == 'I' or image_pil.mode == 'F':
#                     # Handle 32-bit integer or float images
#                     arr = np.array(image_pil)
#                     # Normalize to 0-255 range
#                     arr = ((arr - arr.min()) / (arr.max() - arr.min()) * 255).astype(np.uint8)
#                     image = np.stack([arr, arr, arr], axis=-1)  # Convert to RGB
#                 else:
#                     print(f"  → Converting from mode {image_pil.mode} to RGB")
#                     image = np.array(image_pil.convert('RGB'))

#                 print(f"  → TIFF image loaded: {image.shape}, dtype: {image.dtype}")

#             except Exception as e:
#                 print(f"  ✗ PIL failed, trying OpenCV for TIFF: {e}")
#                 # Fallback to OpenCV
#                 image = cv2.imread(image_path, cv2.IMREAD_COLOR)
#                 if image is None:
#                     raise ValueError(f"Could not load TIFF image from {image_path}")
#                 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

#         else:
#             # Use OpenCV for standard formats (PNG, JPG, JPEG)
#             print("  → Using OpenCV for standard image loading...")
#             image = cv2.imread(image_path)
#             if image is None:
#                 raise ValueError(f"Could not load image from {image_path}")
#             image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

#         print(f"  → Final image shape: {image.shape}, dtype: {image.dtype}")

#         # Ensure image is in correct format (0-255, uint8)
#         if image.dtype != np.uint8:
#             if image.max() <= 1.0:
#                 # Image is in 0-1 range, convert to 0-255
#                 image = (image * 255).astype(np.uint8)
#             else:
#                 # Image might be in different range, normalize
#                 image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)

#         # Apply transforms
#         transform = get_validation_augmentation()
#         input_tensor = transform(image=image)['image']
#         input_tensor = input_tensor.unsqueeze(0).to(device)

#         # Make predictions
#         with torch.no_grad():
#             flood_out, damage_out = model(input_tensor)

#         flood_pred = torch.argmax(flood_out, dim=1)[0].cpu().numpy()
#         damage_pred = torch.argmax(damage_out, dim=1)[0].cpu().numpy()

#         # Visualize results
#         fig, axes = plt.subplots(1, 3, figsize=(15, 5))

#         axes[0].imshow(image)
#         axes[0].set_title(f'Original Image ({file_ext})')
#         axes[0].axis('off')

#         axes[1].imshow(flood_pred, cmap='Blues')
#         axes[1].set_title('Flood Prediction')
#         axes[1].axis('off')

#         axes[2].imshow(damage_pred, cmap='Reds')
#         axes[2].set_title('Damage Prediction')
#         axes[2].axis('off')

#         plt.tight_layout()
#         plt.show()

#         return flood_pred, damage_pred

#     except Exception as e:
#         print(f"  ✗ ERROR processing {image_path}: {str(e)}")
#         return None, None


# import os
# import random

# def test_on_images_with_tiff(model_path, folder_path, num_img=10):
#     """
#     Test the disaster model on multiple images including TIFF files.

#     Supported formats: PNG, JPG, JPEG, TIF, TIFF
#     """
#     try:
#         # Updated file extension list to include TIFF
#         image_extensions = (".png", ".jpg", ".jpeg", ".tif", ".tiff")
#         image_files = [f for f in os.listdir(folder_path)
#                       if f.lower().endswith(image_extensions)]

#         if not image_files:
#             print(f"No supported image files found in {folder_path}")
#             print(f"Looking for: {image_extensions}")
#             all_files = os.listdir(folder_path)[:10]  # Show first 10 files
#             print(f"Files in folder: {all_files}")
#             return []

#         print(f"Found {len(image_files)} supported images")

#         # Group by file type for info
#         file_types = {}
#         for f in image_files:
#             ext = os.path.splitext(f)[1].lower()
#             file_types[ext] = file_types.get(ext, 0) + 1

#         print(f"File type breakdown: {file_types}")

#         # Sample random images
#         sample_files = random.sample(image_files, min(num_img, len(image_files)))
#         print(f"Processing {len(sample_files)} images...")

#         results = []
#         successful_predictions = 0

#         for i, file_name in enumerate(sample_files, 1):
#             file_path = os.path.join(folder_path, file_name)
#             print(f"\n[{i}/{len(sample_files)}] Processing: {file_name}")

#             flood_pred, damage_pred = test_on_new_image_with_tiff(model_path, file_path)

#             if flood_pred is not None and damage_pred is not None:
#                 results.append({
#                     "file": file_name,
#                     "flood_mask": flood_pred,
#                     "damage_mask": damage_pred
#                 })
#                 successful_predictions += 1
#                 print(f"  ✓ Success!")
#             else:
#                 print(f"  ✗ Failed")

#         print(f"\nCompleted: {successful_predictions}/{len(sample_files)} images processed successfully")
#         return results

#     except Exception as e:
#         print(f"Error in test_on_images_with_tiff: {str(e)}")
#         return []


# # Quick check function to see what file types are in your folder
# def check_file_types(folder_path):
#     """Check what file types are present in the folder"""
#     print(f"=== File Analysis for: {folder_path} ===")

#     if not os.path.exists(folder_path):
#         print(f"Folder does not exist!")
#         return

#     all_files = os.listdir(folder_path)
#     print(f"Total files: {len(all_files)}")

#     # Group by extension
#     extensions = {}
#     for f in all_files:
#         ext = os.path.splitext(f)[1].lower()
#         if not ext:
#             ext = "(no extension)"
#         extensions[ext] = extensions.get(ext, 0) + 1

#     print(f"\nFile extensions found:")
#     for ext, count in sorted(extensions.items()):
#         print(f"  {ext}: {count} files")

#     # Check for TIFF specifically
#     tiff_files = [f for f in all_files if f.lower().endswith(('.tif', '.tiff'))]
#     if tiff_files:
#         print(f"\nTIFF files found: {len(tiff_files)}")
#         print(f"First few TIFF files: {tiff_files[:5]}")
#     else:
#         print(f"\nNo TIFF files found.")


# # Usage example
# if __name__ == "__main__":
#     model_path = "/content/checkpoints/best_model.pth"
#     folder_path = "/content/drive/MyDrive/G.E.M.S./FloodNet/FloodNet-Supervised_v1.0/test/test-label-img/"

#     # First check what file types you have
#     check_file_types(folder_path)

#     # Then run with TIFF support
#     if os.path.exists(model_path) and os.path.exists(folder_path):
#         results = test_on_images_with_tiff(model_path, folder_path, num_img=10)
#         print(f"Processing complete. Got results for {len(results)} images.")
#     else:
#         print("Missing model or folder path")

In [None]:
#!pip install nbconvert
#!jupyter nbconvert --to html Hermes.0.3.6.ipynb

In [None]:
# print(f"Flood pixels: {(flood_mask == 1).sum()}")
# print(f"Damage distribution: {np.bincount(damage_mask.flatten())}")