Skip to content

Commit

Permalink
[Feature] support MaskFeat pre-training with video (#678)
Browse files Browse the repository at this point in the history
* add maskfeat video configs

* add maskfeat video

* update configs

* update readme

* update results and links

* update configs

* update configs

* update links

* refine readme

* refine doc

* add ut

* refine readme

* update ut

* update readme

* update docs

* update config

* update train scripts in readme

* add mmaction2 page link
  • Loading branch information
fangyixiao18 committed Feb 7, 2023
1 parent e43925c commit 0f634de
Show file tree
Hide file tree
Showing 18 changed files with 1,253 additions and 13 deletions.
21 changes: 15 additions & 6 deletions mmselfsup/datasets/transforms/formatting.py
Expand Up @@ -56,9 +56,8 @@ def transform(self,
Returns:
Dict:
- 'inputs' (List[torch.Tensor]): The forward data of models.
- 'data_samples' (SelfSupDataSample): The annotation info of the
- ``inputs`` (List[torch.Tensor]): The forward data of models.
- ``data_samples`` (SelfSupDataSample): The annotation info of
the forward data.
"""
packed_results = dict()
Expand All @@ -68,9 +67,19 @@ def transform(self,
if not isinstance(img, List):
img = [img]
for i, img_ in enumerate(img):
if len(img_.shape) < 3:
img_ = np.expand_dims(img_, -1)
img_ = np.ascontiguousarray(img_.transpose(2, 0, 1))
# to handle the single channel image
img_ = np.expand_dims(img_, -1) \
if len(img_.shape) == 2 else img_

if len(img_.shape) == 3:
img_ = np.ascontiguousarray(img_.transpose(2, 0, 1))
elif len(img_.shape) == 5:
# for video data with the shape (B, C, T, H, W)
img_ = img_
else:
raise ValueError(
'img should be 2, 3 or 5 dimensional, '
f'instead of {len(img_.shape)} dimensional.')
img[i] = to_tensor(img_)
packed_results['inputs'] = img

Expand Down
4 changes: 2 additions & 2 deletions mmselfsup/models/target_generators/hog_generator.py
Expand Up @@ -35,8 +35,8 @@ def __init__(self,
self.pool = pool
self.pi = math.pi
weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1)
weight_y = weight_x.transpose(2, 3)
weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1).contiguous()
weight_y = weight_x.transpose(2, 3).contiguous()
self.register_buffer('weight_x', weight_x)
self.register_buffer('weight_y', weight_y)

Expand Down
6 changes: 3 additions & 3 deletions mmselfsup/models/utils/__init__.py
Expand Up @@ -4,7 +4,7 @@
RelativeLocDataPreprocessor,
RotationPredDataPreprocessor,
SelfSupDataPreprocessor,
TwoNormDataPreprocessor)
TwoNormDataPreprocessor, VideoDataPreprocessor)
from .ema import CosineEMA
from .extractor import Extractor
from .gather_layer import GatherLayer
Expand All @@ -29,6 +29,6 @@
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'CosineEMA',
'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor',
'RotationPredDataPreprocessor', 'CAEDataPreprocessor', 'ResLayerExtraNorm',
'NormEMAVectorQuantizer', 'TwoNormDataPreprocessor',
'PromptTransformerEncoderLayer', 'build_clip_model'
'NormEMAVectorQuantizer', 'TwoNormDataPreprocessor', 'build_clip_model',
'PromptTransformerEncoderLayer', 'VideoDataPreprocessor'
]
111 changes: 110 additions & 1 deletion mmselfsup/models/utils/data_preprocessor.py
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional, Sequence, Tuple, Union

import torch
from mmengine.model import ImgDataPreprocessor
from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor

from mmselfsup.registry import MODELS

Expand Down Expand Up @@ -290,3 +290,112 @@ def forward(
]

return batch_inputs, batch_data_samples


@MODELS.register_module()
class VideoDataPreprocessor(BaseDataPreprocessor):
"""Video pre-processor for operations, like normalization and bgr to rgb
conversion .
Compared with the :class:`mmaction.ActionDataPreprocessor`, this module
treats each item in `inputs` of input data as a list, instead of
torch.Tensor.
Args:
mean (Sequence[float or int, optional): The pixel mean of channels
of images or stacked optical flow. Defaults to None.
std (Sequence[float or int], optional): The pixel standard deviation
of channels of images or stacked optical flow. Defaults to None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (float or int): The padded pixel value. Defaults to 0.
bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
Defaults to False.
format_shape (str): Format shape of input data.
Defaults to ``'NCHW'``.
"""

def __init__(self,
mean: Optional[Sequence[Union[float, int]]] = None,
std: Optional[Sequence[Union[float, int]]] = None,
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
bgr_to_rgb: bool = False,
format_shape: str = 'NCHW') -> None:
super().__init__()
self.pad_size_divisor = pad_size_divisor
self.pad_value = pad_value
self.bgr_to_rgb = bgr_to_rgb
self.format_shape = format_shape

if mean is not None:
assert std is not None, 'To enable the normalization in ' \
'preprocessing, please specify both ' \
'`mean` and `std`.'
# Enable the normalization in preprocessing.
self._enable_normalize = True
if self.format_shape == 'NCHW':
normalizer_shape = (-1, 1, 1)
elif self.format_shape == 'NCTHW':
normalizer_shape = (-1, 1, 1, 1)
else:
raise ValueError(f'Invalid format shape: {format_shape}')

self.register_buffer(
'mean',
torch.tensor(mean, dtype=torch.float32).view(normalizer_shape),
False)
self.register_buffer(
'std',
torch.tensor(std, dtype=torch.float32).view(normalizer_shape),
False)
else:
self._enable_normalize = False

def forward(
self,
data: dict,
training: bool = False
) -> Tuple[List[torch.Tensor], Optional[list]]:
"""Performs normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation. If
subclasses override this method, they can perform different
preprocessing strategies for training and testing based on the
value of ``training``.
Returns:
Tuple[List[torch.Tensor], Optional[list]]: Data in the same format
as the model input.
"""

data = [val for _, val in data.items()]
batch_inputs, batch_data_samples = self.cast_data(data)

# ------ To RGB ------
if self.bgr_to_rgb:
if self.format_shape == 'NCHW':
batch_inputs = [
batch_input[..., [2, 1, 0], :, :]
for batch_input in batch_inputs
]
elif self.format_shape == 'NCTHW':
batch_inputs = [
batch_input[..., [2, 1, 0], :, :, :]
for batch_input in batch_inputs
]
else:
raise ValueError(f'Invalid format shape: {self.format_shape}')

# -- Normalization ---
if self._enable_normalize:
batch_inputs = [(batch_input - self.mean) / self.std
for batch_input in batch_inputs]
else:
batch_inputs = [
batch_input.to(torch.float32) for batch_input in batch_inputs
]

return batch_inputs, batch_data_samples

0 comments on commit 0f634de

Please sign in to comment.