# Imports

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import io
import imageio
import medmnist
import ipywidgets
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

SEED = 42
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
keras.utils.set_random_seed(SEED)

# Hyperparameters

In [14]:
# DATA
DATASET_NAME = 'organmnist3d'
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (28, 28, 28, 1)
NUM_CLASSES = 11

# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# TRAINING
EPOCHS = 60

# TUBELET EMBEDDING
PATCH_SIZE = (8, 8, 8)
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2

# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8

# Dataset

In [3]:
# function for downloading the dataset
def download_and_prepare_dataset(data_info: dict):
    data_path = keras.utils.get_file(origin=data_info['url'], md5_hash=data_info['MD5'])
    
    with np.load(data_path) as data:
        # Get videos
        train_videos = data['train_images']
        valid_videos = data['val_images']
        test_videos = data['test_images']
        
        # Get labels
        train_labels = data['train_labels'].flatten()
        valid_labels = data['val_labels'].flatten()
        test_labels = data['test_labels'].flatten()
        
    return (
    (train_videos, train_labels),
    (valid_videos, valid_labels),
    (test_videos, test_labels)
    )

In [15]:
# get meta data of the dataset
info = medmnist.INFO[DATASET_NAME]

# get the dataset
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]

# tf.data pipeline

In [4]:
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
    # preprocess images
    frames = tf.image.convert_image_dtype(
    frames[
        ..., tf.newaxis
    ],
    tf.float32
    )
    # parse label
    label = tf.cast(label, tf.float32)
    return frames, label

def prepare_dataloader(
    videos: np.ndarray,
    labels: np.ndarray,
    loader_type: str = 'train',
    batch_size: int = BATCH_SIZE,
):
    dataset = tf.data.Dataset.from_tensor_slices((videos, labels))
    
    if loader_type == 'train':
        dataset = dataset.shuffle(BATCH_SIZE * 2)
    
    dataloader = (
        dataset.map(preprocess, num_parallel_calls = tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    
    return dataloader

In [16]:
trainloader = prepare_dataloader(train_videos, train_labels, 'train')
validloader = prepare_dataloader(valid_videos, valid_labels, 'valid')
testloader = prepare_dataloader(test_videos, test_labels, 'test')

# Tubelet Embedding

In [5]:
class TubeletEmbedding(layers.Layer):
    def __init__(self, embed_dim, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.projection = layers.Conv3D(
            filters=embed_dim,
            kernel_size=patch_size,
            strides=patch_size,
            padding='VALID'
        )
        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
        
    def call(self, videos):
        projected_patches = self.projection(videos)
        flattened_patches = self.flatten(projected_patches)
        return flattened_patches
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'embed_dim': self.embed_dim,
            'patch_size': self.patch_size,
        })
        return config

# Positional Embedding

In [6]:
class PositionalEncoder(layers.Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        
    def build(self, input_shape):
        _, num_tokens, _ = input_shape
        self.position_embedding = layers.Embedding(
            input_dim=num_tokens, output_dim=self.embed_dim
        )
        self.positions = tf.range(start=0, limit=num_tokens, delta=1)
        
    def call(self, encoded_tokens):
        # Encode the positions and add it to the encoded tokens
        encoded_positions = self.position_embedding(self.positions)
        encoded_tokens = encoded_tokens + encoded_positions
        return encoded_tokens
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'embed_dim': self.embed_dim,
        })
        return config

# ViViT Transformer
implementing **Spatio-temporal attention** variant of this transformer

In [7]:
def create_vivit_classifier(
    tubelet_embedder,
    positional_encoder,
    input_shape=INPUT_SHAPE,
    transformer_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    embed_dim=PROJECTION_DIM,
    layer_norm_eps=LAYER_NORM_EPS,
    num_classes=NUM_CLASSES,
):
    # Get the input layer
    inputs = layers.Input(shape=input_shape)
    # Create patches
    patches = tubelet_embedder(inputs)
    # Encode patches
    encoded_patches = positional_encoder(patches)
    
    # Create multiple layers of the Transformer block
    for _ in range(transformer_layers):
        # Layer normalization and MHSA
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=0.1
        )(x1, x1)
        
        # Skip connection
        x2 = layers.Add()([attention_output, encoded_patches])
        
        # Layer Normalization and MLP
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = keras.Sequential([
            layers.Dense(units=embed_dim * 4, activation=tf.nn.gelu),
            layers.Dense(units=embed_dim, activation=tf.nn.gelu)
        ])(x3)
        
        # Skip connection
        encoded_patches = layers.Add()([x3, x2])
        
    # Layer normalization and Global average pooling
    representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)
    
    # classify outputs
    outputs = layers.Dense(units=num_classes, activation='softmax')(representation)
    
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

# Train

In [8]:
def run_experiment():
    # initializa model
    model = create_vivit_classifier(
        tubelet_embedder=TubeletEmbedding(
        embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
        ),
        positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM)
    )
    
    # compile the model with the optimizer, loss function and the metrics
    optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
    model.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name='top-5-accuracy')
        ]
    )
    
    # train the model
    _ = model.fit(trainloader, epochs=EPOCHS, validation_data=validloader)
    
    _, accuracy, top_5_accuracy = model.evaluate(testloader)
    print(f'Test accuracy: {round(accuracy*100, 2)}%')
    print(f'Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%')
    
    return model

In [11]:
model = run_experiment()

Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Epoch 23/60
Epoch 24/60
Epoch 25/60
Epoch 26/60
Epoch 27/60
Epoch 28/60
Epoch 29/60
Epoch 30/60
Epoch 31/60
Epoch 32/60
Epoch 33/60
Epoch 34/60
Epoch 35/60
Epoch 36/60
Epoch 37/60
Epoch 38/60
Epoch 39/60
Epoch 40/60
Epoch 41/60
Epoch 42/60


Epoch 43/60
Epoch 44/60
Epoch 45/60
Epoch 46/60
Epoch 47/60
Epoch 48/60
Epoch 49/60
Epoch 50/60
Epoch 51/60
Epoch 52/60
Epoch 53/60
Epoch 54/60
Epoch 55/60
Epoch 56/60
Epoch 57/60
Epoch 58/60
Epoch 59/60
Epoch 60/60
Test accuracy: 80.49%
Test top 5 accuracy: 98.2%


In [12]:
model.save('saved_models/test_ViViT.h5', save_format='h5')

  layer_config = serialize_layer_fn(layer)


In [9]:
from tensorflow.keras.models import load_model

In [11]:
loaded_model = load_model('saved_models/test_ViViT.h5', custom_objects={"TubeletEmbedding": TubeletEmbedding, 
                                                                        "PositionalEncoder": PositionalEncoder})

In [12]:
loaded_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 28, 28, 28,  0           []                               
                                 1)]                                                              
                                                                                                  
 tubelet_embedding (TubeletEmbe  (None, 27, 128)     65664       ['input_1[0][0]']                
 dding)                                                                                           
                                                                                                  
 positional_encoder (Positional  (None, 27, 128)     3456        ['tubelet_embedding[0][0]']      
 Encoder)                                                                                     

# Inference

In [17]:
NUM_SAMPLES_VIZ = 25
testsamples, labels = next(iter(testloader))
testsamples, labels = testsamples[:NUM_SAMPLES_VIZ], labels[:NUM_SAMPLES_VIZ]

ground_truths = []
preds = []
videos = []

for i, (testsample, label) in enumerate(zip(testsamples, labels)):
    # Generate gif
    with io.BytesIO() as gif:
        imageio.mimsave(gif, (testsample.numpy() * 255).astype('uint8'), 'GIF', fps=5)
        videos.append(gif.getvalue())
    
    # get prediction
    output = loaded_model.predict(tf.expand_dims(testsample, axis=0))[0]
    pred = np.argmax(output, axis=0)
    
    ground_truths.append(label.numpy().astype('int'))
    preds.append(pred)
    
def make_box_for_grid(image_widget, fit):
    '''
    Make a VBox to hold caption/image for demonstrating optino_fit values.
    
    Source: https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Styling.html
    '''
    # Make the caption
    if fit is not None:
        fit_str = f"'{fit}'"
    else:
        fit_str = str(fit)
        
    h = ipywidgets.HTML(value='' + str(fit_str) + '')
    
    # make the green box with the image widget inside it
    boxb = ipywidgets.widgets.Box()
    boxb.children = [image_widget]
    
    # Compose into a vertical box
    vb = ipywidgets.widgets.VBox()
    vb.layout.align_items = 'center'
    vb.children = [h, boxb]
    return vb

boxes = []
for i in range(NUM_SAMPLES_VIZ):
    ib = ipywidgets.widgets.Image(value=videos[i], width=100, height=100)
    true_class = info['label'][str(ground_truths[i])]
    pred_class = info['label'][str(preds[i])]
    caption = f'T: {true_class} | P: {pred_class}'
    
    boxes.append(make_box_for_grid(ib, caption))
    
ipywidgets.widgets.GridBox(
    boxes, layout=ipywidgets.widgets.Layout(grid_template_columns='repeat(5, 200px)')
)

GridBox(children=(VBox(children=(HTML(value="'T: pancreas | P: pancreas'"), Box(children=(Image(value=b'GIF89a…