Skip to content

Commit

Permalink
[sparse] update sparse object repr to show tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 15, 2021
1 parent 096f77d commit 591507a
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion jax/experimental/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ def bcoo_reduce_sum(data, indices, *, shape, axes):
# Sparse objects (APIs subject to change)
class JAXSparse:
"""Base class for high-level JAX sparse objects."""
data: jnp.ndarray
shape: Tuple[int, int]
nnz: property
dtype: property
Expand All @@ -1048,7 +1049,10 @@ def __init__(self, args, *, shape):
self.shape = shape

def __repr__(self):
return f"{self.__class__.__name__}({self.dtype}{list(self.shape)}, nnz={self.nnz})"
repr_ = f"{self.__class__.__name__}({self.dtype}{list(self.shape)}, nnz={self.nnz})"
if isinstance(self.data, core.Tracer):
repr_ = f"{type(self.data).__name__}[{repr_}]"
return repr_

def tree_flatten(self):
raise NotImplementedError("tree_flatten")
Expand Down

0 comments on commit 591507a

Please sign in to comment.