In [None]:
!curl -L -o rescuenet.zip "https://www.dropbox.com/scl/fo/ntgeyhxe2mzd2wuh7he7x/AFIchlfjVO_7MzPcNc1ZOHE/RescueNet?rlkey=6vxiaqve9gp6vzvzh3t5mz0vv&subfolder_nav_tracking=1&st=cpmz72mg&dl=1"  #download the main rescuenet image archive directly into the notebook workspace
!curl -L -o colormask.zip "https://www.dropbox.com/scl/fo/ntgeyhxe2mzd2wuh7he7x/AK7z2KSL2Df2igYzGrrHlYs/ColorMasks-RescueNet?dl=0&rlkey=6vxiaqve9gp6vzvzh3t5mz0vv&subfolder_nav_tracking=1&d=1"  #download the color-mask archive that pairs with the images
!unzip rescuenet.zip -d rescuenettrain/  #unzip images into a dedicated folder so paths are predictable
!unzip colormask.zip -d rescuenetmask/   #unzip masks into a parallel folder keeping train/test structure


In [None]:
#core libs
#these are the standard python libraries and cv tools i use for file io and image handling
import os, glob, math, random  #filesystem ops, pattern search, math helpers, rng
import numpy as np             #array math
import cv2                     #opencv for image resizing and padding
from PIL import Image          #reliable png/jpg reading

#torch & utils
#pytorch core plus dataloaders and a progress bar for visibility
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
from tqdm.auto import tqdm

#external libs used later
#albumentations handles all of my image/mask augmentations
import albumentations as A
!pip install segmentation_models_pytorch  #install segmentation models so i can use deeplabv3+ quickly
import segmentation_models_pytorch as smp
from google.colab import drive  #used to persist checkpoints to drive

#pytorch perf/precision knobs
#enable safe loading, allow tf32 paths, and prefer higher matmul precision when available
torch.serialization.add_safe_globals([__import__("numpy")._core.multiarray._reconstruct])  #handle numpy arrays in torch.load safety list
torch.backends.cuda.matmul.allow_tf32 = True  #use tf32 on ampere+ for speed without big accuracy loss
torch.backends.cudnn.allow_tf32 = True        #same idea for cudnn convolutions
try:
    torch.set_float32_matmul_precision("high")  #pytorch 2.x matmul tuning; ok to skip if not supported
except Exception:
    pass

#shared constants used throughout
IGNORE_INDEX = 255  #mask value to ignore when computing loss/metrics


In [None]:
#fast color(rgb)->class-id lut
#i convert 3-channel color masks into integer class ids using a 256^3 lookup table
def build_color_lut(color_map, ignore_index=IGNORE_INDEX):
    lut = np.full((256,256,256), ignore_index, dtype=np.uint8)  #pre-fill with ignore so unknown colors are skipped
    for (r,g,b), cid in color_map.items():
        lut[r,g,b] = cid
    return lut

def rgbmask_to_ids(mask_rgb: np.ndarray, lut: np.ndarray) -> np.ndarray:
    flat = mask_rgb.reshape(-1,3)                                #flatten to n×3 so we can index the lut
    ids = lut[flat[:,0], flat[:,1], flat[:,2]]                   #vectorized mapping from color to class id
    return ids.reshape(mask_rgb.shape[:2])                       #restore to h×w

class RescueNetSegDataset(Dataset):
    """
    image_dir: JPG/PNG images (e.g., .../train-org-img)
    mask_dir:  color masks named like <stem>_lab.png
    augment:   Albumentations Compose (image+mask) WITHOUT Normalize (recommended)
    normalize: Albumentations A.Normalize(mean,std) OR None
    """
    #this dataset pairs each image with its color mask, applies augs, converts to tensors
    def __init__(self, image_dir, mask_dir, augment=None, normalize=None,
                 color_map=None, return_paths=False):
        assert color_map is not None, "Provide COLOR_MAP"  #i require a color map for lut conversion
        #allow jpg/jpeg/png
        exts = ("*.jpg","*.jpeg","*.png")
        self.image_paths = sorted(sum([glob.glob(os.path.join(image_dir, e)) for e in exts], []))

        #map masks by stem (strip '_lab')
        mask_files = glob.glob(os.path.join(mask_dir, "*.png"))
        mask_dict = {os.path.basename(m).replace("_lab.png",""): m for m in mask_files}

        self.mask_paths = []
        for ip in self.image_paths:
            stem = os.path.splitext(os.path.basename(ip))[0]
            mp = mask_dict.get(stem)
            if mp is None:
                raise RuntimeError(f"No mask found for image: {ip}")  #hard fail so data issues are caught early
            self.mask_paths.append(mp)

        self.augment = augment
        self.normalize = normalize
        self.return_paths = return_paths
        self.lut = build_color_lut(color_map, IGNORE_INDEX)  #precompute lut once per dataset

    def __len__(self): return len(self.image_paths)  #standard dataset length

    def __getitem__(self, idx):
        ip, mp = self.image_paths[idx], self.mask_paths[idx]

        img = np.array(Image.open(ip).convert("RGB"))        #HWC uint8
        msk_rgb = np.array(Image.open(mp).convert("RGB"))    #HWC uint8

        #ensure same size before aug
        if msk_rgb.shape[:2] != img.shape[:2]:
            msk_rgb = cv2.resize(msk_rgb, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)  #preserve labels

        ids = rgbmask_to_ids(msk_rgb, self.lut)              #HxW uint8/255
        ids = ids.astype(np.int64, copy=False)

        if self.augment is not None:
            out = self.augment(image=img, mask=ids)          #apply joint augs
            img, ids = out["image"], out["mask"]
            ids = ids.astype(np.int64, copy=False)

        #--- normalization guard ---
        if self.normalize is not None:
            #use only this Normalize (float32, mean/std)
            img = self.normalize(image=img)["image"]
        else:
            #if already float32 (e.g., some augs), don't rescale again
            if img.dtype == np.uint8:
                img = img.astype(np.float32) / 255.0
            elif img.dtype != np.float32:
                img = img.astype(np.float32) / 255.0  #fallback

        x = torch.from_numpy(img.transpose(2,0,1))  #C,H,W float32
        y = torch.from_numpy(ids)                   #H,W int64

        if self.return_paths:
            return x, y, (ip, mp)
        return x, y

    def __repr__(self):
        return f"RescueNetSegDataset(n={len(self)}, img0='{os.path.basename(self.image_paths[0])}')"  #quick sanity view


In [None]:
!pip install -q albumentations segmentation_models_pytorch  #ensure correct versions are present for colab runtime

#this custom pad transform enforces a minimum canvas size with constant borders for both image and mask
class PadIfNeededConst(A.DualTransform):
    """
    Pad to (min_height, min_width) using CONSTANT border.
    position: 'bottom_right' (default), 'center', or 'top_left'
    - image padding uses `value` (tuple or int)
    - mask padding uses `mask_value` (int), e.g., IGNORE_INDEX
    """
    def __init__(self, min_height, min_width,
                 border_mode=cv2.BORDER_CONSTANT,
                 value=0, mask_value=255, position='bottom_right',
                 always_apply=False, p=1.0):
        super().__init__(always_apply, p)
        self.min_height = int(min_height)
        self.min_width  = int(min_width)
        self.border_mode = border_mode
        self.value = value
        self.mask_value = int(mask_value)
        self.position = position
        if self.border_mode != cv2.BORDER_CONSTANT:
            raise NotImplementedError("PadIfNeededConst currently supports only BORDER_CONSTANT.")  #keep scope simple

    def _pad_amounts(self, h, w):
        dh = max(0, self.min_height - h)
        dw = max(0, self.min_width  - w)
        if dh == 0 and dw == 0:
            return 0, 0, 0, 0
        if self.position == 'center':
            top    = dh // 2
            bottom = dh - top
            left   = dw // 2
            right  = dw - left
        elif self.position == 'top_left':
            top, left = 0, 0
            bottom, right = dh, dw
        else:  # 'bottom_right' (anchor top-left)
            top, left = 0, 0
            bottom, right = dh, dw
        return top, bottom, left, right

    def apply(self, img, **params):
        h, w = img.shape[:2]
        top, bottom, left, right = self._pad_amounts(h, w)
        if top == bottom == left == right == 0:
            return img
        #handle scalar vs 3-tuple image value
        val = self.value
        if img.ndim == 2 and isinstance(val, (tuple, list)):
            val = val[0] if len(val) else 0
        return cv2.copyMakeBorder(img, top, bottom, left, right,
                                  self.border_mode, value=val)

    def apply_to_mask(self, mask, **params):
        h, w = mask.shape[:2]
        top, bottom, left, right = self._pad_amounts(h, w)
        if top == bottom == left == right == 0:
            return mask
        return cv2.copyMakeBorder(mask, top, bottom, left, right,
                                  self.border_mode, value=self.mask_value)

    def get_transform_init_args_names(self):
        return ("min_height", "min_width", "border_mode", "value", "mask_value", "position")


In [None]:
#==== device & perf flags ====
#select gpu if available and enable cudnn autotune; tf32 flags were set earlier already
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True  #let cudnn pick fastest conv algorithms

#amp: prefer bf16 on newer nvidia; else fall back to fp16 with grad scaler
USE_BF16 = True
AMP_DTYPE = torch.bfloat16 if (USE_BF16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else None
SCALER_ENABLED = (device.type == "cuda" and AMP_DTYPE is None)  #scaler is only needed for fp16

#====hyperparams ====
#these are the core knobs for training; i tuned around gpu limits and validation stability
BATCH        = 16
ACCUM_STEPS  = 1
EPOCHS       = 40
LR           = 1e-4
WD           = 1e-4
GRAD_CLIP_N  = 1.0

#colab drive (kept)
#mount google drive so i can save best checkpoints across sessions
drive.mount('/content/drive', force_remount=True)
BASE = "/content/drive/MyDrive/college_data"
os.makedirs(BASE, exist_ok=True)

#==== project globals assumed present elsewhere ====
#color map for mask decoding and essential paths for images/masks
#COLOR_MAP, IGNORE_INDEX, RescueNetSegDataset, PadIfNeededConst, build_color_lut, rgbmask_to_ids
ROOT_MASK = "rescuenetmask"
ROOT_IMG  = "rescuenettrain"

COLOR_MAP = {
    (0, 0, 0): 0,  # background
    (61, 230, 250): 1,  # water
    (180, 120, 120): 2, # building-no-damage
    (235, 255, 7): 3,   # building-medium-damage
    (255, 184, 6): 4,   # building-major-damage
    (255, 0, 0): 5,     # building-total-destruction
    (255, 0, 245): 6,   # vehicle
    (140, 140, 140): 7, # road-clear
    (160, 150, 20): 8,  # road-blocked
    (4, 250, 7): 9,     # tree
    (0, 101, 140): 10,  # pool
}
NUM_CLASSES = max(COLOR_MAP.values()) + 1
LUT = build_color_lut(COLOR_MAP, IGNORE_INDEX)  #prebuild lut once

#==== transforms ====
#pull imagenet preprocessing stats for the chosen encoder
enc = "resnet50"
pp = smp.encoders.get_preprocessing_params(enc, pretrained="imagenet")
mean, std = pp["mean"], pp["std"]

#version-robust ShiftScaleRotate (remove unsupported args like value/mask_value)
#this gives spatial variety while keeping labels valid via constant borders
aug_scale = A.ShiftScaleRotate(
    shift_limit=0.02, scale_limit=0.5, rotate_limit=10,
    border_mode=cv2.BORDER_CONSTANT, p=0.35
)

#version-robust CoarseDropout: prefer new api, fall back to old
#this simulates occlusions/missing pixels and encourages robustness
try:
    coarse = A.CoarseDropout(
        holes_range=(1, 2),
        max_height=48, max_width=48,
        fill_value=0, mask_fill_value=IGNORE_INDEX, p=0.15
    )
except TypeError:
    coarse = A.CoarseDropout(
        min_holes=1, max_holes=2,
        min_height=24, max_height=48,
        min_width=24,  max_width=48,
        fill_value=0, mask_fill_value=IGNORE_INDEX, p=0.15
    )

def _contig(x, **kwargs): return np.ascontiguousarray(x)  #avoid copies later in torch.from_numpy
make_contig = A.Lambda(image=_contig, mask=_contig)

CROP = 768   #if you try 832, reduce batch to fit vram

#final train pipeline: resize→pad→crop with objects→color/motion augs→normalize
train_tfms = A.Compose([
    A.LongestMaxSize(max_size=1024),
    A.HorizontalFlip(p=0.30),
    A.VerticalFlip(p=0.30),
    A.RandomRotate90(p=0.40),
    aug_scale,
    A.SmallestMaxSize(max_size=CROP),
    PadIfNeededConst(min_height=1024, min_width=1024,
                     border_mode=cv2.BORDER_CONSTANT, value=(0,0,0), mask_value=IGNORE_INDEX),
    A.CropNonEmptyMaskIfExists(CROP, CROP, ignore_values=[IGNORE_INDEX], p=1.0),
    A.ColorJitter(brightness=0.12, contrast=0.12, saturation=0.12, hue=0.02, p=0.33),
    A.MotionBlur(blur_limit=(3,7), p=0.10),
    coarse,
    make_contig,
    A.Normalize(mean=mean, std=std, max_pixel_value=255.0),
])

#validation pipeline: center crop with padding and same normalization
val_tfms = A.Compose([
    A.SmallestMaxSize(max_size=CROP),
    PadIfNeededConst(CROP, CROP, value=(0,0,0), mask_value=IGNORE_INDEX, position='center'),
    A.CenterCrop(CROP, CROP),
    make_contig,
    A.Normalize(mean=mean, std=std, max_pixel_value=255.0),
])

print("created transforms")  #quick sanity check


In [None]:
#==== datasets ====
#build train/val datasets with shared color map and transform pipelines
train_ds_full = RescueNetSegDataset(
    image_dir=f"{ROOT_IMG}/train/train-org-img",
    mask_dir=f"{ROOT_MASK}/ColorMasks-TrainSet",
    color_map=COLOR_MAP,
    augment=train_tfms,
    normalize=None,
)
val_ds_full = RescueNetSegDataset(
    image_dir=f"{ROOT_IMG}/test/test-org-img",
    mask_dir=f"{ROOT_MASK}/ColorMasks-TestSet",
    color_map=COLOR_MAP,
    augment=val_tfms,
    normalize=None,
)
train_ds = train_ds_full
val_ds   = val_ds_full
print("loaded datasets")
print("train size:", len(train_ds))

#==== CE class weights from global pixel counts ====
#compute per-class weights so cross-entropy doesn’t ignore rare classes
def mask_paths_for(subset_or_ds):
    if isinstance(subset_or_ds, Subset):
        base, idxs = subset_or_ds.dataset, subset_or_ds.indices
        return [base.mask_paths[i] for i in idxs]
    return list(subset_or_ds.mask_paths)

train_mask_paths = mask_paths_for(train_ds)

try:
    counts = np.load("train_counts.npy")  #reuse cached counts if present
except:
    counts = np.zeros(NUM_CLASSES, dtype=np.int64)
    for mp in tqdm(train_mask_paths, desc="Counting pixels (full train)"):
        m = np.array(Image.open(mp).convert("RGB"))
        ids = rgbmask_to_ids(m, LUT)
        ids = ids[ids != IGNORE_INDEX]
        if ids.size:
            counts += np.bincount(ids.ravel(), minlength=NUM_CLASSES)
    np.save("train_counts.npy", counts)

w = np.where(counts > 0, 1.0 / np.sqrt(counts.astype(np.float64)), 0.0)  #inverse sqrt freq
nz = w[w > 0]
if nz.size: w /= np.median(nz)  #normalize by median to keep scales reasonable
print("CE weights before clip:", w.tolist())
w = np.clip(w, 0.25, 2.5)  #cap extremes so training remains stable
print("CE weights after clip:", w.tolist())
class_weights = torch.tensor(w, dtype=torch.float32, device=device)  #move to device for loss


In [None]:
#==== rarity-aware image sampler ====
#this sampler increases chance of picking images that contain rare classes and rare areas
def compute_image_weights(
    dataset, lut, num_classes, ignore_index=IGNORE_INDEX,
    gamma=1.1, lambda_mix=0.9, cap_min=0.25, cap_max=6.0, show_progress=True
):
    img_classes = []
    class_img_freq = np.zeros(num_classes, dtype=np.int64)
    per_img_hist = []
    global_pixel_counts = np.zeros(num_classes, dtype=np.int64)

    it = dataset.mask_paths
    if show_progress: it = tqdm(it, desc="Scanning masks (histograms + presence)")
    for mp in it:
        m = np.array(Image.open(mp).convert("RGB"))
        ids = rgbmask_to_ids(m, lut)
        valid = (ids != ignore_index)
        ids_v = ids[valid]

        hist = np.bincount(ids_v.ravel(), minlength=num_classes).astype(np.int64)
        per_img_hist.append(hist)
        global_pixel_counts += hist

        present = np.flatnonzero(hist > 0)
        img_classes.append(present)
        for c in present: class_img_freq[c] += 1

    #presence term
    class_img_freq = np.maximum(class_img_freq, 1)
    denom_presence = np.sqrt(class_img_freq.astype(np.float64))
    presence_weights = []
    for present in img_classes:
        presence_weights.append(1.0 if len(present)==0 else float(np.sum(1.0 / denom_presence[present])))
    presence_weights = np.asarray(presence_weights, dtype=np.float64)

    #rarity term
    g = global_pixel_counts.astype(np.float64).copy()
    g[g < 1] = 1.0
    rarity = 1.0 / (g ** gamma)
    rarity_weights = []
    for hist in per_img_hist:
        total = float(hist.sum())
        if total <= 0.0:
            rarity_weights.append(1.0)
        else:
            mix = hist.astype(np.float64) / total
            rarity_weights.append(float((rarity * mix).sum()))
    rarity_weights = np.asarray(rarity_weights, dtype=np.float64)

    #blend + clamp + normalize
    w = (1.0 - lambda_mix) * presence_weights + lambda_mix * rarity_weights
    w = np.clip(w, cap_min, cap_max)
    w /= w.sum()
    return torch.as_tensor(w, dtype=torch.double), global_pixel_counts

try:
    img_weights, counts_recalc = torch.load("img_weights_sampler_2.pt", weights_only=False)  #reuse if cached
except:
    img_weights, counts_recalc = compute_image_weights(
        train_ds_full, LUT, NUM_CLASSES,
        ignore_index=IGNORE_INDEX,
        gamma=1.1, lambda_mix=0.9,
        cap_min=0.25, cap_max=6.0, show_progress=True,
    )
    torch.save((img_weights, counts_recalc), "img_weights_sampler_2.pt")

sampler = WeightedRandomSampler(
    weights=img_weights,
    num_samples=len(train_ds_full),
    replacement=True
)

def safe_collate(batch):
    xs, ys = zip(*batch)
    xs = [x.contiguous().clone() for x in xs]  #avoid views interfering with dataloader pins
    ys = [y.contiguous().clone() for y in ys]
    return torch.stack(xs, 0), torch.stack(ys, 0)

train_loader = DataLoader(
    train_ds_full, batch_size=BATCH,
    num_workers=12, persistent_workers=True,
    pin_memory=True, prefetch_factor=4,
    sampler=sampler, collate_fn=safe_collate
)
val_loader = DataLoader(
    val_ds, batch_size=max(1, BATCH), shuffle=False,
    num_workers=12, persistent_workers=True,
    pin_memory=True, prefetch_factor=4, collate_fn=safe_collate
)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")  #sanity check on iteration sizes


In [None]:
#==== model + optimizer ====
#build deeplabv3+ with the chosen encoder, optionally warm-start from prior weights
model = smp.DeepLabV3Plus(
    encoder_name=enc,
    encoder_weights="imagenet",
    in_channels=3,
    classes=NUM_CLASSES,
).to(device)

#warm start (optional)
prev_path = "bestprevdeeplabv3+.pth"
if os.path.exists(prev_path):
    try:
        state = torch.load(prev_path, map_location=device)
        model.load_state_dict(state, strict=False)  #strict false so minor head/aux diffs don’t break load
        print(f"Loaded existing weights from {prev_path} (strict=False).")
    except Exception as e:
        print("Warm start skipped:", e)

opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)  #adamw has stable weight decay

#==== cosine schedule with 1-epoch warmup (per-step) ====
#i compute per-step lr because grad accumulation changes effective steps per epoch
steps_per_epoch = max(1, len(train_loader) // max(1, ACCUM_STEPS))
total_steps     = EPOCHS * steps_per_epoch
warmup_steps    = steps_per_epoch * 1

def lr_lambda(step):
    if step < warmup_steps:
        return float(step + 1) / float(warmup_steps)  #linear warmup
    progress = (step - warmup_steps) / max(1, (total_steps - warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * progress))  #cosine decay

sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)

#==== losses: CE + Lovasz + per-class focal CE + OHEM (top-k CE) ====
#final hybrid loss that balances overall accuracy, boundaries, hard pixels, and rare classes
ce = torch.nn.CrossEntropyLoss(
    weight=class_weights, ignore_index=IGNORE_INDEX, label_smoothing=0.02
)
lovasz = smp.losses.LovaszLoss(mode="multiclass", per_image=False)

#per-class alpha from CE weights, bump rare class (8) a bit more
alpha_vec = (class_weights / class_weights.sum()).to(device).clone()
alpha_vec[8] *= 1.50   #slightly emphasize road-blocked in my setup
alpha_vec = alpha_vec / alpha_vec.sum()

#per-class gammas: stronger on rare
gamma_map = torch.full((NUM_CLASSES,), 1.5, device=device)
gamma_map[8] = 2.4     #focus even more on specific difficult class

def focal_ce_loss_per_class_gamma(logits, y, alpha, gamma_map, ignore_index=IGNORE_INDEX):
    logp = F.log_softmax(logits, dim=1)
    p    = torch.exp(logp)

    B, C, H, W = logits.shape
    y_flat     = y.view(B, -1)
    logp_flat  = logp.view(B, C, -1)
    p_flat     = p.view(B, C, -1)

    valid = (y_flat != ignore_index)
    idx   = y_flat.clone()
    idx[~valid] = 0

    pt    = torch.gather(p_flat,  1, idx.unsqueeze(1)).squeeze(1)   #prob of the true class per pixel
    logpt = torch.gather(logp_flat,1, idx.unsqueeze(1)).squeeze(1)  #log-prob of the true class per pixel
    alpha_per_pix = alpha[idx]                                       #class weighting by rarity
    gamma_per_pix = gamma_map[idx]                                   #class-specific focusing

    focal = -(alpha_per_pix * (1.0 - pt).pow(gamma_per_pix) * logpt)
    focal = focal * valid
    return focal.sum() / valid.sum().clamp_min(1)

def topk_ce(logits, y, k=0.2):
    """
    Online hard example mining for segmentation.
    k: fraction of valid pixels to keep (e.g., 0.2 = top 20% hardest).
    """
    per_pix = F.cross_entropy(logits, y, weight=class_weights,
                              ignore_index=IGNORE_INDEX, reduction='none')
    valid = (y != IGNORE_INDEX)
    hard = per_pix[valid]
    if hard.numel() == 0:
        return torch.tensor(0.0, device=logits.device)
    k_keep = max(1, int(k * hard.numel()))
    vals, _ = torch.topk(hard, k_keep)
    return vals.mean()

def loss_fn(logits, y):
    ce_term  = ce(logits, y)
    lov_term = lovasz(logits, y)
    foc_term = focal_ce_loss_per_class_gamma(
        logits, y, alpha=alpha_vec, gamma_map=gamma_map, ignore_index=IGNORE_INDEX
    )
    ohem_term = topk_ce(logits, y, k=0.2)  #0.15..0.25 works; i use 0.2 here
    #keep total weight ~1.0; this mix matched my validation behavior best
    return 0.35 * ce_term + 0.30 * lov_term + 0.20 * foc_term + 0.15 * ohem_term

#==== amp scaler: only for fp16, not bf16 ====
scaler = torch.amp.GradScaler(enabled=SCALER_ENABLED)
print("got losses / optimizer / AMP / scheduler")  #checkpoint log


In [None]:
#==== ema wrapper (validate & save with ema weights) ====
#ema keeps a smoothed copy of weights that usually validates more stably than raw steps
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.data.clone()

    @torch.no_grad()
    def update(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name].mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)

    def apply_shadow(self, model):
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.backup[name] = p.data.clone()
                p.data.copy_(self.shadow[name])

    def restore(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad and name in self.backup:
                p.data.copy_(self.backup[name])
        self.backup = {}

ema = EMA(model, decay=0.999)  #create the ema tracker

#==== metrics ====
#helper to average iou while ignoring background class
def average_miou_ex_background(per_class_ious, bg_id=0):
    arr = np.array(per_class_ious, dtype=np.float64)
    if arr.size <= 1: return float("nan")
    classes = np.delete(arr, bg_id)
    return float(np.nanmean(classes) * 100.0)

@torch.no_grad()
def miou_on_loader_fast(model, loader, n_classes, device,
                        ignore_index=None, use_amp=True, show_progress=True, tta_hflip=True):
    #fast evaluation with optional multi-scale and horizontal flip tta
    def _forward_resized(m, imgs, scale):
        H, W = imgs.shape[-2], imgs.shape[-1]
        if scale != 1.0:
            imgs = F.interpolate(imgs, scale_factor=scale, mode="bilinear", align_corners=False)
        pad_h = (32 - imgs.shape[-2] % 32) % 32
        pad_w = (32 - imgs.shape[-1] % 32) % 32
        imgs_p = F.pad(imgs, (0, pad_w, 0, pad_h)) if (pad_h or pad_w) else imgs
        out = m(imgs_p)
        out = out[..., :imgs.shape[-2], :imgs.shape[-1]]
        return F.interpolate(out, size=(H, W), mode="bilinear", align_corners=False)

    model.eval()
    cm = np.zeros((n_classes, n_classes), dtype=np.int64)  #confusion matrix

    iterator = loader
    pbar = None
    if show_progress:
        try:
            iterator = tqdm(loader, desc="Evaluating", total=len(loader), leave=True)
            pbar = iterator
        except Exception:
            pass

    dev_type = device if isinstance(device, "str") else getattr(device, "type", "cpu")
    autocast_dtype = (torch.bfloat16 if (use_amp and torch.cuda.is_available() and torch.cuda.is_bf16_supported())
                      else (torch.float16 if use_amp else torch.float32))
    scales = [0.85, 1.0, 1.15] if tta_hflip else [1.0]

    try:
        for images, labels in iterator:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            if use_amp:
                with torch.autocast(device_type=dev_type, dtype=autocast_dtype):
                    logits = 0
                    for s in scales:
                        l = _forward_resized(model, images, s)
                        if tta_hflip:
                            lf = _forward_resized(model, torch.flip(images, dims=[-1]), s)
                            lf = torch.flip(lf, dims=[-1])
                            l = 0.5 * (l + lf)
                        logits = logits + l
                    logits = logits / len(scales)
            else:
                logits = 0
                for s in scales:
                    l = _forward_resized(model, images, s)
                    if tta_hflip:
                        lf = _forward_resized(model, torch.flip(images, dims=[-1]), s)
                        lf = torch.flip(lf, dims=[-1])
                        l = 0.5 * (l + lf)
                    logits = logits + l
                logits = logits / len(scales)

            preds = torch.argmax(logits, dim=1)  #class per pixel

            yt = labels.reshape(-1)
            yp = preds.reshape(-1)

            if ignore_index is not None:
                keep = yt != ignore_index
                yt = yt[keep]; yp = yp[keep]

            keep = (yt >= 0) & (yt < n_classes) & (yp >= 0) & (yp < n_classes)
            yt = yt[keep].to(torch.int64); yp = yp[keep].to(torch.int64)
            if yt.numel() == 0:
                continue

            idx = yt * n_classes + yp
            counts = torch.bincount(idx, minlength=n_classes * n_classes)
            cm += counts.cpu().numpy().reshape(n_classes, n_classes)

            if pbar is not None:
                tp_live = np.diag(cm).astype(np.float64)
                fn_live = cm.sum(axis=1) - np.diag(cm)
                fp_live = cm.sum(axis=0) - np.diag(cm)
                denom_live = tp_live + fn_live + fp_live
                with np.errstate(divide="ignore", invalid="ignore"):
                    ious_live = np.where(denom_live > 0, tp_live / denom_live, np.nan)
                pbar.set_postfix(mIoU=np.nanmean(ious_live))

    finally:
        if pbar is not None:
            pbar.close()

    #==== final metrics from confusion matrix ====
    tp = np.diag(cm).astype(np.float64)
    fn = cm.sum(axis=1) - np.diag(cm)
    fp = cm.sum(axis=0) - np.diag(cm)
    denom_iou = tp + fn + fp

    with np.errstate(divide="ignore", invalid="ignore"):
        ious = np.where(denom_iou > 0, tp / denom_iou, np.nan)

    miou = float(np.nanmean(ious))

    #per-class pixel accuracy: tp/(tp+fn)
    denom_acc = tp + fn
    with np.errstate(divide="ignore", invalid="ignore"):
        per_class_acc = np.where(denom_acc > 0, tp / denom_acc, np.nan)

    #overall pixel accuracy across all valid classes
    total_pixels = cm.sum()
    overall_acc = float(tp.sum() / total_pixels) if total_pixels > 0 else float("nan")

    return ious.tolist(), miou, per_class_acc.tolist(), overall_acc


In [None]:
#===== training loop =====
#main training: amp + ema + cosine schedule + hybrid loss; saves best ema checkpoint
best_miou, best_path = -1.0, "best_deeplabv3p_ema.pth"
print("beginning the actual loop")
AMP_CAST_DTYPE = AMP_DTYPE or (torch.float16 if device.type == "cuda" else torch.float32)

class EMAHook:
    def __init__(self, model, decay=0.999):
        self.ema = EMA(model, decay)

ema_hook = EMAHook(model, decay=0.999)

global_step = 0
for epoch in range(1, EPOCHS+1):
    pbar = tqdm(train_loader, desc=f"Epoch {epoch:02d} [train]", ncols=100, leave=False)
    model.train()
    running = 0.0
    opt.zero_grad(set_to_none=True)

    for step, (x, y) in enumerate(pbar, 1):
        x = x.to(device, non_blocking=True)
        y = y.to(device, dtype=torch.long, non_blocking=True)

        #amp
        with torch.amp.autocast(device_type=device.type, dtype=AMP_CAST_DTYPE, enabled=(device.type == "cuda")):
            logits = model(x)
            loss = loss_fn(logits, y)

        if SCALER_ENABLED:
            scaler.scale(loss/ACCUM_STEPS).backward()  #scale for fp16 stability
        else:
            (loss/ACCUM_STEPS).backward()

        need_step = (step % ACCUM_STEPS == 0) or (step == len(train_loader))
        if need_step:
            #grad clip
            if SCALER_ENABLED:
                scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=GRAD_CLIP_N)

            if SCALER_ENABLED:
                scaler.step(opt); scaler.update()
            else:
                opt.step()
            opt.zero_grad(set_to_none=True)
            ema_hook.ema.update(model)  #update ema shadow after optimizer step
            sched.step()
            global_step += 1

        running += loss.item() * x.size(0)
        pbar.set_postfix(loss=f"{loss.item():.3f}", lr=f"{sched.get_last_lr()[0]:.2e}")  #live feedback

    train_loss = running / len(train_loader.dataset)

    #==== validate with ema weights ====
    ema_hook.ema.apply_shadow(model)
    per_class_ious, val_miou, per_class_acc, overall_acc = miou_on_loader_fast(
        model, val_loader, NUM_CLASSES, device, ignore_index=IGNORE_INDEX, tta_hflip=True
    )
    avg_miou_pct = float(np.nanmean(np.array(per_class_ious[1:], dtype=np.float64)) * 100.0)  #exclude background
    ema_hook.ema.restore(model)

    #nice prints
    print("Per-class IoUs:", per_class_ious)
    print("Per-class Accuracies:", [None if np.isnan(a) else float(a) for a in per_class_acc])
    print(f"Overall Pixel Acc: {overall_acc*100:.2f}%")
    print(f"Avg mIoU (no background): {avg_miou_pct:.2f}%")
    print(f"Epoch {epoch:02d} | train_loss {train_loss:.4f} | val_mIoU (EMA+TTA) {val_miou:.3f}")

    if val_miou > best_miou:
        best_miou = val_miou
        #save ema weights as best
        ema_hook.ema.apply_shadow(model)
        torch.save(model.state_dict(), best_path)
        ema_hook.ema.restore(model)
        print(f"  ↳ saved best EMA weights to {best_path}")
        version = 13
        torch.save(model.state_dict(), f"{BASE}/v{version}_deeplabv3p_ema_e{epoch}miou{val_miou:.3f}.pth")
        print("saved to drive")

#gpu vram info
if torch.cuda.is_available():
    free, total = torch.cuda.mem_get_info()
    print(f"GPU VRAM used: {(total-free)/1e9:.2f} GB / {total/1e9:.2f} GB  ({torch.cuda.get_device_name(0)})")

print(f"Done. Best EMA+TTA val mIoU: {best_miou:.3f}")  #final summary so i can compare runs
