# Multi-Task Skin Lesion Diagnostic
This notebook sketches a PyTorch implementation of the architecture described in *Multi-Task Classification and Segmentation for Explicable Capsule Endoscopy Diagnostics*. The model shares an encoder across tasks and uses separate heads for frame-level classification and pixel-level lesion segmentation.

In [None]:
# Core libraries, torch modules, and torchvision utilities used throughout the notebook
from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple

import random

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models import ResNet50_Weights, resnet50
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF

# Fix random seed so the data augmentations and weight initialisation remain reproducible
torch.manual_seed(42)


# Model configurations

In [None]:
# Configuration
@dataclass
class Config:
    """Central hyper-parameters mirroring the settings reported in the paper."""
    project_root: Path = Path("/kaggle/input/skin-cancer-multitask-learning-mixed-data")
    mixed_dataset_name: str = "multitasked_dataset_mixed"
    image_size: Tuple[int, int] = (256, 256)  # input resolution for both tasks
    batch_size: int = 8  # classification mini-batch size
    segmentation_batch_size: int = 4  # segmentation mini-batch size (decoder is heavier)
    base_learning_rate: float = 2e-4
    weight_decay: float = 1e-4
    max_epochs: int = 20
    alpha: float = 0.1  # fusion-penalty weight α used in the paper for regularisation
    classification_loss_weight: float = 0.5  # balance between classification and segmentation losses
    segmentation_loss_weight: float = 0.5
    class_names: Tuple[str, ...] = ("MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC")
    ignore_index: int = 255  # optional label to mask out invalid pixels
    min_learning_rate: float = 1e-5
    scheduler_period: int = 10  # cosine-annealing period (epochs)
    mixed_precision: bool = True  # enable AMP when CUDA is available
    segmentation_positive_weight: float = 1.5  # kept for potential focal weighting experiments
    grad_clip_norm: Optional[float] = 5.0  # guard against exploding gradients
    imagenet_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
    imagenet_std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
    segmentation_suffix: str = "_segmentation"
    mixed_dataset_root: Optional[Path] = None

    def __post_init__(self) -> None:
        if self.mixed_dataset_root is None:
            object.__setattr__(self, "mixed_dataset_root", self.project_root / self.mixed_dataset_name)
        if not self.mixed_dataset_root.exists():
            raise FileNotFoundError(
                f"Expected mixed dataset directory at {self.mixed_dataset_root}, but it was not found."
            )

    def manifest_path(self, subset: str) -> Path:
        """Return the manifest CSV for a given subset inside the mixed dataset."""
        path = self.mixed_dataset_root / subset / "manifest.csv"
        if not path.exists():
            raise FileNotFoundError(f"Manifest for subset '{subset}' not found at {path}")
        return path

    def resolve_mixed_path(self, relative_path: str) -> Path:
        """Resolve a manifest path (stored relative to the mixed dataset root) to an absolute Path."""
        abs_path = self.mixed_dataset_root / Path(relative_path)
        if not abs_path.exists():
            raise FileNotFoundError(f"Resolved path does not exist: {abs_path}")
        return abs_path
    
cfg = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Image utilites

In [None]:
# Data transforms
def build_classification_transform(cfg: Config, train: bool) -> transforms.Compose:
    """Compose the augmentations used for the image-level classifier."""
    augmentations: List[transforms.Compose]
    if train:
        augmentations = [
            # Random crops and flips mimic the appearance variability described in the paper
            transforms.RandomResizedCrop(cfg.image_size, scale=(0.8, 1.0), interpolation=InterpolationMode.BILINEAR),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.1),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
        ]
    else:
        augmentations = [
            transforms.Resize(cfg.image_size, interpolation=InterpolationMode.BILINEAR),
            transforms.CenterCrop(cfg.image_size),
        ]
    augmentations += [
        transforms.ToTensor(),
        transforms.Normalize(mean=cfg.imagenet_mean, std=cfg.imagenet_std),
    ]
    return transforms.Compose(augmentations)


def apply_segmentation_transforms(
    image: Image.Image, mask: Image.Image, cfg: Config, train: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Apply the paper's spatial augmentations jointly and emit integer masks for cross-entropy."""
    image = image.convert("RGB")
    mask = mask.convert("L")
    if train:
        if random.random() < 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if random.random() < 0.2:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
        if random.random() < 0.3:
            angle = random.uniform(-15.0, 15.0)
            image = TF.rotate(image, angle, interpolation=InterpolationMode.BILINEAR)
            mask = TF.rotate(mask, angle, interpolation=InterpolationMode.NEAREST)
    image = TF.resize(image, cfg.image_size, interpolation=InterpolationMode.BILINEAR)
    mask = TF.resize(mask, cfg.image_size, interpolation=InterpolationMode.NEAREST)
    image_tensor = TF.normalize(TF.to_tensor(image), mean=cfg.imagenet_mean, std=cfg.imagenet_std)
    # Binarise lesion area then convert to class indices (0 background, 1 lesion) for pixel-wise cross-entropy
    mask_array = (np.array(mask, dtype=np.uint8) > 0).astype(np.int64)
    mask_tensor = torch.from_numpy(mask_array)
    return image_tensor, mask_tensor

# Dataset 

In [None]:
# Dataset definitions
class ISICClassificationDataset(Dataset):
    """Loads image-level labels used for the global diagnostic head."""

    def __init__(self, cfg: Config, split: str):
        self.cfg = cfg
        self.split = split
        self.transform = build_classification_transform(self.cfg, self.split == "train")
        label_columns = list(cfg.class_names)
        manifests: List[pd.DataFrame] = []

        for subset in ("paired", "classification_only"):
            manifest_path = cfg.manifest_path(subset)
            subset_df = pd.read_csv(manifest_path)
            if "split" not in subset_df.columns:
                raise KeyError(f"Manifest at {manifest_path} is missing a 'split' column")
            subset_df = subset_df[subset_df["split"] == split].copy()
            if subset_df.empty:
                continue
            subset_df["image_abs"] = subset_df["image"].apply(lambda rel: str(cfg.resolve_mixed_path(rel)))
            subset_df["image_id"] = subset_df["image"].apply(lambda rel: Path(rel).stem)
            manifests.append(subset_df)

        if not manifests:
            raise RuntimeError(f"No classification samples found for split '{split}' in the mixed dataset")

        self.metadata = pd.concat(manifests, ignore_index=True)
        missing_cols = [col for col in label_columns if col not in self.metadata.columns]
        if missing_cols:
            raise RuntimeError(
                "Classification manifest is missing expected label columns: " + ", ".join(missing_cols)
            )
        self.label_vectors = self.metadata[label_columns].values.astype(np.float32)
        self.image_paths = [Path(path) for path in self.metadata["image_abs"]]
        self.image_ids = self.metadata["image_id"].tolist()

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        image_path = self.image_paths[index]
        if not image_path.exists():
            raise FileNotFoundError(f"Classification image not found at {image_path}")
        image = Image.open(image_path)
        tensor = self.transform(image)
        label_vector = self.label_vectors[index]
        label = torch.tensor(int(label_vector.argmax()), dtype=torch.long)
        return {
            "image": tensor,
            "label": label,
            "label_one_hot": torch.from_numpy(label_vector),
            "image_id": self.image_ids[index],
        }


class ISICSegmentationDataset(Dataset):
    """Provides pixel-level annotations consumed by the decoder/fusion branches."""

    def __init__(self, cfg: Config, split: str):
        self.cfg = cfg
        self.split = split
        manifests: List[pd.DataFrame] = []

        for subset in ("paired", "segmentation_only"):
            manifest_path = cfg.manifest_path(subset)
            subset_df = pd.read_csv(manifest_path)
            if "split" not in subset_df.columns:
                raise KeyError(f"Manifest at {manifest_path} is missing a 'split' column")
            subset_df = subset_df[subset_df["split"] == split].copy()
            if subset_df.empty:
                continue
            subset_df["image_abs"] = subset_df["image"].apply(lambda rel: str(cfg.resolve_mixed_path(rel)))
            subset_df["mask_abs"] = subset_df["mask"].apply(lambda rel: str(cfg.resolve_mixed_path(rel)))
            subset_df["image_id"] = subset_df["image"].apply(lambda rel: Path(rel).stem)
            manifests.append(subset_df)

        if not manifests:
            raise RuntimeError(f"No segmentation samples found for split '{split}' in the mixed dataset")

        self.metadata = pd.concat(manifests, ignore_index=True)
        self.image_paths = [Path(path) for path in self.metadata["image_abs"]]
        self.mask_paths = [Path(path) for path in self.metadata["mask_abs"]]
        self.image_ids = self.metadata["image_id"].tolist()

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        image_path = self.image_paths[index]
        mask_path = self.mask_paths[index]
        if not image_path.exists():
            raise FileNotFoundError(f"Segmentation image not found at {image_path}")
        if not mask_path.exists():
            raise FileNotFoundError(f"Segmentation mask not found at {mask_path}")
        image = Image.open(image_path)
        mask = Image.open(mask_path)
        image_tensor, mask_tensor = apply_segmentation_transforms(
            image, mask, self.cfg, train=self.split == "train"
        )
        return {
            "image": image_tensor,
            "mask": mask_tensor,
            "image_id": self.image_ids[index],
        }

In [None]:
# DataModule-style helpers
def create_datasets(cfg: Config) -> Tuple[Dict[str, Dataset], Dict[str, Dataset]]:
    """Instantiate train/val/test splits for both tasks."""
    classification = {split: ISICClassificationDataset(cfg, split) for split in ("train", "val", "test")}
    segmentation = {split: ISICSegmentationDataset(cfg, split) for split in ("train", "val", "test")}
    return classification, segmentation


def create_dataloaders(
    cfg: Config,
    classification: Dict[str, Dataset],
    segmentation: Dict[str, Dataset],
) -> Tuple[Dict[str, DataLoader], Dict[str, DataLoader]]:
    """Wrap datasets in PyTorch dataloaders with paper-inspired batching."""
    classification_loaders = {
        split: DataLoader(
            dataset,
            batch_size=cfg.batch_size,
            shuffle=split == "train",
            drop_last=split == "train"
        )
        for split, dataset in classification.items()
    }
    segmentation_loaders = {
        split: DataLoader(
            dataset,
            batch_size=cfg.segmentation_batch_size,
            shuffle=split == "train",
            drop_last=split == "train"
        )
        for split, dataset in segmentation.items()
    }
    return classification_loaders, segmentation_loaders

classification, segmentation = create_datasets(cfg)
classification_loader, segmentation_loader = create_dataloaders(cfg, classification, segmentation) 
# loaders are dict of the form {train: train_loader, val: val_loader, test: test_loader}
# The print acts as a quick sanity check mirroring dataset stats shown in the paper
print("Loaders created succesfully")

## Dataset sanity check

In [None]:
# Quick dataset sanity check
classification_datasets, segmentation_datasets = create_datasets(cfg)
# Mirror the counts reported in the appendix (ensures file structure is correct)
print({split: len(ds) for split, ds in classification_datasets.items()})
print({split: len(ds) for split, ds in segmentation_datasets.items()})
sample_cls = classification_datasets["train"][0]
sample_seg = segmentation_datasets["train"][0]
print("Classification sample:", sample_cls["image"].shape, sample_cls["label_one_hot"], sample_cls["image_id"])
print("Segmentation sample:", sample_seg["image"].shape, sample_seg["mask"].shape, sample_seg["image_id"])

# Model definition

In [None]:
# Multi-task model skeleton with ResNet-50 backbone, dual heads, and stacked cross-fusion modules
class ConvBNReLU(nn.Module):
    """Utility block repeatedly used in decoder refinements (convolution + BN + ReLU)."""

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1):
        super().__init__()  # register parameters and buffers
        padding = kernel_size // 2  # preserve spatial size for odd kernels to keep skip connections aligned
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),  # stabilise feature statistics while training
            nn.ReLU(inplace=True),  # lightweight non-linearity reusing storage
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class ClassificationBranch(nn.Module):
    """Fully connected classifier with four hidden layers, dropout regularisation, and logits head."""

    def __init__(self, in_features: int, hidden_dims: Tuple[int, int, int, int], num_classes: int):
        super().__init__()
        if len(hidden_dims) != 4:
            raise ValueError("hidden_dims must contain four entries for the extended branch")
        self.hidden_dims = hidden_dims  # expose dims for cross-fusion wiring
        dims = (in_features,) + hidden_dims
        # Construct sequential linear layers that can be selectively invoked after each fusion stage
        self.layers = nn.ModuleList(
            nn.Linear(dims[i], dims[i + 1]) for i in range(len(hidden_dims))
        )
        self.dropout = nn.Dropout(p=0.3)  # regularise intermediate representations
        self.head = nn.Linear(hidden_dims[-1], num_classes)  # final logits for frame-level classification

    def activate_layer(self, x: torch.Tensor, idx: int) -> torch.Tensor:
        """Apply the idx-th hidden layer followed by ReLU and dropout."""
        if idx < 0 or idx >= len(self.layers):
            raise IndexError("Layer index out of range for classification branch")
        x = self.layers[idx](x)
        x = F.relu(x)
        return self.dropout(x)

    def final_logits(self, x: torch.Tensor) -> torch.Tensor:
        """Compute classification logits from the last hidden representation."""
        return self.head(x)


class SegmentationBranch(nn.Module):
    """Image-to-image decoder with additional refinement blocks for deeper fusion points."""

    def __init__(self, encoder_channels: Tuple[int, int, int, int, int], out_channels: int):
        super().__init__()
        c0, c1, c2, c3, c4 = encoder_channels
        # Standard U-Net style upsampling path: progressively merge encoder features with decoder activations
        self.reduce = ConvBNReLU(c4, 512, kernel_size=1)  # compress deepest features before upsampling
        self.up3 = ConvBNReLU(512 + c3, 256)
        self.up2 = ConvBNReLU(256 + c2, 128)
        self.up1 = ConvBNReLU(128 + c1, 96)
        self.up0 = ConvBNReLU(96 + c0, 64)
        # Refinement convolutions let fusion interactions happen on stable feature maps
        self.refine1 = ConvBNReLU(64, 64)  # refinement prior to second fusion
        self.refine2 = ConvBNReLU(64, 64)  # refinement prior to prediction
        # Produce background/lesion logits so a softmax classifier can score each pixel
        self.prediction = nn.Conv2d(64, out_channels, kernel_size=1)

    def decode(self, features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
        """Run the entire top-down decoder and return the high-resolution feature map for fusion."""
        c0, c1, c2, c3, c4 = features
        x = self.reduce(c4)
        x = F.interpolate(x, size=c3.shape[-2:], mode="bilinear", align_corners=False)
        x = self.up3(torch.cat([x, c3], dim=1))
        x = F.interpolate(x, size=c2.shape[-2:], mode="bilinear", align_corners=False)
        x = self.up2(torch.cat([x, c2], dim=1))
        x = F.interpolate(x, size=c1.shape[-2:], mode="bilinear", align_corners=False)
        x = self.up1(torch.cat([x, c1], dim=1))
        x = F.interpolate(x, size=c0.shape[-2:], mode="bilinear", align_corners=False)
        return self.up0(torch.cat([x, c0], dim=1))  # high-res feature map used for fusion

    def predict(self, features: torch.Tensor) -> torch.Tensor:
        """Project refined features to per-pixel logits (softmax later gives class probabilities)."""
        return self.prediction(features)


class CrossFusionModule(nn.Module):
    """Bidirectional cross-fusion: couples classification (global) and segmentation (dense) features."""

    def __init__(self, cls_channels: int, seg_channels: int):
        super().__init__()
        # Shared transformation matrix M from the paper (Eq. 7) implemented as a 1×1 convolution
        self.transform = nn.Conv2d(seg_channels, cls_channels, kernel_size=1, bias=False)
        self.pool = nn.AdaptiveAvgPool2d(1)  # used in the segmentation→classification path

    def forward(
        self,
        cls_feature: torch.Tensor,  # shape (batch, cls_channels, 1, 1)
        seg_feature: torch.Tensor,  # shape (batch, seg_channels, H, W)
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # X'_cls = X_cls + Pool(M(X_seg)) — Eq. (7a) converts dense activations to a global descriptor
        seg_to_cls = self.pool(self.transform(seg_feature))
        fused_cls = cls_feature + seg_to_cls

        # X'_seg = X_seg + M^T(Pad(X_cls)) — Eq. (7b) injects classification context into the decoder
        h, w = seg_feature.shape[-2:]
        cls_expanded = cls_feature.expand(-1, cls_feature.size(1), h, w)  # Pad(X_cls) for spatial compatibility
        cls_to_seg = F.conv_transpose2d(cls_expanded, self.transform.weight)  # transpose convolution applies M^T
        fused_seg = seg_feature + cls_to_seg
        return fused_cls, fused_seg


class MultiTask(nn.Module):
    """Multi-task architecture with shared ResNet-50 encoder and stacked cross-fusion heads."""

    def __init__(
        self,
        num_classes: int,
        num_segmentation_classes: int = 2,
        trainable_backbone_layers: int = 3,
        use_pretrained: bool = True,
    ):
        super().__init__()
        weights = ResNet50_Weights.DEFAULT if use_pretrained else None
        backbone = resnet50(weights=weights)

        # Expose ResNet-50 feature stages for skip connections and fine-tuning control
        self.stem = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)
        self.pool0 = backbone.maxpool
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4
        self._set_trainable_layers(trainable_backbone_layers)

        encoder_channels = (64, 256, 512, 1024, 2048)  # channel counts produced by each ResNet stage
        classification_hidden_dims = (1024, 512, 256, 128)  # width of the four FC layers
        self.classifier_pool = nn.AdaptiveAvgPool2d(1)  # compress encoder output before the MLP
        self.classifier_branch = ClassificationBranch(
            in_features=encoder_channels[-1], hidden_dims=classification_hidden_dims, num_classes=num_classes
        )
        self.segmentation_branch = SegmentationBranch(
            encoder_channels=encoder_channels, out_channels=num_segmentation_classes
        )

        # Two cross-fusion modules placed between consecutive layers of each branch per revised spec
        self.cross_fusion_primary = CrossFusionModule(
            cls_channels=self.classifier_branch.hidden_dims[0], seg_channels=64
        )
        self.cross_fusion_secondary = CrossFusionModule(
            cls_channels=self.classifier_branch.hidden_dims[2], seg_channels=64
        )

    def _set_trainable_layers(self, trainable_backbone_layers: int) -> None:
        """Freeze early ResNet blocks to control how much of the backbone is fine-tuned."""
        stages = [
            self.stem,
            self.pool0,
            self.layer1,
            self.layer2,
            self.layer3,
            self.layer4,
        ]
        if trainable_backbone_layers < 1 or trainable_backbone_layers > len(stages):
            raise ValueError(
                f"trainable_backbone_layers must be between 1 and {len(stages)}; got {trainable_backbone_layers}"
            )
        trainable_modules = set(stages[-trainable_backbone_layers:])
        for module in stages:
            for param in module.parameters():
                param.requires_grad = module in trainable_modules

    def forward(
        self,
        x: torch.Tensor,
        task: Optional[str] = None,
    ) -> Dict[str, torch.Tensor]:
        """Run shared encoder, two-stage cross-fusion, and emit requested task predictions."""
        if task and task not in {"classification", "segmentation"}:
            raise ValueError("task must be 'classification', 'segmentation', or None")

        # Shared ResNet-50 encoder produces multi-scale features for both tasks
        c0 = self.stem(x)
        p0 = self.pool0(c0)
        c1 = self.layer1(p0)
        c2 = self.layer2(c1)
        c3 = self.layer3(c2)
        c4 = self.layer4(c3)
        encoder_features = (c0, c1, c2, c3, c4)

        # First branch layers
        seg_stage0 = self.segmentation_branch.decode(encoder_features)  # segmentation layer 1
        cls_vector0 = self.classifier_pool(c4).flatten(1)  # base embedding for the classifier
        cls_hidden1 = self.classifier_branch.activate_layer(cls_vector0, idx=0)  # classification layer 1
        cls_hidden1_map = cls_hidden1.view(cls_hidden1.size(0), cls_hidden1.size(1), 1, 1)

        fusion_penalties: List[torch.Tensor] = []

        # Cross-fusion between first and second layers
        cls_fused1, seg_fused1 = self.cross_fusion_primary(cls_hidden1_map, seg_stage0)
        fusion_penalties.append(F.mse_loss(cls_fused1, cls_hidden1_map, reduction="mean"))
        fusion_penalties.append(F.mse_loss(seg_fused1, seg_stage0, reduction="mean"))

        cls_stage1 = cls_fused1.flatten(1)
        cls_hidden2 = self.classifier_branch.activate_layer(cls_stage1, idx=1)  # classification layer 2
        seg_stage1 = self.segmentation_branch.refine1(seg_fused1)  # segmentation layer 2

        # Progress to third layers of both branches
        cls_hidden3 = self.classifier_branch.activate_layer(cls_hidden2, idx=2)  # classification layer 3
        cls_hidden3_map = cls_hidden3.view(cls_hidden3.size(0), cls_hidden3.size(1), 1, 1)
        seg_stage2 = self.segmentation_branch.refine2(seg_stage1)  # segmentation layer 3

        # Cross-fusion between third and final layers
        cls_fused2, seg_fused2 = self.cross_fusion_secondary(cls_hidden3_map, seg_stage2)
        fusion_penalties.append(F.mse_loss(cls_fused2, cls_hidden3_map, reduction="mean"))
        fusion_penalties.append(F.mse_loss(seg_fused2, seg_stage2, reduction="mean"))

        cls_stage3 = cls_fused2.flatten(1)
        cls_hidden4 = self.classifier_branch.activate_layer(cls_stage3, idx=3)  # classification layer 4
        cls_logits = self.classifier_branch.final_logits(cls_hidden4)

        seg_logits = self.segmentation_branch.predict(seg_fused2)  # segmentation final layer
        seg_logits = F.interpolate(seg_logits, size=x.shape[-2:], mode="bilinear", align_corners=False)

        fusion_reg = torch.stack(fusion_penalties).sum()

        outputs: Dict[str, torch.Tensor] = {}
        requested = {task} if task else {"classification", "segmentation"}
        if "classification" in requested:
            outputs["classification"] = cls_logits
            outputs["fusion_reg"] = fusion_reg
        if "segmentation" in requested:
            outputs["segmentation"] = seg_logits
        return outputs


In [None]:
# Losses and metrics
def build_loss_functions(cfg: Config, device: torch.device) -> Tuple[nn.Module, nn.Module]:
    """Prepare task-specific objectives used during optimisation."""
    # Classification: standard cross-entropy over lesion categories
    classification_loss = nn.CrossEntropyLoss()
    # Segmentation: pixel-wise cross-entropy between softmax logits and ground-truth class map
    segmentation_loss = nn.CrossEntropyLoss(ignore_index=cfg.ignore_index)
    return classification_loss, segmentation_loss


def classification_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
    """Compute top-1 accuracy for the classification branch."""
    predictions = logits.argmax(dim=1)
    correct = (predictions == labels).sum().item()
    total = labels.numel()
    return correct / max(total, 1)


def dice_score(logits: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5, eps: float = 1e-6) -> float:
    """Compute Dice overlap for the lesion class using softmax probabilities."""
    probabilities = torch.softmax(logits, dim=1)[:, 1, ...]  # lesion channel
    preds = (probabilities > threshold).float()
    targets = targets.float()
    intersection = (preds * targets).sum(dim=(1, 2))
    union = preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2))
    dice = (2 * intersection + eps) / (union + eps)
    return dice.mean().item()

In [None]:
# Training utilities
def create_optimizer(model: nn.Module, cfg: Config) -> Tuple[torch.optim.Optimizer, CosineAnnealingLR]:
    """Build AdamW optimizer and cosine scheduler for the parameters left trainable."""
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=cfg.base_learning_rate, weight_decay=cfg.weight_decay)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=cfg.scheduler_period, eta_min=cfg.min_learning_rate
    )
    return optimizer, scheduler


def _next_batch(
    iterator: Iterator[Dict[str, torch.Tensor]], loader: DataLoader
) -> Tuple[Dict[str, torch.Tensor], Iterator[Dict[str, torch.Tensor]]]:
    """Fetch the next batch and restart the iterator when the dataloader is exhausted."""
    try:
        batch = next(iterator)
    except StopIteration:
        iterator = iter(loader)
        batch = next(iterator)
    return batch, iterator


def train_one_epoch(
    model: nn.Module,
    classification_loader: DataLoader,
    segmentation_loader: DataLoader,
    classification_loss_fn: nn.Module,
    segmentation_loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
    cfg: Config,
    device: torch.device,
    scaler: Optional[GradScaler] = None,
 ) -> Dict[str, float]:
    """Train the network for a single epoch while logging both task metrics."""
    model.train()
    use_amp = scaler is not None and scaler.is_enabled()
    classification_iter = iter(classification_loader)
    segmentation_iter = iter(segmentation_loader)
    max_steps = max(len(classification_loader), len(segmentation_loader))
    total_loss = 0.0
    total_cls_loss = 0.0
    total_seg_loss = 0.0
    total_accuracy = 0.0
    total_dice = 0.0
    for step in range(max_steps):
        optimizer.zero_grad(set_to_none=True)
        classification_batch, classification_iter = _next_batch(classification_iter, classification_loader)
        segmentation_batch, segmentation_iter = _next_batch(segmentation_iter, segmentation_loader)
        classification_images = classification_batch["image"].to(device, non_blocking=True)
        classification_labels = classification_batch["label"].to(device, non_blocking=True)
        segmentation_images = segmentation_batch["image"].to(device, non_blocking=True)
        segmentation_masks = segmentation_batch["mask"].to(device, non_blocking=True).long()
        amp_context = autocast() if use_amp and device.type == "cuda" else nullcontext()
        with amp_context:
            # Classification branch: cross-entropy plus α-weighted fusion regularizer
            classification_result = model(classification_images, task="classification")
            cls_logits = classification_result["classification"]
            fusion_reg = classification_result.get("fusion_reg")
            if fusion_reg is None:
                fusion_reg = cls_logits.new_zeros(())
            cls_loss = classification_loss_fn(cls_logits, classification_labels) + cfg.alpha * fusion_reg
            # Segmentation branch: pixel-wise cross-entropy between logits and integer mask
            segmentation_result = model(segmentation_images, task="segmentation")
            seg_logits = segmentation_result["segmentation"]
            seg_loss = segmentation_loss_fn(seg_logits, segmentation_masks)
            loss = cfg.classification_loss_weight * cls_loss + cfg.segmentation_loss_weight * seg_loss
        if use_amp and device.type == "cuda":
            scaler.scale(loss).backward()
            if cfg.grad_clip_norm and cfg.grad_clip_norm > 0:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if cfg.grad_clip_norm and cfg.grad_clip_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)
            optimizer.step()
        total_loss += loss.item()
        total_cls_loss += cls_loss.item()
        total_seg_loss += seg_loss.item()
        total_accuracy += classification_accuracy(cls_logits.detach(), classification_labels)
        total_dice += dice_score(seg_logits.detach(), segmentation_masks)
    steps = float(max_steps)
    return {
        "loss": total_loss / steps,
        "classification_loss": total_cls_loss / steps,
        "segmentation_loss": total_seg_loss / steps,
        "classification_accuracy": total_accuracy / steps,
        "segmentation_dice": total_dice / steps,
    }


@torch.no_grad()
def evaluate(
    model: nn.Module,
    classification_loader: DataLoader,
    segmentation_loader: DataLoader,
    classification_loss_fn: nn.Module,
    segmentation_loss_fn: nn.Module,
    cfg: Config,
    device: torch.device,
 ) -> Dict[str, float]:
    """Run validation/testing without gradient tracking."""
    model.eval()
    cls_loss_total = 0.0
    cls_acc_total = 0.0
    cls_steps = 0
    for batch in classification_loader:
        images = batch["image"].to(device, non_blocking=True)
        labels = batch["label"].to(device, non_blocking=True)
        classification_result = model(images, task="classification")
        cls_logits = classification_result["classification"]
        fusion_reg = classification_result.get("fusion_reg")
        if fusion_reg is None:
            fusion_reg = cls_logits.new_zeros(())
        cls_loss = classification_loss_fn(cls_logits, labels) + cfg.alpha * fusion_reg
        cls_loss_total += cls_loss.item()
        cls_acc_total += classification_accuracy(cls_logits, labels)
        cls_steps += 1
    seg_loss_total = 0.0
    seg_dice_total = 0.0
    seg_steps = 0
    for batch in segmentation_loader:
        images = batch["image"].to(device, non_blocking=True)
        masks = batch["mask"].to(device, non_blocking=True).long()
        segmentation_result = model(images, task="segmentation")
        seg_logits = segmentation_result["segmentation"]
        seg_loss_total += segmentation_loss_fn(seg_logits, masks).item()
        seg_dice_total += dice_score(seg_logits, masks)
        seg_steps += 1
    cls_steps = max(cls_steps, 1)
    seg_steps = max(seg_steps, 1)
    return {
        "classification_loss": cls_loss_total / cls_steps,
        "classification_accuracy": cls_acc_total / cls_steps,
        "segmentation_loss": seg_loss_total / seg_steps,
        "segmentation_dice": seg_dice_total / seg_steps,
    }

In [None]:
# High-level training loop
def fit(
    cfg: Config,
    device: torch.device,
    output_dir: Optional[Path] = None,
    resume_from: Optional[Path] = None,
    save_checkpoints: bool = True,
) -> Dict[str, List[Dict[str, float]]]:
    classification_datasets, segmentation_datasets = create_datasets(cfg)
    classification_loaders, segmentation_loaders = create_dataloaders(
        cfg, classification_datasets, segmentation_datasets
    )
    model = MultiTask(num_classes=len(cfg.class_names)).to(device)
    classification_loss_fn, segmentation_loss_fn = build_loss_functions(cfg, device)
    optimizer, scheduler = create_optimizer(model, cfg)
    use_amp = cfg.mixed_precision and device.type == "cuda"
    scaler = GradScaler(enabled=use_amp) if use_amp else None
    start_epoch = 0
    best_val_dice = 0.0
    history: Dict[str, List[Dict[str, float]]] = {"train": [], "val": []}
    if output_dir:
        output_dir.mkdir(parents=True, exist_ok=True)
    if resume_from:
        checkpoint = torch.load(resume_from, map_location=device)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        if scaler and "scaler" in checkpoint:
            scaler.load_state_dict(checkpoint["scaler"])
        start_epoch = checkpoint.get("epoch", 0) + 1
        best_val_dice = checkpoint.get("best_val_dice", best_val_dice)
        print(f"Resumed training from {resume_from} at epoch {start_epoch}")
    for epoch in range(start_epoch, cfg.max_epochs):
        train_metrics = train_one_epoch(
            model,
            classification_loaders["train"],
            segmentation_loaders["train"],
            classification_loss_fn,
            segmentation_loss_fn,
            optimizer,
            cfg,
            device,
            scaler,
        )
        val_metrics = evaluate(
            model,
            classification_loaders["val"],
            segmentation_loaders["val"],
            classification_loss_fn,
            segmentation_loss_fn,
            cfg,
            device,
        )
        scheduler.step()
        history["train"].append(train_metrics)
        history["val"].append(val_metrics)
        print(
            f"Epoch {epoch + 1}/{cfg.max_epochs} | "
            f"Train Loss: {train_metrics['loss']:.4f} | "
            f"Val Acc: {val_metrics['classification_accuracy']:.4f} | "
            f"Val Dice: {val_metrics['segmentation_dice']:.4f}"
        )
        current_val_dice = val_metrics["segmentation_dice"]
        if save_checkpoints and output_dir:
            checkpoint = {
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "cfg": cfg.__dict__,
                "train_metrics": train_metrics,
                "val_metrics": val_metrics,
                "best_val_dice": max(best_val_dice, current_val_dice),
            }
            if scaler:
                checkpoint["scaler"] = scaler.state_dict()
            torch.save(checkpoint, output_dir / f"multitask_resnet50_epoch_{epoch + 1:03d}.pth")
        if current_val_dice > best_val_dice:
            best_val_dice = current_val_dice
    model.eval()
    return {
        "history": history,
        "model": model,
        "classification_loaders": classification_loaders,
        "segmentation_loaders": segmentation_loaders,
        "best_val_dice": best_val_dice,
    }

In [None]:
results = fit(cfg,device,output_dir=Path("artifacts/multitask"),resume_from=None,save_checkpoints=True,)
best_model = results["model"]
history = results["history"]

In [None]:
from torchinfo import summary

model = MultiTask(num_classes=len(cfg.class_names))
#summary(model, input_size=(1, 3, *cfg.image_size), col_names=("input_size", "output_size", "num_params", "trainable"))

In [None]:
# Evaluate the best_model captured during training on the held-out test splits
if "best_model" not in globals():
    raise RuntimeError("best_model is undefined. Run the training cell that returns best_model before inference.")
model = best_model.to(device).eval()
classification_datasets, segmentation_datasets = create_datasets(cfg)
classification_loaders, segmentation_loaders = create_dataloaders(
    cfg, classification_datasets, segmentation_datasets
)
classification_loss_fn, segmentation_loss_fn = build_loss_functions(cfg, device)
test_metrics = evaluate(
    model,
    classification_loaders["test"],
    segmentation_loaders["test"],
    classification_loss_fn,
    segmentation_loss_fn,
    cfg,
    device,
)
print("Test metrics:")
for key, value in test_metrics.items():
    print(f"  {key}: {value:.4f}")

In [None]:
# Save the best-trained model checkpoint
if "best_model" not in globals():
    raise RuntimeError("best_model is undefined. Run the training cell that returns best_model before saving.")
model_to_save = best_model.to("cpu").eval()
save_path = Path("multitask_resnet50_best.pth")
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model_to_save.state_dict(), save_path)
print(f"Saved best model weights to {save_path}")
best_model = model_to_save.to(device)

In [None]:
# Visualise model predictions on a random test sample
import matplotlib.pyplot as plt
import numpy as np
import random

if "best_model" not in globals():
    raise RuntimeError("Run training to populate best_model before visualisation.")
model = best_model.to(device).eval()

# Ensure test datasets/loaders are available
if "classification_datasets" not in globals() or "segmentation_datasets" not in globals():
    classification_datasets, segmentation_datasets = create_datasets(cfg)
if "classification_loaders" not in globals() or "segmentation_loaders" not in globals():
    classification_loaders, segmentation_loaders = create_dataloaders(
        cfg, classification_datasets, segmentation_datasets
    )

test_segmentation_dataset = segmentation_datasets["test"]
sample_idx = random.randrange(len(test_segmentation_dataset))
sample = test_segmentation_dataset[sample_idx]
image = sample["image"].unsqueeze(0).to(device)
mask_gt = sample["mask"].numpy()
image_id = sample["image_id"]

with torch.no_grad():
    outputs = model(image)
    cls_logits = outputs["classification"]
    seg_logits = outputs["segmentation"]
    cls_probs = torch.softmax(cls_logits, dim=1)[0].cpu().numpy()
    pred_class_idx = int(cls_probs.argmax())
    pred_class_conf = float(cls_probs[pred_class_idx])
    pred_class_label = cfg.class_names[pred_class_idx]
    seg_probs = torch.softmax(seg_logits, dim=1)[0, 1].cpu().numpy()

# Try to recover the ground-truth classification label (if the image exists in the CSV)
ground_truth_label = "N/A"
test_cls_dataset = classification_datasets["test"]
filename_candidate = f"{image_id}.jpg"
try:
    cls_index = test_cls_dataset.image_paths.index(filename_candidate)
    label_vector = test_cls_dataset.label_vectors[cls_index]
    gt_idx = int(label_vector.argmax())
    ground_truth_label = cfg.class_names[gt_idx]
except ValueError:
    pass

# Denormalise image tensor for visualisation
mean = np.array(cfg.imagenet_mean)
std = np.array(cfg.imagenet_std)
image_np = sample["image"].permute(1, 2, 0).cpu().numpy()
image_np = (image_np * std) + mean
image_np = np.clip(image_np, 0.0, 1.0)

fig, axes = plt.subplots(1, 3, figsize=(14, 4))
axes[0].imshow(image_np)
axes[0].set_title(f"Input image: {image_id}")
axes[0].axis("off")

axes[1].imshow(mask_gt, cmap="gray")
axes[1].set_title("Ground-truth mask")
axes[1].axis("off")

axes[2].imshow(image_np)
axes[2].imshow(seg_probs, cmap="viridis", alpha=0.5)
axes[2].set_title("Predicted mask overlay")
axes[2].axis("off")

plt.suptitle(
    f"Predicted label: {pred_class_label} (p={pred_class_conf:.2f}) | Ground truth: {ground_truth_label}"
)
plt.tight_layout()
plt.show()