Skip to content

Commit

Permalink
[sparse] globally change nnz->nse
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 1, 2021
1 parent 5978797 commit 76f9e6f
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 92 deletions.
140 changes: 70 additions & 70 deletions jax/experimental/sparse/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
# TODO: possibly make these utilities into primitives, targeting
# csr2coo/coo2csr/SPDDMM
@functools.partial(jit, static_argnums=1)
def _csr_to_coo(indptr, nnz):
return jnp.cumsum(jnp.zeros_like(indptr, shape=nnz).at[indptr].add(1)) - 1
def _csr_to_coo(indptr, nse):
return jnp.cumsum(jnp.zeros_like(indptr, shape=nse).at[indptr].add(1)) - 1

@functools.partial(jit, static_argnums=1)
def _coo_to_csr(row, nrows):
Expand All @@ -88,8 +88,8 @@ def csr_todense(data, indices, indptr, *, shape):
"""Convert CSR-format sparse matrix to a dense matrix.
Args:
data : array of shape ``(nnz,)``.
indices : array of shape ``(nnz,)``
data : array of shape ``(nse,)``.
indices : array of shape ``(nse,)``
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
shape : length-2 tuple representing the matrix shape
Expand Down Expand Up @@ -125,33 +125,33 @@ def _csr_todense_gpu_translation_rule(c, data, indices, indptr, *, shape):
csr_fromdense_p = core.Primitive('csr_fromdense')
csr_fromdense_p.multiple_results = True

def csr_fromdense(mat, *, nnz, index_dtype=np.int32):
def csr_fromdense(mat, *, nse, index_dtype=np.int32):
"""Create CSR-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to CSR.
nnz : number of nonzero entries in ``mat``
nse : number of specified entries in ``mat``
index_dtype : dtype of sparse indices
Returns:
data : array of shape ``(nnz,)`` and dtype ``mat.dtype``.
indices : array of shape ``(nnz,)`` and dtype ``index_dtype``
data : array of shape ``(nse,)`` and dtype ``mat.dtype``.
indices : array of shape ``(nse,)`` and dtype ``index_dtype``
indptr : array of shape ``(mat.shape[0] + 1,)`` and dtype ``index_dtype``
"""
mat = jnp.asarray(mat)
nnz = core.concrete_or_error(operator.index, nnz, "nnz argument of csr_fromdense()")
return csr_fromdense_p.bind(mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
nse = core.concrete_or_error(operator.index, nse, "nse argument of csr_fromdense()")
return csr_fromdense_p.bind(mat, nse=nse, index_dtype=np.dtype(index_dtype))

@csr_fromdense_p.def_impl
def _csr_fromdense_impl(mat, *, nnz, index_dtype):
def _csr_fromdense_impl(mat, *, nse, index_dtype):
mat = jnp.asarray(mat)
assert mat.ndim == 2
m = mat.shape[0]

row, col = jnp.nonzero(mat, size=nnz)
row, col = jnp.nonzero(mat, size=nse)
data = mat[row, col]

true_nonzeros = jnp.arange(nnz) < (mat != 0).sum()
true_nonzeros = jnp.arange(nse) < (mat != 0).sum()
data = jnp.where(true_nonzeros, data, 0)
row = jnp.where(true_nonzeros, row, m)
indices = col.astype(index_dtype)
Expand All @@ -160,15 +160,15 @@ def _csr_fromdense_impl(mat, *, nnz, index_dtype):
return data, indices, indptr

@csr_fromdense_p.def_abstract_eval
def _csr_fromdense_abstract_eval(mat, *, nnz, index_dtype):
data = core.ShapedArray((nnz,), mat.dtype)
indices = core.ShapedArray((nnz,), index_dtype)
def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype):
data = core.ShapedArray((nse,), mat.dtype)
indices = core.ShapedArray((nse,), index_dtype)
indptr = core.ShapedArray((mat.shape[0] + 1,), index_dtype)
return data, indices, indptr

def _csr_fromdense_gpu_translation_rule(c, mat, *, nnz, index_dtype):
def _csr_fromdense_gpu_translation_rule(c, mat, *, nse, index_dtype):
data, indices, indptr = cusparse.csr_fromdense(
c, mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
c, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return xops.Tuple(c, [data, indices, indptr])

xla.translations[csr_fromdense_p] = xla.lower_fun(
Expand All @@ -186,8 +186,8 @@ def csr_matvec(data, indices, indptr, v, *, shape, transpose=False):
"""Product of CSR sparse matrix and a dense vector.
Args:
data : array of shape ``(nnz,)``.
indices : array of shape ``(nnz,)``
data : array of shape ``(nse,)``.
indices : array of shape ``(nse,)``
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
v : array of shape ``(shape[0] if transpose else shape[1],)``
and dtype ``data.dtype``
Expand Down Expand Up @@ -237,8 +237,8 @@ def csr_matmat(data, indices, indptr, B, *, shape, transpose=False):
"""Product of CSR sparse matrix and a dense matrix.
Args:
data : array of shape ``(nnz,)``.
indices : array of shape ``(nnz,)``
data : array of shape ``(nse,)``.
indices : array of shape ``(nse,)``
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
dtype ``data.dtype``
Expand Down Expand Up @@ -288,9 +288,9 @@ def coo_todense(data, row, col, *, shape):
"""Convert CSR-format sparse matrix to a dense matrix.
Args:
data : array of shape ``(nnz,)``.
row : array of shape ``(nnz,)``
col : array of shape ``(nnz,)`` and dtype ``row.dtype``
data : array of shape ``(nse,)``.
row : array of shape ``(nse,)``
col : array of shape ``(nse,)`` and dtype ``row.dtype``
shape : length-2 tuple representing the matrix shape
Returns:
Expand Down Expand Up @@ -337,52 +337,52 @@ def _coo_todense_transpose(ct, data, row, col, *, shape):
coo_fromdense_p = core.Primitive('coo_fromdense')
coo_fromdense_p.multiple_results = True

def coo_fromdense(mat, *, nnz, index_dtype=jnp.int32):
def coo_fromdense(mat, *, nse, index_dtype=jnp.int32):
"""Create COO-format sparse matrix from a dense matrix.
Args:
mat : array to be converted to COO.
nnz : number of nonzero entries in ``mat``
nse : number of specified entries in ``mat``
index_dtype : dtype of sparse indices
Returns:
data : array of shape ``(nnz,)`` and dtype ``mat.dtype``
row : array of shape ``(nnz,)`` and dtype ``index_dtype``
col : array of shape ``(nnz,)`` and dtype ``index_dtype``
data : array of shape ``(nse,)`` and dtype ``mat.dtype``
row : array of shape ``(nse,)`` and dtype ``index_dtype``
col : array of shape ``(nse,)`` and dtype ``index_dtype``
"""
mat = jnp.asarray(mat)
nnz = core.concrete_or_error(operator.index, nnz, "nnz argument of coo_fromdense()")
return coo_fromdense_p.bind(mat, nnz=nnz, index_dtype=index_dtype)
nse = core.concrete_or_error(operator.index, nse, "nse argument of coo_fromdense()")
return coo_fromdense_p.bind(mat, nse=nse, index_dtype=index_dtype)

@coo_fromdense_p.def_impl
def _coo_fromdense_impl(mat, *, nnz, index_dtype):
def _coo_fromdense_impl(mat, *, nse, index_dtype):
mat = jnp.asarray(mat)
assert mat.ndim == 2

row, col = jnp.nonzero(mat, size=nnz)
row, col = jnp.nonzero(mat, size=nse)
data = mat[row, col]

true_nonzeros = jnp.arange(nnz) < (mat != 0).sum()
true_nonzeros = jnp.arange(nse) < (mat != 0).sum()
data = jnp.where(true_nonzeros, data, 0)

return data, row.astype(index_dtype), col.astype(index_dtype)

@coo_fromdense_p.def_abstract_eval
def _coo_fromdense_abstract_eval(mat, *, nnz, index_dtype):
data = core.ShapedArray((nnz,), mat.dtype)
row = col = core.ShapedArray((nnz,), index_dtype)
def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype):
data = core.ShapedArray((nse,), mat.dtype)
row = col = core.ShapedArray((nse,), index_dtype)
return data, row, col

def _coo_fromdense_gpu_translation_rule(c, mat, *, nnz, index_dtype):
def _coo_fromdense_gpu_translation_rule(c, mat, *, nse, index_dtype):
data, row, col = cusparse.coo_fromdense(
c, mat, nnz=nnz, index_dtype=np.dtype(index_dtype))
c, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return xops.Tuple(c, [data, row, col])

def _coo_fromdense_jvp(primals, tangents, *, nnz, index_dtype):
def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
M, = primals
Mdot, = tangents

primals_out = coo_fromdense(M, nnz=nnz, index_dtype=index_dtype)
primals_out = coo_fromdense(M, nse=nse, index_dtype=index_dtype)
data, row, col = primals_out

if type(Mdot) is ad.Zero:
Expand All @@ -394,9 +394,9 @@ def _coo_fromdense_jvp(primals, tangents, *, nnz, index_dtype):

return primals_out, tangents_out

def _coo_fromdense_transpose(ct, M, *, nnz, index_dtype):
def _coo_fromdense_transpose(ct, M, *, nse, index_dtype):
data, row, col = ct
assert len(data) == nnz
assert len(data) == nse
assert row.dtype == col.dtype == index_dtype
if isinstance(row, ad.Zero) or isinstance(col, ad.Zero):
raise ValueError("Cannot transpose with respect to sparse indices")
Expand All @@ -421,9 +421,9 @@ def coo_matvec(data, row, col, v, *, shape, transpose=False):
"""Product of COO sparse matrix and a dense vector.
Args:
data : array of shape ``(nnz,)``.
row : array of shape ``(nnz,)``
col : array of shape ``(nnz,)`` and dtype ``row.dtype``
data : array of shape ``(nse,)``.
row : array of shape ``(nse,)``
col : array of shape ``(nse,)`` and dtype ``row.dtype``
v : array of shape ``(shape[0] if transpose else shape[1],)`` and
dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
Expand Down Expand Up @@ -492,9 +492,9 @@ def coo_matmat(data, row, col, B, *, shape, transpose=False):
"""Product of COO sparse matrix and a dense matrix.
Args:
data : array of shape ``(nnz,)``.
row : array of shape ``(nnz,)``
col : array of shape ``(nnz,)`` and dtype ``row.dtype``
data : array of shape ``(nse,)``.
row : array of shape ``(nse,)``
col : array of shape ``(nse,)`` and dtype ``row.dtype``
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
Expand Down Expand Up @@ -1216,7 +1216,7 @@ class JAXSparse:
"""Base class for high-level JAX sparse objects."""
data: jnp.ndarray
shape: Tuple[int, int]
nnz: property
nse: property
dtype: property

@property
Expand All @@ -1227,7 +1227,7 @@ def __init__(self, args, *, shape):
self.shape = shape

def __repr__(self):
repr_ = f"{self.__class__.__name__}({self.dtype}{list(self.shape)}, nnz={self.nnz})"
repr_ = f"{self.__class__.__name__}({self.dtype}{list(self.shape)}, nse={self.nse})"
if isinstance(self.data, core.Tracer):
repr_ = f"{type(self.data).__name__}[{repr_}]"
return repr_
Expand Down Expand Up @@ -1270,18 +1270,18 @@ class CSR(JAXSparse):
data: jnp.ndarray
indices: jnp.ndarray
indptr: jnp.ndarray
nnz = property(lambda self: self.data.size)
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)

def __init__(self, args, *, shape):
self.data, self.indices, self.indptr = map(jnp.asarray, args)
super().__init__(args, shape=shape)

@classmethod
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32):
if nnz is None:
nnz = (mat != 0).sum()
return cls(csr_fromdense(mat, nnz=nnz, index_dtype=index_dtype), shape=mat.shape)
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
if nse is None:
nse = (mat != 0).sum()
return cls(csr_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape)

@api.jit
def todense(self):
Expand Down Expand Up @@ -1309,18 +1309,18 @@ class CSC(JAXSparse):
data: jnp.ndarray
indices: jnp.ndarray
indptr: jnp.ndarray
nnz = property(lambda self: self.data.size)
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)

def __init__(self, args, *, shape):
self.data, self.indices, self.indptr = map(jnp.asarray, args)
super().__init__(args, shape=shape)

@classmethod
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32):
if nnz is None:
nnz = (mat != 0).sum()
return cls(csr_fromdense(mat.T, nnz=nnz, index_dtype=index_dtype), shape=mat.shape)
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
if nse is None:
nse = (mat != 0).sum()
return cls(csr_fromdense(mat.T, nse=nse, index_dtype=index_dtype), shape=mat.shape)

@api.jit
def todense(self):
Expand Down Expand Up @@ -1348,18 +1348,18 @@ class COO(JAXSparse):
data: jnp.ndarray
row: jnp.ndarray
col: jnp.ndarray
nnz = property(lambda self: self.data.size)
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)

def __init__(self, args, *, shape):
self.data, self.row, self.col = map(jnp.asarray, args)
super().__init__(args, shape=shape)

@classmethod
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32):
if nnz is None:
nnz = (mat != 0).sum()
return cls(coo_fromdense(mat, nnz=nnz, index_dtype=index_dtype), shape=mat.shape)
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
if nse is None:
nse = (mat != 0).sum()
return cls(coo_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape)

@api.jit
def todense(self):
Expand Down Expand Up @@ -1390,7 +1390,7 @@ class BCOO(JAXSparse):
"""Experimental BCOO matrix implemented in JAX; API subject to change."""
data: jnp.ndarray
indices: jnp.ndarray
nnz = property(lambda self: self.data.size)
nse = property(lambda self: self.data.size)
dtype = property(lambda self: self.data.dtype)
n_batch = property(lambda self: self.indices.ndim - 2)
n_sparse = property(lambda self: self.indices.shape[-2])
Expand All @@ -1406,8 +1406,8 @@ def __init__(self, args, *, shape):
super().__init__(args, shape=shape)

@classmethod
def fromdense(cls, mat, *, nnz=None, index_dtype=np.int32, n_dense=0, n_batch=0):
return cls(bcoo_fromdense(mat, nse=nnz, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch), shape=mat.shape)
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32, n_dense=0, n_batch=0):
return cls(bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype, n_dense=n_dense, n_batch=n_batch), shape=mat.shape)

@api.jit
def todense(self):
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/sparse/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
>>> mat_sparse = BCOO.fromdense(mat)
>>> mat_sparse
BCOO(float32[5, 5], nnz=8)
BCOO(float32[5, 5], nse=8)
>>> sparsify(f)(mat_sparse, vec)
DeviceArray([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
Expand Down

0 comments on commit 76f9e6f

Please sign in to comment.