# Notebook setup

In [27]:
import os, random, json
from pathlib import Path
import numpy as np
import torch, torch.nn as nn
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torchvision import datasets
from sklearn.metrics import classification_report, confusion_matrix
from collections import Counter
import timm
from timm.data import resolve_model_data_config, create_transform
from tqdm import tqdm

In [28]:
# Reproducibility
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

DATA_DIR = Path("/Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/Data/images_1000")
OUT_DIR  = Path("/Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models")
OUT_DIR.mkdir(exist_ok=True)

print("DATA_DIR:", DATA_DIR)
assert DATA_DIR.exists(), "DATA_DIR not found!"

DATA_DIR: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/Data/images_1000


Device & speed knobs

In [18]:
# Device wählen (MPS = Apple GPU)
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print("Device:", device)

# Speed/Training Knobs
BATCH = 32 if device in ("mps", "cuda") else 8
NUMW  = 0        # macOS → safer
PIN   = (device == "cuda")

EPOCHS_HEAD = 40   # Kopf trainieren
EPOCHS_FT   = 6    # Fine-Tuning
PATIENCE    = 5    # Early Stopping

Device: mps


## Dataset (1,000 imgs / 10 classes)

In [19]:
# Dataset laden, nur um Klassen zu sehen
tmp_ds = datasets.ImageFolder(DATA_DIR, transform=None)
classes = tmp_ds.classes
num_classes = len(classes)
print("Classes:", classes, f"({num_classes})")

# Anzahl Bilder pro Klasse checken
counts = Counter([lbl for _, lbl in tmp_ds.samples])
print("Per-class counts:", counts)

# Stratified Split Funktion
def stratified_split(samples, val_ratio=0.15, test_ratio=0.15, seed=SEED):
    y = np.array([c for _, c in samples])
    idxs = np.arange(len(y))
    train_idx, val_idx, test_idx = [], [], []
    rng = np.random.default_rng(seed)
    for c in np.unique(y):
        c_idx = idxs[y == c]
        rng.shuffle(c_idx)
        n = len(c_idx)
        n_test = int(round(n * test_ratio))
        n_val  = int(round(n * val_ratio))
        n_train = n - n_val - n_test
        train_idx += list(c_idx[:n_train])
        val_idx   += list(c_idx[n_train:n_train+n_val])
        test_idx  += list(c_idx[n_train+n_val:])
    return train_idx, val_idx, test_idx

train_idx, val_idx, test_idx = stratified_split(tmp_ds.samples, 0.15, 0.15)
print("Num train/val/test:", len(train_idx), len(val_idx), len(test_idx))


Classes: ['Abstract_Expressionism', 'Cubism', 'Expressionism', 'Impressionism', 'Neoclassicism', 'Post-Impressionism', 'Realism', 'Romanticism', 'Surrealism', 'Symbolism'] (10)
Per-class counts: Counter({0: 100, 1: 100, 2: 100, 3: 100, 4: 100, 5: 100, 6: 100, 7: 100, 8: 100, 9: 100})
Num train/val/test: 700 150 150


# Model + Transforms

In [None]:
MODEL_NAME = "resnet18.a1_in1k"
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=num_classes)

cfg = resolve_model_data_config(model)
tfm_train = create_transform(**cfg, is_training=True)
tfm_val   = create_transform(**cfg, is_training=False)

# Datasets mit Transforms
full_train = datasets.ImageFolder(DATA_DIR, transform=tfm_train)
full_val   = datasets.ImageFolder(DATA_DIR, transform=tfm_val)
full_test  = datasets.ImageFolder(DATA_DIR, transform=tfm_val)

train_ds = Subset(full_train, train_idx)
val_ds   = Subset(full_val,   val_idx)
test_ds  = Subset(full_test,  test_idx)

print("Num train/val/test:", len(train_ds), len(val_ds), len(test_ds))


Num train/val/test: 700 150 150


# Loaders + imbalance handling

In [None]:
# Class weights (handle imbalance even if slight)
y_train = [full_train.samples[i][1] for i in train_idx]

class_counts = Counter(y_train)
weights = torch.tensor([1.0 / class_counts[c] for c in range(num_classes)], dtype=torch.float)

# Either use weighted loss...
use_weighted_loss = True

# ...or oversampling sampler (commented out by default)
use_sampler = False
if use_sampler:
    sample_weights = [1.0 / class_counts[y] for y in y_train]
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(y_train), replacement=True)

train_loader = DataLoader(
    train_ds, batch_size=BATCH, shuffle=(not use_sampler),
    sampler=(sampler if use_sampler else None),
    num_workers=NUMW, pin_memory=PIN
)
val_loader = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=NUMW, pin_memory=PIN)
test_loader = DataLoader(test_ds, batch_size=BATCH, shuffle=False, num_workers=NUMW, pin_memory=PIN)


# Training loop

In [22]:
for p in model.parameters():
    p.requires_grad = False

head = model.get_classifier()          # timm helper
if isinstance(head, nn.Module):
    for p in head.parameters():
        p.requires_grad = True

model = model.to(device)
criterion = nn.CrossEntropyLoss(weight=weights.to(device) if use_weighted_loss else None)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_HEAD)

best_val = float("inf")
best_path = OUT_DIR / f"{MODEL_NAME.replace('/','_')}_best.pt"
bad = 0

for epoch in range(1, EPOCHS_HEAD+1):
    model.train()
    tr_loss, tr_correct, n = 0.0, 0, 0
    for x, y in tqdm(train_loader, leave=False):
        x, y = x.to(device), torch.as_tensor(y, device=device)
        optimizer.zero_grad()
        # Mixed precision only on CUDA; on MPS it’s still quirky → skip by default
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        tr_loss += loss.item() * y.size(0)
        tr_correct += (logits.argmax(1) == y).sum().item()
        n += y.size(0)
    tr_acc = tr_correct / n

    model.eval()
    val_loss, val_correct, m = 0.0, 0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), torch.as_tensor(y, device=device)
            logits = model(x)
            loss = criterion(logits, y)
            val_loss += loss.item() * y.size(0)
            val_correct += (logits.argmax(1) == y).sum().item()
            m += y.size(0)
    val_acc = val_correct / m
    val_loss /= m
    scheduler.step()

    print(f"[HEAD {epoch}/{EPOCHS_HEAD}] train_loss={tr_loss/n:.4f} acc={tr_acc:.3f} | val_loss={val_loss:.4f} acc={val_acc:.3f}")

    if val_loss < best_val - 1e-4:
        best_val, bad = val_loss, 0
        torch.save({"state_dict": model.state_dict(), "classes": classes}, best_path)
        print("  ↳ saved:", best_path)
    else:
        bad += 1
        if bad >= PATIENCE:
            print("Early stopping (head).")
            break


                                               

[HEAD 1/40] train_loss=2.2873 acc=0.163 | val_loss=2.1936 acc=0.313
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 2/40] train_loss=2.1842 acc=0.234 | val_loss=2.1193 acc=0.340
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 3/40] train_loss=2.1117 acc=0.281 | val_loss=2.0623 acc=0.367
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 4/40] train_loss=2.0384 acc=0.363 | val_loss=2.0156 acc=0.393
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 5/40] train_loss=1.9753 acc=0.386 | val_loss=1.9638 acc=0.387
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 6/40] train_loss=1.9315 acc=0.389 | val_loss=1.9306 acc=0.373
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 7/40] train_loss=1.8882 acc=0.391 | val_loss=1.8995 acc=0.420
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 8/40] train_loss=1.8470 acc=0.420 | val_loss=1.8683 acc=0.387
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 9/40] train_loss=1.8184 acc=0.430 | val_loss=1.8453 acc=0.420
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 10/40] train_loss=1.8116 acc=0.427 | val_loss=1.8256 acc=0.393
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 11/40] train_loss=1.7765 acc=0.444 | val_loss=1.8115 acc=0.413
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 12/40] train_loss=1.7511 acc=0.430 | val_loss=1.7959 acc=0.420
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 13/40] train_loss=1.7191 acc=0.444 | val_loss=1.7791 acc=0.433
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 14/40] train_loss=1.7048 acc=0.453 | val_loss=1.7643 acc=0.427
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 15/40] train_loss=1.6929 acc=0.456 | val_loss=1.7623 acc=0.453
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 16/40] train_loss=1.6700 acc=0.479 | val_loss=1.7387 acc=0.433
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 17/40] train_loss=1.6593 acc=0.473 | val_loss=1.7435 acc=0.420


                                               

[HEAD 18/40] train_loss=1.6646 acc=0.466 | val_loss=1.7288 acc=0.427
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 19/40] train_loss=1.6248 acc=0.481 | val_loss=1.7210 acc=0.427
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 20/40] train_loss=1.6308 acc=0.471 | val_loss=1.7167 acc=0.433
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 21/40] train_loss=1.6168 acc=0.490 | val_loss=1.7097 acc=0.413
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 22/40] train_loss=1.6200 acc=0.493 | val_loss=1.7155 acc=0.427


                                               

[HEAD 23/40] train_loss=1.6355 acc=0.473 | val_loss=1.7000 acc=0.447
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 24/40] train_loss=1.6072 acc=0.476 | val_loss=1.6969 acc=0.447
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 25/40] train_loss=1.5828 acc=0.501 | val_loss=1.6915 acc=0.433
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 26/40] train_loss=1.5797 acc=0.486 | val_loss=1.6945 acc=0.440


                                               

[HEAD 27/40] train_loss=1.5722 acc=0.511 | val_loss=1.6853 acc=0.427
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 28/40] train_loss=1.5866 acc=0.514 | val_loss=1.6846 acc=0.427
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 29/40] train_loss=1.5701 acc=0.507 | val_loss=1.6766 acc=0.427
  ↳ saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_best.pt


                                               

[HEAD 30/40] train_loss=1.5718 acc=0.491 | val_loss=1.6887 acc=0.447


                                               

[HEAD 31/40] train_loss=1.5660 acc=0.493 | val_loss=1.6807 acc=0.453


                                               

[HEAD 32/40] train_loss=1.5567 acc=0.500 | val_loss=1.6832 acc=0.433


                                               

[HEAD 33/40] train_loss=1.5839 acc=0.501 | val_loss=1.6877 acc=0.440


                                               

[HEAD 34/40] train_loss=1.5521 acc=0.533 | val_loss=1.6848 acc=0.427
Early stopping (head).


# Fine-tune deeper layers (low LR)

In [23]:
# (Optional) open last blocks for FT; for VGG16 you can open more layers:
for p in model.parameters():
    p.requires_grad = False

head = model.get_classifier()          # timm helper
if isinstance(head, nn.Module):
    for p in head.parameters():
        p.requires_grad = True

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5, weight_decay=1e-4)
FT_EPOCHS = EPOCHS_FT

for e in range(1, FT_EPOCHS+1):
    model.train()
    for x, y in tqdm(train_loader, leave=False):
        x, y = x.to(device), torch.as_tensor(y, device=device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
    print(f"[FT {e}/{FT_EPOCHS}] done")

                                               

[FT 1/6] done


                                               

[FT 2/6] done


                                               

[FT 3/6] done


                                               

[FT 4/6] done


                                               

[FT 5/6] done


                                               

[FT 6/6] done




# Save & evaluate

In [24]:
final_path = OUT_DIR / f"{MODEL_NAME.replace('/','_')}_final.pt"
torch.save({"state_dict": model.state_dict(), "classes": classes}, final_path)
print("Saved:", final_path)

# Test
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for x, y in tqdm(test_loader, leave=False):
        x = x.to(device)
        logits = model(x).cpu()
        y_true += list(y.numpy())
        y_pred += list(logits.argmax(1).numpy())

print("\nClassification report (test):")
print(classification_report(y_true, y_pred, target_names=classes, digits=3))
print("Confusion matrix (rows=true, cols=pred):")
print(confusion_matrix(y_true, y_pred))


Saved: /Users/fabianschwientek/code/gwen-m97/inspiart/inspiart/models/resnet18.a1_in1k_final.pt


                                             


Classification report (test):
                        precision    recall  f1-score   support

Abstract_Expressionism      0.421     0.533     0.471        15
                Cubism      0.533     0.533     0.533        15
         Expressionism      0.286     0.133     0.182        15
         Impressionism      0.450     0.600     0.514        15
         Neoclassicism      0.417     0.667     0.513        15
    Post-Impressionism      0.421     0.533     0.471        15
               Realism      0.077     0.067     0.071        15
           Romanticism      0.250     0.200     0.222        15
            Surrealism      0.462     0.400     0.429        15
             Symbolism      0.500     0.267     0.348        15

              accuracy                          0.393       150
             macro avg      0.382     0.393     0.375       150
          weighted avg      0.382     0.393     0.375       150

Confusion matrix (rows=true, cols=pred):
[[ 8  0  2  1  0  2  0  0  0 

