diff --git a/image/randaugment/augmentation_transforms.py b/image/randaugment/augmentation_transforms.py index 0c85955..58ed346 100644 --- a/image/randaugment/augmentation_transforms.py +++ b/image/randaugment/augmentation_transforms.py @@ -49,6 +49,11 @@ def get_mean_and_std(): return means, stds +def _width_height_from_img_shape(img_shape): + """`img_shape` in autoaugment is (height, width).""" + return (img_shape[1], img_shape[0]) + + def random_flip(x): """Flip the input x horizontally with 50% probability.""" if np.random.rand(1)[0] > 0.5: @@ -315,7 +320,7 @@ def _shear_x_impl(pil_img, level, img_shape): if random.random() > 0.5: level = -level return pil_img.transform( - (img_shape[0], img_shape[1]), + _width_height_from_img_shape(img_shape), Image.AFFINE, (1, level, 0, 0, 1, 0)) @@ -341,7 +346,7 @@ def _shear_y_impl(pil_img, level, img_shape): if random.random() > 0.5: level = -level return pil_img.transform( - (img_shape[0], img_shape[1]), + _width_height_from_img_shape(img_shape), Image.AFFINE, (1, 0, 0, level, 1, 0)) @@ -367,7 +372,7 @@ def _translate_x_impl(pil_img, level, img_shape): if random.random() > 0.5: level = -level return pil_img.transform( - (img_shape[0], img_shape[1]), + _width_height_from_img_shape(img_shape), Image.AFFINE, (1, 0, level, 0, 1, 0)) @@ -393,7 +398,7 @@ def _translate_y_impl(pil_img, level, img_shape): if random.random() > 0.5: level = -level return pil_img.transform( - (img_shape[0], img_shape[1]), + _width_height_from_img_shape(img_shape), Image.AFFINE, (1, 0, 0, 0, 1, level))