Skip to content

Commit

Permalink
Merge pull request #17069 from jakevdp:unpackbits-count
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555695640
  • Loading branch information
jax authors committed Aug 10, 2023
2 parents 5349ea6 + 4df5805 commit 60c3fdf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
7 changes: 6 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3836,7 +3836,12 @@ def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
axis = 0
a = swapaxes(a, axis, -1)
unpacked = ((a[..., None] & expand_dims(bits, tuple(range(a.ndim)))) > 0).astype('uint8')
unpacked = unpacked.reshape(unpacked.shape[:-2] + (-1,))[..., :count]
unpacked = unpacked.reshape(unpacked.shape[:-2] + (-1,))
if count is not None:
if count > unpacked.shape[-1]:
unpacked = pad(unpacked, [(0, 0)] * (unpacked.ndim - 1) + [(0, count - unpacked.shape[-1])])
else:
unpacked = unpacked[..., :count]
return swapaxes(unpacked, axis, -1)


Expand Down
4 changes: 2 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3834,8 +3834,8 @@ def testPackbits(self, shape, dtype, axis, bitorder):
def testUnpackbits(self, shape, dtype, axis, bitorder, count):
rng = jtu.rand_int(self.rng(), 0, 256)
args_maker = lambda: [rng(shape, dtype)]
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder)
np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder)
jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder, count=count)
np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder, count=count)
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

Expand Down

0 comments on commit 60c3fdf

Please sign in to comment.