In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!git clone https://github.com/ell-hosse/Contextual-Penalty-Loss.git
%cd Contextual-Penalty-Loss

fatal: destination path 'Contextual-Penalty-Loss' already exists and is not an empty directory.
/content/Contextual-Penalty-Loss


In [3]:
# If your CPL repo is on GitHub, first pip install it here (uncomment & edit):
# !pip install -U pip
# !pip install git+https://github.com/<you>/<contextual-penalty-loss>.git

import os, math, json
from pathlib import Path
from typing import Tuple, Dict, List

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm.auto import tqdm

# ==== Paths (edit to your folders) ====
train_images_dir = "/content/drive/MyDrive/SemantiCAM/bdd100k_images_10k/10k/train"
train_masks_dir = "/content/drive/MyDrive/SemantiCAM/bdd100k_seg_maps/color_labels/train"
val_images_dir = "/content/drive/MyDrive/SemantiCAM/bdd100k_images_10k/10k/val"
val_masks_dir = "/content/drive/MyDrive/SemantiCAM/bdd100k_seg_maps/color_labels/val"

# ==== Hyperparams ====
IMAGE_SIZE = 512
BATCH_SIZE = 128
WORKERS = 4
EPOCHS = 30
BASE_CH = 64
LR = 3e-4
USE_AMP = True

# CPL weights
W_CPL = 1.0
W_CE = 0.2  # small CE blend helps early stability

OUT_DIR = "/content/drive/MyDrive/CPLoss/saved_model/"

IGNORE_INDEX = 255

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)


Device: cuda


In [4]:
# BDD100K classes
BDD_CLASSES = [
    "road","sidewalk","building","wall","fence","pole","traffic light","traffic sign",
    "vegetation","terrain","sky","person","rider","car","truck","bus","train","motorcycle","bicycle"
]
NUM_CLASSES = len(BDD_CLASSES)

# Common color palette for color-label masks
BDD_COLORS = [
    (128,64,128), (244, 35,232), ( 70, 70, 70), (102,102,156), (190,153,153),
    (153,153,153), (250,170, 30), (220,220,  0), (107,142, 35), (152,251,152),
    ( 70,130,180), (220, 20, 60), (255,  0,  0), (  0,  0,142), (  0,  0, 70),
    (  0, 60,100), (  0, 80,100), (  0,  0,230), (119, 11, 32),
]
COLOR2ID = {c:i for i,c in enumerate(BDD_COLORS)}

# --- Contextual groups to encode similarity (domain knowledge) ---
GROUPS = {
    "ground": ["road","sidewalk","terrain"],
    "construction": ["building","wall","fence"],
    "object": ["pole","traffic light","traffic sign"],
    "nature": ["vegetation","terrain"],  # overlaps with ground by nature; that's OK
    "sky": ["sky"],
    "human": ["person","rider"],
    "vehicle": ["car","truck","bus","train","motorcycle","bicycle"]
}
# Cross-group relatedness (handcrafted priors)
RELATED = {
    ("ground","construction"): 0.45,
    ("ground","vehicle"): 0.35,
    ("human","vehicle"): 0.30,
    ("construction","object"): 0.40,
    ("nature","ground"): 0.40,
}
WITHIN = 0.82 # within-group similarity
LOW = 0.08 # unrelated baseline

def build_similarity_matrix(classes: List[str]) -> torch.Tensor:
    idx = {c:i for i,c in enumerate(classes)}
    C = len(classes)
    S = torch.full((C,C), LOW, dtype=torch.float32)
    # diag
    for i in range(C):
        S[i,i] = 1.0
    # within-group
    for members in GROUPS.values():
        ids = [idx[c] for c in members if c in idx]
        for i in ids:
            for j in ids:
                if i != j:
                    S[i,j] = max(float(S[i,j]), WITHIN)
    # cross-group relatedness
    for (g1,g2), val in RELATED.items():
        ids1 = [idx[c] for c in GROUPS[g1] if c in idx]
        ids2 = [idx[c] for c in GROUPS[g2] if c in idx]
        for i in ids1:
            for j in ids2:
                S[i,j] = max(float(S[i,j]), val)
                S[j,i] = max(float(S[j,i]), val)
    # ensure symmetry
    S = 0.5*(S + S.T)
    return S

S_matrix = build_similarity_matrix(BDD_CLASSES).to(device)
print("Similarity matrix built with context. Shape:", tuple(S_matrix.shape))


Similarity matrix built with context. Shape: (19, 19)


In [5]:
try:
    from cploss import CPLoss  # your repo
    print("Imported CPLoss from installed package.")
except Exception as e:
    print("[WARN] Could not import cploss; using minimal inline CPL.")
    class CPLoss(nn.Module):
        def __init__(self, S: torch.Tensor, w_cpl=1.0, w_ce=0.2,
                     ignore_index=255, reduction='mean', from_logits=True, eps=1e-8):
            super().__init__()
            self.register_buffer("S", S.clamp(0,1))
            C = S.size(0)
            eye = torch.eye(C, device=S.device, dtype=S.dtype)
            self.S = torch.maximum(self.S, eye)
            self.D = (1.0 - self.S)
            self.w_cpl, self.w_ce = float(w_cpl), float(w_ce)
            self.ignore_index, self.reduction = ignore_index, reduction
            self.from_logits, self.eps = from_logits, eps
            self.ce = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)

        def forward(self, logits, target):
            if self.from_logits:
                probs = F.softmax(logits.float(), dim=1)
            else:
                probs = (logits.float() + self.eps) / (logits.float().sum(dim=1, keepdim=True) + self.eps)

            N, C, H, W = probs.shape
            probs_flat = probs.permute(0,2,3,1).reshape(-1, C)
            target_flat = target.view(-1).to(logits.device)
            valid = (target_flat != self.ignore_index)

            if valid.any():
                t_idx = target_flat[valid]
                D_y = self.D.to(logits.device)[t_idx] # (M,C)
                p_val = probs_flat[valid] # (M,C)
                cpl = (p_val * D_y).sum(dim=1).mean()
            else:
                cpl = logits.sum() * 0.0 # keeps graph/device

            out = self.w_cpl * cpl
            if self.w_ce > 0:
                out = out + self.w_ce * self.ce(logits, target)
            return out


Imported CPLoss from installed package.


  @torch.cuda.amp.autocast(enabled=False)


In [6]:
class BDD10KSeg(Dataset):
    def __init__(self, img_dir, mask_dir, size: int = 512):
        self.img_paths = sorted([str(p) for p in Path(img_dir).glob("*.jpg")])
        self.mask_paths = sorted([str(p) for p in Path(mask_dir).glob("*.png")])
        assert len(self.img_paths) == len(self.mask_paths) and len(self.img_paths)>0, \
            "No data found or count mismatch between images and masks."
        self.to_tensor = transforms.Compose([
            transforms.Resize((size,size), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
        ])
        self.resize_mask = transforms.Resize((size,size), interpolation=transforms.InterpolationMode.NEAREST)

    def __len__(self): return len(self.img_paths)

    def _rgb_mask_to_ids(self, mask_rgb: np.ndarray) -> np.ndarray:
        H, W, _ = mask_rgb.shape
        ids = np.full((H,W), IGNORE_INDEX, dtype=np.uint8)
        mask_flat = mask_rgb.reshape(-1,3)
        ids_flat = ids.reshape(-1)
        table = np.array(BDD_COLORS, dtype=np.uint8)
        for cid, color in enumerate(table):
            hits = np.all(mask_flat == color, axis=1)
            ids_flat[hits] = cid
        return ids

    def __getitem__(self, i):
        img = Image.open(self.img_paths[i]).convert("RGB")
        m = Image.open(self.mask_paths[i])
        if m.mode == "P" or (m.mode == "L" and np.array(m).ndim == 2):
            mask_np = np.array(self.resize_mask(m), dtype=np.uint8)
        else:
            m = self.resize_mask(m.convert("RGB"))
            mask_np = self._rgb_mask_to_ids(np.array(m, dtype=np.uint8))
        img_t = self.to_tensor(img)
        mask_t = torch.from_numpy(mask_np).long()
        return img_t, mask_t

# Quick smoke test (first sample)
_dataset_test = BDD10KSeg(train_images_dir, train_masks_dir, size=IMAGE_SIZE)
print("Train samples:", len(_dataset_test))
x0, y0 = _dataset_test[0]
print("Sample shapes:", x0.shape, y0.shape, "unique labels (sample):", torch.unique(y0)[:10])


Train samples: 7000
Sample shapes: torch.Size([3, 512, 512]) torch.Size([512, 512]) unique labels (sample): tensor([ 0,  2,  4,  5,  6,  7,  8, 10, 11, 13])


In [7]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_ch=3, num_classes=NUM_CLASSES, base=BASE_CH):
        super().__init__()
        self.d1 = DoubleConv(in_ch, base)
        self.d2 = DoubleConv(base, base*2)
        self.d3 = DoubleConv(base*2, base*4)
        self.d4 = DoubleConv(base*4, base*8)
        self.b  = DoubleConv(base*8, base*16)

        self.u4 = DoubleConv(base*16 + base*8, base*8)
        self.u3 = DoubleConv(base*8  + base*4, base*4)
        self.u2 = DoubleConv(base*4  + base*2, base*2)
        self.u1 = DoubleConv(base*2  + base,   base)

        self.pool = nn.MaxPool2d(2)
        self.out  = nn.Conv2d(base, num_classes, 1)

    def forward(self, x):
        c1 = self.d1(x)
        c2 = self.d2(self.pool(c1))
        c3 = self.d3(self.pool(c2))
        c4 = self.d4(self.pool(c3))
        cb = self.b(self.pool(c4))

        u4 = F.interpolate(cb, scale_factor=2, mode="bilinear", align_corners=False)
        u4 = self.u4(torch.cat([u4, c4], dim=1))
        u3 = F.interpolate(u4, scale_factor=2, mode="bilinear", align_corners=False)
        u3 = self.u3(torch.cat([u3, c3], dim=1))
        u2 = F.interpolate(u3, scale_factor=2, mode="bilinear", align_corners=False)
        u2 = self.u2(torch.cat([u2, c2], dim=1))
        u1 = F.interpolate(u2, scale_factor=2, mode="bilinear", align_corners=False)
        u1 = self.u1(torch.cat([u1, c1], dim=1))
        return self.out(u1)

model = UNet().to(device)
sum(p.numel() for p in model.parameters())/1e6


31.386003

In [8]:
def fast_hist(true, pred, num_classes):
    k = (true >= 0) & (true < num_classes)
    return torch.bincount(
        (true[k] * num_classes + pred[k]).to(torch.int64),
        minlength=num_classes**2,
    ).reshape(num_classes, num_classes)

train_set = BDD10KSeg(train_images_dir, train_masks_dir, size=IMAGE_SIZE)
val_set = BDD10KSeg(val_images_dir,   val_masks_dir,   size=IMAGE_SIZE)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=WORKERS, pin_memory=True)
val_loader = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=WORKERS, pin_memory=True)

criterion = CPLoss(S=S_matrix, alpha=0.2,
                   ignore_index=IGNORE_INDEX, reduction='mean').to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

print("Loaders ready. Batches:", len(train_loader), len(val_loader))


Loaders ready. Batches: 110 16


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)


In [9]:
def validate(model, loader):
    model.eval()
    total_loss = 0.0
    hist = torch.zeros(NUM_CLASSES, NUM_CLASSES, dtype=torch.int64)
    with torch.no_grad():
        pbar = tqdm(loader, total=len(loader), desc="Validate", leave=False)
        for imgs, masks in pbar:
            imgs, masks = imgs.to(device), masks.to(device)
            logits = model(imgs)
            loss = criterion(logits, masks)
            total_loss += float(loss) * imgs.size(0)

            preds = logits.argmax(1)
            for t, p in zip(masks, preds):
                valid = (t != IGNORE_INDEX)
                hist += fast_hist(t[valid].view(-1), p[valid].view(-1), NUM_CLASSES).cpu()

            pbar.set_postfix({"batch_loss": f"{float(loss):.4f}"})
    # mIoU
    hist = hist.numpy()
    iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-10)
    miou = float(np.nanmean(iu))
    avg_loss = total_loss / len(loader.dataset)
    return avg_loss, miou

best_miou = -1.0
best_path = Path(OUT_DIR) / "best_model.pt"

for epoch in range(1, EPOCHS+1):
    model.train()
    running = 0.0
    pbar = tqdm(train_loader, total=len(train_loader), desc=f"Epoch {epoch:03d}/{EPOCHS}", leave=True)
    for imgs, masks in pbar:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            logits = model(imgs)
            loss = criterion(logits, masks)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running += float(loss) * imgs.size(0)
        pbar.set_postfix({"batch_loss": f"{float(loss):.4f}"})

    train_loss = running / len(train_loader.dataset)
    val_loss, val_miou = validate(model, val_loader)

    print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | mIoU={val_miou:.4f}")

    # Save best by mIoU
    if val_miou > best_miou:
        best_miou = val_miou
        torch.save({
            "epoch": epoch,
            "state_dict": model.state_dict(),
            "best_miou": best_miou,
            "classes": BDD_CLASSES
        }, best_path)
        print(f"New best mIoU {best_miou:.4f}. Saved -> {best_path}")


Epoch 001/30:   0%|          | 0/110 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=USE_AMP):


OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 596.12 MiB is free. Process 89058 has 14.15 GiB memory in use. Of the allocated memory 13.03 GiB is allocated by PyTorch, and 1.01 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)