In [1]:
! pip show opencv-python

Name: opencv-python
Version: 4.6.0.66
Summary: Wrapper package for OpenCV python bindings.
Home-page: https://github.com/skvark/opencv-python
Author: None
Author-email: None
License: MIT
Location: /anaconda/envs/azureml_py38/lib/python3.8/site-packages
Requires: numpy, numpy
Required-by: 


In [2]:
import tensorflow as tf
# tf.test.gpu_device_name()

## Introduction

Videos are sequences of images. Let's assume you have an image
representation model (CNN, ViT, etc.) and a sequence model
(RNN, LSTM, etc.) at hand. We ask you to tweak the model for video
classification. The simplest approach would be to apply the image
model to individual frames, use the sequence model to learn
sequences of image features, then apply a classification head on
the learned sequence representation.
The Keras example
[Video Classification with a CNN-RNN Architecture](https://keras.io/examples/vision/video_classification/)
explains this approach in detail. Alernatively, you can also
build a hybrid Transformer-based model for video classification as shown in the Keras example
[Video Classification with Transformers](https://keras.io/examples/vision/video_transformers/).

In this example, we minimally implement
[ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691)
by Arnab et al., a **pure Transformer-based** model
for video classification. The authors propose a novel embedding scheme
and a number of Transformer variants to model video clips. We implement
the embedding scheme and one of the variants of the Transformer
architecture, for simplicity.

This example requires TensorFlow 2.6 or higher, and the `medmnist`
package, which can be installed by running the code cell below.

## Imports

In [3]:
import os
import io
import cv2
import imageio
import medmnist
#import ipywidgets
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Setting seed for reproducibility
SEED = 42
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
keras.utils.set_random_seed(SEED)

## Hyperparameters

The hyperparameters are chosen via hyperparameter
search. You can learn more about the process in the "conclusion" section.

In [4]:
import os
import io
import cv2
import imageio
import medmnist
#import ipywidgets
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# DATA
DATASET_NAME = "RWF-2000"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (28, 28, 28, 1)
NUM_CLASSES = 2
train_ratio = 0.7

# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# TRAINING
EPOCHS = 25

# TUBELET EMBEDDING
PATCH_SIZE = (8, 8, 8)
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2

# ViViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8

In [5]:
ROOT_PATH = '../datasets/RWF-2000/RWF-2000/train'

Fight_PATH = ROOT_PATH + '/Fight'
Non_Fight_PATH = ROOT_PATH + '/NonFight'

OUT_PATH = '../datasets/RWF-2000/RWF-2000/Extracted frames'

In [6]:
def extract_frames(directory, dimensions=(INPUT_SHAPE[1], INPUT_SHAPE[2]), packet_length = INPUT_SHAPE[0], save_dir_path = None):
    """ 
    Extract packets from video directory

    Parameters
    -----------
    directory : str
    A directory on the disk that contains the videos 

    dimensions: Tuple of shape 2 , optional
    The desired frame dimensions to be stored
    default = (50,50)

    packet_length: int , optional
    Packet length , Default = 15

    save_dir : str , optional
    The path to which the axtracted data should be saved

    Returns
    -------
    ndarray 
    The extracted packets from each video stacked together 
    """
    data = []

    for video_name in os.listdir(directory):
        video = cv2.VideoCapture(directory + '/' + video_name)
        packet = []
        i = 0
        while video.isOpened():
            ret, frame = video.read()

            if not ret: # no more frames
                break

            del ret

    #         if i % 2 == 0: # capture one in every 2 frames
    #             i += 1
    #             continue

            # capturing the frame

            frame = cv2.resize(frame, dimensions, interpolation = cv2.INTER_AREA)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            packet.append(frame)
            del frame

            if len(packet) == packet_length: 
                '''
                consecutive packets share 14 of the 15 frames to generate more data
                '''
                # packet itself is not normalized
                stacked = np.array(packet) # convert to numpy and normalize packet
                data.append(stacked.copy()) # this .copy() is not in the original code 
                packet.pop(0) 

            i += 1

        video.release()
        cv2.destroyAllWindows()

    del packet
    # TODO: read the docs for the next 2 lines
    data = np.stack(data, axis= 0)
    data = np.moveaxis(data, 1 ,-1)

    # save to disk
    if(save_dir_path is not None):
        np.save(f'{save_dir_path}/data', data)

    return data

In [7]:
fight_frames = np.load('../datasets/RWF-2000/RWF-2000/Extracted frames/violence.npy')

In [8]:
non_fight_frames = np.load('../datasets/RWF-2000/RWF-2000/Extracted frames/non-violence.npy')

In [7]:
# fight_frames = extract_frames(Fight_PATH)

In [8]:
fight_frames.shape

(98400, 28, 28, 28)

In [9]:
# non_fight_frames = extract_frames(Non_Fight_PATH)

In [10]:
non_fight_frames.shape

(98400, 28, 28, 28)

In [11]:
# np.save('../datasets/RWF-2000/RWF-2000/Extracted frames/non-violence', non_fight_frames)

In [12]:
# np.save('../datasets/RWF-2000/RWF-2000/Extracted frames/violence', fight_frames)

In [13]:
n_train_samples = int(train_ratio*fight_frames.shape[0])
train_nd_fight_videos, train_nd_fight_labels = fight_frames[:n_train_samples], np.full((n_train_samples,), 0)

n_valid_samples = fight_frames.shape[0] - n_train_samples
valid_nd_fight_videos, valid_nd_fight_labels = fight_frames[n_train_samples:], np.full((n_valid_samples,), 0)

In [14]:
train_nd_fight_videos.shape, train_nd_fight_labels.shape

((68880, 28, 28, 28), (68880,))

In [15]:
valid_nd_fight_videos.shape, valid_nd_fight_labels.shape

((29520, 28, 28, 28), (29520,))

In [16]:
n_train_samples = int(train_ratio*non_fight_frames.shape[0])
train_nd_non_fight_videos, train_nd_non_fight_labels = non_fight_frames[:n_train_samples], np.full((n_train_samples,), 1)

n_valid_samples = non_fight_frames.shape[0] - n_train_samples
valid_nd_non_fight_videos, valid_nd_non_fight_labels = non_fight_frames[n_train_samples:], np.full((n_valid_samples,), 1)

In [17]:
train_nd_non_fight_videos.shape, train_nd_non_fight_labels.shape

((68880, 28, 28, 28), (68880,))

In [18]:
valid_nd_non_fight_videos.shape, valid_nd_non_fight_labels.shape

((29520, 28, 28, 28), (29520,))

In [19]:
train_nd_videos = np.concatenate((train_nd_fight_videos, train_nd_non_fight_videos), axis = 0)
valid_nd_videos = np.concatenate((valid_nd_fight_videos, valid_nd_non_fight_videos), axis = 0)

train_nd_labels = np.concatenate((train_nd_fight_labels, train_nd_non_fight_labels), axis = 0)
valid_nd_labels = np.concatenate((valid_nd_fight_labels, valid_nd_non_fight_labels), axis = 0)

In [20]:
train_nd_videos.shape, train_nd_labels.shape, valid_nd_videos.shape, valid_nd_labels.shape

((137760, 28, 28, 28), (137760,), (59040, 28, 28, 28), (59040,))

### `tf.data` pipeline

In [21]:

@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
    """Preprocess the frames tensors and parse the labels."""
    # Preprocess images
    frames = tf.image.convert_image_dtype(
        frames[
            ..., tf.newaxis
        ],  # The new axis is to help for further processing with Conv3D layers
        tf.float32,
    )
    # Parse label
    label = tf.cast(label, tf.float32)
    return frames, label


def prepare_dataloader(
    videos: np.ndarray,
    labels: np.ndarray,
    loader_type: str = "train",
    batch_size: int = BATCH_SIZE,
):
    """Utility function to prepare the dataloader."""
    dataset = tf.data.Dataset.from_tensor_slices((videos, labels))

    if loader_type == "train":
        dataset = dataset.shuffle(BATCH_SIZE * 2)

    dataloader = (
        dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    return dataloader


trainloader = prepare_dataloader(train_nd_videos, train_nd_labels, "train")
validloader = prepare_dataloader(valid_nd_videos, valid_nd_labels, "valid")
# testloader = prepare_dataloader(test_videos, test_labels, "test")

2022-11-18 17:01:38.217950: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2022-11-18 17:01:38.218032: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (gradinstance1): /proc/driver/nvidia/version does not exist
2022-11-18 17:01:38.238894: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-18 17:01:43.444163: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 3024107520 exceeds 10% of free system memory.
2022-11-18 17:01:47.471813: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 1296046080 exceeds 10% of free system memory.


## Tubelet Embedding

In ViTs, an image is divided into patches, which are then spatially
flattened, a process known as tokenization. For a video, one can
repeat this process for individual frames. **Uniform frame sampling**
as suggested by the authors is a tokenization scheme in which we
sample frames from the video clip and perform simple ViT tokenization.

| ![uniform frame sampling](https://i.imgur.com/aaPyLPX.png) |
| :--: |
| Uniform Frame Sampling [Source](https://arxiv.org/abs/2103.15691) |

**Tubelet Embedding** is different in terms of capturing temporal
information from the video.
First, we extract volumes from the video -- these volumes contain
patches of the frame and the temporal information as well. The volumes
are then flattened to build video tokens.

| ![tubelet embedding](https://i.imgur.com/9G7QTfV.png) |
| :--: |
| Tubelet Embedding [Source](https://arxiv.org/abs/2103.15691) |

In [22]:
class TubeletEmbedding(layers.Layer):
    def __init__(self, embed_dim, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.projection = layers.Conv3D(
            filters=embed_dim,
            kernel_size=patch_size,
            strides=patch_size,
            padding="VALID",
        )
        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))

    def call(self, videos):
        projected_patches = self.projection(videos)
        flattened_patches = self.flatten(projected_patches)
        return flattened_patches


## Positional Embedding

This layer adds positional information to the encoded video tokens.

In [23]:
class PositionalEncoder(layers.Layer):
    def __init__(self, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim

    def build(self, input_shape):
        _, num_tokens, _ = input_shape
        self.position_embedding = layers.Embedding(
            input_dim=num_tokens, output_dim=self.embed_dim
        )
        self.positions = tf.range(start=0, limit=num_tokens, delta=1)

    def call(self, encoded_tokens):
        # Encode the positions and add it to the encoded tokens
        encoded_positions = self.position_embedding(self.positions)
        encoded_tokens = encoded_tokens + encoded_positions
        return encoded_tokens

## Video Vision Transformer

The authors suggest 4 variants of Vision Transformer:

- Spatio-temporal attention
- Factorized encoder
- Factorized self-attention
- Factorized dot-product attention

In this example, we will implement the **Spatio-temporal attention**
model for simplicity. The following code snippet is heavily inspired from
[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
One can also refer to the
[official repository of ViViT](https://github.com/google-research/scenic/tree/main/scenic/projects/vivit)
which contains all the variants, implemented in JAX.

In [24]:
def create_vivit_classifier(
    tubelet_embedder,
    positional_encoder,
    input_shape=INPUT_SHAPE,
    transformer_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    embed_dim=PROJECTION_DIM,
    layer_norm_eps=LAYER_NORM_EPS,
    num_classes=NUM_CLASSES,
):
    # Get the input layer
    inputs = layers.Input(shape=input_shape)
    # Create patches.
    patches = tubelet_embedder(inputs)
    # Encode patches.
    encoded_patches = positional_encoder(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization and MHSA
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=0.1
        )(x1, x1)

        # Skip connection
        x2 = layers.Add()([attention_output, encoded_patches])

        # Layer Normalization and MLP
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = keras.Sequential(
            [
                layers.Dense(units=embed_dim * 4, activation=tf.nn.gelu),
                layers.Dense(units=embed_dim, activation=tf.nn.gelu),
            ]
        )(x3)

        # Skip connection
        encoded_patches = layers.Add()([x3, x2])

    # Layer normalization and Global average pooling.
    representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)

    # Classify outputs.
    outputs = layers.Dense(units=num_classes, activation="softmax")(representation)

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


In [27]:
model_checkpoint = ModelCheckpoint('../checkpoints/RWF-2000',
                monitor='val_loss',
                save_best_only=True)

early_stopping = EarlyStopping(monitor="val_loss",
              min_delta=0,
              patience=2,
              restore_best_weights=True)
callbacks = [model_checkpoint, early_stopping]

## Train

In [28]:

def run_experiment():
    mirrored_strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
    with mirrored_strategy.scope():
    
        # Initialize model
        model = create_vivit_classifier(
            tubelet_embedder=TubeletEmbedding(
                embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
            ),
            positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM),
        )

        # Compile the model with the optimizer, loss function
        # and the metrics.
        optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
        model.compile(
            optimizer=optimizer,
            loss="sparse_categorical_crossentropy",
            metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
        )

        # Train the model.

        _ = model.fit(trainloader, epochs=EPOCHS, validation_data=validloader)

    #     _, accuracy, top_5_accuracy = model.evaluate(testloader)
    #     print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    #     print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return model


model = run_experiment()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


2022-11-18 17:04:54.988149: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_UINT8
      type: DT_INT64
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 137760
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 28
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRO

Epoch 1/25

2022-11-18 17:14:59.740907: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_UINT8
      type: DT_INT64
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 59040
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:5"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 28
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PROD

Epoch 2/25
   1/4305 [..............................] - ETA: 11:26 - loss: 11.5677 - accuracy: 0.0000e+00

2022-11-18 17:16:34.599218: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.




2022-11-18 17:26:27.615690: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 3/25
   1/4305 [..............................] - ETA: 12:50 - loss: 12.6368 - accuracy: 0.0000e+00

2022-11-18 17:27:58.572073: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.




2022-11-18 17:37:40.307121: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 4/25
   1/4305 [..............................] - ETA: 11:54 - loss: 9.5706 - accuracy: 0.0000e+00

2022-11-18 17:39:17.339260: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.




2022-11-18 17:48:58.267669: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 5/25
   1/4305 [..............................] - ETA: 12:02 - loss: 10.1596 - accuracy: 0.0000e+00

2022-11-18 17:50:28.344977: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.




2022-11-18 18:00:01.548582: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 6/25
   1/4305 [..............................] - ETA: 11:08 - loss: 9.8571 - accuracy: 0.0000e+00

2022-11-18 18:01:31.150858: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.




2022-11-18 18:11:02.796236: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 7/25
   1/4305 [..............................] - ETA: 11:37 - loss: 9.1498 - accuracy: 0.0000e+00

2022-11-18 18:12:34.107283: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.




KeyboardInterrupt: 

## Inference

In [None]:
# NUM_SAMPLES_VIZ = 25
# testsamples, labels = next(iter(testloader))
# testsamples, labels = testsamples[:NUM_SAMPLES_VIZ], labels[:NUM_SAMPLES_VIZ]

# ground_truths = []
# preds = []
# videos = []

# for i, (testsample, label) in enumerate(zip(testsamples, labels)):
#     # Generate gif
#     with io.BytesIO() as gif:
#         imageio.mimsave(gif, (testsample.numpy() * 255).astype("uint8"), "GIF", fps=5)
#         videos.append(gif.getvalue())

#     # Get model prediction
#     output = model.predict(tf.expand_dims(testsample, axis=0))[0]
#     pred = np.argmax(output, axis=0)

#     ground_truths.append(label.numpy().astype("int"))
#     preds.append(pred)


# def make_box_for_grid(image_widget, fit):
#     """Make a VBox to hold caption/image for demonstrating option_fit values.

#     Source: https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Styling.html
#     """
#     # Make the caption
#     if fit is not None:
#         fit_str = "'{}'".format(fit)
#     else:
#         fit_str = str(fit)

#     h = ipywidgets.HTML(value="" + str(fit_str) + "")

#     # Make the green box with the image widget inside it
#     boxb = ipywidgets.widgets.Box()
#     boxb.children = [image_widget]

#     # Compose into a vertical box
#     vb = ipywidgets.widgets.VBox()
#     vb.layout.align_items = "center"
#     vb.children = [h, boxb]
#     return vb


# boxes = []
# for i in range(NUM_SAMPLES_VIZ):
#     ib = ipywidgets.widgets.Image(value=videos[i], width=100, height=100)
#     true_class = info["label"][str(ground_truths[i])]
#     pred_class = info["label"][str(preds[i])]
#     caption = f"T: {true_class} | P: {pred_class}"

#     boxes.append(make_box_for_grid(ib, caption))

# ipywidgets.widgets.GridBox(
#     boxes, layout=ipywidgets.widgets.Layout(grid_template_columns="repeat(5, 200px)")
# )

## Final thoughts

With a vanilla implementation, we achieve ~79-80% Top-1 accuracy on the
test dataset.

The hyperparameters used in this tutorial were finalized by running a
hyperparameter search using
[W&B Sweeps](https://docs.wandb.ai/guides/sweeps).
You can find out our sweeps result
[here](https://wandb.ai/minimal-implementations/vivit/sweeps/66fp0lhz)
and our quick analysis of the results
[here](https://wandb.ai/minimal-implementations/vivit/reports/Hyperparameter-Tuning-Analysis--VmlldzoxNDEwNzcx).

For further improvement, you could look into the following:

- Using data augmentation for videos.
- Using a better regularization scheme for training.
- Apply different variants of the transformer model as in the paper.

We would like to thank [Anurag Arnab](https://anuragarnab.github.io/)
(first author of ViViT) for helpful discussion. We are grateful to
[Weights and Biases](https://wandb.ai/site) program for helping with
GPU credits.

You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/video-vision-transformer) and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/video-vision-transformer-CT).