-
Notifications
You must be signed in to change notification settings - Fork 400
/
augmentation_primitives.py
371 lines (280 loc) · 11.7 KB
/
augmentation_primitives.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Helper functions to perform augmentations on a :class:`PIL.Image.Image`.
Augmentations that take an intensity value are normalized on a scale of 1-10,
where 10 is the strongest and maximum value an augmentation function will accept.
Adapted from
`AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty
<https://github.com/google-research/augmix/blob/master/augmentations.py>`_.
Attributes:
AugmentationFn ((PIL.Image.Image, float) -> PIL.Image.Image):
The type annotation for describing an augmentation function.
Each augmentation takes a :class:`PIL.Image.Image` and an intensity level in the range ``[0, 10]``,
and returns an augmented image.
augmentation_sets (Dict[str, List[AugmentationFn]]): The collection of all augmentations.
This dictionary has the following entries:
* ``augmentation_sets["safe"]`` contains augmentations that do not overlap with
ImageNet-C/CIFAR10-C test sets.
* ``augmentation_sets["original"]`` contains augmentations that use the original
implementations of enhancing color, contrast, brightness, and sharpness.
* ``augmentation_sets["all"]`` contains all augmentations.
"""
from typing import Callable
import numpy as np
from PIL import Image, ImageEnhance, ImageOps
AugmentationFn = Callable[[Image.Image, float], Image.Image]
__all__ = [
'AugmentationFn',
'autocontrast',
'equalize',
'posterize',
'rotate',
'solarize',
'shear_x',
'shear_y',
'translate_x',
'translate_y',
'color',
'color_original',
'contrast',
'contrast_original',
'brightness',
'brightness_original',
'sharpness',
'sharpness_original',
'augmentation_sets',
]
def _int_parameter(level: float, maxval: float):
"""Helper function to scale a value between ``0`` and ``maxval`` and return as an int.
Args:
level (float): Level of the operation that will be between ``[0, 10]``.
maxval (float): Maximum value that the operation can have. This will be scaled to
``level/10``.
Returns:
int: The result from scaling ``maxval`` according to ``level``.
"""
return int(level * maxval / 10)
def _float_parameter(level: float, maxval: float):
"""Helper function to scale a value between ``0`` and ``maxval`` and return as a float.
Args:
level (float): Level of the operation that will be between [0, 10].
maxval (float): Maximum value that the operation can have. This will be scaled to
``level/10``.
Returns:
float: The result from scaling ``maxval`` according to ``level``.
"""
return float(level) * maxval / 10.
def _sample_level(n: float):
"""Helper function to sample from a uniform distribution between ``0.1`` and some value ``n``."""
return np.random.uniform(low=0.1, high=n)
def _symmetric_sample(level: float):
"""Helper function to sample from a symmetric distribution.
The distribution over the domain [0.1, 10] with ``median == 1`` and uniform probability of ``x | 0.1 ≤ x ≤ 1``,
and ``x | 1 ≤ x ≤ 10``.
Used for sampling transforms that can range from intensity 0 to infinity and for which an intensity
of 1 meaning no change.
"""
if np.random.uniform() > 0.5:
return np.random.uniform(1, level)
else:
return np.random.uniform(1 - (0.09 * level), 1)
def autocontrast(pil_img: Image.Image, level: float = 0.0):
"""Autocontrast an image.
.. seealso:: :func:`PIL.ImageOps.autocontrast`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity.
"""
del level # unused
return ImageOps.autocontrast(pil_img)
def equalize(pil_img: Image.Image, level: float):
"""Equalize an image.
.. seealso:: :func:`PIL.ImageOps.equalize`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity.
"""
del level # unused
return ImageOps.equalize(pil_img)
def posterize(pil_img: Image.Image, level: float):
"""Posterize an image.
.. seealso:: :func:`PIL.ImageOps.posterize`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _int_parameter(_sample_level(level), 4)
return ImageOps.posterize(pil_img, 4 - level)
def rotate(pil_img: Image.Image, level: float):
"""Rotate an image.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
degrees = _int_parameter(_sample_level(level), 30)
if np.random.uniform() > 0.5:
degrees = -degrees
return pil_img.rotate(degrees, resample=Image.BILINEAR)
def solarize(pil_img: Image.Image, level: float):
"""Solarize an image.
.. seealso:: :func:`PIL.ImageOps.solarize`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _int_parameter(_sample_level(level), 256)
return ImageOps.solarize(pil_img, 256 - level)
def shear_x(pil_img: Image.Image, level: float):
"""Shear an image horizontally.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 0.3)
if np.random.uniform() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, level, 0, 0, 1, 0), resample=Image.BILINEAR)
def shear_y(pil_img: Image.Image, level: float):
"""Shear an image vertically.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 0.3)
if np.random.uniform() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, level, 1, 0), resample=Image.BILINEAR)
def translate_x(pil_img: Image.Image, level: float):
"""Shear an image horizontally.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _int_parameter(_sample_level(level), pil_img.size[0] / 3)
if np.random.random() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, level, 0, 1, 0), resample=Image.BILINEAR)
def translate_y(pil_img: Image.Image, level: float):
"""Shear an image vertically.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _int_parameter(_sample_level(level), pil_img.size[1] / 3)
if np.random.random() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, 0, 1, level), resample=Image.BILINEAR)
# The following augmentations overlap with corruptions in the ImageNet-C/CIFAR10-C test
# sets. Their original implementations also have an intensity sampling scheme that
# samples a value bounded by 0.118 at a minimum, and a maximum value of intensity*0.18+
# 0.1, which ranged from 0.28 (intensity = 1) to 1.9 (intensity 10). These augmentations
# have different effects depending on whether they are < 0 or > 0, so the original
# sampling scheme does not make sense to me. Accordingly, I replaced it with the
# _symmetric_sample() above.
def color(pil_img: Image.Image, level: float):
"""Enhance color on an image.
.. seealso:: :class:`PIL.ImageEnhance.Color`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _symmetric_sample(level)
return ImageEnhance.Color(pil_img).enhance(level)
def color_original(pil_img: Image.Image, level: float):
"""Enhance color on an image, following the
corruptions in the ImageNet-C/CIFAR10-C test sets.
.. seealso :class:`PIL.ImageEnhance.Color`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 1.8) + 0.1
return ImageEnhance.Color(pil_img).enhance(level)
def contrast(pil_img: Image.Image, level: float):
"""Enhance contrast on an image.
.. seealso:: :class:`PIL.ImageEnhance.Contrast`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _symmetric_sample(level)
return ImageEnhance.Contrast(pil_img).enhance(level)
def contrast_original(pil_img: Image.Image, level: float):
"""Enhance contrast on an image, following the
corruptions in the ImageNet-C/CIFAR10-C test sets.
.. seealso:: :class:`PIL.ImageEnhance.Contrast`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 1.8) + 0.1
return ImageEnhance.Contrast(pil_img).enhance(level)
def brightness(pil_img: Image.Image, level: float):
"""Enhance brightness on an image.
.. seealso:: :class:`PIL.ImageEnhance.Brightness`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should be
in ``[0, 10]``.
"""
level = _symmetric_sample(level)
# Reduce intensity of brightness increases
if level > 1:
level = level * .75
return ImageEnhance.Brightness(pil_img).enhance(level)
def brightness_original(pil_img: Image.Image, level: float):
"""Enhance brightness on an image, following the
corruptions in the ImageNet-C/CIFAR10-C test sets.
.. seealso:: :class:`PIL.ImageEnhance.Brightness`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 1.8) + 0.1
return ImageEnhance.Brightness(pil_img).enhance(level)
def sharpness(pil_img: Image.Image, level: float):
"""Enhance sharpness on an image.
.. seealso:: :class:`PIL.ImageEnhance.Sharpness`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _symmetric_sample(level)
return ImageEnhance.Sharpness(pil_img).enhance(level)
def sharpness_original(pil_img: Image.Image, level: float):
"""Enhance sharpness on an image, following the
corruptions in the ImageNet-C/CIFAR10-C test sets.
.. seealso:: :class:`PIL.ImageEnhance.Sharpness`.
Args:
pil_img (PIL.Image.Image): The image.
level (float): The intensity, which should
be in ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 1.8) + 0.1
return ImageEnhance.Sharpness(pil_img).enhance(level)
augmentation_sets = {
'all': [
autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y, color,
contrast, brightness, sharpness
],
# Augmentations that don't overlap with ImageNet-C/CIFAR10-C test sets
'safe': [autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y],
# Augmentations that use original implementations of color, contrast, brightness, and sharpness
'original': [
autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y, color_original,
contrast_original, brightness_original, sharpness_original
],
}