In [None]:
import os, random, json
import numpy as np

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import tensorflow as tf
import keras_hub
from sklearn.metrics import classification_report, confusion_matrix
from types import SimpleNamespace

In [None]:
args = SimpleNamespace(
    data_dir="data/artbench-10",
    img_size=(224, 224),
    batch_size=128,
    val_split=0.2,
    seed=42,
    epochs=200,
    lr_warmup=5e-4,
    lr_finetune=3e-5,
    preset="resnet_v2_50_imagenet",
    results_dir="results",
    warmup_epochs=5,
)

In [None]:
slurm_id = os.environ.get("SLURM_JOB_ID", "")
tag = f"_{slurm_id}" if slurm_id else ""

os.makedirs(args.results_dir, exist_ok=True)

# Seeds
tf.keras.utils.set_random_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)

In [None]:
def try_get_preset_preprocessor(preset: str):
    try:
        from keras_hub import layers as kh_layers
        if hasattr(kh_layers, "PresetPreprocessor"):
            return kh_layers.PresetPreprocessor.from_preset(preset)
    except Exception:
        pass

    try:
        if hasattr(keras_hub.models.ImageClassifier, "preprocessor_from_preset"):
            return keras_hub.models.ImageClassifier.preprocessor_from_preset(preset)
    except Exception:
        pass

    return None

preproc_layer = try_get_preset_preprocessor(args.preset)
print("Using preset preprocessor:", bool(preproc_layer is not None))

In [None]:
def build_datasets(data_dir, img_size, batch_size, val_split, seed, preproc_layer=None):
    AUTOTUNE = tf.data.AUTOTUNE

    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        os.path.join(data_dir, "train"),
        validation_split=val_split,
        subset="training",
        seed=seed,
        image_size=tuple(img_size),
        batch_size=batch_size,
        label_mode="int",
        shuffle=True,
    )

    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        os.path.join(data_dir, "train"),
        validation_split=val_split,
        subset="validation",
        seed=seed,
        image_size=tuple(img_size),
        batch_size=batch_size,
        label_mode="int",
        shuffle=False,
    )

    test_ds = tf.keras.preprocessing.image_dataset_from_directory(
        os.path.join(data_dir, "test"),
        image_size=tuple(img_size),
        batch_size=batch_size,
        label_mode="int",
        shuffle=False,
    )

    class_names = train_ds.class_names
    num_classes = len(class_names)

    if preproc_layer is not None:
        def _map(x, y):
            x = tf.cast(x, tf.float32)
            x = preproc_layer(x, training=False)
            return x, y
    else:
        MEAN = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32)
        STD  = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32)
        def _map(x, y):
            x = tf.cast(x, tf.float32) / 255.0
            x = (x - MEAN) / STD
            return x, y

    train_ds = train_ds.map(_map, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
    val_ds   = val_ds.map(_map,   num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
    test_ds  = test_ds.map(_map,  num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

    return train_ds, val_ds, test_ds, class_names, num_classes

train_ds, val_ds, test_ds, class_names, num_classes = build_datasets(
    args.data_dir, args.img_size, args.batch_size, args.val_split, args.seed,
    preproc_layer=preproc_layer
)
print(f"Classes ({num_classes}):", class_names)

In [None]:
def plot_history(history, out_png):
    hist = history.history
    xx = np.arange(1, len(hist["loss"]) + 1)

    fig, axs = plt.subplots(1, 2, figsize=(10, 2.25))
    fig.tight_layout()
    plt.subplots_adjust(wspace=0.3)

    ax = axs[0]; ax.grid(True); ax.set_ylabel("loss")
    ax.plot(xx, hist["loss"],     "b-", label="loss")
    ax.plot(xx, hist["val_loss"], "r-", label="val_loss")
    ax.legend(loc="best")

    ax = axs[1]; ax.grid(True); ax.set_ylabel("accuracy")
    ax.plot(xx, hist["accuracy"],     "b-", label="accuracy")
    ax.plot(xx, hist["val_accuracy"], "r-", label="val_accuracy")
    ax.legend(loc="best")

    for ax in axs: ax.set_xlabel("epoch")
    plt.savefig(out_png, bbox_inches="tight", dpi=150)
    plt.close(fig)


def plot_confusion_matrix(cm, classes, out_png, normalize=True):
    if normalize:
        cm = cm.astype("float") / (cm.sum(axis=1, keepdims=True) + 1e-12)
    plt.figure(figsize=(5.2, 4.5))
    plt.imshow(cm, interpolation="nearest")
    plt.title("Confusion matrix" + (" (normalized)" if normalize else ""))
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)
    fmt = ".2f" if normalize else "d"
    thresh = cm.max() / 2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            val = format(cm[i, j], fmt)
            plt.text(j, i, val,
                     ha="center",
                     color="white" if cm[i, j] > thresh else "black",
                     fontsize=7)
    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150, bbox_inches="tight")
    plt.close()


def merge_histories(H_list):
    merged = {}
    for H in H_list:
        for k, v in H.history.items():
            merged.setdefault(k, [])
            merged[k].extend(v)
    class _Hist:
        def __init__(self, hist):
            self.history = hist
    return _Hist(merged)

In [None]:
M = keras_hub.models.ImageClassifier.from_preset(
    args.preset,
    num_classes=num_classes,
    activation="softmax",
    dropout=0.5,
)

acc_metric = tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")
top5_metric = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name="top5")

reduce_cb = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_accuracy", factor=0.4, patience=6, verbose=1
)
early_cb = tf.keras.callbacks.EarlyStopping(
    monitor="val_accuracy", patience=12, restore_best_weights=True, verbose=1
)
csv_cb = tf.keras.callbacks.CSVLogger(
    os.path.join(args.results_dir, f"history{tag}.csv")
)

In [None]:
warmup_epochs = max(1, min(args.warmup_epochs, args.epochs))
print(f"Warmup epochs (frozen backbone): {warmup_epochs}")

M.backbone.trainable = False
optimizer_warmup = tf.keras.optimizers.Adam(learning_rate=args.lr_finetune)
M.compile(
    optimizer=optimizer_warmup,
    loss="sparse_categorical_crossentropy",
    metrics=[acc_metric, top5_metric],
)
H1 = M.fit(
    train_ds,
    validation_data=val_ds,
    epochs=warmup_epochs,
    callbacks=[csv_cb],
    verbose=0
)

print(f"Fine-tune epochs (unfrozen backbone): {args.epochs - warmup_epochs}")
M.backbone.trainable = True
optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)
M.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=[acc_metric, top5_metric],
)
H2 = M.fit(
    train_ds,
    validation_data=val_ds,
    initial_epoch=warmup_epochs,
    epochs=args.epochs,
    callbacks=[reduce_cb, early_cb, csv_cb],
    verbose=0
)

H = merge_histories([H1, H2])

In [None]:
train_metrics = M.evaluate(train_ds, return_dict=True, verbose=0)
test_metrics  = M.evaluate(test_ds,  return_dict=True, verbose=0)

acc_train = float(train_metrics["accuracy"])
acc_test  = float(test_metrics["accuracy"])
top5_test = test_metrics.get("top5")

print(f"Train accuracy : {acc_train:.2%}")
print(f"Test accuracy  : {acc_test:.2%}")
if top5_test is not None:
    print(f"Test top-5     : {top5_test:.2%}")

out_png = os.path.join(args.results_dir, f"history{tag}.png")
plot_history(H, out_png)
print(f"Saved: {out_png}")

In [None]:
# True labels
y_true = []
for _, y in test_ds:
    y_true.append(y.numpy())
y_true = np.concatenate(y_true, axis=0)

# Predictions
y_pred = np.argmax(M.predict(test_ds, verbose=0), axis=1)

# Confusion matrix plot
cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
cm_png = os.path.join(args.results_dir, f"confusion_matrix{tag}.png")
plot_confusion_matrix(cm, class_names, cm_png, normalize=True)
print(f"Saved: {cm_png}")

# Per class accuracy
per_class_total = cm.sum(axis=1)
per_class_correct = np.diag(cm)
per_class_acc = np.divide(
    per_class_correct,
    np.maximum(per_class_total, 1),
    out=np.zeros_like(per_class_correct, dtype=float),
    where=per_class_total != 0,
)

In [None]:
report = classification_report(
    y_true, y_pred, target_names=class_names, digits=3
)
rep_path = os.path.join(args.results_dir, f"classification_report{tag}.txt")
with open(rep_path, "w", encoding="utf-8") as f:
    f.write(report + "\n")
    f.write(f"\nTrain accuracy: {acc_train:.4f}\n")
    f.write(f"Test accuracy : {acc_test:.4f}\n")
    if top5_test is not None:
        f.write(f"Test top5     : {top5_test:.4f}\n")

    f.write("\nPer-class accuracy:\n")
    for cname, acc, tot in zip(class_names, per_class_acc, per_class_total):
        f.write(f"  {cname:>20s}: {acc*100:6.2f}%  (n={int(tot)})\n")

    meta = {
        "preset": args.preset,
        "epochs": int(args.epochs),
        "warmup_epochs": int(warmup_epochs),
        "batch_size": int(args.batch_size),
        "lr": float(args.lr),
        "img_size": list(args.img_size),
        "seed": int(args.seed),
        "used_preset_preprocessor": bool(preproc_layer is not None),
    }
    f.write("\nMeta:\n")
    f.write(json.dumps(meta, indent=2))
print(f"Saved: {rep_path}")