<a href="https://colab.research.google.com/github/maktaurus/ML-Work/blob/main/Vision_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformer

In [None]:
import tensorflow as tf
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras.utils import load_img, img_to_array, array_to_img
from tensorboard.plugins.hparams import api as hp
from tensorflow.keras import layers
from tensorflow.keras.metrics import Accuracy,TruePositives,TrueNegatives,FalsePositives,FalseNegatives,Recall,Precision,AUC
import zipfile
import tensorflow_hub as hub

from tensorflow.keras import layers

In [None]:
augment_layer = tf.keras.Sequential([
    layers.Resizing(224,224),
    layers.Rescaling(1./127.5,offset=-1)
])

In [None]:
# Hyperparameters
img_size = 224
patch_size = 16
num_patches = (img_size // patch_size) **2
projection_dims = 786
num_heads = 4
transformer_units = [projection_dims*2,projection_dims]
transformer_layer = 8
mlp_units = [2048,1024]
batch_size = 1

In [None]:
def mlp(x,mlp_units,dropout):
  for units in mlp_units:
    x = layers.Dense(units,activation=tf.nn.gelu)(x)
    x = layers.Dropout(dropout)(x)
  return x

In [None]:
class Patches(layers.Layer):
  def __init__(self,path_size,batch_size):
    super().__init__()
    self.patch_size = patch_size
    self.batch_size = batch_size

  def call(self,x):
    patches = tf.image.extract_patches(x,
                                       sizes = [1,self.patch_size,self.patch_size,1],
                                       strides = [1,self.patch_size,self.patch_size,1],
                                       rates = [1,1,1,1],
                                       padding = "VALID")
    patches = tf.reshape(patches, (self.batch_size,-1,patches.shape[-1]))
    return patches

In [None]:
img = load_img("/content/bb.jpg",target_size=(224,224))
pp = Patches(patch_size,batch_size)(tf.expand_dims(img,axis=0))
pp.shape

TensorShape([1, 196, 768])

In [None]:
class PatchEncoder(layers.Layer):
  def __init__(self,projection_dims,num_patches):
    super().__init__()
    self.num_patches = num_patches
    self.projection = layers.Dense(projection_dims)
    self.pos_emb = layers.Embedding(num_patches,projection_dims)

  def call(self,x):
    positions = tf.range(start=0,limit=self.num_patches,delta=1)
    projections = self.projection(x)
    encoded = projections + self.pos_emb(positions)
    return encoded

In [None]:
pe = PatchEncoder(projection_dims,num_patches)(pp)
pe

In [None]:
def vit_classifier():
    input = layers.Input(shape=(224,224,3))
    augment = augment_layer(input)
    patches = Patches(patch_size,batch_size)(augment)
    encoder = PatchEncoder(projection_dims,num_patches)(patches)

    for _ in range(transformer_layer):
      x1 = layers.LayerNormalization(epsilon=1e-6)(encoder)
      atten = layers.MultiHeadAttention(num_heads=num_heads,key_dim=projection_dims)(x1,x1)

      x2 = layers.Add()([atten,encoder])

      x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
      x3 = mlp(x3,transformer_units,dropout=0.1)

      encoder = layers.Add()([x3,x2])

    repre = layers.LayerNormalization(epsilon=1e-6)(encoder)
    repre = layers.Flatten()(repre)
    repre = layers.Dropout(0.5)(repre)

    repre = mlp(repre,mlp_units,dropout=0.1)

    logits = layers.Dense(120,activation="softmax")(repre)

    model = tf.keras.Model(input,logits)

    return model


In [None]:
vit_model = vit_classifier()
vit_model.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_4 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 sequential (Sequential)     (None, 224, 224, 3)          0         ['input_4[0][0]']             
                                                                                                  
 patches_4 (Patches)         (1, None, 768)               0         ['sequential[3][0]']          
                                                                                                  
 patch_encoder_4 (PatchEnco  (1, 196, 786)                758490    ['patches_4[0][0]']           
 der)                                                                                       