![image](https://user-images.githubusercontent.com/17668390/149625554-b9c7074a-2137-49d5-8726-a3fbfa3f9a4c.gif)

<div align="center">
    Figure: <strong>Grad-CAM</strong> of Hybrid-EfficientNet-Swin Transformer. Left (<i>Input</i>), Middle (<i>EfficientNet</i>), Right (<i>Swin Transformer</i>).
</div>

<table class="tfo-notebook-buttons" align="center">
    
  <td>
    <a target="_blank" href="https://colab.research.google.com/drive/1usxq9yhZthyapAnzFfFObQ7RjXPicAop?usp=sharing"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
    
  <td>
    <a target="_blank" href="https://github.com/innat/HybridModel-GradCAM"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
    
   <td>
    <a target="_blank" href="https://deepnote.com/workspace/mohammed-innat-36e929bc-ce23-4d95-9ddc-a9c6662eb7d6/project/Notebooks-32e94ef4-8ce0-4cc4-8042-49862519f3f2/%2F%5BDeepnote%5D_HENetSwinT.ipynb"><img src="https://user-images.githubusercontent.com/17668390/176064308-845cc64d-cd84-44fa-a491-4a53759b19d4.png" />Run in Deepnote</a>
  </td>
    
   <td>
    <a target="_blank" href="https://www.kaggle.com/code/ipythonx/tf-hybrid-efficientnet-swin-transformer-gradcam"><img src="https://user-images.githubusercontent.com/17668390/176064379-9bcc7836-dcff-42b2-bcca-feff28f22c70.png" />Run in Kaggle</a>
   </td>
    
   <td>
    <a target="_blank" href="https://huggingface.co/spaces/innat/HybridModel-GradCAM"><img src="https://user-images.githubusercontent.com/17668390/176064420-46cbf547-0d17-4438-a791-d23e17eff5a9.png" />Try on Gradio</a>
   </td>
    
</table>

# Introduction

Convolutional Neural Networks (CNNs) have been the de facto model for visual data up to this point. But the Vision transformers have proven as an efficient alternative for CNNs. Current research has shown that transformer models can perform similarly, if not better, on vison tasks as well. In this code example, we have tried to inspect the visual interpretations of a **CNN** and **Transformer** blocks of a hybrid model (EfficientNet + Swin Transformer) with the **GradCAM** technique. In the result, it appears that the transformer blocks are capable of globally refining feature activation across the relevant object, as opposed to the CNN, which is more focused on operating locally (shown in above figure). However, the approach that will be shown here, is highly experimental. The workflow probably can generate a more meaningful modeling approach.

**Data Pipelines**. To keep things simple, we've used [tf_flowers](https://www.tensorflow.org/datasets/catalog/tf_flowers) dataset, a multi-class classification problem. To apply image augmentaiton, we have used the **vectorized** implementaiton of **CutMix** and **MixUp**, derived from the [KerasCV](https://keras.io/keras_cv/). Additionally, we have also used [Jax](https://jax.readthedocs.io/en/latest/) library to write image augmentaiton layer and used it in the `tf.data` API. FYI, yes we can use `Jax` code to build `keras` layer. Also, we have replaced all possible `numpy` code to `jax.numpy` wherever possible.


- [**Code Style: Keras.io.**](https://bit.ly/3Oe9zHY) ✔️
- **JIT Compilation** ✔️
- **Mixed Precision** ✔️
- **Gradient Accumulation** ✔️
- **Label Smoothing** ✔️
- **TensorFlow Lite Conversion** ✔️


**Note**: If the accelerator is set to **GPU**, training will be the first to begin, followed by inference. However, if the accelerator is set to **CPU**, only inference will be performed, and a trained weight file will be used. Also, this notebook is tested on `Jax 0.3.13` and `TensorFlow 2.6.4` on Kaggle, `TF 2.8` on Colab, and `TF 2.9` on Deepnote.

In [None]:
!pip install gdown -q
!pip install jax==0.3.13 jaxlib==0.3.10 -q

In [None]:
import os
import random
import warnings

import cv2
import gdown
from functools import partial

warnings.simplefilter(action="ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import numpy as np
import pandas as pd
from matplotlib import cm
from numpy.random import rand
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend
from tensorflow.keras import layers

physical_devices = tf.config.list_physical_devices("GPU")
try:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    tf.config.optimizer.set_jit(True)
    keras.mixed_precision.set_global_policy("mixed_float16")
except:
    pass

seed = 1337
tf.random.set_seed(seed)

**Utils**

In [None]:
def get_model_weight(model_id):
    """Get the trained weights."""
    if not os.path.exists("model.h5"):
        model_weight = gdown.download(id=model_id, quiet=False)
    else:
        model_weight = "model.h5"
    return model_weight


def get_model_history(history_id):
    """Get the history / log file."""
    if not os.path.exists("history.csv"):
        history_file = gdown.download(id=history_id, quiet=False)
    else:
        history_file = "history.csv"
    return history_file


def make_plot(tfdata, take_batch=1, title=True, figsize=(20, 20)):
    """ref: https://gist.github.com/innat/4dc4080cfdf5cf20ef0fc93d3623ca9b"""

    font = {
        "family": "serif",
        "color": "darkred",
        "weight": "normal",
        "size": 15,
    }

    for images, labels in tfdata.take(take_batch):
        plt.figure(figsize=figsize)
        xy = int(np.ceil(images.shape[0] * 0.5))

        for i in range(images.shape[0]):
            plt.subplot(xy, xy, i + 1)
            plt.imshow(tf.cast(images[i], dtype=tf.uint8))
            if title:
                plt.title(tcls_names[tf.argmax(labels[i], axis=-1)], fontdict=font)
            plt.axis("off")

    plt.tight_layout()
    plt.show()


# Acquiring Data

In [None]:
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*.jpg')))
print('Total Samples: ', image_count)

**Parameter Settings**

In [None]:
class Parameters:
    # data level
    image_size = 384
    batch_size = 8
    num_grad_accumulation = 8
    label_smooth=0.05
    class_number = 5
    val_split = 0.2 
    verbosity = 2
    autotune = tf.data.AUTOTUNE
    
    # hparams
    epochs = 20
    lr_sched = 'cosine_restart' # [or, exponential, cosine, linear, constant]
    lr_base  = 0.016
    lr_min   = 0
    lr_decay_epoch  = 2.4
    lr_warmup_epoch = 5
    lr_decay_factor = 0.97
    
    scaled_lr = lr_base * (batch_size / 256.0)
    scaled_lr_min = lr_min * (batch_size / 256.0)
    num_validation_sample = int(image_count * val_split)
    num_training_sample = image_count - num_validation_sample
    train_step = int(np.ceil(num_training_sample / float(batch_size)))
    total_steps = train_step * epochs

params = Parameters()

# Data Loading

In [None]:
train_set = keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=params.val_split,
    subset="training",
    label_mode='categorical',
    seed=params.image_size,
    image_size=(params.image_size, params.image_size),
    batch_size=params.batch_size,
)

val_set = keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=params.val_split,
    subset="validation",
    label_mode='categorical',
    seed=params.image_size,
    image_size=(params.image_size, params.image_size),
    batch_size=params.batch_size,
)

tcls_names, vcls_names = train_set.class_names , val_set.class_names
tcls_names, vcls_names 

**Visualize Raw Samples**

In [None]:
make_plot(train_set, take_batch=1, title=True) 
make_plot(val_set, take_batch=1, title=True) 

# Advance Image Augmentation

As mentioned, we will be using **CutMix** and **MixUp**, written in `tf.keras`. And along with it, we will be using two simple `Jax` coded augmentation layers and convert them to work in `tf.data` API. However, currently, there is no such **probability** parameter to control the occurrence of the [keras built-in augmentation](https://keras.io/api/layers/preprocessing_layers/image_augmentation/) layer. So, we will use a wrapper class for the augmentation layers and make the image transformation random in action.

In [None]:
class RandomApply(layers.Layer):
    """RandomApply will randomly apply the transformation layer
    based on the given probability.
    
    Ref. https://stackoverflow.com/a/72558994/9215780
    """

    def __init__(self, layer, probability, **kwargs):
        super().__init__(**kwargs)
        self.layer = layer
        self.probability = probability

    def call(self, inputs, training=True):
        apply_layer = tf.random.uniform([]) < self.probability
        outputs = tf.cond(
            pred=tf.logical_and(apply_layer, training),
            true_fn=lambda: self.layer(inputs),
            false_fn=lambda: inputs,
        )
        return outputs

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "layer": layers.serialize(self.layer),
                "probability": self.probability,
            }
        )
        return config

## MixUp

Implemented in `tf.keras`.

In [None]:
class MixUp(layers.Layer):
    """Original implementation: https://github.com/keras-team/keras-cv.
    The original implementaiton provide more interface to apply mixup on
    various CV related task, i.e. object detection etc. It also provides
    many effective validation check.

    Derived and modified for simpler usages: M.Innat.
    Ref. https://gist.github.com/innat/0ee2b6155d663aac2617fe596e1d8d49
    """

    def __init__(self, alpha=0.2, seed=None, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha
        self.seed = seed

    @staticmethod
    def _sample_from_beta(alpha, beta, shape):
        sample_alpha = tf.random.gamma(shape, 1.0, beta=alpha)
        sample_beta = tf.random.gamma(shape, 1.0, beta=beta)
        return sample_alpha / (sample_alpha + sample_beta)

    def _mixup_samples(self, images):
        batch_size = tf.shape(images)[0]
        permutation_order = tf.random.shuffle(tf.range(0, batch_size), seed=self.seed)

        lambda_sample = MixUp._sample_from_beta(self.alpha, self.alpha, (batch_size,))
        lambda_sample = tf.reshape(lambda_sample, [-1, 1, 1, 1])

        mixup_images = tf.gather(images, permutation_order)
        images = lambda_sample * images + (1.0 - lambda_sample) * mixup_images

        return images, tf.squeeze(lambda_sample), permutation_order

    def _mixup_labels(self, labels, lambda_sample, permutation_order):
        labels_for_mixup = tf.gather(labels, permutation_order)

        lambda_sample = tf.reshape(lambda_sample, [-1, 1])
        labels = lambda_sample * labels + (1.0 - lambda_sample) * labels_for_mixup

        return labels

    def call(self, batch_inputs):
        bs_images = tf.cast(batch_inputs[0], dtype=tf.float32)  
        bs_labels = tf.cast(batch_inputs[1], dtype=tf.float32)  

        mixup_images, lambda_sample, permutation_order = self._mixup_samples(bs_images)
        mixup_labels = self._mixup_labels(bs_labels, lambda_sample, permutation_order)

        return [mixup_images, mixup_labels]

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "alpha": self.alpha,
                "seed": self.seed,
            }
        )
        return config


In [None]:
temp_ds = train_set.map(
    lambda x, y: MixUp()([x, y]), num_parallel_calls=params.autotune
)
make_plot(temp_ds, take_batch=1, title=False) 

## CutMix

Implemented in `tf.keras`.

In [None]:
class CutMix(layers.Layer):
    """Original implementation: https://github.com/keras-team/keras-cv.
    The original implementaiton provide more interface to apply mixup on
    various CV related task, i.e. object detection etc. It also provides
    many effective validation check.
    
    Derived and modified for simpler usages: M.Innat.
    Ref. https://gist.github.com/innat/0524ee77de17f0601f0dee69aa52c713
    """

    def __init__(self, alpha=1.0, seed=None, **kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha
        self.seed = seed

    @staticmethod
    def _sample_from_beta(alpha, beta, shape):
        sample_alpha = tf.random.gamma(shape, 1.0, beta=alpha)
        sample_beta = tf.random.gamma(shape, 1.0, beta=beta)
        return sample_alpha / (sample_alpha + sample_beta)

    def _cutmix_labels(self, labels, lambda_sample, permutation_order):
        cutout_labels = tf.gather(labels, permutation_order)

        lambda_sample = tf.reshape(lambda_sample, [-1, 1])
        labels = lambda_sample * labels + (1.0 - lambda_sample) * cutout_labels
        return labels

    def _cutmix_samples(self, images):
        input_shape = tf.shape(images)
        batch_size, image_height, image_width = (
            input_shape[0],
            input_shape[1],
            input_shape[2],
        )

        permutation_order = tf.random.shuffle(tf.range(0, batch_size), seed=self.seed)
        lambda_sample = CutMix._sample_from_beta(self.alpha, self.alpha, (batch_size,))

        ratio = tf.math.sqrt(1 - lambda_sample)

        cut_height = tf.cast(
            ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32
        )
        cut_width = tf.cast(
            ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32
        )

        random_center_height = tf.random.uniform(
            shape=[batch_size], minval=0, maxval=image_height, dtype=tf.int32
        )
        random_center_width = tf.random.uniform(
            shape=[batch_size], minval=0, maxval=image_width, dtype=tf.int32
        )

        bounding_box_area = cut_height * cut_width
        lambda_sample = 1.0 - bounding_box_area / (image_height * image_width)
        lambda_sample = tf.cast(lambda_sample, dtype=tf.float32)

        images = self.fill_rectangle(
            images,
            random_center_width,
            random_center_height,
            cut_width,
            cut_height,
            tf.gather(images, permutation_order),
        )

        return images, lambda_sample, permutation_order

    def call(self, batch_inputs, training=None):
        bs_images = tf.cast(batch_inputs[0], dtype=tf.float32)  
        bs_labels = tf.cast(batch_inputs[1], dtype=tf.float32)  

        cutmix_images, lambda_sample, permutation_order = self._cutmix_samples(
            bs_images
        )
        cutmix_labels = self._cutmix_labels(bs_labels, lambda_sample, permutation_order)

        return [cutmix_images, cutmix_labels]

    def fill_rectangle(
        self, images, centers_x, centers_y, widths, heights, fill_values
    ):
        images_shape = tf.shape(images)
        images_height = images_shape[1]
        images_width = images_shape[2]

        xywh = tf.stack([centers_x, centers_y, widths, heights], axis=1)
        xywh = tf.cast(xywh, tf.float32)
        corners = self.convert_format(xywh)
        mask_shape = (images_width, images_height)

        is_rectangle = self.corners_to_mask(corners, mask_shape)
        is_rectangle = tf.expand_dims(is_rectangle, -1)
        images = tf.where(is_rectangle, fill_values, images)
        return images

    def convert_format(self, boxes):
        boxes = tf.cast(boxes, dtype=tf.float32)
        x, y, width, height, rest = tf.split(boxes, [1, 1, 1, 1, -1], axis=-1)
        results = tf.concat(
            [
                x - width / 2.0,
                y - height / 2.0,
                x + width / 2.0,
                y + height / 2.0,
                rest,
            ],
            axis=-1,
        )
        return results

    def _axis_mask(self, starts, ends, mask_len):
        # index range of axis
        batch_size = tf.shape(starts)[0]
        axis_indices = tf.range(mask_len, dtype=starts.dtype)
        axis_indices = tf.expand_dims(axis_indices, 0)
        axis_indices = tf.tile(axis_indices, [batch_size, 1])

        # mask of index bounds
        axis_mask = tf.greater_equal(axis_indices, starts) & tf.less(axis_indices, ends)
        return axis_mask

    def corners_to_mask(self, bounding_boxes, mask_shape):
        mask_width, mask_height = mask_shape
        x0, y0, x1, y1 = tf.split(bounding_boxes, [1, 1, 1, 1], axis=-1)

        w_mask = self._axis_mask(x0, x1, mask_width)
        h_mask = self._axis_mask(y0, y1, mask_height)

        w_mask = tf.expand_dims(w_mask, axis=1)
        h_mask = tf.expand_dims(h_mask, axis=2)
        masks = tf.logical_and(w_mask, h_mask)
        return masks

In [None]:
temp_ds = train_set.map(
    lambda x, y: CutMix()([x, y]), num_parallel_calls=params.autotune
)
make_plot(temp_ds, take_batch=1, title=False) 

In [None]:
class RandomMixUpCutMix(layers.Layer):
    def __init__(self, switch_prob=0.50, **kwargs):
        super().__init__(**kwargs)
        self.switch_prob = (
            switch_prob  # probability of switching between mixup and cutmix
        )
        self.mixup = CutMix()
        self.cutmix = MixUp()

    def call(self, batch_inputs):
        return tf.cond(
            tf.less(
                tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
                tf.cast(self.switch_prob, tf.float32),
            ),
            lambda: self.mixup(batch_inputs),
            lambda: self.cutmix(batch_inputs),
        )

    def get_config(self):
        config = {
            "switch_prob": self.switch_prob,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

## Channel Shuffle (jax2tf)

Implemented in `jax`.

In [None]:
import jax
from jax import jit
from jax import random
from jax import numpy as jnp
from jax.experimental import jax2tf

In [None]:
class RandomChannelShuffle(layers.Layer):
    """Shuffle channels of an input image.

    Ref. https://gist.github.com/innat/35ab35329e2ca890a17556384056334b
    """
    def __init__(self, groups=3, **kwargs):
        super().__init__(**kwargs)
        self.groups = groups

    @partial(jit, static_argnums=0)
    def _jax_channel_shuffling(self, images):
        batch_size, height, width, num_channels = images.shape

        if not num_channels % self.groups == 0:
            raise ValueError(
                "The number of input channels should be "
                "divisible by the number of groups."
                f"Received: channels={num_channels}, groups={self.groups}"
            )

        channels_per_group = num_channels // self.groups

        images = images.reshape(-1, height, width, self.groups, channels_per_group)
        images = images.transpose([3, 1, 2, 4, 0])
        key = random.PRNGKey(np.random.randint(50))
        images = random.permutation(key=key, x=images, axis=0)
        images = images.transpose([4, 1, 2, 3, 0])
        images = images.reshape(-1, height, width, num_channels)
        return images

    def call(self, images, training=True):
        if training:
            return jax2tf.convert(
                self._jax_channel_shuffling, polymorphic_shapes=("batch, ...")
            )(images)
        else:
            return images
    
    def get_config(self):
        config = {
            "groups": self.groups,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [None]:
temp_ds = train_set.map(
    lambda x, y: (RandomChannelShuffle()(x), y), num_parallel_calls=params.autotune
)
make_plot(temp_ds, take_batch=1, title=False) 

## GrayScaling (jax2tf)

Implemented in `jax`.

In [None]:
class RandomGrayscale(layers.Layer):
    """Grayscale is a preprocessing layer that transforms
    RGB images to Grayscale images.

    Ref. https://gist.github.com/innat/4e89725ccdcd763e0a6ba19216fd60bf
    """

    def __init__(self, output_channel=1, prob=1, **kwargs):
        super().__init__(**kwargs)
        self.output_channel = self._check_input_params(output_channel)

    def _check_input_params(self, output_channels):
        if output_channels not in [1, 3]:
            raise ValueError(
                "Received invalid argument output_channels. "
                f"output_channels must be in 1 or 3. Got {output_channels}"
            )
        return output_channels

    @partial(jit, static_argnums=0)
    def _jax_gray_scale(self, images):
        rgb_weights = jnp.array([0.2989, 0.5870, 0.1140], dtype=images.dtype)
        grayscale = (rgb_weights * images).sum(axis=-1)

        if self.output_channel == 1:
            grayscale = jnp.expand_dims(grayscale, axis=-1)
            return grayscale
        elif self.output_channel == 3:
            return jnp.stack([grayscale] * 3, axis=-1)
        else:
            raise ValueError("Unsupported value for `output_channels`.")

    def call(self, images, training=True):
        if training:
            return jax2tf.convert(
                self._jax_gray_scale, polymorphic_shapes=("batch, ...")
            )(images)
        else:
            return images

    def get_config(self):
        config = {
            "output_channel": self.output_channel,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))


In [None]:
temp_ds = train_set.map(
    lambda x, y: (RandomGrayscale(output_channel=3)(x), y), num_parallel_calls=params.autotune
)
make_plot(temp_ds, take_batch=1, title=False) 

**Combine All Augmentation**

In [None]:
jax_to_keras_augment = keras.Sequential(
    [
        RandomApply(RandomGrayscale(output_channel=3), probability=0.2),
        RandomApply(RandomChannelShuffle(), probability=0.5),
    ],
    name="jax2keras_augment",
)


tf_to_keras_augment = keras.Sequential(
    [
        RandomApply(layers.RandomFlip("horizontal"), probability=0.5),
        RandomApply(layers.RandomZoom(0.2, 0.3), probability=0.2),
        RandomApply(
            layers.RandomRotation((0.2, 0.3), fill_mode="reflect"), probability=0.8
        ),
    ],
    name="tf2keras_augment",
)


In [None]:
# for train set : augmentation
keras_aug = keras.Sequential(
    [
        layers.Resizing(height=params.image_size, width=params.image_size),
        jax_to_keras_augment,
        tf_to_keras_augment,
    ],
    name="keras_augment",
)

train_ds = train_set.shuffle(10 * params.batch_size)
train_ds = train_ds.map(
    lambda x, y: (keras_aug(x), y), num_parallel_calls=params.autotune
)
train_ds = train_ds.map(
    lambda x, y: RandomMixUpCutMix()([x, y]), num_parallel_calls=params.autotune
)


**Visualize Augmented Training Set**

In [None]:
make_plot(train_ds, take_batch=5, title=False) 

**Visualize Validation Set**

In [None]:
make_plot(val_set, take_batch=1, title=False) 

In [None]:
# Overlaps data preprocessing and model execution while training.
# It often improves latency and throughput, at the cost of using 
# additional memory to store prefetched elements.
train_ds = train_ds.prefetch(buffer_size=params.autotune)
val_ds = val_set.prefetch(buffer_size=params.autotune)

# Modeling

We will use *EfficientNet B0* model and its mid layer, `block6a_expand_activation` as an input of *Swin Transformer Blocks*. The input shape of model *EfficientNet B0* is set `384` and thus we will get output size `24` at layer `block6a_expand_activation`. And this will be the input of **swin-vit**. See the diagram below. Here the **2D CNN output** and **2D Swin Transformer output** are retrieved in order to compute the **GradCAM**.

![](https://user-images.githubusercontent.com/17668390/173227753-36cbf5a5-ee56-4202-bbb2-6020a388b188.png)

<div align="center">
    <i>Figure: Hybrid-EfficientNet-Swin Transformer.</i>
</div>

In [None]:
patch_size      = (2,2)   # 2-by-2 sized patches
dropout_rate    = 0.5     # Dropout rate
num_heads       = 8       # Attention heads
embed_dim       = 64      # Embedding dimension
num_mlp         = 128     # MLP layer size
qkv_bias        = True    # Convert embedded patches to query, key, and values
window_size     = 2       # Size of attention window
shift_size      = 1       # Size of shifting window
image_dimension = 24      # Initial image size / Input size of the transformer model 

num_patch_x = image_dimension // patch_size[0]
num_patch_y = image_dimension // patch_size[1]

In [None]:
def window_partition(x, window_size):
    _, height, width, channels = x.shape
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels)
    )
    x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
    windows = tf.reshape(x, shape=(-1, window_size, window_size, channels))
    return windows


def window_reverse(windows, window_size, height, width, channels):
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        windows,
        shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels),
    )
    x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5))
    x = tf.reshape(x, shape=(-1, height, width, channels))
    return x


class DropPath(layers.Layer):
    def __init__(self, drop_prob=None, **kwargs):
        super(DropPath, self).__init__(**kwargs)
        self.drop_prob = drop_prob

    def call(self, inputs, training=None):
        if self.drop_prob == 0.0 or not training:
            return inputs
        else:
            batch_size = tf.shape(inputs)[0]
            keep_prob = 1 - self.drop_prob
            path_mask_shape = (batch_size,) + (1,) * (len(tf.shape(inputs)) - 1)
            path_mask = tf.floor(
                backend.random_bernoulli(path_mask_shape, p=keep_prob)
            )
            outputs = (
                tf.math.divide(tf.cast(inputs, dtype=tf.float32), keep_prob) * path_mask
            )
            return outputs

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "drop_prob": self.drop_prob,
            }
        )
        return config

In [None]:
class PatchExtract(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.patch_size_x = patch_size[0]
        self.patch_size_y = patch_size[0]

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=(1, self.patch_size_x, self.patch_size_y, 1),
            strides=(1, self.patch_size_x, self.patch_size_y, 1),
            rates=(1, 1, 1, 1),
            padding="VALID",
        )
        patch_dim = patches.shape[-1]
        patch_num = patches.shape[1]
        return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "patch_size_y": self.patch_size_y,
                "patch_size_x": self.patch_size_x,
            }
        )
        return config


class PatchEmbedding(layers.Layer):
    def __init__(self, num_patch, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.num_patch = num_patch
        self.proj = layers.Dense(embed_dim)
        self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

    def call(self, patch):
        pos = tf.range(start=0, limit=self.num_patch, delta=1)
        return self.proj(patch) + self.pos_embed(pos)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "num_patch": self.num_patch,
            }
        )
        return config


class PatchMerging(layers.Layer):
    def __init__(self, num_patch, embed_dim):
        super().__init__()
        self.num_patch = num_patch
        self.embed_dim = embed_dim
        self.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)

    def call(self, x):
        height, width = self.num_patch
        _, _, C = x.get_shape().as_list()
        x = tf.reshape(x, shape=(-1, height, width, C))
        feat_maps = x

        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = tf.concat((x0, x1, x2, x3), axis=-1)
        x = tf.reshape(x, shape=(-1, (height // 2) * (width // 2), 4 * C))
        return self.linear_trans(x), feat_maps

    def get_config(self):
        config = super().get_config()
        config.update({"num_patch": self.num_patch, "embed_dim": self.embed_dim})
        return config

In [None]:
class WindowAttention(layers.Layer):
    def __init__(
        self,
        dim,
        window_size,
        num_heads,
        qkv_bias=True,
        dropout_rate=0.0,
        return_attention_scores=False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.return_attention_scores = return_attention_scores
        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
        self.dropout = layers.Dropout(dropout_rate)
        self.proj = layers.Dense(dim)

    def build(self, input_shape):
        self.relative_position_bias_table = self.add_weight(
            shape=(
                (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
                self.num_heads,
            ),
            initializer="zeros",
            trainable=True,
            name="relative_position_bias_table",
        )

        self.relative_position_index = self.get_relative_position_index(
            self.window_size[0], self.window_size[1]
        )
        super().build(input_shape)

    def get_relative_position_index(self, window_height, window_width):
        x_x, y_y = tf.meshgrid(range(window_height), range(window_width))
        coords = tf.stack([y_y, x_x], axis=0)
        coords_flatten = tf.reshape(coords, [2, -1])

        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0])

        x_x = (relative_coords[:, :, 0] + window_height - 1) * (2 * window_width - 1)
        y_y = relative_coords[:, :, 1] + window_width - 1
        relative_coords = tf.stack([x_x, y_y], axis=-1)

        return tf.reduce_sum(relative_coords, axis=-1)

    def call(self, x, mask=None):
        _, size, channels = x.shape
        head_dim = channels // self.num_heads
        x_qkv = self.qkv(x)
        x_qkv = tf.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim))
        x_qkv = tf.transpose(x_qkv, perm=(2, 0, 3, 1, 4))
        q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
        q = q * self.scale
        k = tf.transpose(k, perm=(0, 1, 3, 2))
        attn = q @ k

        relative_position_bias = tf.gather(
            self.relative_position_bias_table,
            self.relative_position_index,
            axis=0,
        )
        relative_position_bias = tf.transpose(relative_position_bias, [2, 0, 1])
        attn = attn + tf.expand_dims(relative_position_bias, axis=0)

        if mask is not None:
            nW = mask.get_shape()[0]
            mask_float = tf.cast(
                tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32
            )
            attn = (
                tf.reshape(attn, shape=(-1, nW, self.num_heads, size, size))
                + mask_float
            )
            attn = tf.reshape(attn, shape=(-1, self.num_heads, size, size))
            attn = tf.nn.softmax(attn, axis=-1)
        else:
            attn = tf.nn.softmax(attn, axis=-1)
        attn = self.dropout(attn)

        x_qkv = attn @ v
        x_qkv = tf.transpose(x_qkv, perm=(0, 2, 1, 3))
        x_qkv = tf.reshape(x_qkv, shape=(-1, size, channels))
        x_qkv = self.proj(x_qkv)
        x_qkv = self.dropout(x_qkv)

        if self.return_attention_scores:
            return x_qkv, attn
        else:
            return x_qkv

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "dim": self.dim,
                "window_size": self.window_size,
                "num_heads": self.num_heads,
                "scale": self.scale,
            }
        )
        return config

In [None]:
class SwinTransformer(layers.Layer):
    def __init__(
        self, 
        dim,
        num_patch,
        num_heads,
        window_size=7,
        shift_size=0,
        num_mlp=1024,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super(SwinTransformer, self).__init__(**kwargs)

        self.dim = dim 
        self.num_patch = num_patch  
        self.num_heads = num_heads 
        self.window_size = window_size  
        self.shift_size = shift_size  
        self.num_mlp = num_mlp  

        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(
            dim,
            window_size=(self.window_size, self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            dropout_rate=dropout_rate,
        )
        self.drop_path = (
            DropPath(dropout_rate) if dropout_rate > 0.0 else tf.identity
        )
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)

        self.mlp = keras.Sequential(
            [
                layers.Dense(num_mlp),
                layers.Activation(keras.activations.gelu),
                layers.Dropout(dropout_rate),
                layers.Dense(dim),
                layers.Dropout(dropout_rate),
            ]
        )

        if min(self.num_patch) < self.window_size:
            self.shift_size = 0
            self.window_size = min(self.num_patch)

    def build(self, input_shape):
        if self.shift_size == 0:
            self.attn_mask = None
        else:
            height, width = self.num_patch
            h_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            w_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            mask_array = jnp.zeros((1, height, width, 1))
            count = 0
            for h in h_slices:
                for w in w_slices:
                    mask_array[:, h, w, :] = count
                    count += 1
            mask_array = tf.convert_to_tensor(mask_array)

            # mask array to windows
            mask_windows = window_partition(mask_array, self.window_size)
            mask_windows = tf.reshape(
                mask_windows, shape=[-1, self.window_size * self.window_size]
            )
            attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(
                mask_windows, axis=2
            )
            attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
            attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
            self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False)

    def call(self, x):
        height, width = self.num_patch
        _, num_patches_before, channels = x.shape
        x_skip = x
        x = self.norm1(x)
        x = tf.reshape(x, shape=(-1, height, width, channels))
        if self.shift_size > 0:
            shifted_x = tf.roll(
                x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
            )
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = tf.reshape(
            x_windows, shape=(-1, self.window_size * self.window_size, channels)
        )
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        attn_windows = tf.reshape(
            attn_windows, shape=(-1, self.window_size, self.window_size, channels)
        )
        shifted_x = window_reverse(
            attn_windows, self.window_size, height, width, channels
        )
        if self.shift_size > 0:
            x = tf.roll(
                shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
            )
        else:
            x = shifted_x

        x = tf.reshape(x, shape=(-1, height * width, channels))
        x = self.drop_path(x)
        x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
        x_skip = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
        return x

# Hybrid-EfficientNet-Swin-Transformer

In [None]:
# base cnn models
base = keras.applications.EfficientNetB0(
    include_top=False,
    weights='imagenet',
    input_tensor=keras.Input((params.image_size, params.image_size, 3)),
)


In [None]:
class HybridModel(keras.Model):
    def __init__(self, model_name, **kwargs):
        super().__init__(name=model_name, **kwargs)

        # base model with compatible output which will be an input of transformer model
        self.multi_output_cnn = keras.Model(
            [base.inputs],
            [base.get_layer("block6a_expand_activation").output, base.output],
            name="efficientnet",
        )

        # base model's (cnn model) head
        self.conv_head = keras.Sequential(
            [
                layers.GlobalAveragePooling2D(),
                layers.AlphaDropout(0.5),
            ],
            name="conv_head",
        )

        # stuff of swin transformers
        self.patch_extract = PatchExtract(patch_size)
        self.patch_embedds = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)
        self.patch_merging = PatchMerging(
            (num_patch_x, num_patch_y), embed_dim=embed_dim
        )

        # swin blocks containers
        self.swin_sequences = keras.Sequential(name="swin_blocks")
        for i in range(shift_size):
            self.swin_sequences.add(
                SwinTransformer(
                    dim=embed_dim,
                    num_patch=(num_patch_x, num_patch_y),
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=i,
                    num_mlp=num_mlp,
                    qkv_bias=qkv_bias,
                    dropout_rate=dropout_rate,
                )
            )

        # swin block's head
        self.swin_head = keras.Sequential(
            [
                layers.GlobalAveragePooling1D(),
                layers.AlphaDropout(0.5),
                layers.BatchNormalization(),
            ],
            name="swin_head",
        )

        # classifier
        self.classifier = layers.Dense(
            params.class_number, activation=None, dtype="float32"
        )

        # build the graph
        self.build_graph()

    def forward_cnn(self, inputs):
        # CNN model.
        return self.multi_output_cnn(inputs)

    def forward_transformer(self, inputs):
        # Transformer model.
        x = self.patch_extract(inputs)
        x = self.patch_embedds(x)
        x = self.swin_sequences(tf.cast(x, dtype=tf.float32))
        x, swin_gcam_top = self.patch_merging(x)
        return x, swin_gcam_top

    def call(self, inputs, training=None, **kwargs):
        cnn_mid_layer, cnn_gcam_top = self.forward_cnn(inputs)
        transformer_output, transformer_gcam_top = self.forward_transformer(
            cnn_mid_layer
        )

        transformer_output = self.swin_head(transformer_output)
        cnn_output = self.conv_head(cnn_gcam_top)
        logits = self.classifier(tf.concat([transformer_output, cnn_output], axis=-1))

        if training:
            return logits
        else:
            return logits, cnn_gcam_top, transformer_gcam_top

    def build_graph(self):
        x = keras.Input(shape=(params.image_size, params.image_size, 3))
        return keras.Model(inputs=[x], outputs=self.call(x))


## Implement Gradient Accumulation

We also like to implement gradient accumulation technique in our model building pipelines. Usually transformer based models are computationally expensive and thus constrain the batch size limit. To overcome, we like to split up the batch into smaller mini-batches which are run sequentially, while accumulating their results. Because gradient accumulation technique calculates the loss and gradients after each mini-batch, but instead of updating the model parameters, it waits and accumulates the gradients over consecutive batches, so it can overcoming memory constraints, i.e using less memory to training the model like it using large batch size. So, if we run a gradient accumulation with steps of **8** and batch size of **8** images, it serves almost the same purpose of running with a batch size of **64** images.


<img src="https://miro.medium.com/max/1050/1*rJIH9gPhctTLCk5G5iQ_oA.png" width="500" height="500" />

[source.](https://towardsdatascience.com/what-is-gradient-accumulation-in-deep-learning-ec034122cfa)

In [None]:
class GradientAccumulation(HybridModel):
    """Subclassing the mdoel class to override the train step to 
    implemnet gradient accumulation.
    
    Ref: https://gist.github.com/innat/ba6740293e7b7b227829790686f2119c
    """

    def __init__(self, n_gradients, **kwargs):
        super().__init__(**kwargs)
        self.n_gradients = tf.constant(n_gradients, dtype=tf.int32)
        self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False)
        self.gradient_accumulation = [
            tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False)
            for v in self.trainable_variables
        ]

    def train_step(self, data):
        # track accumulation step update
        self.n_acum_step.assign_add(1)

        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Calculate batch gradients
        gradients = tape.gradient(loss, self.trainable_variables)

        # Accumulate batch gradients
        for i in range(len(self.gradient_accumulation)):
            self.gradient_accumulation[i].assign_add(gradients[i])

        # If n_acum_step reach the n_gradients then we apply accumulated gradients to -
        # update the variables otherwise do nothing
        tf.cond(
            tf.equal(self.n_acum_step, self.n_gradients),
            self.apply_accu_gradients,
            lambda: None,
        )

        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def apply_accu_gradients(self):
        # Update weights
        self.optimizer.apply_gradients(
            zip(self.gradient_accumulation, self.trainable_variables)
        )

        # reset accumulation step
        self.n_acum_step.assign(0)
        for i in range(len(self.gradient_accumulation)):
            self.gradient_accumulation[i].assign(
                tf.zeros_like(self.trainable_variables[i], dtype=tf.float32)
            )

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_pred, base_gcam_top, swin_gcam_top = self(x, training=False)

        # Updates the metrics tracking the loss
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_pred)

        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}
    

In [None]:
def get_model(plot_summary=False, plot_graph=False):
    keras.backend.clear_session()
    model = GradientAccumulation(
        n_gradients=params.num_grad_accumulation, model_name="HybridModel"
    )

    if plot_summary:
        display(
            model.build_graph().summary()
        )

    if plot_graph:
        display(
            keras.utils.plot_model(
                model.build_graph(),
                show_shapes=True,
                show_layer_names=True,
                expand_nested=False,
            )
        )

    # compile
    model.compile(
        loss=losses.CategoricalCrossentropy(
            label_smoothing=params.label_smooth, from_logits=True
        ),
        optimizer=optimizers.Adam(learning_rate, amsgrad=True),
        metrics=["accuracy"],
    )
    return model

**Pre Setting for Training**

In [None]:
from tensorflow.keras import losses
from tensorflow.keras import metrics
from tensorflow.keras import callbacks
from tensorflow.keras import optimizers

ckp = callbacks.ModelCheckpoint(
    "model.h5",
    monitor="val_accuracy",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="max",
)
log = callbacks.CSVLogger("history.csv", separator=",", append=False)

In [None]:
class WarmupLearningRateSchedule(optimizers.schedules.LearningRateSchedule):
    """WarmupLearningRateSchedule a variety of learning rate
    decay schedules with warm up.
    
    Ref. https://gist.github.com/innat/69e8f3500c2418c69b150a0a651f31dc
    """

    def __init__(
        self,
        initial_lr,
        steps_per_epoch=None,
        lr_decay_type="exponential",
        decay_factor=0.97,
        decay_epochs=2.4,
        total_steps=None,
        warmup_epochs=5,
        minimal_lr=0, 
        **kwargs
    ):
        super().__init__(**kwargs)
        self.initial_lr = initial_lr
        self.steps_per_epoch = steps_per_epoch
        self.lr_decay_type = lr_decay_type
        self.decay_factor = decay_factor
        self.decay_epochs = decay_epochs
        self.total_steps = total_steps
        self.warmup_epochs = warmup_epochs
        self.minimal_lr = minimal_lr

    def __call__(self, step):
        if self.lr_decay_type == "exponential":
            assert self.steps_per_epoch is not None
            decay_steps = self.steps_per_epoch * self.decay_epochs
            lr = schedules.ExponentialDecay(
                self.initial_lr, decay_steps, self.decay_factor, staircase=True
            )(step)
            
        elif self.lr_decay_type == "cosine":
            assert self.total_steps is not None
            lr = (
                0.5
                * self.initial_lr
                * (1 + tf.cos(np.pi * tf.cast(step, tf.float32) / self.total_steps))
            )

        elif self.lr_decay_type == "linear":
            assert self.total_steps is not None
            lr = (1.0 - tf.cast(step, tf.float32) / self.total_steps) * self.initial_lr

        elif self.lr_decay_type == "constant":
            lr = self.initial_lr

        elif self.lr_decay_type == "cosine_restart":
            decay_steps = self.steps_per_epoch * self.decay_epochs
            lr = tf.keras.experimental.CosineDecayRestarts(
                self.initial_lr, decay_steps
            )(step)
        else:
            assert False, "Unknown lr_decay_type : %s" % self.lr_decay_type

        if self.minimal_lr:
            lr = tf.math.maximum(lr, self.minimal_lr)

        if self.warmup_epochs:
            warmup_steps = int(self.warmup_epochs * self.steps_per_epoch)
            warmup_lr = (
                self.initial_lr
                * tf.cast(step, tf.float32)
                / tf.cast(warmup_steps, tf.float32)
            )
            lr = tf.cond(step < warmup_steps, lambda: warmup_lr, lambda: lr)

        return lr

    def get_config(self):
        return {
            "initial_lr": self.initial_lr,
            "steps_per_epoch": self.steps_per_epoch,
            "lr_decay_type": self.lr_decay_type,
            "decay_factor": self.decay_factor,
            "decay_epochs": self.decay_epochs,
            "total_steps": self.total_steps,
            "warmup_epochs": self.warmup_epochs,
            "minimal_lr": self.minimal_lr,
        }


In [None]:
learning_rate = WarmupLearningRateSchedule(
    params.scaled_lr,
    steps_per_epoch=params.train_step,
    decay_epochs=params.lr_decay_epoch,
    warmup_epochs=params.lr_warmup_epoch,
    decay_factor=params.lr_decay_factor,
    lr_decay_type=params.lr_sched,
    total_steps=params.total_steps,
    minimal_lr=params.scaled_lr_min,
)

rng = [i for i in range(params.total_steps)]
lr_y = [learning_rate(x) for x in rng]
plt.figure(figsize=(10, 4))
plt.plot(rng, lr_y)
plt.xlabel("Iteration", size=14)
plt.ylabel("Learning Rate", size=14)

In [None]:
# training
if physical_devices:
    model = get_model(plot_summary=True, plot_graph=False)
    history = model.fit(
        train_ds,
        epochs=params.epochs,
        callbacks=[ckp, log],
        validation_data=val_ds,
        verbose=params.verbosity,
    ).history
    
    model.load_weights("./model.h5")
    display(pd.DataFrame(history).tail(5))
else:
    keras.mixed_precision.set_global_policy("float32")
    model = get_model(plot_summary=True, plot_graph=False)
    model(tf.ones((1, params.image_size, params.image_size, 3)))[0].shape

    # get trained weight and history file
    weight = get_model_weight(model_id="1y6tseN0194T6d-4iIh5wo7RL9ttQERe0")
    model.load_weights(weight)

    history_csv = get_model_history(history_id="1J6QgHUqtYL0mIWC2h0K6HeIHcollfRUe")
    history = pd.read_csv(history_csv)
    display(history.tail())


**Learning Curve**

In [None]:
# Plotting
plt.figure(figsize=(20, 10))
plt.plot(range(len(history["loss"])), history["loss"], "-o", label="train_loss")
plt.plot(range(len(history["loss"])), history["val_loss"], "-o", label="val_loss")
plt.plot(range(len(history["loss"])), history["accuracy"], "-o", label="train_accuracy")
plt.plot(
    range(len(history["loss"])), history["val_accuracy"], "-o", label="val_accuracy"
)
plt.title("Training Loss and Accuracy", fontdict={'fontsize':20})
plt.xlabel(f"Epoch {len(history)}", fontsize=20)
plt.ylabel("Loss/Accuracy", fontsize=20)
plt.legend(loc="best", prop={"size": 20})
plt.tight_layout()
plt.show()

# Grad-CAM : Hybrid-EfficientNet-Swin-Transformer

In [None]:
def plot_stuff(inputs, features_a, features_b):
    plt.figure(figsize=(25, 25))
    
    plt.subplot(1, 3, 1)
    plt.axis('off')
    plt.imshow(tf.squeeze(inputs/255, axis=0))
    plt.title('Input')
    
    plt.subplot(1, 3, 2)
    plt.axis('off')
    plt.imshow(features_a)
    plt.title('CNN')
    
    plt.subplot(1, 3, 3)
    plt.axis('off')
    plt.imshow(features_b)
    plt.title('Hybrid-CNN-Transformer')
    plt.show()

# ref: https://keras.io/examples/vision/grad_cam/
def get_img_array(img):
    array = keras.utils.img_to_array(img)
    array = np.expand_dims(array, axis=0)
    return array

# ref: https://keras.io/examples/vision/grad_cam/
def make_gradcam_heatmap(img_array, grad_model, pred_index=None):
    with tf.GradientTape(persistent=True) as tape:
        preds, base_top, swin_top = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    grads = tape.gradient(class_channel, base_top)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    base_top = base_top[0]
    heatmap_a = base_top @ pooled_grads[..., tf.newaxis]
    heatmap_a = tf.squeeze(heatmap_a)
    heatmap_a = tf.maximum(heatmap_a, 0) / tf.math.reduce_max(heatmap_a)
    heatmap_a = heatmap_a.numpy()

    grads = tape.gradient(class_channel, swin_top)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    swin_top = swin_top[0]
    heatmap_b = swin_top @ pooled_grads[..., tf.newaxis]
    heatmap_b = tf.squeeze(heatmap_b)
    heatmap_b = tf.maximum(heatmap_b, 0) / tf.math.reduce_max(heatmap_b)
    heatmap_b = heatmap_b.numpy()
    return heatmap_a, heatmap_b

In [None]:
# Get val images
img_arrays = next(iter(val_ds))[0]
print(img_arrays.shape)

# plot utils
for img_array in img_arrays[:3]:
    # Generate class activation heatmap
    img_array = get_img_array(img_array)
    cnn_heatmap, swin_heatmap = make_gradcam_heatmap(img_array, model)
    print(cnn_heatmap.shape, cnn_heatmap.max(), cnn_heatmap.min())
    print(swin_heatmap.shape, swin_heatmap.max(), swin_heatmap.min())

    # Display heatmap
    plot_stuff(img_array, cnn_heatmap, swin_heatmap)

In [None]:
# ref: https://keras.io/examples/vision/grad_cam/
def save_and_display_gradcam(
    img,
    heatmap,
    target=None,
    pred=None,
    cam_path="cam.jpg",
    cmap="jet",  # inferno, viridis
    alpha=0.6,
    plot=None,
):
    # Rescale heatmap to a range 0-255
    heatmap = np.uint8(255 * heatmap)

    # Use jet colormap to colorize heatmap
    jet = cm.get_cmap(cmap)

    # Use RGB values of the colormap
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]

    # Create an image with RGB colorized heatmap
    jet_heatmap = keras.utils.array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((img.shape[0], img.shape[1]))
    jet_heatmap = keras.utils.img_to_array(jet_heatmap)

    # Superimpose the heatmap on original image
    superimposed_img = img + jet_heatmap * alpha
    superimposed_img = keras.utils.array_to_img(superimposed_img)
    return superimposed_img

In [None]:
samples, labels = next(iter(val_ds.shuffle(params.batch_size)))

for sample, label in zip(samples, labels):
    # preparing
    img_array = sample[tf.newaxis, ...]

    # get heatmaps
    heatmap_a, heatmap_b = make_gradcam_heatmap(img_array, model)

    # overaly heatmap and input sample
    overaly_a = save_and_display_gradcam(sample, heatmap_a)
    overlay_b = save_and_display_gradcam(sample, heatmap_b)

    # ploting stuff
    plot_stuff(img_array, overaly_a, overlay_b)

# Model Saving and Reloading

Previously, we have saved the model's weights using `callbacks.ModelCheckpoint` API. But we can also save the entire model in `TensorFlow SavedModel` format (recommended format option). A `SavedModel` contains a complete TensorFlow program, including trained parameters (i.e, [tf.Variables](https://www.tensorflow.org/api_docs/python/tf/Variable)) and computation. It does not require the original model building code to run, which makes it useful for sharing or deploying with [TFLite](https://www.tensorflow.org/lite), [TensorFlow.js](https://www.tensorflow.org/js/), [TensorFlow Serving](https://www.tensorflow.org/tfx/tutorials/serving/rest_simple), or [TensorFlow Hub](https://www.tensorflow.org/hub).

In [None]:
# Calling `save('my_model')` creates a SavedModel folder `my_model`.
model.save("saved_model")


# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model(
    "saved_model",
    custom_objects={"WarmupLearningRateSchedule": WarmupLearningRateSchedule},
)


In [None]:
# Let's check: weight matching
assert len(model.weights) == len(reconstructed_model.weights)
for a, b in zip(model.weights, reconstructed_model.weights):
    np.testing.assert_allclose(a.numpy(), b.numpy())

    
# Let's check: inference matching
test_input = tf.random.normal(
    [1, params.image_size, params.image_size, 3], 0, 1, tf.float32
)
tf.nest.map_structure(
    np.testing.assert_allclose,
    model.predict(test_input),
    reconstructed_model.predict(test_input),
)


# TensorFlow Lite Conversion

This **HybridSwinTransformer** model can be converted into [TensorFlow Lite](https://www.tensorflow.org/lite/guide) format for mobile and edge devices. Below is the conversion code. Note that, the `TFLite` can be used to minimize the complexity of optimizing inference. You can read more details about it from [here](https://www.tensorflow.org/lite/performance/model_optimization).

In [None]:
from tensorflow import lite

# wrap keras model, optimize, enable tf operation
# and convert.
converter = lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
    lite.OpsSet.TFLITE_BUILTINS,
    lite.OpsSet.SELECT_TF_OPS,
]
converter.allow_custom_ops = True
tflite_model = converter.convert()

# Save the model.
with open("model.tflite", "wb") as write_tflite:
    write_tflite.write(tflite_model)

In [None]:
!ls ./ -a

# [Info] TF-Hub Publications

After we saved the model in **`SavedModel`** format, we can start working on publicaiton process in [TF-Hub](https://www.tensorflow.org/hub) if we want. Basically, its a **three** step process, and [documented here](https://www.tensorflow.org/hub/publish#overview_of_the_publishing_process) in great details. Mainly,

- [Export Model](https://www.tensorflow.org/hub/exporting_tf2_saved_model)
- [Write Documentation](https://www.tensorflow.org/hub/writing_documentation)
- [Send PR](https://www.tensorflow.org/hub/contribute_a_model).


Note, at the time of exporting the model, you might want to exclude the optimizer, shown below. So that, you won't need to use `custom_objects` while reloading model later.

```python

model.save('saved_model', include_optimizer=False)
reconstructed_model = keras.models.load_model(
    'saved_model'
)

```

# Conclusion

In this experiment, we've tried to inspect the visual attributes of a hybrid model with gradcam technique. In a results, it appears that the transformers blocks have more potential to interpret its decision making process with strong visual maps. The current findings are promising. We also observed that mid-level features maps of CNN tend to rectify more globally by the transformer blocks. 

The visual maps created by the transformer, as shown here, are encouraging. However, it is not robust; it sometimes fails to capture relevant features. However, we could well rethink the integration strategy and develop a more effective hybrid model. Further training on more extensive datasets will generalize this hybrid model. You can try this model in [hugging face spaces](https://huggingface.co/spaces/innat/HybridModel-GradCAM). 

### References

- [Swin Transformer](https://arxiv.org/pdf/2103.14030.pdf)
- [VcampSoldiers/Swin-Transformer-Tensorflow](https://github.com/VcampSoldiers/Swin-Transformer-Tensorflow)