diff --git a/examples/vision/ipynb/shiftvit.ipynb b/examples/vision/ipynb/shiftvit.ipynb
new file mode 100644
index 0000000000..c4e0180dcb
--- /dev/null
+++ b/examples/vision/ipynb/shiftvit.ipynb
@@ -0,0 +1,1101 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# A Vision Transformer without Attention\n",
+ "\n",
+ "**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
\n",
+ "**Date created:** 2022/02/24
\n",
+ "**Last modified:** 2022/03/01
\n",
+ "**Description:** A minimal implementation of ShiftViT."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Introduction\n",
+ "\n",
+ "[Vision Transformers](https://arxiv.org/abs/2010.11929) (ViTs) have sparked a wave of\n",
+ "research at the intersection of Transformers and Computer Vision (CV).\n",
+ "\n",
+ "ViTs can simultaneously model long- and short-range dependencies, thanks to\n",
+ "the Multi-Head Self-Attention mechanism in the Transformer block. Many researchers believe\n",
+ "that the success of ViTs are purely due to the attention layer, and they seldom\n",
+ "think about other parts of the ViT model.\n",
+ "\n",
+ "In the academic paper\n",
+ "[When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism](https://arxiv.org/abs/2201.10801)\n",
+ "the authors propose to demystify the success of ViTs with the introduction of a **NO\n",
+ "PARAMETER** operation in place of the attention operation. They swap the attention\n",
+ "operation with a shifting operation.\n",
+ "\n",
+ "In this example, we minimally implement the paper with close alignement to the author's\n",
+ "[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).\n",
+ "\n",
+ "This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can\n",
+ "be installed using the following command:\n",
+ "\n",
+ "```shell\n",
+ "pip install -qq -U tensorflow-addons\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Setup and imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "from tensorflow import keras\n",
+ "from tensorflow.keras import layers\n",
+ "\n",
+ "import tensorflow_addons as tfa\n",
+ "\n",
+ "# Setting seed for reproducibiltiy\n",
+ "SEED = 42\n",
+ "keras.utils.set_random_seed(SEED)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Hyperparameters\n",
+ "\n",
+ "These are the hyperparameters that we have chosen for the experiment.\n",
+ "Please feel free to tune them."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class Config(object):\n",
+ " # DATA\n",
+ " batch_size = 256\n",
+ " buffer_size = batch_size * 2\n",
+ " input_shape = (32, 32, 3)\n",
+ " num_classes = 10\n",
+ "\n",
+ " # AUGMENTATION\n",
+ " image_size = 48\n",
+ "\n",
+ " # ARCHITECTURE\n",
+ " patch_size = 4\n",
+ " projected_dim = 96\n",
+ " num_shift_blocks_per_stages = [2, 4, 8, 2]\n",
+ " epsilon = 1e-5\n",
+ " stochastic_depth_rate = 0.2\n",
+ " mlp_dropout_rate = 0.2\n",
+ " num_div = 12\n",
+ " shift_pixel = 1\n",
+ " mlp_expand_ratio = 2\n",
+ "\n",
+ " # OPTIMIZER\n",
+ " lr_start = 1e-5\n",
+ " lr_max = 1e-3\n",
+ " weight_decay = 1e-4\n",
+ "\n",
+ " # TRAINING\n",
+ " epochs = 100\n",
+ "\n",
+ "\n",
+ "config = Config()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Load the CIFAR-10 dataset\n",
+ "\n",
+ "We use the CIFAR-10 dataset for our experiments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n",
+ "(x_train, y_train), (x_val, y_val) = (\n",
+ " (x_train[:40000], y_train[:40000]),\n",
+ " (x_train[40000:], y_train[40000:]),\n",
+ ")\n",
+ "print(f\"Training samples: {len(x_train)}\")\n",
+ "print(f\"Validation samples: {len(x_val)}\")\n",
+ "print(f\"Testing samples: {len(x_test)}\")\n",
+ "\n",
+ "AUTO = tf.data.AUTOTUNE\n",
+ "train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
+ "train_ds = train_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO)\n",
+ "\n",
+ "val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))\n",
+ "val_ds = val_ds.batch(config.batch_size).prefetch(AUTO)\n",
+ "\n",
+ "test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
+ "test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Data Augmentation\n",
+ "\n",
+ "The augmentation pipeline consists of:\n",
+ "\n",
+ "- Rescaling\n",
+ "- Resizing\n",
+ "- Random cropping\n",
+ "- Random horizontal flipping\n",
+ "\n",
+ "_Note_: The image data augmentation layers do not apply\n",
+ "data transformations at inference time. This means that\n",
+ "when these layers are called with `training=False` they\n",
+ "behave differently. Refer to the\n",
+ "[documentation](https://keras.io/api/layers/preprocessing_layers/image_augmentation/)\n",
+ "for more details."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def get_augmentation_model():\n",
+ " \"\"\"Build the data augmentation model.\"\"\"\n",
+ " data_augmentation = keras.Sequential(\n",
+ " [\n",
+ " layers.Resizing(config.input_shape[0] + 20, config.input_shape[0] + 20),\n",
+ " layers.RandomCrop(config.image_size, config.image_size),\n",
+ " layers.RandomFlip(\"horizontal\"),\n",
+ " layers.Rescaling(1 / 255.0),\n",
+ " ]\n",
+ " )\n",
+ " return data_augmentation\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## The ShiftViT architecture\n",
+ "\n",
+ "In this section, we build the architecture proposed in\n",
+ "[the ShiftViT paper](https://arxiv.org/abs/2201.10801).\n",
+ "\n",
+ "|  |\n",
+ "| :--: |\n",
+ "| Figure 1: The entire architecutre of ShiftViT.\n",
+ "[Source](https://arxiv.org/abs/2201.10801) |\n",
+ "\n",
+ "The architecture as shown in Fig. 1, is inspired by\n",
+ "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030).\n",
+ "Here the authors propose a modular architecture with 4 stages. Each stage works on its\n",
+ "own spatial size, creating a hierarchical architecture.\n",
+ "\n",
+ "An input image of size `HxWx3` is split into non-overlapping patches of size `4x4`.\n",
+ "This is done via the patchify layer which results in individual tokens of feature size `48`\n",
+ "(`4x4x3`). Each stage comprises two parts:\n",
+ "\n",
+ "1. Embedding Generation\n",
+ "2. Stacked Shift Blocks\n",
+ "\n",
+ "We discuss the stages and the modules in detail in what follows.\n",
+ "\n",
+ "_Note_: Compared to the [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py)\n",
+ "we restructure some key components to better fit the Keras API."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### The ShiftViT Block\n",
+ "\n",
+ "|  |\n",
+ "| :--: |\n",
+ "| Figure 2: From the Model to a Shift Block. |\n",
+ "\n",
+ "Each stage in the ShiftViT architecture comprises of a Shift Block as shown in Fig 2.\n",
+ "\n",
+ "|  |\n",
+ "| :--: |\n",
+ "| Figure 3: The Shift ViT Block. [Source](https://arxiv.org/abs/2201.10801) |\n",
+ "\n",
+ "The Shift Block as shown in Fig. 3, comprises of the following:\n",
+ "\n",
+ "1. Shift Operation\n",
+ "2. Linear Normalization\n",
+ "3. MLP Layer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "#### The MLP block\n",
+ "\n",
+ "The MLP block is intended to be a stack of densely-connected layers.s"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class MLP(layers.Layer):\n",
+ " \"\"\"Get the MLP layer for each shift block.\n",
+ "\n",
+ " Args:\n",
+ " mlp_expand_ratio (int): The ratio with which the first feature map is expanded.\n",
+ " mlp_dropout_rate (float): The rate for dropout.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, mlp_expand_ratio, mlp_dropout_rate, **kwargs):\n",
+ " super().__init__(**kwargs)\n",
+ " self.mlp_expand_ratio = mlp_expand_ratio\n",
+ " self.mlp_dropout_rate = mlp_dropout_rate\n",
+ "\n",
+ " def build(self, input_shape):\n",
+ " input_channels = input_shape[-1]\n",
+ " initial_filters = int(self.mlp_expand_ratio * input_channels)\n",
+ "\n",
+ " self.mlp = keras.Sequential(\n",
+ " [\n",
+ " layers.Dense(units=initial_filters, activation=tf.nn.gelu,),\n",
+ " layers.Dropout(rate=self.mlp_dropout_rate),\n",
+ " layers.Dense(units=input_channels),\n",
+ " layers.Dropout(rate=self.mlp_dropout_rate),\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " def call(self, x):\n",
+ " x = self.mlp(x)\n",
+ " return x\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "#### The DropPath layer\n",
+ "\n",
+ "Stochastic depth is a regularization technique that randomly drops a set of\n",
+ "layers. During inference, the layers are kept as they are. It is very\n",
+ "similar to Dropout, but it operates on a block of layers rather\n",
+ "than on individual nodes present inside a layer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class DropPath(layers.Layer):\n",
+ " \"\"\"Drop Path also known as the Stochastic Depth layer.\n",
+ "\n",
+ " Refernece:\n",
+ " - https://keras.io/examples/vision/cct/#stochastic-depth-for-regularization\n",
+ " - github.com:rwightman/pytorch-image-models\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, drop_path_prob, **kwargs):\n",
+ " super().__init__(**kwargs)\n",
+ " self.drop_path_prob = drop_path_prob\n",
+ "\n",
+ " def call(self, x, training=False):\n",
+ " if training:\n",
+ " keep_prob = 1 - self.drop_path_prob\n",
+ " shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)\n",
+ " random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)\n",
+ " random_tensor = tf.floor(random_tensor)\n",
+ " return (x / keep_prob) * random_tensor\n",
+ " return x\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "#### Block\n",
+ "\n",
+ "The most important operation in this paper is the **shift opperation**. In this section,\n",
+ "we describe the shift operation and compare it with its original implementation provided\n",
+ "by the authors.\n",
+ "\n",
+ "A generic feature map is assumed to have the shape `[N, H, W, C]`. Here we choose a\n",
+ "`num_div` parameter that decides the division size of the channels. The first 4 divisions\n",
+ "are shifted (1 pixel) in the left, right, up, and down direction. The remaining splits\n",
+ "are kept as is. After partial shifting the shifted channels are padded and the overflown\n",
+ "pixels are chopped off. This completes the partial shifting operation.\n",
+ "\n",
+ "In the original implementation, the code is approximately:\n",
+ "\n",
+ "```python\n",
+ "out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # shift left\n",
+ "out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # shift right\n",
+ "out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # shift up\n",
+ "out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # shift down\n",
+ "\n",
+ "out[:, g * 4:, :, :] = x[:, g * 4:, :, :] # no shift\n",
+ "```\n",
+ "\n",
+ "In TensorFlow it would be infeasible for us to assign shifted channels to a tensor in the\n",
+ "middle of the training process. This is why we have resorted to the following procedure:\n",
+ "\n",
+ "1. Split the channels with the `num_div` parameter.\n",
+ "2. Select each of the first four spilts and shift and pad them in the respective\n",
+ "directions.\n",
+ "3. After shifting and padding, we concatenate the channel back.\n",
+ "\n",
+ "|  |\n",
+ "| :--: |\n",
+ "| Figure 4: The TensorFlow style shifting |\n",
+ "\n",
+ "The entire procedure is explained in the Fig. 4."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class ShiftViTBlock(layers.Layer):\n",
+ " \"\"\"A unit ShiftViT Block\n",
+ "\n",
+ " Args:\n",
+ " shift_pixel (int): The number of pixels to shift. Default to 1.\n",
+ " mlp_expand_ratio (int): The ratio with which MLP features are\n",
+ " expanded. Default to 2.\n",
+ " mlp_dropout_rate (float): The dropout rate used in MLP.\n",
+ " num_div (int): The number of divisions of the feature map's channel.\n",
+ " Totally, 4/num_div of channels will be shifted. Defaults to 12.\n",
+ " epsilon (float): Epsilon constant.\n",
+ " drop_path_prob (float): The drop probability for drop path.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " epsilon,\n",
+ " drop_path_prob,\n",
+ " mlp_dropout_rate,\n",
+ " num_div=12,\n",
+ " shift_pixel=1,\n",
+ " mlp_expand_ratio=2,\n",
+ " **kwargs,\n",
+ " ):\n",
+ " super().__init__(**kwargs)\n",
+ " self.shift_pixel = shift_pixel\n",
+ " self.mlp_expand_ratio = mlp_expand_ratio\n",
+ " self.mlp_dropout_rate = mlp_dropout_rate\n",
+ " self.num_div = num_div\n",
+ " self.epsilon = epsilon\n",
+ " self.drop_path_prob = drop_path_prob\n",
+ "\n",
+ " def build(self, input_shape):\n",
+ " self.H = input_shape[1]\n",
+ " self.W = input_shape[2]\n",
+ " self.C = input_shape[3]\n",
+ " self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)\n",
+ " self.drop_path = (\n",
+ " DropPath(drop_path_prob=self.drop_path_prob)\n",
+ " if self.drop_path_prob > 0.0\n",
+ " else layers.Activation(\"linear\")\n",
+ " )\n",
+ " self.mlp = MLP(\n",
+ " mlp_expand_ratio=self.mlp_expand_ratio,\n",
+ " mlp_dropout_rate=self.mlp_dropout_rate,\n",
+ " )\n",
+ "\n",
+ " def get_shift_pad(self, x, mode):\n",
+ " \"\"\"Shifts the channels according to the mode chosen.\"\"\"\n",
+ " if mode == \"left\":\n",
+ " offset_height = 0\n",
+ " offset_width = 0\n",
+ " target_height = 0\n",
+ " target_width = self.shift_pixel\n",
+ " elif mode == \"right\":\n",
+ " offset_height = 0\n",
+ " offset_width = self.shift_pixel\n",
+ " target_height = 0\n",
+ " target_width = self.shift_pixel\n",
+ " elif mode == \"up\":\n",
+ " offset_height = 0\n",
+ " offset_width = 0\n",
+ " target_height = self.shift_pixel\n",
+ " target_width = 0\n",
+ " else:\n",
+ " offset_height = self.shift_pixel\n",
+ " offset_width = 0\n",
+ " target_height = self.shift_pixel\n",
+ " target_width = 0\n",
+ " crop = tf.image.crop_to_bounding_box(\n",
+ " x,\n",
+ " offset_height=offset_height,\n",
+ " offset_width=offset_width,\n",
+ " target_height=self.H - target_height,\n",
+ " target_width=self.W - target_width,\n",
+ " )\n",
+ " shift_pad = tf.image.pad_to_bounding_box(\n",
+ " crop,\n",
+ " offset_height=offset_height,\n",
+ " offset_width=offset_width,\n",
+ " target_height=self.H,\n",
+ " target_width=self.W,\n",
+ " )\n",
+ " return shift_pad\n",
+ "\n",
+ " def call(self, x, training=False):\n",
+ " # Split the feature maps\n",
+ " x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1)\n",
+ "\n",
+ " # Shift the feature maps\n",
+ " x_splits[0] = self.get_shift_pad(x_splits[0], mode=\"left\")\n",
+ " x_splits[1] = self.get_shift_pad(x_splits[1], mode=\"right\")\n",
+ " x_splits[2] = self.get_shift_pad(x_splits[2], mode=\"up\")\n",
+ " x_splits[3] = self.get_shift_pad(x_splits[3], mode=\"down\")\n",
+ "\n",
+ " # Concatenate the shifted and unshifted feature maps\n",
+ " x = tf.concat(x_splits, axis=-1)\n",
+ "\n",
+ " # Add the residual connection\n",
+ " shortcut = x\n",
+ " x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training)\n",
+ " return x\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### The ShiftViT blocks\n",
+ "\n",
+ "|  |\n",
+ "| :--: |\n",
+ "| Figure 5: Shift Blocks in the architecture. [Source](https://arxiv.org/abs/2201.10801) |\n",
+ "\n",
+ "Each stage of the architecture has shift blocks as shown in Fig.5. Each of these blocks\n",
+ "contain a variable number of stacked ShiftViT block (as built in the earlier section).\n",
+ "\n",
+ "Shift blocks are followed by a PatchMerging layer that scales down feature inputs. The\n",
+ "PatchMerging layer helps in the pyramidal structure of the model."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "#### The PatchMerging layer\n",
+ "\n",
+ "This layer merges the two adjacent tokens. This layer helps in scaling the features down\n",
+ "spatially and increasing the features up channel wise. We use a Conv2D layer to merge the\n",
+ "patches."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class PatchMerging(layers.Layer):\n",
+ " \"\"\"The Patch Merging layer.\n",
+ "\n",
+ " Args:\n",
+ " epsilon (float): The epsilon constant.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, epsilon, **kwargs):\n",
+ " super().__init__(**kwargs)\n",
+ " self.epsilon = epsilon\n",
+ "\n",
+ " def build(self, input_shape):\n",
+ " filters = 2 * input_shape[-1]\n",
+ " self.reduction = layers.Conv2D(\n",
+ " filters=filters, kernel_size=2, strides=2, padding=\"same\", use_bias=False\n",
+ " )\n",
+ " self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)\n",
+ "\n",
+ " def call(self, x):\n",
+ " # Apply the patch merging algorithm on the feature maps\n",
+ " x = self.layer_norm(x)\n",
+ " x = self.reduction(x)\n",
+ " return x\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "#### Stacked Shift Blocks\n",
+ "\n",
+ "Each stage will have a variable number of stacked ShiftViT Blocks, as suggested in\n",
+ "the paper. This is a generic layer that will contain the stacked shift vit blocks\n",
+ "with the patch merging layer as well. Combining the two operations (shift ViT\n",
+ "block and patch merging) is a design choice we picked for better code reusability."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "# Note: This layer will have a different depth of stacking\n",
+ "# for different stages on the model.\n",
+ "class StackedShiftBlocks(layers.Layer):\n",
+ " \"\"\"The layer containing stacked ShiftViTBlocks.\n",
+ "\n",
+ " Args:\n",
+ " epsilon (float): The epsilon constant.\n",
+ " mlp_dropout_rate (float): The dropout rate used in the MLP block.\n",
+ " num_shift_blocks (int): The number of shift vit blocks for this stage.\n",
+ " stochastic_depth_rate (float): The maximum drop path rate chosen.\n",
+ " is_merge (boolean): A flag that determines the use of the Patch Merge\n",
+ " layer after the shift vit blocks.\n",
+ " num_div (int): The division of channels of the feature map. Defaults to 12.\n",
+ " shift_pixel (int): The number of pixels to shift. Defaults to 1.\n",
+ " mlp_expand_ratio (int): The ratio with which the initial dense layer of\n",
+ " the MLP is expanded Defaults to 2.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " epsilon,\n",
+ " mlp_dropout_rate,\n",
+ " num_shift_blocks,\n",
+ " stochastic_depth_rate,\n",
+ " is_merge,\n",
+ " num_div=12,\n",
+ " shift_pixel=1,\n",
+ " mlp_expand_ratio=2,\n",
+ " **kwargs,\n",
+ " ):\n",
+ " super().__init__(**kwargs)\n",
+ " self.epsilon = epsilon\n",
+ " self.mlp_dropout_rate = mlp_dropout_rate\n",
+ " self.num_shift_blocks = num_shift_blocks\n",
+ " self.stochastic_depth_rate = stochastic_depth_rate\n",
+ " self.is_merge = is_merge\n",
+ " self.num_div = num_div\n",
+ " self.shift_pixel = shift_pixel\n",
+ " self.mlp_expand_ratio = mlp_expand_ratio\n",
+ "\n",
+ " def build(self, input_shapes):\n",
+ " # Calculate stochastic depth probabilities.\n",
+ " # Reference: https://keras.io/examples/vision/cct/#the-final-cct-model\n",
+ " dpr = [\n",
+ " x\n",
+ " for x in np.linspace(\n",
+ " start=0, stop=self.stochastic_depth_rate, num=self.num_shift_blocks\n",
+ " )\n",
+ " ]\n",
+ "\n",
+ " # Build the shift blocks as a list of ShiftViT Blocks\n",
+ " self.shift_blocks = list()\n",
+ " for num in range(self.num_shift_blocks):\n",
+ " self.shift_blocks.append(\n",
+ " ShiftViTBlock(\n",
+ " num_div=self.num_div,\n",
+ " epsilon=self.epsilon,\n",
+ " drop_path_prob=dpr[num],\n",
+ " mlp_dropout_rate=self.mlp_dropout_rate,\n",
+ " shift_pixel=self.shift_pixel,\n",
+ " mlp_expand_ratio=self.mlp_expand_ratio,\n",
+ " )\n",
+ " )\n",
+ " if self.is_merge:\n",
+ " self.patch_merge = PatchMerging(epsilon=self.epsilon)\n",
+ "\n",
+ " def call(self, x, training=False):\n",
+ " for shift_block in self.shift_blocks:\n",
+ " x = shift_block(x, training=training)\n",
+ " if self.is_merge:\n",
+ " x = self.patch_merge(x)\n",
+ " return x\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## The ShiftViT model\n",
+ "\n",
+ "Build the ShiftViT custom model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class ShiftViTModel(keras.Model):\n",
+ " \"\"\"The ShiftViT Model.\n",
+ "\n",
+ " Args:\n",
+ " data_augmentation (keras.Model): A data augmentation model.\n",
+ " projected_dim (int): The dimension to which the patches of the image are\n",
+ " projected.\n",
+ " patch_size (int): The patch size of the images.\n",
+ " num_shift_blocks_per_stages (list[int]): A list of all the number of shit\n",
+ " blocks per stage.\n",
+ " epsilon (float): The epsilon constant.\n",
+ " mlp_dropout_rate (float): The dropout rate used in the MLP block.\n",
+ " stochastic_depth_rate (float): The maximum drop rate probability.\n",
+ " num_div (int): The number of divisions of the channesl of the feature\n",
+ " map. Defaults to 12.\n",
+ " shift_pixel (int): The number of pixel to shift. Default to 1.\n",
+ " mlp_expand_ratio (int): The ratio with which the initial mlp dense layer\n",
+ " is expanded to. Defaults to 2.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " data_augmentation,\n",
+ " projected_dim,\n",
+ " patch_size,\n",
+ " num_shift_blocks_per_stages,\n",
+ " epsilon,\n",
+ " mlp_dropout_rate,\n",
+ " stochastic_depth_rate,\n",
+ " num_div=12,\n",
+ " shift_pixel=1,\n",
+ " mlp_expand_ratio=2,\n",
+ " **kwargs,\n",
+ " ):\n",
+ " super().__init__(**kwargs)\n",
+ " self.data_augmentation = data_augmentation\n",
+ " self.patch_projection = layers.Conv2D(\n",
+ " filters=projected_dim,\n",
+ " kernel_size=patch_size,\n",
+ " strides=patch_size,\n",
+ " padding=\"same\",\n",
+ " )\n",
+ " self.stages = list()\n",
+ " for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages):\n",
+ " if index == len(num_shift_blocks_per_stages) - 1:\n",
+ " # This is the last stage, do not use the patch merge here.\n",
+ " is_merge = False\n",
+ " else:\n",
+ " is_merge = True\n",
+ " # Build the stages.\n",
+ " self.stages.append(\n",
+ " StackedShiftBlocks(\n",
+ " epsilon=epsilon,\n",
+ " mlp_dropout_rate=mlp_dropout_rate,\n",
+ " num_shift_blocks=num_shift_blocks,\n",
+ " stochastic_depth_rate=stochastic_depth_rate,\n",
+ " is_merge=is_merge,\n",
+ " num_div=num_div,\n",
+ " shift_pixel=shift_pixel,\n",
+ " mlp_expand_ratio=mlp_expand_ratio,\n",
+ " )\n",
+ " )\n",
+ " self.global_avg_pool = layers.GlobalAveragePooling2D()\n",
+ "\n",
+ " def get_config(self):\n",
+ " config = super().get_config()\n",
+ " config.update(\n",
+ " {\n",
+ " \"data_augmentation\": self.data_augmentation,\n",
+ " \"patch_projection\": self.patch_projection,\n",
+ " \"stages\": self.stages,\n",
+ " \"global_avg_pool\": self.global_avg_pool,\n",
+ " }\n",
+ " )\n",
+ " return config\n",
+ "\n",
+ " def _calculate_loss(self, data, training=False):\n",
+ " (images, labels) = data\n",
+ "\n",
+ " # Augment the images\n",
+ " augmented_images = self.data_augmentation(images, training=training)\n",
+ "\n",
+ " # Create patches and project the pathces.\n",
+ " projected_patches = self.patch_projection(augmented_images)\n",
+ "\n",
+ " # Pass through the stages\n",
+ " x = projected_patches\n",
+ " for stage in self.stages:\n",
+ " x = stage(x, training=training)\n",
+ "\n",
+ " # Get the logits.\n",
+ " logits = self.global_avg_pool(x)\n",
+ "\n",
+ " # Calculate the loss and return it.\n",
+ " total_loss = self.compiled_loss(labels, logits)\n",
+ " return total_loss, labels, logits\n",
+ "\n",
+ " def train_step(self, inputs):\n",
+ " with tf.GradientTape() as tape:\n",
+ " total_loss, labels, logits = self._calculate_loss(\n",
+ " data=inputs, training=True\n",
+ " )\n",
+ "\n",
+ " # Apply gradients.\n",
+ " train_vars = [\n",
+ " self.data_augmentation.trainable_variables,\n",
+ " self.patch_projection.trainable_variables,\n",
+ " self.global_avg_pool.trainable_variables,\n",
+ " ]\n",
+ " train_vars = train_vars + [stage.trainable_variables for stage in self.stages]\n",
+ "\n",
+ " # Optimize the gradients.\n",
+ " grads = tape.gradient(total_loss, train_vars)\n",
+ " trainable_variable_list = []\n",
+ " for (grad, var) in zip(grads, train_vars):\n",
+ " for g, v in zip(grad, var):\n",
+ " trainable_variable_list.append((g, v))\n",
+ " self.optimizer.apply_gradients(trainable_variable_list)\n",
+ "\n",
+ " # Update the metrics\n",
+ " self.compiled_metrics.update_state(labels, logits)\n",
+ " return {m.name: m.result() for m in self.metrics}\n",
+ "\n",
+ " def test_step(self, data):\n",
+ " _, labels, logits = self._calculate_loss(data=data, training=False)\n",
+ "\n",
+ " # Update the metrics\n",
+ " self.compiled_metrics.update_state(labels, logits)\n",
+ " return {m.name: m.result() for m in self.metrics}\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Instantiate the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model = ShiftViTModel(\n",
+ " data_augmentation=get_augmentation_model(),\n",
+ " projected_dim=config.projected_dim,\n",
+ " patch_size=config.patch_size,\n",
+ " num_shift_blocks_per_stages=config.num_shift_blocks_per_stages,\n",
+ " epsilon=config.epsilon,\n",
+ " mlp_dropout_rate=config.mlp_dropout_rate,\n",
+ " stochastic_depth_rate=config.stochastic_depth_rate,\n",
+ " num_div=config.num_div,\n",
+ " shift_pixel=config.shift_pixel,\n",
+ " mlp_expand_ratio=config.mlp_expand_ratio,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Learning rate schedule\n",
+ "\n",
+ "In many experiments, we want to warm up the model with a slowly increasing learning rate\n",
+ "and then cool down the model with a slowly decaying learning rate. In the warmup cosine\n",
+ "decay, the learning rate linearly increases for the warmup steps and then decays with a\n",
+ "cosine decay."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):\n",
+ " \"\"\"A LearningRateSchedule that uses a warmup cosine decay schedule.\"\"\"\n",
+ "\n",
+ " def __init__(self, lr_start, lr_max, warmup_steps, total_steps):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " lr_start: The initial learning rate\n",
+ " lr_max: The maximum learning rate to which lr should increase to in\n",
+ " the warmup steps\n",
+ " warmup_steps: The number of steps for which the model warms up\n",
+ " total_steps: The total number of steps for the model training\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ " self.lr_start = lr_start\n",
+ " self.lr_max = lr_max\n",
+ " self.warmup_steps = warmup_steps\n",
+ " self.total_steps = total_steps\n",
+ " self.pi = tf.constant(np.pi)\n",
+ "\n",
+ " def __call__(self, step):\n",
+ " # Check whether the total number of steps is larger than the warmup\n",
+ " # steps. If not, then throw a value error.\n",
+ " if self.total_steps < self.warmup_steps:\n",
+ " raise ValueError(\n",
+ " f\"Total number of steps {self.total_steps} must be\"\n",
+ " + f\"larger or equal to warmup steps {self.warmup_steps}.\"\n",
+ " )\n",
+ "\n",
+ " # `cos_annealed_lr` is a graph that increases to 1 from the initial\n",
+ " # step to the warmup step. After that this graph decays to -1 at the\n",
+ " # final step mark.\n",
+ " cos_annealed_lr = tf.cos(\n",
+ " self.pi\n",
+ " * (tf.cast(step, tf.float32) - self.warmup_steps)\n",
+ " / tf.cast(self.total_steps - self.warmup_steps, tf.float32)\n",
+ " )\n",
+ "\n",
+ " # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes\n",
+ " # from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0\n",
+ " # to 1. With the normalized graph we scale it with `lr_max` such that\n",
+ " # it goes from 0 to `lr_max`\n",
+ " learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)\n",
+ "\n",
+ " # Check whether warmup_steps is more than 0.\n",
+ " if self.warmup_steps > 0:\n",
+ " # Check whether lr_max is larger that lr_start. If not, throw a value\n",
+ " # error.\n",
+ " if self.lr_max < self.lr_start:\n",
+ " raise ValueError(\n",
+ " f\"lr_start {self.lr_start} must be smaller or\"\n",
+ " + f\"equal to lr_max {self.lr_max}.\"\n",
+ " )\n",
+ "\n",
+ " # Calculate the slope with which the learning rate should increase\n",
+ " # in the warumup schedule. The formula for slope is m = ((b-a)/steps)\n",
+ " slope = (self.lr_max - self.lr_start) / self.warmup_steps\n",
+ "\n",
+ " # With the formula for a straight line (y = mx+c) build the warmup\n",
+ " # schedule\n",
+ " warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start\n",
+ "\n",
+ " # When the current step is lesser that warmup steps, get the line\n",
+ " # graph. When the current step is greater than the warmup steps, get\n",
+ " # the scaled cos graph.\n",
+ " learning_rate = tf.where(\n",
+ " step < self.warmup_steps, warmup_rate, learning_rate\n",
+ " )\n",
+ "\n",
+ " # When the current step is more that the total steps, return 0 else return\n",
+ " # the calculated graph.\n",
+ " return tf.where(\n",
+ " step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
+ " )\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Compile and train the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "# Get the total number of steps for training.\n",
+ "total_steps = int((len(x_train) / config.batch_size) * config.epochs)\n",
+ "\n",
+ "# Calculate the number of steps for warmup.\n",
+ "warmup_epoch_percentage = 0.15\n",
+ "warmup_steps = int(total_steps * warmup_epoch_percentage)\n",
+ "\n",
+ "# Initialize the warmupcosine schedule.\n",
+ "scheduled_lrs = WarmUpCosine(\n",
+ " lr_start=1e-5, lr_max=1e-3, warmup_steps=warmup_steps, total_steps=total_steps,\n",
+ ")\n",
+ "\n",
+ "# Get the optimizer.\n",
+ "optimizer = tfa.optimizers.AdamW(\n",
+ " learning_rate=scheduled_lrs, weight_decay=config.weight_decay\n",
+ ")\n",
+ "\n",
+ "# Compile and pretrain the model.\n",
+ "model.compile(\n",
+ " optimizer=optimizer,\n",
+ " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
+ " metrics=[\n",
+ " keras.metrics.SparseCategoricalAccuracy(name=\"accuracy\"),\n",
+ " keras.metrics.SparseTopKCategoricalAccuracy(5, name=\"top-5-accuracy\"),\n",
+ " ],\n",
+ ")\n",
+ "\n",
+ "# Train the model\n",
+ "history = model.fit(\n",
+ " train_ds,\n",
+ " epochs=config.epochs,\n",
+ " validation_data=val_ds,\n",
+ " callbacks=[\n",
+ " keras.callbacks.EarlyStopping(monitor=\"val_accuracy\", patience=5, mode=\"auto\",)\n",
+ " ],\n",
+ ")\n",
+ "\n",
+ "# Evaluate the model with the test dataset.\n",
+ "print(\"TESTING\")\n",
+ "loss, acc_top1, acc_top5 = model.evaluate(test_ds)\n",
+ "print(f\"Loss: {loss:0.2f}\")\n",
+ "print(f\"Top 1 test accuracy: {acc_top1*100:0.2f}%\")\n",
+ "print(f\"Top 5 test accuracy: {acc_top5*100:0.2f}%\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Conclusion\n",
+ "\n",
+ "The most impactful contribution of the paper is not the novel architecture, but\n",
+ "the idea that hierarchical ViTs trained with no attention can perform quite well. This\n",
+ "opens up the question of how essential attention is to the performance of ViTs.\n",
+ "\n",
+ "For curious minds, we would suggest reading the\n",
+ "[ConvNexT](https://arxiv.org/abs/2201.03545) paper which attends more to the training\n",
+ "paradigms and architectural details of ViTs rather than providing a novel architecture\n",
+ "based on attention.\n",
+ "\n",
+ "Acknowledgements:\n",
+ "\n",
+ "- We would like to thank [PyImageSearch](https://pyimagesearch.com) for providing us with\n",
+ "resources that helped in the completion of this project.\n",
+ "- We would like to thank [JarvisLabs.ai](https://jarvislabs.ai/) for providing with the\n",
+ "GPU credits.\n",
+ "- We would like to thank [Manim Community](https://www.manim.community/) for the manim\n",
+ "library.\n",
+ "- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for\n",
+ "helping us with the Learning Rate Schedule."
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "shiftvit",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/vision/md/shiftvit.md b/examples/vision/md/shiftvit.md
new file mode 100644
index 0000000000..f717a55546
--- /dev/null
+++ b/examples/vision/md/shiftvit.md
@@ -0,0 +1,1004 @@
+# A Vision Transformer without Attention
+
+**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
+**Date created:** 2022/02/24
+**Last modified:** 2022/03/01
+**Description:** A minimal implementation of ShiftViT.
+
+
+
[**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/shiftvit.ipynb) •
[**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/vision/shiftvit.py)
+
+
+
+---
+## Introduction
+
+[Vision Transformers](https://arxiv.org/abs/2010.11929) (ViTs) have sparked a wave of
+research at the intersection of Transformers and Computer Vision (CV).
+
+ViTs can simultaneously model long- and short-range dependencies, thanks to
+the Multi-Head Self-Attention mechanism in the Transformer block. Many researchers believe
+that the success of ViTs are purely due to the attention layer, and they seldom
+think about other parts of the ViT model.
+
+In the academic paper
+[When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism](https://arxiv.org/abs/2201.10801)
+the authors propose to demystify the success of ViTs with the introduction of a **NO
+PARAMETER** operation in place of the attention operation. They swap the attention
+operation with a shifting operation.
+
+In this example, we minimally implement the paper with close alignement to the author's
+[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).
+
+This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can
+be installed using the following command:
+
+```shell
+pip install -qq -U tensorflow-addons
+```
+
+---
+## Setup and imports
+
+
+```python
+import numpy as np
+import matplotlib.pyplot as plt
+
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+
+import tensorflow_addons as tfa
+
+# Setting seed for reproducibiltiy
+SEED = 42
+keras.utils.set_random_seed(SEED)
+```
+
+
+```
+/usr/local/lib/python3.8/dist-packages/tensorflow_addons/utils/ensure_tf_install.py:53: UserWarning: Tensorflow Addons supports using Python ops for all Tensorflow versions above or equal to 2.3.0 and strictly below 2.6.0 (nightly versions are not supported).
+ The versions of TensorFlow you are currently using is 2.8.0 and is not supported.
+Some things might work, some things might not.
+If you were to encounter a bug, do not file an issue.
+If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version.
+You can find the compatibility matrix in TensorFlow Addon's readme:
+https://github.com/tensorflow/addons
+ warnings.warn(
+
+```
+
+---
+## Hyperparameters
+
+These are the hyperparameters that we have chosen for the experiment.
+Please feel free to tune them.
+
+
+```python
+
+class Config(object):
+ # DATA
+ batch_size = 256
+ buffer_size = batch_size * 2
+ input_shape = (32, 32, 3)
+ num_classes = 10
+
+ # AUGMENTATION
+ image_size = 48
+
+ # ARCHITECTURE
+ patch_size = 4
+ projected_dim = 96
+ num_shift_blocks_per_stages = [2, 4, 8, 2]
+ epsilon = 1e-5
+ stochastic_depth_rate = 0.2
+ mlp_dropout_rate = 0.2
+ num_div = 12
+ shift_pixel = 1
+ mlp_expand_ratio = 2
+
+ # OPTIMIZER
+ lr_start = 1e-5
+ lr_max = 1e-3
+ weight_decay = 1e-4
+
+ # TRAINING
+ epochs = 100
+
+
+config = Config()
+```
+
+---
+## Load the CIFAR-10 dataset
+
+We use the CIFAR-10 dataset for our experiments.
+
+
+```python
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+(x_train, y_train), (x_val, y_val) = (
+ (x_train[:40000], y_train[:40000]),
+ (x_train[40000:], y_train[40000:]),
+)
+print(f"Training samples: {len(x_train)}")
+print(f"Validation samples: {len(x_val)}")
+print(f"Testing samples: {len(x_test)}")
+
+AUTO = tf.data.AUTOTUNE
+train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+train_ds = train_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO)
+
+val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
+val_ds = val_ds.batch(config.batch_size).prefetch(AUTO)
+
+test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
+```
+
+
+```
+Training samples: 40000
+Validation samples: 10000
+Testing samples: 10000
+
+2022-03-01 03:10:21.342684: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
+To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
+2022-03-01 03:10:21.850844: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38420 MB memory: -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:61:00.0, compute capability: 8.0
+
+```
+
+---
+## Data Augmentation
+
+The augmentation pipeline consists of:
+
+- Rescaling
+- Resizing
+- Random cropping
+- Random horizontal flipping
+
+_Note_: The image data augmentation layers do not apply
+data transformations at inference time. This means that
+when these layers are called with `training=False` they
+behave differently. Refer to the
+[documentation](https://keras.io/api/layers/preprocessing_layers/image_augmentation/)
+for more details.
+
+
+```python
+
+def get_augmentation_model():
+ """Build the data augmentation model."""
+ data_augmentation = keras.Sequential(
+ [
+ layers.Resizing(config.input_shape[0] + 20, config.input_shape[0] + 20),
+ layers.RandomCrop(config.image_size, config.image_size),
+ layers.RandomFlip("horizontal"),
+ layers.Rescaling(1 / 255.0),
+ ]
+ )
+ return data_augmentation
+
+```
+
+---
+## The ShiftViT architecture
+
+In this section, we build the architecture proposed in
+[the ShiftViT paper](https://arxiv.org/abs/2201.10801).
+
+|  |
+| :--: |
+| Figure 1: The entire architecutre of ShiftViT.
+[Source](https://arxiv.org/abs/2201.10801) |
+
+The architecture as shown in Fig. 1, is inspired by
+[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030).
+Here the authors propose a modular architecture with 4 stages. Each stage works on its
+own spatial size, creating a hierarchical architecture.
+
+An input image of size `HxWx3` is split into non-overlapping patches of size `4x4`.
+This is done via the patchify layer which results in individual tokens of feature size `48`
+(`4x4x3`). Each stage comprises two parts:
+
+1. Embedding Generation
+2. Stacked Shift Blocks
+
+We discuss the stages and the modules in detail in what follows.
+
+_Note_: Compared to the [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py)
+we restructure some key components to better fit the Keras API.
+
+### The ShiftViT Block
+
+|  |
+| :--: |
+| Figure 2: From the Model to a Shift Block. |
+
+Each stage in the ShiftViT architecture comprises of a Shift Block as shown in Fig 2.
+
+|  |
+| :--: |
+| Figure 3: The Shift ViT Block. [Source](https://arxiv.org/abs/2201.10801) |
+
+The Shift Block as shown in Fig. 3, comprises of the following:
+
+1. Shift Operation
+2. Linear Normalization
+3. MLP Layer
+
+#### The MLP block
+
+The MLP block is intended to be a stack of densely-connected layers.s
+
+
+```python
+
+class MLP(layers.Layer):
+ """Get the MLP layer for each shift block.
+
+ Args:
+ mlp_expand_ratio (int): The ratio with which the first feature map is expanded.
+ mlp_dropout_rate (float): The rate for dropout.
+ """
+
+ def __init__(self, mlp_expand_ratio, mlp_dropout_rate, **kwargs):
+ super().__init__(**kwargs)
+ self.mlp_expand_ratio = mlp_expand_ratio
+ self.mlp_dropout_rate = mlp_dropout_rate
+
+ def build(self, input_shape):
+ input_channels = input_shape[-1]
+ initial_filters = int(self.mlp_expand_ratio * input_channels)
+
+ self.mlp = keras.Sequential(
+ [
+ layers.Dense(units=initial_filters, activation=tf.nn.gelu,),
+ layers.Dropout(rate=self.mlp_dropout_rate),
+ layers.Dense(units=input_channels),
+ layers.Dropout(rate=self.mlp_dropout_rate),
+ ]
+ )
+
+ def call(self, x):
+ x = self.mlp(x)
+ return x
+
+```
+
+#### The DropPath layer
+
+Stochastic depth is a regularization technique that randomly drops a set of
+layers. During inference, the layers are kept as they are. It is very
+similar to Dropout, but it operates on a block of layers rather
+than on individual nodes present inside a layer.
+
+
+```python
+
+class DropPath(layers.Layer):
+ """Drop Path also known as the Stochastic Depth layer.
+
+ Refernece:
+ - https://keras.io/examples/vision/cct/#stochastic-depth-for-regularization
+ - github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path_prob, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path_prob = drop_path_prob
+
+ def call(self, x, training=False):
+ if training:
+ keep_prob = 1 - self.drop_path_prob
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+```
+
+#### Block
+
+The most important operation in this paper is the **shift opperation**. In this section,
+we describe the shift operation and compare it with its original implementation provided
+by the authors.
+
+A generic feature map is assumed to have the shape `[N, H, W, C]`. Here we choose a
+`num_div` parameter that decides the division size of the channels. The first 4 divisions
+are shifted (1 pixel) in the left, right, up, and down direction. The remaining splits
+are kept as is. After partial shifting the shifted channels are padded and the overflown
+pixels are chopped off. This completes the partial shifting operation.
+
+In the original implementation, the code is approximately:
+
+```python
+out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # shift left
+out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # shift right
+out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # shift up
+out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # shift down
+
+out[:, g * 4:, :, :] = x[:, g * 4:, :, :] # no shift
+```
+
+In TensorFlow it would be infeasible for us to assign shifted channels to a tensor in the
+middle of the training process. This is why we have resorted to the following procedure:
+
+1. Split the channels with the `num_div` parameter.
+2. Select each of the first four spilts and shift and pad them in the respective
+directions.
+3. After shifting and padding, we concatenate the channel back.
+
+|  |
+| :--: |
+| Figure 4: The TensorFlow style shifting |
+
+The entire procedure is explained in the Fig. 4.
+
+
+```python
+
+class ShiftViTBlock(layers.Layer):
+ """A unit ShiftViT Block
+
+ Args:
+ shift_pixel (int): The number of pixels to shift. Default to 1.
+ mlp_expand_ratio (int): The ratio with which MLP features are
+ expanded. Default to 2.
+ mlp_dropout_rate (float): The dropout rate used in MLP.
+ num_div (int): The number of divisions of the feature map's channel.
+ Totally, 4/num_div of channels will be shifted. Defaults to 12.
+ epsilon (float): Epsilon constant.
+ drop_path_prob (float): The drop probability for drop path.
+ """
+
+ def __init__(
+ self,
+ epsilon,
+ drop_path_prob,
+ mlp_dropout_rate,
+ num_div=12,
+ shift_pixel=1,
+ mlp_expand_ratio=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.shift_pixel = shift_pixel
+ self.mlp_expand_ratio = mlp_expand_ratio
+ self.mlp_dropout_rate = mlp_dropout_rate
+ self.num_div = num_div
+ self.epsilon = epsilon
+ self.drop_path_prob = drop_path_prob
+
+ def build(self, input_shape):
+ self.H = input_shape[1]
+ self.W = input_shape[2]
+ self.C = input_shape[3]
+ self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
+ self.drop_path = (
+ DropPath(drop_path_prob=self.drop_path_prob)
+ if self.drop_path_prob > 0.0
+ else layers.Activation("linear")
+ )
+ self.mlp = MLP(
+ mlp_expand_ratio=self.mlp_expand_ratio,
+ mlp_dropout_rate=self.mlp_dropout_rate,
+ )
+
+ def get_shift_pad(self, x, mode):
+ """Shifts the channels according to the mode chosen."""
+ if mode == "left":
+ offset_height = 0
+ offset_width = 0
+ target_height = 0
+ target_width = self.shift_pixel
+ elif mode == "right":
+ offset_height = 0
+ offset_width = self.shift_pixel
+ target_height = 0
+ target_width = self.shift_pixel
+ elif mode == "up":
+ offset_height = 0
+ offset_width = 0
+ target_height = self.shift_pixel
+ target_width = 0
+ else:
+ offset_height = self.shift_pixel
+ offset_width = 0
+ target_height = self.shift_pixel
+ target_width = 0
+ crop = tf.image.crop_to_bounding_box(
+ x,
+ offset_height=offset_height,
+ offset_width=offset_width,
+ target_height=self.H - target_height,
+ target_width=self.W - target_width,
+ )
+ shift_pad = tf.image.pad_to_bounding_box(
+ crop,
+ offset_height=offset_height,
+ offset_width=offset_width,
+ target_height=self.H,
+ target_width=self.W,
+ )
+ return shift_pad
+
+ def call(self, x, training=False):
+ # Split the feature maps
+ x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1)
+
+ # Shift the feature maps
+ x_splits[0] = self.get_shift_pad(x_splits[0], mode="left")
+ x_splits[1] = self.get_shift_pad(x_splits[1], mode="right")
+ x_splits[2] = self.get_shift_pad(x_splits[2], mode="up")
+ x_splits[3] = self.get_shift_pad(x_splits[3], mode="down")
+
+ # Concatenate the shifted and unshifted feature maps
+ x = tf.concat(x_splits, axis=-1)
+
+ # Add the residual connection
+ shortcut = x
+ x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training)
+ return x
+
+```
+
+### The ShiftViT blocks
+
+|  |
+| :--: |
+| Figure 5: Shift Blocks in the architecture. [Source](https://arxiv.org/abs/2201.10801) |
+
+Each stage of the architecture has shift blocks as shown in Fig.5. Each of these blocks
+contain a variable number of stacked ShiftViT block (as built in the earlier section).
+
+Shift blocks are followed by a PatchMerging layer that scales down feature inputs. The
+PatchMerging layer helps in the pyramidal structure of the model.
+
+#### The PatchMerging layer
+
+This layer merges the two adjacent tokens. This layer helps in scaling the features down
+spatially and increasing the features up channel wise. We use a Conv2D layer to merge the
+patches.
+
+
+```python
+
+class PatchMerging(layers.Layer):
+ """The Patch Merging layer.
+
+ Args:
+ epsilon (float): The epsilon constant.
+ """
+
+ def __init__(self, epsilon, **kwargs):
+ super().__init__(**kwargs)
+ self.epsilon = epsilon
+
+ def build(self, input_shape):
+ filters = 2 * input_shape[-1]
+ self.reduction = layers.Conv2D(
+ filters=filters, kernel_size=2, strides=2, padding="same", use_bias=False
+ )
+ self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
+
+ def call(self, x):
+ # Apply the patch merging algorithm on the feature maps
+ x = self.layer_norm(x)
+ x = self.reduction(x)
+ return x
+
+```
+
+#### Stacked Shift Blocks
+
+Each stage will have a variable number of stacked ShiftViT Blocks, as suggested in
+the paper. This is a generic layer that will contain the stacked shift vit blocks
+with the patch merging layer as well. Combining the two operations (shift ViT
+block and patch merging) is a design choice we picked for better code reusability.
+
+
+```python
+# Note: This layer will have a different depth of stacking
+# for different stages on the model.
+class StackedShiftBlocks(layers.Layer):
+ """The layer containing stacked ShiftViTBlocks.
+
+ Args:
+ epsilon (float): The epsilon constant.
+ mlp_dropout_rate (float): The dropout rate used in the MLP block.
+ num_shift_blocks (int): The number of shift vit blocks for this stage.
+ stochastic_depth_rate (float): The maximum drop path rate chosen.
+ is_merge (boolean): A flag that determines the use of the Patch Merge
+ layer after the shift vit blocks.
+ num_div (int): The division of channels of the feature map. Defaults to 12.
+ shift_pixel (int): The number of pixels to shift. Defaults to 1.
+ mlp_expand_ratio (int): The ratio with which the initial dense layer of
+ the MLP is expanded Defaults to 2.
+ """
+
+ def __init__(
+ self,
+ epsilon,
+ mlp_dropout_rate,
+ num_shift_blocks,
+ stochastic_depth_rate,
+ is_merge,
+ num_div=12,
+ shift_pixel=1,
+ mlp_expand_ratio=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.epsilon = epsilon
+ self.mlp_dropout_rate = mlp_dropout_rate
+ self.num_shift_blocks = num_shift_blocks
+ self.stochastic_depth_rate = stochastic_depth_rate
+ self.is_merge = is_merge
+ self.num_div = num_div
+ self.shift_pixel = shift_pixel
+ self.mlp_expand_ratio = mlp_expand_ratio
+
+ def build(self, input_shapes):
+ # Calculate stochastic depth probabilities.
+ # Reference: https://keras.io/examples/vision/cct/#the-final-cct-model
+ dpr = [
+ x
+ for x in np.linspace(
+ start=0, stop=self.stochastic_depth_rate, num=self.num_shift_blocks
+ )
+ ]
+
+ # Build the shift blocks as a list of ShiftViT Blocks
+ self.shift_blocks = list()
+ for num in range(self.num_shift_blocks):
+ self.shift_blocks.append(
+ ShiftViTBlock(
+ num_div=self.num_div,
+ epsilon=self.epsilon,
+ drop_path_prob=dpr[num],
+ mlp_dropout_rate=self.mlp_dropout_rate,
+ shift_pixel=self.shift_pixel,
+ mlp_expand_ratio=self.mlp_expand_ratio,
+ )
+ )
+ if self.is_merge:
+ self.patch_merge = PatchMerging(epsilon=self.epsilon)
+
+ def call(self, x, training=False):
+ for shift_block in self.shift_blocks:
+ x = shift_block(x, training=training)
+ if self.is_merge:
+ x = self.patch_merge(x)
+ return x
+
+```
+
+---
+## The ShiftViT model
+
+Build the ShiftViT custom model.
+
+
+```python
+
+class ShiftViTModel(keras.Model):
+ """The ShiftViT Model.
+
+ Args:
+ data_augmentation (keras.Model): A data augmentation model.
+ projected_dim (int): The dimension to which the patches of the image are
+ projected.
+ patch_size (int): The patch size of the images.
+ num_shift_blocks_per_stages (list[int]): A list of all the number of shit
+ blocks per stage.
+ epsilon (float): The epsilon constant.
+ mlp_dropout_rate (float): The dropout rate used in the MLP block.
+ stochastic_depth_rate (float): The maximum drop rate probability.
+ num_div (int): The number of divisions of the channesl of the feature
+ map. Defaults to 12.
+ shift_pixel (int): The number of pixel to shift. Default to 1.
+ mlp_expand_ratio (int): The ratio with which the initial mlp dense layer
+ is expanded to. Defaults to 2.
+ """
+
+ def __init__(
+ self,
+ data_augmentation,
+ projected_dim,
+ patch_size,
+ num_shift_blocks_per_stages,
+ epsilon,
+ mlp_dropout_rate,
+ stochastic_depth_rate,
+ num_div=12,
+ shift_pixel=1,
+ mlp_expand_ratio=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.data_augmentation = data_augmentation
+ self.patch_projection = layers.Conv2D(
+ filters=projected_dim,
+ kernel_size=patch_size,
+ strides=patch_size,
+ padding="same",
+ )
+ self.stages = list()
+ for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages):
+ if index == len(num_shift_blocks_per_stages) - 1:
+ # This is the last stage, do not use the patch merge here.
+ is_merge = False
+ else:
+ is_merge = True
+ # Build the stages.
+ self.stages.append(
+ StackedShiftBlocks(
+ epsilon=epsilon,
+ mlp_dropout_rate=mlp_dropout_rate,
+ num_shift_blocks=num_shift_blocks,
+ stochastic_depth_rate=stochastic_depth_rate,
+ is_merge=is_merge,
+ num_div=num_div,
+ shift_pixel=shift_pixel,
+ mlp_expand_ratio=mlp_expand_ratio,
+ )
+ )
+ self.global_avg_pool = layers.GlobalAveragePooling2D()
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "data_augmentation": self.data_augmentation,
+ "patch_projection": self.patch_projection,
+ "stages": self.stages,
+ "global_avg_pool": self.global_avg_pool,
+ }
+ )
+ return config
+
+ def _calculate_loss(self, data, training=False):
+ (images, labels) = data
+
+ # Augment the images
+ augmented_images = self.data_augmentation(images, training=training)
+
+ # Create patches and project the pathces.
+ projected_patches = self.patch_projection(augmented_images)
+
+ # Pass through the stages
+ x = projected_patches
+ for stage in self.stages:
+ x = stage(x, training=training)
+
+ # Get the logits.
+ logits = self.global_avg_pool(x)
+
+ # Calculate the loss and return it.
+ total_loss = self.compiled_loss(labels, logits)
+ return total_loss, labels, logits
+
+ def train_step(self, inputs):
+ with tf.GradientTape() as tape:
+ total_loss, labels, logits = self._calculate_loss(
+ data=inputs, training=True
+ )
+
+ # Apply gradients.
+ train_vars = [
+ self.data_augmentation.trainable_variables,
+ self.patch_projection.trainable_variables,
+ self.global_avg_pool.trainable_variables,
+ ]
+ train_vars = train_vars + [stage.trainable_variables for stage in self.stages]
+
+ # Optimize the gradients.
+ grads = tape.gradient(total_loss, train_vars)
+ trainable_variable_list = []
+ for (grad, var) in zip(grads, train_vars):
+ for g, v in zip(grad, var):
+ trainable_variable_list.append((g, v))
+ self.optimizer.apply_gradients(trainable_variable_list)
+
+ # Update the metrics
+ self.compiled_metrics.update_state(labels, logits)
+ return {m.name: m.result() for m in self.metrics}
+
+ def test_step(self, data):
+ _, labels, logits = self._calculate_loss(data=data, training=False)
+
+ # Update the metrics
+ self.compiled_metrics.update_state(labels, logits)
+ return {m.name: m.result() for m in self.metrics}
+
+```
+
+---
+## Instantiate the model
+
+
+```python
+model = ShiftViTModel(
+ data_augmentation=get_augmentation_model(),
+ projected_dim=config.projected_dim,
+ patch_size=config.patch_size,
+ num_shift_blocks_per_stages=config.num_shift_blocks_per_stages,
+ epsilon=config.epsilon,
+ mlp_dropout_rate=config.mlp_dropout_rate,
+ stochastic_depth_rate=config.stochastic_depth_rate,
+ num_div=config.num_div,
+ shift_pixel=config.shift_pixel,
+ mlp_expand_ratio=config.mlp_expand_ratio,
+)
+```
+
+---
+## Learning rate schedule
+
+In many experiments, we want to warm up the model with a slowly increasing learning rate
+and then cool down the model with a slowly decaying learning rate. In the warmup cosine
+decay, the learning rate linearly increases for the warmup steps and then decays with a
+cosine decay.
+
+
+```python
+
+class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
+ """A LearningRateSchedule that uses a warmup cosine decay schedule."""
+
+ def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
+ """
+ Args:
+ lr_start: The initial learning rate
+ lr_max: The maximum learning rate to which lr should increase to in
+ the warmup steps
+ warmup_steps: The number of steps for which the model warms up
+ total_steps: The total number of steps for the model training
+ """
+ super().__init__()
+ self.lr_start = lr_start
+ self.lr_max = lr_max
+ self.warmup_steps = warmup_steps
+ self.total_steps = total_steps
+ self.pi = tf.constant(np.pi)
+
+ def __call__(self, step):
+ # Check whether the total number of steps is larger than the warmup
+ # steps. If not, then throw a value error.
+ if self.total_steps < self.warmup_steps:
+ raise ValueError(
+ f"Total number of steps {self.total_steps} must be"
+ + f"larger or equal to warmup steps {self.warmup_steps}."
+ )
+
+ # `cos_annealed_lr` is a graph that increases to 1 from the initial
+ # step to the warmup step. After that this graph decays to -1 at the
+ # final step mark.
+ cos_annealed_lr = tf.cos(
+ self.pi
+ * (tf.cast(step, tf.float32) - self.warmup_steps)
+ / tf.cast(self.total_steps - self.warmup_steps, tf.float32)
+ )
+
+ # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
+ # from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0
+ # to 1. With the normalized graph we scale it with `lr_max` such that
+ # it goes from 0 to `lr_max`
+ learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)
+
+ # Check whether warmup_steps is more than 0.
+ if self.warmup_steps > 0:
+ # Check whether lr_max is larger that lr_start. If not, throw a value
+ # error.
+ if self.lr_max < self.lr_start:
+ raise ValueError(
+ f"lr_start {self.lr_start} must be smaller or"
+ + f"equal to lr_max {self.lr_max}."
+ )
+
+ # Calculate the slope with which the learning rate should increase
+ # in the warumup schedule. The formula for slope is m = ((b-a)/steps)
+ slope = (self.lr_max - self.lr_start) / self.warmup_steps
+
+ # With the formula for a straight line (y = mx+c) build the warmup
+ # schedule
+ warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
+
+ # When the current step is lesser that warmup steps, get the line
+ # graph. When the current step is greater than the warmup steps, get
+ # the scaled cos graph.
+ learning_rate = tf.where(
+ step < self.warmup_steps, warmup_rate, learning_rate
+ )
+
+ # When the current step is more that the total steps, return 0 else return
+ # the calculated graph.
+ return tf.where(
+ step > self.total_steps, 0.0, learning_rate, name="learning_rate"
+ )
+
+```
+
+---
+## Compile and train the model
+
+
+```python
+# Get the total number of steps for training.
+total_steps = int((len(x_train) / config.batch_size) * config.epochs)
+
+# Calculate the number of steps for warmup.
+warmup_epoch_percentage = 0.15
+warmup_steps = int(total_steps * warmup_epoch_percentage)
+
+# Initialize the warmupcosine schedule.
+scheduled_lrs = WarmUpCosine(
+ lr_start=1e-5, lr_max=1e-3, warmup_steps=warmup_steps, total_steps=total_steps,
+)
+
+# Get the optimizer.
+optimizer = tfa.optimizers.AdamW(
+ learning_rate=scheduled_lrs, weight_decay=config.weight_decay
+)
+
+# Compile and pretrain the model.
+model.compile(
+ optimizer=optimizer,
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+)
+
+# Train the model
+history = model.fit(
+ train_ds,
+ epochs=config.epochs,
+ validation_data=val_ds,
+ callbacks=[
+ keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5, mode="auto",)
+ ],
+)
+
+# Evaluate the model with the test dataset.
+print("TESTING")
+loss, acc_top1, acc_top5 = model.evaluate(test_ds)
+print(f"Loss: {loss:0.2f}")
+print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%")
+print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
+```
+
+
+```
+Epoch 1/100
+
+2022-03-01 03:10:41.373231: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8202
+2022-03-01 03:10:43.145958: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
+
+157/157 [==============================] - 34s 84ms/step - loss: 3.2975 - accuracy: 0.1084 - top-5-accuracy: 0.4806 - val_loss: 2.1575 - val_accuracy: 0.2017 - val_top-5-accuracy: 0.7184
+Epoch 2/100
+157/157 [==============================] - 11s 67ms/step - loss: 2.1727 - accuracy: 0.2289 - top-5-accuracy: 0.7516 - val_loss: 1.8819 - val_accuracy: 0.3182 - val_top-5-accuracy: 0.8386
+Epoch 3/100
+157/157 [==============================] - 10s 67ms/step - loss: 1.8169 - accuracy: 0.3426 - top-5-accuracy: 0.8592 - val_loss: 1.6174 - val_accuracy: 0.4053 - val_top-5-accuracy: 0.8934
+Epoch 4/100
+157/157 [==============================] - 10s 67ms/step - loss: 1.6215 - accuracy: 0.4092 - top-5-accuracy: 0.8983 - val_loss: 1.4239 - val_accuracy: 0.4903 - val_top-5-accuracy: 0.9216
+Epoch 5/100
+157/157 [==============================] - 10s 66ms/step - loss: 1.5081 - accuracy: 0.4571 - top-5-accuracy: 0.9148 - val_loss: 1.3359 - val_accuracy: 0.5161 - val_top-5-accuracy: 0.9369
+Epoch 6/100
+157/157 [==============================] - 11s 68ms/step - loss: 1.4282 - accuracy: 0.4868 - top-5-accuracy: 0.9249 - val_loss: 1.2929 - val_accuracy: 0.5347 - val_top-5-accuracy: 0.9404
+Epoch 7/100
+157/157 [==============================] - 10s 66ms/step - loss: 1.3465 - accuracy: 0.5181 - top-5-accuracy: 0.9362 - val_loss: 1.2653 - val_accuracy: 0.5497 - val_top-5-accuracy: 0.9449
+Epoch 8/100
+157/157 [==============================] - 10s 67ms/step - loss: 1.2907 - accuracy: 0.5400 - top-5-accuracy: 0.9416 - val_loss: 1.1919 - val_accuracy: 0.5753 - val_top-5-accuracy: 0.9515
+Epoch 9/100
+157/157 [==============================] - 11s 67ms/step - loss: 1.2247 - accuracy: 0.5644 - top-5-accuracy: 0.9480 - val_loss: 1.1741 - val_accuracy: 0.5742 - val_top-5-accuracy: 0.9563
+Epoch 10/100
+157/157 [==============================] - 11s 67ms/step - loss: 1.1983 - accuracy: 0.5760 - top-5-accuracy: 0.9505 - val_loss: 1.4545 - val_accuracy: 0.4804 - val_top-5-accuracy: 0.9198
+Epoch 11/100
+157/157 [==============================] - 10s 66ms/step - loss: 1.2002 - accuracy: 0.5766 - top-5-accuracy: 0.9510 - val_loss: 1.1129 - val_accuracy: 0.6055 - val_top-5-accuracy: 0.9593
+Epoch 12/100
+157/157 [==============================] - 10s 66ms/step - loss: 1.1309 - accuracy: 0.5990 - top-5-accuracy: 0.9575 - val_loss: 1.0369 - val_accuracy: 0.6341 - val_top-5-accuracy: 0.9638
+Epoch 13/100
+157/157 [==============================] - 10s 66ms/step - loss: 1.0786 - accuracy: 0.6204 - top-5-accuracy: 0.9613 - val_loss: 1.0802 - val_accuracy: 0.6193 - val_top-5-accuracy: 0.9594
+Epoch 14/100
+157/157 [==============================] - 10s 65ms/step - loss: 1.0438 - accuracy: 0.6330 - top-5-accuracy: 0.9640 - val_loss: 0.9584 - val_accuracy: 0.6596 - val_top-5-accuracy: 0.9713
+Epoch 15/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.9957 - accuracy: 0.6496 - top-5-accuracy: 0.9684 - val_loss: 0.9530 - val_accuracy: 0.6636 - val_top-5-accuracy: 0.9712
+Epoch 16/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.9710 - accuracy: 0.6599 - top-5-accuracy: 0.9696 - val_loss: 0.8856 - val_accuracy: 0.6863 - val_top-5-accuracy: 0.9756
+Epoch 17/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.9316 - accuracy: 0.6706 - top-5-accuracy: 0.9721 - val_loss: 0.9919 - val_accuracy: 0.6480 - val_top-5-accuracy: 0.9671
+Epoch 18/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.8899 - accuracy: 0.6884 - top-5-accuracy: 0.9763 - val_loss: 0.8753 - val_accuracy: 0.6949 - val_top-5-accuracy: 0.9752
+Epoch 19/100
+157/157 [==============================] - 10s 64ms/step - loss: 0.8529 - accuracy: 0.6979 - top-5-accuracy: 0.9772 - val_loss: 0.8793 - val_accuracy: 0.6943 - val_top-5-accuracy: 0.9754
+Epoch 20/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.8509 - accuracy: 0.7009 - top-5-accuracy: 0.9783 - val_loss: 0.8183 - val_accuracy: 0.7174 - val_top-5-accuracy: 0.9763
+Epoch 21/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.8087 - accuracy: 0.7143 - top-5-accuracy: 0.9809 - val_loss: 0.7885 - val_accuracy: 0.7276 - val_top-5-accuracy: 0.9769
+Epoch 22/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.8004 - accuracy: 0.7192 - top-5-accuracy: 0.9811 - val_loss: 0.7601 - val_accuracy: 0.7371 - val_top-5-accuracy: 0.9805
+Epoch 23/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.7665 - accuracy: 0.7304 - top-5-accuracy: 0.9816 - val_loss: 0.7564 - val_accuracy: 0.7412 - val_top-5-accuracy: 0.9808
+Epoch 24/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.7599 - accuracy: 0.7344 - top-5-accuracy: 0.9832 - val_loss: 0.7475 - val_accuracy: 0.7389 - val_top-5-accuracy: 0.9822
+Epoch 25/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.7398 - accuracy: 0.7427 - top-5-accuracy: 0.9833 - val_loss: 0.7211 - val_accuracy: 0.7504 - val_top-5-accuracy: 0.9829
+Epoch 26/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.7114 - accuracy: 0.7500 - top-5-accuracy: 0.9857 - val_loss: 0.7385 - val_accuracy: 0.7462 - val_top-5-accuracy: 0.9822
+Epoch 27/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.6954 - accuracy: 0.7577 - top-5-accuracy: 0.9851 - val_loss: 0.7477 - val_accuracy: 0.7402 - val_top-5-accuracy: 0.9802
+Epoch 28/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.6807 - accuracy: 0.7588 - top-5-accuracy: 0.9871 - val_loss: 0.7275 - val_accuracy: 0.7536 - val_top-5-accuracy: 0.9822
+Epoch 29/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.6719 - accuracy: 0.7648 - top-5-accuracy: 0.9876 - val_loss: 0.7261 - val_accuracy: 0.7487 - val_top-5-accuracy: 0.9815
+Epoch 30/100
+157/157 [==============================] - 10s 65ms/step - loss: 0.6578 - accuracy: 0.7696 - top-5-accuracy: 0.9871 - val_loss: 0.6932 - val_accuracy: 0.7641 - val_top-5-accuracy: 0.9833
+Epoch 31/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.6489 - accuracy: 0.7740 - top-5-accuracy: 0.9877 - val_loss: 0.7400 - val_accuracy: 0.7486 - val_top-5-accuracy: 0.9820
+Epoch 32/100
+157/157 [==============================] - 10s 65ms/step - loss: 0.6290 - accuracy: 0.7812 - top-5-accuracy: 0.9895 - val_loss: 0.6954 - val_accuracy: 0.7628 - val_top-5-accuracy: 0.9847
+Epoch 33/100
+157/157 [==============================] - 10s 67ms/step - loss: 0.6194 - accuracy: 0.7826 - top-5-accuracy: 0.9894 - val_loss: 0.6913 - val_accuracy: 0.7619 - val_top-5-accuracy: 0.9842
+Epoch 34/100
+157/157 [==============================] - 10s 65ms/step - loss: 0.5917 - accuracy: 0.7930 - top-5-accuracy: 0.9902 - val_loss: 0.6879 - val_accuracy: 0.7715 - val_top-5-accuracy: 0.9831
+Epoch 35/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.5878 - accuracy: 0.7916 - top-5-accuracy: 0.9907 - val_loss: 0.6759 - val_accuracy: 0.7720 - val_top-5-accuracy: 0.9849
+Epoch 36/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.5713 - accuracy: 0.8004 - top-5-accuracy: 0.9913 - val_loss: 0.6920 - val_accuracy: 0.7657 - val_top-5-accuracy: 0.9841
+Epoch 37/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.5590 - accuracy: 0.8040 - top-5-accuracy: 0.9913 - val_loss: 0.6790 - val_accuracy: 0.7718 - val_top-5-accuracy: 0.9831
+Epoch 38/100
+157/157 [==============================] - 11s 67ms/step - loss: 0.5445 - accuracy: 0.8114 - top-5-accuracy: 0.9926 - val_loss: 0.6756 - val_accuracy: 0.7720 - val_top-5-accuracy: 0.9852
+Epoch 39/100
+157/157 [==============================] - 11s 67ms/step - loss: 0.5292 - accuracy: 0.8155 - top-5-accuracy: 0.9930 - val_loss: 0.6578 - val_accuracy: 0.7807 - val_top-5-accuracy: 0.9845
+Epoch 40/100
+157/157 [==============================] - 11s 68ms/step - loss: 0.5169 - accuracy: 0.8181 - top-5-accuracy: 0.9926 - val_loss: 0.6582 - val_accuracy: 0.7795 - val_top-5-accuracy: 0.9849
+Epoch 41/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.5108 - accuracy: 0.8217 - top-5-accuracy: 0.9937 - val_loss: 0.6344 - val_accuracy: 0.7846 - val_top-5-accuracy: 0.9855
+Epoch 42/100
+157/157 [==============================] - 10s 65ms/step - loss: 0.5056 - accuracy: 0.8220 - top-5-accuracy: 0.9936 - val_loss: 0.6723 - val_accuracy: 0.7744 - val_top-5-accuracy: 0.9851
+Epoch 43/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.4824 - accuracy: 0.8317 - top-5-accuracy: 0.9943 - val_loss: 0.6800 - val_accuracy: 0.7771 - val_top-5-accuracy: 0.9834
+Epoch 44/100
+157/157 [==============================] - 10s 67ms/step - loss: 0.4719 - accuracy: 0.8339 - top-5-accuracy: 0.9938 - val_loss: 0.6742 - val_accuracy: 0.7785 - val_top-5-accuracy: 0.9840
+Epoch 45/100
+157/157 [==============================] - 10s 65ms/step - loss: 0.4605 - accuracy: 0.8379 - top-5-accuracy: 0.9953 - val_loss: 0.6732 - val_accuracy: 0.7781 - val_top-5-accuracy: 0.9841
+Epoch 46/100
+157/157 [==============================] - 10s 66ms/step - loss: 0.4608 - accuracy: 0.8390 - top-5-accuracy: 0.9947 - val_loss: 0.6547 - val_accuracy: 0.7846 - val_top-5-accuracy: 0.9852
+TESTING
+40/40 [==============================] - 1s 22ms/step - loss: 0.6801 - accuracy: 0.7720 - top-5-accuracy: 0.9864
+Loss: 0.68
+Top 1 test accuracy: 77.20%
+Top 5 test accuracy: 98.64%
+
+```
+
+---
+## Conclusion
+
+The most impactful contribution of the paper is not the novel architecture, but
+the idea that hierarchical ViTs trained with no attention can perform quite well. This
+opens up the question of how essential attention is to the performance of ViTs.
+
+For curious minds, we would suggest reading the
+[ConvNexT](https://arxiv.org/abs/2201.03545) paper which attends more to the training
+paradigms and architectural details of ViTs rather than providing a novel architecture
+based on attention.
+
+Acknowledgements:
+
+- We would like to thank [PyImageSearch](https://pyimagesearch.com) for providing us with
+resources that helped in the completion of this project.
+- We would like to thank [JarvisLabs.ai](https://jarvislabs.ai/) for providing with the
+GPU credits.
+- We would like to thank [Manim Community](https://www.manim.community/) for the manim
+library.
+- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for
+helping us with the Learning Rate Schedule.
diff --git a/examples/vision/shiftvit.py b/examples/vision/shiftvit.py
new file mode 100644
index 0000000000..69434cae35
--- /dev/null
+++ b/examples/vision/shiftvit.py
@@ -0,0 +1,854 @@
+"""
+Title: A Vision Transformer without Attention
+Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
+Date created: 2022/02/24
+Last modified: 2022/03/01
+Description: A minimal implementation of ShiftViT.
+"""
+"""
+## Introduction
+
+[Vision Transformers](https://arxiv.org/abs/2010.11929) (ViTs) have sparked a wave of
+research at the intersection of Transformers and Computer Vision (CV).
+
+ViTs can simultaneously model long- and short-range dependencies, thanks to
+the Multi-Head Self-Attention mechanism in the Transformer block. Many researchers believe
+that the success of ViTs are purely due to the attention layer, and they seldom
+think about other parts of the ViT model.
+
+In the academic paper
+[When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism](https://arxiv.org/abs/2201.10801)
+the authors propose to demystify the success of ViTs with the introduction of a **NO
+PARAMETER** operation in place of the attention operation. They swap the attention
+operation with a shifting operation.
+
+In this example, we minimally implement the paper with close alignement to the author's
+[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).
+
+This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can
+be installed using the following command:
+
+```shell
+pip install -qq -U tensorflow-addons
+```
+"""
+
+"""
+## Setup and imports
+"""
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+
+import tensorflow_addons as tfa
+
+# Setting seed for reproducibiltiy
+SEED = 42
+keras.utils.set_random_seed(SEED)
+
+"""
+## Hyperparameters
+
+These are the hyperparameters that we have chosen for the experiment.
+Please feel free to tune them.
+"""
+
+
+class Config(object):
+ # DATA
+ batch_size = 256
+ buffer_size = batch_size * 2
+ input_shape = (32, 32, 3)
+ num_classes = 10
+
+ # AUGMENTATION
+ image_size = 48
+
+ # ARCHITECTURE
+ patch_size = 4
+ projected_dim = 96
+ num_shift_blocks_per_stages = [2, 4, 8, 2]
+ epsilon = 1e-5
+ stochastic_depth_rate = 0.2
+ mlp_dropout_rate = 0.2
+ num_div = 12
+ shift_pixel = 1
+ mlp_expand_ratio = 2
+
+ # OPTIMIZER
+ lr_start = 1e-5
+ lr_max = 1e-3
+ weight_decay = 1e-4
+
+ # TRAINING
+ epochs = 100
+
+
+config = Config()
+
+"""
+## Load the CIFAR-10 dataset
+
+We use the CIFAR-10 dataset for our experiments.
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+(x_train, y_train), (x_val, y_val) = (
+ (x_train[:40000], y_train[:40000]),
+ (x_train[40000:], y_train[40000:]),
+)
+print(f"Training samples: {len(x_train)}")
+print(f"Validation samples: {len(x_val)}")
+print(f"Testing samples: {len(x_test)}")
+
+AUTO = tf.data.AUTOTUNE
+train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+train_ds = train_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO)
+
+val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
+val_ds = val_ds.batch(config.batch_size).prefetch(AUTO)
+
+test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
+
+"""
+## Data Augmentation
+
+The augmentation pipeline consists of:
+
+- Rescaling
+- Resizing
+- Random cropping
+- Random horizontal flipping
+
+_Note_: The image data augmentation layers do not apply
+data transformations at inference time. This means that
+when these layers are called with `training=False` they
+behave differently. Refer to the
+[documentation](https://keras.io/api/layers/preprocessing_layers/image_augmentation/)
+for more details.
+"""
+
+
+def get_augmentation_model():
+ """Build the data augmentation model."""
+ data_augmentation = keras.Sequential(
+ [
+ layers.Resizing(config.input_shape[0] + 20, config.input_shape[0] + 20),
+ layers.RandomCrop(config.image_size, config.image_size),
+ layers.RandomFlip("horizontal"),
+ layers.Rescaling(1 / 255.0),
+ ]
+ )
+ return data_augmentation
+
+
+"""
+## The ShiftViT architecture
+
+In this section, we build the architecture proposed in
+[the ShiftViT paper](https://arxiv.org/abs/2201.10801).
+
+|  |
+| :--: |
+| Figure 1: The entire architecutre of ShiftViT.
+[Source](https://arxiv.org/abs/2201.10801) |
+
+The architecture as shown in Fig. 1, is inspired by
+[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030).
+Here the authors propose a modular architecture with 4 stages. Each stage works on its
+own spatial size, creating a hierarchical architecture.
+
+An input image of size `HxWx3` is split into non-overlapping patches of size `4x4`.
+This is done via the patchify layer which results in individual tokens of feature size `48`
+(`4x4x3`). Each stage comprises two parts:
+
+1. Embedding Generation
+2. Stacked Shift Blocks
+
+We discuss the stages and the modules in detail in what follows.
+
+_Note_: Compared to the [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py)
+we restructure some key components to better fit the Keras API.
+"""
+
+"""
+### The ShiftViT Block
+
+|  |
+| :--: |
+| Figure 2: From the Model to a Shift Block. |
+
+Each stage in the ShiftViT architecture comprises of a Shift Block as shown in Fig 2.
+
+|  |
+| :--: |
+| Figure 3: The Shift ViT Block. [Source](https://arxiv.org/abs/2201.10801) |
+
+The Shift Block as shown in Fig. 3, comprises of the following:
+
+1. Shift Operation
+2. Linear Normalization
+3. MLP Layer
+"""
+
+"""
+#### The MLP block
+
+The MLP block is intended to be a stack of densely-connected layers.s
+"""
+
+
+class MLP(layers.Layer):
+ """Get the MLP layer for each shift block.
+
+ Args:
+ mlp_expand_ratio (int): The ratio with which the first feature map is expanded.
+ mlp_dropout_rate (float): The rate for dropout.
+ """
+
+ def __init__(self, mlp_expand_ratio, mlp_dropout_rate, **kwargs):
+ super().__init__(**kwargs)
+ self.mlp_expand_ratio = mlp_expand_ratio
+ self.mlp_dropout_rate = mlp_dropout_rate
+
+ def build(self, input_shape):
+ input_channels = input_shape[-1]
+ initial_filters = int(self.mlp_expand_ratio * input_channels)
+
+ self.mlp = keras.Sequential(
+ [
+ layers.Dense(units=initial_filters, activation=tf.nn.gelu,),
+ layers.Dropout(rate=self.mlp_dropout_rate),
+ layers.Dense(units=input_channels),
+ layers.Dropout(rate=self.mlp_dropout_rate),
+ ]
+ )
+
+ def call(self, x):
+ x = self.mlp(x)
+ return x
+
+
+"""
+#### The DropPath layer
+
+Stochastic depth is a regularization technique that randomly drops a set of
+layers. During inference, the layers are kept as they are. It is very
+similar to Dropout, but it operates on a block of layers rather
+than on individual nodes present inside a layer.
+"""
+
+
+class DropPath(layers.Layer):
+ """Drop Path also known as the Stochastic Depth layer.
+
+ Refernece:
+ - https://keras.io/examples/vision/cct/#stochastic-depth-for-regularization
+ - github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path_prob, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path_prob = drop_path_prob
+
+ def call(self, x, training=False):
+ if training:
+ keep_prob = 1 - self.drop_path_prob
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+"""
+#### Block
+
+The most important operation in this paper is the **shift opperation**. In this section,
+we describe the shift operation and compare it with its original implementation provided
+by the authors.
+
+A generic feature map is assumed to have the shape `[N, H, W, C]`. Here we choose a
+`num_div` parameter that decides the division size of the channels. The first 4 divisions
+are shifted (1 pixel) in the left, right, up, and down direction. The remaining splits
+are kept as is. After partial shifting the shifted channels are padded and the overflown
+pixels are chopped off. This completes the partial shifting operation.
+
+In the original implementation, the code is approximately:
+
+```python
+out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # shift left
+out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # shift right
+out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # shift up
+out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # shift down
+
+out[:, g * 4:, :, :] = x[:, g * 4:, :, :] # no shift
+```
+
+In TensorFlow it would be infeasible for us to assign shifted channels to a tensor in the
+middle of the training process. This is why we have resorted to the following procedure:
+
+1. Split the channels with the `num_div` parameter.
+2. Select each of the first four spilts and shift and pad them in the respective
+directions.
+3. After shifting and padding, we concatenate the channel back.
+
+|  |
+| :--: |
+| Figure 4: The TensorFlow style shifting |
+
+The entire procedure is explained in the Fig. 4.
+"""
+
+
+class ShiftViTBlock(layers.Layer):
+ """A unit ShiftViT Block
+
+ Args:
+ shift_pixel (int): The number of pixels to shift. Default to 1.
+ mlp_expand_ratio (int): The ratio with which MLP features are
+ expanded. Default to 2.
+ mlp_dropout_rate (float): The dropout rate used in MLP.
+ num_div (int): The number of divisions of the feature map's channel.
+ Totally, 4/num_div of channels will be shifted. Defaults to 12.
+ epsilon (float): Epsilon constant.
+ drop_path_prob (float): The drop probability for drop path.
+ """
+
+ def __init__(
+ self,
+ epsilon,
+ drop_path_prob,
+ mlp_dropout_rate,
+ num_div=12,
+ shift_pixel=1,
+ mlp_expand_ratio=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.shift_pixel = shift_pixel
+ self.mlp_expand_ratio = mlp_expand_ratio
+ self.mlp_dropout_rate = mlp_dropout_rate
+ self.num_div = num_div
+ self.epsilon = epsilon
+ self.drop_path_prob = drop_path_prob
+
+ def build(self, input_shape):
+ self.H = input_shape[1]
+ self.W = input_shape[2]
+ self.C = input_shape[3]
+ self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
+ self.drop_path = (
+ DropPath(drop_path_prob=self.drop_path_prob)
+ if self.drop_path_prob > 0.0
+ else layers.Activation("linear")
+ )
+ self.mlp = MLP(
+ mlp_expand_ratio=self.mlp_expand_ratio,
+ mlp_dropout_rate=self.mlp_dropout_rate,
+ )
+
+ def get_shift_pad(self, x, mode):
+ """Shifts the channels according to the mode chosen."""
+ if mode == "left":
+ offset_height = 0
+ offset_width = 0
+ target_height = 0
+ target_width = self.shift_pixel
+ elif mode == "right":
+ offset_height = 0
+ offset_width = self.shift_pixel
+ target_height = 0
+ target_width = self.shift_pixel
+ elif mode == "up":
+ offset_height = 0
+ offset_width = 0
+ target_height = self.shift_pixel
+ target_width = 0
+ else:
+ offset_height = self.shift_pixel
+ offset_width = 0
+ target_height = self.shift_pixel
+ target_width = 0
+ crop = tf.image.crop_to_bounding_box(
+ x,
+ offset_height=offset_height,
+ offset_width=offset_width,
+ target_height=self.H - target_height,
+ target_width=self.W - target_width,
+ )
+ shift_pad = tf.image.pad_to_bounding_box(
+ crop,
+ offset_height=offset_height,
+ offset_width=offset_width,
+ target_height=self.H,
+ target_width=self.W,
+ )
+ return shift_pad
+
+ def call(self, x, training=False):
+ # Split the feature maps
+ x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1)
+
+ # Shift the feature maps
+ x_splits[0] = self.get_shift_pad(x_splits[0], mode="left")
+ x_splits[1] = self.get_shift_pad(x_splits[1], mode="right")
+ x_splits[2] = self.get_shift_pad(x_splits[2], mode="up")
+ x_splits[3] = self.get_shift_pad(x_splits[3], mode="down")
+
+ # Concatenate the shifted and unshifted feature maps
+ x = tf.concat(x_splits, axis=-1)
+
+ # Add the residual connection
+ shortcut = x
+ x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training)
+ return x
+
+
+"""
+### The ShiftViT blocks
+
+|  |
+| :--: |
+| Figure 5: Shift Blocks in the architecture. [Source](https://arxiv.org/abs/2201.10801) |
+
+Each stage of the architecture has shift blocks as shown in Fig.5. Each of these blocks
+contain a variable number of stacked ShiftViT block (as built in the earlier section).
+
+Shift blocks are followed by a PatchMerging layer that scales down feature inputs. The
+PatchMerging layer helps in the pyramidal structure of the model.
+"""
+
+"""
+#### The PatchMerging layer
+
+This layer merges the two adjacent tokens. This layer helps in scaling the features down
+spatially and increasing the features up channel wise. We use a Conv2D layer to merge the
+patches.
+"""
+
+
+class PatchMerging(layers.Layer):
+ """The Patch Merging layer.
+
+ Args:
+ epsilon (float): The epsilon constant.
+ """
+
+ def __init__(self, epsilon, **kwargs):
+ super().__init__(**kwargs)
+ self.epsilon = epsilon
+
+ def build(self, input_shape):
+ filters = 2 * input_shape[-1]
+ self.reduction = layers.Conv2D(
+ filters=filters, kernel_size=2, strides=2, padding="same", use_bias=False
+ )
+ self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
+
+ def call(self, x):
+ # Apply the patch merging algorithm on the feature maps
+ x = self.layer_norm(x)
+ x = self.reduction(x)
+ return x
+
+
+"""
+#### Stacked Shift Blocks
+
+Each stage will have a variable number of stacked ShiftViT Blocks, as suggested in
+the paper. This is a generic layer that will contain the stacked shift vit blocks
+with the patch merging layer as well. Combining the two operations (shift ViT
+block and patch merging) is a design choice we picked for better code reusability.
+"""
+
+# Note: This layer will have a different depth of stacking
+# for different stages on the model.
+class StackedShiftBlocks(layers.Layer):
+ """The layer containing stacked ShiftViTBlocks.
+
+ Args:
+ epsilon (float): The epsilon constant.
+ mlp_dropout_rate (float): The dropout rate used in the MLP block.
+ num_shift_blocks (int): The number of shift vit blocks for this stage.
+ stochastic_depth_rate (float): The maximum drop path rate chosen.
+ is_merge (boolean): A flag that determines the use of the Patch Merge
+ layer after the shift vit blocks.
+ num_div (int): The division of channels of the feature map. Defaults to 12.
+ shift_pixel (int): The number of pixels to shift. Defaults to 1.
+ mlp_expand_ratio (int): The ratio with which the initial dense layer of
+ the MLP is expanded Defaults to 2.
+ """
+
+ def __init__(
+ self,
+ epsilon,
+ mlp_dropout_rate,
+ num_shift_blocks,
+ stochastic_depth_rate,
+ is_merge,
+ num_div=12,
+ shift_pixel=1,
+ mlp_expand_ratio=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.epsilon = epsilon
+ self.mlp_dropout_rate = mlp_dropout_rate
+ self.num_shift_blocks = num_shift_blocks
+ self.stochastic_depth_rate = stochastic_depth_rate
+ self.is_merge = is_merge
+ self.num_div = num_div
+ self.shift_pixel = shift_pixel
+ self.mlp_expand_ratio = mlp_expand_ratio
+
+ def build(self, input_shapes):
+ # Calculate stochastic depth probabilities.
+ # Reference: https://keras.io/examples/vision/cct/#the-final-cct-model
+ dpr = [
+ x
+ for x in np.linspace(
+ start=0, stop=self.stochastic_depth_rate, num=self.num_shift_blocks
+ )
+ ]
+
+ # Build the shift blocks as a list of ShiftViT Blocks
+ self.shift_blocks = list()
+ for num in range(self.num_shift_blocks):
+ self.shift_blocks.append(
+ ShiftViTBlock(
+ num_div=self.num_div,
+ epsilon=self.epsilon,
+ drop_path_prob=dpr[num],
+ mlp_dropout_rate=self.mlp_dropout_rate,
+ shift_pixel=self.shift_pixel,
+ mlp_expand_ratio=self.mlp_expand_ratio,
+ )
+ )
+ if self.is_merge:
+ self.patch_merge = PatchMerging(epsilon=self.epsilon)
+
+ def call(self, x, training=False):
+ for shift_block in self.shift_blocks:
+ x = shift_block(x, training=training)
+ if self.is_merge:
+ x = self.patch_merge(x)
+ return x
+
+
+"""
+## The ShiftViT model
+
+Build the ShiftViT custom model.
+"""
+
+
+class ShiftViTModel(keras.Model):
+ """The ShiftViT Model.
+
+ Args:
+ data_augmentation (keras.Model): A data augmentation model.
+ projected_dim (int): The dimension to which the patches of the image are
+ projected.
+ patch_size (int): The patch size of the images.
+ num_shift_blocks_per_stages (list[int]): A list of all the number of shit
+ blocks per stage.
+ epsilon (float): The epsilon constant.
+ mlp_dropout_rate (float): The dropout rate used in the MLP block.
+ stochastic_depth_rate (float): The maximum drop rate probability.
+ num_div (int): The number of divisions of the channesl of the feature
+ map. Defaults to 12.
+ shift_pixel (int): The number of pixel to shift. Default to 1.
+ mlp_expand_ratio (int): The ratio with which the initial mlp dense layer
+ is expanded to. Defaults to 2.
+ """
+
+ def __init__(
+ self,
+ data_augmentation,
+ projected_dim,
+ patch_size,
+ num_shift_blocks_per_stages,
+ epsilon,
+ mlp_dropout_rate,
+ stochastic_depth_rate,
+ num_div=12,
+ shift_pixel=1,
+ mlp_expand_ratio=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.data_augmentation = data_augmentation
+ self.patch_projection = layers.Conv2D(
+ filters=projected_dim,
+ kernel_size=patch_size,
+ strides=patch_size,
+ padding="same",
+ )
+ self.stages = list()
+ for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages):
+ if index == len(num_shift_blocks_per_stages) - 1:
+ # This is the last stage, do not use the patch merge here.
+ is_merge = False
+ else:
+ is_merge = True
+ # Build the stages.
+ self.stages.append(
+ StackedShiftBlocks(
+ epsilon=epsilon,
+ mlp_dropout_rate=mlp_dropout_rate,
+ num_shift_blocks=num_shift_blocks,
+ stochastic_depth_rate=stochastic_depth_rate,
+ is_merge=is_merge,
+ num_div=num_div,
+ shift_pixel=shift_pixel,
+ mlp_expand_ratio=mlp_expand_ratio,
+ )
+ )
+ self.global_avg_pool = layers.GlobalAveragePooling2D()
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "data_augmentation": self.data_augmentation,
+ "patch_projection": self.patch_projection,
+ "stages": self.stages,
+ "global_avg_pool": self.global_avg_pool,
+ }
+ )
+ return config
+
+ def _calculate_loss(self, data, training=False):
+ (images, labels) = data
+
+ # Augment the images
+ augmented_images = self.data_augmentation(images, training=training)
+
+ # Create patches and project the pathces.
+ projected_patches = self.patch_projection(augmented_images)
+
+ # Pass through the stages
+ x = projected_patches
+ for stage in self.stages:
+ x = stage(x, training=training)
+
+ # Get the logits.
+ logits = self.global_avg_pool(x)
+
+ # Calculate the loss and return it.
+ total_loss = self.compiled_loss(labels, logits)
+ return total_loss, labels, logits
+
+ def train_step(self, inputs):
+ with tf.GradientTape() as tape:
+ total_loss, labels, logits = self._calculate_loss(
+ data=inputs, training=True
+ )
+
+ # Apply gradients.
+ train_vars = [
+ self.data_augmentation.trainable_variables,
+ self.patch_projection.trainable_variables,
+ self.global_avg_pool.trainable_variables,
+ ]
+ train_vars = train_vars + [stage.trainable_variables for stage in self.stages]
+
+ # Optimize the gradients.
+ grads = tape.gradient(total_loss, train_vars)
+ trainable_variable_list = []
+ for (grad, var) in zip(grads, train_vars):
+ for g, v in zip(grad, var):
+ trainable_variable_list.append((g, v))
+ self.optimizer.apply_gradients(trainable_variable_list)
+
+ # Update the metrics
+ self.compiled_metrics.update_state(labels, logits)
+ return {m.name: m.result() for m in self.metrics}
+
+ def test_step(self, data):
+ _, labels, logits = self._calculate_loss(data=data, training=False)
+
+ # Update the metrics
+ self.compiled_metrics.update_state(labels, logits)
+ return {m.name: m.result() for m in self.metrics}
+
+
+"""
+## Instantiate the model
+"""
+
+model = ShiftViTModel(
+ data_augmentation=get_augmentation_model(),
+ projected_dim=config.projected_dim,
+ patch_size=config.patch_size,
+ num_shift_blocks_per_stages=config.num_shift_blocks_per_stages,
+ epsilon=config.epsilon,
+ mlp_dropout_rate=config.mlp_dropout_rate,
+ stochastic_depth_rate=config.stochastic_depth_rate,
+ num_div=config.num_div,
+ shift_pixel=config.shift_pixel,
+ mlp_expand_ratio=config.mlp_expand_ratio,
+)
+
+"""
+## Learning rate schedule
+
+In many experiments, we want to warm up the model with a slowly increasing learning rate
+and then cool down the model with a slowly decaying learning rate. In the warmup cosine
+decay, the learning rate linearly increases for the warmup steps and then decays with a
+cosine decay.
+"""
+
+
+class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
+ """A LearningRateSchedule that uses a warmup cosine decay schedule."""
+
+ def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
+ """
+ Args:
+ lr_start: The initial learning rate
+ lr_max: The maximum learning rate to which lr should increase to in
+ the warmup steps
+ warmup_steps: The number of steps for which the model warms up
+ total_steps: The total number of steps for the model training
+ """
+ super().__init__()
+ self.lr_start = lr_start
+ self.lr_max = lr_max
+ self.warmup_steps = warmup_steps
+ self.total_steps = total_steps
+ self.pi = tf.constant(np.pi)
+
+ def __call__(self, step):
+ # Check whether the total number of steps is larger than the warmup
+ # steps. If not, then throw a value error.
+ if self.total_steps < self.warmup_steps:
+ raise ValueError(
+ f"Total number of steps {self.total_steps} must be"
+ + f"larger or equal to warmup steps {self.warmup_steps}."
+ )
+
+ # `cos_annealed_lr` is a graph that increases to 1 from the initial
+ # step to the warmup step. After that this graph decays to -1 at the
+ # final step mark.
+ cos_annealed_lr = tf.cos(
+ self.pi
+ * (tf.cast(step, tf.float32) - self.warmup_steps)
+ / tf.cast(self.total_steps - self.warmup_steps, tf.float32)
+ )
+
+ # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
+ # from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0
+ # to 1. With the normalized graph we scale it with `lr_max` such that
+ # it goes from 0 to `lr_max`
+ learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)
+
+ # Check whether warmup_steps is more than 0.
+ if self.warmup_steps > 0:
+ # Check whether lr_max is larger that lr_start. If not, throw a value
+ # error.
+ if self.lr_max < self.lr_start:
+ raise ValueError(
+ f"lr_start {self.lr_start} must be smaller or"
+ + f"equal to lr_max {self.lr_max}."
+ )
+
+ # Calculate the slope with which the learning rate should increase
+ # in the warumup schedule. The formula for slope is m = ((b-a)/steps)
+ slope = (self.lr_max - self.lr_start) / self.warmup_steps
+
+ # With the formula for a straight line (y = mx+c) build the warmup
+ # schedule
+ warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
+
+ # When the current step is lesser that warmup steps, get the line
+ # graph. When the current step is greater than the warmup steps, get
+ # the scaled cos graph.
+ learning_rate = tf.where(
+ step < self.warmup_steps, warmup_rate, learning_rate
+ )
+
+ # When the current step is more that the total steps, return 0 else return
+ # the calculated graph.
+ return tf.where(
+ step > self.total_steps, 0.0, learning_rate, name="learning_rate"
+ )
+
+
+"""
+## Compile and train the model
+"""
+
+# Get the total number of steps for training.
+total_steps = int((len(x_train) / config.batch_size) * config.epochs)
+
+# Calculate the number of steps for warmup.
+warmup_epoch_percentage = 0.15
+warmup_steps = int(total_steps * warmup_epoch_percentage)
+
+# Initialize the warmupcosine schedule.
+scheduled_lrs = WarmUpCosine(
+ lr_start=1e-5, lr_max=1e-3, warmup_steps=warmup_steps, total_steps=total_steps,
+)
+
+# Get the optimizer.
+optimizer = tfa.optimizers.AdamW(
+ learning_rate=scheduled_lrs, weight_decay=config.weight_decay
+)
+
+# Compile and pretrain the model.
+model.compile(
+ optimizer=optimizer,
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+)
+
+# Train the model
+history = model.fit(
+ train_ds,
+ epochs=config.epochs,
+ validation_data=val_ds,
+ callbacks=[
+ keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5, mode="auto",)
+ ],
+)
+
+# Evaluate the model with the test dataset.
+print("TESTING")
+loss, acc_top1, acc_top5 = model.evaluate(test_ds)
+print(f"Loss: {loss:0.2f}")
+print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%")
+print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
+
+"""
+## Conclusion
+
+The most impactful contribution of the paper is not the novel architecture, but
+the idea that hierarchical ViTs trained with no attention can perform quite well. This
+opens up the question of how essential attention is to the performance of ViTs.
+
+For curious minds, we would suggest reading the
+[ConvNexT](https://arxiv.org/abs/2201.03545) paper which attends more to the training
+paradigms and architectural details of ViTs rather than providing a novel architecture
+based on attention.
+
+Acknowledgements:
+
+- We would like to thank [PyImageSearch](https://pyimagesearch.com) for providing us with
+resources that helped in the completion of this project.
+- We would like to thank [JarvisLabs.ai](https://jarvislabs.ai/) for providing with the
+GPU credits.
+- We would like to thank [Manim Community](https://www.manim.community/) for the manim
+library.
+- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for
+helping us with the Learning Rate Schedule.
+"""