diff --git a/docs/zh_cn/advanced_guides/datasets.md b/docs/zh_cn/advanced_guides/datasets.md index 06a75e54bd..29062e73f3 100644 --- a/docs/zh_cn/advanced_guides/datasets.md +++ b/docs/zh_cn/advanced_guides/datasets.md @@ -1,4 +1,4 @@ -# 数据集 +# 数据集 在 MMSegmentation 算法库中, 所有 Dataset 类的功能有两个: 加载[预处理](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/user_guides/2_dataset_prepare.md) 之后的数据集的信息, 和将数据送入[数据集变换流水线](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/datasets/basesegdataset.py#L141) 中, 进行[数据变换操作](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/zh_cn/advanced_guides/transforms.md). 加载的数据集信息包括两类: 元信息 (meta information), 数据集本身的信息, 例如数据集总共的类别, 和它们对应调色盘信息: 数据信息 (data information) 是指每组数据中图片和对应标签的路径. 下文中介绍了 MMSegmentation 1.x 中数据集的常用接口, 和 mmseg 数据集基类中数据信息加载与修改数据集类别的逻辑, 以及数据集与数据变换流水线 (pipeline) 的关系. diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index 8ae2574afe..64bddf8081 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -18,9 +18,10 @@ from .pascal_context import PascalContextDataset, PascalContextDataset59 from .potsdam import PotsdamDataset from .stare import STAREDataset +# yapf: disable from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop, BioMedicalGaussianBlur, BioMedicalGaussianNoise, - GenerateEdge, LoadAnnotations, + BioMedicalRandomGamma, GenerateEdge, LoadAnnotations, LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalImageFromFile, LoadImageFromNDArray, PackSegInputs, PhotoMetricDistortion, RandomCrop, @@ -30,7 +31,6 @@ from .voc import PascalVOCDataset # yapf: enable - __all__ = [ 'BaseSegDataset', 'BioMedical3DRandomCrop', 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', @@ -44,5 +44,6 @@ 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge', - 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur' + 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', + 'BioMedicalRandomGamma' ] diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py index 4ea3d81c98..e584377368 100644 --- a/mmseg/datasets/transforms/__init__.py +++ b/mmseg/datasets/transforms/__init__.py @@ -6,8 +6,9 @@ # yapf: disable from .transforms import (CLAHE, AdjustGamma, BioMedical3DRandomCrop, BioMedicalGaussianBlur, BioMedicalGaussianNoise, - GenerateEdge, PhotoMetricDistortion, RandomCrop, - RandomCutOut, RandomMosaic, RandomRotate, Rerange, + BioMedicalRandomGamma, GenerateEdge, + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomMosaic, RandomRotate, Rerange, ResizeShortestEdge, ResizeToMultiple, RGB2Gray, SegRescale) @@ -18,5 +19,6 @@ 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', - 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur' + 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', + 'BioMedicalRandomGamma' ] diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index 94e2643473..5d25e12641 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -1686,3 +1686,122 @@ def __repr__(self): repr_str += 'different_sigma_per_axis='\ f'{self.different_sigma_per_axis})' return repr_str + + +@TRANSFORMS.register_module() +class BioMedicalRandomGamma(BaseTransform): + """Using random gamma correction to process the biomedical image. + + Modified from + https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501 + With licence: Apache 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + - img + + Args: + prob (float): The probability to perform this transform. Default: 0.5. + gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2). + invert_image (bool): Whether invert the image before applying gamma + augmentation. Default: False. + per_channel (bool): Whether perform the transform each channel + individually. Default: False + retain_stats (bool): Gamma transformation will alter the mean and std + of the data in the patch. If retain_stats=True, the data will be + transformed to match the mean and standard deviation before gamma + augmentation. Default: False. + """ + + def __init__(self, + prob: float = 0.5, + gamma_range: Tuple[float] = (0.5, 2), + invert_image: bool = False, + per_channel: bool = False, + retain_stats: bool = False): + assert 0 <= prob and prob <= 1 + assert isinstance(gamma_range, tuple) and len(gamma_range) == 2 + assert isinstance(invert_image, bool) + assert isinstance(per_channel, bool) + assert isinstance(retain_stats, bool) + self.prob = prob + self.gamma_range = gamma_range + self.invert_image = invert_image + self.per_channel = per_channel + self.retain_stats = retain_stats + + @cache_randomness + def _do_gamma(self): + """Whether do adjust gamma for image.""" + return np.random.rand() < self.prob + + def _adjust_gamma(self, img: np.array): + """Gamma adjustment for image. + + Args: + img (np.array): Input image before gamma adjust. + + Returns: + np.arrays: Image after gamma adjust. + """ + + if self.invert_image: + img = -img + + def _do_adjust(img): + if retain_stats_here: + img_mean = img.mean() + img_std = img.std() + if np.random.random() < 0.5 and self.gamma_range[0] < 1: + gamma = np.random.uniform(self.gamma_range[0], 1) + else: + gamma = np.random.uniform( + max(self.gamma_range[0], 1), self.gamma_range[1]) + img_min = img.min() + img_range = img.max() - img_min # range + img = np.power(((img - img_min) / float(img_range + 1e-7)), + gamma) * img_range + img_min + if retain_stats_here: + img = img - img.mean() + img = img / (img.std() + 1e-8) * img_std + img = img + img_mean + return img + + if not self.per_channel: + retain_stats_here = self.retain_stats + img = _do_adjust(img) + else: + for c in range(img.shape[0]): + img[c] = _do_adjust(img[c]) + if self.invert_image: + img = -img + return img + + def transform(self, results: dict) -> dict: + """Call function to perform random gamma correction + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with random gamma correction performed. + """ + do_gamma = self._do_gamma() + + if do_gamma: + results['img'] = self._adjust_gamma(results['img']) + else: + pass + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'gamma_range={self.gamma_range},' + repr_str += f'invert_image={self.invert_image},' + repr_str += f'per_channel={self.per_channel},' + repr_str += f'retain_stats={self.retain_stats}' + return repr_str diff --git a/tests/test_datasets/test_transform.py b/tests/test_datasets/test_transform.py index 397d1c0da6..c218ebd666 100644 --- a/tests/test_datasets/test_transform.py +++ b/tests/test_datasets/test_transform.py @@ -8,7 +8,8 @@ from PIL import Image from mmseg.datasets.transforms import * # noqa -from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop +from mmseg.datasets.transforms import (LoadBiomedicalImageFromFile, + PhotoMetricDistortion, RandomCrop) from mmseg.registry import TRANSFORMS from mmseg.utils import register_all_modules @@ -886,3 +887,67 @@ def test_biomedical_gaussian_blur(): # the max value in the smoothed image should be less than the original one assert original_img.max() >= results['img'].max() assert original_img.min() <= results['img'].min() + + +def test_BioMedicalRandomGamma(): + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', prob=-1, gamma_range=(0.7, 1.2)) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', prob=1.2, gamma_range=(0.7, 1.2)) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', prob=1.0, gamma_range=(0.7)) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', + prob=1.0, + gamma_range=(0.7, 0.2, 0.3)) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', + prob=1.0, + gamma_range=(0.7, 2), + invert_image=1) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', + prob=1.0, + gamma_range=(0.7, 2), + per_channel=1) + TRANSFORMS.build(transform) + + with pytest.raises(AssertionError): + transform = dict( + type='BioMedicalRandomGamma', + prob=1.0, + gamma_range=(0.7, 2), + retain_stats=1) + TRANSFORMS.build(transform) + + test_img = 'tests/data/biomedical.nii.gz' + results = dict(img_path=test_img) + transform = LoadBiomedicalImageFromFile() + results = transform(copy.deepcopy(results)) + origin_img = results['img'] + transform2 = dict( + type='BioMedicalRandomGamma', + prob=1.0, + gamma_range=(0.7, 2), + ) + transform2 = TRANSFORMS.build(transform2) + results = transform2(results) + transformed_img = results['img'] + assert origin_img.shape == transformed_img.shape