# Vision Transformer

In [None]:
import tensorflow as tf
import tensorflow.keras.layers as layers
import keras
import numpy as np
import matplotlib.pyplot as plt

@keras.saving.register_keras_serializable(package="vit_model", name="MultiHeadAttention")
class MultiHeadSelfAttention(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads=8, index=0):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if embed_dim % num_heads != 0:
            raise ValueError(
                f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}"
            )
        self.index = index
        self.projection_dim = embed_dim // num_heads
        self.query_dense = layers.Dense(embed_dim, name = f'query_{index}')
        self.key_dense = layers.Dense(embed_dim, name = f"key_{index}")
        self.value_dense = layers.Dense(embed_dim, name = f'value_{index}')
        self.combine_heads = layers.Dense(embed_dim, name = f'out_{index}')

    def attention(self, query, key, value):
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dim_key)
        weights = tf.nn.softmax(scaled_score, axis=-1)
        output = tf.matmul(weights, value)
        return output, weights

    def separate_heads(self, x, batch_size):
        x = tf.reshape(
            x, (batch_size, -1, self.num_heads, self.projection_dim)
        )
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def get_config(self):
        config = super().get_config()
        config.update({"embed_dim": self.embed_dim,
                       "num_heads": self.num_heads  ,
                       "index": self.index                     
                       })
        return config

    def call(self, inputs, training=False):
        batch_size = tf.shape(inputs)[0]
        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)
        query = self.separate_heads(query, batch_size)
        key = self.separate_heads(key, batch_size)
        value = self.separate_heads(value, batch_size)

        attention, weights = self.attention(query, key, value)
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(
            attention, (batch_size, -1, self.embed_dim)
        )
        output = self.combine_heads(concat_attention)
        return output

@keras.saving.register_keras_serializable(package="vit_model", name="TransformerBlock")
class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1, index = 0):
        super(TransformerBlock, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.index = index
        self.name = f"transformer_block_{self.index}"
        self.att = MultiHeadSelfAttention(embed_dim, num_heads, self.index)
        self.mlp = tf.keras.Sequential(
            [
                layers.Dense(mlp_dim, activation=keras.activations.gelu, name =f'enc_{self.index}_d1'),
                layers.Dropout(dropout),
                layers.Dense(embed_dim, name = f'enc_{index}_d2'),
                layers.Dropout(dropout),
            ],
            name='mlp'
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)
        
    def get_config(self):
        config = super().get_config()
        config.update({"mlp_dim": self.mlp_dim,
                       "embed_dim": self.embed_dim,
                       "num_heads": self.num_heads,
                        "dropout": self.dropout,
                        "att": self.att,
                        "index": self.index
                       })
        return config
    
    @classmethod
    def from_config(cls, config):
        config.pop('trainable')
        config.pop('dtype')
        config['att'] = keras.layers.deserialize(config["att"])
        return cls(**config)

    def call(self, inputs, training):
        inputs_norm = self.layernorm1(inputs)
        attn_output = self.att(inputs_norm)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = attn_output + inputs

        out1_norm = self.layernorm2(out1)
        mlp_output = self.mlp(out1_norm)
        mlp_output = self.dropout2(mlp_output, training=training)
        return mlp_output + out1

@keras.saving.register_keras_serializable(package="vit_model", name="ViT")
class ViT(keras.Model):
    def __init__(
        self,
        image_size,
        patch_size,
        depth,
        num_classes,
        hidden_dim,
        num_heads,
        mlp_dim,
        dropout=0.1,
        name = "ViT",
        include_head = True,
        channels = 3
    ):
        super(ViT, self).__init__()
        self.image_size = image_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.depth = depth
        self.num_classes = num_classes
        self.mlp_dim = mlp_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout
        self.name = name
        self.input_shape = (image_size, image_size, channels)
        self.include_head = include_head
        self.mlp_head = None
        self.channels = channels

        # Create patch embedding/flattening layers
        self.conv_proj = layers.Conv2D(filters = self.hidden_dim, 
                                       kernel_size = self.patch_size,
                                       strides=self.patch_size,
                                       name = "patches_encoder")
        self.patch_flatten = layers.Reshape((self.num_patches, 
                                             self.hidden_dim),
                                             name = "patches_flatten")
        
        self.add_cls_emb = layers.Concatenate(name = 'concat_pos_embed', axis = 1)

        self.class_emb = tf.Variable(
            initial_value=tf.random.normal([1, 1, self.hidden_dim]),
            trainable=True,
            name="class_emb",
            dtype=tf.float32
        )

        self.pos_emb = tf.Variable(
            initial_value=tf.random.normal([1, self.num_patches + 1, self.hidden_dim]),
            trainable=True,
            name="pos_emb",
            dtype=tf.float32
        )

        self.transformer_encoder = [
            TransformerBlock(self.hidden_dim, self.num_heads, self.mlp_dim, self.dropout_rate, index)
            for index in range(self.depth)
        ]
        
        if self.include_head:
            # Classification head
            self.mlp_head = tf.keras.Sequential(
                [
                    layers.LayerNormalization(epsilon=1e-6, name = 'mlp_head_ln'),
                    layers.Dense(self.mlp_dim, activation=tf.keras.activations.tanh, name = 'mlp_head_d1'),
                    layers.Dense(self.num_classes, name= 'classifier'),
                ], 
                name = 'mlp_head'
            )

    def build_functional(self):
        input_layer = layers.Input(shape=(self.image_size, self.image_size, self.channels), name = 'Input')
        output_layer = self(input_layer)
        return keras.Model(inputs=input_layer, outputs=output_layer, name = self.name)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "image_size": self.image_size,
                "patch_size": self.patch_size,
                "hidden_dim": self.hidden_dim,
                "num_heads": self.num_heads,
                "depth": self.depth,
                "mlp_dim": self.mlp_dim,
                "num_classes": self.num_classes,
                "dropout": self.dropout_rate,
                "name": self.name,
                "include_head": self.include_head,
                "channels": self.channels
            }
        )

        return config
    
    @classmethod
    def from_config(cls, config):
        config.pop('trainable')
        config.pop('dtype')
        return cls(**config)
        
    def call(self, inputs, training= False):
        # Create encoded patches
        batch_size = tf.shape(inputs)[0]
        x = self.conv_proj(inputs)
        x = self.patch_flatten(x)

        # Adjust and prepend class embedding
        cls_emb = tf.broadcast_to(
            self.class_emb, [batch_size, 1, self.hidden_dim]
        )
        x = self.add_cls_emb([cls_emb, x])

        # Add positional embeddings
        x = x + self.pos_emb

        # Pass through transformer encoder
        for block in self.transformer_encoder:
            x = block(x, training= training)
            
        # Classify
        if self.include_head:
            x = self.mlp_head(x[:, 0])
        return x

####  Create ViT-B16 and verify we get the same number of parameters as the model from the original paper

In [None]:
config = {
    "image_size": 224,
    "patch_size": 16,
    "depth": 12,
    "num_heads": 12,
    "hidden_dim": 768,
    "mlp_dim": 3072,
    "num_classes": 1000,
    "dropout": 0.1,
    "name": "ViT-B16",
    "include_head": True
}

vit_model = ViT(config['image_size'],
              config['patch_size'],
              config['depth'],
              config['num_classes'],
              config['hidden_dim'],
              config['num_heads'],
              config['mlp_dim'],
              config['dropout'],
              config['name'],
              include_head= config['include_head'])

#### Build as functional model and get output shapes

In [None]:
# Build as functional and print the model summary
vit_model = vit_model.build_functional()
vit_model.summary()

# Train a small ViT model on MNIST

In [None]:
import tensorflow_datasets as tfds
from tensorflow.keras.optimizers import AdamW

# ------------------------------------------------------------------------------------------------
# Load dataset and prepare model
# ------------------------------------------------------------------------------------------------
ds, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
ds_train, ds_test = ds['train'], ds['test']


def preprocess_image(image, label):
    image = tf.image.resize(image, (28, 28)) 
    image = tf.cast(image, tf.float32) / 255.0
    return image, tf.cast(label, tf.float32)


ds_train = ds_train.map(preprocess_image)
ds_test = ds_test.map(preprocess_image)


batch_size = 128
ds_train = ds_train.batch(batch_size).shuffle(10000).prefetch(tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

model = ViT(
    image_size=28,
    patch_size=4,
    depth=6,
    num_classes=10,
    hidden_dim=64,
    num_heads=4,
    mlp_dim=128,
    dropout=0.1,
    name="ViT_MNIST",
    include_head=True,
    channels=1
).build_functional()

print(model.summary())
#-------------------------------------------------------------------------------------------------
# Prepare training arguments
# ------------------------------------------------------------------------------------------------
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
initial_learning_rate = 1e-3
decay_steps = len(ds_train) * (150 - 10)
alpha = 1e-5 / initial_learning_rate
warmup_steps = len(ds_train) * 10
lr_schedule = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate, decay_steps, warmup_target=1e-5,
    warmup_steps=warmup_steps
)
optimizer = AdamW(learning_rate=lr_schedule, weight_decay=5e-5)

model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
# ------------------------------------------------------------------------------------------------
# Train the model
# ------------------------------------------------------------------------------------------------
history = model.fit(ds_train.take(1), validation_data=ds_test.take(1), epochs=1, callbacks = [early_stopping])

## Check metrics from training

In [None]:
def plot_training_history(history):
    # Plot training & validation accuracy values
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')

    # Plot training & validation loss values
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper right')

    plt.tight_layout()
    plt.show()

plot_training_history(history)

## Try the model on some random images from the test set

In [None]:
def random_predictions(model, dataset, ds_info, num_images=8):
    class_names = {idx: name for idx, name in enumerate(ds_info.features['label'].names)}
    
    # Get a random sample of num_images from the dataset
    random_indices = np.random.choice(len(dataset), size=num_images, replace=False)
    ds_subset = dataset.unbatch().skip(random_indices[0]).take(num_images).batch(num_images)

    # Make predictions
    images, labels = next(iter(ds_subset))
    predicted_logits = model.predict(images)
    predicted_probabilities = tf.nn.softmax(predicted_logits, axis=-1)
    predicted_classes = np.argmax(predicted_probabilities, axis=-1)

    # Display results
    plt.figure(figsize=(15, 10))
    for i in range(num_images):
        plt.subplot(2, 4, i+1)
        plt.imshow((images[i].numpy() * 255).astype(np.uint8))
        if predicted_classes[i] == labels[i].numpy():
            color = 'green'
        else:
            color = 'red'
        
        plt.title(f'True: {class_names[labels[i].numpy()]}\nPredicted: {class_names[predicted_classes[i]]}', color=color)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

random_predictions(model, ds_test, ds_info, num_images=8)

In [None]:
test_accuracy = model.evaluate(ds_test)[1]
print(f'Test accuracy: {round(test_accuracy, 2) * 100} %')