# DeepLabv3-ResNet50 Baseline
Build and train a semantic segmentation model using DeepLabv3 with a ResNet-50 backbone on the dermoscopy dataset.

In [None]:
import os
import time
import copy
import random
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.cuda.amp as amp
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision.models.segmentation import (
    deeplabv3_resnet50,
    DeepLabV3_ResNet50_Weights,
 )
from torchvision.models import ResNet50_Weights

torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

In [None]:
# Configuration
data_root = Path("/Users/enricotazzer/Desktop/multi-task-learning-for-classification-and-segmentation-of-skin-lesions/dataset/segmentation")
train_img_dir = data_root / "train" / "input"
train_mask_dir = data_root / "train" / "ground_truth"
val_img_dir = data_root / "val" / "input"
val_mask_dir = data_root / "val" / "ground_truth"
test_img_dir = data_root / "test" / "input"
test_mask_dir = data_root / "test" / "ground_truth"
required_dirs = [train_img_dir, train_mask_dir, val_img_dir, val_mask_dir]
for path in required_dirs:
    if not path.exists():
        raise FileNotFoundError(f"Missing required directory: {path}")

classes = ["background", "lesion"]
num_classes = len(classes)
ignore_index = 255
image_size = (256, 256)

batch_size = 4
num_epochs = 25
learning_rate = 5e-4
weight_decay = 1e-4

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Training on {device}")

Training on mps


In [None]:
# Data pipeline
weights = DeepLabV3_ResNet50_Weights.DEFAULT
normalize_mean = tuple(float(m) for m in weights.meta.get("mean", (0.485, 0.456, 0.406)))
normalize_std = tuple(float(s) for s in weights.meta.get("std", (0.229, 0.224, 0.225)))

class SegmentationTransform:
    def __init__(self, is_train: bool, image_size: Tuple[int, int], mean: Tuple[float, ...], std: Tuple[float, ...]):
        self.is_train = is_train
        self.image_size = image_size
        self.mean = mean
        self.std = std

    def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.is_train:
            if random.random() < 0.5:
                image = F.hflip(image)
                mask = F.hflip(mask)
            if random.random() < 0.5:
                image = F.vflip(image)
                mask = F.vflip(mask)
            if random.random() < 0.3:
                angle = random.uniform(-15.0, 15.0)
                image = F.rotate(image, angle, interpolation=transforms.InterpolationMode.BILINEAR)
                mask = F.rotate(mask, angle, interpolation=transforms.InterpolationMode.NEAREST)

        image = F.resize(image, self.image_size, interpolation=transforms.InterpolationMode.BILINEAR)
        mask = F.resize(mask, self.image_size, interpolation=transforms.InterpolationMode.NEAREST)

        image = F.to_tensor(image)
        image = F.normalize(image, mean=self.mean, std=self.std)

        mask_array = np.array(mask, dtype=np.int64)
        mask_tensor = torch.from_numpy(mask_array)
        mask_tensor = (mask_tensor > 0).to(torch.int64)
        return image, mask_tensor

class SegmentationDataset(Dataset):
    def __init__(self, image_dir: Path, mask_dir: Path, transform: SegmentationTransform):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.transform = transform
        self.image_exts = {".jpg", ".jpeg", ".png", ".bmp"}
        self.mask_exts = [".png", ".jpg", ".jpeg", ".bmp"]
        if not self.image_dir.exists():
            raise FileNotFoundError(f"Image directory does not exist: {self.image_dir}")
        if not self.mask_dir.exists():
            raise FileNotFoundError(f"Mask directory does not exist: {self.mask_dir}")
        self.samples = self._gather_samples()
        if not self.samples:
            raise RuntimeError(f"No image/mask pairs found in {self.image_dir} and {self.mask_dir}")

    def _gather_samples(self) -> List[Tuple[Path, Path]]:
        pairs: List[Tuple[Path, Path]] = []
        for image_path in sorted(self.image_dir.iterdir()):
            if image_path.suffix.lower() not in self.image_exts:
                continue
            mask_path = self._find_corresponding_mask(image_path)
            pairs.append((image_path, mask_path))
        return pairs

    def _find_corresponding_mask(self, image_path: Path) -> Path:

        stem = image_path.stem+"_segmentation"
        for ext in self.mask_exts:
            candidate = self.mask_dir / f"{stem}{ext}"
            if candidate.exists():
                return candidate
        raise FileNotFoundError(f"Mask for {image_path.name} not found in {self.mask_dir}")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image_path, mask_path = self.samples[index]
        with Image.open(image_path) as img:
            image = img.convert("RGB")
        with Image.open(mask_path) as msk:
            mask = msk.convert("L")
        image_tensor, mask_tensor = self.transform(image, mask)
        return image_tensor, mask_tensor

train_transform = SegmentationTransform(is_train=True, image_size=image_size, mean=normalize_mean, std=normalize_std)
eval_transform = SegmentationTransform(is_train=False, image_size=image_size, mean=normalize_mean, std=normalize_std)

train_dataset = SegmentationDataset(train_img_dir, train_mask_dir, transform=train_transform)
val_dataset = SegmentationDataset(val_img_dir, val_mask_dir, transform=eval_transform)
test_dataset = (
    SegmentationDataset(test_img_dir, test_mask_dir, transform=eval_transform)
    if test_img_dir.exists() and test_mask_dir.exists()
    else None
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = (
    DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    if test_dataset is not None
    else None
)

print(f"Train samples: {len(train_dataset)} | Val samples: {len(val_dataset)}")
if test_dataset is not None:
    print(f"Test samples: {len(test_dataset)}")

Train samples: 2594 | Val samples: 66
Test samples: 1000


In [None]:
# Model setup
model = deeplabv3_resnet50(weights=DeepLabV3_ResNet50_Weights.DEFAULT)
model.classifier[-1] = nn.Conv2d(256, num_classes, kernel_size=1)
if model.aux_classifier is not None:
    model.aux_classifier[-1] = nn.Conv2d(256, num_classes, kernel_size=1)
nn.init.xavier_uniform_(model.classifier[-1].weight)
nn.init.zeros_(model.classifier[-1].bias)
if model.aux_classifier is not None and hasattr(model.aux_classifier[-1], "bias") and model.aux_classifier[-1].bias is not None:
    nn.init.xavier_uniform_(model.aux_classifier[-1].weight)
    nn.init.zeros_(model.aux_classifier[-1].bias)

freeze_prefixes = ("conv1", "bn1", "layer1", "layer2")
for name, param in model.backbone.named_parameters():
    if name.split(".")[0] in freeze_prefixes:
        param.requires_grad = False

model.to(device)

use_amp = device.type == "cuda"
scaler = amp.GradScaler(enabled=use_amp)

criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
trainable_parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable_parameters, lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1)

trainable_params = sum(p.numel() for p in trainable_parameters)
print(f"Trainable parameters: {trainable_params:,}")

Trainable parameters: 39,633,986


In [None]:
# Training utilities
def update_confusion_matrix(confmat: torch.Tensor, preds: torch.Tensor, targets: torch.Tensor, num_classes: int) -> torch.Tensor:
    valid = targets != ignore_index
    preds = preds[valid]
    targets = targets[valid]
    if preds.numel() == 0:
        return confmat
    indices = targets * num_classes + preds
    confmat += torch.bincount(indices, minlength=num_classes ** 2).reshape(num_classes, num_classes)
    return confmat

def compute_iou(confmat: torch.Tensor) -> torch.Tensor:
    intersection = torch.diag(confmat)
    ground_truth = confmat.sum(dim=1)
    predicted = confmat.sum(dim=0)
    union = ground_truth + predicted - intersection
    iou = intersection / union.clamp(min=1.0)
    return iou

def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    scaler: Optional[amp.GradScaler] = None,
    use_amp: bool = False,
) -> Tuple[float, float, float]:
    model.train()
    running_loss = 0.0
    correct_pixels = 0
    total_pixels = 0
    start = time.time()

    for images, masks in dataloader:
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with amp.autocast(enabled=use_amp):
            outputs = model(images)["out"]
            loss = criterion(outputs, masks)

        if scaler is not None and use_amp:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        valid = masks != ignore_index
        correct_pixels += (preds[valid] == masks[valid]).sum().item()
        total_pixels += valid.sum().item()

    epoch_loss = running_loss / max(len(dataloader.dataset), 1)
    pixel_acc = correct_pixels / max(total_pixels, 1)
    elapsed = time.time() - start
    return epoch_loss, pixel_acc, elapsed

@torch.no_grad()
def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    use_amp: bool = False,
) -> Tuple[float, float, float, torch.Tensor]:
    model.eval()
    running_loss = 0.0
    correct_pixels = 0
    total_pixels = 0
    confmat = torch.zeros((num_classes, num_classes), dtype=torch.float64)

    for images, masks in dataloader:
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)

        with amp.autocast(enabled=use_amp):
            outputs = model(images)["out"]
            loss = criterion(outputs, masks)
        running_loss += loss.item() * images.size(0)

        preds = outputs.argmax(dim=1)
        valid = masks != ignore_index
        correct_pixels += (preds[valid] == masks[valid]).sum().item()
        total_pixels += valid.sum().item()

        confmat = update_confusion_matrix(confmat, preds.cpu(), masks.cpu(), num_classes)

    epoch_loss = running_loss / max(len(dataloader.dataset), 1)
    pixel_acc = correct_pixels / max(total_pixels, 1)
    per_class_iou = compute_iou(confmat)
    if num_classes > 1:
        mean_iou = per_class_iou[1:].mean().item()
    else:
        mean_iou = per_class_iou.mean().item()
    return epoch_loss, pixel_acc, mean_iou, per_class_iou

In [None]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
    num_epochs: int = 25,
    checkpoint_dir: str = "artifacts/segmentation",
    scaler: Optional[amp.GradScaler] = None,
    use_amp: bool = False,
 ) -> Tuple[nn.Module, Dict[str, List[float]]]:
    best_state = copy.deepcopy(model.state_dict())
    best_miou = 0.0
    history: Dict[str, List[float]] = {
        "train_loss": [],
        "train_pixel_acc": [],
        "val_loss": [],
        "val_pixel_acc": [],
        "val_mIoU": [],
        "val_per_class_iou": [],
    }

    checkpoint_path = Path(checkpoint_dir)
    checkpoint_path.mkdir(parents=True, exist_ok=True)
    best_ckpt = checkpoint_path / "deeplabv3_resnet50_best.pt"

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print("-" * 20)

        train_loss, train_acc, train_time = train_one_epoch(
            model, train_loader, criterion, optimizer, scaler=scaler, use_amp=use_amp
        )
        val_loss, val_acc, val_miou, val_per_class_iou = evaluate(
            model, val_loader, criterion, use_amp=use_amp
        )
        if scheduler is not None:
            scheduler.step()

        history["train_loss"].append(train_loss)
        history["train_pixel_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_pixel_acc"].append(val_acc)
        history["val_mIoU"].append(val_miou)
        history["val_per_class_iou"].append(val_per_class_iou.tolist())

        print(
            f"train loss: {train_loss:.4f} | pixel acc: {train_acc:.4f} | time: {train_time:.1f}s"
        )
        print(
            f"val   loss: {val_loss:.4f} | pixel acc: {val_acc:.4f} | mIoU: {val_miou:.4f}"
        )

        if val_miou > best_miou:
            best_miou = val_miou
            best_state = copy.deepcopy(model.state_dict())
            torch.save({
                "model_state_dict": best_state,
                "val_mIoU": best_miou,
                "epoch": epoch + 1,
                "classes": classes,
                "image_size": image_size,
                "per_class_iou": val_per_class_iou.tolist(),
            }, best_ckpt)
            print(f"\nâœ… Saved new best checkpoint to {best_ckpt}\n")

    print(f"Best validation mIoU: {best_miou:.4f}")
    model.load_state_dict(best_state)
    return model, history

In [None]:
# Train the model
if __name__ == "__main__":
    if len(train_dataset) == 0 or len(val_dataset) == 0:
        raise RuntimeError("Training/validation datasets are empty. Check the data directory structure.")

    trained_model, history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=num_epochs,
        scaler=scaler,
        use_amp=use_amp,
    )

Epoch 1/50
--------------------


Traceback (most recent call last):
  File [35m"<string>"[0m, line [35m1[0m, in [35m<module>[0m
    from multiprocessing.spawn import spawn_main; [31mspawn_main[0m[1;31m(tracker_fd=77, pipe_handle=91)[0m
                                                  [31m~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m122[0m, in [35mspawn_main[0m
    exitcode = _main(fd, parent_sentinel)
  File [35m"/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m132[0m, in [35m_main[0m
    self = reduction.pickle.load(from_parent)
[1;35mAttributeError[0m: [35mCan't get attribute 'SegmentationDataset' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>[0m


KeyboardInterrupt: 

In [None]:
# Learning curves
if 'history' in locals():
    import matplotlib.pyplot as plt

    epochs = range(1, len(history['train_loss']) + 1)
    fig, axes = plt.subplots(1, 3, figsize=(18, 4))

    axes[0].plot(epochs, history['train_loss'], label='Train')
    axes[0].plot(epochs, history['val_loss'], label='Val')
    axes[0].set_title('Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Cross-Entropy')
    axes[0].legend()

    axes[1].plot(epochs, history['train_pixel_acc'], label='Train')
    axes[1].plot(epochs, history['val_pixel_acc'], label='Val')
    axes[1].set_title('Pixel Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].legend()

    axes[2].plot(epochs, history['val_mIoU'], label='Val mIoU', color='tab:green')
    axes[2].set_title('Validation mIoU')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('mIoU')
    axes[2].legend()

    plt.tight_layout()
else:
    print("Run the training cell first to generate history.")

In [None]:
# Validation metrics
if 'trained_model' in locals():
    val_loss, val_acc, val_miou, val_per_class = evaluate(trained_model, val_loader, criterion, use_amp=use_amp)
    print(f"Validation loss: {val_loss:.4f}")
    print(f"Validation pixel acc : {val_acc:.4f}")
    print(f"Validation mIoU      : {val_miou:.4f}")
    for cls_name, cls_iou in zip(classes, val_per_class.tolist()):
        print(f"  IoU({cls_name}): {cls_iou:.4f}")
else:
    print("Train the model before running this cell.")

In [None]:
# Test set evaluation
if test_loader is not None and 'trained_model' in locals():
    test_loss, test_acc, test_miou, test_per_class = evaluate(trained_model, test_loader, criterion, use_amp=use_amp)
    print(f"Test loss      : {test_loss:.4f}")
    print(f"Test pixel acc : {test_acc:.4f}")
    print(f"Test mIoU      : {test_miou:.4f}")
    for cls_name, cls_iou in zip(classes, test_per_class.tolist()):
        print(f"  IoU({cls_name}): {cls_iou:.4f}")
else:
    print("No test set detected or model has not been trained yet.")

In [None]:
# Qualitative inspection
def colorize_mask(mask: torch.Tensor) -> np.ndarray:
    palette = np.array([[0, 0, 0], [255, 0, 0]], dtype=np.uint8)
    mask_np = mask.cpu().numpy().astype(np.int64)
    mask_np = np.clip(mask_np, 0, len(palette) - 1)
    return palette[mask_np]

if 'trained_model' in locals():
    trained_model.eval()
    sample_batch = next(iter(val_loader))
    images, masks = sample_batch
    with torch.no_grad():
        outputs = trained_model(images.to(device))["out"]
        preds = outputs.argmax(dim=1).cpu()

    import matplotlib.pyplot as plt

    std_tensor = torch.tensor(normalize_std, dtype=images.dtype).view(-1, 1, 1)
    mean_tensor = torch.tensor(normalize_mean, dtype=images.dtype).view(-1, 1, 1)

    batch_size_vis = min(3, images.size(0))
    fig, axes = plt.subplots(batch_size_vis, 3, figsize=(12, 4 * batch_size_vis))
    if batch_size_vis == 1:
        axes = np.expand_dims(axes, axis=0)
    for idx in range(batch_size_vis):
        img = images[idx].cpu() * std_tensor + mean_tensor
        img = img.permute(1, 2, 0).clamp(0, 1).numpy()
        axes[idx, 0].imshow(img)
        axes[idx, 0].set_title('Input')
        axes[idx, 0].axis('off')

        axes[idx, 1].imshow(colorize_mask(masks[idx]))
        axes[idx, 1].set_title('Ground Truth')
        axes[idx, 1].axis('off')

        axes[idx, 2].imshow(colorize_mask(preds[idx]))
        axes[idx, 2].set_title('Prediction')
        axes[idx, 2].axis('off')

    plt.tight_layout()
else:
    print("Train the model to visualize qualitative predictions.")