# Setup


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
params = {
    "epochs": 8,
    "batch_size": 128,
    "learning_rate": 0.001,
    "class_power": 0.9,  # between 0 and 1 (inclusive), scales positive class weight (0 removes class weighting, 1 leaves class ratio unchanged)
    "focal_power": 2,  # focusing parameter (gamma) in focal loss (default 2)
    "image_size": 128,  # target image size, images resized to fit in square where sides or of this length
    "threshold": 0.5,  # probability threshold for positive classification
    "seed": 42,  # rng seed for reproducibility
    # model architecture
    "cnn_layers": [
        (16, 5, True),  # 128×128×3 → 64×64×16 (5×5 kernel, pool)
        (32, 3, True),  # 64×64×16 → 32×32×32 (3×3 kernel, pool)
        (64, 3, True),  # 32×32×32 → 16×16×64
        (64, 3, True),  # 16×16×64 → 8×8×64
        (32, 3, True),  # 8×8×64 → 4×4×32 (final: 512 features)
    ],
    "metadata_layer_dims": [8, 16, 32],  # accepts metatadata tensor from dataloader
    "fusion_layer_dims": [256, 128, 64, 8],  # fuse encoded image & metadata
}

epochs = params["epochs"]
batch_size = params["batch_size"]
lr = params["learning_rate"]
class_power = params["class_power"]
focal_power = params["focal_power"]

img_size = params["image_size"], params["image_size"]
image_shape = params["image_size"], params["image_size"], 3
threshold = params["threshold"]
seed = params["seed"]

# model architecture params
cnn_layers = params["cnn_layers"]
metadata_layer_dims = params["metadata_layer_dims"]
fusion_layer_dims = params["fusion_layer_dims"]

In [3]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cuda')

In [4]:
import torch

torch.manual_seed(seed)
generator = torch.Generator().manual_seed(seed)

In [5]:
import torch

torch.set_float32_matmul_precision("high")

# Dataset


In [6]:
from datasets import load_dataset
from isic.dataset import ImageEncoder, MetadataEncoder, collate_batch

ds = load_dataset("mrbrobot/isic-2024", split="train")
ds = ds.select_columns(["image", "age_approx", "sex", "anatom_site_general", "target"])

len(ds)

401059

In [7]:
# encode metadata
metadata_encoder = MetadataEncoder().fit(ds)
ds = ds.with_format("arrow")
ds = ds.map(
    metadata_encoder,
    batched=True,
    batch_size=1000,
    desc="Encoding metadata columns",
)

# encode images
image_encoder = ImageEncoder(image_size=img_size)
ds = ds.with_format("torch")
ds = ds.with_transform(image_encoder, columns=["image"], output_all_columns=True)

# Model Definition


In [8]:
from isic.models import FusionModel

model = FusionModel(
    image_shape=image_shape,
    cnn_layers=cnn_layers,
    metadata_layer_dims=metadata_layer_dims,
    fusion_layer_dims=fusion_layer_dims,
).to(device)

model

FusionModel(
  (image_stack): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1), padding=same)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): SiLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): SiLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): SiLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14

In [9]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params

262961

# Training


In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

print(f"Model device: {next(model.parameters()).device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(
    f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}"
)

Model device: cuda:0
Total parameters: 262,961
Trainable parameters: 262,961


Class imbalance measurement & handling


In [11]:
class_counts = [400666, 393]  # [benign, malignant] from EDA

print("Class Distribution:")
print(f"Benign: {class_counts[0]:,} samples")
print(f"Malignant: {class_counts[1]:,} samples")
print(f"Imbalance ratio: {class_counts[0] / class_counts[1]:.1f}:1")

Class Distribution:
Benign: 400,666 samples
Malignant: 393 samples
Imbalance ratio: 1019.5:1


In [12]:
df = ds.to_pandas()
neg_count = (df["target"] == 0).sum()
pos_count = (df["target"] == 1).sum()
pos_weight = neg_count / pos_count

print(f"Positive weight: {pos_weight:.1f}")
print(f"Scaled positive class weight: {pos_weight**class_power:.1f}")

Positive weight: 1019.5
Scaled positive class weight: 510.0


In [13]:
from isic.loss import WeightedFocalLoss

scaled_pos_weight = torch.tensor([pos_weight**class_power], device=device)
criterion = WeightedFocalLoss(pos_weight=scaled_pos_weight, gamma=focal_power)

In [14]:
from torch.utils.data import DataLoader

split = ds.train_test_split(test_size=0.2, seed=seed)
train_ds, val_ds = split["train"], split["test"]

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_batch,
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_batch,
)

print(f"Batches per epoch - Train: {len(train_loader)}, Val: {len(val_loader)}")

Batches per epoch - Train: 2507, Val: 627


In [None]:
import trackio

trackio.init(project="fusion", config=params, embed=False);

* Trackio project initialized: fusion
* Trackio metrics logged to: /home/vscode/.cache/huggingface/trackio
* View dashboard by running in your terminal:
[1m[38;5;208mtrackio show --project "fusion"[0m
* or by running in Python: trackio.show(project="fusion")
* Created new run: dainty-sunset-0


: 

In [None]:
from isic.training import train, validate
from rich.console import Console

console = Console(force_jupyter=False)

for epoch in range(epochs):
    console.rule(f"[bold]Epoch {epoch + 1}/{epochs}")

    # train
    _ = train(
        model, train_loader, criterion, optimizer, device, threshold, console=console
    )

    # validate
    _ = validate(model, val_loader, criterion, device, threshold, console=console)

[92m────────────────────────────────── [0m[1mEpoch [0m[1;36m1[0m[1m/[0m[1;36m8[0m[92m ───────────────────────────────────[0m
[2KTraining [91m━━━━━━[0m[91m╸[0m[90m━━━[0m [35m 67%[0m (1673/2507) Loss: [36m0.4905[0m Prec: [32m0.001[0m Rec: [33m0.539[0m [33m0:22:10[0m

In [None]:
trackio.finish()