In [None]:
!pip install gdown
!gdown --fuzzy https://drive.google.com/file/d/1hy8a2y5_iLSh1CiqMjeY68i4Xyraw9Im/view?usp=sharing
!unzip semantic_mapping_masks.zip

In [None]:
kaggle=True
!pip install segmentation_models_pytorch torchmetrics
if kaggle:
  from kaggle_secrets import UserSecretsClient
  user_secrets = UserSecretsClient()
  secret_value_0 = user_secrets.get_secret("wandb_api")
else:
  !pip install onnx
  from google.colab import userdata
  secret_value_0 = userdata.get("wandb_api")


import wandb
wandb.login(key = secret_value_0)

In [None]:
import os
import random
import time

import numpy as np
import torch
from PIL import Image
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF

import wandb
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss, FocalLoss
from torchmetrics.classification import MulticlassJaccardIndex, MulticlassF1Score, MulticlassDiceScore

BASE_DIR = "."
TRAIN_IMG_DIR = os.path.join(BASE_DIR, "train", "images")
if use_sam:
    TRAIN_MASK_DIR = os.path.join(BASE_DIR, "train", "masks-sam")
    VAL_MASK_DIR = os.path.join(BASE_DIR, "val", "masks-sam")
    TEST_MASK_DIR = os.path.join(BASE_DIR, "test", "masks-sam")
else:
    TRAIN_MASK_DIR = os.path.join(BASE_DIR, "train", "masks")
    VAL_MASK_DIR = os.path.join(BASE_DIR, "val", "masks")
    TEST_MASK_DIR = os.path.join(BASE_DIR, "test", "masks")
VAL_IMG_DIR = os.path.join(BASE_DIR, "val", "images")
TEST_IMG_DIR = os.path.join(BASE_DIR, "test", "images")


#####  SET IF YOU WANT TO USE SAM MASKS #########
use_sam = False
# Training parameters
IMAGE_SIZE = (960,608)
NUM_CLASSES = 9
BATCH_SIZE = 4
NUM_EPOCHS = 12
LEARNING_RATE = 1e-4
SEED = 1234
model_type_smp = "UNET"
ENCODER_NAME = "tu-mobilenetv4_conv_small"
ENCODER_WEIGHTS = "imagenet"
WANDB_PROJECT = "seg_sem_veh_pub"

APPENDIX_COMMENT = "-new_VER"
WANDB_RUN_NAME = model_type_smp + ENCODER_NAME + APPENDIX_COMMENT

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

NEW_COLOR_MAP = {
    0: (0,0,0),
    1: (70,130,180),
    2: (70,70,70),
    3: (152,251,152),
    4: (128,64,128),
    5: (107,142,35),
    6: (220,220,0),
    7: (0,0,142),
    8: (255,0,0)
}

if use_sam:
  CLASS_WEIGHTS = {
      0: 0.06560487,
      1: 0.05168178,
      2: 0.04068273,
      3: 0.05864576,
      4: 0.03198082,
      5: 0.07781676,
      6: 0.2369693,
      7: 0.08697386,
      8: 0.34964411
  }
else:
  CLASS_WEIGHTS = {
      0: 0.62686078,
      1: 0.01697029,
      2: 0.0129806,
      3: 0.05128192,
      4: 0.00910945,
      5: 0.02012851,
      6: 0.08470032,
      7: 0.02736542,
      8: 0.15060272
  }

def index_to_color_mask(idx_mask):
    h, w = idx_mask.shape
    arr = np.zeros((h, w, 3), dtype=np.uint8)
    for idx, color in NEW_COLOR_MAP.items():
        arr[idx_mask == idx] = color
    return Image.fromarray(arr)

class SegmentationTransforms:
    def __init__(self, is_train=False):
        self.is_train = is_train
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                             std=[0.229, 0.224, 0.225])
        
    def __call__(self, image, mask=None):
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        if mask is not None and isinstance(mask, np.ndarray):
            mask = Image.fromarray(mask.astype(np.uint8))
            
        if self.is_train:
            # Random horizontal flipping
            if random.random() > 0.5:
                image = TF.hflip(image)
                if mask is not None:
                    mask = TF.hflip(mask)
            
            # Random brightness and contrast
            image = transforms.ColorJitter(
                brightness=0.2, 
                contrast=0.2, 
                saturation=0.1, 
                hue=0.1
            )(image)
            
            # Random small rotation
            angle = random.uniform(-15, 15)
            image = TF.rotate(image, angle, interpolation=transforms.InterpolationMode.BILINEAR)
            if mask is not None:
                mask = TF.rotate(mask, angle, interpolation=transforms.InterpolationMode.NEAREST)
                
        image = TF.to_tensor(image)
        image = self.normalize(image)
        
        if mask is not None:
            mask = torch.from_numpy(np.array(mask)).long()
            return image, mask
        return image

class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, size, is_train=False):
        self.imgs = sorted([os.path.join(img_dir, f) for f in os.listdir(img_dir) 
                           if f.lower().endswith((".png", ".jpg", ".jpeg"))])
        self.masks = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir)
                            if f.lower().endswith((".png", ".jpg", ".jpeg"))])
        self.size = size
        self.tfm = SegmentationTransforms(is_train=is_train)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, i):
        img = Image.open(self.imgs[i]).convert("RGB").resize(self.size)
        
        mask = Image.open(self.masks[i]).resize(self.size, Image.NEAREST)
        
        img, mask_idx = self.tfm(img, mask)
        
        return img, mask_idx, os.path.basename(self.imgs[i])

train_ds = SegmentationDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, IMAGE_SIZE, is_train=True)
val_ds = SegmentationDataset(VAL_IMG_DIR, VAL_MASK_DIR, IMAGE_SIZE, is_train=False)
test_ds = SegmentationDataset(TEST_IMG_DIR, TEST_MASK_DIR, IMAGE_SIZE, is_train=False)

# Create dalaloaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if model_type_smp == "FPN":
    model = smp.FPN(encoder_name=ENCODER_NAME, encoder_weights=ENCODER_WEIGHTS, in_channels=3, classes=NUM_CLASSES).to(device)
elif model_type_smp == "DL3PLUS":
    model = smp.DeepLabV3Plus(encoder_name=ENCODER_NAME, encoder_weights=ENCODER_WEIGHTS, in_channels=3, classes=NUM_CLASSES).to(device)
elif model_type_smp == "SEGFORMER":
    model = smp.Segformer(encoder_name=ENCODER_NAME, encoder_weights=ENCODER_WEIGHTS, in_channels=3, classes=NUM_CLASSES).to(device)
elif model_type_smp == "LINKNET":
    model = smp.Linknet(encoder_name=ENCODER_NAME, encoder_weights=ENCODER_WEIGHTS, in_channels=3, classes=NUM_CLASSES).to(device)
else:
    model = smp.Unet(encoder_name=ENCODER_NAME, encoder_weights=ENCODER_WEIGHTS, in_channels=3, classes=NUM_CLASSES).to(device)

# Set up loss functions, optimizer, and metrics
weights = torch.tensor([CLASS_WEIGHTS[i] for i in range(NUM_CLASSES)], dtype=torch.float32).to(device)
ce_loss_fn = CrossEntropyLoss(weight=weights)
dice_loss_fn = DiceLoss(mode="multiclass")
focal_loss_fn = FocalLoss(mode="multiclass", alpha=0.25)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
iou_metric = MulticlassJaccardIndex(num_classes=NUM_CLASSES).to(device)
fscore_metric = MulticlassF1Score(num_classes=NUM_CLASSES).to(device)
dice_score_metric = MulticlassDiceScore(num_classes=NUM_CLASSES).to(device)


In [None]:
print(device)

In [None]:
from tqdm.auto import tqdm

# Initialize Weights & Biases
wandb.init(
    project=WANDB_PROJECT,
    name=WANDB_RUN_NAME,
    entity="qbizm",
    config={
        "encoder": ENCODER_NAME,
        "learning_rate": LEARNING_RATE,
        "batch_size": BATCH_SIZE,
        "num_epochs": NUM_EPOCHS,
        "num_classes": NUM_CLASSES,
        "image_size": IMAGE_SIZE,
    }
)

# Define metrics for Weights & Biases
wandb.define_metric("train_step")
wandb.define_metric("val_step")
wandb.define_metric("train_epoch")
wandb.define_metric("val_epoch")
wandb.define_metric("train/batch_*", step_metric="train_step")
wandb.define_metric("val/batch_*", step_metric="val_step")
wandb.define_metric("train/epoch_*", step_metric="train_epoch")
wandb.define_metric("val/epoch_*", step_metric="val_epoch")

# Select images for visualization
viz_indices = random.sample(range(len(val_ds)), k=5)
viz_fns = [os.path.basename(val_ds.imgs[i]) for i in viz_indices]
train_step = 0
val_step = 0
train_epoch = 0
val_epoch = 0
best_val = float("inf")
single_logged = False

# Training loop
for epoch in range(NUM_EPOCHS):
    model.train()
    train_loss = train_ce = train_dice = train_focal = 0.0
    train_iou_total = train_f_total = 0.0

    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]", leave=False)
    for imgs_b, msks_b, _ in train_pbar:
        imgs_b, msks_b = imgs_b.to(device), msks_b.to(device)
        optimizer.zero_grad()
        out = model(imgs_b)
        ce = ce_loss_fn(out, msks_b)
        dloss = dice_loss_fn(out, msks_b)
        floss = focal_loss_fn(out, msks_b)
        loss = ce + dloss + floss
        loss.backward()
        optimizer.step()
        preds = out.argmax(1)
        biou = iou_metric(preds, msks_b).item()
        bf = fscore_metric(preds, msks_b).item()

        train_loss += loss.item()
        train_ce += ce.item()
        train_dice += dloss.item()
        train_focal += floss.item()
        train_iou_total += biou
        train_f_total += bf

        train_pbar.set_postfix(loss=f"{loss.item():.4f}", iou=f"{biou:.4f}")

        train_step += 1
        wandb.log({
            "train_step": train_step,
            "train/batch_loss": loss.item(),
            "train/batch_ce": ce.item(),
            "train/batch_dice": 1.0 - dloss.item(),
            "train/batch_focal": floss.item(),
            "train/batch_iou": biou,
            "train/batch_fscore": bf
        })

    train_epoch += 1
    wandb.log({
        "train_epoch": train_epoch,
        "train/epoch_loss": train_loss / len(train_loader),
        "train/epoch_ce": train_ce / len(train_loader),
        "train/epoch_dice": 1.0 - (train_dice / len(train_loader)),
        "train/epoch_focal": train_focal / len(train_loader),
        "train/epoch_iou": train_iou_total / len(train_loader),
        "train/epoch_fscore": train_f_total / len(train_loader)
    })

    # Validation
    model.eval()
    val_loss = val_ce = val_dice = val_focal = 0.0
    val_iou_total = val_f_total = 0.0

    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Valid]", leave=False)
        for imgs_b, msks_b, fns in val_pbar:
            imgs_b, msks_b = imgs_b.to(device), msks_b.to(device)
            if not single_logged:
                t0 = time.time()
                _ = model(imgs_b[0].unsqueeze(0))
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                wandb.log({"val/single_inference_time": time.time() - t0})
                single_logged = True

            out = model(imgs_b)
            ce = ce_loss_fn(out, msks_b)
            dloss = dice_loss_fn(out, msks_b)
            floss = focal_loss_fn(out, msks_b)
            loss = ce + dloss + floss
            preds = out.argmax(1)
            biou = iou_metric(preds, msks_b).item()
            bf = fscore_metric(preds, msks_b).item()

            val_loss += loss.item()
            val_ce += ce.item()
            val_dice += dloss.item()
            val_focal += floss.item()
            val_iou_total += biou
            val_f_total += bf

            val_pbar.set_postfix(loss=f"{loss.item():.4f}", iou=f"{biou:.4f}")

            val_step += 1
            wandb.log({
                "val_step": val_step,
                "val/batch_loss": loss.item(),
                "val/batch_ce": ce.item(),
                "val/batch_dice": 1.0 - dloss.item(),
                "val/batch_focal": floss.item(),
                "val/batch_iou": biou,
                "val/batch_fscore": bf
            })

    val_epoch += 1
    wandb.log({
        "val_epoch": val_epoch,
        "val/epoch_loss": val_loss / len(val_loader),
        "val/epoch_ce": val_ce / len(val_loader),
        "val/epoch_dice": 1.0 - (val_dice / len(val_loader)),
        "val/epoch_focal": val_focal / len(val_loader),
        "val/epoch_iou": val_iou_total / len(val_loader),
        "val/epoch_fscore": val_f_total / len(val_loader)
    })

    # Save best model
    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), "best_model.pth")
        wandb.save("best_model.pth")
        dummy_input = torch.randn(1, 3, IMAGE_SIZE[1], IMAGE_SIZE[0], device=device)
        torch.onnx.export(model, dummy_input, "best_model.onnx", input_names=["input"], output_names=["output"], opset_version=11)
        wandb.save("best_model.onnx")

    scheduler.step(val_loss / len(val_loader))

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    del imgs_b, msks_b, out, preds

# ─── Test set evaluation ──────────────────────────────────────────────────────
model.eval()
test_loss = test_ce = test_focal = 0.0
test_iou = test_f1 = test_dice = 0.0

with torch.no_grad():
    test_pbar = tqdm(test_loader, desc="Evaluating on test set")
    for imgs_b, msks_b, _ in test_pbar:
        imgs_b, msks_b = imgs_b.to(device), msks_b.to(device)
        out    = model(imgs_b)
        ce     = ce_loss_fn(out, msks_b)
        dloss  = dice_loss_fn(out, msks_b)
        floss  = focal_loss_fn(out, msks_b)
        loss   = ce + dloss + floss
        preds  = out.argmax(1)

        biou   = iou_metric(preds, msks_b).item()
        bf     = fscore_metric(preds, msks_b).item()
        bdice  = dice_score_metric(preds, msks_b).item()

        test_loss  += loss.item()
        test_ce    += ce.item()
        test_focal += floss.item()
        test_iou   += biou
        test_f1    += bf
        test_dice  += bdice

        test_pbar.set_postfix(loss=f"{loss.item():.4f}", iou=f"{biou:.4f}", dice=f"{bdice:.4f}")

# Compute averages
n = len(test_loader)
wandb.log({
    "test/loss":  test_loss/n,
    "test/ce":    test_ce/n,
    "test/focal": test_focal/n,
    "test/iou":   test_iou/n,
    "test/fscore":test_f1/n,
    "test/dice":  test_dice/n
})

# Summary
wandb.run.summary.update({
    "test_loss":   test_loss/n,
    "test_ce":     test_ce/n,
    "test_focal":  test_focal/n,
    "test_iou":    test_iou/n,
    "test_fscore": test_f1/n,
    "test_dice":   test_dice/n
})


# Generate a few visualizations from test set
viz_test_indices = random.sample(range(len(test_ds)), k=5)
viz_test_fns = [os.path.basename(test_ds.imgs[i]) for i in viz_test_indices]

with torch.no_grad():
    for i in viz_test_indices:
        img, mask, fn = test_ds[i]
        img = img.unsqueeze(0).to(device)
        mask = mask.unsqueeze(0).to(device)
        
        out = model(img)
        pred = out.argmax(1)
        
        img_np = img[0].cpu().permute(1,2,0).numpy() * np.array([0.229,0.224,0.225]) + np.array([0.485,0.456,0.406])
        img_np = np.clip(img_np, 0, 1)
        
        gt_col = np.array(index_to_color_mask(mask[0].cpu().numpy())) / 255.0
        pd_col = np.array(index_to_color_mask(pred[0].cpu().numpy())) / 255.0
        
        wandb.log({
            f"test_img/{fn}": wandb.Image(img_np),
            f"test_gt/{fn}": wandb.Image(gt_col),
            f"test_pred/{fn}": wandb.Image(pd_col)
        })

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

wandb.run.summary["best_val_loss"] = best_val / len(val_loader)
wandb.run.summary["final_train_loss"] = train_loss / len(train_loader)
wandb.run.summary["final_val_loss"] = val_loss / len(val_loader)
wandb.finish()