From d5c09d3d0d4f612411b0c1d99511c6fd51ea2871 Mon Sep 17 00:00:00 2001 From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com> Date: Wed, 22 Feb 2023 11:04:28 +0800 Subject: [PATCH 01/18] [Enhance] Add stochastic depth decay rule in resnet. (#1363) * add stochastic depth decay rule to drop path rate * add default value * update * pass ut * update * pass ut * remove np --- mmpretrain/models/backbones/res2net.py | 13 ++++++++++++- mmpretrain/models/backbones/resnet.py | 26 ++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/mmpretrain/models/backbones/res2net.py b/mmpretrain/models/backbones/res2net.py index 332931043db..6e9bb6df37a 100644 --- a/mmpretrain/models/backbones/res2net.py +++ b/mmpretrain/models/backbones/res2net.py @@ -143,6 +143,8 @@ class Res2Layer(Sequential): Default: dict(type='BN') scales (int): Scales used in Res2Net. Default: 4 base_width (int): Basic width of each scale. Default: 26 + drop_path_rate (float or np.ndarray): stochastic depth rate. + Default: 0. """ def __init__(self, @@ -156,9 +158,16 @@ def __init__(self, norm_cfg=dict(type='BN'), scales=4, base_width=26, + drop_path_rate=0.0, **kwargs): self.block = block + if isinstance(drop_path_rate, float): + drop_path_rate = [drop_path_rate] * num_blocks + + assert len(drop_path_rate + ) == num_blocks, 'Please check the length of drop_path_rate' + downsample = None if stride != 1 or in_channels != out_channels: if avg_down: @@ -201,9 +210,10 @@ def __init__(self, scales=scales, base_width=base_width, stage_type='stage', + drop_path_rate=drop_path_rate[0], **kwargs)) in_channels = out_channels - for _ in range(1, num_blocks): + for i in range(1, num_blocks): layers.append( block( in_channels=in_channels, @@ -213,6 +223,7 @@ def __init__(self, norm_cfg=norm_cfg, scales=scales, base_width=base_width, + drop_path_rate=drop_path_rate[i], **kwargs)) super(Res2Layer, self).__init__(*layers) diff --git a/mmpretrain/models/backbones/resnet.py b/mmpretrain/models/backbones/resnet.py index 704d49559c9..4ef626a85ec 100644 --- a/mmpretrain/models/backbones/resnet.py +++ b/mmpretrain/models/backbones/resnet.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. - +import torch import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, @@ -334,6 +334,8 @@ class ResLayer(nn.Sequential): layer. Default: None norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') + drop_path_rate (float or list): stochastic depth rate. + Default: 0. """ def __init__(self, @@ -346,10 +348,17 @@ def __init__(self, avg_down=False, conv_cfg=None, norm_cfg=dict(type='BN'), + drop_path_rate=0.0, **kwargs): self.block = block self.expansion = get_expansion(block, expansion) + if isinstance(drop_path_rate, float): + drop_path_rate = [drop_path_rate] * num_blocks + + assert len(drop_path_rate + ) == num_blocks, 'Please check the length of drop_path_rate' + downsample = None if stride != 1 or in_channels != out_channels: downsample = [] @@ -384,6 +393,7 @@ def __init__(self, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, + drop_path_rate=drop_path_rate[0], **kwargs)) in_channels = out_channels for i in range(1, num_blocks): @@ -395,6 +405,7 @@ def __init__(self, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, + drop_path_rate=drop_path_rate[i], **kwargs)) super(ResLayer, self).__init__(*layers) @@ -518,6 +529,16 @@ def __init__(self, self.res_layers = [] _in_channels = stem_channels _out_channels = base_channels * self.expansion + + # stochastic depth decay rule + total_depth = sum(stage_blocks) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + # net_num_blocks = sum(stage_blocks) + # dpr = np.linspace(0, drop_path_rate, net_num_blocks) + # block_id = 0 + for i, num_blocks in enumerate(self.stage_blocks): stride = strides[i] dilation = dilations[i] @@ -534,9 +555,10 @@ def __init__(self, with_cp=with_cp, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - drop_path_rate=drop_path_rate) + drop_path_rate=dpr[:num_blocks]) _in_channels = _out_channels _out_channels *= 2 + dpr = dpr[num_blocks:] layer_name = f'layer{i + 1}' self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) From aafa4b517ded4f50223df04582c6cd88c99f1ba7 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 18:46:24 +0800 Subject: [PATCH 02/18] rebase --- mmpretrain/datasets/transforms/processing.py | 92 ++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py index 62eb9e65f12..95dad126b5c 100644 --- a/mmpretrain/datasets/transforms/processing.py +++ b/mmpretrain/datasets/transforms/processing.py @@ -2,6 +2,9 @@ import inspect import math import numbers +import re +import traceback +from enum import EnumMeta from numbers import Number from typing import Dict, List, Optional, Sequence, Tuple, Union @@ -10,6 +13,7 @@ import numpy as np from mmcv.transforms import BaseTransform from mmcv.transforms.utils import cache_randomness +from torchvision.transforms.transforms import InterpolationMode from mmpretrain.registry import TRANSFORMS @@ -19,6 +23,94 @@ albumentations = None +def _str_to_torch_dtype(t: str): + import torch # noqa: F401,F403 + return eval(f'torch.{t}') + + +def _interpolation_modes_from_str(t: str) -> InterpolationMode: + t = t.lower() + inverse_modes_mapping = { + 'nearest': InterpolationMode.NEAREST, + 'bilinear': InterpolationMode.BILINEAR, + 'bicubic': InterpolationMode.BICUBIC, + 'box': InterpolationMode.BOX, + 'hammimg': InterpolationMode.HAMMING, + 'lanczos': InterpolationMode.LANCZOS, + } + return inverse_modes_mapping[t] + + +def _warpper_vision_transform_cls(vision_transform_cls, new_name): + """build a transform warpper class for specific torchvison.transform to + handle the different input type between torchvison.transforms with + mmcls.datasets.transforms.""" + + def new_init(self, *args, **kwargs): + if 'interpolation' in kwargs and isinstance(kwargs['interpolation'], + str): + kwargs['interpolation'] = _interpolation_modes_from_str( + kwargs['interpolation']) + if 'dtype' in kwargs and isinstance(kwargs['dtype'], str): + kwargs['dtype'] = _str_to_torch_dtype(kwargs['dtype']) + + try: + self.t = vision_transform_cls(*args, **kwargs) + except TypeError as e: + traceback.print_exc() + raise TypeError( + f'Error when init the {vision_transform_cls}, please ' + f'check the argmemnts of {args} and {kwargs}. \n{e}') + + def new_call(self, input): + try: + input['img'] = self.t(input['img']) + except Exception as e: + traceback.print_exc() + raise Exception('Error when processing of transform(`torhcvison/' + f'{vision_transform_cls.__name__}`). \n{e}') + return input + + def new_str(self): + return str(self.t) + + new_transforms_cls = type( + new_name, (), + dict(__init__=new_init, __call__=new_call, __str__=new_str)) + return new_transforms_cls + + +def register_vision_transforms() -> List[str]: + """Register transforms in ``torchvision.transforms`` to the ``TRANSFORMS`` + registry. + + Returns: + List[str]: A list of registered transforms' name. + """ + try: + import torchvision.transforms + except ImportError: + raise ImportError('please install ``torchvision``.') + + vision_transforms = [] + for module_name in dir(torchvision.transforms): + if not re.match('[A-Z]', module_name): + # must startswith a capital letter + continue + _transform = getattr(torchvision.transforms, module_name) + if inspect.isclass(_transform) and callable( + _transform) and not isinstance(_transform, (EnumMeta)): + new_cls = _warpper_vision_transform_cls( + _transform, f'TorchVison{module_name}') + TRANSFORMS.register_module( + module=new_cls, name=f'torchvision/{module_name}') + vision_transforms.append(f'torchvision/{module_name}') + return vision_transforms + + +VISION_transforms = register_vision_transforms() + + @TRANSFORMS.register_module() class RandomCrop(BaseTransform): """Crop the given Image at a random location. From 938b5db665da956f221223dd4f548aa7a8d3e536 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Wed, 14 Dec 2022 13:23:48 +0800 Subject: [PATCH 03/18] update ToPIL and ToNumpy --- mmpretrain/datasets/transforms/formatting.py | 40 ++++++++++++++------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/mmpretrain/datasets/transforms/formatting.py b/mmpretrain/datasets/transforms/formatting.py index e6a06800818..d054541c07b 100644 --- a/mmpretrain/datasets/transforms/formatting.py +++ b/mmpretrain/datasets/transforms/formatting.py @@ -2,6 +2,7 @@ from collections import defaultdict from collections.abc import Sequence +import cv2 import numpy as np import torch import torchvision.transforms.functional as F @@ -262,49 +263,64 @@ class ToPIL(BaseTransform): **Required Keys:** - - img + - ``*img**`` **Modified Keys:** - - img + - ``*img**`` + + Args: + to_rgb (bool): Whether to convert img to rgb. Defaults to False. """ + def __init__(self, to_rgb: bool = False): + self.to_rgb = to_rgb + def transform(self, results): """Method to convert images to :obj:`PIL.Image.Image`.""" - results['img'] = Image.fromarray(results['img']) + img = results['img'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img + + results['img'] = Image.fromarray(img) return results + def __repr__(self): + return self.__class__.__name__ + f'(to_rgb={self.to_rgb})' + @TRANSFORMS.register_module() class ToNumpy(BaseTransform): - """Convert object to :obj:`numpy.ndarray`. + """Convert img to :obj:`numpy.ndarray`. **Required Keys:** - - ``*keys**`` + - ``*img**`` **Modified Keys:** - - ``*keys**`` + - ``*img**`` Args: + to_rgb (bool): Whether to convert img to rgb. Defaults to False. dtype (str, optional): The dtype of the converted numpy array. Defaults to None. """ - def __init__(self, keys, dtype=None): - self.keys = keys + def __init__(self, to_rgb: bool = False, dtype=None): + self.to_rgb = to_rgb self.dtype = dtype def transform(self, results): - """Method to convert object to :obj:`numpy.ndarray`.""" - for key in self.keys: - results[key] = np.array(results[key], dtype=self.dtype) + """Method to convert img to :obj:`numpy.ndarray`.""" + img = np.array(results['img'], dtype=self.dtype) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img + + results['img'] = img return results def __repr__(self): return self.__class__.__name__ + \ - f'(keys={self.keys}, dtype={self.dtype})' + f'(to_rgb={self.to_rgb}, dtype={self.dtype})' @TRANSFORMS.register_module() From bb873ec0aa2d359080b0f70f7e7417aafa1d74a1 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 18:49:50 +0800 Subject: [PATCH 04/18] rebase --- mmpretrain/datasets/transforms/formatting.py | 4 +- mmpretrain/datasets/transforms/processing.py | 40 +++++++++++-------- .../test_transforms/test_formatting.py | 36 ++++++++++++----- .../test_transforms/test_processing.py | 9 ++++- 4 files changed, 61 insertions(+), 28 deletions(-) diff --git a/mmpretrain/datasets/transforms/formatting.py b/mmpretrain/datasets/transforms/formatting.py index d054541c07b..4504d5d96c5 100644 --- a/mmpretrain/datasets/transforms/formatting.py +++ b/mmpretrain/datasets/transforms/formatting.py @@ -257,7 +257,7 @@ def __repr__(self): f'(keys={self.keys}, order={self.order})' -@TRANSFORMS.register_module() +@TRANSFORMS.register_module(('ImgToPIL', 'ToPIL')) class ToPIL(BaseTransform): """Convert the image from OpenCV format to :obj:`PIL.Image.Image`. @@ -288,7 +288,7 @@ def __repr__(self): return self.__class__.__name__ + f'(to_rgb={self.to_rgb})' -@TRANSFORMS.register_module() +@TRANSFORMS.register_module(('ImgToNumpy', 'ToNumpy')) class ToNumpy(BaseTransform): """Convert img to :obj:`numpy.ndarray`. diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py index 95dad126b5c..35b871ba556 100644 --- a/mmpretrain/datasets/transforms/processing.py +++ b/mmpretrain/datasets/transforms/processing.py @@ -11,9 +11,10 @@ import mmcv import mmengine import numpy as np +import torchvision from mmcv.transforms import BaseTransform from mmcv.transforms.utils import cache_randomness -from torchvision.transforms.transforms import InterpolationMode +from mmengine.utils import digit_version from mmpretrain.registry import TRANSFORMS @@ -28,16 +29,28 @@ def _str_to_torch_dtype(t: str): return eval(f'torch.{t}') -def _interpolation_modes_from_str(t: str) -> InterpolationMode: +def _interpolation_modes_from_str(t: str): t = t.lower() - inverse_modes_mapping = { - 'nearest': InterpolationMode.NEAREST, - 'bilinear': InterpolationMode.BILINEAR, - 'bicubic': InterpolationMode.BICUBIC, - 'box': InterpolationMode.BOX, - 'hammimg': InterpolationMode.HAMMING, - 'lanczos': InterpolationMode.LANCZOS, - } + if digit_version(torchvision.__version__) >= digit_version('0.8.0'): + from torchvision.transforms.transforms import InterpolationMode + inverse_modes_mapping = { + 'nearest': InterpolationMode.NEAREST, + 'bilinear': InterpolationMode.BILINEAR, + 'bicubic': InterpolationMode.BICUBIC, + 'box': InterpolationMode.BOX, + 'hammimg': InterpolationMode.HAMMING, + 'lanczos': InterpolationMode.LANCZOS, + } + else: + from PIL import Image + inverse_modes_mapping = { + 'nearest': Image.NEAREST, + 'bilinear': Image.BILINEAR, + 'bicubic': Image.BICUBIC, + 'box': Image.BOX, + 'hammimg': Image.HAMMING, + 'lanczos': Image.LANCZOS, + } return inverse_modes_mapping[t] @@ -87,11 +100,6 @@ def register_vision_transforms() -> List[str]: Returns: List[str]: A list of registered transforms' name. """ - try: - import torchvision.transforms - except ImportError: - raise ImportError('please install ``torchvision``.') - vision_transforms = [] for module_name in dir(torchvision.transforms): if not re.match('[A-Z]', module_name): @@ -108,7 +116,7 @@ def register_vision_transforms() -> List[str]: return vision_transforms -VISION_transforms = register_vision_transforms() +VISION_TRANSFORMS = register_vision_transforms() @TRANSFORMS.register_module() diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py index 3fe255aff2d..96df57c1541 100644 --- a/tests/test_datasets/test_transforms/test_formatting.py +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -104,28 +104,46 @@ def test_transform(self): results = transform(copy.deepcopy(data)) self.assertIsInstance(results['img'], Image.Image) + cfg = dict(type='ToPIL', to_rgb=True) + transform = TRANSFORMS.build(cfg) + + data = {'img': np.random.randint(0, 256, (224, 224, 3), dtype='uint8')} + + results = transform(copy.deepcopy(data)) + self.assertIsInstance(results['img'], Image.Image) + np.equal(np.array(results['img']), data['img'][:, :, ::-1]) + + def test_repr(self): + cfg = dict(type='ToPIL', to_rgb=True) + transform = TRANSFORMS.build(cfg) + self.assertEqual(repr(transform), 'ToPIL(to_rgb=True)') + class TestToNumpy(unittest.TestCase): def test_transform(self): img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg') data = { - 'tensor': torch.tensor([1, 2, 3]), - 'Image': Image.open(img_path), + 'img': Image.open(img_path), } - cfg = dict(type='ToNumpy', keys=['tensor', 'Image'], dtype='uint8') + cfg = dict(type='ToNumpy') + transform = TRANSFORMS.build(cfg) + results = transform(copy.deepcopy(data)) + self.assertIsInstance(results['img'], np.ndarray) + self.assertEqual(results['img'].dtype, 'uint8') + + cfg = dict(type='ToNumpy', to_rgb=True) transform = TRANSFORMS.build(cfg) results = transform(copy.deepcopy(data)) - self.assertIsInstance(results['tensor'], np.ndarray) - self.assertEqual(results['tensor'].dtype, 'uint8') - self.assertIsInstance(results['Image'], np.ndarray) - self.assertEqual(results['Image'].dtype, 'uint8') + self.assertIsInstance(results['img'], np.ndarray) + self.assertEqual(results['img'].dtype, 'uint8') + np.equal(results['img'], np.array(data['img'])[:, :, ::-1]) def test_repr(self): - cfg = dict(type='ToNumpy', keys=['img'], dtype='uint8') + cfg = dict(type='ToNumpy', to_rgb=True) transform = TRANSFORMS.build(cfg) - self.assertEqual(repr(transform), "ToNumpy(keys=['img'], dtype=uint8)") + self.assertEqual(repr(transform), 'ToNumpy(to_rgb=True, dtype=None)') class TestCollect(unittest.TestCase): diff --git a/tests/test_datasets/test_transforms/test_processing.py b/tests/test_datasets/test_transforms/test_processing.py index 5dddf6eb884..6b42d3f35e1 100644 --- a/tests/test_datasets/test_transforms/test_processing.py +++ b/tests/test_datasets/test_transforms/test_processing.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import math +import os.path as osp import random from unittest import TestCase from unittest.mock import ANY, call, patch @@ -8,15 +9,21 @@ import mmengine import numpy as np import pytest +import torch +import torchvision +from mmcv.transforms import Compose +from mmengine.utils import digit_version +from PIL import Image +from torchvision import transforms from mmpretrain.registry import TRANSFORMS +from mmpretrain.datasets.transforms.processing import VISION_TRANSFORMS try: import albumentations except ImportError: albumentations = None - def construct_toy_data(): img = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.uint8) From ff2dbb09d7fd052f714a37958c4544050d8146c9 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 18:51:34 +0800 Subject: [PATCH 05/18] rebase --- mmpretrain/datasets/transforms/processing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py index 35b871ba556..f6af6b30499 100644 --- a/mmpretrain/datasets/transforms/processing.py +++ b/mmpretrain/datasets/transforms/processing.py @@ -25,13 +25,15 @@ def _str_to_torch_dtype(t: str): + """mapping str format dtype to torch.dtype.""" import torch # noqa: F401,F403 return eval(f'torch.{t}') def _interpolation_modes_from_str(t: str): + """mapping str format to Interpolation.""" t = t.lower() - if digit_version(torchvision.__version__) >= digit_version('0.8.0'): + if digit_version(torchvision.__version__) > digit_version('0.8.0'): from torchvision.transforms.transforms import InterpolationMode inverse_modes_mapping = { 'nearest': InterpolationMode.NEAREST, From c431cbc0f44ac468d47840e573f9e7b73def0ee0 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 18:52:19 +0800 Subject: [PATCH 06/18] rebase --- mmcls/datasets/transforms/__init__.py | 21 ++++++++++++++ mmpretrain/datasets/transforms/formatting.py | 29 ++++++++++++------- .../test_transforms/test_formatting.py | 5 ++-- 3 files changed, 43 insertions(+), 12 deletions(-) create mode 100644 mmcls/datasets/transforms/__init__.py diff --git a/mmcls/datasets/transforms/__init__.py b/mmcls/datasets/transforms/__init__.py new file mode 100644 index 00000000000..6d1f552946b --- /dev/null +++ b/mmcls/datasets/transforms/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform, + Brightness, ColorTransform, Contrast, Cutout, + Equalize, Invert, Posterize, RandAugment, Rotate, + Sharpness, Shear, Solarize, SolarizeAdd, Translate) +from .formatting import (Collect, NumpyToPIL, PackClsInputs, PILToNumpy, + Transpose) +from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop, + EfficientNetRandomCrop, Lighting, RandomCrop, + RandomErasing, RandomResizedCrop, ResizeEdge) + +__all__ = [ + 'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop', + 'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert', + 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', + 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', + 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', + 'PackClsInputs', 'Albumentations', 'EfficientNetRandomCrop', + 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform', + 'PackMultiTaskInputs' +] diff --git a/mmpretrain/datasets/transforms/formatting.py b/mmpretrain/datasets/transforms/formatting.py index 4504d5d96c5..410e0e3283c 100644 --- a/mmpretrain/datasets/transforms/formatting.py +++ b/mmpretrain/datasets/transforms/formatting.py @@ -1,6 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import defaultdict from collections.abc import Sequence +<<<<<<< HEAD:mmpretrain/datasets/transforms/formatting.py +======= +from functools import partial +from typing import Dict +>>>>>>> b4e24a1a... fix review:mmcls/datasets/transforms/formatting.py import cv2 import numpy as np @@ -102,6 +107,7 @@ def __init__(self, self.algorithm_keys = algorithm_keys self.meta_keys = meta_keys +<<<<<<< HEAD:mmpretrain/datasets/transforms/formatting.py @staticmethod def format_input(input_): if isinstance(input_, list): @@ -128,6 +134,9 @@ def format_input(input_): return input_ def transform(self, results: dict) -> dict: +======= + def transform(self, results: Dict) -> Dict: +>>>>>>> b4e24a1a... fix review:mmcls/datasets/transforms/formatting.py """Method to pack the input data.""" packed_results = dict() if self.input_key in results: @@ -257,8 +266,8 @@ def __repr__(self): f'(keys={self.keys}, order={self.order})' -@TRANSFORMS.register_module(('ImgToPIL', 'ToPIL')) -class ToPIL(BaseTransform): +@TRANSFORMS.register_module(('NumpyToPIL', 'ToPIL')) +class NumpyToPIL(BaseTransform): """Convert the image from OpenCV format to :obj:`PIL.Image.Image`. **Required Keys:** @@ -273,10 +282,10 @@ class ToPIL(BaseTransform): to_rgb (bool): Whether to convert img to rgb. Defaults to False. """ - def __init__(self, to_rgb: bool = False): + def __init__(self, to_rgb: bool = False) -> None: self.to_rgb = to_rgb - def transform(self, results): + def transform(self, results: Dict) -> Dict: """Method to convert images to :obj:`PIL.Image.Image`.""" img = results['img'] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img @@ -284,12 +293,12 @@ def transform(self, results): results['img'] = Image.fromarray(img) return results - def __repr__(self): + def __repr__(self) -> str: return self.__class__.__name__ + f'(to_rgb={self.to_rgb})' -@TRANSFORMS.register_module(('ImgToNumpy', 'ToNumpy')) -class ToNumpy(BaseTransform): +@TRANSFORMS.register_module(('PILToNumpy', 'ToNumpy')) +class PILToNumpy(BaseTransform): """Convert img to :obj:`numpy.ndarray`. **Required Keys:** @@ -306,11 +315,11 @@ class ToNumpy(BaseTransform): Defaults to None. """ - def __init__(self, to_rgb: bool = False, dtype=None): + def __init__(self, to_rgb: bool = False, dtype=None) -> None: self.to_rgb = to_rgb self.dtype = dtype - def transform(self, results): + def transform(self, results: Dict) -> Dict: """Method to convert img to :obj:`numpy.ndarray`.""" img = np.array(results['img'], dtype=self.dtype) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img @@ -318,7 +327,7 @@ def transform(self, results): results['img'] = img return results - def __repr__(self): + def __repr__(self) -> str: return self.__class__.__name__ + \ f'(to_rgb={self.to_rgb}, dtype={self.dtype})' diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py index 96df57c1541..cca27fdbdc4 100644 --- a/tests/test_datasets/test_transforms/test_formatting.py +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -116,7 +116,7 @@ def test_transform(self): def test_repr(self): cfg = dict(type='ToPIL', to_rgb=True) transform = TRANSFORMS.build(cfg) - self.assertEqual(repr(transform), 'ToPIL(to_rgb=True)') + self.assertEqual(repr(transform), 'NumpyToPIL(to_rgb=True)') class TestToNumpy(unittest.TestCase): @@ -143,7 +143,8 @@ def test_transform(self): def test_repr(self): cfg = dict(type='ToNumpy', to_rgb=True) transform = TRANSFORMS.build(cfg) - self.assertEqual(repr(transform), 'ToNumpy(to_rgb=True, dtype=None)') + self.assertEqual( + repr(transform), 'PILToNumpy(to_rgb=True, dtype=None)') class TestCollect(unittest.TestCase): From 8b9cf976024621f5a67afa7012ac08443ee6424f Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 18:52:50 +0800 Subject: [PATCH 07/18] rebase --- mmcls/datasets/transforms/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcls/datasets/transforms/__init__.py b/mmcls/datasets/transforms/__init__.py index 6d1f552946b..193995a9664 100644 --- a/mmcls/datasets/transforms/__init__.py +++ b/mmcls/datasets/transforms/__init__.py @@ -3,8 +3,8 @@ Brightness, ColorTransform, Contrast, Cutout, Equalize, Invert, Posterize, RandAugment, Rotate, Sharpness, Shear, Solarize, SolarizeAdd, Translate) -from .formatting import (Collect, NumpyToPIL, PackClsInputs, PILToNumpy, - Transpose) +from .formatting import (Collect, NumpyToPIL, PackClsInputs, + PackMultiTaskInputs, PILToNumpy, Transpose) from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop, EfficientNetRandomCrop, Lighting, RandomCrop, RandomErasing, RandomResizedCrop, ResizeEdge) From 809e63c12f81bf2587a935db5057ecb6238b7054 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 14 Mar 2023 13:08:44 +0800 Subject: [PATCH 08/18] add readme --- mmpretrain/datasets/transforms/processing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py index f6af6b30499..776bc5d472c 100644 --- a/mmpretrain/datasets/transforms/processing.py +++ b/mmpretrain/datasets/transforms/processing.py @@ -118,6 +118,7 @@ def register_vision_transforms() -> List[str]: return vision_transforms +# register all the transforms in torchvision by using a transform wrapper VISION_TRANSFORMS = register_vision_transforms() From 67f04364894dd8f5f2d4cfdbddb6c247d51ba887 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Fri, 17 Mar 2023 10:50:47 +0800 Subject: [PATCH 09/18] fix review suggestions --- mmpretrain/datasets/transforms/formatting.py | 16 ++++++++-------- .../test_transforms/test_formatting.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mmpretrain/datasets/transforms/formatting.py b/mmpretrain/datasets/transforms/formatting.py index 410e0e3283c..e8f6450c366 100644 --- a/mmpretrain/datasets/transforms/formatting.py +++ b/mmpretrain/datasets/transforms/formatting.py @@ -279,13 +279,13 @@ class NumpyToPIL(BaseTransform): - ``*img**`` Args: - to_rgb (bool): Whether to convert img to rgb. Defaults to False. + to_rgb (bool): Whether to convert img to rgb. Defaults to True. """ def __init__(self, to_rgb: bool = False) -> None: self.to_rgb = to_rgb - def transform(self, results: Dict) -> Dict: + def transform(self, results: dict) -> dict: """Method to convert images to :obj:`PIL.Image.Image`.""" img = results['img'] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img @@ -310,26 +310,26 @@ class PILToNumpy(BaseTransform): - ``*img**`` Args: - to_rgb (bool): Whether to convert img to rgb. Defaults to False. + to_bgr (bool): Whether to convert img to rgb. Defaults to True. dtype (str, optional): The dtype of the converted numpy array. Defaults to None. """ - def __init__(self, to_rgb: bool = False, dtype=None) -> None: - self.to_rgb = to_rgb + def __init__(self, to_bgr: bool = False, dtype=None) -> None: + self.to_bgr = to_bgr self.dtype = dtype - def transform(self, results: Dict) -> Dict: + def transform(self, results: dict) -> dict: """Method to convert img to :obj:`numpy.ndarray`.""" img = np.array(results['img'], dtype=self.dtype) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.to_rgb else img + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if self.to_bgr else img results['img'] = img return results def __repr__(self) -> str: return self.__class__.__name__ + \ - f'(to_rgb={self.to_rgb}, dtype={self.dtype})' + f'(to_bgr={self.to_bgr}, dtype={self.dtype})' @TRANSFORMS.register_module() diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py index cca27fdbdc4..e515c6d33e5 100644 --- a/tests/test_datasets/test_transforms/test_formatting.py +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -133,7 +133,7 @@ def test_transform(self): self.assertIsInstance(results['img'], np.ndarray) self.assertEqual(results['img'].dtype, 'uint8') - cfg = dict(type='ToNumpy', to_rgb=True) + cfg = dict(type='ToNumpy', to_bgr=True) transform = TRANSFORMS.build(cfg) results = transform(copy.deepcopy(data)) self.assertIsInstance(results['img'], np.ndarray) @@ -141,10 +141,10 @@ def test_transform(self): np.equal(results['img'], np.array(data['img'])[:, :, ::-1]) def test_repr(self): - cfg = dict(type='ToNumpy', to_rgb=True) + cfg = dict(type='ToNumpy', to_bgr=True) transform = TRANSFORMS.build(cfg) self.assertEqual( - repr(transform), 'PILToNumpy(to_rgb=True, dtype=None)') + repr(transform), 'PILToNumpy(to_bgr=True, dtype=None)') class TestCollect(unittest.TestCase): From 0b818d66a86f322466b78235e812b25775d6c34b Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 18:55:39 +0800 Subject: [PATCH 10/18] rebase --- .../test_transforms/test_processing.py | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/tests/test_datasets/test_transforms/test_processing.py b/tests/test_datasets/test_transforms/test_processing.py index 6b42d3f35e1..f053d457d4e 100644 --- a/tests/test_datasets/test_transforms/test_processing.py +++ b/tests/test_datasets/test_transforms/test_processing.py @@ -871,3 +871,88 @@ def test_repr(self): repr(transform), 'BEiTMaskGenerator(height=14, width=14, ' 'num_patches=196, num_masking_patches=75, min_num_patches=16, ' f'max_num_patches=75, log_aspect_ratio={log_aspect_ratio})') + +class TestVisionTransformWrapper(TestCase): + + def test_register(self): + for t in VISION_TRANSFORMS: + self.assertIn('torchvision/', t) + self.assertIn(t, TRANSFORMS) + + def test_transform(self): + img_path = osp.join(osp.dirname(__file__), '../../data/color.jpg') + data = {'img': Image.open(img_path)} + + # test normal transform + vision_trans = transforms.RandomResizedCrop(224) + vision_transformed_img = vision_trans(data['img']) + mmcls_trans = TRANSFORMS.build( + dict(type='torchvision/RandomResizedCrop', size=224)) + mmcls_transformed_img = mmcls_trans(data)['img'] + np.equal( + np.array(vision_transformed_img), np.array(mmcls_transformed_img)) + + # test convert type dtype + data = {'img': torch.randn(3, 224, 224)} + vision_trans = transforms.ConvertImageDtype(torch.float) + vision_transformed_img = vision_trans(data['img']) + mmcls_trans = TRANSFORMS.build( + dict(type='torchvision/ConvertImageDtype', dtype='float')) + mmcls_transformed_img = mmcls_trans(data)['img'] + np.equal( + np.array(vision_transformed_img), np.array(mmcls_transformed_img)) + + # test transform with interpolation + data = {'img': Image.open(img_path)} + if digit_version(torchvision.__version__) > digit_version('0.8.0'): + from torchvision.transforms import InterpolationMode + interpolation_t = InterpolationMode.NEAREST + else: + interpolation_t = Image.NEAREST + vision_trans = transforms.Resize(224, interpolation_t) + vision_transformed_img = vision_trans(data['img']) + mmcls_trans = TRANSFORMS.build( + dict(type='torchvision/Resize', size=224, interpolation='nearest')) + mmcls_transformed_img = mmcls_trans(data)['img'] + np.equal( + np.array(vision_transformed_img), np.array(mmcls_transformed_img)) + + # test compose transforms + data = {'img': Image.open(img_path)} + vision_trans = transforms.Compose([ + transforms.Resize(176), + transforms.RandomHorizontalFlip(), + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + vision_transformed_img = vision_trans(data['img']) + + pipeline_cfg = [ + dict(type='LoadImageFromFile'), + dict(type='NumpyToPIL', to_rgb=True), + dict(type='torchvision/Resize', size=176), + dict(type='torchvision/RandomHorizontalFlip'), + dict(type='torchvision/PILToTensor'), + dict(type='torchvision/ConvertImageDtype', dtype='float'), + dict( + type='torchvision/Normalize', + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + ) + ] + pipeline = [TRANSFORMS.build(t) for t in pipeline_cfg] + mmcls_trans = Compose(transforms=pipeline) + mmcls_data = {'img_path': img_path} + mmcls_transformed_img = mmcls_trans(mmcls_data)['img'] + np.equal( + np.array(vision_transformed_img), np.array(mmcls_transformed_img)) + + def test_repr(self): + vision_trans = transforms.RandomResizedCrop(224) + mmcls_trans = TRANSFORMS.build( + dict(type='torchvision/RandomResizedCrop', size=224)) + + self.assertEqual(str(vision_trans), str(mmcls_trans)) + From 970c85e16d3fe1aeff62f54e5d22b27cd02c85ce Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 19:08:22 +0800 Subject: [PATCH 11/18] fix conflicts --- mmcls/datasets/transforms/__init__.py | 21 -------------------- mmpretrain/datasets/transforms/__init__.py | 7 ++++--- mmpretrain/datasets/transforms/formatting.py | 9 --------- 3 files changed, 4 insertions(+), 33 deletions(-) delete mode 100644 mmcls/datasets/transforms/__init__.py diff --git a/mmcls/datasets/transforms/__init__.py b/mmcls/datasets/transforms/__init__.py deleted file mode 100644 index 193995a9664..00000000000 --- a/mmcls/datasets/transforms/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .auto_augment import (AutoAugment, AutoContrast, BaseAugTransform, - Brightness, ColorTransform, Contrast, Cutout, - Equalize, Invert, Posterize, RandAugment, Rotate, - Sharpness, Shear, Solarize, SolarizeAdd, Translate) -from .formatting import (Collect, NumpyToPIL, PackClsInputs, - PackMultiTaskInputs, PILToNumpy, Transpose) -from .processing import (Albumentations, ColorJitter, EfficientNetCenterCrop, - EfficientNetRandomCrop, Lighting, RandomCrop, - RandomErasing, RandomResizedCrop, ResizeEdge) - -__all__ = [ - 'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop', - 'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert', - 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', - 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', - 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', - 'PackClsInputs', 'Albumentations', 'EfficientNetRandomCrop', - 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform', - 'PackMultiTaskInputs' -] diff --git a/mmpretrain/datasets/transforms/__init__.py b/mmpretrain/datasets/transforms/__init__.py index cc2df47e8ce..474a9699f91 100644 --- a/mmpretrain/datasets/transforms/__init__.py +++ b/mmpretrain/datasets/transforms/__init__.py @@ -8,8 +8,8 @@ Equalize, GaussianBlur, Invert, Posterize, RandAugment, Rotate, Sharpness, Shear, Solarize, SolarizeAdd, Translate) -from .formatting import (Collect, PackInputs, PackMultiTaskInputs, ToNumpy, - ToPIL, Transpose) +from .formatting import (Collect, PackInputs, PackMultiTaskInputs, PILToNumpy, + NumpyToPIL, Transpose) from .processing import (Albumentations, BEiTMaskGenerator, ColorJitter, EfficientNetCenterCrop, EfficientNetRandomCrop, Lighting, RandomCrop, RandomErasing, @@ -30,5 +30,6 @@ 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform', 'PackMultiTaskInputs', 'GaussianBlur', 'BEiTMaskGenerator', 'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize', - 'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView' + 'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView', + 'PILToNumpy', 'NumpyToPIL' ] diff --git a/mmpretrain/datasets/transforms/formatting.py b/mmpretrain/datasets/transforms/formatting.py index e8f6450c366..69da1299a4e 100644 --- a/mmpretrain/datasets/transforms/formatting.py +++ b/mmpretrain/datasets/transforms/formatting.py @@ -1,11 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import defaultdict from collections.abc import Sequence -<<<<<<< HEAD:mmpretrain/datasets/transforms/formatting.py -======= -from functools import partial -from typing import Dict ->>>>>>> b4e24a1a... fix review:mmcls/datasets/transforms/formatting.py import cv2 import numpy as np @@ -107,7 +102,6 @@ def __init__(self, self.algorithm_keys = algorithm_keys self.meta_keys = meta_keys -<<<<<<< HEAD:mmpretrain/datasets/transforms/formatting.py @staticmethod def format_input(input_): if isinstance(input_, list): @@ -134,9 +128,6 @@ def format_input(input_): return input_ def transform(self, results: dict) -> dict: -======= - def transform(self, results: Dict) -> Dict: ->>>>>>> b4e24a1a... fix review:mmcls/datasets/transforms/formatting.py """Method to pack the input data.""" packed_results = dict() if self.input_key in results: From d457a67e8d5033ba46625d49114d917223d44066 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 19:17:26 +0800 Subject: [PATCH 12/18] fix conflicts --- docs/en/api/data_process.rst | 4 ++-- mmpretrain/datasets/transforms/__init__.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/en/api/data_process.rst b/docs/en/api/data_process.rst index 8c0159099ca..376700c1070 100644 --- a/docs/en/api/data_process.rst +++ b/docs/en/api/data_process.rst @@ -61,8 +61,8 @@ Loading and Formatting LoadImageFromFile PackInputs PackMultiTaskInputs - ToNumpy - ToPIL + PILToNumpy + NumpyToPIL Transpose Collect diff --git a/mmpretrain/datasets/transforms/__init__.py b/mmpretrain/datasets/transforms/__init__.py index 474a9699f91..e173b64ba8d 100644 --- a/mmpretrain/datasets/transforms/__init__.py +++ b/mmpretrain/datasets/transforms/__init__.py @@ -21,7 +21,7 @@ TRANSFORMS.register_module(module=t) __all__ = [ - 'ToPIL', 'ToNumpy', 'Transpose', 'Collect', 'RandomCrop', + 'NumpyToPIL', 'PILToNumpy', 'Transpose', 'Collect', 'RandomCrop', 'RandomResizedCrop', 'Shear', 'Translate', 'Rotate', 'Invert', 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', @@ -30,6 +30,5 @@ 'EfficientNetCenterCrop', 'ResizeEdge', 'BaseAugTransform', 'PackMultiTaskInputs', 'GaussianBlur', 'BEiTMaskGenerator', 'SimMIMMaskGenerator', 'CenterCrop', 'LoadImageFromFile', 'Normalize', - 'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView', - 'PILToNumpy', 'NumpyToPIL' + 'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView' ] From 4a510c25d5725cf817f4d23ce6fa5316d4694c52 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 19:28:42 +0800 Subject: [PATCH 13/18] fix lint --- mmpretrain/datasets/transforms/__init__.py | 4 ++-- tests/test_datasets/test_transforms/test_processing.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mmpretrain/datasets/transforms/__init__.py b/mmpretrain/datasets/transforms/__init__.py index e173b64ba8d..583303cc206 100644 --- a/mmpretrain/datasets/transforms/__init__.py +++ b/mmpretrain/datasets/transforms/__init__.py @@ -8,8 +8,8 @@ Equalize, GaussianBlur, Invert, Posterize, RandAugment, Rotate, Sharpness, Shear, Solarize, SolarizeAdd, Translate) -from .formatting import (Collect, PackInputs, PackMultiTaskInputs, PILToNumpy, - NumpyToPIL, Transpose) +from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs, + PILToNumpy, Transpose) from .processing import (Albumentations, BEiTMaskGenerator, ColorJitter, EfficientNetCenterCrop, EfficientNetRandomCrop, Lighting, RandomCrop, RandomErasing, diff --git a/tests/test_datasets/test_transforms/test_processing.py b/tests/test_datasets/test_transforms/test_processing.py index f053d457d4e..164a9213ab1 100644 --- a/tests/test_datasets/test_transforms/test_processing.py +++ b/tests/test_datasets/test_transforms/test_processing.py @@ -16,14 +16,15 @@ from PIL import Image from torchvision import transforms -from mmpretrain.registry import TRANSFORMS from mmpretrain.datasets.transforms.processing import VISION_TRANSFORMS +from mmpretrain.registry import TRANSFORMS try: import albumentations except ImportError: albumentations = None + def construct_toy_data(): img = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.uint8) @@ -872,6 +873,7 @@ def test_repr(self): 'num_patches=196, num_masking_patches=75, min_num_patches=16, ' f'max_num_patches=75, log_aspect_ratio={log_aspect_ratio})') + class TestVisionTransformWrapper(TestCase): def test_register(self): @@ -955,4 +957,3 @@ def test_repr(self): dict(type='torchvision/RandomResizedCrop', size=224)) self.assertEqual(str(vision_trans), str(mmcls_trans)) - From fd45d501c652d4a845b78740411280c9f0da47a6 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 19:34:12 +0800 Subject: [PATCH 14/18] remove comments --- mmpretrain/models/backbones/resnet.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mmpretrain/models/backbones/resnet.py b/mmpretrain/models/backbones/resnet.py index 4ef626a85ec..e4df601db56 100644 --- a/mmpretrain/models/backbones/resnet.py +++ b/mmpretrain/models/backbones/resnet.py @@ -535,9 +535,6 @@ def __init__(self, dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, total_depth) ] - # net_num_blocks = sum(stage_blocks) - # dpr = np.linspace(0, drop_path_rate, net_num_blocks) - # block_id = 0 for i, num_blocks in enumerate(self.stage_blocks): stride = strides[i] From 9ba80680dcd8610d4b7b8b00cded8a35a1b67847 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Tue, 11 Apr 2023 19:49:44 +0800 Subject: [PATCH 15/18] remove useless code --- mmpretrain/datasets/transforms/processing.py | 30 ++++++-------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py index 776bc5d472c..9e41ed1db02 100644 --- a/mmpretrain/datasets/transforms/processing.py +++ b/mmpretrain/datasets/transforms/processing.py @@ -14,7 +14,7 @@ import torchvision from mmcv.transforms import BaseTransform from mmcv.transforms.utils import cache_randomness -from mmengine.utils import digit_version +from torchvision.transforms.transforms import InterpolationMode from mmpretrain.registry import TRANSFORMS @@ -33,26 +33,14 @@ def _str_to_torch_dtype(t: str): def _interpolation_modes_from_str(t: str): """mapping str format to Interpolation.""" t = t.lower() - if digit_version(torchvision.__version__) > digit_version('0.8.0'): - from torchvision.transforms.transforms import InterpolationMode - inverse_modes_mapping = { - 'nearest': InterpolationMode.NEAREST, - 'bilinear': InterpolationMode.BILINEAR, - 'bicubic': InterpolationMode.BICUBIC, - 'box': InterpolationMode.BOX, - 'hammimg': InterpolationMode.HAMMING, - 'lanczos': InterpolationMode.LANCZOS, - } - else: - from PIL import Image - inverse_modes_mapping = { - 'nearest': Image.NEAREST, - 'bilinear': Image.BILINEAR, - 'bicubic': Image.BICUBIC, - 'box': Image.BOX, - 'hammimg': Image.HAMMING, - 'lanczos': Image.LANCZOS, - } + inverse_modes_mapping = { + 'nearest': InterpolationMode.NEAREST, + 'bilinear': InterpolationMode.BILINEAR, + 'bicubic': InterpolationMode.BICUBIC, + 'box': InterpolationMode.BOX, + 'hammimg': InterpolationMode.HAMMING, + 'lanczos': InterpolationMode.LANCZOS, + } return inverse_modes_mapping[t] From 1ae38896d896a7fc16bedb8a037f1830db789801 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Wed, 12 Apr 2023 17:01:25 +0800 Subject: [PATCH 16/18] update docstring --- mmpretrain/datasets/transforms/formatting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmpretrain/datasets/transforms/formatting.py b/mmpretrain/datasets/transforms/formatting.py index 69da1299a4e..30480b7d99f 100644 --- a/mmpretrain/datasets/transforms/formatting.py +++ b/mmpretrain/datasets/transforms/formatting.py @@ -263,11 +263,11 @@ class NumpyToPIL(BaseTransform): **Required Keys:** - - ``*img**`` + - ``img`` **Modified Keys:** - - ``*img**`` + - ``img`` Args: to_rgb (bool): Whether to convert img to rgb. Defaults to True. @@ -294,11 +294,11 @@ class PILToNumpy(BaseTransform): **Required Keys:** - - ``*img**`` + - ``img`` **Modified Keys:** - - ``*img**`` + - ``img`` Args: to_bgr (bool): Whether to convert img to rgb. Defaults to True. From 029a5db9c0b2cdd0ddd05728936639e3fe3673e4 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Wed, 12 Apr 2023 17:28:23 +0800 Subject: [PATCH 17/18] update doc API --- docs/en/api/data_process.rst | 82 ++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/docs/en/api/data_process.rst b/docs/en/api/data_process.rst index 376700c1070..b757b6100f6 100644 --- a/docs/en/api/data_process.rst +++ b/docs/en/api/data_process.rst @@ -147,6 +147,88 @@ Transform Wrapper .. module:: mmpretrain.models.utils.data_preprocessor + +TorchVision Transforms +^^^^^^^^^^^^^^^^^^^^^^ + +We also provides all the transforms in TorchVision. You can use them like following examples: + +**1. Use some TorchVision Augs Surrounded by NumpyToPIL and PILToNumpy (Recommendation)** + +Add TorchVision Augs surrounded by ``dict(type='NumpyToPIL', to_rgb=True),`` and ``dict(type='PILToNumpy', to_bgr=True),`` + +.. code:: python + + train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL + dict(type='torchvision/RandomResizedCrop',size=176), + dict(type='PILToNumpy', to_bgr=True), # from RGB in PIL to BGR in cv2 + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='PackInputs'), + ] + + data_preprocessor = dict( + num_classes=1000, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True, # from BGR in cv2 to RGB in PIL + ) + + +**2. Use TorchVision Augs and ToTensor&Normalize** + +Make sure have converted to RGB-Numpy format before processing by TorchVision Augs. + +.. code:: python + + train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL + dict( + type='torchvision/RandomResizedCrop', + size=176, + interpolation='bilinear'), # accept str format interpolation mode + dict(type='torchvision/RandomHorizontalFlip', p=0.5), + dict( + type='torchvision/TrivialAugmentWide', + interpolation='bilinear'), + dict(type='torchvision/PILToTensor'), + dict(type='torchvision/ConvertImageDtype', dtype=torch.float), + dict( + type='torchvision/Normalize', + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + ), + dict(type='torchvision/RandomErasing', p=0.1), + dict(type='PackInputs'), + ] + + data_preprocessor = dict(num_classes=1000, mean=None, std=None, to_rgb=False) # Normalize in dataset pipeline + + +**3. USe TorchVision Augs Except ToTensor&Normalize** + +.. code:: python + + train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='NumpyToPIL', to_rgb=True), # from BGR in cv2 to RGB in PIL + dict(type='torchvision/RandomResizedCrop', size=176, interpolation='bilinear'), + dict(type='torchvision/RandomHorizontalFlip', p=0.5), + dict(type='torchvision/TrivialAugmentWide', interpolation='bilinear'), + dict(type='PackInputs'), + ] + + # here the Normalize params is for the RGB format + data_preprocessor = dict( + num_classes=1000, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=False, + ) + + Data Preprocessors ------------------ From e9f53be7c2f2d096393c522826e51e24dfead88f Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Wed, 12 Apr 2023 18:07:01 +0800 Subject: [PATCH 18/18] update doc --- docs/en/api/data_process.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/en/api/data_process.rst b/docs/en/api/data_process.rst index b757b6100f6..af0f6e54ec2 100644 --- a/docs/en/api/data_process.rst +++ b/docs/en/api/data_process.rst @@ -151,7 +151,7 @@ Transform Wrapper TorchVision Transforms ^^^^^^^^^^^^^^^^^^^^^^ -We also provides all the transforms in TorchVision. You can use them like following examples: +We also provide all the transforms in TorchVision. You can use them the like following examples: **1. Use some TorchVision Augs Surrounded by NumpyToPIL and PILToNumpy (Recommendation)** @@ -178,7 +178,7 @@ Add TorchVision Augs surrounded by ``dict(type='NumpyToPIL', to_rgb=True),`` and **2. Use TorchVision Augs and ToTensor&Normalize** -Make sure have converted to RGB-Numpy format before processing by TorchVision Augs. +Make sure the 'img' has been converted to PIL format from BGR-Numpy format before being processed by TorchVision Augs. .. code:: python @@ -207,7 +207,7 @@ Make sure have converted to RGB-Numpy format before processing by TorchVision Au data_preprocessor = dict(num_classes=1000, mean=None, std=None, to_rgb=False) # Normalize in dataset pipeline -**3. USe TorchVision Augs Except ToTensor&Normalize** +**3. Use TorchVision Augs Except ToTensor&Normalize** .. code:: python