# Image classification with Swin Transformers

**Author:** [Rishit Dagli](https://twitter.com/rishit_dagli)<br>
**Date created:** 2021/09/08<br>
**Last modified:** 2021/09/08<br>
**Description:** Image classification using Swin Transformers, a general-purpose backbone for computer vision.

This example implements [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
by Liu et al. for image classification, and demonstrates it on the
[CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html).

Swin Transformer (**S**hifted **Win**dow Transformer) can serve as a general-purpose backbone
for computer vision. Swin Transformer is a hierarchical Transformer whose
representations are computed with _shifted windows_. The shifted window scheme
brings greater efficiency by limiting self-attention computation to
non-overlapping local windows while also allowing for cross-window connections.
This architecture has the flexibility to model information at various scales and has
a linear computational complexity with respect to image size.

This example requires TensorFlow 2.5 or higher, as well as TensorFlow Addons,
which can be installed using the following commands:

# 新段落

## Setup

In [50]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
import os
from sklearn import model_selection
from PIL import Image
from tensorflow.keras.utils import Sequence


In [2]:
batch_size = 100
num_classes = 2
# # root = "C:/Users/wuwul/Desktop/f/f"
# from keras.preprocessing.image import ImageDataGenerator
# def read_image(image_name):
#     im = Image.open(image_name)
# # #     .convert('L')
#     data = np.array(im)
#     return data[:,:,:3]

# # Generator = ImageDataGenerator()
# # train_data = Generator.flow_from_directory(train_root, (100, 100), batch_size=batch_size)
# # test_data = Generator.flow_from_directory(test_root, (100, 100), batch_size=batch_size)
# # print(train_data)
# images = []
# labels = []
# test = os.listdir("C:/Users/wuwul/Desktop/f/f")
# print(test)
# for testpath in test:
#     for fn in os.listdir(os.path.join("C:/Users/wuwul/Desktop/f/f",testpath)):
#         if fn.endswith('.png'):
#             fd = os.path.join("C:/Users/wuwul/Desktop/f/f",testpath,fn)
# # #             print(fd)
#             images.append(read_image(fd))
#             labels.append(testpath)
# X = np.array(images)
# Y = np.array(list(map(int,labels)))


# X_train,X_test,Y_train,Y_test = model_selection.train_test_split(X,Y,test_size = 0.3,random_state = 0)
# X_train = X_train.astype(np.float32)
# Y_train = Y_train.astype(np.float32)

    # #optional
# print(train_data[0][0][0].shape)
# # total 4317 data below to 5 clasess
# print(len(train_data)) #4317/batch size
# print(len(train_data[0])) #2, 1st image, 2nd is label
# #print(train_data[0])
# print(len(train_data[0][0])) #1st batch of 10 data
# print(len(train_data[0][0][0])) #the image, the vertical
# print(len(train_data[0][0][0][0])) #the image, the horizontal
# print(len(train_data[0][0][0][0][0])) #the image, RGB



In [3]:
train_root = "C:/Users/wuwul/Desktop/Face Mask Dataset 2/Train"
test_root = "C:/Users/wuwul/Desktop/Face Mask Dataset 2/Test"

In [4]:
from keras.preprocessing.image import ImageDataGenerator

Generator = ImageDataGenerator()
train_data = Generator.flow_from_directory(train_root, (224, 224), batch_size=100,classes = ['withmask','withoutmask'],seed=0)
test_data =  Generator.flow_from_directory(test_root, (224, 224), batch_size=100,classes = ['withmask','withoutmask'],seed=0)

Found 10991 images belonging to 2 classes.
Found 1395 images belonging to 2 classes.


In [5]:
m = train_data.next()

In [6]:
m[0]

array([[[[217., 213., 222.],
         [208., 199., 204.],
         [204., 190., 199.],
         ...,
         [236., 222., 236.],
         [236., 222., 236.],
         [236., 222., 236.]],

        [[217., 213., 222.],
         [213., 199., 208.],
         [204., 190., 199.],
         ...,
         [236., 222., 236.],
         [236., 222., 236.],
         [236., 222., 236.]],

        [[217., 208., 217.],
         [213., 199., 204.],
         [204., 185., 190.],
         ...,
         [236., 217., 236.],
         [231., 217., 236.],
         [231., 217., 231.]],

        ...,

        [[148., 111.,  83.],
         [148., 111.,  83.],
         [148., 111.,  83.],
         ...,
         [ 69.,  55.,  69.],
         [ 69.,  51.,  64.],
         [ 69.,  51.,  64.]],

        [[148., 111.,  83.],
         [148., 111.,  83.],
         [143., 106.,  83.],
         ...,
         [ 74.,  55.,  69.],
         [ 69.,  51.,  64.],
         [ 69.,  51.,  64.]],

        [[139., 102.,  74.],
       

## Prepare the data

We load the CIFAR-100 dataset through `tf.keras.datasets`,
normalize the images, and convert the integer labels to one-hot encoded vectors.

In [40]:
num_classes = 2
input_shape = (224,224,3)

# # # (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
# X_train, X_test = X_train / 255.0, X_test / 255.0
# Y_train = keras.utils.to_categorical(Y_train, num_classes)
# Y_test = keras.utils.to_categorical(Y_test, num_classes)
# # print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
# # print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

# plt.figure(figsize=(10, 10))
# for i in range(25):
#     plt.subplot(5, 5, i + 1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(X_train[i])
# plt.show()

In [41]:
train_data[2]

(array([[[[183., 186., 184.],
          [184., 187., 184.],
          [188., 191., 188.],
          ...,
          [180., 174., 172.],
          [173., 169., 165.],
          [172., 167., 164.]],
 
         [[183., 186., 184.],
          [184., 186., 184.],
          [187., 190., 187.],
          ...,
          [178., 174., 170.],
          [173., 168., 165.],
          [172., 168., 164.]],
 
         [[183., 186., 184.],
          [183., 186., 183.],
          [182., 185., 182.],
          ...,
          [168., 164., 160.],
          [173., 167., 165.],
          [173., 167., 165.]],
 
         ...,
 
         [[109., 138., 128.],
          [112., 141., 130.],
          [132., 162., 149.],
          ...,
          [ 61.,  61.,  61.],
          [120., 123., 121.],
          [130., 133., 131.]],
 
         [[125., 150., 142.],
          [128., 153., 145.],
          [145., 170., 160.],
          ...,
          [ 79.,  82.,  77.],
          [121., 121., 116.],
          [127., 128., 123.

In [36]:
# X_train.shape

## Configure the hyperparameters

A key parameter to pick is the `patch_size`, the size of the input patches.
In order to use each pixel as an individual input, you can set `patch_size` to `(1, 1)`.
Below, we take inspiration from the original paper settings
for training on ImageNet-1K, keeping most of the original settings for this example.

In [10]:
# Y_train.shape

In [11]:
num_classes = 2
input_shape = (224,224,3)

patch_size = (2, 2)  # 2-by-2 sized patches
dropout_rate = 0.03  # Dropout rate
num_heads = 8  # Attention heads
embed_dim = 64  # Embedding dimension
num_mlp = 256  # MLP layer size
qkv_bias = True  # Convert embedded patches to query, key, and values with a learnable additive value
window_size = 2  # Size of attention window
shift_size = 1  # Size of shifting window
image_dimension = 224  # Initial image size

num_patch_x = input_shape[0] // patch_size[0]
num_patch_y = input_shape[1] // patch_size[1]

learning_rate = 1e-3

num_epochs = 4
validation_split = 0.1
weight_decay = 0.0001
label_smoothing = 0.1

## Helper functions

We create two helper functions to help us get a sequence of
patches from the image, merge patches, and apply dropout.

In [12]:

def window_partition(x, window_size):
    _, height, width, channels = x.shape
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels)
    )
    x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
    windows = tf.reshape(x, shape=(-1, window_size, window_size, channels))
    return windows


def window_reverse(windows, window_size, height, width, channels):
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        windows,
        shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels),
    )
    x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5))
    x = tf.reshape(x, shape=(-1, height, width, channels))
    return x


class DropPath(layers.Layer):
    def __init__(self, drop_prob=None, **kwargs):
        super(DropPath, self).__init__(**kwargs)
        self.drop_prob = drop_prob

    def call(self, x):
        input_shape = tf.shape(x)
        batch_size = input_shape[0]
        rank = x.shape.rank
        shape = (batch_size,) + (1,) * (rank - 1)
        random_tensor = (1 - self.drop_prob) + tf.random.uniform(shape, dtype=x.dtype)
        path_mask = tf.floor(random_tensor)
        output = tf.math.divide(x, 1 - self.drop_prob) * path_mask
        return output


## Window based multi-head self-attention

Usually Transformers perform global self-attention, where the relationships between
a token and all other tokens are computed. The global computation leads to quadratic
complexity with respect to the number of tokens. Here, as the [original paper](https://arxiv.org/abs/2103.14030)
suggests, we compute self-attention within local windows, in a non-overlapping manner.
Global self-attention leads to quadratic computational complexity in the number of patches,
whereas window-based self-attention leads to linear complexity and is easily scalable.

In [13]:

class WindowAttention(layers.Layer):
    def __init__(
        self, dim, window_size, num_heads, qkv_bias=True, dropout_rate=0.0, **kwargs
    ):
        super(WindowAttention, self).__init__(**kwargs)
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
        self.dropout = layers.Dropout(dropout_rate)
        self.proj = layers.Dense(dim)

    def build(self, input_shape):
        num_window_elements = (2 * self.window_size[0] - 1) * (
            2 * self.window_size[1] - 1
        )
        self.relative_position_bias_table = self.add_weight(
            shape=(num_window_elements, self.num_heads),
            initializer=tf.initializers.Zeros(),
            trainable=True,
        )
        coords_h = np.arange(self.window_size[0])
        coords_w = np.arange(self.window_size[1])
        coords_matrix = np.meshgrid(coords_h, coords_w, indexing="ij")
        coords = np.stack(coords_matrix)
        coords_flatten = coords.reshape(2, -1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.transpose([1, 2, 0])
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)

        self.relative_position_index = tf.Variable(
            initial_value=tf.convert_to_tensor(relative_position_index), trainable=False
        )

    def call(self, x, mask=None):
        _, size, channels = x.shape
        head_dim = channels // self.num_heads
        x_qkv = self.qkv(x)
        x_qkv = tf.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim))
        x_qkv = tf.transpose(x_qkv, perm=(2, 0, 3, 1, 4))
        q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
        q = q * self.scale
        k = tf.transpose(k, perm=(0, 1, 3, 2))
        attn = q @ k

        num_window_elements = self.window_size[0] * self.window_size[1]
        relative_position_index_flat = tf.reshape(
            self.relative_position_index, shape=(-1,)
        )
        relative_position_bias = tf.gather(
            self.relative_position_bias_table, relative_position_index_flat
        )
        relative_position_bias = tf.reshape(
            relative_position_bias, shape=(num_window_elements, num_window_elements, -1)
        )
        relative_position_bias = tf.transpose(relative_position_bias, perm=(2, 0, 1))
        attn = attn + tf.expand_dims(relative_position_bias, axis=0)

        if mask is not None:
            nW = mask.get_shape()[0]
            mask_float = tf.cast(
                tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32
            )
            attn = (
                tf.reshape(attn, shape=(-1, nW, self.num_heads, size, size))
                + mask_float
            )
            attn = tf.reshape(attn, shape=(-1, self.num_heads, size, size))
            attn = keras.activations.softmax(attn, axis=-1)
        else:
            attn = keras.activations.softmax(attn, axis=-1)
        attn = self.dropout(attn)

        x_qkv = attn @ v
        x_qkv = tf.transpose(x_qkv, perm=(0, 2, 1, 3))
        x_qkv = tf.reshape(x_qkv, shape=(-1, size, channels))
        x_qkv = self.proj(x_qkv)
        x_qkv = self.dropout(x_qkv)
        return x_qkv


## The complete Swin Transformer model

Finally, we put together the complete Swin Transformer by replacing the standard multi-head
attention (MHA) with shifted windows attention. As suggested in the
original paper, we create a model comprising of a shifted window-based MHA
layer, followed by a 2-layer MLP with GELU nonlinearity in between, applying
`LayerNormalization` before each MSA layer and each MLP, and a residual
connection after each of these layers.

Notice that we only create a simple MLP with 2 Dense and
2 Dropout layers. Often you will see models using ResNet-50 as the MLP which is
quite standard in the literature. However in this paper the authors use a
2-layer MLP with GELU nonlinearity in between.

In [14]:

class SwinTransformer(layers.Layer):
    def __init__(
        self,
        dim,
        num_patch,
        num_heads,
        window_size=7,
        shift_size=0,
        num_mlp=1024,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super(SwinTransformer, self).__init__(**kwargs)

        self.dim = dim  # number of input dimensions
        self.num_patch = num_patch  # number of embedded patches
        self.num_heads = num_heads  # number of attention heads
        self.window_size = window_size  # size of window
        self.shift_size = shift_size  # size of window shift
        self.num_mlp = num_mlp  # number of MLP nodes

        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(
            dim,
            window_size=(self.window_size, self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            dropout_rate=dropout_rate,
        )
        self.drop_path = DropPath(dropout_rate)
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)

        self.mlp = keras.Sequential(
            [
                layers.Dense(num_mlp),
                layers.Activation(keras.activations.gelu),
                layers.Dropout(dropout_rate),
                layers.Dense(dim),
                layers.Dropout(dropout_rate),
            ]
        )

        if min(self.num_patch) < self.window_size:
            self.shift_size = 0
            self.window_size = min(self.num_patch)

    def build(self, input_shape):
        if self.shift_size == 0:
            self.attn_mask = None
        else:
            height, width = self.num_patch
            h_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            w_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            mask_array = np.zeros((1, height, width, 1))
            count = 0
            for h in h_slices:
                for w in w_slices:
                    mask_array[:, h, w, :] = count
                    count += 1
            mask_array = tf.convert_to_tensor(mask_array)

            # mask array to windows
            mask_windows = window_partition(mask_array, self.window_size)
            mask_windows = tf.reshape(
                mask_windows, shape=[-1, self.window_size * self.window_size]
            )
            attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(
                mask_windows, axis=2
            )
            attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
            attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
            self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False)

    def call(self, x):
        height, width = self.num_patch
        _, num_patches_before, channels = x.shape
        x_skip = x
        x = self.norm1(x)
        x = tf.reshape(x, shape=(-1, height, width, channels))
        if self.shift_size > 0:
            shifted_x = tf.roll(
                x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
            )
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = tf.reshape(
            x_windows, shape=(-1, self.window_size * self.window_size, channels)
        )
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        attn_windows = tf.reshape(
            attn_windows, shape=(-1, self.window_size, self.window_size, channels)
        )
        shifted_x = window_reverse(
            attn_windows, self.window_size, height, width, channels
        )
        if self.shift_size > 0:
            x = tf.roll(
                shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
            )
        else:
            x = shifted_x

        x = tf.reshape(x, shape=(-1, height * width, channels))
        x = self.drop_path(x)
        x = x_skip + x
        x_skip = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = x_skip + x
        return x


## Model training and evaluation

### Extract and embed patches

We first create 3 layers to help us extract, embed and merge patches from the
images on top of which we will later use the Swin Transformer class we built.

In [15]:

class PatchExtract(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super(PatchExtract, self).__init__(**kwargs)
        self.patch_size_x = patch_size[0]
        self.patch_size_y = patch_size[0]

    def call(self, images):
        batch_size = tf.shape(images)[0]
        print(batch_size)
        patches = tf.image.extract_patches(
            images=images,
            sizes=(1, self.patch_size_x, self.patch_size_y, 1),
            strides=(1, self.patch_size_x, self.patch_size_y, 1),
            rates=(1, 1, 1, 1),
            padding="VALID",
        )
        patch_dim = patches.shape[-1]
        patch_num = patches.shape[1]
        return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))


class PatchEmbedding(layers.Layer):
    def __init__(self, num_patch, embed_dim, **kwargs):
        super(PatchEmbedding, self).__init__(**kwargs)
        self.num_patch = num_patch
        self.proj = layers.Dense(embed_dim)
        self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

    def call(self, patch):
        pos = tf.range(start=0, limit=self.num_patch, delta=1)
        return self.proj(patch) + self.pos_embed(pos)


class PatchMerging(tf.keras.layers.Layer):
    def __init__(self, num_patch, embed_dim):
        super(PatchMerging, self).__init__()
        self.num_patch = num_patch
        self.embed_dim = embed_dim
        self.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)

    def call(self, x):
        height, width = self.num_patch
        _, _, C = x.get_shape().as_list()
        x = tf.reshape(x, shape=(-1, height, width, C))
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = tf.concat((x0, x1, x2, x3), axis=-1)
        x = tf.reshape(x, shape=(-1, (height // 2) * (width // 2), 4 * C))
        return self.linear_trans(x)


### Build the model

We put together the Swin Transformer model.

In [16]:
input = layers.Input(input_shape)
x = layers.RandomCrop(image_dimension, image_dimension)(input)
x = layers.RandomFlip("horizontal")(x)
x = PatchExtract(patch_size)(x)
x = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(x)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads,
    window_size=window_size,
    shift_size=0,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads,
    window_size=window_size,
    shift_size=shift_size,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)(x)
x = layers.GlobalAveragePooling1D()(x)
output = layers.Dense(2, activation="softmax")(x)

Tensor("patch_extract/strided_slice:0", shape=(), dtype=int32)


### Train on CIFAR-100

We train the model on CIFAR-100. Here, we only train the model
for 40 epochs to keep the training time short in this example.
In practice, you should train for 150 epochs to reach convergence.

In [46]:
from keras import backend as K

In [47]:
K.set_floatx('float64')

In [48]:
input_data = np.random.rand(1,10,10,1)
input_data

array([[[[0.75195356],
         [0.06166533],
         [0.74482367],
         [0.94627593],
         [0.60355955],
         [0.28757994],
         [0.67236922],
         [0.71204879],
         [0.65645029],
         [0.14693032]],

        [[0.97347557],
         [0.95538345],
         [0.42462553],
         [0.59363733],
         [0.03962857],
         [0.98863433],
         [0.81874499],
         [0.63650234],
         [0.76108474],
         [0.18802993]],

        [[0.30765459],
         [0.24639364],
         [0.59604912],
         [0.09190485],
         [0.89560999],
         [0.46227594],
         [0.44481236],
         [0.1047113 ],
         [0.68490404],
         [0.81689399]],

        [[0.62955017],
         [0.24202261],
         [0.78542072],
         [0.14567943],
         [0.82727633],
         [0.58070537],
         [0.28937281],
         [0.5132434 ],
         [0.6288514 ],
         [0.2585895 ]],

        [[0.84690656],
         [0.42125371],
         [0.89233746],
   

In [49]:
tf.convert_to_tensor(np.array([read_image(fd).tolist(),read_image(fd).tolist()]))

NameError: name 'read_image' is not defined

In [21]:
from keras.models import Sequential

In [22]:
input.shape

TensorShape([None, 224, 224, 3])

In [43]:
train_data[0]

(array([[[[217., 213., 222.],
          [208., 199., 204.],
          [204., 190., 199.],
          ...,
          [236., 222., 236.],
          [236., 222., 236.],
          [236., 222., 236.]],
 
         [[217., 213., 222.],
          [213., 199., 208.],
          [204., 190., 199.],
          ...,
          [236., 222., 236.],
          [236., 222., 236.],
          [236., 222., 236.]],
 
         [[217., 208., 217.],
          [213., 199., 204.],
          [204., 185., 190.],
          ...,
          [236., 217., 236.],
          [231., 217., 236.],
          [231., 217., 231.]],
 
         ...,
 
         [[148., 111.,  83.],
          [148., 111.,  83.],
          [148., 111.,  83.],
          ...,
          [ 69.,  55.,  69.],
          [ 69.,  51.,  64.],
          [ 69.,  51.,  64.]],
 
         [[148., 111.,  83.],
          [148., 111.,  83.],
          [143., 106.,  83.],
          ...,
          [ 74.,  55.,  69.],
          [ 69.,  51.,  64.],
          [ 69.,  51.,  64.

In [23]:
input

<KerasTensor: shape=(None, 224, 224, 3) dtype=float32 (created by layer 'input_1')>

In [24]:
output

<KerasTensor: shape=(None, 2) dtype=float32 (created by layer 'dense_10')>

In [27]:

model = keras.Model(input, output)
model = Sequential()

In [28]:

model.compile(
    loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
    optimizer=tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    metrics=[
        keras.metrics.CategoricalAccuracy(name="accuracy"),
        keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ],
)

# history = model.fit(
#     train_
# #     batch_size=10,
# #     epochs=num_epochs,
# #     validation_split=validation_split,
# )


In [44]:
model.summary()

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________


In [45]:
model.fit?

In [30]:
x,y = next(train_data)

In [31]:
y.shape

(100, 2)

In [32]:
x.shape

(100, 224, 224, 3)

In [33]:
model.fit(x,y, epochs=10)

Epoch 1/10


ValueError: in user code:

    File "C:\Users\wuwul\python\lib\site-packages\keras\engine\training.py", line 1021, in train_function  *
        return step_function(self, iterator)
    File "C:\Users\wuwul\python\lib\site-packages\keras\engine\training.py", line 1010, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\wuwul\python\lib\site-packages\keras\engine\training.py", line 1000, in run_step  **
        outputs = model.train_step(data)
    File "C:\Users\wuwul\python\lib\site-packages\keras\engine\training.py", line 860, in train_step
        loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "C:\Users\wuwul\python\lib\site-packages\keras\engine\training.py", line 918, in compute_loss
        return self.compiled_loss(
    File "C:\Users\wuwul\python\lib\site-packages\keras\engine\compile_utils.py", line 201, in __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "C:\Users\wuwul\python\lib\site-packages\keras\losses.py", line 141, in __call__
        losses = call_fn(y_true, y_pred)
    File "C:\Users\wuwul\python\lib\site-packages\keras\losses.py", line 245, in call  **
        return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "C:\Users\wuwul\python\lib\site-packages\keras\losses.py", line 1789, in categorical_crossentropy
        return backend.categorical_crossentropy(
    File "C:\Users\wuwul\python\lib\site-packages\keras\backend.py", line 5083, in categorical_crossentropy
        target.shape.assert_is_compatible_with(output.shape)

    ValueError: Shapes (None, 2) and (None, 224, 224, 3) are incompatible


Let's visualize the training progress of the model.

In [None]:
plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
plt.legend()
plt.grid()
plt.show()

Let's display the final results of the training on CIFAR-100.

In [None]:
loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test loss: {round(loss, 2)}")
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

The Swin Transformer model we just trained has just 152K parameters, and it gets
us to ~75% test top-5 accuracy within just 40 epochs without any signs of overfitting
as well as seen in above graph. This means we can train this network for longer
(perhaps with a bit more regularization) and obtain even better performance.
This performance can further be improved by additional techniques like cosine
decay learning rate schedule, other data augmentation techniques. While experimenting,
I tried training the model for 150 epochs with a slightly higher dropout and greater
embedding dimensions which pushes the performance to ~72% test accuracy on CIFAR-100
as you can see in the screenshot.

![Results of training for longer](https://i.imgur.com/9vnQesZ.png)

The authors present a top-1 accuracy of 87.3% on ImageNet. The authors also present
a number of experiments to study how input sizes, optimizers etc. affect the final
performance of this model. The authors further present using this model for object detection,
semantic segmentation and instance segmentation as well and report competitive results
for these. You are strongly advised to also check out the
[original paper](https://arxiv.org/abs/2103.14030).

This example takes inspiration from the official
[PyTorch](https://github.com/microsoft/Swin-Transformer) and
[TensorFlow](https://github.com/VcampSoldiers/Swin-Transformer-Tensorflow) implementations.