# Tokenization comparison in ViViT
## Uniform Frame Sampling vs. Tubelet Embedding

Since 3D vision datasets (videos, 3D images, etc) have additional axis, tokenizing method can be another challenge for efficient model. Thus here is a comparison between **Uniform Frame Sampling** and **Tubelet Embedding** both introduced in paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691).

### Settings 
This experiment is done on 6 datasets from **medmnist3D**, where all of them has 28 time periods, 28 pixels in height, 28 pixels in width, and 1 channel. Hyperparameters were fixed except for learning rate.

### The Result
Overall performance of Tubelet Embedding excelled Uniform Frame Sampling. Also, the time elapsed for each training was shorter for Tubelete Embedding method. The **full result of this note book** is on **[TensorBoard.dev](https://tensorboard.dev/experiment/PKs2SEeNQLO68B4tB7U8tg/)**.

### Limitations
Since the experiment does not cover variations of dataset distribution(on image size, length of time period, types of data) and hyperparameters on ViViT architecture, _this is not a general consequence of the tokenization methods._ However, given the limitations of operation the simple setting and the result can be an intuitive base for those who are curious :-)

### Reference
1. Datasets : https://medmnist.com/
1. Paper : https://arxiv.org/abs/2103.15691
1. Codes : https://keras.io/examples/vision/vivit/


## Install & Load packages

In [1]:
!pip install -q tensorflow-addons
!pip install -q tensorflow-gpu
!pip install -q -U tensorboard
!pip install -q medmnist

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.1/1.1 MB[0m [31m38.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for tensorflow-gpu (setup.py) ... [?25lerror
[31m  ERROR: Failed building wheel for tensorflow-gpu[0m[31m
[0m  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mRunning setup.py install for tensorflow-gpu[0m did not run s

In [2]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [3]:
import os
import io
import time
import imageio
import medmnist
import ipywidgets
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa
from tensorflow.keras import layers

import matplotlib.pyplot as plt
from datetime import datetime

# Setting seed for reproducibility
SEED = 2023
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
keras.utils.set_random_seed(SEED)

In [4]:
# Check device
from tensorflow.python.client import device_lib 
print(device_lib.list_local_devices())

if tf.config.list_physical_devices('GPU'):
    DEVICE_NAME = "/gpu:0"
    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
else:  
    DEVICE_NAME = "/cpu:0"

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 3709101813517062318
xla_global_id: -1
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 13912375296
locality {
  bus_id: 1
  links {
  }
}
incarnation: 2596457658420606362
physical_device_desc: "device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5"
xla_global_id: 416903419
]
Num GPUs Available:  1


In [5]:
# Refrain from verbose logging
import logging, os
logging.disable(logging.WARNING)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

### Hyperparameters

In [6]:
# DATA
# source : https://github.com/MedMNIST/MedMNIST/blob/main/medmnist/dataset.py
DATASET_INFOS = (("organmnist3d", 11), 
                 ("nodulemnist3d", 2),
                 ("adrenalmnist3d", 2),
                 ("fracturemnist3d", 3),
                 ("vesselmnist3d", 2),
                 ("synapsemnist3d", 2))  # (dataset_name, num_classes)
BATCH_SIZE = 64
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (28, 28, 28, 1)

# OPTIMIZER
LEARNING_RATES = [1e-3, 1e-4, 1e-5]
WEIGHT_DECAY = 1e-5

# TRAINING
EPOCHS = 60

# EMBEDDING
PATCH_SIZE = 8
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE) ** 2 

# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8
MLP_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM,
]

## Data Loading & Preparing

In [7]:
def download_and_prepare_dataset(data_info: dict):
    """Utility function to download the dataset.

    Arguments:
        data_info (dict): Dataset metadata.
    """
    data_path = keras.utils.get_file(
        origin=data_info["url"], md5_hash=data_info["MD5"]
    )

    with np.load(data_path) as data:
        # Get videos
        train_videos = data["train_images"]
        valid_videos = data["val_images"]
        test_videos = data["test_images"]

        # Get labels
        train_labels = data["train_labels"].flatten()
        valid_labels = data["val_labels"].flatten()
        test_labels = data["test_labels"].flatten()

    return (
        (train_videos, train_labels),
        (valid_videos, valid_labels),
        (test_videos, test_labels),
    )

In [8]:
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
    """Preprocess the frames tensors and parse the labels."""
    # Preprocess images
    frames = tf.image.convert_image_dtype(
        frames[
            ..., tf.newaxis
        ],  # The new axis is to help for further processing with Conv3D layers
        tf.float32,
    )
    # Parse label
    label = tf.cast(label, tf.float32)
    return frames, label


def prepare_dataloader(
        videos: np.ndarray,
        labels: np.ndarray,
        loader_type: str = "train",
        batch_size: int = BATCH_SIZE,
    ):
    """Utility function to prepare the dataloader."""
    dataset = tf.data.Dataset.from_tensor_slices((videos, labels))

    if loader_type == "train":
        dataset = dataset.shuffle(BATCH_SIZE * 2)

    dataloader = (
        dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    return dataloader


def get_dataloaders(dataset_name):
    # Get the metadata of the dataset
    info = medmnist.INFO[dataset_name]

    # Get the dataset
    prepared_dataset = download_and_prepare_dataset(info)
    (train_videos, train_labels) = prepared_dataset[0]
    (valid_videos, valid_labels) = prepared_dataset[1]
    (test_videos, test_labels) = prepared_dataset[2]

    # Prepare DataLoaders
    trainloader = prepare_dataloader(train_videos, train_labels, "train")
    validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
    testloader = prepare_dataloader(test_videos, test_labels, "test")

    return (trainloader, validloader, testloader)

## Define Structures
### Define Embedding Methods

In [9]:
# Define Uniform Frame Sampling, which is an extension of 2D Embedding method.
class UniformFrameSampling(layers.Layer):
    def __init__(self, embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.projection2d = layers.Conv2D(
            filters=embed_dim,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding="valid"
        )
        self.concat = layers.Concatenate(axis=1)
        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))

    def call(self, videos):
        t = videos.shape[1]
        projected_patches = []
        for frame in range(self.patch_size, t, self.patch_size): # Sample n_t
            patch = self.projection2d(videos[:, frame])  # (B, n_h, n_w, embed_dim)
            _, n_h, n_w, embed_dim = patch.shape
            patch = layers.Reshape((-1, 1, n_h, n_w, embed_dim))(patch)
            projected_patches.append(patch)
        projected_patches = self.concat(projected_patches)  # (B, n_t, n_h, n_w, embed_dim)
        flattened_patches = self.flatten(projected_patches)  # (B, num_patches, embed_dim)
        return flattened_patches

In [10]:
# Define Tubelet Embedding, which is a time-spatial embedding for video.
class TubeletEmbedding(layers.Layer):
    def __init__(self, embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE, **kwargs):
        super().__init__(**kwargs)
        self.projection3d = layers.Conv3D(
            filters=embed_dim,
            kernel_size=(patch_size, patch_size, patch_size),
            strides=(patch_size, patch_size, patch_size),
            padding="valid"
        )
        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))

    def call(self, videos):  # videos.shape = (B, T, H, W, C)
        projected_patches = self.projection3d(videos)  # (B, n_t, n_h, n_w, embed_dim)
        flattened_patches = self.flatten(projected_patches)  # (B, num_patches, embed_dim)
        return flattened_patches

In [11]:
# Positionnal Encoding for both embedding methods.
class PositionalEncoder(layers.Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim

    def build(self, input_shape):
        _, num_tokens, _ = input_shape
        self.position_embedding = layers.Embedding(
            input_dim=num_tokens, output_dim=self.embed_dim
        )
        self.positions = tf.range(start=0, limit=num_tokens, delta=1)

    def call(self, encoded_tokens):
        # Encode the positions and add it to the encoded tokens
        encoded_positions = self.position_embedding(self.positions)
        encoded_tokens = encoded_tokens + encoded_positions
        return encoded_tokens

### Define Model

In [12]:
def create_vivit_classifier(
    embedder,
    positional_encoder,
    num_classes,
    input_shape=INPUT_SHAPE,
    transformer_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    embed_dim=PROJECTION_DIM,
    layer_norm_eps=LAYER_NORM_EPS,
):
    # Get the input layer
    inputs = layers.Input(shape=input_shape)
    # Create patches.
    patches = embedder(inputs)
    # Encode patches.
    encoded_patches = positional_encoder(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization and MHSA
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=0.1
        )(x1, x1)

        # Skip connection
        x2 = layers.Add()([attention_output, encoded_patches])

        # Layer Normalization and MLP
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = keras.Sequential(
            [
                layers.Dense(units=embed_dim * 4, activation=tf.nn.gelu),
                layers.Dense(units=embed_dim, activation=tf.nn.gelu),
            ]
        )(x3)

        # Skip connection
        encoded_patches = layers.Add()([x3, x2])

    # Layer normalization and Global average pooling.
    representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)

    # Classify outputs.
    outputs = layers.Dense(units=num_classes, activation="softmax")(representation)

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

## Train
Log on TensorBoard

In [15]:
def run_experiment(dataset_info, learning_rate, use_tubelet_embedding=True):
    with tf.device(DEVICE_NAME):
        # Define dataset
        dataset_name, num_classes = dataset_info
        trainloader, validloader, testloader = get_dataloaders(dataset_name)

        # Define embedder
        if use_tubelet_embedding:
            embedder = TubeletEmbedding(
                embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
            )
        else:
            embedder = UniformFrameSampling(
                embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
            )

        model = create_vivit_classifier(
            embedder,
            positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM),
            num_classes=num_classes
        )
        optimizer = tfa.optimizers.AdamW(
            learning_rate=learning_rate, weight_decay=WEIGHT_DECAY
        )
        model.compile(
            optimizer=optimizer,
            loss="sparse_categorical_crossentropy",
            metrics=[
                keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
            ],
        )
        start = time.time()

        # Define checkpoint callbacks
        checkpoint_filepath = "/tmp/checkpoint"
        checkpoint_callback = keras.callbacks.ModelCheckpoint(
            checkpoint_filepath,
            monitor="val_accuracy",
            save_best_only=True,
            save_weights_only=True,
        )

        # Define tensorboard callbacks
        embedder_name = "tubelet" if use_tubelet_embedding else "sampling"
        log_dir = f"logs/fit/{dataset_name}/{embedder_name}/lr:{learning_rate}/" \
                  + datetime.now().strftime("%Y%m%d-%H%M%S")
        tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

        # Train the model
        _ = model.fit(
            trainloader,
            epochs=EPOCHS,
            validation_data=validloader,
            callbacks=[tensorboard_callback, checkpoint_callback],
        )

        model.load_weights(checkpoint_filepath)
        _, accuracy, top_5_accuracy = model.evaluate(testloader)
        print(f"*** Data: {dataset_name}, Embedder: {embedder_name}, LR: {learning_rate} ***")
        print(f"Test accuracy: {round(accuracy * 100, 2)}%")
        print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
        print(f"Time elapsed: {round(time.time() - start, 4)}s")

In [16]:
# Clear out prior logging data.
!rm -rf logs/fit

for dataset_info in DATASET_INFOS:
    for learning_rate in LEARNING_RATES:
        run_experiment(dataset_info, learning_rate, use_tubelet_embedding=False)
        run_experiment(dataset_info, learning_rate, use_tubelet_embedding=True)

Downloading data from https://zenodo.org/record/6496656/files/organmnist3d.npz?download=1
Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Epoch 23/60
Epoch 24/60
Epoch 25/60
Epoch 26/60
Epoch 27/60
Epoch 28/60
Epoch 29/60
Epoch 30/60
Epoch 31/60
Epoch 32/60
Epoch 33/60
Epoch 34/60
Epoch 35/60
Epoch 36/60
Epoch 37/60
Epoch 38/60
Epoch 39/60
Epoch 40/60
Epoch 41/60
Epoch 42/60
Epoch 43/60
Epoch 44/60
Epoch 45/60
Epoch 46/60
Epoch 47/60
Epoch 48/60
Epoch 49/60
Epoch 50/60
Epoch 51/60
Epoch 52/60
Epoch 53/60
Epoch 54/60
Epoch 55/60
Epoch 56/60
Epoch 57/60
Epoch 58/60
Epoch 59/60
Epoch 60/60
*** Data: organmnist3d, Embedder: sampling, LR: 0.001 ***
Test accuracy: 59.84%
Test top 5 accuracy: 91.64%
Time elapsed: 158.6115s
Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
E

## Check the result
Please refer to [here](https://tensorboard.dev/experiment/PKs2SEeNQLO68B4tB7U8tg/).

In [None]:
# # Show tensorboard extension
# %tensorboard --logdir logs/fit

In [None]:
# Upload the experiment on Tensorboard.dev
!tensorboard dev upload --logdir logs \
    --name "Tokenization comparison in ViViT" \
    --description "Simple comparison between Uniform Frame Sampling and Tubelet Embedding"


***** TensorBoard Uploader *****

This will upload your TensorBoard logs to https://tensorboard.dev/ from
the following directory:

logs

This TensorBoard will be visible to everyone. Do not upload sensitive
data.

Your use of this service is subject to Google's Terms of Service
<https://policies.google.com/terms> and Privacy Policy
<https://policies.google.com/privacy>, and TensorBoard.dev's Terms of Service
<https://tensorboard.dev/policy/terms/>.

This notice will not be shown again while you are logged into the uploader.
To log out, run `tensorboard dev auth revoke`.

Continue? (yes/NO) yes

To sign in with the TensorBoard uploader:

1. On your computer or phone, visit:

   https://www.google.com/device

2. Sign in with your Google account, then enter:

   SHZ-BRW-HVN


Upload started and will continue reading any new data as it's added to the logdir.

To stop uploading, press Ctrl-C.

New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/PKs2SEeNQLO