diff --git a/mmselfsup/datasets/transforms/formatting.py b/mmselfsup/datasets/transforms/formatting.py index 3ceedf3a7..6e24d799f 100644 --- a/mmselfsup/datasets/transforms/formatting.py +++ b/mmselfsup/datasets/transforms/formatting.py @@ -56,9 +56,8 @@ def transform(self, Returns: Dict: - - - 'inputs' (List[torch.Tensor]): The forward data of models. - - 'data_samples' (SelfSupDataSample): The annotation info of the + - ``inputs`` (List[torch.Tensor]): The forward data of models. + - ``data_samples`` (SelfSupDataSample): The annotation info of the forward data. """ packed_results = dict() @@ -68,9 +67,19 @@ def transform(self, if not isinstance(img, List): img = [img] for i, img_ in enumerate(img): - if len(img_.shape) < 3: - img_ = np.expand_dims(img_, -1) - img_ = np.ascontiguousarray(img_.transpose(2, 0, 1)) + # to handle the single channel image + img_ = np.expand_dims(img_, -1) \ + if len(img_.shape) == 2 else img_ + + if len(img_.shape) == 3: + img_ = np.ascontiguousarray(img_.transpose(2, 0, 1)) + elif len(img_.shape) == 5: + # for video data with the shape (B, C, T, H, W) + img_ = img_ + else: + raise ValueError( + 'img should be 2, 3 or 5 dimensional, ' + f'instead of {len(img_.shape)} dimensional.') img[i] = to_tensor(img_) packed_results['inputs'] = img diff --git a/mmselfsup/models/target_generators/hog_generator.py b/mmselfsup/models/target_generators/hog_generator.py index 53d8515f7..a6ccabaf3 100644 --- a/mmselfsup/models/target_generators/hog_generator.py +++ b/mmselfsup/models/target_generators/hog_generator.py @@ -35,8 +35,8 @@ def __init__(self, self.pool = pool self.pi = math.pi weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) - weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1) - weight_y = weight_x.transpose(2, 3) + weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1).contiguous() + weight_y = weight_x.transpose(2, 3).contiguous() self.register_buffer('weight_x', weight_x) self.register_buffer('weight_y', weight_y) diff --git a/mmselfsup/models/utils/__init__.py b/mmselfsup/models/utils/__init__.py index 1d6b7f4ba..e4bc6463c 100644 --- a/mmselfsup/models/utils/__init__.py +++ b/mmselfsup/models/utils/__init__.py @@ -4,7 +4,7 @@ RelativeLocDataPreprocessor, RotationPredDataPreprocessor, SelfSupDataPreprocessor, - TwoNormDataPreprocessor) + TwoNormDataPreprocessor, VideoDataPreprocessor) from .ema import CosineEMA from .extractor import Extractor from .gather_layer import GatherLayer @@ -29,6 +29,6 @@ 'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor', 'RotationPredDataPreprocessor', 'CAEDataPreprocessor', 'ResLayerExtraNorm', - 'NormEMAVectorQuantizer', 'TwoNormDataPreprocessor', - 'PromptTransformerEncoderLayer', 'build_clip_model' + 'NormEMAVectorQuantizer', 'TwoNormDataPreprocessor', 'build_clip_model', + 'PromptTransformerEncoderLayer', 'VideoDataPreprocessor' ] diff --git a/mmselfsup/models/utils/data_preprocessor.py b/mmselfsup/models/utils/data_preprocessor.py index 4901cd3df..19366be6b 100644 --- a/mmselfsup/models/utils/data_preprocessor.py +++ b/mmselfsup/models/utils/data_preprocessor.py @@ -2,7 +2,7 @@ from typing import List, Optional, Sequence, Tuple, Union import torch -from mmengine.model import ImgDataPreprocessor +from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor from mmselfsup.registry import MODELS @@ -290,3 +290,112 @@ def forward( ] return batch_inputs, batch_data_samples + + +@MODELS.register_module() +class VideoDataPreprocessor(BaseDataPreprocessor): + """Video pre-processor for operations, like normalization and bgr to rgb + conversion . + + Compared with the :class:`mmaction.ActionDataPreprocessor`, this module + treats each item in `inputs` of input data as a list, instead of + torch.Tensor. + + Args: + mean (Sequence[float or int, optional): The pixel mean of channels + of images or stacked optical flow. Defaults to None. + std (Sequence[float or int], optional): The pixel standard deviation + of channels of images or stacked optical flow. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + bgr_to_rgb (bool): Whether to convert image from BGR to RGB. + Defaults to False. + format_shape (str): Format shape of input data. + Defaults to ``'NCHW'``. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + bgr_to_rgb: bool = False, + format_shape: str = 'NCHW') -> None: + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.bgr_to_rgb = bgr_to_rgb + self.format_shape = format_shape + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both ' \ + '`mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + if self.format_shape == 'NCHW': + normalizer_shape = (-1, 1, 1) + elif self.format_shape == 'NCTHW': + normalizer_shape = (-1, 1, 1, 1) + else: + raise ValueError(f'Invalid format shape: {format_shape}') + + self.register_buffer( + 'mean', + torch.tensor(mean, dtype=torch.float32).view(normalizer_shape), + False) + self.register_buffer( + 'std', + torch.tensor(std, dtype=torch.float32).view(normalizer_shape), + False) + else: + self._enable_normalize = False + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[List[torch.Tensor], Optional[list]]: Data in the same format + as the model input. + """ + + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + + # ------ To RGB ------ + if self.bgr_to_rgb: + if self.format_shape == 'NCHW': + batch_inputs = [ + batch_input[..., [2, 1, 0], :, :] + for batch_input in batch_inputs + ] + elif self.format_shape == 'NCTHW': + batch_inputs = [ + batch_input[..., [2, 1, 0], :, :, :] + for batch_input in batch_inputs + ] + else: + raise ValueError(f'Invalid format shape: {self.format_shape}') + + # -- Normalization --- + if self._enable_normalize: + batch_inputs = [(batch_input - self.mean) / self.std + for batch_input in batch_inputs] + else: + batch_inputs = [ + batch_input.to(torch.float32) for batch_input in batch_inputs + ] + + return batch_inputs, batch_data_samples diff --git a/projects/maskfeat_video/README.md b/projects/maskfeat_video/README.md new file mode 100644 index 000000000..ab2850df2 --- /dev/null +++ b/projects/maskfeat_video/README.md @@ -0,0 +1,275 @@ +# MaskFeat Pre-training with Video + +- [MaskFeat Pre-training with Video](#maskfeat-pre-training-with-video) + - [Description](#description) + - [Usage](#usage) + - [Setup Environment](#setup-environment) + - [Data Preparation](#data-preparation) + - [Pre-training Commands](#pre-training-commands) + - [On Local Single GPU](#on-local-single-gpu) + - [On Multiple GPUs](#on-multiple-gpus) + - [On Multiple GPUs with Slurm](#on-multiple-gpus-with-slurm) + - [Downstream Tasks Commands](#downstream-tasks-commands) + - [On Multiple GPUs](#on-multiple-gpus-1) + - [On Multiple GPUs with Slurm](#on-multiple-gpus-with-slurm-1) + - [Results](#results) + - [Citation](#citation) + - [Checklist](#checklist) + +## Description + + + +Author: @fangyixiao18 + +This is the implementation of **MaskFeat** with video dataset, like Kinetics400. + +## Usage + + + +### Setup Environment + +Requirements: + +- MMSelfSup >= 1.0.0rc6 +- MMAction2 >= 1.0.0rc3 + +Please refer to [Get Started](https://mmselfsup.readthedocs.io/en/1.x/get_started.html) documentation of MMSelfSup to finish installation. + +Besides, to process the video data, we apply transforms in MMAction2. The instruction to install MMAction2 can be found in [Get Started documentation](https://mmaction2.readthedocs.io/en/1.x/get_started.html). + +### Data Preparation + +You can refer to the [documentation](https://mmaction2.readthedocs.io/en/1.x/user_guides/2_data_prepare.html) in MMAction2. + +### Pre-training Commands + +At first, you need to add the current folder to `PYTHONPATH`, so that Python can find your model files. In `projects/maskfeat_video/` root directory, please run command below to add it. + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +Then run the following commands to train the model: + +#### On Local Single GPU + +```bash +# train with mim +mim train mmselfsup ${CONFIG} --work-dir ${WORK_DIR} + +# a specific command example +mim train mmselfsup configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py \ + --work-dir work_dirs/selfsup/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400/ + +# train with scripts +python tools/train.py configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py \ + --work-dir work_dirs/selfsup/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400/ +``` + +#### On Multiple GPUs + +```bash +# train with mim +# a specific command examples, 8 GPUs here +mim train mmselfsup configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py \ + --work-dir work_dirs/selfsup/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400/ \ + --launcher pytorch --gpus 8 + +# train with scripts +bash tools/dist_train.sh configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py 8 +``` + +Note: + +- CONFIG: the config files under the directory `configs/` +- WORK_DIR: the working directory to save configs, logs, and checkpoints + +#### On Multiple GPUs with Slurm + +```bash +# train with mim +mim train mmselfsup configs/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py \ + --work-dir work_dirs/selfsup/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400/ \ + --launcher slurm --gpus 16 --gpus-per-node 8 \ + --partition ${PARTITION} + +# train with scripts +GPUS_PER_NODE=8 GPUS=16 bash tools/slurm_train.sh ${PARTITION} maskfeat-video \ + configs/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py \ + --work-dir work_dirs/selfsup/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400/ +``` + +Note: + +- CONFIG: the config files under the directory `configs/` +- WORK_DIR: the working directory to save configs, logs, and checkpoints +- PARTITION: the slurm partition you are using + +### Downstream Tasks Commands + +To evaluate the **MaskFeat MViT** pretrained with MMSelfSup, we recommend to run MMAction2: + +#### On Multiple GPUs + +```bash +# command example for train +mim train mmaction2 ${CONFIG} \ + --work-dir ${WORK_DIR} \ + --launcher pytorch -gpus 8 \ + --cfg-options model.backbone.init_cfg.type=Pretrained \ + model.backbone.init_cfg.checkpoint=${CHECKPOINT} \ + model.backbone.init_cfg.prefix="backbone." \ + ${PY_ARGS} + [optional args] + +mim train mmaction2 configs/mvit-small_ft-8xb8-coslr-100e_k400.py \ + --work-dir work_dirs/benchmarks/maskfeat/training_maskfeat-mvit-k400/ \ + --launcher pytorch -gpus 8 \ + --cfg-options model.backbone.init_cfg.type=Pretrained \ + model.backbone.init_cfg.checkpoint=https://download.openmmlab.com/mmselfsup/1.x/maskfeat/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400_20230131-87d60b6f.pth \ + model.backbone.init_cfg.prefix="backbone." \ + $PY_ARGS + +# command example for test +mim test mmaction2 configs/mvit-small_ft-8xb16-coslr-100e_k400.py \ + --checkpoint https://download.openmmlab.com/mmselfsup/1.x/maskfeat/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400/mvit-small_ft-8xb16-coslr-100e_k400/mvit-small_ft-8xb16-coslr-100e_k400_20230131-5e8303f5.pth \ + --work-dir work_dirs/benchmarks/maskfeat/maskfeat-mvit-k400/test/ \ + --launcher pytorch --gpus 8 +``` + +#### On Multiple GPUs with Slurm + +```bash +mim train mmaction2 ${CONFIG} \ + --work-dir ${WORK_DIR} \ + --launcher slurm --gpus 8 --gpus-per-node 8 \ + --partition ${PARTITION} \ + --cfg-options model.backbone.init_cfg.type=Pretrained \ + model.backbone.init_cfg.checkpoint=$CHECKPOINT \ + model.backbone.init_cfg.prefix="backbone." \ + $PY_ARGS + +mim test mmaction2 ${CONFIG} \ + --checkpoint https://download.openmmlab.com/mmselfsup/1.x/maskfeat/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400/mvit-small_ft-8xb16-coslr-100e_k400/mvit-small_ft-8xb16-coslr-100e_k400_20230131-5e8303f5.pth + --work-dir ${WORK_DIR} \ + --launcher slurm --gpus 8 --gpus-per-node 8 \ + --partition ${PARTITION} \ + $PY_ARGS +``` + +Note: + +- CONFIG: the config files under the directory `configs/` +- WORK_DIR: the working directory to save configs, logs, and checkpoints +- PARTITION: the slurm partition you are using +- CHECKPOINT: the pretrained checkpoint of MMSelfSup saved in working directory, like `$WORK_DIR/epoch_300.pth` +- PY_ARGS: other optional args + +## Results + + + +The Fine-tuning results are based on Kinetics400(K400) dataset. + +Due to the version of K400 dataset, our pretraining, fine-tuning and the final test results are based on MMAction2 version, which is a little different from PySlowFast version. + + + + + + + + + + + + + + + + + + + + + + + + +
AlgorithmBackboneEpochBatch SizeFine-tuningPretrain LinksFine-tuning Links
MaskFeatMViT-small30051281.8config | model | logconfig | model | log
+ +Remarks: + +- We converted the pretrained model from PySlowFast and run fine-tuning with MMAction2, based on MMAction2 version of K400, we got `81.5` test accuracy. The pretrained model from MMSelfSup got `81.8`, as provided above. +- We also tested our model on [other version](https://github.com/facebookresearch/video-nonlocal-net/blob/main/DATASET.md) of K400, we got `82.1` test accuracy. +- Some other details can be found in [MMAction2 MViT page](https://github.com/open-mmlab/mmaction2/tree/dev-1.x/configs/recognition/mvit). + +## Citation + +```bibtex +@InProceedings{wei2022masked, + author = {Wei, Chen and Fan, Haoqi and Xie, Saining and Wu, Chao-Yuan and Yuille, Alan and Feichtenhofer, Christoph}, + title = {Masked Feature Prediction for Self-Supervised Visual Pre-Training}, + booktitle = {CVPR}, + year = {2022}, +} +``` + +## Checklist + +Here is a checklist illustrating a usual development workflow of a successful project, and also serves as an overview of this project's progress. + + + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + + + - [x] Basic docstrings & proper citation + + + + - [x] Inference correctness + + + + - [x] A full README + + + +- [x] Milestone 2: Indicates a successful model implementation. + + - [x] Training-time correctness + + + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + + + - [ ] Unit tests + + + + - [ ] Code polishing + + + + - [ ] `metafile.yml` and `README.md` + + + +- [ ] Refactor and Move your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/maskfeat_video/configs/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py b/projects/maskfeat_video/configs/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py new file mode 100644 index 000000000..aacdb94dd --- /dev/null +++ b/projects/maskfeat_video/configs/maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py @@ -0,0 +1,102 @@ +_base_ = 'mmselfsup::selfsup/_base_/default_runtime.py' + +custom_imports = dict(imports=['models'], allow_failed_imports=False) + +model = dict( + type='VideoMaskFeat', + data_preprocessor=dict( + type='VideoDataPreprocessor', + mean=[114.75, 114.75, 114.75], + std=[57.375, 57.375, 57.375], + format_shape='NCTHW'), + backbone=dict( + type='MaskFeatMViT', + arch='maskfeat-small', + drop_path_rate=0.0, + dim_mul_in_attention=False), + neck=dict( + type='LinearNeck', + in_channels=768, + out_channels=108, + with_avg_pool=False, + init_cfg=dict(type='TruncNormal', layer='Linear', std=0.02, bias=0)), + head=dict( + type='MaskFeatPretrainHead', + loss=dict(type='PixelReconstructionLoss', criterion='L2')), + target_generator=dict( + type='HOGGenerator3d', nbins=9, pool=8, gaussian_window=16)) + +# dataset settings +dataset_type = 'mmaction.VideoDataset' +data_root = 'data/kinetics400/videos_train' +ann_file_train = 'data/Kinetics400/kinetics400_train_list_videos.txt' + +train_pipeline = [ + dict(type='mmaction.DecordInit'), + dict( + type='mmaction.SampleFrames', + clip_len=16, + frame_interval=4, + num_clips=1), + dict(type='mmaction.DecordDecode'), + dict(type='mmaction.Resize', scale=(-1, 256)), + dict(type='mmaction.RandomResizedCrop', area_range=(0.5, 1.0)), + dict(type='mmaction.Resize', scale=(224, 224), keep_ratio=False), + dict(type='mmaction.Flip', flip_ratio=0.5), + dict(type='mmaction.FormatShape', input_format='NCTHW'), + dict( + type='MaskFeatMaskGenerator3D', + input_size=(8, 7, 7), + num_masking_patches=157, + min_num_patches=9, + max_num_patches=49), + dict(type='PackSelfSupInputs', key='imgs', algorithm_keys=['mask']) +] + +train_dataloader = dict( + batch_size=32, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) + +optim_wrapper = dict( + type='AmpOptimWrapper', + loss_scale='dynamic', + optimizer=dict( + type='AdamW', lr=8e-4 * 2, betas=(0.9, 0.999), weight_decay=0.05), + clip_grad=dict(max_norm=0.02), + paramwise_cfg=dict( + bias_decay_mult=0., + norm_decay_mult=0., + custom_keys={ + 'pos_embed': dict(decay_mult=0.), + 'cls_token': dict(decay_mult=0.) + })) + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=290, + eta_min=1e-6, + by_epoch=True, + begin=10, + end=300, + convert_to_iter_based=True) +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300) +default_hooks = dict( + checkpoint=dict(interval=1, max_keep_ckpts=2), logger=dict(interval=100)) diff --git a/projects/maskfeat_video/configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py b/projects/maskfeat_video/configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py new file mode 100644 index 000000000..3f26e56e9 --- /dev/null +++ b/projects/maskfeat_video/configs/maskfeat_mvit-small_8xb32-amp-coslr-300e_k400.py @@ -0,0 +1,5 @@ +_base_ = './maskfeat_mvit-small_16xb32-amp-coslr-300e_k400.py' + +optim_wrapper = dict( + optimizer=dict( + type='AdamW', lr=8e-4, betas=(0.9, 0.999), weight_decay=0.05)) diff --git a/projects/maskfeat_video/configs/mvit-small_ft-8xb16-coslr-100e_k400.py b/projects/maskfeat_video/configs/mvit-small_ft-8xb16-coslr-100e_k400.py new file mode 100644 index 000000000..367e4baf5 --- /dev/null +++ b/projects/maskfeat_video/configs/mvit-small_ft-8xb16-coslr-100e_k400.py @@ -0,0 +1,157 @@ +_base_ = [ + 'mmaction::_base_/models/mvit_small.py', + 'mmaction::_base_/default_runtime.py' +] + +model = dict( + backbone=dict( + drop_path_rate=0.1, + dim_mul_in_attention=False, + pretrained=None, + pretrained_type='maskfeat', + ), + data_preprocessor=dict( + type='ActionDataPreprocessor', + mean=[114.75, 114.75, 114.75], + std=[57.375, 57.375, 57.375], + blending=dict( + type='RandomBatchAugment', + augments=[ + dict(type='MixupBlending', alpha=0.8, num_classes=400), + dict(type='CutmixBlending', alpha=1, num_classes=400) + ]), + format_shape='NCTHW'), + cls_head=dict(dropout_ratio=0., init_scale=0.001)) + +# dataset settings +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' + +train_pipeline = [ + dict(type='DecordInit'), + dict(type='SampleFrames', clip_len=16, frame_interval=4, num_clips=1), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='PytorchVideoWrapper', op='RandAugment', magnitude=7), + dict(type='RandomResizedCrop'), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='RandomErasing', erase_prob=0.25, mode='rand'), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='PackActionInputs') +] +val_pipeline = [ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=16, + frame_interval=4, + num_clips=1, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='PackActionInputs') +] +test_pipeline = [ + dict(type='DecordInit'), + dict( + type='SampleFrames', + clip_len=16, + frame_interval=4, + num_clips=10, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 224)), + dict(type='CenterCrop', crop_size=224), + dict(type='FormatShape', input_format='NCTHW'), + dict(type='PackActionInputs') +] + +repeat_sample = 2 +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='repeat_pseudo_collate'), + dataset=dict( + type='RepeatAugDataset', + num_repeats=repeat_sample, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +val_evaluator = dict(type='AccMetric') +test_evaluator = val_evaluator + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=100, val_begin=1, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +base_lr = 9.6e-3 +optim_wrapper = dict( + optimizer=dict( + type='AdamW', lr=base_lr, betas=(0.9, 0.999), weight_decay=0.05), + constructor='LearningRateDecayOptimizerConstructor', + paramwise_cfg={ + 'decay_rate': 0.75, + 'decay_type': 'layer_wise', + 'num_layers': 16 + }, + clip_grad=dict(max_norm=5, norm_type=2)) + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1 / 600, + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=80, + eta_min_ratio=1 / 600, + by_epoch=True, + begin=20, + end=100, + convert_to_iter_based=True) +] + +default_hooks = dict( + checkpoint=dict(interval=3, max_keep_ckpts=20), logger=dict(interval=100)) + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (64 samples per GPU) / repeat_sample. +auto_scale_lr = dict(enable=True, base_batch_size=512 // repeat_sample) diff --git a/projects/maskfeat_video/models/__init__.py b/projects/maskfeat_video/models/__init__.py new file mode 100644 index 000000000..96e5f913a --- /dev/null +++ b/projects/maskfeat_video/models/__init__.py @@ -0,0 +1,9 @@ +from .hog_generator_3d import HOGGenerator3d +from .maskfeat import VideoMaskFeat +from .maskfeat_mvit import MaskFeatMViT +from .transforms import MaskFeatMaskGenerator3D + +__all__ = [ + 'HOGGenerator3d', 'VideoMaskFeat', 'MaskFeatMViT', + 'MaskFeatMaskGenerator3D' +] diff --git a/projects/maskfeat_video/models/hog_generator_3d.py b/projects/maskfeat_video/models/hog_generator_3d.py new file mode 100644 index 000000000..df6073b9b --- /dev/null +++ b/projects/maskfeat_video/models/hog_generator_3d.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmselfsup.models.target_generators import HOGGenerator +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class HOGGenerator3d(HOGGenerator): + """Generate HOG feature for videos. + + This module is used in MaskFeat to generate HOG feature. + Here is the link of `HOG wikipedia + `_. + + Args: + nbins (int): Number of bin. Defaults to 9. + pool (float): Number of cell. Defaults to 8. + gaussian_window (int): Size of gaussian kernel. Defaults to 16. + """ + + def __init__(self, + nbins: int = 9, + pool: int = 8, + gaussian_window: int = 16) -> None: + super().__init__( + nbins=nbins, pool=pool, gaussian_window=gaussian_window) + + def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor: + """Reshape HOG Features for output.""" + hog_feat = hog_feat.flatten(1, 2) + self.unfold_size = hog_feat.shape[-1] // 14 + hog_feat = hog_feat.permute(0, 2, 3, 1) + hog_feat = hog_feat.unfold(1, self.unfold_size, + self.unfold_size).unfold( + 2, self.unfold_size, self.unfold_size) + hog_feat = hog_feat.flatten(3).view(self.B, self.T, 14, 14, -1) + hog_feat = hog_feat.flatten(1, 3) # B N C + return hog_feat diff --git a/projects/maskfeat_video/models/maskfeat.py b/projects/maskfeat_video/models/maskfeat.py new file mode 100644 index 000000000..3e9fb3792 --- /dev/null +++ b/projects/maskfeat_video/models/maskfeat.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch +import torch.nn.functional as F + +from mmselfsup.models import BaseModel +from mmselfsup.registry import MODELS +from mmselfsup.structures import SelfSupDataSample + + +@MODELS.register_module() +class VideoMaskFeat(BaseModel): + """MaskFeat. + + Implementation of `Masked Feature Prediction for Self-Supervised Visual + Pre-Training `_. + """ + + def loss(self, inputs: List[torch.Tensor], + data_samples: List[SelfSupDataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[SelfSupDataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack( + [data_sample.mask.value for data_sample in data_samples]) + mask = mask.to(torch.bool) + + video = inputs[0] + video = video.view((-1, ) + video.shape[2:]) # B, C, T, H, W + latent = self.backbone(video, mask) + B, L, C = latent[0].shape + pred = self.neck([latent[0].view(B * L, C)]) + pred = pred[0].view(B, L, -1) + + # generate hog target + video = video[:, :, ::self.backbone.patch_stride[0], :, :] + video = video.transpose(1, 2) # B, T, C, H, W + self.target_generator.B = video.size(0) + self.target_generator.T = video.size(1) + video = video.flatten(0, 1) # B*T, C, H, W + hog = self.target_generator(video) + + mask = self._get_output_mask(mask) + loss = self.head(pred, hog, mask) + losses = dict(loss=loss) + return losses + + def _get_output_mask(self, mask: torch.Tensor) -> torch.Tensor: + size = self.backbone.out_patch_resolution[-1][-1] + output_mask = F.interpolate(mask.float(), size=size) + return output_mask diff --git a/projects/maskfeat_video/models/maskfeat_mvit.py b/projects/maskfeat_video/models/maskfeat_mvit.py new file mode 100644 index 000000000..b255665ad --- /dev/null +++ b/projects/maskfeat_video/models/maskfeat_mvit.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmaction.models import MViT +from mmaction.models.backbones.mvit import resize_pos_embed + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class MaskFeatMViT(MViT): + + arch_zoo = { + 'maskfeat-small': { + 'embed_dims': 96, + 'num_layers': 16, + 'num_heads': 1, + 'downscale_indices': [1, 3], + 'dim_mul_indices': [1, 3, 14] + }, + 'maskfeat-large': { + 'embed_dims': 144, + 'num_layers': 48, + 'num_heads': 2, + 'downscale_indices': [2, 8], + 'dim_mul_indices': [2, 8, 44] + }, + } + + def __init__( + self, + arch: str = 'base', + spatial_size: int = 224, + temporal_size: int = 16, + in_channels: int = 3, + out_scales: Union[int, Sequence[int]] = -1, + drop_path_rate: float = 0, + use_abs_pos_embed: bool = False, + interpolate_mode: str = 'trilinear', + pool_kernel: tuple = (3, 3, 3), + dim_mul: int = 2, + head_mul: int = 2, + adaptive_kv_stride: tuple = (1, 8, 8), + rel_pos_embed: bool = True, + residual_pooling: bool = True, + dim_mul_in_attention: bool = True, + with_cls_token: bool = True, + output_cls_token: bool = True, + rel_pos_zero_init: bool = False, + mlp_ratio: float = 4, + qkv_bias: bool = True, + norm_cfg: dict = dict(type='LN', eps=1e-6), + patch_cfg: dict = dict( + kernel_size=(3, 7, 7), stride=(2, 4, 4), padding=(1, 3, 3)), + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict(type='TruncNormal', layer=['Conv2d', 'Conv3d'], std=0.02), + dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.02), + ] + ) -> None: + super().__init__( + arch=arch, + spatial_size=spatial_size, + temporal_size=temporal_size, + in_channels=in_channels, + out_scales=out_scales, + drop_path_rate=drop_path_rate, + use_abs_pos_embed=use_abs_pos_embed, + interpolate_mode=interpolate_mode, + pool_kernel=pool_kernel, + dim_mul=dim_mul, + head_mul=head_mul, + adaptive_kv_stride=adaptive_kv_stride, + rel_pos_embed=rel_pos_embed, + residual_pooling=residual_pooling, + dim_mul_in_attention=dim_mul_in_attention, + with_cls_token=with_cls_token, + output_cls_token=output_cls_token, + rel_pos_zero_init=rel_pos_zero_init, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + patch_cfg=patch_cfg, + init_cfg=init_cfg) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + self.patch_stride = patch_cfg['stride'] + + def init_weights(self) -> None: + """Initialize mask token and cls token.""" + super().init_weights() + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + nn.init.trunc_normal_(self.cls_token, std=.02) + nn.init.trunc_normal_(self.mask_token, std=.02) + + def forward(self, x: torch.Tensor, + mask: torch.Tensor) -> Tuple[torch.Tensor]: + + x, patch_resolution = self.patch_embed(x) + B, L, C = x.shape + T, H, W = patch_resolution + + mask_tokens = self.mask_token.expand(B, L, -1) + mask = F.interpolate(mask.float(), size=(H, W)) + mask = mask.flatten(1).unsqueeze(-1) + x = x * (1 - mask) + mask_tokens * mask + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + + # if not self.with_cls_token: + # # Remove class token for transformer encoder input + # x = x[:, 1:] + + outs = [] + self.out_patch_resolution = [] + for i, block in enumerate(self.blocks): + x, patch_resolution = block(x, patch_resolution) + + if i in self.stage_indices: + stage_index = self.stage_indices[i] + if stage_index in self.out_scales: + self.out_patch_resolution.append(patch_resolution) + x = getattr(self, f'norm{stage_index}')(x) + if not self.output_cls_token: + out = x[:, 1:] + else: + out = x + outs.append(out) + + return tuple(outs) diff --git a/projects/maskfeat_video/models/transforms.py b/projects/maskfeat_video/models/transforms.py new file mode 100644 index 000000000..d269b96cf --- /dev/null +++ b/projects/maskfeat_video/models/transforms.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import random +from typing import Optional, Tuple + +import numpy as np +from mmcv.transforms.base import BaseTransform + +from mmselfsup.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class MaskFeatMaskGenerator3D(BaseTransform): + """Generate mask for video. + + Added Keys: + + - mask + + This module is borrowed from + https://github.com/facebookresearch/SlowFast/blob/main/slowfast/datasets/transform.py + + Args: + input_size (int): The size of input video. + num_masking_patches (int): The number of patches to be masked. + min_num_patches (int): The minimum number of patches to be masked + in the process of generating mask. Defaults to 4. + max_num_patches (int, optional): The maximum number of patches to be + masked in the process of generating mask. Defaults to None. + min_aspect (float): The minimum aspect ratio of mask blocks. Defaults + to 0.3. + min_aspect (float, optional): The minimum aspect ratio of mask blocks. + Defaults to None. + """ + + def __init__(self, + input_size: int, + num_masking_patches: int, + min_num_patches: int = 4, + max_num_patches: Optional[int] = None, + min_aspect: float = 0.3, + max_aspect: Optional[float] = None) -> None: + + self.temporal, self.height, self.width = input_size + self.num_masking_patches = num_masking_patches + self.min_num_patches = min_num_patches + self.max_num_patches = ( + num_masking_patches + if max_num_patches is None else max_num_patches) + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def get_shape(self) -> Tuple[int, int, int]: + """Get the shape of mask. + + Returns: + Tuple[int, int, int]: The shape of mask. + """ + return self.temporal, self.height, self.width + + def _mask(self, mask: np.ndarray, max_mask_patches: int) -> int: + """Generate mask recursively. + + Args: + mask (np.ndarray): The mask to be generated. + max_mask_patches (int): The maximum number of patches to be masked. + + Returns: + int: The number of patches masked. + """ + delta = 0 + for _ in range(100): + target_area = random.uniform(self.min_num_patches, + self.max_num_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + t = random.randint(1, self.temporal) # ! + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + front = random.randint(0, self.temporal - t) + + num_masked = mask[front:front + t, top:top + h, + left:left + w].sum() + # Overlap + if 0 < h * w * t - num_masked <= max_mask_patches: + for i in range(front, front + t): + for j in range(top, top + h): + for k in range(left, left + w): + if mask[i, j, k] == 0: + mask[i, j, k] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def transform(self, results: dict) -> dict: + """Method to generate random block mask. + + Args: + results (dict): Result dict from previous pipeline. + + Returns: + dict: Result dict with added key ``mask``. + """ + mask = np.zeros(shape=self.get_shape(), dtype=np.int) + mask_count = 0 + while mask_count < self.num_masking_patches: + max_mask_patches = self.num_masking_patches - mask_count + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + results.update({'mask': mask}) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(temporal={self.temporal}, ' + repr_str += f'height={self.height}, ' + repr_str += f'width={self.width}, ' + repr_str += f'num_masking_patches={self.num_masking_patches}, ' + repr_str += f'min_num_patches={self.min_num_patches}, ' + repr_str += f'max_num_patches={self.max_num_patches}, ' + repr_str += f'log_aspect_ratio={self.log_aspect_ratio})' + return repr_str diff --git a/projects/maskfeat_video/tools/dist_train.sh b/projects/maskfeat_video/tools/dist_train.sh new file mode 100644 index 000000000..3fca7641d --- /dev/null +++ b/projects/maskfeat_video/tools/dist_train.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + --launcher pytorch ${@:3} diff --git a/projects/maskfeat_video/tools/slurm_train.sh b/projects/maskfeat_video/tools/slurm_train.sh new file mode 100644 index 000000000..ac36d5082 --- /dev/null +++ b/projects/maskfeat_video/tools/slurm_train.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-""} +PY_ARGS=${@:4} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} diff --git a/projects/maskfeat_video/tools/train.py b/projects/maskfeat_video/tools/train.py new file mode 100644 index 000000000..ef0d3127c --- /dev/null +++ b/projects/maskfeat_video/tools/train.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + +from mmselfsup.utils import register_all_modules + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume', + nargs='?', + type=str, + const='auto', + help='If specify checkpint path, resume from it, while if not ' + 'specify, try to auto resume from the latest checkpoint ' + 'in the work directory.') + parser.add_argument( + '--amp', + action='store_true', + help='enable automatic-mixed-precision training') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + args = parse_args() + + # register all modules in mmselfsup into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + work_type = args.config.split('/')[1] + cfg.work_dir = osp.join('./work_dirs', work_type, + osp.splitext(osp.basename(args.config))[0]) + + # enable automatic-mixed-precision training + if args.amp is True: + optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper') + assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \ + '`--amp` is not supported custom optimizer wrapper type ' \ + f'`{optim_wrapper}.' + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.setdefault('loss_scale', 'dynamic') + + # resume training + if args.resume == 'auto': + cfg.resume = True + cfg.load_from = None + elif args.resume is not None: + cfg.resume = True + cfg.load_from = args.resume + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start training + runner.train() + + +if __name__ == '__main__': + main() diff --git a/tests/test_datasets/test_transforms/test_formmatting.py b/tests/test_datasets/test_transforms/test_formmatting.py index 751edc4d6..bc3fca10d 100644 --- a/tests/test_datasets/test_transforms/test_formmatting.py +++ b/tests/test_datasets/test_transforms/test_formmatting.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np +import pytest import torch from mmselfsup.datasets.transforms import PackSelfSupInputs @@ -32,6 +33,19 @@ def test_pack_selfsup_inputs(): assert list(results['inputs'][0].shape) == [1, 8, 8] assert results['data_samples'].gt_label.value == torch.tensor([1]) + # video data + transform = PackSelfSupInputs(key='img', algorithm_keys=['gt_label']) + results = {'img': np.ones((4, 3, 8, 8, 8)), 'gt_label': 1} + results = transform(results) + assert list(results['inputs'][0].shape) == [4, 3, 8, 8, 8] + assert results['data_samples'].gt_label.value == torch.tensor([1]) + + # dimension check + with pytest.raises(ValueError): + transform = PackSelfSupInputs(key='img', algorithm_keys=['gt_label']) + results = {'img': np.ones((3, 8, 8, 8)), 'gt_label': 1} + results = transform(results) + # img is a list transform = PackSelfSupInputs(key='img', algorithm_keys=['gt_label']) results = {'img': [np.ones((8, 8))], 'gt_label': 1} diff --git a/tests/test_models/test_utils/test_data_preprocessor.py b/tests/test_models/test_utils/test_data_preprocessor.py index bf34a6aa7..94fdbc79f 100644 --- a/tests/test_models/test_utils/test_data_preprocessor.py +++ b/tests/test_models/test_utils/test_data_preprocessor.py @@ -3,7 +3,8 @@ import torch from mmselfsup.models.utils import (SelfSupDataPreprocessor, - TwoNormDataPreprocessor) + TwoNormDataPreprocessor, + VideoDataPreprocessor) from mmselfsup.structures import SelfSupDataSample @@ -66,3 +67,46 @@ def test_two_norm_data_preprocessor(): fake_batches, fake_samples = data_preprocessor(fake_data) assert len(fake_batches) == 2 assert len(fake_samples) == 4 + + +def test_video_data_preprocessor(): + data_preprocessor = VideoDataPreprocessor( + mean=[114.75, 114.75, 114.75], + std=[57.375, 57.375, 57.375], + format_shape='NCTHW') + fake_data = { + 'inputs': [torch.randn((2, 3, 4, 224, 224))], + 'data_sample': [SelfSupDataSample(), + SelfSupDataSample()] + } + fake_batches, fake_samples = data_preprocessor(fake_data) + assert len(fake_batches) == 1 + assert len(fake_samples) == 2 + + data_preprocessor = VideoDataPreprocessor( + mean=[114.75, 114.75, 114.75], + std=[57.375, 57.375, 57.375], + bgr_to_rgb=True, + format_shape='NCTHW') + fake_data = { + 'inputs': [torch.randn((2, 3, 4, 224, 224))], + 'data_sample': [SelfSupDataSample(), + SelfSupDataSample()] + } + fake_batches, fake_samples = data_preprocessor(fake_data) + assert len(fake_batches) == 1 + assert len(fake_samples) == 2 + + data_preprocessor = VideoDataPreprocessor( + mean=[114.75, 114.75, 114.75], + std=[57.375, 57.375, 57.375], + bgr_to_rgb=True, + format_shape='NCHW') + fake_data = { + 'inputs': [torch.randn((2, 3, 224, 224))], + 'data_sample': [SelfSupDataSample(), + SelfSupDataSample()] + } + fake_batches, fake_samples = data_preprocessor(fake_data) + assert len(fake_batches) == 1 + assert len(fake_samples) == 2