Skip to content

Commit

Permalink
[Cleanup] Make the error message more informative in `third_party/py/…
Browse files Browse the repository at this point in the history
…dm_pix/_src/depth_and_space.py`.

PiperOrigin-RevId: 404195610
  • Loading branch information
PIXDev authored and Copybara-Service committed Oct 19, 2021
1 parent 77058f0 commit dfb692a
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions dm_pix/_src/depth_and_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def depth_to_space(inputs: chex.Array, block_size: int) -> chex.Array:

height, width, depth = inputs.shape
if depth % (block_size**2) != 0:
raise ValueError('Number of channels must be divisible by block_size ** 2.')
raise ValueError(
f'Number of channels {depth} must be divisible by block_size ** 2 {block_size**2}.'
)
new_depth = depth // (block_size**2)
outputs = jnp.reshape(inputs,
[height, width, block_size, block_size, new_depth])
Expand Down Expand Up @@ -67,9 +69,11 @@ def space_to_depth(inputs: chex.Array, block_size: int) -> chex.Array:

height, width, depth = inputs.shape
if height % block_size != 0:
raise ValueError('Height must be divisible by block size.')
raise ValueError(
f'Height {height} must be divisible by block size {block_size}.')
if width % block_size != 0:
raise ValueError('Width must be divisible by block size.')
raise ValueError(
f'Width {width} must be divisible by block size {block_size}.')
new_depth = depth * (block_size**2)
new_height = height // block_size
new_width = width // block_size
Expand Down

0 comments on commit dfb692a

Please sign in to comment.