diff --git a/mmselfsup/datasets/transforms/formatting.py b/mmselfsup/datasets/transforms/formatting.py
index 3ceedf3a7..6e24d799f 100644
--- a/mmselfsup/datasets/transforms/formatting.py
+++ b/mmselfsup/datasets/transforms/formatting.py
@@ -56,9 +56,8 @@ def transform(self,
Returns:
Dict:
-
- - 'inputs' (List[torch.Tensor]): The forward data of models.
- - 'data_samples' (SelfSupDataSample): The annotation info of the
+ - ``inputs`` (List[torch.Tensor]): The forward data of models.
+ - ``data_samples`` (SelfSupDataSample): The annotation info of
the forward data.
"""
packed_results = dict()
@@ -68,9 +67,19 @@ def transform(self,
if not isinstance(img, List):
img = [img]
for i, img_ in enumerate(img):
- if len(img_.shape) < 3:
- img_ = np.expand_dims(img_, -1)
- img_ = np.ascontiguousarray(img_.transpose(2, 0, 1))
+ # to handle the single channel image
+ img_ = np.expand_dims(img_, -1) \
+ if len(img_.shape) == 2 else img_
+
+ if len(img_.shape) == 3:
+ img_ = np.ascontiguousarray(img_.transpose(2, 0, 1))
+ elif len(img_.shape) == 5:
+ # for video data with the shape (B, C, T, H, W)
+ img_ = img_
+ else:
+ raise ValueError(
+ 'img should be 2, 3 or 5 dimensional, '
+ f'instead of {len(img_.shape)} dimensional.')
img[i] = to_tensor(img_)
packed_results['inputs'] = img
diff --git a/mmselfsup/models/target_generators/hog_generator.py b/mmselfsup/models/target_generators/hog_generator.py
index 53d8515f7..a6ccabaf3 100644
--- a/mmselfsup/models/target_generators/hog_generator.py
+++ b/mmselfsup/models/target_generators/hog_generator.py
@@ -35,8 +35,8 @@ def __init__(self,
self.pool = pool
self.pi = math.pi
weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
- weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1)
- weight_y = weight_x.transpose(2, 3)
+ weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1).contiguous()
+ weight_y = weight_x.transpose(2, 3).contiguous()
self.register_buffer('weight_x', weight_x)
self.register_buffer('weight_y', weight_y)
diff --git a/mmselfsup/models/utils/__init__.py b/mmselfsup/models/utils/__init__.py
index 1d6b7f4ba..e4bc6463c 100644
--- a/mmselfsup/models/utils/__init__.py
+++ b/mmselfsup/models/utils/__init__.py
@@ -4,7 +4,7 @@
RelativeLocDataPreprocessor,
RotationPredDataPreprocessor,
SelfSupDataPreprocessor,
- TwoNormDataPreprocessor)
+ TwoNormDataPreprocessor, VideoDataPreprocessor)
from .ema import CosineEMA
from .extractor import Extractor
from .gather_layer import GatherLayer
@@ -29,6 +29,6 @@
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'CosineEMA',
'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor',
'RotationPredDataPreprocessor', 'CAEDataPreprocessor', 'ResLayerExtraNorm',
- 'NormEMAVectorQuantizer', 'TwoNormDataPreprocessor',
- 'PromptTransformerEncoderLayer', 'build_clip_model'
+ 'NormEMAVectorQuantizer', 'TwoNormDataPreprocessor', 'build_clip_model',
+ 'PromptTransformerEncoderLayer', 'VideoDataPreprocessor'
]
diff --git a/mmselfsup/models/utils/data_preprocessor.py b/mmselfsup/models/utils/data_preprocessor.py
index 4901cd3df..19366be6b 100644
--- a/mmselfsup/models/utils/data_preprocessor.py
+++ b/mmselfsup/models/utils/data_preprocessor.py
@@ -2,7 +2,7 @@
from typing import List, Optional, Sequence, Tuple, Union
import torch
-from mmengine.model import ImgDataPreprocessor
+from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor
from mmselfsup.registry import MODELS
@@ -290,3 +290,112 @@ def forward(
]
return batch_inputs, batch_data_samples
+
+
+@MODELS.register_module()
+class VideoDataPreprocessor(BaseDataPreprocessor):
+ """Video pre-processor for operations, like normalization and bgr to rgb
+ conversion .
+
+ Compared with the :class:`mmaction.ActionDataPreprocessor`, this module
+ treats each item in `inputs` of input data as a list, instead of
+ torch.Tensor.
+
+ Args:
+ mean (Sequence[float or int, optional): The pixel mean of channels
+ of images or stacked optical flow. Defaults to None.
+ std (Sequence[float or int], optional): The pixel standard deviation
+ of channels of images or stacked optical flow. Defaults to None.
+ pad_size_divisor (int): The size of padded image should be
+ divisible by ``pad_size_divisor``. Defaults to 1.
+ pad_value (float or int): The padded pixel value. Defaults to 0.
+ bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
+ Defaults to False.
+ format_shape (str): Format shape of input data.
+ Defaults to ``'NCHW'``.
+ """
+
+ def __init__(self,
+ mean: Optional[Sequence[Union[float, int]]] = None,
+ std: Optional[Sequence[Union[float, int]]] = None,
+ pad_size_divisor: int = 1,
+ pad_value: Union[float, int] = 0,
+ bgr_to_rgb: bool = False,
+ format_shape: str = 'NCHW') -> None:
+ super().__init__()
+ self.pad_size_divisor = pad_size_divisor
+ self.pad_value = pad_value
+ self.bgr_to_rgb = bgr_to_rgb
+ self.format_shape = format_shape
+
+ if mean is not None:
+ assert std is not None, 'To enable the normalization in ' \
+ 'preprocessing, please specify both ' \
+ '`mean` and `std`.'
+ # Enable the normalization in preprocessing.
+ self._enable_normalize = True
+ if self.format_shape == 'NCHW':
+ normalizer_shape = (-1, 1, 1)
+ elif self.format_shape == 'NCTHW':
+ normalizer_shape = (-1, 1, 1, 1)
+ else:
+ raise ValueError(f'Invalid format shape: {format_shape}')
+
+ self.register_buffer(
+ 'mean',
+ torch.tensor(mean, dtype=torch.float32).view(normalizer_shape),
+ False)
+ self.register_buffer(
+ 'std',
+ torch.tensor(std, dtype=torch.float32).view(normalizer_shape),
+ False)
+ else:
+ self._enable_normalize = False
+
+ def forward(
+ self,
+ data: dict,
+ training: bool = False
+ ) -> Tuple[List[torch.Tensor], Optional[list]]:
+ """Performs normalization、padding and bgr2rgb conversion based on
+ ``BaseDataPreprocessor``.
+
+ Args:
+ data (dict): data sampled from dataloader.
+ training (bool): Whether to enable training time augmentation. If
+ subclasses override this method, they can perform different
+ preprocessing strategies for training and testing based on the
+ value of ``training``.
+ Returns:
+ Tuple[List[torch.Tensor], Optional[list]]: Data in the same format
+ as the model input.
+ """
+
+ data = [val for _, val in data.items()]
+ batch_inputs, batch_data_samples = self.cast_data(data)
+
+ # ------ To RGB ------
+ if self.bgr_to_rgb:
+ if self.format_shape == 'NCHW':
+ batch_inputs = [
+ batch_input[..., [2, 1, 0], :, :]
+ for batch_input in batch_inputs
+ ]
+ elif self.format_shape == 'NCTHW':
+ batch_inputs = [
+ batch_input[..., [2, 1, 0], :, :, :]
+ for batch_input in batch_inputs
+ ]
+ else:
+ raise ValueError(f'Invalid format shape: {self.format_shape}')
+
+ # -- Normalization ---
+ if self._enable_normalize:
+ batch_inputs = [(batch_input - self.mean) / self.std
+ for batch_input in batch_inputs]
+ else:
+ batch_inputs = [
+ batch_input.to(torch.float32) for batch_input in batch_inputs
+ ]
+
+ return batch_inputs, batch_data_samples
diff --git a/projects/maskfeat_video/README.md b/projects/maskfeat_video/README.md
new file mode 100644
index 000000000..ab2850df2
--- /dev/null
+++ b/projects/maskfeat_video/README.md
@@ -0,0 +1,275 @@
+# MaskFeat Pre-training with Video
+
+- [MaskFeat Pre-training with Video](#maskfeat-pre-training-with-video)
+ - [Description](#description)
+ - [Usage](#usage)
+ - [Setup Environment](#setup-environment)
+ - [Data Preparation](#data-preparation)
+ - [Pre-training Commands](#pre-training-commands)
+ - [On Local Single GPU](#on-local-single-gpu)
+ - [On Multiple GPUs](#on-multiple-gpus)
+ - [On Multiple GPUs with Slurm](#on-multiple-gpus-with-slurm)
+ - [Downstream Tasks Commands](#downstream-tasks-commands)
+ - [On Multiple GPUs](#on-multiple-gpus-1)
+ - [On Multiple GPUs with Slurm](#on-multiple-gpus-with-slurm-1)
+ - [Results](#results)
+ - [Citation](#citation)
+ - [Checklist](#checklist)
+
+## Description
+
+
+
+Author: @fangyixiao18
+
+This is the implementation of **MaskFeat** with video dataset, like Kinetics400.
+
+## Usage
+
+
+
+### Setup Environment
+
+Requirements:
+
+- MMSelfSup >= 1.0.0rc6
+- MMAction2 >= 1.0.0rc3
+
+Please refer to [Get Started](https://mmselfsup.readthedocs.io/en/1.x/get_started.html) documentation of MMSelfSup to finish installation.
+
+Besides, to process the video data, we apply transforms in MMAction2. The instruction to install MMAction2 can be found in [Get Started documentation](https://mmaction2.readthedocs.io/en/1.x/get_started.html).
+
+### Data Preparation
+
+You can refer to the [documentation](https://mmaction2.readthedocs.io/en/1.x/user_guides/2_data_prepare.html) in MMAction2.
+
+### Pre-training Commands
+
+At first, you need to add the current folder to `PYTHONPATH`, so that Python can find your model files. In `projects/maskfeat_video/` root directory, please run command below to add it.
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+Then run the following commands to train the model:
+
+#### On Local Single GPU
+
+```bash
+# train with mim
+mim train mmselfsup ${CONFIG} --work-dir ${WORK_DIR}
+
+# a specific command example
+mim train mmselfsup configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py \
+ --work-dir work_dirs/selfsup/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400/
+
+# train with scripts
+python tools/train.py configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py \
+ --work-dir work_dirs/selfsup/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400/
+```
+
+#### On Multiple GPUs
+
+```bash
+# train with mim
+# a specific command examples, 8 GPUs here
+mim train mmselfsup configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py \
+ --work-dir work_dirs/selfsup/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400/ \
+ --launcher pytorch --gpus 8
+
+# train with scripts
+bash tools/dist_train.sh configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py 8
+```
+
+Note:
+
+- CONFIG: the config files under the directory `configs/`
+- WORK_DIR: the working directory to save configs, logs, and checkpoints
+
+#### On Multiple GPUs with Slurm
+
+```bash
+# train with mim
+mim train mmselfsup configs/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py \
+ --work-dir work_dirs/selfsup/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400/ \
+ --launcher slurm --gpus 16 --gpus-per-node 8 \
+ --partition ${PARTITION}
+
+# train with scripts
+GPUS_PER_NODE=8 GPUS=16 bash tools/slurm_train.sh ${PARTITION} maskfeat-video \
+ configs/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py \
+ --work-dir work_dirs/selfsup/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400/
+```
+
+Note:
+
+- CONFIG: the config files under the directory `configs/`
+- WORK_DIR: the working directory to save configs, logs, and checkpoints
+- PARTITION: the slurm partition you are using
+
+### Downstream Tasks Commands
+
+To evaluate the **MaskFeat MViT** pretrained with MMSelfSup, we recommend to run MMAction2:
+
+#### On Multiple GPUs
+
+```bash
+# command example for train
+mim train mmaction2 ${CONFIG} \
+ --work-dir ${WORK_DIR} \
+ --launcher pytorch -gpus 8 \
+ --cfg-options model.backbone.init_cfg.type=Pretrained \
+ model.backbone.init_cfg.checkpoint=${CHECKPOINT} \
+ model.backbone.init_cfg.prefix="backbone." \
+ ${PY_ARGS}
+ [optional args]
+
+mim train mmaction2 configs/mvit-small_ft-8xb8-coslr-100e_k400.py \
+ --work-dir work_dirs/benchmarks/maskfeat/training_maskfeat-mvit-k400/ \
+ --launcher pytorch -gpus 8 \
+ --cfg-options model.backbone.init_cfg.type=Pretrained \
+ model.backbone.init_cfg.checkpoint=https://download.openmmlab.com/mmselfsup/1.x/maskfeat/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400_20230131-87d60b6f.pth \
+ model.backbone.init_cfg.prefix="backbone." \
+ $PY_ARGS
+
+# command example for test
+mim test mmaction2 configs/mvit-small_ft-8xb16-coslr-100e_k400.py \
+ --checkpoint https://download.openmmlab.com/mmselfsup/1.x/maskfeat/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400/mvit-small_ft-8xb16-coslr-100e_k400/mvit-small_ft-8xb16-coslr-100e_k400_20230131-5e8303f5.pth \
+ --work-dir work_dirs/benchmarks/maskfeat/maskfeat-mvit-k400/test/ \
+ --launcher pytorch --gpus 8
+```
+
+#### On Multiple GPUs with Slurm
+
+```bash
+mim train mmaction2 ${CONFIG} \
+ --work-dir ${WORK_DIR} \
+ --launcher slurm --gpus 8 --gpus-per-node 8 \
+ --partition ${PARTITION} \
+ --cfg-options model.backbone.init_cfg.type=Pretrained \
+ model.backbone.init_cfg.checkpoint=$CHECKPOINT \
+ model.backbone.init_cfg.prefix="backbone." \
+ $PY_ARGS
+
+mim test mmaction2 ${CONFIG} \
+ --checkpoint https://download.openmmlab.com/mmselfsup/1.x/maskfeat/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400/mvit-small_ft-8xb16-coslr-100e_k400/mvit-small_ft-8xb16-coslr-100e_k400_20230131-5e8303f5.pth
+ --work-dir ${WORK_DIR} \
+ --launcher slurm --gpus 8 --gpus-per-node 8 \
+ --partition ${PARTITION} \
+ $PY_ARGS
+```
+
+Note:
+
+- CONFIG: the config files under the directory `configs/`
+- WORK_DIR: the working directory to save configs, logs, and checkpoints
+- PARTITION: the slurm partition you are using
+- CHECKPOINT: the pretrained checkpoint of MMSelfSup saved in working directory, like `$WORK_DIR/epoch_300.pth`
+- PY_ARGS: other optional args
+
+## Results
+
+
+
+The Fine-tuning results are based on Kinetics400(K400) dataset.
+
+Due to the version of K400 dataset, our pretraining, fine-tuning and the final test results are based on MMAction2 version, which is a little different from PySlowFast version.
+
+
+
+
+ Algorithm |
+ Backbone |
+ Epoch |
+ Batch Size |
+ Fine-tuning |
+ Pretrain Links |
+ Fine-tuning Links |
+
+
+
+
+ MaskFeat |
+ MViT-small |
+ 300 |
+ 512 |
+ 81.8 |
+ config | model | log |
+ config | model | log |
+
+
+
+
+Remarks:
+
+- We converted the pretrained model from PySlowFast and run fine-tuning with MMAction2, based on MMAction2 version of K400, we got `81.5` test accuracy. The pretrained model from MMSelfSup got `81.8`, as provided above.
+- We also tested our model on [other version](https://github.com/facebookresearch/video-nonlocal-net/blob/main/DATASET.md) of K400, we got `82.1` test accuracy.
+- Some other details can be found in [MMAction2 MViT page](https://github.com/open-mmlab/mmaction2/tree/dev-1.x/configs/recognition/mvit).
+
+## Citation
+
+```bibtex
+@InProceedings{wei2022masked,
+ author = {Wei, Chen and Fan, Haoqi and Xie, Saining and Wu, Chao-Yuan and Yuille, Alan and Feichtenhofer, Christoph},
+ title = {Masked Feature Prediction for Self-Supervised Visual Pre-Training},
+ booktitle = {CVPR},
+ year = {2022},
+}
+```
+
+## Checklist
+
+Here is a checklist illustrating a usual development workflow of a successful project, and also serves as an overview of this project's progress.
+
+
+
+- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
+
+ - [x] Finish the code
+
+
+
+ - [x] Basic docstrings & proper citation
+
+
+
+ - [x] Inference correctness
+
+
+
+ - [x] A full README
+
+
+
+- [x] Milestone 2: Indicates a successful model implementation.
+
+ - [x] Training-time correctness
+
+
+
+- [ ] Milestone 3: Good to be a part of our core package!
+
+ - [ ] Type hints and docstrings
+
+
+
+ - [ ] Unit tests
+
+
+
+ - [ ] Code polishing
+
+
+
+ - [ ] `metafile.yml` and `README.md`
+
+
+
+- [ ] Refactor and Move your modules into the core package following the codebase's file hierarchy structure.
diff --git a/projects/maskfeat_video/configs/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py b/projects/maskfeat_video/configs/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py
new file mode 100644
index 000000000..aacdb94dd
--- /dev/null
+++ b/projects/maskfeat_video/configs/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py
@@ -0,0 +1,102 @@
+_base_ = 'mmselfsup::selfsup/_base_/default_runtime.py'
+
+custom_imports = dict(imports=['models'], allow_failed_imports=False)
+
+model = dict(
+ type='VideoMaskFeat',
+ data_preprocessor=dict(
+ type='VideoDataPreprocessor',
+ mean=[114.75, 114.75, 114.75],
+ std=[57.375, 57.375, 57.375],
+ format_shape='NCTHW'),
+ backbone=dict(
+ type='MaskFeatMViT',
+ arch='maskfeat-small',
+ drop_path_rate=0.0,
+ dim_mul_in_attention=False),
+ neck=dict(
+ type='LinearNeck',
+ in_channels=768,
+ out_channels=108,
+ with_avg_pool=False,
+ init_cfg=dict(type='TruncNormal', layer='Linear', std=0.02, bias=0)),
+ head=dict(
+ type='MaskFeatPretrainHead',
+ loss=dict(type='PixelReconstructionLoss', criterion='L2')),
+ target_generator=dict(
+ type='HOGGenerator3d', nbins=9, pool=8, gaussian_window=16))
+
+# dataset settings
+dataset_type = 'mmaction.VideoDataset'
+data_root = 'data/kinetics400/videos_train'
+ann_file_train = 'data/Kinetics400/kinetics400_train_list_videos.txt'
+
+train_pipeline = [
+ dict(type='mmaction.DecordInit'),
+ dict(
+ type='mmaction.SampleFrames',
+ clip_len=16,
+ frame_interval=4,
+ num_clips=1),
+ dict(type='mmaction.DecordDecode'),
+ dict(type='mmaction.Resize', scale=(-1, 256)),
+ dict(type='mmaction.RandomResizedCrop', area_range=(0.5, 1.0)),
+ dict(type='mmaction.Resize', scale=(224, 224), keep_ratio=False),
+ dict(type='mmaction.Flip', flip_ratio=0.5),
+ dict(type='mmaction.FormatShape', input_format='NCTHW'),
+ dict(
+ type='MaskFeatMaskGenerator3D',
+ input_size=(8, 7, 7),
+ num_masking_patches=157,
+ min_num_patches=9,
+ max_num_patches=49),
+ dict(type='PackSelfSupInputs', key='imgs', algorithm_keys=['mask'])
+]
+
+train_dataloader = dict(
+ batch_size=32,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ collate_fn=dict(type='default_collate'),
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_train,
+ data_prefix=dict(video=data_root),
+ pipeline=train_pipeline))
+
+optim_wrapper = dict(
+ type='AmpOptimWrapper',
+ loss_scale='dynamic',
+ optimizer=dict(
+ type='AdamW', lr=8e-4 * 2, betas=(0.9, 0.999), weight_decay=0.05),
+ clip_grad=dict(max_norm=0.02),
+ paramwise_cfg=dict(
+ bias_decay_mult=0.,
+ norm_decay_mult=0.,
+ custom_keys={
+ 'pos_embed': dict(decay_mult=0.),
+ 'cls_token': dict(decay_mult=0.)
+ }))
+
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-4,
+ by_epoch=True,
+ begin=0,
+ end=10,
+ convert_to_iter_based=True),
+ dict(
+ type='CosineAnnealingLR',
+ T_max=290,
+ eta_min=1e-6,
+ by_epoch=True,
+ begin=10,
+ end=300,
+ convert_to_iter_based=True)
+]
+
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
+default_hooks = dict(
+ checkpoint=dict(interval=1, max_keep_ckpts=2), logger=dict(interval=100))
diff --git a/projects/maskfeat_video/configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py b/projects/maskfeat_video/configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py
new file mode 100644
index 000000000..3f26e56e9
--- /dev/null
+++ b/projects/maskfeat_video/configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py
@@ -0,0 +1,5 @@
+_base_ = './maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py'
+
+optim_wrapper = dict(
+ optimizer=dict(
+ type='AdamW', lr=8e-4, betas=(0.9, 0.999), weight_decay=0.05))
diff --git a/projects/maskfeat_video/configs/mvit-small_ft-8xb16-coslr-100e_k400.py b/projects/maskfeat_video/configs/mvit-small_ft-8xb16-coslr-100e_k400.py
new file mode 100644
index 000000000..367e4baf5
--- /dev/null
+++ b/projects/maskfeat_video/configs/mvit-small_ft-8xb16-coslr-100e_k400.py
@@ -0,0 +1,157 @@
+_base_ = [
+ 'mmaction::_base_/models/mvit_small.py',
+ 'mmaction::_base_/default_runtime.py'
+]
+
+model = dict(
+ backbone=dict(
+ drop_path_rate=0.1,
+ dim_mul_in_attention=False,
+ pretrained=None,
+ pretrained_type='maskfeat',
+ ),
+ data_preprocessor=dict(
+ type='ActionDataPreprocessor',
+ mean=[114.75, 114.75, 114.75],
+ std=[57.375, 57.375, 57.375],
+ blending=dict(
+ type='RandomBatchAugment',
+ augments=[
+ dict(type='MixupBlending', alpha=0.8, num_classes=400),
+ dict(type='CutmixBlending', alpha=1, num_classes=400)
+ ]),
+ format_shape='NCTHW'),
+ cls_head=dict(dropout_ratio=0., init_scale=0.001))
+
+# dataset settings
+dataset_type = 'VideoDataset'
+data_root = 'data/kinetics400/videos_train'
+data_root_val = 'data/kinetics400/videos_val'
+ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
+ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
+ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt'
+
+train_pipeline = [
+ dict(type='DecordInit'),
+ dict(type='SampleFrames', clip_len=16, frame_interval=4, num_clips=1),
+ dict(type='DecordDecode'),
+ dict(type='Resize', scale=(-1, 256)),
+ dict(type='PytorchVideoWrapper', op='RandAugment', magnitude=7),
+ dict(type='RandomResizedCrop'),
+ dict(type='Resize', scale=(224, 224), keep_ratio=False),
+ dict(type='Flip', flip_ratio=0.5),
+ dict(type='RandomErasing', erase_prob=0.25, mode='rand'),
+ dict(type='FormatShape', input_format='NCTHW'),
+ dict(type='PackActionInputs')
+]
+val_pipeline = [
+ dict(type='DecordInit'),
+ dict(
+ type='SampleFrames',
+ clip_len=16,
+ frame_interval=4,
+ num_clips=1,
+ test_mode=True),
+ dict(type='DecordDecode'),
+ dict(type='Resize', scale=(-1, 256)),
+ dict(type='CenterCrop', crop_size=224),
+ dict(type='FormatShape', input_format='NCTHW'),
+ dict(type='PackActionInputs')
+]
+test_pipeline = [
+ dict(type='DecordInit'),
+ dict(
+ type='SampleFrames',
+ clip_len=16,
+ frame_interval=4,
+ num_clips=10,
+ test_mode=True),
+ dict(type='DecordDecode'),
+ dict(type='Resize', scale=(-1, 224)),
+ dict(type='CenterCrop', crop_size=224),
+ dict(type='FormatShape', input_format='NCTHW'),
+ dict(type='PackActionInputs')
+]
+
+repeat_sample = 2
+train_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ collate_fn=dict(type='repeat_pseudo_collate'),
+ dataset=dict(
+ type='RepeatAugDataset',
+ num_repeats=repeat_sample,
+ ann_file=ann_file_train,
+ data_prefix=dict(video=data_root),
+ pipeline=train_pipeline))
+val_dataloader = dict(
+ batch_size=16,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_val,
+ data_prefix=dict(video=data_root_val),
+ pipeline=val_pipeline,
+ test_mode=True))
+test_dataloader = dict(
+ batch_size=1,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ ann_file=ann_file_test,
+ data_prefix=dict(video=data_root_val),
+ pipeline=test_pipeline,
+ test_mode=True))
+
+val_evaluator = dict(type='AccMetric')
+test_evaluator = val_evaluator
+
+train_cfg = dict(
+ type='EpochBasedTrainLoop', max_epochs=100, val_begin=1, val_interval=1)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+base_lr = 9.6e-3
+optim_wrapper = dict(
+ optimizer=dict(
+ type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05),
+ constructor='LearningRateDecayOptimizerConstructor',
+ paramwise_cfg={
+ 'decay_rate': 0.75,
+ 'decay_type': 'layer_wise',
+ 'num_layers': 16
+ },
+ clip_grad=dict(max_norm=5, norm_type=2))
+
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1 / 600,
+ by_epoch=True,
+ begin=0,
+ end=20,
+ convert_to_iter_based=True),
+ dict(
+ type='CosineAnnealingLR',
+ T_max=80,
+ eta_min_ratio=1 / 600,
+ by_epoch=True,
+ begin=20,
+ end=100,
+ convert_to_iter_based=True)
+]
+
+default_hooks = dict(
+ checkpoint=dict(interval=3, max_keep_ckpts=20), logger=dict(interval=100))
+
+# Default setting for scaling LR automatically
+# - `enable` means enable scaling LR automatically
+# or not by default.
+# - `base_batch_size` = (8 GPUs) x (64 samples per GPU) / repeat_sample.
+auto_scale_lr = dict(enable=True, base_batch_size=512 // repeat_sample)
diff --git a/projects/maskfeat_video/models/__init__.py b/projects/maskfeat_video/models/__init__.py
new file mode 100644
index 000000000..96e5f913a
--- /dev/null
+++ b/projects/maskfeat_video/models/__init__.py
@@ -0,0 +1,9 @@
+from .hog_generator_3d import HOGGenerator3d
+from .maskfeat import VideoMaskFeat
+from .maskfeat_mvit import MaskFeatMViT
+from .transforms import MaskFeatMaskGenerator3D
+
+__all__ = [
+ 'HOGGenerator3d', 'VideoMaskFeat', 'MaskFeatMViT',
+ 'MaskFeatMaskGenerator3D'
+]
diff --git a/projects/maskfeat_video/models/hog_generator_3d.py b/projects/maskfeat_video/models/hog_generator_3d.py
new file mode 100644
index 000000000..df6073b9b
--- /dev/null
+++ b/projects/maskfeat_video/models/hog_generator_3d.py
@@ -0,0 +1,39 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmselfsup.models.target_generators import HOGGenerator
+from mmselfsup.registry import MODELS
+
+
+@MODELS.register_module()
+class HOGGenerator3d(HOGGenerator):
+ """Generate HOG feature for videos.
+
+ This module is used in MaskFeat to generate HOG feature.
+ Here is the link of `HOG wikipedia
+ `_.
+
+ Args:
+ nbins (int): Number of bin. Defaults to 9.
+ pool (float): Number of cell. Defaults to 8.
+ gaussian_window (int): Size of gaussian kernel. Defaults to 16.
+ """
+
+ def __init__(self,
+ nbins: int = 9,
+ pool: int = 8,
+ gaussian_window: int = 16) -> None:
+ super().__init__(
+ nbins=nbins, pool=pool, gaussian_window=gaussian_window)
+
+ def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor:
+ """Reshape HOG Features for output."""
+ hog_feat = hog_feat.flatten(1, 2)
+ self.unfold_size = hog_feat.shape[-1] // 14
+ hog_feat = hog_feat.permute(0, 2, 3, 1)
+ hog_feat = hog_feat.unfold(1, self.unfold_size,
+ self.unfold_size).unfold(
+ 2, self.unfold_size, self.unfold_size)
+ hog_feat = hog_feat.flatten(3).view(self.B, self.T, 14, 14, -1)
+ hog_feat = hog_feat.flatten(1, 3) # B N C
+ return hog_feat
diff --git a/projects/maskfeat_video/models/maskfeat.py b/projects/maskfeat_video/models/maskfeat.py
new file mode 100644
index 000000000..3e9fb3792
--- /dev/null
+++ b/projects/maskfeat_video/models/maskfeat.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List
+
+import torch
+import torch.nn.functional as F
+
+from mmselfsup.models import BaseModel
+from mmselfsup.registry import MODELS
+from mmselfsup.structures import SelfSupDataSample
+
+
+@MODELS.register_module()
+class VideoMaskFeat(BaseModel):
+ """MaskFeat.
+
+ Implementation of `Masked Feature Prediction for Self-Supervised Visual
+ Pre-Training `_.
+ """
+
+ def loss(self, inputs: List[torch.Tensor],
+ data_samples: List[SelfSupDataSample],
+ **kwargs) -> Dict[str, torch.Tensor]:
+ """The forward function in training.
+
+ Args:
+ inputs (List[torch.Tensor]): The input images.
+ data_samples (List[SelfSupDataSample]): All elements required
+ during the forward function.
+
+ Returns:
+ Dict[str, torch.Tensor]: A dictionary of loss components.
+ """
+ mask = torch.stack(
+ [data_sample.mask.value for data_sample in data_samples])
+ mask = mask.to(torch.bool)
+
+ video = inputs[0]
+ video = video.view((-1, ) + video.shape[2:]) # B, C, T, H, W
+ latent = self.backbone(video, mask)
+ B, L, C = latent[0].shape
+ pred = self.neck([latent[0].view(B * L, C)])
+ pred = pred[0].view(B, L, -1)
+
+ # generate hog target
+ video = video[:, :, ::self.backbone.patch_stride[0], :, :]
+ video = video.transpose(1, 2) # B, T, C, H, W
+ self.target_generator.B = video.size(0)
+ self.target_generator.T = video.size(1)
+ video = video.flatten(0, 1) # B*T, C, H, W
+ hog = self.target_generator(video)
+
+ mask = self._get_output_mask(mask)
+ loss = self.head(pred, hog, mask)
+ losses = dict(loss=loss)
+ return losses
+
+ def _get_output_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ size = self.backbone.out_patch_resolution[-1][-1]
+ output_mask = F.interpolate(mask.float(), size=size)
+ return output_mask
diff --git a/projects/maskfeat_video/models/maskfeat_mvit.py b/projects/maskfeat_video/models/maskfeat_mvit.py
new file mode 100644
index 000000000..b255665ad
--- /dev/null
+++ b/projects/maskfeat_video/models/maskfeat_mvit.py
@@ -0,0 +1,146 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmaction.models import MViT
+from mmaction.models.backbones.mvit import resize_pos_embed
+
+from mmselfsup.registry import MODELS
+
+
+@MODELS.register_module()
+class MaskFeatMViT(MViT):
+
+ arch_zoo = {
+ 'maskfeat-small': {
+ 'embed_dims': 96,
+ 'num_layers': 16,
+ 'num_heads': 1,
+ 'downscale_indices': [1, 3],
+ 'dim_mul_indices': [1, 3, 14]
+ },
+ 'maskfeat-large': {
+ 'embed_dims': 144,
+ 'num_layers': 48,
+ 'num_heads': 2,
+ 'downscale_indices': [2, 8],
+ 'dim_mul_indices': [2, 8, 44]
+ },
+ }
+
+ def __init__(
+ self,
+ arch: str = 'base',
+ spatial_size: int = 224,
+ temporal_size: int = 16,
+ in_channels: int = 3,
+ out_scales: Union[int, Sequence[int]] = -1,
+ drop_path_rate: float = 0,
+ use_abs_pos_embed: bool = False,
+ interpolate_mode: str = 'trilinear',
+ pool_kernel: tuple = (3, 3, 3),
+ dim_mul: int = 2,
+ head_mul: int = 2,
+ adaptive_kv_stride: tuple = (1, 8, 8),
+ rel_pos_embed: bool = True,
+ residual_pooling: bool = True,
+ dim_mul_in_attention: bool = True,
+ with_cls_token: bool = True,
+ output_cls_token: bool = True,
+ rel_pos_zero_init: bool = False,
+ mlp_ratio: float = 4,
+ qkv_bias: bool = True,
+ norm_cfg: dict = dict(type='LN', eps=1e-6),
+ patch_cfg: dict = dict(
+ kernel_size=(3, 7, 7), stride=(2, 4, 4), padding=(1, 3, 3)),
+ init_cfg: Optional[Union[dict, List[dict]]] = [
+ dict(type='TruncNormal', layer=['Conv2d', 'Conv3d'], std=0.02),
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.02),
+ ]
+ ) -> None:
+ super().__init__(
+ arch=arch,
+ spatial_size=spatial_size,
+ temporal_size=temporal_size,
+ in_channels=in_channels,
+ out_scales=out_scales,
+ drop_path_rate=drop_path_rate,
+ use_abs_pos_embed=use_abs_pos_embed,
+ interpolate_mode=interpolate_mode,
+ pool_kernel=pool_kernel,
+ dim_mul=dim_mul,
+ head_mul=head_mul,
+ adaptive_kv_stride=adaptive_kv_stride,
+ rel_pos_embed=rel_pos_embed,
+ residual_pooling=residual_pooling,
+ dim_mul_in_attention=dim_mul_in_attention,
+ with_cls_token=with_cls_token,
+ output_cls_token=output_cls_token,
+ rel_pos_zero_init=rel_pos_zero_init,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ norm_cfg=norm_cfg,
+ patch_cfg=patch_cfg,
+ init_cfg=init_cfg)
+
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
+ self.patch_stride = patch_cfg['stride']
+
+ def init_weights(self) -> None:
+ """Initialize mask token and cls token."""
+ super().init_weights()
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ # Suppress default init if use pretrained model.
+ return
+
+ nn.init.trunc_normal_(self.cls_token, std=.02)
+ nn.init.trunc_normal_(self.mask_token, std=.02)
+
+ def forward(self, x: torch.Tensor,
+ mask: torch.Tensor) -> Tuple[torch.Tensor]:
+
+ x, patch_resolution = self.patch_embed(x)
+ B, L, C = x.shape
+ T, H, W = patch_resolution
+
+ mask_tokens = self.mask_token.expand(B, L, -1)
+ mask = F.interpolate(mask.float(), size=(H, W))
+ mask = mask.flatten(1).unsqueeze(-1)
+ x = x * (1 - mask) + mask_tokens * mask
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if self.use_abs_pos_embed:
+ x = x + resize_pos_embed(
+ self.pos_embed,
+ self.patch_resolution,
+ patch_resolution,
+ mode=self.interpolate_mode,
+ num_extra_tokens=self.num_extra_tokens)
+
+ # if not self.with_cls_token:
+ # # Remove class token for transformer encoder input
+ # x = x[:, 1:]
+
+ outs = []
+ self.out_patch_resolution = []
+ for i, block in enumerate(self.blocks):
+ x, patch_resolution = block(x, patch_resolution)
+
+ if i in self.stage_indices:
+ stage_index = self.stage_indices[i]
+ if stage_index in self.out_scales:
+ self.out_patch_resolution.append(patch_resolution)
+ x = getattr(self, f'norm{stage_index}')(x)
+ if not self.output_cls_token:
+ out = x[:, 1:]
+ else:
+ out = x
+ outs.append(out)
+
+ return tuple(outs)
diff --git a/projects/maskfeat_video/models/transforms.py b/projects/maskfeat_video/models/transforms.py
new file mode 100644
index 000000000..d269b96cf
--- /dev/null
+++ b/projects/maskfeat_video/models/transforms.py
@@ -0,0 +1,130 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import random
+from typing import Optional, Tuple
+
+import numpy as np
+from mmcv.transforms.base import BaseTransform
+
+from mmselfsup.registry import TRANSFORMS
+
+
+@TRANSFORMS.register_module()
+class MaskFeatMaskGenerator3D(BaseTransform):
+ """Generate mask for video.
+
+ Added Keys:
+
+ - mask
+
+ This module is borrowed from
+ https://github.com/facebookresearch/SlowFast/blob/main/slowfast/datasets/transform.py
+
+ Args:
+ input_size (int): The size of input video.
+ num_masking_patches (int): The number of patches to be masked.
+ min_num_patches (int): The minimum number of patches to be masked
+ in the process of generating mask. Defaults to 4.
+ max_num_patches (int, optional): The maximum number of patches to be
+ masked in the process of generating mask. Defaults to None.
+ min_aspect (float): The minimum aspect ratio of mask blocks. Defaults
+ to 0.3.
+ min_aspect (float, optional): The minimum aspect ratio of mask blocks.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ input_size: int,
+ num_masking_patches: int,
+ min_num_patches: int = 4,
+ max_num_patches: Optional[int] = None,
+ min_aspect: float = 0.3,
+ max_aspect: Optional[float] = None) -> None:
+
+ self.temporal, self.height, self.width = input_size
+ self.num_masking_patches = num_masking_patches
+ self.min_num_patches = min_num_patches
+ self.max_num_patches = (
+ num_masking_patches
+ if max_num_patches is None else max_num_patches)
+ max_aspect = max_aspect or 1 / min_aspect
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
+
+ def get_shape(self) -> Tuple[int, int, int]:
+ """Get the shape of mask.
+
+ Returns:
+ Tuple[int, int, int]: The shape of mask.
+ """
+ return self.temporal, self.height, self.width
+
+ def _mask(self, mask: np.ndarray, max_mask_patches: int) -> int:
+ """Generate mask recursively.
+
+ Args:
+ mask (np.ndarray): The mask to be generated.
+ max_mask_patches (int): The maximum number of patches to be masked.
+
+ Returns:
+ int: The number of patches masked.
+ """
+ delta = 0
+ for _ in range(100):
+ target_area = random.uniform(self.min_num_patches,
+ self.max_num_patches)
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ t = random.randint(1, self.temporal) # !
+ if w < self.width and h < self.height:
+ top = random.randint(0, self.height - h)
+ left = random.randint(0, self.width - w)
+ front = random.randint(0, self.temporal - t)
+
+ num_masked = mask[front:front + t, top:top + h,
+ left:left + w].sum()
+ # Overlap
+ if 0 < h * w * t - num_masked <= max_mask_patches:
+ for i in range(front, front + t):
+ for j in range(top, top + h):
+ for k in range(left, left + w):
+ if mask[i, j, k] == 0:
+ mask[i, j, k] = 1
+ delta += 1
+
+ if delta > 0:
+ break
+ return delta
+
+ def transform(self, results: dict) -> dict:
+ """Method to generate random block mask.
+
+ Args:
+ results (dict): Result dict from previous pipeline.
+
+ Returns:
+ dict: Result dict with added key ``mask``.
+ """
+ mask = np.zeros(shape=self.get_shape(), dtype=np.int)
+ mask_count = 0
+ while mask_count < self.num_masking_patches:
+ max_mask_patches = self.num_masking_patches - mask_count
+ delta = self._mask(mask, max_mask_patches)
+ if delta == 0:
+ break
+ else:
+ mask_count += delta
+
+ results.update({'mask': mask})
+ return results
+
+ def __repr__(self) -> str:
+ repr_str = self.__class__.__name__
+ repr_str += f'(temporal={self.temporal}, '
+ repr_str += f'height={self.height}, '
+ repr_str += f'width={self.width}, '
+ repr_str += f'num_masking_patches={self.num_masking_patches}, '
+ repr_str += f'min_num_patches={self.min_num_patches}, '
+ repr_str += f'max_num_patches={self.max_num_patches}, '
+ repr_str += f'log_aspect_ratio={self.log_aspect_ratio})'
+ return repr_str
diff --git a/projects/maskfeat_video/tools/dist_train.sh b/projects/maskfeat_video/tools/dist_train.sh
new file mode 100644
index 000000000..3fca7641d
--- /dev/null
+++ b/projects/maskfeat_video/tools/dist_train.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env bash
+
+CONFIG=$1
+GPUS=$2
+NNODES=${NNODES:-1}
+NODE_RANK=${NODE_RANK:-0}
+PORT=${PORT:-29500}
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python -m torch.distributed.launch \
+ --nnodes=$NNODES \
+ --node_rank=$NODE_RANK \
+ --master_addr=$MASTER_ADDR \
+ --nproc_per_node=$GPUS \
+ --master_port=$PORT \
+ $(dirname "$0")/train.py \
+ $CONFIG \
+ --launcher pytorch ${@:3}
diff --git a/projects/maskfeat_video/tools/slurm_train.sh b/projects/maskfeat_video/tools/slurm_train.sh
new file mode 100644
index 000000000..ac36d5082
--- /dev/null
+++ b/projects/maskfeat_video/tools/slurm_train.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+
+set -x
+
+PARTITION=$1
+JOB_NAME=$2
+CONFIG=$3
+GPUS=${GPUS:-8}
+GPUS_PER_NODE=${GPUS_PER_NODE:-8}
+CPUS_PER_TASK=${CPUS_PER_TASK:-5}
+SRUN_ARGS=${SRUN_ARGS:-""}
+PY_ARGS=${@:4}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ ${SRUN_ARGS} \
+ python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}
diff --git a/projects/maskfeat_video/tools/train.py b/projects/maskfeat_video/tools/train.py
new file mode 100644
index 000000000..ef0d3127c
--- /dev/null
+++ b/projects/maskfeat_video/tools/train.py
@@ -0,0 +1,99 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os
+import os.path as osp
+
+from mmengine.config import Config, DictAction
+from mmengine.runner import Runner
+
+from mmselfsup.utils import register_all_modules
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a model')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument('--work-dir', help='the dir to save logs and models')
+ parser.add_argument(
+ '--resume',
+ nargs='?',
+ type=str,
+ const='auto',
+ help='If specify checkpint path, resume from it, while if not '
+ 'specify, try to auto resume from the latest checkpoint '
+ 'in the work directory.')
+ parser.add_argument(
+ '--amp',
+ action='store_true',
+ help='enable automatic-mixed-precision training')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. If the value to '
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ 'Note that the quotation marks are necessary and that no white space '
+ 'is allowed.')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ # register all modules in mmselfsup into the registries
+ # do not init the default scope here because it will be init in the runner
+ register_all_modules(init_default_scope=False)
+
+ # load config
+ cfg = Config.fromfile(args.config)
+ cfg.launcher = args.launcher
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+
+ # work_dir is determined in this priority: CLI > segment in file > filename
+ if args.work_dir is not None:
+ # update configs according to CLI args if args.work_dir is not None
+ cfg.work_dir = args.work_dir
+ elif cfg.get('work_dir', None) is None:
+ # use config filename as default work_dir if cfg.work_dir is None
+ work_type = args.config.split('/')[1]
+ cfg.work_dir = osp.join('./work_dirs', work_type,
+ osp.splitext(osp.basename(args.config))[0])
+
+ # enable automatic-mixed-precision training
+ if args.amp is True:
+ optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
+ assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
+ '`--amp` is not supported custom optimizer wrapper type ' \
+ f'`{optim_wrapper}.'
+ cfg.optim_wrapper.type = 'AmpOptimWrapper'
+ cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
+
+ # resume training
+ if args.resume == 'auto':
+ cfg.resume = True
+ cfg.load_from = None
+ elif args.resume is not None:
+ cfg.resume = True
+ cfg.load_from = args.resume
+
+ # build the runner from config
+ runner = Runner.from_cfg(cfg)
+
+ # start training
+ runner.train()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/test_datasets/test_transforms/test_formmatting.py b/tests/test_datasets/test_transforms/test_formmatting.py
index 751edc4d6..bc3fca10d 100644
--- a/tests/test_datasets/test_transforms/test_formmatting.py
+++ b/tests/test_datasets/test_transforms/test_formmatting.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
+import pytest
import torch
from mmselfsup.datasets.transforms import PackSelfSupInputs
@@ -32,6 +33,19 @@ def test_pack_selfsup_inputs():
assert list(results['inputs'][0].shape) == [1, 8, 8]
assert results['data_samples'].gt_label.value == torch.tensor([1])
+ # video data
+ transform = PackSelfSupInputs(key='img', algorithm_keys=['gt_label'])
+ results = {'img': np.ones((4, 3, 8, 8, 8)), 'gt_label': 1}
+ results = transform(results)
+ assert list(results['inputs'][0].shape) == [4, 3, 8, 8, 8]
+ assert results['data_samples'].gt_label.value == torch.tensor([1])
+
+ # dimension check
+ with pytest.raises(ValueError):
+ transform = PackSelfSupInputs(key='img', algorithm_keys=['gt_label'])
+ results = {'img': np.ones((3, 8, 8, 8)), 'gt_label': 1}
+ results = transform(results)
+
# img is a list
transform = PackSelfSupInputs(key='img', algorithm_keys=['gt_label'])
results = {'img': [np.ones((8, 8))], 'gt_label': 1}
diff --git a/tests/test_models/test_utils/test_data_preprocessor.py b/tests/test_models/test_utils/test_data_preprocessor.py
index bf34a6aa7..94fdbc79f 100644
--- a/tests/test_models/test_utils/test_data_preprocessor.py
+++ b/tests/test_models/test_utils/test_data_preprocessor.py
@@ -3,7 +3,8 @@
import torch
from mmselfsup.models.utils import (SelfSupDataPreprocessor,
- TwoNormDataPreprocessor)
+ TwoNormDataPreprocessor,
+ VideoDataPreprocessor)
from mmselfsup.structures import SelfSupDataSample
@@ -66,3 +67,46 @@ def test_two_norm_data_preprocessor():
fake_batches, fake_samples = data_preprocessor(fake_data)
assert len(fake_batches) == 2
assert len(fake_samples) == 4
+
+
+def test_video_data_preprocessor():
+ data_preprocessor = VideoDataPreprocessor(
+ mean=[114.75, 114.75, 114.75],
+ std=[57.375, 57.375, 57.375],
+ format_shape='NCTHW')
+ fake_data = {
+ 'inputs': [torch.randn((2, 3, 4, 224, 224))],
+ 'data_sample': [SelfSupDataSample(),
+ SelfSupDataSample()]
+ }
+ fake_batches, fake_samples = data_preprocessor(fake_data)
+ assert len(fake_batches) == 1
+ assert len(fake_samples) == 2
+
+ data_preprocessor = VideoDataPreprocessor(
+ mean=[114.75, 114.75, 114.75],
+ std=[57.375, 57.375, 57.375],
+ bgr_to_rgb=True,
+ format_shape='NCTHW')
+ fake_data = {
+ 'inputs': [torch.randn((2, 3, 4, 224, 224))],
+ 'data_sample': [SelfSupDataSample(),
+ SelfSupDataSample()]
+ }
+ fake_batches, fake_samples = data_preprocessor(fake_data)
+ assert len(fake_batches) == 1
+ assert len(fake_samples) == 2
+
+ data_preprocessor = VideoDataPreprocessor(
+ mean=[114.75, 114.75, 114.75],
+ std=[57.375, 57.375, 57.375],
+ bgr_to_rgb=True,
+ format_shape='NCHW')
+ fake_data = {
+ 'inputs': [torch.randn((2, 3, 224, 224))],
+ 'data_sample': [SelfSupDataSample(),
+ SelfSupDataSample()]
+ }
+ fake_batches, fake_samples = data_preprocessor(fake_data)
+ assert len(fake_batches) == 1
+ assert len(fake_samples) == 2