# Image classification with vision transformer using a convolutional stem
Taken from an example implementation of ViT by Khalid Salama

Source: https://keras.io/examples/vision/image_classification_with_vision_transformer/

---

Computer Vision Project WS2021/22

By Maria R. Lily Djami

This notebook shows the usage of the ViTc model we implemented in for this project, as well as the ViT model. The models and training method are defined in the same notebook so that the notebook can easily be uploaded to platforms such as Kaggle and be run on GPUs since we do not otherwise own machines with GPU.


In [None]:
import json
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

In [None]:
# Make sure that the GPU is being used by Tensorflow
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
config = tf.config.experimental.set_memory_growth(physical_devices[0], True)

## Dataset and Hyperparameter

In [None]:
# Load dataset
num_classes = 10
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

In [None]:
# hyperparameter definitions
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 10
image_size = 72 # original example uses 72, but this is too small to be used with 18GF stem 
projection_dim = 64
num_heads = 3
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  

# Size of the transformer layers
transformer_layers = 12
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier

In [None]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

## Implement multilayer perceptron (MLP)

In [None]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

## Build the ViTp model

This is the original vision transformer model. Implementation is taken from https://keras.io/examples/vision/image_classification_with_vision_transformer/. 

Adjustments were made so that the size of the ViTp model can be easily adjusted using a function parameter, similar to how it is with the ViTc model.

In [None]:
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2

# Patchify operation is implemented as 2 layers,
# the Patches layer and the PatchEncoder layer.

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
    
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded
    
def create_vit_classifier(size="1GF"):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)

    ### START OF STEM ###   
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
    ### END OF STEM ###   

    if size == "18GF":
        num_heads = 12
    elif size == "4GF":
        num_heads = 6
    elif size == '1GF': 
        num_heads = 3
    else:
        print("error!")
        return -1
    
    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

## Build the ViTc model

The ViTc model introduced in [...] replaces the patchify stem with convolutional layers. The model below implements the 1GF, 4GF, and 18GF variants of the ViTc model. They are currently hardcoded into the model.

In [None]:
def create_vitc_classifier(stem='1GF', kernel=3):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)

    ####################################################
    #   Here is the stem of the vision transformer.    #
    #   For ViTc, we want to use convolutional layers  #
    #   instead of patchify to get the encoded image   #
    ####################################################
    
    cnn_stem = keras.Sequential()
    
    if stem == '18GF':
        num_heads = 12

        cnn_stem.add(layers.Conv2D(64, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(128, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(128, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(256, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(256, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(512, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
    elif stem == '4GF':
        num_heads = 6
        
        cnn_stem.add(layers.Conv2D(48, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(96, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(192, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(384, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
    elif stem == '1GF': 
        num_heads = 3
        
        cnn_stem.add(layers.Conv2D(24, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(48, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(96, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(192, kernel, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
    else: 
        print("Invalid stem design given!")
        return -1
    
    cnn_stem.add(layers.Conv2D(projection_dim, 1, strides=(1,1)))    
    
    encoded_cnn = cnn_stem(augmented)
    
    ######################################################
    #   The code block below is the transformer block.   #
    #   This remains the same between ViTp and ViTc.     #
    ######################################################

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers-1):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_cnn)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_cnn])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_cnn = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_cnn)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

## Compile, train, and evaluate the mode

In [None]:
def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    # this path may need to be changed accordingly.
    # sometimes errors occur if the full path is not given.
    checkpoint_filepath = "tmp"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


## Run Experiments

Run the experiment with ViTp and ViTc. Results of training (training accuracy and loss, as well as validation accuracy and loss) are saved into a json file, which can be reimported back and plotted using matplotlib.

In [None]:
vitc_classifier = create_vitc_classifier(stem='1GF')
history = run_experiment(vitc_classifier)
json.dump(history.history, open("vitc_1gf_cifar10_10epochs.json", "w"))

In [None]:
vit_classifier = create_vit_classifier(size='1GF')
history = run_experiment(vit_classifier)
json.dump(history.history, open("vitp_1gf_cifar10_10epochs.json", "w"))