In [None]:
import os
import math
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

SEED = 42
BATCH_SIZE = 128
EPOCHS = 50
NUM_CLASSES = 10
IMG_SIZE = 32

try:
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        from tensorflow.keras import mixed_precision
        mixed_precision.set_global_policy("mixed_float16")
        print("✅ Mixed precision enabled.")
except Exception as e:
    print("Mixed precision not enabled:", e)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

VAL_SPLIT = 5000
x_val, y_val = x_train[-VAL_SPLIT:], y_train[-VAL_SPLIT:]
x_train, y_train = x_train[:-VAL_SPLIT], y_train[:-VAL_SPLIT]

AUTOTUNE = tf.data.AUTOTUNE

def preprocess(img, label):
    img = tf.cast(img, tf.float32) / 255.0
    label = tf.squeeze(label, axis=-1)  # (1,) -> ()
    return img, label

def make_dataset(x, y, training=False):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if training:
        ds = ds.shuffle(10000, seed=SEED, reshuffle_each_iteration=True)
    ds = ds.map(preprocess, num_parallel_calls=AUTOTUNE)
    ds = ds.batch(BATCH_SIZE, drop_remainder=training)
    ds = ds.prefetch(AUTOTUNE)
    return ds

train_ds = make_dataset(x_train, y_train, training=True)
val_ds   = make_dataset(x_val,   y_val,   training=False)
test_ds  = make_dataset(x_test,  y_test,  training=False)

data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal", seed=SEED),
        layers.RandomTranslation(0.05, 0.05, seed=SEED),
        layers.RandomZoom(0.1, seed=SEED),
    ],
    name="augment",
)

def conv_bn_act(x, filters, kernel_size, stride=1):
    x = layers.Conv2D(filters, kernel_size, strides=stride, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    return x

def residual_block(x, filters, stride=1):
    shortcut = x
    x = conv_bn_act(x, filters, 3, stride)
    x = layers.Conv2D(filters, 3, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)

    if shortcut.shape[-1] != filters or stride != 1:
        shortcut = layers.Conv2D(filters, 1, strides=stride, padding="same", use_bias=False)(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = layers.Activation("relu")(x)
    return x

def build_resnet():
    inputs = layers.Input((IMG_SIZE, IMG_SIZE, 3))
    x = data_augmentation(inputs)
    x = conv_bn_act(x, 64, 3)
    x = residual_block(x, 64)
    x = residual_block(x, 64)
    x = residual_block(x, 128, stride=2)
    x = residual_block(x, 128)
    x = residual_block(x, 256, stride=2)
    x = residual_block(x, 256)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)
    dtype = tf.float32 if tf.keras.mixed_precision.global_policy().name == "mixed_float16" else None
    outputs = layers.Dense(NUM_CLASSES, activation="softmax", dtype=dtype)(x)
    return keras.Model(inputs, outputs, name="ResNet_CIFAR10")

model = build_resnet()
model.summary()

initial_lr = 1e-3
lr_schedule = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=initial_lr, decay_steps=EPOCHS * math.ceil(len(x_train) / BATCH_SIZE)
)

optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)

# --- Loss: compatible with older TF (no label_smoothing kw) ---
try:
    loss = keras.losses.SparseCategoricalCrossentropy(label_smoothing=0.05)
    print("Using SparseCategoricalCrossentropy with label smoothing 0.05")
except TypeError:
    print("⚠️ label_smoothing not supported for SparseCategoricalCrossentropy in this TF version.")
    print("Falling back to standard SparseCategoricalCrossentropy.")
    loss = keras.losses.SparseCategoricalCrossentropy()

model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])

ckpt = keras.callbacks.ModelCheckpoint(
    "cifar10_resnet_best.keras", monitor="val_accuracy", save_best_only=True, mode="max", verbose=1
)
early = keras.callbacks.EarlyStopping(
    monitor="val_accuracy", patience=10, mode="max", restore_best_weights=True, verbose=1
)
reduce = keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss", factor=0.5, patience=3, verbose=1, min_lr=1e-5
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[ckpt, early, reduce],
    verbose=1,
)

test_loss, test_acc = model.evaluate(test_ds, verbose=1)
print(f"Test accuracy: {test_acc:.4f}")
model.save("cifar10_resnet_final.keras")
