In [3]:
import math
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from tensorflow.keras import layers
import os
import cv2
import platform
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
from IPython.display import clear_output

# # Setting seed for reproducibiltiy
# SEED = 42
# keras.utils.set_random_seed(SEED)

In [4]:
import math
import os
import numpy as np
import pandas as pd
from PIL import Image


class_label = {3: ('train_COVIDx9A.txt', 'test_COVIDx9A.txt'),
               2: ('train.txt', 'test.txt')
               }

platform_path = {'Local': './data/',
                 'Kaggle': '/kaggle/input/covidx-cxr2/',
                 'Colab': './content/'
                 }

label_colnames = ['patient_id', 'filename', 'class', 'data_source']

label_encode = {'positive': 1,
                'negative': 0,
                'COVID-19': 2,
                'normal': 0
                }


class DataLoader():
    """
    The DataLoader class handles loading image data from respective folders regardless of platform while performing act-
    ions such as bootstrapping and sampling.  

    Attributes:
        platform (str): Name of the platform, options are ["Local", "Kaggle", "Colab"]
        n_classes (int): Number of classfication classes, options are [2, 3]
        data_dir (str): Folder directory containing the images, must contain the ./train/ and ./test/ folder
        txt_dir (str): Folder directory containing the label files such as "train_COVIDx9A.txt", by default = data_dir
        img_size (int): The final output image size after cropping and resizing
        combined (bool): Determines whether to combine original data's train and test sets for custom test split
        channels (str): Determines whether to output RGB or Greyscale (L) images
    """
    def __init__(self, platform: str = 'Local', n_classes: int = 2, data_dir: str = None, txt_dir: str = None,
                 img_size: int = 224, combined: bool = True, channels: str = 'RGB'):

        # data validation checks
        assert platform in platform_path, 'Platforms must be in ["Local", "Kaggle", "Colab"]'
        self.platform = platform

        assert n_classes in [2, 3], 'n_classes must be in [2, 3]'
        self.n_classes = n_classes

        # if directories not specified, take default
        if data_dir is None:
            self.data_dir = platform_path.get(platform)
        else:
            self.data_dir = data_dir

        if txt_dir is None:
            self.txt_dir = platform_path.get(platform)
        else:
            self.txt_dir = txt_dir

        self.img_size = img_size
        self.combined = combined

        assert channels in ['L', 'RGB'], "channels must be in ['L', 'RGB']"
        self.channels = channels

        # check if either label file is in txt_dir folder    
        for i in range(2):
            if not os.path.exists(os.path.join(self.txt_dir, class_label.get(self.n_classes)[i])):
                raise FileNotFoundError(class_label.get(self.n_classes)[i] + ' not found in ' + self.txt_dir)

        # print a summary message
        print(f'Platform: {self.platform}\nNum Classes: {self.n_classes}')
        print(f'Data Folder: {self.data_dir}\nLabel Folder: {self.txt_dir}')
        print(f'Image size: {self.img_size}\nCombined: {self.combined}\nImage Channels: {self.channels}')


    def __crop_image(self, img: Image) -> Image:
        """
        Crop and resize image to a square
        The image is cropped at the centre using the shorter length.
        It's then resized to the proposed dimensions.

        Args:
            img (Image): a PIL.Image object to be resized

        Returns:
            image_final (Image): a PIL.Image object that's resized
        """
        width, height = img.size
        diff = abs(width-height)

        # initialise final image parameters
        left, right, top, bottom = 0, width, 0, height

        # crop based on whether the difference in dimensions is odd or even
        if diff % 2 == 0:
            if width >= height:
                bottom = height
                left = diff / 2
                right = width - left
            elif height > width:
                top = diff / 2
                bottom = height - top
                right = width
        else:
            if width > height:
                bottom = height
                left = diff / 2 + 0.5
                right = width - left + 1
            elif height > width:
                top = diff / 2 + 0.5
                bottom = height - top + 1
                right = width

        # crop image into a square
        img_cropped = img.crop((left, top, right, bottom))
        # resize to desired shape
        img_final = img_cropped.resize((self.img_size, self.img_size))
        
        return img_final


    def __load_metadata(self) -> pd.DataFrame:
        """
        Loads the label txt files, return them as separate or combined pd.DataFrame

        Returns:
            train_df (pd.DataFrame): DF of all training images file names and labels
            test_df (pd.DataFrame): DF of all test images file names and labels
            combined_df (pd.DataFrame): DF of combined images file names and labels
        """
        # name of the label files
        train_filename, test_filename = class_label.get(self.n_classes)
        # read in label files
        train_df = pd.read_csv(os.path.join(self.txt_dir, train_filename), header=None, sep=' ')
        test_df = pd.read_csv(os.path.join(self.txt_dir, test_filename), header=None, sep=' ')

        train_df.columns = label_colnames
        test_df.columns = label_colnames

        # assign file path for each image
        train_df['filepath'] = train_df.apply(lambda row: os.path.join(self.data_dir, 'train/', row['filename']), axis=1)
        test_df['filepath'] = test_df.apply(lambda row: os.path.join(self.data_dir, 'test/', row['filename']), axis=1)

        # combine all images metadata
        combined_df = pd.concat([train_df, test_df], ignore_index=True)

        # return combined or separate based on argument
        if self.combined:
            return combined_df
        else:
            return train_df, test_df


    def __bootstrap_sample(self, label_df: pd.DataFrame, rand_state: int,
                           sample_size: int, sample_percent: float) -> pd.DataFrame:
        """
        Creates a bootstrapped training sample from a given population.
        Patients with multiple images have one image randomly selected with the remainder dropped from the data.
        The data is samples with replacement with an equal number of positive/negative images.
        The validation set is from remaining COVID-19 images and a sample of non-COVID.
        
        Arguments:
            label_df (pd.DataFrame): DF of image metadata
            rand_state (int): Random state for Pandas sampling
            sample_size (int): Total number of images in the training set
            sample_percent (float): Percentage of training images vs total images, if specified, overrides sample_size
            
        Returns:
            df_train (pd.DataFrame): DF of images metadata for training after bootstrapping
            df_val (pd.DataFrame): DF of images metadata for validation after bootstrapping
        """
        # defaults to sample_size, unless sample_percent is specified
        if sample_percent is not None:
            sample_size = math.floor(label_df.shape[0] * sample_percent)
        else:
            assert sample_size <= label_df.shape[0], f'sample_size must be no larger than {label_df.shape[0]}'

        sample_half = math.floor(sample_size/2)

        # randomly pick one image from a patient and drop the rest from the sample population
        patient_img_count = label_df[['patient_id', 'filename']].groupby(['patient_id']).count().reset_index()
        duplicate_patients = patient_img_count['patient_id'][patient_img_count['filename'] > 1]
        
        for patient in duplicate_patients.unique():
            sample_row = label_df.loc[label_df['patient_id'] == patient, :].sample(n=1, random_state=rand_state)
            label_df = label_df.loc[label_df['patient_id'] != patient, :]
            label_df = pd.concat([label_df, sample_row], axis=0)

        # sample with replacement for positive/negative classes to create training data
        df_train_positive = label_df[label_df['class'] == 'positive'].sample(n=sample_half,
                                                                             replace=True, random_state=rand_state)
        df_train_negative = label_df[label_df['class'] == 'negative'].sample(n=sample_half,
                                                                             replace=True, random_state=rand_state)
        
        df_train = pd.concat([df_train_positive, df_train_negative], axis=0)
        df_remaining = label_df.loc[label_df['filename'].isin(df_train['filename']) == False, :]
        
        # remaining has lots of negative samples so we sample without replacement to match positive samples
        df_val_size = min(df_train_positive.shape[0], df_remaining[df_remaining['class']=='positive'].shape[0])

        df_val_positive = df_remaining.loc[df_remaining['class'] == 'positive', :].sample(n=df_val_size,
                                                                                          replace=False,
                                                                                          random_state=rand_state)
        df_val_negative = df_remaining.loc[df_remaining['class'] == 'negative', :].sample(n=df_val_positive.shape[0],
                                                                                          replace=False,
                                                                                          random_state=rand_state)
        
        df_val = pd.concat([df_val_positive, df_val_negative], axis=0)
        
        # shuffle datasets before returning
        df_train = df_train.sample(frac=1, random_state=rand_state).reset_index(drop=True)
        df_val = df_val.sample(frac=1, random_state=rand_state).reset_index(drop=True)

        return df_train, df_val


    def __test_split(self, test_percent: float) -> pd.DataFrame:
        """
        Split out the holdout test set from the whole data
        This is only for when the original data is combined, otherwise the original data defined the holdout test set
        
        Arguments:
            test_percent (float): Percentage of holdout images vs total images

        Returns:
            df_train_val (pd.DataFrame): DF of images metadata for training and validation
            df_test (pd.DataFrame): DF of images metadata for holdout testing
        """
        assert test_percent > 0 and test_percent <= 1.00, 'test_percent must be in range (0, 1.00]'

        # load in combined data
        label_df = self.__load_metadata()
        test_size = math.floor(label_df.shape[0] * test_percent)

        # obtain the proportion of positive class in the data
        positive_percent = label_df['class'].value_counts(normalize=True)['positive']
        # determine the number of positive images in the holdout set for balanced test data
        positive_size = math.floor(test_size * positive_percent)

        positive_patients = label_df.loc[label_df['class'] == 'positive', :]

        # DF of positive patients and their positive image count, then random shuffle to be sampled
        positive_img_count = positive_patients[['patient_id', 'filename']].groupby(['patient_id']).count().reset_index()
        positive_img_count = positive_img_count.sample(frac=1, random_state=50)


        # sample positive patients until total number of images equals required number
        positive_count = 0
        positive_patients = []

        for row in positive_img_count.iterrows():
            positive_count += row[1]['filename']
            positive_patients.append(row[1]['patient_id'])
            if positive_count > positive_size:
                break
        
        # repeat the process above for patients with negative images, however, remove any patients already picked
        # this is to avoid cases where someone has images in both class, so they don't get picked twice
        negative_patients = label_df.loc[(label_df['class'] == 'negative') &
                                         (~label_df['patient_id'].isin(positive_patients)), :]

        negative_img_count = negative_patients[['patient_id', 'filename']].groupby(['patient_id']).count().reset_index()
        negative_img_count = negative_img_count.sample(frac=1, random_state=50)

        negative_count = 0
        negative_patients = []

        for row in negative_img_count.iterrows():
            negative_count += row[1]['filename']
            negative_patients.append(row[1]['patient_id'])
            if negative_count + positive_count > test_size:
                break

        # combined all patients picked and filter out their iamges
        test_patients = positive_patients + negative_patients

        df_test = label_df[label_df['patient_id'].isin(test_patients)]
        df_train_val = pd.concat([label_df, df_test, df_test]).drop_duplicates(keep=False)

        return df_train_val, df_test
    
    
    def __train_test_split(self, rand_state: int, bootstrap: bool,
                           sample_size: int, sample_percent: float, test_percent: float) -> pd.DataFrame:
        """
        Split to train, val, test based on user choices
        
        Arguments:
            rand_state (int): Random state that will determine the images selected
            bootstrap (bool): Determines whether to bootstrap the data or not
            sample_size (int): Total number of images in the training set
            sample_percent (float): Percentage of training images vs total images, if specified, overrides sample_size
            test_percent (float): Percentage of holdout images vs total images

        Returns:
            df_train (pd.DataFrame): DF of images metadata for training 
            df_val (pd.DataFrame): DF of images metadata for validation
            df_test (pd.DataFrame): DF of images metadata for holdout testing
        """
        if self.combined:
            df_train_val, df_test = self.__test_split(test_percent=test_percent)
        else:
            df_train_val, df_test = self.__load_metadata(combined=self.combined)

        df_train, df_val = self.__bootstrap_sample(label_df=df_train_val, rand_state=rand_state,
                                                   sample_size=sample_size, sample_percent=sample_percent)

        return df_train, df_val, df_test


    def __load_images(self, image_df: pd.DataFrame) -> np.array:
        """
        Loads in the images based on the input image metadata

        Arguments:
            image_df (pd.DataFrame): DF of images metadata, which includes their file paths

        Returns:
            images (np.array): A (no of images) x (img_size^2 x channels) array containing images specified in image_df
        """
        images = []

        for idx, file_path in enumerate(image_df['filepath']):
            img = Image.open(file_path)
            img = img.convert(self.channels)
            img_final = self.__crop_image(img)
            img_array = np.asarray(img_final).flatten()

            images.append(img_array)

        images = np.array(images)

        return images


    def __load_labels(self, image_df: pd.DataFrame) -> np.array:
        """
        Loads in the image labels based on the input image metadata

        Arguments:
            image_df (pd.DataFrame): DF of images metadata, which includes their file paths

        Returns:
            labels (np.array): A (no of images x 1) array containing image labels
        """
        labels = []

        for label in image_df['class']:
            labels.append(label_encode.get(label))

        labels = np.array(labels)

        return labels


    def load_train_val(self, rand_state: int, bootstrap: bool = True, 
                       sample_size: int = 20000, sample_percent: float = None) -> np.array:
        """
        Loads in the training and validation images based on the input image metadata

        Arguments:
            rand_state (int): Random state that will determine the images selected
            bootstrap (bool): Determines whether to bootstrap the data or not
            sample_size (int): Total number of images in the training set
            sample_percent (float): Percentage of training images vs total images, if specified, overrides sample_size

        Returns:
            X_train (np.array): A (no of images) x (img_size^2 x channels) array containing training images
            X_val (np.array): A (no of images) x (img_size^2 x channels) array containing validation images
            Y_train (np.array): A (no of images x 1) array containing training image labels
            Y_val (np.array): A (no of images x 1) array containing validation image labels
        """
        if sample_percent is not None:
            assert sample_percent > 0 and sample_percent <= 1.00, 'sample_percent must be in range (0, 1.00]'

        df_train, df_val, _ = self.__train_test_split(rand_state=rand_state, bootstrap=bootstrap,
                                                            sample_size=sample_size, sample_percent=sample_percent,
                                                            test_percent=0.1)

        X_train = self.__load_images(df_train)
        X_val = self.__load_images(df_val)
        
        Y_train = self.__load_labels(df_train)
        Y_val = self.__load_labels(df_val)

        return X_train, X_val, Y_train, Y_val


    def load_test(self, test_percent: float = 0.1) -> np.array:
        """
        Loads in the holdout test images.
        The images loaded are always fixed given the same test_percent

        Arguments:
            test_percent (float): Percentage of holdout images vs total images

        Returns:
            X_test (np.array): A (no of images) x (img_size^2 x channels) array containing holdout test images
            Y_test (np.array): A (no of images x 1) array containing holdout test image labels   
        """
        assert test_percent > 0 and test_percent <= 1.00, 'test_percent must be in range (0, 1.00]'

        _, _, df_test = self.__train_test_split(rand_state=1, bootstrap=True,
                                                sample_size=3000, sample_percent=None, test_percent=test_percent)

        X_test = self.__load_images(df_test)
        Y_test = self.__load_labels(df_test)

        return X_test, Y_test

In [5]:
image_loader = DataLoader(platform='Kaggle', # ["Local", "Kaggle", "Colab"]
                          n_classes=2,      # [2, 3] only 2 works for now
                          data_dir='/kaggle/input/covidx-cxr2/',    # See next cell
                          txt_dir='/kaggle/input/covidx-cxr2/',     # See next cell
                          img_size=224,     
                          combined=True,    
                          channels='RGB') 

In [6]:
import time
t = time.time()
X_train, X_val, Y_train, Y_val = image_loader.load_train_val(rand_state=42, sample_size=25000)
print(X_train.shape)
print(X_val.shape)
print(time.time() - t)

In [7]:
X_test, Y_test = image_loader.load_test()
print(X_test.shape)

In [8]:
X_train = np.asarray(X_train).reshape(-1,224,224,3)
Y_train = np.asarray(Y_train)
print(X_train.shape) #  data*224*224*3(height*width*channel)
print(Y_train.shape)

In [9]:
X_val = np.asarray(X_val).reshape(-1,224,224,3)
Y_val = np.asarray(Y_val)
print(X_val.shape)
print(Y_val.shape)

In [10]:
X_test = np.asarray(X_test).reshape(-1,224,224,3)
Y_test = np.asarray(Y_test)
print(X_test.shape)
print(Y_test.shape)

In [11]:
NUM_CLASSES = 100
INPUT_SHAPE = (224, 224, 3)
# DATA
BUFFER_SIZE = 512
BATCH_SIZE = 256

# AUGMENTATION
IMAGE_SIZE = 72
PATCH_SIZE = 6
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

# OPTIMIZER
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0001

# TRAINING
EPOCHS = 50

# ARCHITECTURE
LAYER_NORM_EPS = 1e-6
TRANSFORMER_LAYERS = 8
PROJECTION_DIM = 64
NUM_HEADS = 4
TRANSFORMER_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM,
]
MLP_HEAD_UNITS = [2048, 1024]

In [12]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(X_train)

In [13]:
class ShiftedPatchTokenization(layers.Layer):
    def __init__(
        self,
        image_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        num_patches=NUM_PATCHES,
        projection_dim=PROJECTION_DIM,
        vanilla=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.vanilla = vanilla  # Flag to swtich to vanilla patch extractor
        self.image_size = image_size
        self.patch_size = patch_size
        self.half_patch = patch_size // 2
        self.flatten_patches = layers.Reshape((num_patches, -1))
        self.projection = layers.Dense(units=projection_dim)
        self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)

    def crop_shift_pad(self, images, mode):
        # Build the diagonally shifted images
        if mode == "left-up":
            crop_height = self.half_patch
            crop_width = self.half_patch
            shift_height = 0
            shift_width = 0
        elif mode == "left-down":
            crop_height = 0
            crop_width = self.half_patch
            shift_height = self.half_patch
            shift_width = 0
        elif mode == "right-up":
            crop_height = self.half_patch
            crop_width = 0
            shift_height = 0
            shift_width = self.half_patch
        else:
            crop_height = 0
            crop_width = 0
            shift_height = self.half_patch
            shift_width = self.half_patch

        # Crop the shifted images and pad them
        crop = tf.image.crop_to_bounding_box(
            images,
            offset_height=crop_height,
            offset_width=crop_width,
            target_height=self.image_size - self.half_patch,
            target_width=self.image_size - self.half_patch,
        )
        shift_pad = tf.image.pad_to_bounding_box(
            crop,
            offset_height=shift_height,
            offset_width=shift_width,
            target_height=self.image_size,
            target_width=self.image_size,
        )
        return shift_pad

    def call(self, images):
        if not self.vanilla:
            # Concat the shifted images with the original image
            images = tf.concat(
                [
                    images,
                    self.crop_shift_pad(images, mode="left-up"),
                    self.crop_shift_pad(images, mode="left-down"),
                    self.crop_shift_pad(images, mode="right-up"),
                    self.crop_shift_pad(images, mode="right-down"),
                ],
                axis=-1,
            )
        # Patchify the images and flatten it
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        flat_patches = self.flatten_patches(patches)
        if not self.vanilla:
            # Layer normalize the flat patches and linearly project it
            tokens = self.layer_norm(flat_patches)
            tokens = self.projection(tokens)
        else:
            # Linearly project the flat patches
            tokens = self.projection(flat_patches)
        return (tokens, patches)

In [14]:
# Get a random image from the training dataset
# and resize the image
image = X_train[np.random.choice(range(X_train.shape[0]))]
resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)
)

# Shifted Patch Tokenization: This layer takes the image, shifts it
# diagonally and then extracts patches from the concatinated images
(token, patch) = ShiftedPatchTokenization(vanilla=False)(resized_image / 255.0)
(token, patch) = (token[0], patch[0])
n = patch.shape[0]
shifted_images = ["ORIGINAL", "LEFT-UP", "LEFT-DOWN", "RIGHT-UP", "RIGHT-DOWN"]
for index, name in enumerate(shifted_images):
    print(name)
    count = 1
    plt.figure(figsize=(4, 4))
    for row in range(n):
        for col in range(n):
            plt.subplot(n, n, count)
            count = count + 1
            image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))
            plt.imshow(image[..., 3 * index : 3 * index + 3])
            plt.axis("off")
    plt.show()

In [15]:
class PatchEncoder(layers.Layer):
    def __init__(
        self, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM, **kwargs
    ):
        super().__init__(**kwargs)
        self.num_patches = num_patches
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )
        self.positions = tf.range(start=0, limit=self.num_patches, delta=1)

    def call(self, encoded_patches):
        encoded_positions = self.position_embedding(self.positions)
        encoded_patches = encoded_patches + encoded_positions
        return encoded_patches

In [16]:
class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # The trainable temperature term. The initial value is
        # the square root of the key dimension.
        self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)

    def _compute_attention(self, query, key, value, attention_mask=None, training=None):
        query = tf.multiply(query, 1.0 / self.tau)
        attention_scores = tf.einsum(self._dot_product_equation, key, query)
        attention_scores = self._masked_softmax(attention_scores, attention_mask)
        attention_scores_dropout = self._dropout_layer(
            attention_scores, training=training
        )
        attention_output = tf.einsum(
            self._combine_equation, attention_scores_dropout, value
        )
        return attention_output, attention_scores

In [17]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


# Build the diagonal attention mask
diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)

In [18]:
def create_vit_classifier(vanilla=False):
    inputs = layers.Input(shape=INPUT_SHAPE)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    (tokens, _) = ShiftedPatchTokenization(vanilla=vanilla)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder()(tokens)

    # Create multiple layers of the Transformer block.
    for _ in range(TRANSFORMER_LAYERS):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        if not vanilla:
            attention_output = MultiHeadAttentionLSA(
                num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
            )(x1, x1, attention_mask=diag_attn_mask)
        else:
            attention_output = layers.MultiHeadAttention(
                num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
            )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=TRANSFORMER_UNITS, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=MLP_HEAD_UNITS, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(NUM_CLASSES)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

In [19]:
# Some code is taken from:
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(
        self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
    ):
        super(WarmUpCosine, self).__init__()

        self.learning_rate_base = learning_rate_base
        self.total_steps = total_steps
        self.warmup_learning_rate = warmup_learning_rate
        self.warmup_steps = warmup_steps
        self.pi = tf.constant(np.pi)

    def __call__(self, step):
        if self.total_steps < self.warmup_steps:
            raise ValueError("Total_steps must be larger or equal to warmup_steps.")

        cos_annealed_lr = tf.cos(
            self.pi
            * (tf.cast(step, tf.float32) - self.warmup_steps)
            / float(self.total_steps - self.warmup_steps)
        )
        learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)

        if self.warmup_steps > 0:
            if self.learning_rate_base < self.warmup_learning_rate:
                raise ValueError(
                    "Learning_rate_base must be larger or equal to "
                    "warmup_learning_rate."
                )
            slope = (
                self.learning_rate_base - self.warmup_learning_rate
            ) / self.warmup_steps
            warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
            learning_rate = tf.where(
                step < self.warmup_steps, warmup_rate, learning_rate
            )
        return tf.where(
            step > self.total_steps, 0.0, learning_rate, name="learning_rate"
        )


def run_experiment(model):
    total_steps = int((len(X_train) / BATCH_SIZE) * EPOCHS)
    warmup_epoch_percentage = 0.10
    warmup_steps = int(total_steps * warmup_epoch_percentage)
    scheduled_lrs = WarmUpCosine(
        learning_rate_base=LEARNING_RATE,
        total_steps=total_steps,
        warmup_learning_rate=0.0,
        warmup_steps=warmup_steps,
    )

    optimizer = tfa.optimizers.AdamW(
        learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    history = model.fit(
        x=X_train,
        y=Y_train,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_split=0.1,
    )
    _, accuracy, top_5_accuracy = model.evaluate(X_test, Y_test, batch_size=BATCH_SIZE)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


# # Run experiments with the vanilla ViT
# vit = create_vit_classifier(vanilla=True)
# history = run_experiment(vit)

# Run experiments with the Shifted Patch Tokenization and
# Locality Self Attention modified ViT
vit_sl = create_vit_classifier(vanilla=False)
history = run_experiment(vit_sl)

from sklearn.metrics import confusion_matrix
test_predictions= vit_sl.predict(X_test)

confusion= confusion_matrix(Y_test, np.argmax(test_predictions,axis=1))

import seaborn as sns
import matplotlib.pyplot as plt
ax= plt.subplot()
sns.heatmap(confusion, annot=True, fmt='g', ax=ax);
ax.set_title('Confusion Matrix of VT without Vanilla');

