# Image classification with vision transformer using / without a convolutional stem
Source: https://keras.io/examples/vision/image_classification_with_vision_transformer/#build-the-vit-model






In [None]:
import json
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

In [None]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
config = tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
import pickle
path = '/kaggle/input/miniimagenet/mini-imagenet-cache-train.pkl'
with open(path, 'rb') as f:
    data = pickle.load(f)

x = np.array(data['image_data'])
y = np.array([np.array(int(i / 600)) for i in range(38400)])
del(data)

x_train = np.array(x[0 * 600:0 * 600 + 540])
for i in range(1, 64):
    x_train = np.concatenate((x_train, x[i * 600:i * 600 + 540]), axis=0)

y_train = np.array(y[0 * 600:0 * 600 + 540])
for i in range(1, 64):
    y_train = np.concatenate((y_train, y[i * 600:i * 600 +540]), axis=0)
y_train = np.reshape(y_train, (34560,1))    

x_test = np.array(x[0 * 600 + 540:1 * 600])
for i in range(1, 64):
    x_test = np.concatenate((x_test, x[i * 600 + 540:(i+1) * 600]), axis=0)

y_test = np.array(y[0 * 600 + 540:1 * 600])
for i in range(1, 64):
    y_test = np.concatenate((y_test, y[i * 600+540:(i+1) * 600]), axis=0)
y_test = np.reshape(y_test, (3840,1))   

del(x)
del(y)

num_classes = 64
input_shape = (84, 84, 3)

In [None]:
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 120
num_epochs = 100 # 50, 200
image_size = 224  
projection_dim = 64
num_heads = 3
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  

# Size of the transformer layers
transformer_layers = 12
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier

## Use data augmentation

In [None]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

## Implement multilayer perceptron (MLP)

In [None]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

## Build the ViTC model


In [None]:
def create_vitc_classifier(stem='1GF'):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)

    '''
    Here we'd want to use a convolutional layer instead of patches
    to get the encoded image
    '''
    cnn_stem = keras.Sequential()
    
    if stem == '18GF':
        # using filters [64, 128, 128, 256, 256, 512]
    
        cnn_stem.add(layers.Conv2D(64, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(128, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(128, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(256, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(256, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(512, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
    elif stem == '4GF':
        # using filters [48, 96, 192, 384]
        
        cnn_stem.add(layers.Conv2D(48, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(96, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(192, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(384, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
    elif stem == '1GF': 
        # using filters [24, 48, 96, 192]
        
        cnn_stem.add(layers.Conv2D(24, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(48, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(96, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
        cnn_stem.add(layers.Conv2D(192, 3, strides=(2,2)))
        cnn_stem.add(layers.BatchNormalization())
        cnn_stem.add(layers.ReLU())
    else: 
        print("Invalid stem design given!")
        return -1
    
    cnn_stem.add(layers.Conv2D(projection_dim, 1, strides=(1,1)))    
    
    encoded_cnn = cnn_stem(augmented)
    #encoded_cnn = tf.reshape(encoded_cnn, (-1,192, 64))
    
    '''
    The part below for the transformer block is unchanged
    '''
    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers-1):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_cnn)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_cnn])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_cnn = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_cnn)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

## Build the ViT model


In [None]:
patch_size = 16  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
    
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded
    
def create_vit_classifier(size="1GF"):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
    
    if size == "18GF":
        num_heads = 12
    elif size == "4GF":
        num_heads = 6
    elif size == '1GF': 
        num_heads = 3
    else:
        print("error!")
        return -1
    
    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

## Compile, train, and evaluate the mode

In [None]:
def run_experiment(model):
#     optimizer = tfa.optimizers.AdamW(
#         learning_rate=learning_rate, weight_decay=weight_decay
#     )
    optimizer = tf.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_data=(x_test, y_test),
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


The training is performed below, where changing the size of the model (1GF, 4GF, 18GF) is adjusted in the corresponding position of the code.

In [None]:
vitc_classifier = create_vitc_classifier(stem='1GF')
history = run_experiment(vitc_classifier)
with open("vitc_1gf_mini_imagenet_100epochs.json", "w") as f:
    json.dump(history.history, f)

In [None]:
vit_classifier = create_vit_classifier(size='1GF')
history = run_experiment(vit_classifier)
with open("vitp_1gf_mini_imagenet_100epochs.json", "w") as f:
    json.dump(history.history, f)