# ViViT with/without Token Learner

This notebook explores the effect of **Token Learner** put in **ViViT**. 

The datasets used for training are from **MedMNIST 3D**, which contains medical 3D images with different types of classes. The model structure was tested on patch size 8 and 16, and token learner was put in the middle (half point of the transformer blocks). AdamW optimization method was used for regulralization and the learning rate was reduced on plateau.

### The Result

<p align="center">
 <img src = "./ViViT/img/vivit_tl_nodulemnist3d.png", height="200x", width="500px">
</p>

<p align="center">
 <img src = "./ViViT/img/vivit_tl_organmnist3d.png", height="200x", width="500px">
</p>

The overall performance of the model with token learner was better than the naive model in validation accracy and loss over epochs. Also, there was no signs of overfitting with token learner even though the training time was shortened. The result shows that with token learners models learn faster, without significant risk of overfitting.

All of the result graphs are displayed on [TensorBoard](https://tensorboard.dev/experiment/nYVP58K4Q1GEuWLbkWBFow/). 

### References :
1. Paper : [TokenLearner: Adaptive Space-Time Tokenization for Videos](https://proceedings.neurips.cc/paper/2021/file/6a30e32e56fce5cf381895dfe6ca7b6f-Paper.pdf) 
1. Paper : [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691)
1. Codes : https://keras.io/examples/vision/token_learner/
1. Codes : https://keras.io/examples/vision/vivit/
1. Blog : https://ai.googleblog.com/2021/12/improving-vision-transformer-efficiency.html

## Settings & Downloads

In [1]:
!pip install -q -U tensorflow
!pip install -q -U tensorboard
!pip install -q tensorflow-gpu
!pip install -q tensorflow-addons  # for AdamW Optimizer.
!pip install -q medmnist

In [2]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [3]:
import os
import io
import logging
import medmnist
import numpy as np
from datetime import datetime

import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa
from tensorflow.keras import layers


# Setting seed for reproducibility
SEED = 2023
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
keras.utils.set_random_seed(SEED)


# Single device checking
if tf.config.list_physical_devices('GPU'):
    DEVICE_NAME = "/gpu:0"
    print("Currently running on GPU")
    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
else:  
    DEVICE_NAME = "/cpu:0"
    print("Currently running on CPU")


# Refrain from verbose logging
logging.disable(logging.WARNING)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

Currently running on GPU
Num GPUs Available:  1


### Hyperparameters

In [4]:
# DATA
# https://github.com/MedMNIST/MedMNIST/blob/main/medmnist/dataset.py
DATASET_INFOS = (("organmnist3d", 11), 
                 ("nodulemnist3d", 2),
                 ("adrenalmnist3d", 2),
                 ("fracturemnist3d", 3),
                 ("vesselmnist3d", 2),
                 ("synapsemnist3d", 2))  # (dataset_name, num_classes)
INPUT_SHAPE = (28, 28, 28, 1)
BATCH_SIZE = 64
AUTO = tf.data.AUTOTUNE

# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# TRAINING
EPOCHS = 60  # will use early stop method.

# TUBELET EMBEDDING & TOKENLEARNER
PATCH_SIZES = [8, 16]

# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8
MLP_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM,
]

<!-- ### Load and prepare the CIFAR-10 dataset -->



In [5]:
def download_and_prepare_dataset(data_info: dict):
    """Utility function to download the dataset.

    Arguments:
        data_info (dict): Dataset metadata.
    """
    data_path = keras.utils.get_file(
        origin=data_info["url"], md5_hash=data_info["MD5"]
    )

    with np.load(data_path) as data:
        # Get videos
        train_videos = data["train_images"]
        valid_videos = data["val_images"]
        test_videos = data["test_images"]

        # Get labels
        train_labels = data["train_labels"].flatten()
        valid_labels = data["val_labels"].flatten()
        test_labels = data["test_labels"].flatten()

    return (
        (train_videos, train_labels),
        (valid_videos, valid_labels),
        (test_videos, test_labels),
    )

In [6]:
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
    """Preprocess the frames tensors and parse the labels."""
    # Preprocess images
    frames = tf.image.convert_image_dtype(
        frames[
            ..., tf.newaxis
        ],  # The new axis is to help for further processing with Conv3D layers
        tf.float32,
    )
    # Parse label
    label = tf.cast(label, tf.float32)
    return frames, label


def prepare_dataloader(
        videos: np.ndarray,
        labels: np.ndarray,
        loader_type: str = "train",
        batch_size: int = BATCH_SIZE,
    ):
    """Utility function to prepare the dataloader."""
    dataset = tf.data.Dataset.from_tensor_slices((videos, labels))

    if loader_type == "train":
        dataset = dataset.shuffle(BATCH_SIZE * 2)

    dataloader = (
        dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    return dataloader


def get_dataloaders(dataset_name):
    # Get the metadata of the dataset
    info = medmnist.INFO[dataset_name]

    # Get the dataset
    prepared_dataset = download_and_prepare_dataset(info)
    (train_videos, train_labels) = prepared_dataset[0]
    (valid_videos, valid_labels) = prepared_dataset[1]
    (test_videos, test_labels) = prepared_dataset[2]

    # Prepare DataLoaders
    trainloader = prepare_dataloader(train_videos, train_labels, "train")
    validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
    testloader = prepare_dataloader(test_videos, test_labels, "test")

    return (trainloader, validloader, testloader)

## Model


In [7]:
class TubeletEmbedding(layers.Layer):
    def __init__(self, embed_dim, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.projection = layers.Conv3D(
            filters=embed_dim,
            kernel_size=patch_size,
            strides=patch_size,
            padding="valid",
        )
        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))

    def call(self, videos):
        projected_patches = self.projection(videos)
        flattened_patches = self.flatten(projected_patches)
        return flattened_patches

In [8]:
class PositionalEncoder(layers.Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim

    def build(self, input_shape):
        _, num_tokens, _ = input_shape
        self.position_embedding = layers.Embedding(
            input_dim=num_tokens, output_dim=self.embed_dim
        )
        self.positions = tf.range(start=0, limit=num_tokens, delta=1)

    def call(self, encoded_tokens):
        # Encode the positions and add it to the encoded tokens
        encoded_positions = self.position_embedding(self.positions)
        encoded_tokens = encoded_tokens + encoded_positions
        return encoded_tokens

### TokenLearner Module

In [9]:
def token_learner(inputs, num_tokens):
    # Layer normalize the inputs.
    x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(inputs)  # (B, T, H, W, C)

    # Applying Conv3D => Reshape => Permute
    # The reshape and permute is done to help with the next steps of
    # multiplication and Global Average Pooling.
    attention_maps = keras.Sequential(
        [
            # 3 layers of conv with gelu activation as suggested
            # in the paper.
            layers.Conv3D(
                filters=num_tokens,
                kernel_size=(3, 3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False,
            ),
            layers.Conv3D(
                filters=num_tokens,
                kernel_size=(3, 3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False,
            ),
            layers.Conv3D(
                filters=num_tokens,
                kernel_size=(3, 3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False,
            ),
            # This conv layer will generate the attention maps
            layers.Conv3D(
                filters=num_tokens,
                kernel_size=(3, 3, 3),
                activation="sigmoid",  # Note sigmoid for [0, 1] output
                padding="same",
                use_bias=False,
            ),
            # Reshape and Permute
            layers.Reshape((-1, num_tokens)),  # (B, T*H*W, num_of_tokens)
            layers.Permute((2, 1)),
        ]
    )(
        x
    )  # (B, num_of_tokens, T*H*W)

    # Reshape the input to align it with the output of the conv block.
    num_filters = inputs.shape[-1]
    inputs = layers.Reshape((1, -1, num_filters))(inputs)  # inputs == (B, 1, T*H*W, C)

    # Element-Wise multiplication of the attention maps and the inputs
    attended_inputs = (
        attention_maps[..., tf.newaxis] * inputs
    )  # (B, num_tokens, T*H*W, C)

    # Global average pooling the element wise multiplication result.
    outputs = tf.reduce_mean(attended_inputs, axis=2)  # (B, num_tokens, C)
    return outputs

### Transformer Block

In [10]:
def mlp(x, dropout_rate, hidden_units):
    # Iterate over the hidden units and add Dense => Dropout.
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


def transformer(encoded_patches):
    # Layer normalization 1.
    x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)

    # Multi Head Self Attention layer 1.
    attention_output = layers.MultiHeadAttention(
        num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
    )(x1, x1)

    # Skip connection 1.
    x2 = layers.Add()([attention_output, encoded_patches])

    # Layer normalization 2.
    x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

    # MLP layer 1.
    x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1)

    # Skip connection 2.
    encoded_patches = layers.Add()([x4, x2])
    return encoded_patches

### ViViT model


In [11]:
def create_vivit_classifier(
    num_tokens,
    num_classes,
    tubelet_embedder,
    positional_encoder,
    input_shape=INPUT_SHAPE,
    transformer_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    embed_dim=PROJECTION_DIM,
    layer_norm_eps=LAYER_NORM_EPS,
    use_token_learner=True,
):
    # Get the input layer
    inputs = layers.Input(shape=input_shape)
    # Create patches.
    patches = tubelet_embedder(inputs)  # (B, num_tokens, embed_dim)
    # Encode patches.
    encoded_patches = positional_encoder(patches)  # (B, num_tokens, embed_dim)

    for i in range(transformer_layers):
        # Add a Transformer block.
        encoded_patches = transformer(encoded_patches)

        # Add TokenLearner layer in the middle(1/2) of the architecture. 
        if use_token_learner and i == transformer_layers // 2:
            _, thw, c = encoded_patches.shape
            n = int(pow(thw, 1/3))
            encoded_patches = layers.Reshape((-1, n, n, n, c))(encoded_patches)  # (B, n, n, n, c)
            encoded_patches = token_learner(encoded_patches, num_tokens)  # (B, num_tokens, C)

    # Layer normalization and Global average pooling.
    representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)

    # Classify outputs.
    outputs = layers.Dense(units=num_classes, activation="softmax")(representation)

    # Create Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

## Train

### Training Utility

In [12]:
def run_experiment(dataset_info, 
                   patch_size,
                   use_token_learner=True):
  
    with tf.device(DEVICE_NAME):
        # Load dataset
        dataset_name, num_classes = dataset_info
        trainloader, validloader, testloader = get_dataloaders(dataset_name)
        
        # Set number of patches
        num_patches = (INPUT_SHAPE[0] // patch_size) ** 2

        # Initialize model
        model = create_vivit_classifier(
            num_tokens=num_patches,
            num_classes=num_classes,
            tubelet_embedder=TubeletEmbedding(
                embed_dim=PROJECTION_DIM, patch_size=patch_size
            ),
            positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM),
            use_token_learner=use_token_learner,
        )

        # Define AdamW optimizer for regularization
        optimizer = tfa.optimizers.AdamW(
            learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
        )
        model.compile(
            optimizer=optimizer,
            loss="sparse_categorical_crossentropy",
            metrics=[
                keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
            ],
        )

        # Define early stop callbacks (for variations of datasets.)
        earlystop_callback = keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=15,
            start_from_epoch=20
        )

        # Define reduce learning rate on plateau
        reducelr_callback = keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss', factor=0.5,
            patience=5, min_lr=1e-6
        )

        # Define checkpoint callbacks
        checkpoint_filepath = "/tmp/checkpoint"
        checkpoint_callback = keras.callbacks.ModelCheckpoint(
            checkpoint_filepath,
            monitor="val_accuracy",
            save_best_only=True,
            save_weights_only=True,
        )

        # Define tensorboard callbacks
        log_dir = f"logs/fit/{dataset_name}/tl:{use_token_learner}/p:{patch_size}/" \
                  + datetime.now().strftime("%Y%m%d-%H%M%S")
        tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

        # Train the model.
        _ = model.fit(
            trainloader,
            epochs=EPOCHS,
            validation_data=validloader,
            callbacks=[tensorboard_callback, 
                       checkpoint_callback, 
                       earlystop_callback,
                       reducelr_callback],
        )

        model.load_weights(checkpoint_filepath)
        _, accuracy, top_5_accuracy = model.evaluate(testloader)
        print("TRAIN RESULT : ")
        print(f"{dataset_name}/tl:{use_token_learner}/p:{patch_size}")
        print(f"    Test accuracy: {round(accuracy * 100, 2)}%")
        print(f"    Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

In [None]:
# Clear out prior logging data.
!rm -rf logs/fit

# Train total 2*2*2 = 8 models.
for dataset_info in DATASET_INFOS[:2]:
    for patch_size in PATCH_SIZES:
        run_experiment(dataset_info, patch_size, use_token_learner=False)
        run_experiment(dataset_info, patch_size, use_token_learner=True)

Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Epoch 23/60
Epoch 24/60
Epoch 25/60
Epoch 26/60
Epoch 27/60
Epoch 28/60
Epoch 29/60
Epoch 30/60
Epoch 31/60
Epoch 32/60
Epoch 33/60
Epoch 34/60
Epoch 35/60
Epoch 36/60
Epoch 37/60
Epoch 38/60
Epoch 39/60
Epoch 40/60
Epoch 41/60
Epoch 42/60
Epoch 43/60
Epoch 44/60
Epoch 45/60
Epoch 46/60
Epoch 47/60
Epoch 48/60
Epoch 49/60
Epoch 50/60
Epoch 51/60
TRAIN RESULT : 
organmnist3d/tl:False/p:8
    Test accuracy: 71.48%
    Test top 5 accuracy: 97.38%
Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Epoch 23/60
Epoch 24/60
Epoch 25/60
Epoch 

## View & Upload on TensorBoard

In [None]:
# View TensorBoard
%tensorboard --logdir logs/fit

In [None]:
# Upload an experiment:
!tensorboard dev upload --logdir logs \
    --name "ViViT with/without Token Learner" \
    --description "Comparison between ViViT with and without Token Learner."

Upload started and will continue reading any new data as it's added to the logdir.

To stop uploading, press Ctrl-C.

New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/nYVP58K4Q1GEuWLbkWBFow/

[1m[2023-01-18T11:08:25][0m Started scanning logdir.
[1m[2023-01-18T11:09:44][0m Total uploaded: 3393 scalars, 52053 tensors (36.5 MB), 8 binary objects (9.9 MB)
