Skip to content

Commit

Permalink
[sparse] display n_batch & n_dense in BCOO repr
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 28, 2022
1 parent cdd1167 commit 0bb37e2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
19 changes: 19 additions & 0 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,25 @@ def __init__(self, args, *, shape, indices_sorted=False):
self._indices_sorted = indices_sorted
super().__init__(args, shape=shape)

def __repr__(self):
name = self.__class__.__name__
try:
nse = self.nse
n_batch = self.n_batch
n_dense = self.n_dense
dtype = self.dtype
shape = list(self.shape)
except:
repr_ = f"{name}(<invalid>)"
else:
extra = f", nse={nse}"
if n_batch: extra += f", n_batch={n_batch}"
if n_dense: extra += f", n_dense={n_dense}"
repr_ = f"{name}({dtype}{shape}{extra})"
if isinstance(self.data, core.Tracer):
repr_ = f"{type(self.data).__name__}[{repr_}]"
return repr_

@classmethod
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0, n_batch=0):
"""Create a BCOO array from a (dense) :class:`DeviceArray`."""
Expand Down
25 changes: 19 additions & 6 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,25 @@ def test_coo_matmul_ad(self, shape, dtype, bshape):


class BCOOTest(jtu.JaxTestCase):

def test_repr(self):
x = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32'))
self.assertEqual(repr(x), "BCOO(float32[5], nse=4)")

y = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3), n_batch=1)
self.assertEqual(repr(y), "BCOO(float32[2, 3], nse=3, n_batch=1)")

y = sparse.BCOO.fromdense(jnp.arange(6, dtype='float32').reshape(2, 3), n_batch=1, n_dense=1)
self.assertEqual(repr(y), "BCOO(float32[2, 3], nse=1, n_batch=1, n_dense=1)")

M_invalid = sparse.BCOO(([], []), shape=(100,))
self.assertEqual(repr(M_invalid), "BCOO(<invalid>)")

@jit
def f(x):
self.assertEqual(repr(x), "DynamicJaxprTracer[BCOO(float32[5], nse=4)]")
f(x)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
Expand Down Expand Up @@ -2007,12 +2026,6 @@ def f(X, y):


class SparseObjectTest(jtu.JaxTestCase):
def test_repr(self):
M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32'))
self.assertEqual(repr(M), "BCOO(float32[5], nse=4)")

M_invalid = sparse.BCOO(([], []), shape=(100,))
self.assertEqual(repr(M_invalid), "BCOO(<invalid>)")

@parameterized.named_parameters(
{"testcase_name": "_{}{}".format(cls.__name__, shape), "cls": cls, "shape": shape}
Expand Down

0 comments on commit 0bb37e2

Please sign in to comment.