# Setup


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
params = {
    "epochs": 8,
    "batch_size": 128,
    "learning_rate": 0.001,
    "class_power": 0.9,  # between 0 and 1 (inclusive), scales positive class weight (0 removes class weighting, 1 leaves class ratio unchanged)
    "focal_power": 2,  # focusing parameter (gamma) in focal loss (default 2)
    "image_size": 128,  # target image size, images resized to fit in square where sides or of this length
    "threshold": 0.5,  # probability threshold for positive classification
    "seed": 42,  # rng seed for reproducibility
    # model architecture
    "image_layer_dims": [128, 64, 32],  # accepts flattened image
    "metadata_layer_dims": [8, 16, 32],  # accepts metatadata tensor from dataloader
    "fusion_layer_dims": [16, 8],  # accepts concat of encoded image & metadata
}

epochs = params["epochs"]
batch_size = params["batch_size"]
lr = params["learning_rate"]
class_power = params["class_power"]
focal_power = params["focal_power"]

img_size = params["image_size"], params["image_size"]
image_shape = params["image_size"], params["image_size"], 3
threshold = params["threshold"]
seed = params["seed"]

# model architecture params
image_layer_dims = params["image_layer_dims"]
metadata_layer_dims = params["metadata_layer_dims"]
fusion_layer_dims = params["fusion_layer_dims"]

In [3]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cuda')

In [4]:
import torch

torch.manual_seed(seed)
generator = torch.Generator().manual_seed(seed)

In [5]:
import torch

torch.set_float32_matmul_precision("high")

# Dataset


In [6]:
from datasets import load_dataset
from isic.dataset import ImageEncoder, MetadataEncoder, collate_batch

ds = load_dataset("mrbrobot/isic-2024", split="train")
ds = ds.select_columns(["image", "age_approx", "sex", "anatom_site_general", "target"])

len(ds)

401059

In [7]:
# encode metadata
metadata_encoder = MetadataEncoder().fit(ds)
ds = ds.with_format("arrow")
ds = ds.map(
    metadata_encoder,
    batched=True,
    batch_size=1000,
    desc="Encoding metadata columns",
)

# encode images
image_encoder = ImageEncoder(image_size=img_size)
ds = ds.with_format("torch")
ds = ds.with_transform(image_encoder, columns=["image"], output_all_columns=True)

# Model Definition


In [8]:
from isic.models import FusionMLPModel

model = FusionMLPModel(
    image_shape=image_shape,
    image_layer_dims=image_layer_dims,
    metadata_layer_dims=metadata_layer_dims,
    fusion_layer_dims=fusion_layer_dims,
).to(device)

model

FusionMLPModel(
  (image_stack): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): MLP(
      (network): Sequential(
        (0): Linear(in_features=49152, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=64, bias=True)
        (3): ReLU()
        (4): Linear(in_features=64, out_features=32, bias=True)
        (5): ReLU()
      )
    )
  )
  (metadata_stack): MLP(
    (network): Sequential(
      (0): Linear(in_features=8, out_features=16, bias=True)
      (1): ReLU()
      (2): Linear(in_features=16, out_features=32, bias=True)
      (3): ReLU()
    )
  )
  (fusion_head): Sequential(
    (0): MLP(
      (network): Sequential(
        (0): Linear(in_features=64, out_features=16, bias=True)
        (1): ReLU()
        (2): Linear(in_features=16, out_features=8, bias=True)
        (3): ReLU()
      )
    )
    (1): Linear(in_features=8, out_features=1, bias=True)
  )
)

In [9]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params

6303793

# Training


In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

print(f"Model device: {next(model.parameters()).device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(
    f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}"
)

Model device: cuda:0
Total parameters: 6,303,793
Trainable parameters: 6,303,793


Class imbalance measurement & handling


In [11]:
class_counts = [400666, 393]  # [benign, malignant] from EDA

print("Class Distribution:")
print(f"Benign: {class_counts[0]:,} samples")
print(f"Malignant: {class_counts[1]:,} samples")
print(f"Imbalance ratio: {class_counts[0] / class_counts[1]:.1f}:1")

Class Distribution:
Benign: 400,666 samples
Malignant: 393 samples
Imbalance ratio: 1019.5:1


In [12]:
df = ds.to_pandas()
neg_count = (df["target"] == 0).sum()
pos_count = (df["target"] == 1).sum()
pos_weight = neg_count / pos_count

print(f"Positive weight: {pos_weight:.1f}")
print(f"Scaled positive class weight: {pos_weight**class_power:.1f}")

Positive weight: 1019.5
Scaled positive class weight: 510.0


In [13]:
from isic.loss import WeightedFocalLoss

scaled_pos_weight = torch.tensor([pos_weight**class_power], device=device)
criterion = WeightedFocalLoss(pos_weight=scaled_pos_weight, gamma=focal_power)

In [14]:
from torch.utils.data import DataLoader

split = ds.train_test_split(test_size=0.2, seed=seed)
train_ds, val_ds = split["train"], split["test"]

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_batch,
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_batch,
)

print(f"Batches per epoch - Train: {len(train_loader)}, Val: {len(val_loader)}")

Batches per epoch - Train: 2507, Val: 627


In [15]:
import trackio
from isic.training import train, validate

trackio.init(project="mlp", config=params, embed=False)

for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    print("-" * 60)

    # train
    train_metrics = train(model, train_loader, criterion, optimizer, device, threshold)

    # validate
    val_metrics, confusion_mat = validate(
        model, val_loader, criterion, device, threshold
    )

trackio.finish()

* Trackio project initialized: mlp
* Trackio metrics logged to: /home/vscode/.cache/huggingface/trackio
* View dashboard by running in your terminal:
[1m[38;5;208mtrackio show --project "mlp"[0m
* or by running in Python: trackio.show(project="mlp")
* Created new run: brave-forest-1

Epoch 1/8
------------------------------------------------------------




Batch   0/2507: Loss: 0.1737 | Precision: 0.000 | Recall: 0.000
Batch 100/2507: Loss: 2.9075 | Precision: 0.001 | Recall: 0.286
Batch 200/2507: Loss: 0.1519 | Precision: 0.001 | Recall: 0.357
Batch 300/2507: Loss: 0.1552 | Precision: 0.001 | Recall: 0.244
Batch 400/2507: Loss: 0.1337 | Precision: 0.001 | Recall: 0.200
Batch 500/2507: Loss: 3.3835 | Precision: 0.001 | Recall: 0.179
Batch 600/2507: Loss: 0.1652 | Precision: 0.001 | Recall: 0.149
Batch 700/2507: Loss: 0.1651 | Precision: 0.001 | Recall: 0.123
Batch 800/2507: Loss: 0.1632 | Precision: 0.001 | Recall: 0.135
Batch 900/2507: Loss: 0.1645 | Precision: 0.001 | Recall: 0.134
Batch 1000/2507: Loss: 0.1591 | Precision: 0.002 | Recall: 0.137
Batch 1100/2507: Loss: 0.1569 | Precision: 0.002 | Recall: 0.145
Batch 1200/2507: Loss: 3.2494 | Precision: 0.002 | Recall: 0.138
Batch 1300/2507: Loss: 0.1524 | Precision: 0.002 | Recall: 0.138
Batch 1400/2507: Loss: 0.1553 | Precision: 0.002 | Recall: 0.138
Batch 1500/2507: Loss: 0.1585 | Pre


Epoch 2/8
------------------------------------------------------------




Batch   0/2507: Loss: 0.1408 | Precision: 0.000 | Recall: 0.000
Batch 100/2507: Loss: 0.1347 | Precision: 0.006 | Recall: 0.375
Batch 200/2507: Loss: 0.1555 | Precision: 0.005 | Recall: 0.200
Batch 300/2507: Loss: 0.1406 | Precision: 0.004 | Recall: 0.154
Batch 400/2507: Loss: 0.1371 | Precision: 0.005 | Recall: 0.188
Batch 500/2507: Loss: 0.1321 | Precision: 0.005 | Recall: 0.196
Batch 600/2507: Loss: 0.1207 | Precision: 0.004 | Recall: 0.182
Batch 700/2507: Loss: 3.2922 | Precision: 0.004 | Recall: 0.162
Batch 800/2507: Loss: 0.1354 | Precision: 0.005 | Recall: 0.191
Batch 900/2507: Loss: 0.1439 | Precision: 0.005 | Recall: 0.198
Batch 1000/2507: Loss: 0.1235 | Precision: 0.005 | Recall: 0.186
Batch 1100/2507: Loss: 0.1325 | Precision: 0.005 | Recall: 0.181
Batch 1200/2507: Loss: 0.1298 | Precision: 0.005 | Recall: 0.172
Batch 1300/2507: Loss: 0.1362 | Precision: 0.005 | Recall: 0.182
Batch 1400/2507: Loss: 0.1697 | Precision: 0.005 | Recall: 0.181
Batch 1500/2507: Loss: 0.1355 | Pre




Epoch 3/8
------------------------------------------------------------
Batch   0/2507: Loss: 0.1416 | Precision: 0.000 | Recall: 0.000
Batch 100/2507: Loss: 0.1738 | Precision: 0.003 | Recall: 0.250
Batch 200/2507: Loss: 0.1300 | Precision: 0.003 | Recall: 0.227
Batch 300/2507: Loss: 0.1283 | Precision: 0.004 | Recall: 0.233
Batch 400/2507: Loss: 0.1225 | Precision: 0.004 | Recall: 0.220
Batch 500/2507: Loss: 0.1381 | Precision: 0.005 | Recall: 0.245
Batch 600/2507: Loss: 0.1406 | Precision: 0.004 | Recall: 0.200
Batch 700/2507: Loss: 0.1452 | Precision: 0.004 | Recall: 0.188
Batch 800/2507: Loss: 0.1451 | Precision: 0.005 | Recall: 0.217
Batch 900/2507: Loss: 0.1397 | Precision: 0.005 | Recall: 0.221
Batch 1000/2507: Loss: 3.3340 | Precision: 0.006 | Recall: 0.226
Batch 1100/2507: Loss: 0.1275 | Precision: 0.006 | Recall: 0.217
Batch 1200/2507: Loss: 0.1528 | Precision: 0.005 | Recall: 0.204
Batch 1300/2507: Loss: 0.1394 | Precision: 0.005 | Recall: 0.199
Batch 1400/2507: Loss: 0.171




Epoch 4/8
------------------------------------------------------------
Batch   0/2507: Loss: 0.1621 | Precision: 0.000 | Recall: 0.000
Batch 100/2507: Loss: 0.1790 | Precision: 0.009 | Recall: 0.357
Batch 200/2507: Loss: 0.1658 | Precision: 0.008 | Recall: 0.348
Batch 300/2507: Loss: 0.1547 | Precision: 0.009 | Recall: 0.316
Batch 400/2507: Loss: 0.1425 | Precision: 0.008 | Recall: 0.318
Batch 500/2507: Loss: 0.1126 | Precision: 0.007 | Recall: 0.255
Batch 600/2507: Loss: 0.1159 | Precision: 0.006 | Recall: 0.209
Batch 700/2507: Loss: 3.1672 | Precision: 0.007 | Recall: 0.214
Batch 800/2507: Loss: 0.1923 | Precision: 0.007 | Recall: 0.213
Batch 900/2507: Loss: 0.1407 | Precision: 0.006 | Recall: 0.192
Batch 1000/2507: Loss: 0.1449 | Precision: 0.006 | Recall: 0.183
Batch 1100/2507: Loss: 3.1733 | Precision: 0.006 | Recall: 0.199
Batch 1200/2507: Loss: 0.1556 | Precision: 0.007 | Recall: 0.197
Batch 1300/2507: Loss: 0.1616 | Precision: 0.006 | Recall: 0.194
Batch 1400/2507: Loss: 0.161




Epoch 5/8
------------------------------------------------------------
Batch   0/2507: Loss: 0.1616 | Precision: 0.000 | Recall: 0.000
Batch 100/2507: Loss: 0.1676 | Precision: 0.002 | Recall: 0.077
Batch 200/2507: Loss: 0.1457 | Precision: 0.003 | Recall: 0.080
Batch 300/2507: Loss: 0.1460 | Precision: 0.003 | Recall: 0.108
Batch 400/2507: Loss: 0.1478 | Precision: 0.006 | Recall: 0.196
Batch 500/2507: Loss: 0.1629 | Precision: 0.007 | Recall: 0.222
Batch 600/2507: Loss: 0.1257 | Precision: 0.006 | Recall: 0.208
Batch 700/2507: Loss: 0.1303 | Precision: 0.006 | Recall: 0.176
Batch 800/2507: Loss: 0.1367 | Precision: 0.006 | Recall: 0.177
Batch 900/2507: Loss: 0.1227 | Precision: 0.005 | Recall: 0.183
Batch 1000/2507: Loss: 0.1656 | Precision: 0.006 | Recall: 0.202
Batch 1100/2507: Loss: 0.1479 | Precision: 0.006 | Recall: 0.189
Batch 1200/2507: Loss: 0.1558 | Precision: 0.006 | Recall: 0.199
Batch 1300/2507: Loss: 0.1464 | Precision: 0.006 | Recall: 0.203
Batch 1400/2507: Loss: 0.136




Epoch 6/8
------------------------------------------------------------
Batch   0/2507: Loss: 0.1847 | Precision: 0.000 | Recall: 0.000
Batch 100/2507: Loss: 0.1637 | Precision: 0.005 | Recall: 0.125
Batch 200/2507: Loss: 0.1311 | Precision: 0.002 | Recall: 0.100
Batch 300/2507: Loss: 0.1346 | Precision: 0.005 | Recall: 0.171
Batch 400/2507: Loss: 0.1402 | Precision: 0.004 | Recall: 0.149
Batch 500/2507: Loss: 0.1373 | Precision: 0.005 | Recall: 0.164
Batch 600/2507: Loss: 0.1478 | Precision: 0.004 | Recall: 0.133
Batch 700/2507: Loss: 0.1341 | Precision: 0.004 | Recall: 0.148
Batch 800/2507: Loss: 3.5549 | Precision: 0.004 | Recall: 0.132
Batch 900/2507: Loss: 0.1352 | Precision: 0.005 | Recall: 0.165
Batch 1000/2507: Loss: 0.1367 | Precision: 0.004 | Recall: 0.150
Batch 1100/2507: Loss: 0.1149 | Precision: 0.004 | Recall: 0.148
Batch 1200/2507: Loss: 0.1423 | Precision: 0.005 | Recall: 0.163
Batch 1300/2507: Loss: 0.1323 | Precision: 0.005 | Recall: 0.169
Batch 1400/2507: Loss: 0.160


Epoch 7/8
------------------------------------------------------------
Batch   0/2507: Loss: 3.2681 | Precision: 0.000 | Recall: 0.000
Batch 100/2507: Loss: 0.1724 | Precision: 0.000 | Recall: 0.000
Batch 200/2507: Loss: 3.2049 | Precision: 0.000 | Recall: 0.000
Batch 300/2507: Loss: 0.1591 | Precision: 0.004 | Recall: 0.139
Batch 400/2507: Loss: 0.1439 | Precision: 0.003 | Recall: 0.104
Batch 500/2507: Loss: 0.1308 | Precision: 0.004 | Recall: 0.150
Batch 600/2507: Loss: 0.1368 | Precision: 0.004 | Recall: 0.128
Batch 700/2507: Loss: 0.1919 | Precision: 0.005 | Recall: 0.160
Batch 800/2507: Loss: 0.1625 | Precision: 0.005 | Recall: 0.167
Batch 900/2507: Loss: 1.1556 | Precision: 0.006 | Recall: 0.186
Batch 1000/2507: Loss: 0.1617 | Precision: 0.005 | Recall: 0.175
Batch 1100/2507: Loss: 0.1321 | Precision: 0.006 | Recall: 0.180
Batch 1200/2507: Loss: 0.1596 | Precision: 0.006 | Recall: 0.180
Batch 1300/2507: Loss: 0.1490 | Precision: 0.006 | Recall: 0.193
Batch 1400/2507: Loss: 0.119




Epoch 8/8
------------------------------------------------------------
Batch   0/2507: Loss: 0.1599 | Precision: 0.000 | Recall: 0.000
Batch 100/2507: Loss: 3.2465 | Precision: 0.005 | Recall: 0.167
Batch 200/2507: Loss: 0.1903 | Precision: 0.007 | Recall: 0.214
Batch 300/2507: Loss: 0.1639 | Precision: 0.007 | Recall: 0.282
Batch 400/2507: Loss: 0.1581 | Precision: 0.007 | Recall: 0.250
Batch 500/2507: Loss: 0.1408 | Precision: 0.006 | Recall: 0.238
Batch 600/2507: Loss: 0.1767 | Precision: 0.007 | Recall: 0.225
Batch 700/2507: Loss: 0.1431 | Precision: 0.006 | Recall: 0.215
Batch 800/2507: Loss: 0.1574 | Precision: 0.006 | Recall: 0.202
Batch 900/2507: Loss: 0.1465 | Precision: 0.006 | Recall: 0.211
Batch 1000/2507: Loss: 0.1551 | Precision: 0.006 | Recall: 0.208
Batch 1100/2507: Loss: 0.1468 | Precision: 0.007 | Recall: 0.208
Batch 1200/2507: Loss: 0.1424 | Precision: 0.006 | Recall: 0.210
Batch 1300/2507: Loss: 0.1640 | Precision: 0.007 | Recall: 0.209
Batch 1400/2507: Loss: 0.120

* Run finished. Uploading logs to Trackio (please wait...)
