Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Register torchvision transforms into mmcls #1265

Merged
merged 18 commits into from Apr 13, 2023
86 changes: 84 additions & 2 deletions docs/en/api/data_process.rst
Expand Up @@ -61,8 +61,8 @@ Loading and Formatting
LoadImageFromFile
PackInputs
PackMultiTaskInputs
ToNumpy
ToPIL
PILToNumpy
NumpyToPIL
Transpose
Collect

Expand Down Expand Up @@ -147,6 +147,88 @@ Transform Wrapper

.. module:: mmpretrain.models.utils.data_preprocessor


TorchVision Transforms
^^^^^^^^^^^^^^^^^^^^^^

We also provide all the transforms in TorchVision. You can use them the like following examples:

**1. Use some TorchVision Augs Surrounded by NumpyToPIL and PILToNumpy (Recommendation)**

Add TorchVision Augs surrounded by ``dict(type='NumpyToPIL', to_rgb=True),`` and ``dict(type='PILToNumpy', to_bgr=True),``

.. code:: python

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL
dict(type='torchvision/RandomResizedCrop',size=176),
dict(type='PILToNumpy', to_bgr=True), # from RGB in PIL to BGR in cv2
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]

data_preprocessor = dict(
num_classes=1000,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True, # from BGR in cv2 to RGB in PIL
)


**2. Use TorchVision Augs and ToTensor&Normalize**

Make sure the 'img' has been converted to PIL format from BGR-Numpy format before being processed by TorchVision Augs.

.. code:: python

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL
dict(
type='torchvision/RandomResizedCrop',
size=176,
interpolation='bilinear'), # accept str format interpolation mode
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(
type='torchvision/TrivialAugmentWide',
interpolation='bilinear'),
dict(type='torchvision/PILToTensor'),
dict(type='torchvision/ConvertImageDtype', dtype=torch.float),
dict(
type='torchvision/Normalize',
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
dict(type='torchvision/RandomErasing', p=0.1),
dict(type='PackInputs'),
]

data_preprocessor = dict(num_classes=1000, mean=None, std=None, to_rgb=False) # Normalize in dataset pipeline


**3. Use TorchVision Augs Except ToTensor&Normalize**

.. code:: python

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL
dict(type='torchvision/RandomResizedCrop', size=176, interpolation='bilinear'),
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(type='torchvision/TrivialAugmentWide', interpolation='bilinear'),
dict(type='PackInputs'),
]

# here the Normalize params is for the RGB format
data_preprocessor = dict(
num_classes=1000,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=False,
)


Data Preprocessors
------------------

Expand Down
6 changes: 3 additions & 3 deletions mmpretrain/datasets/transforms/__init__.py
Expand Up @@ -8,8 +8,8 @@
Equalize, GaussianBlur, Invert, Posterize,
RandAugment, Rotate, Sharpness, Shear, Solarize,
SolarizeAdd, Translate)
from .formatting import (Collect, PackInputs, PackMultiTaskInputs, ToNumpy,
ToPIL, Transpose)
from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs,
PILToNumpy, Transpose)
from .processing import (Albumentations, BEiTMaskGenerator, ColorJitter,
EfficientNetCenterCrop, EfficientNetRandomCrop,
Lighting, RandomCrop, RandomErasing,
Expand All @@ -21,7 +21,7 @@
TRANSFORMS.register_module(module=t)

__all__ = [
'ToPIL', 'ToNumpy', 'Transpose', 'Collect', 'RandomCrop',
'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop',
'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert',
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
Expand Down
54 changes: 35 additions & 19 deletions mmpretrain/datasets/transforms/formatting.py
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from collections.abc import Sequence

import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
Expand Down Expand Up @@ -256,55 +257,70 @@ def __repr__(self):
f'(keys={self.keys}, order={self.order})'


@TRANSFORMS.register_module()
class ToPIL(BaseTransform):
@TRANSFORMS.register_module(('NumpyToPIL', 'ToPIL'))
class NumpyToPIL(BaseTransform):
"""Convert the image from OpenCV format to :obj:`PIL.Image.Image`.

**Required Keys:**

- img
- ``img``

**Modified Keys:**

- img
- ``img``

Args:
to_rgb (bool): Whether to convert img to rgb. Defaults to True.
"""

def transform(self, results):
def __init__(self, to_rgb: bool = False) -> None:
self.to_rgb = to_rgb

def transform(self, results: dict) -> dict:
"""Method to convert images to :obj:`PIL.Image.Image`."""
results['img'] = Image.fromarray(results['img'])
img = results['img']
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img

results['img'] = Image.fromarray(img)
return results

def __repr__(self) -> str:
return self.__class__.__name__ + f'(to_rgb={self.to_rgb})'

@TRANSFORMS.register_module()
class ToNumpy(BaseTransform):
"""Convert object to :obj:`numpy.ndarray`.

@TRANSFORMS.register_module(('PILToNumpy', 'ToNumpy'))
class PILToNumpy(BaseTransform):
"""Convert img to :obj:`numpy.ndarray`.

**Required Keys:**

- ``*keys**``
- ``img``

**Modified Keys:**

- ``*keys**``
- ``img``

Args:
to_bgr (bool): Whether to convert img to rgb. Defaults to True.
dtype (str, optional): The dtype of the converted numpy array.
Defaults to None.
"""

def __init__(self, keys, dtype=None):
self.keys = keys
def __init__(self, to_bgr: bool = False, dtype=None) -> None:
self.to_bgr = to_bgr
self.dtype = dtype

def transform(self, results):
"""Method to convert object to :obj:`numpy.ndarray`."""
for key in self.keys:
results[key] = np.array(results[key], dtype=self.dtype)
def transform(self, results: dict) -> dict:
"""Method to convert img to :obj:`numpy.ndarray`."""
img = np.array(results['img'], dtype=self.dtype)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if self.to_bgr else img

results['img'] = img
return results

def __repr__(self):
def __repr__(self) -> str:
return self.__class__.__name__ + \
f'(keys={self.keys}, dtype={self.dtype})'
f'(to_bgr={self.to_bgr}, dtype={self.dtype})'


@TRANSFORMS.register_module()
Expand Down
91 changes: 91 additions & 0 deletions mmpretrain/datasets/transforms/processing.py
Expand Up @@ -2,14 +2,19 @@
import inspect
import math
import numbers
import re
import traceback
from enum import EnumMeta
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union

import mmcv
import mmengine
import numpy as np
import torchvision
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from torchvision.transforms.transforms import InterpolationMode

from mmpretrain.registry import TRANSFORMS

Expand All @@ -19,6 +24,92 @@
albumentations = None


def _str_to_torch_dtype(t: str):
"""mapping str format dtype to torch.dtype."""
import torch # noqa: F401,F403
return eval(f'torch.{t}')


def _interpolation_modes_from_str(t: str):
"""mapping str format to Interpolation."""
t = t.lower()
inverse_modes_mapping = {
'nearest': InterpolationMode.NEAREST,
'bilinear': InterpolationMode.BILINEAR,
'bicubic': InterpolationMode.BICUBIC,
'box': InterpolationMode.BOX,
'hammimg': InterpolationMode.HAMMING,
'lanczos': InterpolationMode.LANCZOS,
}
return inverse_modes_mapping[t]


def _warpper_vision_transform_cls(vision_transform_cls, new_name):
"""build a transform warpper class for specific torchvison.transform to
handle the different input type between torchvison.transforms with
mmcls.datasets.transforms."""

def new_init(self, *args, **kwargs):
if 'interpolation' in kwargs and isinstance(kwargs['interpolation'],
str):
kwargs['interpolation'] = _interpolation_modes_from_str(
kwargs['interpolation'])
if 'dtype' in kwargs and isinstance(kwargs['dtype'], str):
kwargs['dtype'] = _str_to_torch_dtype(kwargs['dtype'])

try:
self.t = vision_transform_cls(*args, **kwargs)
except TypeError as e:
traceback.print_exc()
raise TypeError(
f'Error when init the {vision_transform_cls}, please '
f'check the argmemnts of {args} and {kwargs}. \n{e}')

def new_call(self, input):
try:
input['img'] = self.t(input['img'])
except Exception as e:
traceback.print_exc()
raise Exception('Error when processing of transform(`torhcvison/'
f'{vision_transform_cls.__name__}`). \n{e}')
return input

def new_str(self):
return str(self.t)

new_transforms_cls = type(
new_name, (),
dict(__init__=new_init, __call__=new_call, __str__=new_str))
return new_transforms_cls


def register_vision_transforms() -> List[str]:
"""Register transforms in ``torchvision.transforms`` to the ``TRANSFORMS``
registry.

Returns:
List[str]: A list of registered transforms' name.
"""
vision_transforms = []
for module_name in dir(torchvision.transforms):
if not re.match('[A-Z]', module_name):
# must startswith a capital letter
continue
_transform = getattr(torchvision.transforms, module_name)
if inspect.isclass(_transform) and callable(
_transform) and not isinstance(_transform, (EnumMeta)):
new_cls = _warpper_vision_transform_cls(
_transform, f'TorchVison{module_name}')
TRANSFORMS.register_module(
module=new_cls, name=f'torchvision/{module_name}')
vision_transforms.append(f'torchvision/{module_name}')
return vision_transforms


# register all the transforms in torchvision by using a transform wrapper
VISION_TRANSFORMS = register_vision_transforms()


@TRANSFORMS.register_module()
class RandomCrop(BaseTransform):
"""Crop the given Image at a random location.
Expand Down
13 changes: 12 additions & 1 deletion mmpretrain/models/backbones/res2net.py
Expand Up @@ -143,6 +143,8 @@ class Res2Layer(Sequential):
Default: dict(type='BN')
scales (int): Scales used in Res2Net. Default: 4
base_width (int): Basic width of each scale. Default: 26
drop_path_rate (float or np.ndarray): stochastic depth rate.
Default: 0.
"""

def __init__(self,
Expand All @@ -156,9 +158,16 @@ def __init__(self,
norm_cfg=dict(type='BN'),
scales=4,
base_width=26,
drop_path_rate=0.0,
**kwargs):
self.block = block

if isinstance(drop_path_rate, float):
drop_path_rate = [drop_path_rate] * num_blocks

assert len(drop_path_rate
) == num_blocks, 'Please check the length of drop_path_rate'

downsample = None
if stride != 1 or in_channels != out_channels:
if avg_down:
Expand Down Expand Up @@ -201,9 +210,10 @@ def __init__(self,
scales=scales,
base_width=base_width,
stage_type='stage',
drop_path_rate=drop_path_rate[0],
**kwargs))
in_channels = out_channels
for _ in range(1, num_blocks):
for i in range(1, num_blocks):
layers.append(
block(
in_channels=in_channels,
Expand All @@ -213,6 +223,7 @@ def __init__(self,
norm_cfg=norm_cfg,
scales=scales,
base_width=base_width,
drop_path_rate=drop_path_rate[i],
**kwargs))
super(Res2Layer, self).__init__(*layers)

Expand Down