diff --git a/docs/en/api/transforms.rst b/docs/en/api/transforms.rst index a7714cc42..dbc745a33 100644 --- a/docs/en/api/transforms.rst +++ b/docs/en/api/transforms.rst @@ -34,7 +34,6 @@ TextDet Transforms :template: classtemplate.rst BoundedScaleAspectJitter - FixInvalidPolygon RandomFlip SourceImagePad ShortScaleAspectJitter @@ -50,6 +49,10 @@ TextRecog Transforms :nosignatures: :template: classtemplate.rst + TextRecogGeneralAug + CropHeight + ImageContentJitter + ReversePixels PyramidRescale PadToWidth RescaleToHeight @@ -66,6 +69,8 @@ OCR Transforms RandomCrop RandomRotate Resize + FixInvalidPolygon + RemoveIgnored diff --git a/docs/zh_cn/api/transforms.rst b/docs/zh_cn/api/transforms.rst index a7714cc42..dbc745a33 100644 --- a/docs/zh_cn/api/transforms.rst +++ b/docs/zh_cn/api/transforms.rst @@ -34,7 +34,6 @@ TextDet Transforms :template: classtemplate.rst BoundedScaleAspectJitter - FixInvalidPolygon RandomFlip SourceImagePad ShortScaleAspectJitter @@ -50,6 +49,10 @@ TextRecog Transforms :nosignatures: :template: classtemplate.rst + TextRecogGeneralAug + CropHeight + ImageContentJitter + ReversePixels PyramidRescale PadToWidth RescaleToHeight @@ -66,6 +69,8 @@ OCR Transforms RandomCrop RandomRotate Resize + FixInvalidPolygon + RemoveIgnored diff --git a/mmocr/datasets/transforms/__init__.py b/mmocr/datasets/transforms/__init__.py index a1e51950a..696b2ab53 100644 --- a/mmocr/datasets/transforms/__init__.py +++ b/mmocr/datasets/transforms/__init__.py @@ -9,7 +9,9 @@ from .textdet_transforms import (BoundedScaleAspectJitter, RandomFlip, ShortScaleAspectJitter, SourceImagePad, TextDetRandomCrop, TextDetRandomCropFlip) -from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight +from .textrecog_transforms import (CropHeight, ImageContentJitter, PadToWidth, + PyramidRescale, RescaleToHeight, + ReversePixels, TextRecogGeneralAug) from .wrappers import ConditionApply, ImgAugWrapper, TorchVisionWrapper __all__ = [ @@ -20,5 +22,6 @@ 'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter', 'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR', 'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile', - 'LoadImageFromNDArray', 'RemoveIgnored', 'ConditionApply' + 'LoadImageFromNDArray', 'CropHeight', 'TextRecogGeneralAug', + 'ImageContentJitter', 'ReversePixels', 'RemoveIgnored', 'ConditionApply' ] diff --git a/mmocr/datasets/transforms/textrecog_transforms.py b/mmocr/datasets/transforms/textrecog_transforms.py index ce42056ed..abb094c31 100644 --- a/mmocr/datasets/transforms/textrecog_transforms.py +++ b/mmocr/datasets/transforms/textrecog_transforms.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import math -from typing import Dict, Optional, Tuple +import random +from typing import Dict, List, Optional, Tuple import cv2 import mmcv @@ -251,3 +252,473 @@ def __repr__(self) -> str: repr_str += f'(width={self.width}, ' repr_str += f'pad_cfg={self.pad_cfg})' return repr_str + + +@TRANSFORMS.register_module() +class TextRecogGeneralAug(BaseTransform): + """A general geometric augmentation tool for text images in the CVPR 2020 + paper "Learn to Augment: Joint Data Augmentation and Network Optimization + for Text Recognition". It applies distortion, stretching, and perspective + transforms to an image. + + This implementation is adapted from + https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py # noqa + + TODO: Split this transform into three transforms. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_shape + """ # noqa + + def transform(self, results: Dict) -> Dict: + """Call function to pad images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + h, w = results['img'].shape[:2] + if h >= 20 and w >= 20: + results['img'] = self.tia_distort(results['img'], + random.randint(3, 6)) + results['img'] = self.tia_stretch(results['img'], + random.randint(3, 6)) + h, w = results['img'].shape[:2] + if h >= 5 and w >= 5: + results['img'] = self.tia_perspective(results['img']) + results['img_shape'] = results['img'].shape[:2] + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += '()' + return repr_str + + def tia_distort(self, img: np.ndarray, segment: int = 4) -> np.ndarray: + """Image distortion. + + Args: + img (np.ndarray): The image. + segment (int): The number of segments to divide the image along + the width. Defaults to 4. + """ + img_h, img_w = img.shape[:2] + + cut = img_w // segment + thresh = cut // 3 + + src_pts = list() + dst_pts = list() + + src_pts.append([0, 0]) + src_pts.append([img_w, 0]) + src_pts.append([img_w, img_h]) + src_pts.append([0, img_h]) + + dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)]) + dst_pts.append( + [img_w - np.random.randint(thresh), + np.random.randint(thresh)]) + dst_pts.append([ + img_w - np.random.randint(thresh), + img_h - np.random.randint(thresh) + ]) + dst_pts.append( + [np.random.randint(thresh), img_h - np.random.randint(thresh)]) + + half_thresh = thresh * 0.5 + + for cut_idx in np.arange(1, segment, 1): + src_pts.append([cut * cut_idx, 0]) + src_pts.append([cut * cut_idx, img_h]) + dst_pts.append([ + cut * cut_idx + np.random.randint(thresh) - half_thresh, + np.random.randint(thresh) - half_thresh + ]) + dst_pts.append([ + cut * cut_idx + np.random.randint(thresh) - half_thresh, + img_h + np.random.randint(thresh) - half_thresh + ]) + + dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h) + + return dst + + def tia_stretch(self, img: np.ndarray, segment: int = 4) -> np.ndarray: + """Image stretching. + + Args: + img (np.ndarray): The image. + segment (int): The number of segments to divide the image along + the width. Defaults to 4. + """ + img_h, img_w = img.shape[:2] + + cut = img_w // segment + thresh = cut * 4 // 5 + + src_pts = list() + dst_pts = list() + + src_pts.append([0, 0]) + src_pts.append([img_w, 0]) + src_pts.append([img_w, img_h]) + src_pts.append([0, img_h]) + + dst_pts.append([0, 0]) + dst_pts.append([img_w, 0]) + dst_pts.append([img_w, img_h]) + dst_pts.append([0, img_h]) + + half_thresh = thresh * 0.5 + + for cut_idx in np.arange(1, segment, 1): + move = np.random.randint(thresh) - half_thresh + src_pts.append([cut * cut_idx, 0]) + src_pts.append([cut * cut_idx, img_h]) + dst_pts.append([cut * cut_idx + move, 0]) + dst_pts.append([cut * cut_idx + move, img_h]) + + dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h) + + return dst + + def tia_perspective(self, img: np.ndarray) -> np.ndarray: + """Image perspective transformation. + + Args: + img (np.ndarray): The image. + segment (int): The number of segments to divide the image along + the width. Defaults to 4. + """ + img_h, img_w = img.shape[:2] + + thresh = img_h // 2 + + src_pts = list() + dst_pts = list() + + src_pts.append([0, 0]) + src_pts.append([img_w, 0]) + src_pts.append([img_w, img_h]) + src_pts.append([0, img_h]) + + dst_pts.append([0, np.random.randint(thresh)]) + dst_pts.append([img_w, np.random.randint(thresh)]) + dst_pts.append([img_w, img_h - np.random.randint(thresh)]) + dst_pts.append([0, img_h - np.random.randint(thresh)]) + + dst = self.warp_mls(img, src_pts, dst_pts, img_w, img_h) + + return dst + + def warp_mls(self, + src: np.ndarray, + src_pts: List[int], + dst_pts: List[int], + dst_w: int, + dst_h: int, + trans_ratio: float = 1.) -> np.ndarray: + """Warp the image.""" + rdx, rdy = self._calc_delta(dst_w, dst_h, src_pts, dst_pts, 100) + return self._gen_img(src, rdx, rdy, dst_w, dst_h, 100, trans_ratio) + + def _calc_delta(self, dst_w: int, dst_h: int, src_pts: List[int], + dst_pts: List[int], + grid_size: int) -> Tuple[np.ndarray, np.ndarray]: + """Compute delta.""" + + pt_count = len(dst_pts) + rdx = np.zeros((dst_h, dst_w)) + rdy = np.zeros((dst_h, dst_w)) + w = np.zeros(pt_count, dtype=np.float32) + + if pt_count < 2: + return + + i = 0 + while True: + if dst_w <= i < dst_w + grid_size - 1: + i = dst_w - 1 + elif i >= dst_w: + break + + j = 0 + while True: + if dst_h <= j < dst_h + grid_size - 1: + j = dst_h - 1 + elif j >= dst_h: + break + + sw = 0 + swp = np.zeros(2, dtype=np.float32) + swq = np.zeros(2, dtype=np.float32) + new_pt = np.zeros(2, dtype=np.float32) + cur_pt = np.array([i, j], dtype=np.float32) + + k = 0 + for k in range(pt_count): + if i == dst_pts[k][0] and j == dst_pts[k][1]: + break + + w[k] = 1. / ((i - dst_pts[k][0]) * (i - dst_pts[k][0]) + + (j - dst_pts[k][1]) * (j - dst_pts[k][1])) + + sw += w[k] + swp = swp + w[k] * np.array(dst_pts[k]) + swq = swq + w[k] * np.array(src_pts[k]) + + if k == pt_count - 1: + pstar = 1 / sw * swp + qstar = 1 / sw * swq + + miu_s = 0 + for k in range(pt_count): + if i == dst_pts[k][0] and j == dst_pts[k][1]: + continue + pt_i = dst_pts[k] - pstar + miu_s += w[k] * np.sum(pt_i * pt_i) + + cur_pt -= pstar + cur_pt_j = np.array([-cur_pt[1], cur_pt[0]]) + + for k in range(pt_count): + if i == dst_pts[k][0] and j == dst_pts[k][1]: + continue + + pt_i = dst_pts[k] - pstar + pt_j = np.array([-pt_i[1], pt_i[0]]) + + tmp_pt = np.zeros(2, dtype=np.float32) + tmp_pt[0] = ( + np.sum(pt_i * cur_pt) * src_pts[k][0] - + np.sum(pt_j * cur_pt) * src_pts[k][1]) + tmp_pt[1] = (-np.sum(pt_i * cur_pt_j) * src_pts[k][0] + + np.sum(pt_j * cur_pt_j) * src_pts[k][1]) + tmp_pt *= (w[k] / miu_s) + new_pt += tmp_pt + + new_pt += qstar + else: + new_pt = src_pts[k] + + rdx[j, i] = new_pt[0] - i + rdy[j, i] = new_pt[1] - j + + j += grid_size + i += grid_size + return rdx, rdy + + def _gen_img(self, src: np.ndarray, rdx: np.ndarray, rdy: np.ndarray, + dst_w: int, dst_h: int, grid_size: int, + trans_ratio: float) -> np.ndarray: + """Generate the image based on delta.""" + + src_h, src_w = src.shape[:2] + dst = np.zeros_like(src, dtype=np.float32) + + for i in np.arange(0, dst_h, grid_size): + for j in np.arange(0, dst_w, grid_size): + ni = i + grid_size + nj = j + grid_size + w = h = grid_size + if ni >= dst_h: + ni = dst_h - 1 + h = ni - i + 1 + if nj >= dst_w: + nj = dst_w - 1 + w = nj - j + 1 + + di = np.reshape(np.arange(h), (-1, 1)) + dj = np.reshape(np.arange(w), (1, -1)) + delta_x = self._bilinear_interp(di / h, dj / w, rdx[i, j], + rdx[i, nj], rdx[ni, j], + rdx[ni, nj]) + delta_y = self._bilinear_interp(di / h, dj / w, rdy[i, j], + rdy[i, nj], rdy[ni, j], + rdy[ni, nj]) + nx = j + dj + delta_x * trans_ratio + ny = i + di + delta_y * trans_ratio + nx = np.clip(nx, 0, src_w - 1) + ny = np.clip(ny, 0, src_h - 1) + nxi = np.array(np.floor(nx), dtype=np.int32) + nyi = np.array(np.floor(ny), dtype=np.int32) + nxi1 = np.array(np.ceil(nx), dtype=np.int32) + nyi1 = np.array(np.ceil(ny), dtype=np.int32) + + if len(src.shape) == 3: + x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3)) + y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3)) + else: + x = ny - nyi + y = nx - nxi + dst[i:i + h, + j:j + w] = self._bilinear_interp(x, y, src[nyi, nxi], + src[nyi, nxi1], + src[nyi1, nxi], src[nyi1, + nxi1]) + + dst = np.clip(dst, 0, 255) + dst = np.array(dst, dtype=np.uint8) + + return dst + + @staticmethod + def _bilinear_interp(x, y, v11, v12, v21, v22): + """Bilinear interpolation. + + TODO: Docs for args and put it into utils. + """ + return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 * + (1 - y) + v22 * y) * x + + +@TRANSFORMS.register_module() +class CropHeight(BaseTransform): + """Randomly crop the image's height, either from top or bottom. + + Adapted from + https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa + + Required Keys: + + - img + + Modified Keys: + + - img + - img_shape + + Args: + crop_min (int): Minimum pixel(s) to crop. Defaults to 1. + crop_max (int): Maximum pixel(s) to crop. Defaults to 8. + """ + + def __init__( + self, + min_pixels: int = 1, + max_pixels: int = 8, + ) -> None: + super().__init__() + assert max_pixels >= min_pixels + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + @cache_randomness + def get_random_vars(self): + """Get all the random values used in this transform.""" + crop_pixels = int(random.randint(self.min_pixels, self.max_pixels)) + crop_top = random.randint(0, 1) + return crop_pixels, crop_top + + def transform(self, results: Dict) -> Dict: + """Transform function to crop images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Cropped results. + """ + h = results['img'].shape[0] + crop_pixels, crop_top = self.get_random_vars() + crop_pixels = min(crop_pixels, h - 1) + img = results['img'].copy() + if crop_top: + img = img[crop_pixels:h, :, :] + else: + img = img[0:h - crop_pixels, :, :] + results['img_shape'] = img.shape[:2] + results['img'] = img + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(min_pixels = {self.min_pixels}, ' + repr_str += f'max_pixels = {self.max_pixels})' + return repr_str + + +@TRANSFORMS.register_module() +class ImageContentJitter(BaseTransform): + """Jitter the image contents. + + Adapted from + https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa + + Required Keys: + + - img + + Modified Keys: + + - img + """ + + def transform(self, results: Dict, jitter_ratio: float = 0.01) -> Dict: + """Transform function to jitter images. + + Args: + results (dict): Result dict from loading pipeline. + jitter_ratio (float): Controls the strength of jittering. + Defaults to 0.01. + + Returns: + dict: Jittered results. + """ + h, w = results['img'].shape[:2] + img = results['img'].copy() + if h > 10 and w > 10: + thres = min(h, w) + jitter_range = int(random.random() * thres * 0.01) + for i in range(jitter_range): + img[i:, i:, :] = img[:h - i, :w - i, :] + results['img'] = img + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += '()' + return repr_str + + +@TRANSFORMS.register_module() +class ReversePixels(BaseTransform): + """Reverse image pixels. + + Adapted from + https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/ppocr/data/imaug/rec_img_aug.py # noqa + + Required Keys: + + - img + + Modified Keys: + + - img + """ + + def transform(self, results: Dict) -> Dict: + """Transform function to reverse image pixels. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Reversed results. + """ + results['img'] = 255. - results['img'].copy() + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += '()' + return repr_str diff --git a/tests/test_datasets/test_transforms/test_textrecog_transforms.py b/tests/test_datasets/test_transforms/test_textrecog_transforms.py index c68277766..b602bc6c5 100644 --- a/tests/test_datasets/test_transforms/test_textrecog_transforms.py +++ b/tests/test_datasets/test_transforms/test_textrecog_transforms.py @@ -3,9 +3,12 @@ import unittest import numpy as np +from parameterized import parameterized -from mmocr.datasets.transforms import (PadToWidth, PyramidRescale, - RescaleToHeight) +from mmocr.datasets.transforms import (CropHeight, ImageContentJitter, + PadToWidth, PyramidRescale, + RescaleToHeight, ReversePixels, + TextRecogGeneralAug) class TestPadToWidth(unittest.TestCase): @@ -125,3 +128,85 @@ def test_repr(self): 'min_width=None, max_width=None, ' 'width_divisor=1, ' "resize_cfg={'type': 'Resize', 'scale': 0})")) + + +class TestTextRecogGeneralAug(unittest.TestCase): + + def setUp(self) -> None: + self.transform = TextRecogGeneralAug() + + @parameterized.expand([(np.random.random((3, 3, 3)), ), + (np.random.random((10, 10, 3)), ), + (np.random.random((30, 30, 3)), )]) + def test_transform(self, img): + data_info = dict(img=img) + results = self.transform(copy.deepcopy(data_info)) + self.assertEqual(results['img'].shape[:2], results['img_shape']) + + def test_repr(self): + repr_str = self.transform.__repr__() + self.assertEqual(repr_str, 'TextRecogGeneralAug()') + + +class TestCropHeight(unittest.TestCase): + + def setUp(self) -> None: + self.data_info = dict(img=np.random.random((20, 20, 3))) + + @parameterized.expand([ + (3, 3), + (5, 10), + ]) + def test_transform(self, min_pixels, max_pixels): + self.transform = CropHeight( + min_pixels=min_pixels, max_pixels=max_pixels) + results = self.transform(copy.deepcopy(self.data_info)) + self.assertEqual(results['img'].shape[:2], results['img_shape']) + h_diff = self.data_info['img'].shape[0] - results['img_shape'][0] + self.assertGreaterEqual(h_diff, min_pixels) + self.assertLessEqual(h_diff, max_pixels) + + def test_invalid(self): + with self.assertRaises(AssertionError): + self.transform = CropHeight(min_pixels=10, max_pixels=9) + + def test_repr(self): + transform = CropHeight(min_pixels=2, max_pixels=10) + repr_str = transform.__repr__() + self.assertEqual(repr_str, 'CropHeight(min_pixels = 2, ' + 'max_pixels = 10)') + + +class TestImageContentJitter(unittest.TestCase): + + def setUp(self) -> None: + self.transform = ImageContentJitter() + + @parameterized.expand([(np.random.random((3, 3, 3)), ), + (np.random.random((10, 10, 3)), ), + (np.random.random((30, 30, 3)), )]) + def test_transform(self, img): + data_info = dict(img=img) + self.transform(copy.deepcopy(data_info)) + + def test_repr(self): + repr_str = self.transform.__repr__() + self.assertEqual(repr_str, 'ImageContentJitter()') + + +class TestReversePixels(unittest.TestCase): + + def setUp(self) -> None: + self.transform = ReversePixels() + + @parameterized.expand([(np.random.random((3, 3, 3)), ), + (np.random.random((10, 10, 3)), ), + (np.random.random((30, 30, 3)), )]) + def test_transform(self, img): + data_info = dict(img=img) + results = self.transform(copy.deepcopy(data_info)) + self.assertTrue(np.array_equal(results['img'], 255. - img)) + + def test_repr(self): + repr_str = self.transform.__repr__() + self.assertEqual(repr_str, 'ReversePixels()')