Skip to content

innat/VideoMAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

77 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

VideoMAE

videomae

Palestine

arXiv keras-2.12. Open In Colab HugginFace badge HugginFace badge

Video masked autoencoders (VideoMAE) are seen as data-efficient learners for self-supervised video pre-training (SSVP). Inspiration was drawn from the recent ImageMAE, and customized video tube masking with an extremely high ratio was proposed. Due to this simple design, video reconstruction is made a more challenging self-supervision task, leading to the extraction of more effective video representations during this pre-training process. Some hightlights of VideoMAE:

  • Masked Video Modeling for Video Pre-Training
  • A Simple, Efficient and Strong Baseline in SSVP
  • High performance, but NO extra data required

This is a unofficial Keras implementation of VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training model. The official PyTorch implementation can be found here.

News

  • [27-10-2023]: Code of Video-FocalNet in Keras becomes available.
  • [15-10-2023]: Code of UniFormerV2 (UFV2) in Keras becomes available.
  • [06-10-2023]: Code of Video Swin Transformer in Keras becomes available.
  • [24-10-2023]: Kinetics-400 test data set can be found on kaggle, link.
  • [9-10-2023]: TensorFlow SavedModel (formet) checkpoints, link.
  • [6-10-2023]: VideoMAE integrated into Huggingface Space.
  • [4-10-2023]: VideoMAE checkpoints SSV2 and UCF101 becomes available, link.
  • [3-10-2023]: VideoMAE checkpoints on Kinetics-400 becomes available, link.
  • [29-9-2023]: GPU(s), TPU-VM for fine-tune training are supported.
  • [27-9-2023]: Code of VideoMAE in Keras becomes available.

Install

git clone https://github.com/innat/VideoMAE.git
cd VideoMAE
pip install -e . 

Usage

There are many variants of VideoMAE mdoels available, i.e. small, base, large, and huge. And also for benchmark data specific, i.e. Kinetics-400, SSV2, and UCF101. Check this release and model zoo page to know details of it.

Pre-trained Masked Autoencoder

Only the inference part is provided for pre-trained VideoMAE models. Using the trained checkpoint, it would be possible to reconstruct the input sample even with high mask ratio. For end-to-end workflow, check this reconstruction.ipynb notebook. Some highlights:

from videomae import VideoMAE_ViTS16PT

# pre-trained self-supervised model
>>> model = VideoMAE_ViTS16PT(img_size=224, patch_size=16)
>>> model.load_weights('TFVideoMAE_B_K400_16x224_PT.h5')

# tube masking
>>> tube_mask = TubeMaskingGenerator(
    input_size=window_size, 
    mask_ratio=0.80
)
>>> make_bool = tube_mask()
>>> bool_masked_pos_tf = tf.constant(make_bool, dtype=tf.int32)
>>> bool_masked_pos_tf = tf.expand_dims(bool_masked_pos_tf, axis=0)
>>> bool_masked_pos_tf = tf.cast(bool_masked_pos_tf, tf.bool)

# running
>>> container = read_video('sample.mp4')
>>> frames = frame_sampling(container, num_frames=16)
>>> pred_tf = model(frames, bool_masked_pos_tf)
>>> pred_tf.numpy().shape
TensorShape([1, 1176, 1536])

A reconstructed results on a sample from SSV2 with mask_ratio=0.8

Fine Tuned Model

With the fine-tuned VideoMAE checkpoint, it would be possible to evaluate the benchmark datast and also retraining would be possible on custom dataset. For end-to-end workflow, check this quick retraining.ipynb notebook. It supports both multi-gpu and tpu-vm retraining and evaluation. Some highlights:

from videomae import VideoMAE_ViTS16FT

>>> model = VideoMAE_ViTS16FT(img_size=224, patch_size=16, num_classes=400)
>>> container = read_video('sample.mp4')
>>> frames = frame_sampling(container, num_frames=16)
>>> y = model(frames)
>>> y.shape
TensorShape([1, 400])

>>> probabilities = tf.nn.softmax(y_pred_tf)
>>> probabilities = probabilities.numpy().squeeze(0)
>>> confidences = {
    label_map_inv[i]: float(probabilities[i]) \
    for i in np.argsort(probabilities)[::-1]
}
>>> confidences

A classification results on a sample from Kinetics-400.

Video Top-5
{
'playing_cello': 0.6552159786224365,
'snowkiting': 0.0018940207082778215,
'deadlifting': 0.0018381892004981637,
'playing_guitar': 0.001778001431375742,
'playing_recorder': 0.0017528659664094448
}

Model Zoo

The pre-trained and fine-tuned models are listed in MODEL_ZOO.md. Following are some hightlights.

Kinetics-400

For Kinetrics-400, VideoMAE is trained around 1600 epoch without any extra data. The following checkpoints are available in both tensorflow SavedModel and h5 format.

Backbone #Frame Top-1 Top-5 Params [FT] MB Params [PT] MB) FLOPs
ViT-S 16x5x3 79.0 93.8 22 24 57G
ViT-B 16x5x3 81.5 95.1 87 94 181G
ViT-L 16x5x3 85.2 96.8 304 343 -
ViT-H 16x5x3 86.6 97.1 632 ? -

?* Official ViT-H backbone of VideoMAE has weight issue in pretrained model, details MCG-NJU/VideoMAE#89. The FLOPs of encoder models (FT) are reported only.

Something-Something V2

For SSv2, VideoMAE is trained around 2400 epoch without any extra data.

Backbone #Frame Top-1 Top-5 Params [FT] MB Params [PT] MB FLOPs
ViT-S 16x2x3 66.8 90.3 22 24 57G
ViT-B 16x2x3 70.8 92.4 86 94 181G

UCF101

For UCF101, VideoMAE is trained around 3200 epoch without any extra data.

Backbone #Frame Top-1 Top-5 Params [FT] MB Params [PT] MB FLOPS
ViT-B 16x5x3 91.3 98.5 86 94 181G

Visualization

Some reconstructed video sample using VideoMAE with different mask ratio.

Kinetics-400-testset mask
0.8
0.8
0.9
0.9
SSv2-testset mask
0.9
0.9
UCF101-testset mask
0.8
0.9

TODO

  • Custom fine-tuning code.
  • Publish on TF-Hub.
  • Support Keras V3to support multi-framework backend.

Citation

If you use this videomae implementation in your research, please cite it using the metadata from our CITATION.cff file.

@inproceedings{tong2022videomae,
  title={Video{MAE}: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training},
  author={Zhan Tong and Yibing Song and Jue Wang and Limin Wang},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}