In [343]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [344]:
import tensorflow as tf
print(tf.__version__)

In [345]:
pip install -qq -U tensorflow-addons

In [346]:
import tensorflow_addons as tfa

In [347]:
import numpy as np 
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers

In [348]:
# Hyperparameters
class ConfigureHyperParameters(object):
    # HP related to Data
    batch_size=256
    buffer_size=2*batch_size
    input_shape=(32,32,3) # 32*32
    num_classes=10 # 10 different categories
    
    # HP related to Augmentation
    image_size=48
    
    # HP related to Architecture
    patch_size=4
    projected_dim=96
    num_shift_blocks_per_stages=[2,4,8,2]
    epsilon=1e-5
    stochastic_depth_rate=0.2
    mlp_dropout_rate=0.2
    num_div=12
    shift_pixel=1
    mlp_expand_ratio=2
    
    # HP related to Optimizer
    lr_start=1e-5
    lr_max=1e-3
    weight_decay=1e-4
    
    # HP related to Training
    epochs=100

In [349]:
config=ConfigureHyperParameters()

In [350]:
(x_train,y_train),(x_test,y_test)=keras.datasets.cifar10.load_data()
print(len(x_train))
print(len(x_test))

In [351]:
(x_train,y_train),(x_val,y_val) = (x_train[:40000],y_train[:40000]),(x_train[40000:],y_train[40000:])

In [352]:
print(len(x_train))
print(len(x_val))

In [353]:
auto=tf.data.AUTOTUNE

train_dataset=tf.data.Dataset.from_tensor_slices((x_train,y_train)) # 2/3
train_dataset=train_dataset.shuffle(config.buffer_size).batch(config.batch_size).prefetch(auto)

validation_dataset=tf.data.Dataset.from_tensor_slices((x_val,y_val)) # 1/6
validation_dataset=validation_dataset.batch(config.batch_size).prefetch(auto)

test_dataset=tf.data.Dataset.from_tensor_slices((x_test,y_test)) # 1/6
test_dataset=test_dataset.batch(config.batch_size).prefetch(auto)

In [354]:
def data_augmentation():
    data_aug=keras.Sequential(
    [
        layers.Resizing(config.input_shape[0]+20,config.input_shape[0]+20),
        layers.RandomCrop(config.image_size,config.image_size),
        layers.RandomFlip("horizontal"),
        layers.Rescaling(1/255.0), # 255.0 - max pixel value
    ])
    return data_aug

In [355]:
class MLP(layers.Layer):
    """ MLP for each shift block"""
    
    def __init__(self,mlp_expand_ratio,mlp_dropout_rate,**kwargs):
        super().__init__(**kwargs)
        self.mlp_expand_ratio=mlp_expand_ratio
        self.mlp_dropout_rate=mlp_dropout_rate
    
    def build(self,input_shape):
        input_channels=input_shape[-1] # last-index, no of channels (32,32,3)
        initial_filters=int(self.mlp_expand_ratio*input_channels) # 6
        
        self.mlp=keras.Sequential(
        [
            layers.Dense(units=initial_filters,activation=tf.nn.gelu,), # Gaussian Error Linear Unit
            layers.Dropout(rate=self.mlp_dropout_rate),
            layers.Dense(units=input_channels),
            layers.Dropout(rate=self.mlp_dropout_rate)
        ])
    
    def call(self,x):
        x=self.mlp(x)
        return x

In [356]:
class DropPath(layers.Layer):
    
    def __init__(self,drop_path_prob,**kwargs):
        super().__init__(**kwargs)
        self.drop_path_prob=drop_path_prob
    
    def call(self,x,training=False): # drop the random tensors
        if training:
            keep_prob=1-self.drop_path_prob
            shape=(tf.shape(x)[0],)+(1,)*(len(tf.shape(x))-1)
            random_tensor=keep_prob + tf.random.uniform(shape,0,1)
            random_tensor=tf.floor(random_tensor)
            return (x/keep_prob)* random_tensor
        return x

In [357]:
class shiftViTBlock(layers.Layer):
    
    def __init__(self,epsilon,drop_path_prob,mlp_dropout_rate,num_div=12,shift_pixel=1,mlp_expand_ratio=2,**kwargs):
        super().__init__(**kwargs)
        self.shift_pixel=shift_pixel
        self.mlp_expand_ratio=mlp_expand_ratio
        self.mlp_dropout_rate=mlp_dropout_rate
        self.num_div=num_div
        self.epsilon=epsilon
        self.drop_path_prob=drop_path_prob
    
    def build(self,input_shape):
        self.H=input_shape[1]
        self.W=input_shape[2]
        self.C=input_shape[3]
        self.layer_norm=layers.LayerNormalization(epsilon=self.epsilon)
        self.drop_path=(DropPath(drop_path_prob=self.drop_path_prob) if self.drop_path_prob>0.0 else layers.Activation('linear'))
        self.mlp=MLP(mlp_expand_ratio=self.mlp_expand_ratio,mlp_dropout_rate=self.mlp_dropout_rate)
    
    def get_shift_pad(self,x,mode):
        if mode=='left':
            offset_height=0
            offset_width=0
            target_height=0
            target_width=self.shift_pixel
        elif mode=='right':
            offset_height=0
            offset_width=self.shift_pixel
            target_height=0
            target_width=self.shift_pixel
        elif mode=='up':
            offset_height=0
            offset_width=0
            target_height=self.shift_pixel
            target_width=0
        else:
            offset_height=self.shift_pixel
            offset_width=0
            target_height=self.shift_pixel
            target_width=0
        
        crop=tf.image.crop_to_bounding_box(x,offset_height=offset_height,offset_width=offset_width,target_height=self.H-target_height,target_width=self.W-target_width)
        shift_pad=tf.image.pad_to_bounding_box(crop,offset_height=offset_height,offset_width=offset_width,target_height=self.H,target_width=self.W)
        return shift_pad
    
    def call(self,x,training=False):
        # feature maps
        x_splits=tf.split(x,num_or_size_splits=self.C//self.num_div,axis=-1)
        
        # shift feature maps
        x_splits[0]=self.get_shift_pad(x_splits[0],mode="left")
        x_splits[1]=self.get_shift_pad(x_splits[1],mode='right')
        x_splits[2]=self.get_shift_pad(x_splits[2],mode='up')
        x_splits[3]=self.get_shift_pad(x_splits[3],mode='down')
        
        x=tf.concat(x_splits,axis=-1)
        
        temp=x
        x=temp+self.drop_path(self.mlp(self.layer_norm(x)),training=training)
        return x

In [358]:
class PatchMerging(layers.Layer):
    
    def __init__(self,epsilon,**kwargs):
        super().__init__(**kwargs)
        self.epsilon=epsilon
    
    def build(self,input_shape):
        filters=2*input_shape[-1] # (32,32,3) -> 6
        self.reduction=layers.Conv2D(filters=filters,kernel_size=2,strides=2,padding="same",use_bias=False)
        self.layer_norm=layers.LayerNormalization(epsilon=self.epsilon)
    
    def call(self,x):
        x=self.layer_norm(x)
        x=self.reduction(x)
        return x

In [359]:
class StackedShiftBlocks(layers.Layer):
    
    def __init__(self,epsilon,mlp_dropout_rate,num_shift_blocks,stochastic_depth_rate,is_merge,num_div=12,shift_pixel=1,mlp_expand_ratio=2,**kwargs):
        super().__init__(**kwargs)
        self.epsilon=epsilon
        self.mlp_dropout_rate=mlp_dropout_rate
        self.num_shift_blocks=num_shift_blocks
        self.stochastic_depth_rate=stochastic_depth_rate
        self.is_merge=is_merge
        self.num_div=num_div
        self.shift_pixel=shift_pixel
        self.mlp_expand_ratio=mlp_expand_ratio
    
    def build(self,input_shapes):
        dpr=[x for x in np.linspace(start=0,stop=self.stochastic_depth_rate,num=self.num_shift_blocks)]
        self.shift_blocks=list()
        for num in range(self.num_shift_blocks):
            self.shift_blocks.append(shiftViTBlock(num_div=self.num_div,epsilon=self.epsilon,drop_path_prob=dpr[num],mlp_dropout_rate=self.mlp_dropout_rate,shift_pixel=self.shift_pixel,mlp_expand_ratio=self.mlp_expand_ratio))
        if self.is_merge:
            self.patch_merge=PatchMerging(epsilon=self.epsilon)
    
    def call(self,x,training=False):
        for shift_block in self.shift_blocks:
            x=shift_block(x,training=training)
        if self.is_merge:
            x=self.patch_merge(x)
        return x
            

In [360]:
class ShiftViTModel(keras.Model):
    
    def __init__(self,data_augmentation,projected_dim,patch_size,num_shift_blocks_per_stages,epsilon,mlp_dropout_rate,stochastic_depth_rate,num_div=12,shift_pixel=1,mlp_expand_ratio=2,**kwargs):
        super().__init__(**kwargs)
        self.data_augmentation=data_augmentation
        self.patch_projection=layers.Conv2D(filters=projected_dim,kernel_size=patch_size,strides=patch_size,padding='same')
        self.stages=list()
        for index,num_shift_blocks in enumerate(num_shift_blocks_per_stages):
            if index==len(num_shift_blocks_per_stages)-1:
                is_merge=False
            else:
                is_merge=True
            self.stages.append(StackedShiftBlocks(epsilon=epsilon,mlp_dropout_rate=mlp_dropout_rate,num_shift_blocks=num_shift_blocks,stochastic_depth_rate=stochastic_depth_rate,is_merge=is_merge,num_div=num_div,shift_pixel=shift_pixel,mlp_expand_ratio=mlp_expand_ratio))
        self.global_avg_pool=layers.GlobalAveragePooling2D()
    
    def get_config(self):
        config=super().get_config()
        config.update(
        {
            'data_augmentation': self.data_augmentation,
            'patch_projection': self.patch_projection,
            'stages':self.stages,
            'global_avg_pool':self.global_avg_pool,
        })
        return config
    
    def _calculate_loss(self,data,training=False):
        (images,labels)=data
        augmented_images=self.data_augmentation(images,training=training)
        projected_patches=self.patch_projection(augmented_images)
        x=projected_patches
        for stage in self.stages:
            x=stage(x,training=training)
        
        logits=self.global_avg_pool(x)
        total_loss=self.compiled_loss(labels,logits)
        return total_loss,labels,logits
    
    def train_step(self,inputs):
        with tf.GradientTape() as tape:
            total_loss,labels,logits=self._calculate_loss(data=inputs,training=True)
        
        train_vars=[self.data_augmentation.trainable_variables,self.patch_projection.trainable_variables,self.global_avg_pool.trainable_variables]
        train_vars=train_vars+[stage.trainable_variables for stage in self.stages]
        
        # optimise the gradients
        grads=tape.gradient(total_loss,train_vars)
        trainable_variable_list=[]
        for (grad,var) in zip(grads,train_vars):
            for g,v in zip(grad,var):
                trainable_variable_list.append((g,v))
        self.optimizer.apply_gradients(trainable_variable_list)
        
        self.compiled_metrics.update_state(labels,logits)
        return {m.name:m.result() for m in self.metrics}
    
    def test_step(self,data):
        _,labels,logits=self._calculate_loss(data=data,training=False)
        self.compiled_metrics.update_state(labels,logits)
        return {m.name:m.result() for m in self.metrics}
            

In [361]:
# call the model
model=ShiftViTModel(
data_augmentation=data_augmentation(),
projected_dim=config.projected_dim,
patch_size=config.patch_size,
num_shift_blocks_per_stages=config.num_shift_blocks_per_stages,
epsilon=config.epsilon,
mlp_dropout_rate=config.mlp_dropout_rate,
stochastic_depth_rate=config.stochastic_depth_rate,
num_div=config.num_div,
shift_pixel=config.shift_pixel,
mlp_expand_ratio=config.mlp_expand_ratio)

In [362]:
# learning rate scheduler

class LearningRateScheduler(keras.optimizers.schedules.LearningRateSchedule):
    
    def __init__(self,lr_start,lr_max,warmup_steps,total_steps):
        super().__init__()
        self.lr_start=lr_start
        self.lr_max=lr_max
        self.warmup_steps=warmup_steps
        self.total_steps=total_steps
        self.pi=tf.constant(np.pi)
    
    def __call__(self,step):
        
        if self.total_steps<self.warmup_steps:
            raise ValueError('something went wrong')
        
        cos_lr=tf.cos(self.pi*(tf.cast(step,tf.float32)-self.warmup_steps)/tf.cast(self.total_steps-self.warmup_steps,tf.float32))
        learning_rate=0.5*self.lr_max*(1+cos_lr)
        
        if self.warmup_steps>0:
            if self.lr_max<self.lr_start:
                raise ValueError('start value is greather max value')
            slope=(self.lr_max-self.lr_start)/self.warmup_steps
            warmup_rate=slope*tf.cast(step,tf.float32)+self.lr_start
            learning_rate=tf.where(step<self.warmup_steps,warmup_rate,learning_rate)
        return tf.where(step>self.total_steps,0.0,learning_rate,name='learning_rate')

In [363]:
# compile and train the model
total_steps=int((len(x_train)/config.batch_size)*config.epochs)

warmup_epoch_percentage=0.15
warmup_steps=int(total_steps*warmup_epoch_percentage)

scheduled_lrs=LearningRateScheduler(lr_start=1e-5,lr_max=1e-3,warmup_steps=warmup_steps,total_steps=total_steps)

optimizer=tfa.optimizers.AdamW(learning_rate=scheduled_lrs,weight_decay=config.weight_decay)

model.compile(optimizer=optimizer,loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy(name='accuracy'),keras.metrics.SparseTopKCategoricalAccuracy(5,name='top-5-accuracy')])

In [None]:
history=model.fit(train_dataset,epochs=config.epochs,validation_data=validation_dataset,callbacks=[keras.callbacks.EarlyStopping(monitor='val_accuracy',patience=5,mode='auto')])