Skip to content

Commit

Permalink
[sparse] implement sparse rule for lax.concatenate_p
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 28, 2022
1 parent 0b47036 commit 2d9af38
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 1 deletion.
1 change: 1 addition & 0 deletions jax/experimental/sparse/__init__.py
Expand Up @@ -189,6 +189,7 @@
from jax.experimental.sparse.bcoo import (
bcoo_add_batch_dim as bcoo_add_batch_dim,
bcoo_broadcast_in_dim as bcoo_broadcast_in_dim,
bcoo_concatenate as bcoo_concatenate,
bcoo_dot_general as bcoo_dot_general,
bcoo_dot_general_p as bcoo_dot_general_p,
bcoo_dot_general_sampled as bcoo_dot_general_sampled,
Expand Down
92 changes: 91 additions & 1 deletion jax/experimental/sparse/bcoo.py
Expand Up @@ -20,6 +20,7 @@

import numpy as np

import jax
from jax import core
from jax import lax
from jax import tree_util
Expand Down Expand Up @@ -79,6 +80,25 @@ def _bcoo_nse(mat, n_batch=0, n_dense=0):
mask = mask.sum(list(range(n_batch, mask.ndim)))
return mask.max()

def _bcoo_set_nse(mat, nse):
"""Return a copy of `mat` with the specified nse.
Note that if nse < mat.nse, this will potentially discard data.
"""
nse = operator.index(nse)
assert nse >= 0
if mat.nse == nse:
return mat
if nse <= mat.nse:
data = mat.data[(*(slice(None) for i in range(mat.n_batch)), slice(nse))]
indices = mat.indices[..., :nse, :]
else:
data = jnp.zeros_like(mat.data, shape=(*mat.data.shape[:mat.n_batch], nse, *mat.data.shape[mat.n_batch + 1:]))
data = data.at[(*(slice(None) for i in range(mat.n_batch)), slice(mat.nse))].set(mat.data)
indices = jnp.zeros_like(mat.indices, shape=(*mat.indices.shape[:-2], nse, mat.indices.shape[-1]))
indices = indices.at[..., :mat.nse, :].set(mat.indices)
indices = indices.at[..., mat.nse:, :].set(jnp.array(mat.shape[mat.n_batch:mat.n_batch + mat.n_sparse]))
return BCOO((data, indices), shape=mat.shape, indices_sorted=mat._indices_sorted)

# TODO(jakevdp) this can be problematic when used with autodiff; see
# https://github.com/google/jax/issues/10163. Should this be a primitive?
# Alternatively, maybe roll this into bcoo_sum_duplicates as an optional argument.
Expand All @@ -102,6 +122,7 @@ def _unbatch_bcoo(data, indices, shape):
data = jnp.broadcast_to(data, shape[:n_batch] + data.shape[n_batch:])
indices = jnp.broadcast_to(indices, shape[:n_batch] + indices.shape[n_batch:])
batch_indices = jnp.mgrid[tuple(slice(None, d) for d in indices.shape[:n_batch + 1])][:-1]
batch_indices = batch_indices.astype(indices.dtype)
batch_indices = batch_indices.reshape(n_batch, -1).T
data = data.reshape(np.prod(data.shape[:n_batch + 1]), *data.shape[n_batch + 1:])
indices = indices.reshape(np.prod(indices.shape[:n_batch + 1]), *indices.shape[n_batch + 1:])
Expand Down Expand Up @@ -1510,7 +1531,6 @@ def _bcoo_broadcast_in_dim(data, indices, *, spinfo, shape, broadcast_dimensions
if np.prod(spinfo.shape[props.n_batch: props.n_batch + props.n_sparse]) != np.prod(shape[new_n_batch:new_n_batch + new_n_sparse]):
raise NotImplementedError("Adding sparse dimensions with lengths != 1")
nse = props.nse

# batch & dense dimensions
new_data = lax.broadcast_in_dim(data,
shape=(*shape[:new_n_batch], nse, *shape[new_n_batch + new_n_sparse:]),
Expand All @@ -1525,6 +1545,76 @@ def _bcoo_broadcast_in_dim(data, indices, *, spinfo, shape, broadcast_dimensions

return new_data, new_indices

def bcoo_concatenate(operands, *, dimension):
"""Sparse implementation of :func:`jax.lax.concatenate`
Args:
operands : Sequence of BCOO arrays to concatenate. The arrays must have equal
shapes, except in the `dimension` axis. Additionally, the arrays must have
have equivalent batch, sparse, and dense dimensions.
dimension : Positive integer specifying the dimension along which to concatenate
the arrays. The dimension must be among batch or sparse dimensions of the input;
concatenation along dense dimensions is not supported.
Returns:
A BCOO array containing the concatenation of the inputs.
"""
dimension = operator.index(dimension)
if not all(isinstance(op, BCOO) for op in operands):
raise ValueError("bcoo_concatenate: expected operands to be a sequence of BCOO arrays. "
f"Got {operands}")
# Validate inputs using lax.concatenate abstract evaluation.
out_aval = jax.eval_shape(
functools.partial(lax.concatenate, dimension=dimension),
[core.ShapedArray(op.shape, op.dtype) for op in operands])
if len(set(op.n_dense for op in operands)) > 1:
raise ValueError("bcoo_concatenate requires inputs to have matching nse dimensions.")

n_batches = set(op.n_batch for op in operands)
# Correct for the common case, where op[None, :] adds a single batch dimension and we
# need to align it in order to match the others & concatenate.
if len(n_batches) != 1 and max(n_batches) == 1:
if all(op.shape[0] == 1 for op in operands if op.n_batch == 0):
operands = [bcoo_add_batch_dim(op) if op.n_batch == 0 else op for op in operands]
elif all(op.shape[0] == 1 for op in operands if op.n_batch == 1):
operands = [op._unbatch() if op.n_batch == 1 else op for op in operands]
n_batches = set(op.n_batch for op in operands)

if len(n_batches) != 1:
raise ValueError("bcoo_concatenate requires inputs to have matching batch dimensions.")

n_batch, n_sparse = operands[0].n_batch, operands[0].n_sparse

index_batches = [op.indices.shape[:n_batch] for op in operands]
data_batches = [op.data.shape[:n_batch] for op in operands]
if dimension < n_batch:
index_batches = [s[:dimension] + s[dimension + 1:] for s in index_batches]
data_batches = [s[:dimension] + s[dimension + 1:] for s in data_batches]
if not (len(set(index_batches)) == len(set(data_batches)) == 1):
raise NotImplementedError("concatenation of arrays with broadcasted batch indices")

if dimension < n_batch: # Concatenation along batch axes
# Ensure nse of operands match.
nses = set(op.nse for op in operands)
if len(nses) != 1:
nse = max(nses)
operands = [_bcoo_set_nse(op, nse) for op in operands]
new_indices = lax.concatenate([op.indices for op in operands], dimension=dimension)
new_data = lax.concatenate([op.data for op in operands], dimension=dimension)
elif dimension < n_batch + n_sparse: # Concatenation along sparse axes
offsets = np.cumsum([0] + [op.shape[dimension] for op in operands[:-1]])
new_data = lax.concatenate([op.data for op in operands], dimension=n_batch)
new_indices = lax.concatenate([op.indices.at[..., dimension - n_batch].add(offset)
for op, offset in safe_zip(operands, offsets)],
dimension=n_batch)
else: # Concatenation along dense axes
# TODO(jakevdp) should we implement this? In general it results in a wasteful
# representation because we cannot assume that the indices match.
raise NotImplementedError("Concatenation along dense dimensions.")

return BCOO((new_data, new_indices), shape=out_aval.shape)


def _tuple_replace(tup, ind, val):
return tuple(val if i == ind else t for i, t in enumerate(tup))

Expand Down
7 changes: 7 additions & 0 deletions jax/experimental/sparse/transform.py
Expand Up @@ -586,6 +586,13 @@ def _broadcast_in_dim_sparse(spenv, *spvalues, shape, broadcast_dimensions):

sparse_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_sparse

def _concatenate_sparse(spenv, *spvalues, dimension):
operands = spvalues_to_arrays(spenv, spvalues)
result = sparse.bcoo_concatenate(operands, dimension=dimension)
return arrays_to_spvalues(spenv, (result,))

sparse_rules[lax.concatenate_p] = _concatenate_sparse

def _squeeze_sparse(spenv, *spvalues, dimensions):
arr, = spvalues
dimensions = tuple(canonicalize_axis(dim, arr.ndim) for dim in dimensions)
Expand Down
20 changes: 20 additions & 0 deletions tests/sparse_test.py
Expand Up @@ -1878,6 +1878,26 @@ def test_bcoo_broadcast_in_dim(self, shape, dtype, n_batch, n_dense):
self.assertArraysEqual(xsp[:, :, None].todense(), x[:, :, None])
self.assertArraysEqual(xsp[:, None, :, None].todense(), x[:, None, :, None])

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_n_batch={}_n_dense={}_dimension={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense, dimension),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense, "dimension": dimension}
for shape in [ (3,), (3, 5), (3, 5, 4)]
for dtype in all_dtypes
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)
for dimension in range(len(shape) - n_dense))) # Concatenation of dense dimensions not implemented.
def test_bcoo_concatenate(self, shape, dtype, n_batch, n_dense, dimension):
rng = rand_sparse(self.rng())
operands_dense = [rng(shape, dtype) for i in range(3)]
operands_sparse = [sparse.BCOO.fromdense(op, n_batch=n_batch, n_dense=n_dense)
for op in operands_dense]

mat_dense = lax.concatenate(operands_dense, dimension=dimension)
mat_sparse = sparse.bcoo_concatenate(operands_sparse, dimension=dimension)

self.assertArraysEqual(mat_sparse.todense(), mat_dense)

def test_bcoo_vmap_shape(self, shape=(2, 3, 4, 5), dtype=np.float32):
# This test checks that BCOO shape metadata interacts correctly with vmap.
rng = rand_sparse(self.rng())
Expand Down
30 changes: 30 additions & 0 deletions tests/sparsify_test.py
Expand Up @@ -283,6 +283,36 @@ def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense):

self.assertAllClose(result_sparse, result_dense)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_shapes={shapes}_func={func}_nbatch={n_batch}",
"shapes": shapes, "func": func, "n_batch": n_batch}
for shapes, func, n_batch in [
([(4,), (4,)], "concatenate", 0),
([(4,), (4,)], "stack", 0),
([(4,), (4,)], "hstack", 0),
([(4,), (4,)], "vstack", 0),
([(4,), (4,)], "concatenate", 1),
([(4,), (4,)], "stack", 1),
([(4,), (4,)], "hstack", 1),
([(4,), (4,)], "vstack", 1),
([(2, 4), (2, 4)], "stack", 0),
([(2, 4), (3, 4)], "vstack", 0),
([(2, 4), (2, 5)], "hstack", 0),
([(2, 4), (3, 4)], "vstack", 1),
([(2, 4), (2, 5)], "hstack", 1),
([(2, 4), (3, 4)], "vstack", 2),
([(2, 4), (2, 5)], "hstack", 2),
([(2, 4), (4,), (3, 4)], "vstack", 0),
([(1, 4), (4,), (1, 4)], "vstack", 0),
]))
def testSparseConcatenate(self, shapes, func, n_batch):
f = self.sparsify(getattr(jnp, func))
rng = jtu.rand_some_zero(self.rng())
arrs = [rng(shape, 'int32') for shape in shapes]
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
self.assertArraysEqual(f(arrs), f(sparrs).todense())


def testSparseWhileLoop(self):
def cond_fun(params):
i, A = params
Expand Down

0 comments on commit 2d9af38

Please sign in to comment.