Skip to content

Commit

Permalink
[sparse][x64] prevent unnecessary dtype promotion in sparse impls
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 9, 2022
1 parent 859883c commit ca01d1b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def _bcoo_set_nse(mat, nse):
data = data.at[(*(slice(None) for i in range(mat.n_batch)), slice(mat.nse))].set(mat.data)
indices = jnp.zeros_like(mat.indices, shape=(*mat.indices.shape[:-2], nse, mat.indices.shape[-1]))
indices = indices.at[..., :mat.nse, :].set(mat.indices)
indices = indices.at[..., mat.nse:, :].set(jnp.array(mat.shape[mat.n_batch:mat.n_batch + mat.n_sparse]))
indices = indices.at[..., mat.nse:, :].set(jnp.array(mat.shape[mat.n_batch:mat.n_batch + mat.n_sparse],
dtype=indices.dtype))
return BCOO((data, indices), shape=mat.shape,
indices_sorted=mat.indices_sorted,
unique_indices=mat.unique_indices)
Expand Down Expand Up @@ -1738,7 +1739,8 @@ def bcoo_concatenate(operands, *, dimension):
new_indices = lax.concatenate([op.indices for op in operands], dimension=dimension)
new_data = lax.concatenate([op.data for op in operands], dimension=dimension)
elif dimension < n_batch + n_sparse: # Concatenation along sparse axes
offsets = np.cumsum([0] + [op.shape[dimension] for op in operands[:-1]])
offsets = np.cumsum([0] + [op.shape[dimension] for op in operands[:-1]],
dtype=operands[0].indices.dtype)
new_data = lax.concatenate([op.data for op in operands], dimension=n_batch)
new_indices = lax.concatenate([op.indices.at[..., dimension - n_batch].add(offset)
for op, offset in safe_zip(operands, offsets)],
Expand Down Expand Up @@ -1803,7 +1805,7 @@ def bcoo_reshape(mat, *, new_sizes, dimensions):
new_indices = jnp.concatenate([col[..., None] for col in new_index_cols], axis=-1)
with jax.numpy_rank_promotion('allow'):
oob_indices = (indices >= jnp.array(mat.shape[mat.n_batch:])).any(-1)
new_indices = new_indices.at[oob_indices].set(jnp.array(sparse_sizes))
new_indices = new_indices.at[oob_indices].set(jnp.array(sparse_sizes, dtype=new_indices.dtype))

return BCOO((data, new_indices), shape=new_sizes)

Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/sparse/csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'):
# TODO(jakevdp): this can be done more efficiently.
row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k))
indptr = jnp.zeros(N + 1, dtype=index_dtype).at[1:].set(
jnp.cumsum(jnp.bincount(row, length=N)))
jnp.cumsum(jnp.bincount(row, length=N).astype(index_dtype)))
return cls((data, indices, indptr), shape=(N, M))

def todense(self):
Expand Down Expand Up @@ -296,7 +296,7 @@ def _csr_fromdense_impl(mat, *, nse, index_dtype):
row = jnp.where(true_nonzeros, row, m)
indices = col.astype(index_dtype)
indptr = jnp.zeros(m + 1, dtype=index_dtype).at[1:].set(
jnp.cumsum(jnp.bincount(row, length=m)))
jnp.cumsum(jnp.bincount(row, length=m).astype(index_dtype)))
return data, indices, indptr

@csr_fromdense_p.def_abstract_eval
Expand Down

0 comments on commit ca01d1b

Please sign in to comment.