## Learning to Resize Images for Vision Transformer


In this notebook, it's shown how we can train a higher resolution image for vision transformer model. Usually transformer based models with high resolutoin may not fit into GPU. In such cases, we can adopt a **Trainable Resizer** mechanism as a backbone of the transformer models and perform as a joint  learning of the image resizer and recognition models.

[**Learning to Resize Images for Computer Vision Tasks** - Google Research](https://arxiv.org/pdf/2103.09950v1.pdf). For a given image resolutoin and a model, this research work answer how to best resize that image in a target resolutoin. Off-the-shelf image resizers for example: bilinear, bicubic methods are commonly used in most of the machine learning softwares. But this may limit the on-task performance of the trained model. In the above work, it's shown that typical linear resizer can be replaced with the **Learned Resizer**. Below is the overall proposed learnable resizer blocks.

![image resizer](https://user-images.githubusercontent.com/17668390/138250657-29995830-b903-447f-8729-09b72b90ab3c.png)

In the paper, they showed that the proposed resizer mechanism improve the classificaiton mdoels. The added the resizer mechanism to the classification mdoels such as `DenseNet`, `InceptionNet` etc. IN this way, we can input very image size and the resizer mechanism will downsample the image appropriately for the actual mdoel. 

![rtee](https://user-images.githubusercontent.com/17668390/138254072-f87daa13-12cc-4c6a-9145-a567f644cb12.png)

[**Vision Transformer** - Google Research](https://arxiv.org/pdf/2010.11929.pdf). We know that the transformer models are computationally expensive. And thus limits the input size roughly around `224`, `384`. So, the idea is to use this **resizer mechanism** as a bacbone of the **vision transformer**, so that we can input enough large image for **joint learning**. So, the overall model architecture would be 

![Presentation2](https://user-images.githubusercontent.com/17668390/138256285-c24f98db-ce35-4877-8741-221fd57d895e.jpg)

**Reference**

- [ROBIN SMITS](https://www.kaggle.com/rsmits/effnet-b2-feature-models-catboost#SET-TPU-/-GPU) - For general training pipelines. Great work. 
- [Learnable-Image-Resizing](https://github.com/sayakpaul/Learnable-Image-Resizing) For resizer building blocks. 
- [TensorFlow-HUB](https://github.com/sayakpaul/ViT-jax2tf) For ViT 

In [1]:
import numpy as np 
import pandas as pd 

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf 
from tensorflow.keras import Input, Model, Sequential, layers
import tensorflow_hub as hub

In [2]:
INP_SIZE      = (512, 512) # Input size of the Image Resizer Module (IRM)
TARGET_SIZE   = (224, 224) # Output size of IRM and Input size of the Vision Transformer 
INTERPOLATION = "bilinear"

**tf.data**

In [3]:
Q = 30
feature_folds = 10

batch_size = 12
epochs  = 10
seed  = 123123
verbose = 1
lr  = 0.005

DATA_DIR = '../input/petfinder-pawpularity-score/'
TRAIN_DIR = DATA_DIR + 'train/'
TEST_DIR = DATA_DIR + 'test/'

# SetAutoTune
AUTOTUNE = tf.data.experimental.AUTOTUNE  

def build_augmenter(is_labelled):
    def augment(img):
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)
        img = tf.image.random_saturation(img, 0.95, 1.05)
        img = tf.image.random_brightness(img, 0.05)
        img = tf.image.random_contrast(img, 0.95, 1.05)
        img = tf.image.random_hue(img, 0.05)
        return img
    def augment_with_labels(img, label):
        return augment(img), label
    return augment_with_labels if is_labelled else augment

def build_decoder(is_labelled):
    def decode(path):
        file_bytes = tf.io.read_file(path)
        img = tf.image.decode_jpeg(file_bytes, channels = 3)
        img = tf.image.resize(img, (INP_SIZE))
        return img
    def decode_with_labels(path, label):
        return decode(path), label
    return decode_with_labels if is_labelled else decode

def create_dataset(df, batch_size = 32, is_labelled = False, 
                   augment = False, repeat = False, shuffle = False):
    decode_fn    = build_decoder(is_labelled)
    augmenter_fn = build_augmenter(is_labelled)
    
    # Create Dataset
    if is_labelled:
        dataset = tf.data.Dataset.from_tensor_slices((df['Id'].values, df['target_value'].values))
    else:
        dataset = tf.data.Dataset.from_tensor_slices((df['Id'].values))
    dataset = dataset.map(decode_fn, num_parallel_calls = AUTOTUNE)
    dataset = dataset.map(augmenter_fn, num_parallel_calls = AUTOTUNE) if augment else dataset
    dataset = dataset.repeat() if repeat else dataset
    dataset = dataset.shuffle(1024, reshuffle_each_iteration = True) if shuffle else dataset
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

In [4]:
# Load Train Data
train_df = pd.read_csv(f'{DATA_DIR}train.csv')
train_df['Id'] = train_df['Id'].apply(lambda x: f'{TRAIN_DIR}{x}.jpg')

# Set a specific label to be able to perform stratification
train_df['stratify_label'] = pd.qcut(train_df['Pawpularity'], q = Q, labels = range(Q))

# Label value to be used for feature model 'classification' training.
train_df['target_value'] = train_df['Pawpularity'] / 100.

# Summary
print(f'train_df: {train_df.shape}')
train_df.head()

# Learning to Resize

![image resizer](https://user-images.githubusercontent.com/17668390/138250657-29995830-b903-447f-8729-09b72b90ab3c.png)

In [5]:
# ref: https://keras.io/examples/vision/learnable_resizer/
def residual_block(x):
    shortcut = x

    def conv_bn_leaky(inputs, filters, kernel_size, strides):
        x = layers.Conv2D(filters, kernel_size, strides=strides, 
                          use_bias=False, padding='same')(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        return x 
    
    def conv_bn(inputs, filters, kernel_size, strides):
        x = layers.Conv2D(filters, kernel_size, strides, padding='same')(inputs)
        x = layers.BatchNormalization()(x)
        return x 
    
    x = conv_bn_leaky(x, 16, 3, 1)
    x = conv_bn(x, 16, 3, 1)
    x = layers.add([shortcut, x])
    return x


def get_learnable_resizer(filters=16, num_res_blocks=1, interpolation=INTERPOLATION):
    inputs = layers.Input(shape=[None, None, 3])

    # First, perform naive resizing.
    naive_resize = layers.Resizing(
        *TARGET_SIZE, interpolation=interpolation
    )(inputs)

    # First convolution block without batch normalization.
    x = layers.Conv2D(filters=filters, kernel_size=7, strides=1, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)

    # Second convolution block with batch normalization.
    x = layers.Conv2D(filters=filters, kernel_size=1, strides=1, padding='same')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.BatchNormalization()(x)

    # Intermediate resizing as a bottleneck.
    bottleneck = layers.Resizing(
        *TARGET_SIZE, interpolation=interpolation
    )(x)
    
    # Residual passes.
    x = residual_block(bottleneck)
    for i in range(1, num_res_blocks):
        x = residual_block(x)
        
    # Projection.
    x = layers.Conv2D(
        filters=filters, kernel_size=3, strides=1, padding="same", use_bias=False
    )(x)
    x = layers.BatchNormalization()(x)

    # Skip connection.
    x = layers.Add()([bottleneck, x])

    # Final resized image.
    x = layers.Conv2D(filters=3, kernel_size=7, strides=1, padding="same")(x)
    final_resize = layers.Add()([naive_resize, x])
    return Model(inputs, final_resize, name="learnable_resizer")

learnable_resizer = get_learnable_resizer(num_res_blocks=3)

**Check**

Let's check how raw image get tansformed with this resizer blocks with initial states.

In [6]:
training_dataset = create_dataset(train_df,
                                  batch_size  = batch_size, 
                                  is_labelled = True, 
                                  augment = True,
                                  repeat  = True, 
                                  shuffle = True)
sample_images, _ = next(iter(training_dataset))

In [7]:
import matplotlib.pyplot as plt 

plt.figure(figsize=(16, 10))
for i, image in enumerate(sample_images[:6]):
    image = image / 255

    ax = plt.subplot(3, 4, 2 * i + 1)
    plt.title("Input Image")
    plt.imshow(image.numpy().squeeze())
    plt.axis("off")

    ax = plt.subplot(3, 4, 2 * i + 2)
    resized_image = learnable_resizer(image[None, ...])
    plt.title("Resized Image")
    plt.imshow(resized_image.numpy().squeeze())
    plt.axis("off")

# Learned Resizer + Vision Transformer 

In [8]:
handle="https://tfhub.dev/sayakpaul/vit_s16_fe/1"

def get_model(plot_modal, print_summary, with_compile):
    hub_layer = hub.KerasLayer(handle, trainable=True)
    backbone = Sequential(
        [
            layers.InputLayer((TARGET_SIZE[0], TARGET_SIZE[1], 3)),
            hub_layer
        ], name='vit'
    )
    inputs = layers.Input((INP_SIZE[0], INP_SIZE[1], 3))
    x = layers.Rescaling(scale=1.0 / 255)(inputs)
    x = learnable_resizer(x)
    outputs = backbone(x)
    
    tail = Sequential(
            [
                layers.Dropout(0.2),
                layers.BatchNormalization(),
                layers.Dense(1, activation = 'sigmoid')
            ], name='head'
        )
    
    model = Model(inputs, tail(outputs))
    
    if plot_modal:
        display(tf.keras.utils.plot_model(model, show_shapes=True, 
                                          show_layer_names=True,  expand_nested=True))
    if print_summary:
        print(model.summary())
        
    if with_compile:
        model.compile(
            optimizer = tf.keras.optimizers.Adam(learning_rate = lr),  
            loss = tf.keras.losses.BinaryCrossentropy(), 
            metrics = [tf.keras.metrics.RootMeanSquaredError('rmse')])  
    return model 

In [9]:
get_model(plot_modal=True, print_summary=True, with_compile=False)

In [10]:
from tensorflow.keras import losses, optimizers, metrics
from tensorflow.keras import callbacks

def model_callback(fold):
    ckpt = tf.keras.callbacks.ModelCheckpoint(f'feature_model_{fold}.h5',
                                              verbose = 1, 
                                              monitor = 'val_rmse',
                                              mode = 'min', 
                                              save_weights_only = True,
                                              save_best_only = True)
    
    return [ckpt]

In [None]:
import gc
from sklearn.model_selection import StratifiedKFold

# OOF RMSE Placeholder
all_val_rmse = []
kfold = StratifiedKFold(n_splits = feature_folds, 
                        shuffle = True, random_state = seed)
for fold, (train_index, val_index) in enumerate(kfold.split(train_df.index,
                                                            train_df['stratify_label'])):
    if fold == 0:
        print(f'\nFold {fold}\n')
        # Pre model.fit cleanup
        tf.keras.backend.clear_session()
        gc.collect()

        # Create Model
        model = get_model(plot_modal    = False, 
                          print_summary = False,
                          with_compile  = True)
        for i in range(len(model.weights)):
            model.weights[i]._handle_name = model.weights[i].name + str(i)
    
        # Create TF Datasets
        trn = train_df.iloc[train_index]
        val = train_df.iloc[val_index]
        training_dataset = create_dataset(trn, 
                                          batch_size  = batch_size, 
                                          is_labelled = True, 
                                          augment     = True, 
                                          repeat      = True, 
                                          shuffle     = True)
        validation_dataset = create_dataset(val, 
                                            batch_size  = batch_size, 
                                            is_labelled = True,
                                            augment     = False, 
                                            repeat      = True,
                                            shuffle     = False)
        # Fit Model
        history = model.fit(training_dataset,
                            epochs = epochs,
                            steps_per_epoch  = trn.shape[0] // batch_size,
                            validation_steps = val.shape[0] // batch_size,
                            callbacks = model_callback(fold),
                            validation_data = validation_dataset,
                            verbose = 2)   

        # Validation Information
        best_val_rmse = min(history.history['val_rmse'])
        all_val_rmse.append(best_val_rmse)
        print(f'\nValidation RMSE: {best_val_rmse}\n')
        del model 