Skip to content

Commit

Permalink
Add random_gamma to the augmentation functions
Browse files Browse the repository at this point in the history
random_gamma is just a wrapper for `adjust_gamma` where the value of the
gamma parameter is sampled uniformly in the given range, similarly to
what other augmentation functions already do in the library.
  • Loading branch information
alonfnt committed Oct 4, 2022
1 parent cb1e16b commit 49a4a6a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 0 deletions.
1 change: 1 addition & 0 deletions dm_pix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
random_crop = augment.random_crop
random_flip_left_right = augment.random_flip_left_right
random_flip_up_down = augment.random_flip_up_down
random_gamma = augment.random_gamma
random_hue = augment.random_hue
random_saturation = augment.random_saturation
rotate = augment.rotate
Expand Down
12 changes: 12 additions & 0 deletions dm_pix/_src/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,18 @@ def random_brightness(
delta = jax.random.uniform(key, (), minval=-max_delta, maxval=max_delta)
return adjust_brightness(image, delta)

def random_gamma(
key: chex.PRNGKey,
image: chex.Array,
min_gamma: chex.Numeric,
max_gamma: chex.Numeric,
*,
gain: chex.Numeric = 1,
assume_in_bounds: bool = False,
) -> chex.Array:
"""`adjust_gamma(...)` with random gamma in [min_gamma, max_gamma)`."""
gamma = jax.random.uniform(key, (), minval=min_gamma, maxval=max_gamma)
return adjust_gamma(image, gamma, gain=gain, assume_in_bounds=assume_in_bounds)

def random_hue(
key: chex.PRNGKey,
Expand Down
6 changes: 6 additions & 0 deletions dm_pix/_src/augment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ def test_adjust_gamma(self, images_list):
jax_fn=augment.adjust_gamma,
reference_fn=tf.image.adjust_gamma,
gamma=(0.5, 1.5))
key = jax.random.PRNGKey(0)
self._test_fn_with_random_arg(
images_list,
jax_fn=functools.partial(augment.random_gamma, key, min_gamma=1),
reference_fn=None,
max_gamma=(1.5, 1.9))

@parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE),
("out_of_range", _RAND_FLOATS_OUT_OF_RANGE))
Expand Down
6 changes: 6 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Augmentations
random_crop
random_flip_left_right
random_flip_up_down
random_gamma
random_hue
random_saturation
rotate
Expand Down Expand Up @@ -101,6 +102,11 @@ random_flip_up_down

.. autofunction:: random_flip_up_down

random_gamma
~~~~~~~~~~~~~~~~~~~

.. autofunction:: random_gamma

random_hue
~~~~~~~~~~

Expand Down

0 comments on commit 49a4a6a

Please sign in to comment.