In [None]:
import os
import tensorflow as tf
import tensorflow.keras as keras
# import keras
from keras import layers
# from keras import ops

import numpy as np
import matplotlib.pyplot as plt

num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [
    2048,
    1024,
]  # Size of the dense layers of the final classifier
dropout_rate = 0.1

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)



def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        input_shape = images.shape
        batch_size = tf.shape(images)[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        p_s = [1, self.patch_size, self.patch_size, 1]
        patches = tf.image.extract_patches(images, sizes=p_s, strides=p_s, rates=[1,1,1,1], padding='VALID')

        patches = tf.reshape(patches, (
                batch_size,
                num_patches_h * num_patches_w,
                self.patch_size * self.patch_size * channels,
            ),
        )
        return patches

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config


In [None]:
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.numpy().astype("uint8"))
    plt.axis("off")

In [None]:

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches+1
        self.cls_token = tf.Variable(initial_value=tf.random.normal([1, 1, projection_dim]))
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=self.num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.expand_dims(
            tf.range(start=0, limit=self.num_patches, delta=1), axis=0
        )
        bz = tf.shape(patch)[0]
        cls_tokens = tf.repeat(self.cls_token, bz, axis=0)
        projected_patches = self.projection(patch)
        projected_token = tf.concat([cls_tokens, projected_patches], axis=1)
        encoded = projected_token + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config

class My_MultiHeadAttention(layers.Layer):
    def __init__(self, embed_dim, num_heads, dropout_rate=0.1):
        super(My_MultiHeadAttention, self).__init__()
        self.heads = num_heads  # Number of attention heads
        self.projection_dim = embed_dim // self.heads  # Dimension of projected vectors
        self.dropout_rate = dropout_rate  # Dropout rate
        assert embed_dim % num_heads == 0

        # Dense layers for query, key, and value projections
        self.qkv = [layers.Dense(embed_dim) for _ in range(3)]
        # Dense layer for combining attention heads
        self.out = layers.Dense(embed_dim)  # dropout to improve generalisation

    def split_heads(self, vector, n):
        # Reshape input tensor to (batch_size, num_heads, seq_len, projection_dim)
        vector = tf.reshape(vector, (n, -1, self.heads, self.projection_dim))
        # Transpose to (batch_size, num_heads, projection_dim, seq_len)
        return tf.transpose(vector, perm=[0, 2, 1, 3])

    def compute_attention(self, q, k, v):
        # Compute scaled dot product attention
        score = tf.matmul(q, k, transpose_b=True)
        scaled_score = score / tf.math.sqrt(tf.cast(tf.shape(k)[-1], tf.float32))
        weights = tf.nn.softmax(scaled_score, axis=-1)
        attention = tf.matmul(weights, v)
        return attention

    def call(self, inputs):
        n = tf.shape(inputs)[0]  # Get batch size
        # Project inputs to q k v
        q, k, v = [dense(inputs) for dense in self.qkv]

        # Split heads for query, key, and value tensors
        q, k, v = [self.split_heads(x, n) for x in (q, k, v)]

        # Compute scaled dot product attention
        attention  = self.compute_attention(q, k, v)

        # Transpose and reshape the attention output
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(attention, (n, -1, self.projection_dim * self.heads))  # Reshape with fully defined shape

        # Apply dropout and combine attention heads
        output = self.out(concat_attention)
        output = tf.nn.dropout(output, rate=self.dropout_rate)
        return output


VIT MODEL

In [None]:
from tensorflow import math, matmul, reshape, shape, transpose, cast, float32
from tensorflow.keras.layers import Dense, Layer
from keras.backend import softmax


def create_vit_classifier():
    inputs = keras.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        att_module = My_MultiHeadAttention(embed_dim=projection_dim, num_heads=num_heads)
        attention_output = att_module(x1) # TODO: replace the required inputs of your module
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])
    encoded_cls = encoded_patches[:, 0]  # encoded feature of cls token
    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_cls)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model


Training the model

In [None]:
def plot_history(history, item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()

vit_classifier = create_vit_classifier()

optimizer = keras.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

vit_classifier.compile(
    optimizer=optimizer,
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    checkpoint_filepath,
    monitor="val_accuracy",
    save_best_only=True,
    save_weights_only=True,
)

history = vit_classifier.fit(
    x=x_train,
    y=y_train,
    batch_size=batch_size,
    epochs=num_epochs,
    validation_split=0.1,
    callbacks=[checkpoint_callback],
)
plot_history(history, "loss")
plot_history(history, "accuracy")