# Multi-Task Skin Lesion Diagnostic
This notebook sketches a PyTorch implementation of the architecture described in *Multi-Task Classification and Segmentation for Explicable Capsule Endoscopy Diagnostics*. The model shares an encoder across tasks and uses separate heads for frame-level classification and pixel-level lesion segmentation.

In [92]:
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple

import random

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models import ResNet50_Weights, resnet50
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF

torch.manual_seed(42)


<torch._C.Generator at 0x1155eb3d0>

# Model configurations

In [93]:
# Configuration
@dataclass
class Config:
    project_root: Path = Path.cwd()
    image_size: Tuple[int, int] = (256, 256)
    batch_size: int = 8
    segmentation_batch_size: int = 4
    base_learning_rate: float = 3e-4
    weight_decay: float = 1e-4
    max_epochs: int = 40
    classification_loss_weight: float = 0.4
    segmentation_loss_weight: float = 1.0
    class_names: Tuple[str, ...] = ("MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC")
    ignore_index: int = 255
    min_learning_rate: float = 1e-6
    scheduler_period: int = 10
    mixed_precision: bool = True
    segmentation_positive_weight: float = 1.5
    grad_clip_norm: Optional[float] = 5.0
    imagenet_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
    imagenet_std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
    segmentation_suffix: str = "_segmentation"
    csv_path = '/Users/enricotazzer/Desktop/multi-task-learning-for-classification-and-segmentation-of-skin-lesions/mutlitask_dataset/segmentation_labels.csv'

    def segmentation_input_dir(self, split: str) -> Path:
        return self.project_root / "dataset" / "segmentation" / split / "input"

    def segmentation_mask_dir(self, split: str) -> Path:
        return self.project_root / "dataset" / "segmentation" / split / "ground_truth"
    
    def classification_labels(self, split: str) -> Path:
        df = pd.read_csv(self.csv_path)
        mask = df["image"].astype(str).str.contains(split, na=False)
        return df.loc[mask].reset_index(drop=True)
    
cfg = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


# Image utilites

In [94]:
# Data transforms
def build_classification_transform(cfg: Config, train: bool) -> transforms.Compose:
    augmentations: List[transforms.Compose]
    if train:
        augmentations = [
            transforms.RandomResizedCrop(cfg.image_size, scale=(0.8, 1.0), interpolation=InterpolationMode.BILINEAR),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.1),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
        ]
    else:
        augmentations = [
            transforms.Resize(cfg.image_size, interpolation=InterpolationMode.BILINEAR),
            transforms.CenterCrop(cfg.image_size),
        ]
    augmentations += [
        transforms.ToTensor(),
        transforms.Normalize(mean=cfg.imagenet_mean, std=cfg.imagenet_std),
    ]
    return transforms.Compose(augmentations)


def apply_segmentation_transforms(
    image: Image.Image, mask: Image.Image, cfg: Config, train: bool
 ) -> Tuple[torch.Tensor, torch.Tensor]:
    """Apply spatial transforms jointly to the image and mask."""
    image = image.convert("RGB")
    mask = mask.convert("L")
    if train:
        if random.random() < 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if random.random() < 0.2:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
        if random.random() < 0.3:
            angle = random.uniform(-15.0, 15.0)
            image = TF.rotate(image, angle, interpolation=InterpolationMode.BILINEAR)
            mask = TF.rotate(mask, angle, interpolation=InterpolationMode.NEAREST)
    image = TF.resize(image, cfg.image_size, interpolation=InterpolationMode.BILINEAR)
    mask = TF.resize(mask, cfg.image_size, interpolation=InterpolationMode.NEAREST)
    image_tensor = TF.normalize(TF.to_tensor(image), mean=cfg.imagenet_mean, std=cfg.imagenet_std)
    mask_tensor = torch.from_numpy(np.array(mask, dtype=np.uint8))
    mask_tensor = (mask_tensor > 0).float().unsqueeze(0)
    return image_tensor, mask_tensor

# Dataset 

In [103]:
# Dataset definitions
class ISICClassificationDataset(Dataset):
    def __init__(self, cfg: Config, split: str):
        self.cfg = cfg
        self.split = split
        self.transform = build_classification_transform(self.cfg, self.split == 'train')
        self.csv_path = cfg.csv_path
        if not Path(self.csv_path).exists():
            raise FileNotFoundError(f"Missing classification CSV at {self.csv_path}")
        self.metadata = cfg.classification_labels(split)
        if self.metadata.empty:
            raise RuntimeError(f"Classification CSV at {self.csv_path} is empty")
        self.image_paths = self.metadata["image"].tolist()
        self.label_vectors = self.metadata.loc[:, cfg.class_names].values.astype(np.float32)

    def __len__(self) -> int:
        return len(self.image_paths)

    def _resolve_image_path(self, image_id: str) -> Path:
        path = f'mutlitask_dataset/segmentation/{image_id}'
        if Path(path).exists():
            return path
        raise FileNotFoundError(f"Could not locate an image file for id '{image_id}' in {self.image_paths}")

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        image_id = self.image_paths[index]
        image_path = self._resolve_image_path(image_id)
        image = Image.open(image_path)
        
        tensor = self.transform(image)
        label_vector = self.label_vectors[index]
        label = torch.tensor(int(label_vector.argmax()), dtype=torch.long)
        return {
            "image": tensor,
            "label": label,
            "label_one_hot": torch.from_numpy(label_vector),
            "image_id": image_id,
        }


class ISICSegmentationDataset(Dataset):
    def __init__(self, cfg: Config, split: str):
        self.cfg = cfg
        self.split = split
        self.images_dir = cfg.segmentation_input_dir(split)
        self.masks_dir = cfg.segmentation_mask_dir(split)
        self.image_paths = sorted([p for p in self.images_dir.glob("*.jpg")])
        if not self.image_paths:
            raise RuntimeError(f"No segmentation images found under {self.images_dir}")

    def __len__(self) -> int:
        return len(self.image_paths)

    def _mask_path(self, image_path: Path) -> Path:
        mask_name = f"{image_path.stem}{self.cfg.segmentation_suffix}.png"
        return self.masks_dir / mask_name

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        image_path = self.image_paths[index]
        mask_path = self._mask_path(image_path)
        if not mask_path.exists():
            raise FileNotFoundError(f"Missing mask for {image_path.name} at {mask_path}")
        image = Image.open(image_path)
        mask = Image.open(mask_path)
        image_tensor, mask_tensor = apply_segmentation_transforms(
            image, mask, self.cfg, train=self.split == "train"
        )
        return {
            "image": image_tensor,
            "mask": mask_tensor,
            "image_id": image_path.stem,
        }

In [104]:
# DataModule-style helpers
def create_datasets(cfg: Config) -> Tuple[Dict[str, Dataset], Dict[str, Dataset]]:
    classification = {split: ISICClassificationDataset(cfg, split) for split in ("train", "val", "test")}
    segmentation = {split: ISICSegmentationDataset(cfg, split) for split in ("train", "val", "test")}
    return classification, segmentation


def create_dataloaders(
    cfg: Config,
    classification: Dict[str, Dataset],
    segmentation: Dict[str, Dataset],
) -> Tuple[Dict[str, DataLoader], Dict[str, DataLoader]]:
    classification_loaders = {
        split: DataLoader(
            dataset,
            batch_size=cfg.batch_size,
            shuffle=split == "train",
            drop_last=split == "train"
        )
        for split, dataset in classification.items()
    }
    segmentation_loaders = {
        split: DataLoader(
            dataset,
            batch_size=cfg.segmentation_batch_size,
            shuffle=split == "train",
            drop_last=split == "train"
        )
        for split, dataset in segmentation.items()
    }
    return classification_loaders, segmentation_loaders

classification, segmentation = create_datasets(cfg)
classification_loader, segmentation_loader = create_dataloaders(cfg, classification, segmentation) 
# loaders are dict of the form {train: train_loader, val: val_loader, test: test_loader}
print("Loaders created succesfully")

Loaders created succesfully


## Dataset sanity check

In [105]:
# Quick dataset sanity check
classification_datasets, segmentation_datasets = create_datasets(cfg)
print({split: len(ds) for split, ds in classification_datasets.items()})
print({split: len(ds) for split, ds in segmentation_datasets.items()})
sample_cls = classification_datasets["train"][0]
sample_seg = segmentation_datasets["train"][0]
print("Classification sample:", sample_cls["image"].shape, sample_cls["label_one_hot"], sample_cls["image_id"])
print("Segmentation sample:", sample_seg["image"].shape, sample_seg["mask"].shape, sample_seg["image_id"])

{'train': 2594, 'val': 100, 'test': 1000}
{'train': 2594, 'val': 100, 'test': 1000}
Classification sample: torch.Size([3, 256, 256]) tensor([0., 1., 0., 0., 0., 0., 0.]) train/input/ISIC_0000000.jpg
Segmentation sample: torch.Size([3, 256, 256]) torch.Size([1, 256, 256]) ISIC_0000000


# Model definition

In [None]:
# Model components
class ConvBNReLU(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class SegmentationDecoder(nn.Module):
    def __init__(self, channels: Tuple[int, int, int, int, int], num_classes: int):
        super().__init__()
        c0, c1, c2, c3, c4 = channels
        self.block4 = ConvBNReLU(c4, 512, kernel_size=1, padding=0)
        self.block3 = ConvBNReLU(512 + c3, 256)
        self.block2 = ConvBNReLU(256 + c2, 128)
        self.block1 = ConvBNReLU(128 + c1, 96)
        self.block0 = ConvBNReLU(96 + c0, 64)
        self.head = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, features: Tuple[torch.Tensor, ...]) -> torch.Tensor:
        c0, c1, c2, c3, c4 = features
        x = self.block4(c4)
        x = F.interpolate(x, size=c3.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, c3], dim=1)
        x = self.block3(x)
        x = F.interpolate(x, size=c2.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, c2], dim=1)
        x = self.block2(x)
        x = F.interpolate(x, size=c1.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, c1], dim=1)
        x = self.block1(x)
        x = F.interpolate(x, size=c0.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, c0], dim=1)
        x = self.block0(x)
        return self.head(x)


class ClassificationHead(nn.Module):
    def __init__(self, in_channels: int, num_classes: int):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(p=0.3)
        self.fc = nn.Linear(in_channels, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(x).flatten(1)
        x = self.dropout(x)
        return self.fc(x)


class MultiTaskResNet50(nn.Module):
    def __init__(
        self,
        num_classes: int,
        num_segmentation_classes: int = 1,
        trainable_backbone_layers: int = 2,
        use_pretrained: bool = True,
    ):
        super().__init__()
        weights = ResNet50_Weights.DEFAULT if use_pretrained else None
        backbone = resnet50(weights=weights)
        self.initial = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)
        self.maxpool = backbone.maxpool
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4
        self._set_trainable_layers(trainable_backbone_layers)
        self.classifier = ClassificationHead(in_channels=2048, num_classes=num_classes)
        self.segmentation_decoder = SegmentationDecoder(
            channels=(64, 256, 512, 1024, 2048), num_classes=num_segmentation_classes
        )

    def _set_trainable_layers(self, trainable_backbone_layers: int) -> None:
        if trainable_backbone_layers < 1 or trainable_backbone_layers > 5:
            raise ValueError("trainable_backbone_layers must be between 1 and 5 for ResNet-50")
        layers = [self.initial, self.layer1, self.layer2, self.layer3, self.layer4]
        kept_layers = layers[-trainable_backbone_layers:]
        for module in layers:
            requires_grad = module in kept_layers
            for param in module.parameters():
                param.requires_grad = requires_grad

    def forward(self, x: torch.Tensor, task: Optional[str] = None) -> Dict[str, torch.Tensor]:
        if task and task not in {"classification", "segmentation"}:
            raise ValueError("task must be 'classification', 'segmentation', or None")
        input_spatial = x.shape[-2:]
        c0 = self.initial(x)
        p0 = self.maxpool(c0)
        c1 = self.layer1(p0)
        c2 = self.layer2(c1)
        c3 = self.layer3(c2)
        c4 = self.layer4(c3)
        features = (p0, c1, c2, c3, c4)
        requested = {task} if task else {"classification", "segmentation"}
        outputs: Dict[str, torch.Tensor] = {}
        if "classification" in requested:
            outputs["classification"] = self.classifier(c4)
        if "segmentation" in requested:
            segmentation_logits = self.segmentation_decoder(features)
            segmentation_logits = F.interpolate(
                segmentation_logits, size=input_spatial, mode="bilinear", align_corners=False
            )
            outputs["segmentation"] = segmentation_logits
        return outputs

In [None]:
# Losses and metrics
def build_loss_functions(cfg: Config, device: torch.device) -> Tuple[nn.Module, nn.Module]:
    classification_loss = nn.CrossEntropyLoss()
    pos_weight = torch.tensor([cfg.segmentation_positive_weight], device=device)
    segmentation_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    return classification_loss, segmentation_loss


def classification_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
    predictions = logits.argmax(dim=1)
    correct = (predictions == labels).sum().item()
    total = labels.numel()
    return correct / max(total, 1)


def dice_score(logits: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5, eps: float = 1e-6) -> float:
    probabilities = torch.sigmoid(logits)
    preds = (probabilities > threshold).float()
    intersection = (preds * targets).sum(dim=(1, 2, 3))
    union = preds.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3))
    dice = (2 * intersection + eps) / (union + eps)
    return dice.mean().item()

In [None]:
# Training utilities
def create_optimizer(model: nn.Module, cfg: Config) -> Tuple[torch.optim.Optimizer, CosineAnnealingLR]:
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=cfg.base_learning_rate, weight_decay=cfg.weight_decay)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=cfg.scheduler_period, eta_min=cfg.min_learning_rate
    )
    return optimizer, scheduler


def _next_batch(
    iterator: Iterator[Dict[str, torch.Tensor]], loader: DataLoader
) -> Tuple[Dict[str, torch.Tensor], Iterator[Dict[str, torch.Tensor]]]:
    try:
        batch = next(iterator)
    except StopIteration:
        iterator = iter(loader)
        batch = next(iterator)
    return batch, iterator


def train_one_epoch(
    model: nn.Module,
    classification_loader: DataLoader,
    segmentation_loader: DataLoader,
    classification_loss_fn: nn.Module,
    segmentation_loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
    cfg: Config,
    device: torch.device,
    scaler: Optional[GradScaler] = None,
) -> Dict[str, float]:
    model.train()
    use_amp = scaler is not None and scaler.is_enabled()
    classification_iter = iter(classification_loader)
    segmentation_iter = iter(segmentation_loader)
    max_steps = max(len(classification_loader), len(segmentation_loader))
    total_loss = 0.0
    total_cls_loss = 0.0
    total_seg_loss = 0.0
    total_accuracy = 0.0
    total_dice = 0.0
    for step in range(max_steps):
        optimizer.zero_grad(set_to_none=True)
        classification_batch, classification_iter = _next_batch(classification_iter, classification_loader)
        segmentation_batch, segmentation_iter = _next_batch(segmentation_iter, segmentation_loader)
        classification_images = classification_batch["image"].to(device, non_blocking=True)
        classification_labels = classification_batch["label"].to(device, non_blocking=True)
        segmentation_images = segmentation_batch["image"].to(device, non_blocking=True)
        segmentation_masks = segmentation_batch["mask"].to(device, non_blocking=True)
        with autocast(device_type=device.type, enabled=use_amp):
            classification_outputs = model(classification_images, task="classification")["classification"]
            cls_loss = classification_loss_fn(classification_outputs, classification_labels)
            segmentation_outputs = model(segmentation_images, task="segmentation")["segmentation"]
            seg_loss = segmentation_loss_fn(segmentation_outputs, segmentation_masks)
            loss = cfg.classification_loss_weight * cls_loss + cfg.segmentation_loss_weight * seg_loss
        if use_amp:
            scaler.scale(loss).backward()
            if cfg.grad_clip_norm and cfg.grad_clip_norm > 0:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if cfg.grad_clip_norm and cfg.grad_clip_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)
            optimizer.step()
        total_loss += loss.item()
        total_cls_loss += cls_loss.item()
        total_seg_loss += seg_loss.item()
        total_accuracy += classification_accuracy(classification_outputs.detach(), classification_labels)
        total_dice += dice_score(segmentation_outputs.detach(), segmentation_masks)
    steps = float(max_steps)
    return {
        "loss": total_loss / steps,
        "classification_loss": total_cls_loss / steps,
        "segmentation_loss": total_seg_loss / steps,
        "classification_accuracy": total_accuracy / steps,
        "segmentation_dice": total_dice / steps,
    }


@torch.no_grad()
def evaluate(
    model: nn.Module,
    classification_loader: DataLoader,
    segmentation_loader: DataLoader,
    classification_loss_fn: nn.Module,
    segmentation_loss_fn: nn.Module,
    cfg: Config,
    device: torch.device,
) -> Dict[str, float]:
    model.eval()
    cls_loss_total = 0.0
    cls_acc_total = 0.0
    cls_steps = 0
    for batch in classification_loader:
        images = batch["image"].to(device, non_blocking=True)
        labels = batch["label"].to(device, non_blocking=True)
        outputs = model(images, task="classification")["classification"]
        cls_loss_total += classification_loss_fn(outputs, labels).item()
        cls_acc_total += classification_accuracy(outputs, labels)
        cls_steps += 1
    seg_loss_total = 0.0
    seg_dice_total = 0.0
    seg_steps = 0
    for batch in segmentation_loader:
        images = batch["image"].to(device, non_blocking=True)
        masks = batch["mask"].to(device, non_blocking=True)
        outputs = model(images, task="segmentation")["segmentation"]
        seg_loss_total += segmentation_loss_fn(outputs, masks).item()
        seg_dice_total += dice_score(outputs, masks)
        seg_steps += 1
    cls_steps = max(cls_steps, 1)
    seg_steps = max(seg_steps, 1)
    return {
        "classification_loss": cls_loss_total / cls_steps,
        "classification_accuracy": cls_acc_total / cls_steps,
        "segmentation_loss": seg_loss_total / seg_steps,
        "segmentation_dice": seg_dice_total / seg_steps,
    }

In [None]:
# High-level training loop
def fit(
    cfg: Config,
    device: torch.device,
    output_dir: Optional[Path] = None,
    resume_from: Optional[Path] = None,
    save_checkpoints: bool = True,
) -> Dict[str, List[Dict[str, float]]]:
    classification_datasets, segmentation_datasets = create_datasets(cfg)
    classification_loaders, segmentation_loaders = create_dataloaders(
        cfg, classification_datasets, segmentation_datasets
    )
    model = MultiTaskResNet50(num_classes=len(cfg.class_names)).to(device)
    classification_loss_fn, segmentation_loss_fn = build_loss_functions(cfg, device)
    optimizer, scheduler = create_optimizer(model, cfg)
    use_amp = cfg.mixed_precision and device.type == "cuda"
    scaler = GradScaler(enabled=use_amp) if use_amp else None
    start_epoch = 0
    best_val_dice = 0.0
    history: Dict[str, List[Dict[str, float]]] = {"train": [], "val": []}
    if output_dir:
        output_dir.mkdir(parents=True, exist_ok=True)
    if resume_from:
        checkpoint = torch.load(resume_from, map_location=device)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        if scaler and "scaler" in checkpoint:
            scaler.load_state_dict(checkpoint["scaler"])
        start_epoch = checkpoint.get("epoch", 0) + 1
        best_val_dice = checkpoint.get("best_val_dice", best_val_dice)
        print(f"Resumed training from {resume_from} at epoch {start_epoch}")
    for epoch in range(start_epoch, cfg.max_epochs):
        train_metrics = train_one_epoch(
            model,
            classification_loaders["train"],
            segmentation_loaders["train"],
            classification_loss_fn,
            segmentation_loss_fn,
            optimizer,
            cfg,
            device,
            scaler,
        )
        val_metrics = evaluate(
            model,
            classification_loaders["val"],
            segmentation_loaders["val"],
            classification_loss_fn,
            segmentation_loss_fn,
            cfg,
            device,
        )
        scheduler.step()
        history["train"].append(train_metrics)
        history["val"].append(val_metrics)
        print(
            f"Epoch {epoch + 1}/{cfg.max_epochs} | "
            f"Train Loss: {train_metrics['loss']:.4f} | "
            f"Val Acc: {val_metrics['classification_accuracy']:.4f} | "
            f"Val Dice: {val_metrics['segmentation_dice']:.4f}"
        )
        current_val_dice = val_metrics["segmentation_dice"]
        if save_checkpoints and output_dir:
            checkpoint = {
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "cfg": cfg.__dict__,
                "train_metrics": train_metrics,
                "val_metrics": val_metrics,
                "best_val_dice": max(best_val_dice, current_val_dice),
            }
            if scaler:
                checkpoint["scaler"] = scaler.state_dict()
            torch.save(checkpoint, output_dir / f"multitask_resnet50_epoch_{epoch + 1:03d}.pth")
        if current_val_dice > best_val_dice:
            best_val_dice = current_val_dice
    model.eval()
    return {
        "history": history,
        "model": model,
        "classification_loaders": classification_loaders,
        "segmentation_loaders": segmentation_loaders,
        "best_val_dice": best_val_dice,
    }

In [None]:
# Example training call (disabled by default)
# results = fit(
#     cfg,
#     device,
#     output_dir=Path("artifacts/multitask"),
#     resume_from=None,
#     save_checkpoints=True,
# )
# best_model = results["model"]
# history = results["history"]