In [5]:
import tensorflow as tf
import tensorflow_datasets as tfds
import math, time
from tensorflow.keras import layers, models, losses, optimizers

# Enable mixed precision (faster on GPUs like V100/A100)
tf.keras.mixed_precision.set_global_policy("mixed_float16")

# ------------------ ResNet18 building blocks ------------------
class BasicBlock(tf.keras.Model):
    expansion = 1
    def __init__(self, filters, stride=1):
        super().__init__()
        self.conv1 = layers.Conv2D(filters, 3, strides=stride, padding="same", use_bias=False)
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.ReLU()
        self.conv2 = layers.Conv2D(filters, 3, strides=1, padding="same", use_bias=False)
        self.bn2 = layers.BatchNormalization()
        # Shortcut (skip connection):
        # If stride=1 → input and output shapes match → identity mapping
        # If stride=2 → use 1x1 conv to downsample input to correct shape
        if stride != 1:
            self.shortcut = models.Sequential([
                layers.Conv2D(filters, 1, strides=stride, use_bias=False),
                layers.BatchNormalization()
            ])
        else:
            self.shortcut = lambda x: x

    def call(self, x, training=False):
        out = self.relu(self.bn1(self.conv1(x), training=training))
        out = self.bn2(self.conv2(out), training=training)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

def make_layer(filters, blocks, stride=1):
    layers_ = [BasicBlock(filters, stride)]
    for _ in range(1, blocks):
        layers_.append(BasicBlock(filters, 1))
    return models.Sequential(layers_)


 # Input: CIFAR-10 images (32,32,3).

    # Stem: Conv → BN → ReLU.

    # 4 stages: 64, 128, 256, 512 filters.

    # Downsamples via stride=2 at each stage.

    # Global average pool → Dense layer → logits for 10 classes
def ResNet18(num_classes=10):
    inputs = layers.Input(shape=(32,32,3))
    x = layers.Conv2D(64, 3, strides=1, padding="same", use_bias=False)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = make_layer(64, 2, stride=1)(x)
    x = make_layer(128, 2, stride=2)(x)
    x = make_layer(256, 2, stride=2)(x)
    x = make_layer(512, 2, stride=2)(x)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(num_classes)(x)

    return models.Model(inputs, x)


# ------------------ Data pipeline ------------------
def preprocess_train(sample):
#   Normalize images to [0,1].
# One-hot encode labels for CategoricalCrossentropy
    image = tf.cast(sample['image'], tf.float32) / 255.0
    label = tf.one_hot(sample['label'], depth=10)

    # Pad + crop + flip
    image = tf.image.resize_with_crop_or_pad(image, 40, 40)
    image = tf.image.random_crop(image, [32,32,3])
    image = tf.image.random_flip_left_right(image)

    # Simple extra augmentation
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, 0.9, 1.1)

    return image, label

def preprocess_test(sample):
    image = tf.cast(sample['image'], tf.float32) / 255.0
    label = tf.one_hot(sample['label'], depth=10)
    return image, label

def get_datasets(batch_size=512):
    train_ds = tfds.load("cifar10", split="train", as_supervised=False)
    test_ds  = tfds.load("cifar10", split="test",  as_supervised=False)

    train_ds = (train_ds
                .shuffle(50000)
                .map(preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
                .batch(batch_size)
                .prefetch(tf.data.AUTOTUNE))
    test_ds = (test_ds
                .map(preprocess_test, num_parallel_calls=tf.data.AUTOTUNE)
                .batch(batch_size)
                .prefetch(tf.data.AUTOTUNE))
    return train_ds, test_ds

# ------------------ Cosine LR schedule with warmup ------------------
class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
#   Warmup for first few epochs → LR linearly increases from 0 to base.

# Then follows a cosine decay schedule → gradually lowers LR to near 0.

# This is common in DAWNBench-style fast training (improves convergence).
    def __init__(self, base_lr, steps, warmup_epochs=5):
        super().__init__()
        self.base_lr = base_lr
        self.steps = steps
        self.warmup_steps = warmup_epochs * steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        lr = 0.5 * self.base_lr * (1 + tf.cos(math.pi * step / self.steps))
        if self.warmup_steps > 0:
            warmup_lr = self.base_lr * step / tf.cast(self.warmup_steps, tf.float32)
            lr = tf.where(step < self.warmup_steps, warmup_lr, lr)
        return lr

# ------------------ EMA callback ------------------
# decay=0.999: means each update keeps 99.9% of old EMA weights, 0.1% of the new.

# ema_weights: a list to store the smoothed weights.
# At the very end of training:

# Replace model’s actual weights with the EMA-smoothed weights.

# So evaluation (validation/test) uses the stable version.
class EMA(tf.keras.callbacks.Callback):
    def __init__(self, decay=0.999):
        super().__init__()
        self.decay = decay
        self.ema_weights = None

    def on_train_begin(self, logs=None):
        self.ema_weights = [w.numpy() for w in self.model.weights]

    def on_batch_end(self, batch, logs=None):
        for i, w in enumerate(self.model.weights):
            self.ema_weights[i] = self.decay * self.ema_weights[i] + (1. - self.decay) * w.numpy()

    def on_train_end(self, logs=None):
        # Assign EMA weights back for evaluation
        for i, w in enumerate(self.model.weights):
            w.assign(self.ema_weights[i])

# ------------------ Training ------------------
def train_cifar10(epochs=200, batch_size=512, base_lr=0.4):
    train_ds, test_ds = get_datasets(batch_size)

    model = ResNet18(num_classes=10)

    steps_per_epoch = 50000 // batch_size
    lr_schedule = WarmupCosine(base_lr, steps=steps_per_epoch*epochs)

    optimizer = optimizers.AdamW(learning_rate=lr_schedule, weight_decay=5e-4)

    loss_fn = losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1)

    model.compile(
        optimizer=optimizer,
        loss=loss_fn,
        metrics=["accuracy"]
    )

    history = model.fit(train_ds,
                        validation_data=test_ds,
                        epochs=epochs,
                        callbacks=[EMA(decay=0.999)])

    test_loss, test_acc = model.evaluate(test_ds)
    print("Final Test Accuracy (EMA weights):", test_acc)

    return model, history

if __name__ == "__main__":
    model, history = train_cifar10(epochs=150, batch_size=512)
    print("Best Val Accuracy:", max(history.history['val_accuracy']))


Epoch 1/150
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 347ms/step - accuracy: 0.2640 - loss: 2.1383 - val_accuracy: 0.1009 - val_loss: 2.6377
Epoch 2/150
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m52s[0m 196ms/step - accuracy: 0.5445 - loss: 1.4908 - val_accuracy: 0.1000 - val_loss: 3.4299
Epoch 3/150
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 191ms/step - accuracy: 0.6542 - loss: 1.2775 - val_accuracy: 0.2506 - val_loss: 3.0155
Epoch 4/150
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 197ms/step - accuracy: 0.7305 - loss: 1.1287 - val_accuracy: 0.4599 - val_loss: 2.0079
Epoch 5/150
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 190ms/step - accuracy: 0.7758 - loss: 1.0287 - val_accuracy: 0.3612 - val_loss: 2.3058
Epoch 6/150
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 197ms/step - accuracy: 0.8096 - loss: 0.9552 - val_accuracy: 0.4530 - val_loss: 1.8292
Epoch 7/150
[1m