In [24]:
import os, random
import numpy as np
import torch
import torch.nn as nn
from copy import deepcopy
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torchvision import datasets, transforms, models
from sklearn.metrics import f1_score, classification_report

In [18]:
DATA_DIR = "/Users/christinsilos/Desktop/Deep Learning Practice/train"
IMG_SIZE = 224
BATCH_SIZE = 32
VAL_SPLIT = 0.2
SEED = 123
EPOCHS = 15
LR = 3e-4
WEIGHT_DECAY = 1e-4

In [19]:
# -----------------------
# Reproducibility
# -----------------------
def seed_all(seed=123):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_all(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [20]:
# -----------------------
# Transforms (ResNet expects ImageNet-style normalization)
# -----------------------
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [21]:
# -----------------------
# Dataset + split
# -----------------------
full_ds = datasets.ImageFolder(DATA_DIR, transform=train_tfms)
class_names = full_ds.classes
num_classes = len(class_names)
assert num_classes == 5, f"Expected 5 classes, found {num_classes}"
print("Classes:", class_names)

n = len(full_ds)
n_val = int(VAL_SPLIT * n)
n_train = n - n_val

train_ds, val_ds = random_split(
    full_ds,
    [n_train, n_val],
    generator=torch.Generator().manual_seed(SEED)
)

# Important: swap validation transforms (no augmentation)
val_ds.dataset.transform = val_tfms


Classes: ['acne', 'eksim', 'herpes', 'panu', 'rosacea']


In [22]:

# -----------------------
# Imbalance handling: WeightedRandomSampler for TRAIN subset
# -----------------------
train_targets = [full_ds.targets[i] for i in train_ds.indices]
class_counts = np.bincount(train_targets, minlength=num_classes)
print("Train class counts:", dict(zip(class_names, class_counts.tolist())))

class_inv_freq = 1.0 / np.clip(class_counts, 1, None)
sample_weights = [class_inv_freq[t] for t in train_targets]
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler,
                          num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=4, pin_memory=True)

Train class counts: {'acne': 242, 'eksim': 233, 'herpes': 249, 'panu': 229, 'rosacea': 243}


In [25]:
# -----------------------
# Model: ResNet18 pretrained + new classification head
# -----------------------
# torchvision >= 0.13 uses "weights=..."
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/christinsilos/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████████████████████████████████| 44.7M/44.7M [00:01<00:00, 23.7MB/s]


In [26]:
# -----------------------
# Loss: class-weighted CrossEntropy (based on TRAIN subset)
# (You can disable this if using sampler feels sufficient)
# -----------------------
train_counts = class_counts.astype(np.float32)
ce_weights = (train_counts.sum() / np.clip(train_counts, 1, None))
ce_weights = ce_weights / ce_weights.mean()  # normalize scale
ce_weights = torch.tensor(ce_weights, dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(weight=ce_weights)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)


In [27]:
# -----------------------
# Eval: macro F1
# -----------------------
@torch.no_grad()
def evaluate_f1(model, loader, device):
    model.eval()
    all_preds, all_true = [], []

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        logits = model(x)
        preds = torch.argmax(logits, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_true.append(y.cpu().numpy())

    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_true)

    macro_f1 = f1_score(y_true, y_pred, average="macro")
    per_class_f1 = f1_score(y_true, y_pred, average=None)
    return macro_f1, per_class_f1, y_true, y_pred


In [28]:
# -----------------------
# Training loop: pick best epoch by val macro F1
# -----------------------
best_f1 = -1.0
best_state = None

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0

    for x, y in train_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)

    train_loss = running_loss / len(train_loader.dataset)

    val_macro_f1, val_per_class_f1, _, _ = evaluate_f1(model, val_loader, device)

    print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} | val_macro_f1={val_macro_f1:.4f}")
    print("  per-class F1:", np.round(val_per_class_f1, 4))

    if val_macro_f1 > best_f1:
        best_f1 = val_macro_f1
        best_state = deepcopy(model.state_dict())

# Load best checkpoint
model.load_state_dict(best_state)
print("\nBest val macro F1:", best_f1)

# Final report on validation set
val_macro_f1, _, y_true, y_pred = evaluate_f1(model, val_loader, device)
print(classification_report(y_true, y_pred, target_names=class_names, digits=4))

Epoch 01 | train_loss=0.5283 | val_macro_f1=0.7680
  per-class F1: [0.7586 0.5275 0.844  0.9065 0.8036]
Epoch 02 | train_loss=0.2232 | val_macro_f1=0.8604
  per-class F1: [0.7967 0.8358 0.8713 0.9032 0.8947]
Epoch 03 | train_loss=0.0846 | val_macro_f1=0.8911
  per-class F1: [0.9256 0.8281 0.8545 0.9624 0.8846]
Epoch 04 | train_loss=0.0407 | val_macro_f1=0.8839
  per-class F1: [0.9217 0.8148 0.8119 0.9635 0.9074]
Epoch 05 | train_loss=0.0523 | val_macro_f1=0.8926
  per-class F1: [0.84   0.8872 0.8868 0.9429 0.906 ]
Epoch 06 | train_loss=0.1071 | val_macro_f1=0.8762
  per-class F1: [0.8908 0.7863 0.8929 0.9565 0.8545]
Epoch 07 | train_loss=0.1194 | val_macro_f1=0.8882
  per-class F1: [0.8908 0.8154 0.9072 0.9781 0.8496]
Epoch 08 | train_loss=0.0482 | val_macro_f1=0.8999
  per-class F1: [0.9649 0.7965 0.8644 0.9645 0.9091]
Epoch 09 | train_loss=0.0168 | val_macro_f1=0.9315
  per-class F1: [0.9464 0.896  0.9216 0.9784 0.9153]
Epoch 10 | train_loss=0.0190 | val_macro_f1=0.9075
  per-class F