Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add probability param for random flip functions. #38

Merged
merged 9 commits into from May 20, 2022
66 changes: 56 additions & 10 deletions dm_pix/_src/augment.py
Expand Up @@ -160,7 +160,7 @@ def flip_left_right(
Assumes that the image is either ...HWC or ...CHW and flips the W axis.

Args:
image: an RGB image, given as a float tensor in [0, 1].
image: a JAX array which represents an image.
pabloduque0 marked this conversation as resolved.
Show resolved Hide resolved
channel_axis: the index of the channel axis.

Returns:
Expand All @@ -183,7 +183,7 @@ def flip_up_down(
Assumes that the image is either ...HWC or ...CHW, and flips the H axis.

Args:
image: an RGB image, given as a float tensor in [0, 1].
image: a JAX array which represents an image.
pabloduque0 marked this conversation as resolved.
Show resolved Hide resolved
channel_axis: the index of the channel axis.

Returns:
Expand Down Expand Up @@ -295,16 +295,62 @@ def solarize(image: chex.Array, threshold: chex.Numeric) -> chex.Array:
return jnp.where(image < threshold, image, 1. - image)


def random_flip_left_right(key: chex.PRNGKey, image: chex.Array) -> chex.Array:
"""50% chance of `flip_left_right(...)` otherwise returns image unchanged."""
coin_flip = jax.random.bernoulli(key)
return jax.lax.cond(coin_flip, flip_left_right, lambda x: x, image)
def random_flip_left_right(
key: chex.PRNGKey,
image: chex.Array,
*,
probability: chex.Numeric = .5,
pabloduque0 marked this conversation as resolved.
Show resolved Hide resolved
) -> chex.Array:
"""Applies `flip_left_right(...)` with a given probability.

If the number drawn from a uniform [0, 1] falls bellow the given probability
it applies `flip_up_down(...)` to the image, otherwise returns original image.

Args:
key: a JAX RNG key.
image: a JAX array which represents an image. Assumes that the image is
pabloduque0 marked this conversation as resolved.
Show resolved Hide resolved
either ...HWC or ...CHW.
probability: the probability of applying flip_left_right transform. Must be
a value in [0, 1].

Returns:
A left-right flipped image if condition is met, otherwise original image.
"""
drawn_value = jax.random.uniform(key)
return jax.lax.cond(
drawn_value < probability,
flip_left_right,
lambda x: x,
image)


def random_flip_up_down(
pabloduque0 marked this conversation as resolved.
Show resolved Hide resolved
key: chex.PRNGKey,
image: chex.Array,
*,
probability: chex.Numeric = .5,
pabloduque0 marked this conversation as resolved.
Show resolved Hide resolved
) -> chex.Array:
"""Applies `flip_up_down(...)` with a given probability.

If the number drawn from a uniform [0, 1] falls bellow the given probability
it applies `flip_up_down(...)` to the image, otherwise returns original image.

Args:
key: a JAX RNG key.
image: a JAX array which represents an image. Assumes that the image is
pabloduque0 marked this conversation as resolved.
Show resolved Hide resolved
either ...HWC or ...CHW.
probability: the probability of applying flip_up_down transform. Must be a
value in [0, 1].

def random_flip_up_down(key: chex.PRNGKey, image: chex.Array) -> chex.Array:
"""50% chance of `flip_up_down(...)` otherwise returns image unchanged."""
coin_flip = jax.random.bernoulli(key)
return jax.lax.cond(coin_flip, flip_up_down, lambda x: x, image)
Returns:
An up-down flipped image if condition is met, otherwise original image.
"""
drawn_value = jax.random.uniform(key)
return jax.lax.cond(
drawn_value < probability,
flip_up_down,
lambda x: x,
image)


def random_brightness(
Expand Down
15 changes: 13 additions & 2 deletions dm_pix/_src/augment_test.py
Expand Up @@ -166,6 +166,16 @@ def test_flip(self, images_list):
images_list,
jax_fn=functools.partial(augment.random_flip_up_down, key),
tf_fn=None)
self._test_fn_with_random_arg(
images_list,
jax_fn=functools.partial(augment.random_flip_left_right, key),
tf_fn=None,
probability=(0., 1.))
self._test_fn_with_random_arg(
images_list,
jax_fn=functools.partial(augment.random_flip_up_down, key),
tf_fn=None,
probability=(0., 1.))

@parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE),
("out_of_range", _RAND_FLOATS_OUT_OF_RANGE))
Expand Down Expand Up @@ -221,7 +231,8 @@ def _test_fn_with_random_arg(self, images_list, jax_fn, tf_fn, **kw_range):
]
fn_vmap = jax.vmap(jax_fn)
outputs_vmaped = list(
fn_vmap(np.stack(images_list, axis=0), np.stack(arguments, axis=0)))
fn_vmap(np.stack(images_list, axis=0),
**{kw_name: np.stack(arguments, axis=0)}))
assert len(images_list) == len(outputs_vmaped)
assert len(images_list) == len(arguments)
for image_rgb, argument, adjusted_vmap in zip(images_list, arguments,
Expand All @@ -248,7 +259,7 @@ def _test_fn_with_random_arg(self, images_list, jax_fn, tf_fn, **kw_range):
jax_fn_jitted = jax.jit(jax_fn)
for image_rgb in images_list:
argument = np.random.uniform(random_min, random_max, size=())
adjusted_jax = jax_fn(image_rgb, argument)
adjusted_jax = jax_fn(image_rgb, **{kw_name: argument})
adjusted_jit = jax_fn_jitted(image_rgb, **{kw_name: argument})
self.assertAllCloseTolerant(adjusted_jax, adjusted_jit)

Expand Down