An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale: https://arxiv.org/pdf/2010.11929.pdf


outputs = sum(values * pairwise_scores( query, keys ))

In [1]:
import tensorflow as tf
from tensorflow.keras import layers

class MultiHeadSelfAttention(layers.Layer):
    def __init__(self, num_heads, d_model):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.depth = d_model // self.num_heads
        
        self.query_dense = layers.Dense(units=d_model)
        self.key_dense = layers.Dense(units=d_model)
        self.value_dense = layers.Dense(units=d_model)
        
        self.dense = layers.Dense(units=d_model)

    def attention(self, query, key, value):
        # Calculate dot product attention
        dot_product = tf.matmul(query, key, transpose_b=True)
        dot_product = dot_product / tf.sqrt(tf.cast(self.depth, dtype=tf.float32))
        weights = tf.nn.softmax(dot_product)

        # Calculate weighted sum of values
        output = tf.matmul(weights, value)
        return output, weights
        
    def split_heads(self, inputs, batch_size):
        inputs = tf.reshape(inputs, shape=(batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(inputs, perm=[0, 2, 1, 3])
    
    def call(self, inputs):
        # Linear layers
        query = self.query_dense(inputs)
        key = self.key_dense(inputs)
        value = self.value_dense(inputs)
        
        # Split heads
        batch_size = tf.shape(query)[0]
        query = self.split_heads(query, batch_size)
        key = self.split_heads(key, batch_size)
        value = self.split_heads(value, batch_size)
        
        # Calculate dot product attention
        output, weights = self.attention(query, key, value)
        
        # Concatenate heads
        output = tf.transpose(output, perm=[0, 2, 1, 3])
        output = tf.reshape(output, shape=(batch_size, -1, self.d_model))
        
        # Final linear layer
        output = self.dense(output)
        
        return output


In [2]:
from tensorflow.keras import layers

class TransformerBlock(layers.Layer):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.att = MultiHeadSelfAttention(num_heads, d_model)
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        
        self.ffn = tf.keras.Sequential([
            layers.Dense(4*d_model, activation="relu"),
            layers.Dense(d_model)
        ])
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)
        
    def call(self, inputs, training):
        attn_output = self.att(inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

The GELU Paper: https://arxiv.org/pdf/1606.08415v3.pdf

 tf.math.erf computes the Gauss Error Function: https://www.tensorflow.org/api_docs/python/tf/math/erf

https://www.wikiwand.com/en/Error_function





In [3]:
def gelu(x):
    cdf = 0.5 * (1.0 + tf.math.erf(x / tf.sqrt(2.0)))
    #CDF stands for Cumulative Distribution Function
    return x * cdf

In [4]:
import tensorflow as tf
from tensorflow.keras import Model

class VisionTransformer(Model):
    def __init__(self, num_layers, d_model, num_heads, image_size, patch_size, dropout=0.1):
        super(VisionTransformer, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.encoder_layers = [TransformerBlock(d_model, num_heads, dropout) for _ in range(num_layers)]
        self.layernorm = layers.LayerNormalization(epsilon=1e-6)
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(units=64, activation=gelu)
        self.classifier = layers.Dense(units=10, activation="softmax")
        
    def extract_patches(self, images):
        patches = tf.image.extract_patches(images, sizes=[1, self.patch_size, self.patch_size, 1], strides=[1, self.patch_size, self.patch_size, 1], rates=[1, 1, 1, 1], padding='VALID')
        patches = tf.reshape(patches, shape=(-1, self.patch_size, self.patch_size, 1))
        return patches
    
    def call(self, x, training):
        x = tf.image.resize(x, size=[self.image_size, self.image_size])
        x = self.extract_patches(x)
        for layer in self.encoder_layers:
            x = layer(x, training)
        x = self.layernorm(x)
        x = self.flatten(x)
        x = self.dense(x)
        return self.classifier(x)


In [5]:
img_height=32
patch_size=4
num_layers=8  
d_model=64 
num_heads=4

In [6]:
model = VisionTransformer(
            image_size=img_height,
            patch_size= patch_size,
            num_layers= num_layers,
            d_model=d_model,
            num_heads=num_heads,
            dropout=0.1,
        )

In [7]:
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            optimizer=tf.optimizers.Adam(learning_rate=0.01),
            metrics=["accuracy"],
        )

In [8]:
model.summary

<bound method Model.summary of <__main__.VisionTransformer object at 0x7feefd5a0700>>