### Vision Transformer

In [1]:
PATCH_SIZE = 16
IMAGE_SIZE = 256
num_patche = (IMAGE_SIZE // PATCH_SIZE) ** 2 # 256/16 = 16 this is used to calculate the number of patches
projection_dim = 64 # This is the dimension of the output of the linear layer
num_heads = 4 # This is the number of heads in the multi-head attention layer
transformer_units = [
    projection_dim * 2,
    projection_dim,
]
transformer_layers = 8 # This is the number of transformer layers
mlp_head_units = [2048, 1024] # This is the number of units in the MLP head

Set up patch creation as a layer

In [3]:
import tensorflow as tf

In [4]:
class Patches(tf.keras.layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0] # This is the batch size
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1], # This is the size of the patch this is used for extracting patches for the image
            strides = [1, self.patch_size, self.patch_size, 1], # This is the stride of the patch 
            rates = [1, 1, 1, 1], # This is the rate of the patch this is used for dilated convolutions
            padding = "VALID", # This is the padding of the patch vaild means no padding
        ) # This is the function to extract patches from the image
        patch_dims = patches.shape[-1] # This is the dimension of the patch
        patches = tf.reshape(patches, [batch_size, -1, patch_dims]) # This is the reshaping of the patches 
        return patches 

In [6]:
class PatchEncoder(tf.keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = tf.keras.layers.Dense(units=projection_dim)
        self.position_embedding = tf.keras.layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )
    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded