Skip to content

Commit

Permalink
[sparse] add bcoo_transpose primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 23, 2021
1 parent 2897776 commit c358da4
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 15 deletions.
2 changes: 2 additions & 0 deletions jax/experimental/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
bcoo_rdot_general,
bcoo_todense,
bcoo_todense_p,
bcoo_transpose,
bcoo_transpose_p,
coo_fromdense,
coo_fromdense_p,
coo_matmat,
Expand Down
135 changes: 120 additions & 15 deletions jax/experimental/sparse/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import functools
import operator

from typing import Any, Tuple
from typing import Any, Sequence, Tuple

from jax import api
from jax import core
Expand Down Expand Up @@ -603,11 +603,18 @@ def _validate_bcoo(data, indices, shape):
def _compatible(shape1, shape2):
return all(s1 in (1, s2) for s1, s2 in safe_zip(shape1, shape2))

assert _compatible(data.shape[:n_batch], shape[:n_batch])
assert data.shape[-(n_dense + 1):] == (nse,) + shape[n_batch + n_sparse:]

assert _compatible(indices.shape[:n_batch], shape[:n_batch])
assert indices.shape[n_batch:] == (n_sparse, nse)
if not _compatible(data.shape[:n_batch], shape[:n_batch]):
raise ValueError("data batch dimensions not compatible for "
f"data.shape={data.shape}, shape={shape}")
if data.shape[-(n_dense + 1):] != (nse,) + shape[n_batch + n_sparse:]:
raise ValueError(f"Invalid data.shape={data.shape} for "
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")
if not _compatible(indices.shape[:n_batch], shape[:n_batch]):
raise ValueError("indices batch dimensions not compatible for "
f"indices.shape={indices.shape}, shape={shape}")
if indices.shape[n_batch:] != (n_sparse, nse):
raise ValueError(f"Invalid indices.shape={indices.shape} for "
f"nse={nse}, n_batch={n_batch}, n_dense={n_dense}")

return n_batch, n_sparse, n_dense

Expand Down Expand Up @@ -833,6 +840,98 @@ def _bcoo_extract_batching_rule(batched_args, batch_dims):
xla.translations[bcoo_extract_p] = xla.lower_fun(
_bcoo_extract_impl, multiple_results=False)

#----------------------------------------------------------------------
# bcoo_transpose
# transpose of a BCOO array

bcoo_transpose_p = core.Primitive('bcoo_transpose')
bcoo_transpose_p.multiple_results = True

def bcoo_transpose(data, indices, *, permutation, shape):
return bcoo_transpose_p.bind(data, indices, permutation=permutation, shape=shape)

def _validate_permutation(data, indices, permutation, shape):
if not isinstance(permutation, (tuple, list, np.ndarray)):
raise TypeError(f"transpose permutation must be a tuple/list/ndarray, got {type(permutation)}.")
if tuple(sorted(permutation)) != tuple(range(len(shape))):
raise TypeError("transpose permutation isn't a permutation of operand dimensions, "
f"got permutation {permutation} for shape {shape}.")
n_batch, n_sparse, n_dense = _validate_bcoo(data, indices, shape)
batch_perm = permutation[:n_batch]
sparse_perm = [p - n_batch for p in permutation[n_batch: n_batch + n_sparse]]
dense_perm = [p - n_sparse - n_batch for p in permutation[n_batch + n_sparse:]]
if n_batch and tuple(sorted(batch_perm)) != tuple(range(n_batch)):
raise NotImplementedError("transpose permutation cannot permute batch axes with non-batch axes; "
f"got permutation {permutation}, with n_batch={n_batch}.")
if n_dense and tuple(sorted(dense_perm)) != tuple(range(n_dense)):
raise NotImplementedError("transpose permutation cannot permute dense axes with non-dense axes; "
f"got permutation {permutation}, with n_dense={n_dense}.")
return batch_perm, sparse_perm, dense_perm

@bcoo_transpose_p.def_impl
def _bcoo_transpose_impl(data, indices, *, permutation: Sequence[int], shape: Tuple[int]):
batch_perm, sparse_perm, dense_perm = _validate_permutation(data, indices, permutation, shape)
n_batch = len(batch_perm)
indices = indices[..., sparse_perm, :].transpose(*batch_perm, n_batch, n_batch + 1)
data = data.transpose(*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm))
return data, indices

@bcoo_transpose_p.def_abstract_eval
def _bcoo_transpose_abstract_eval(data, indices, *, permutation: Sequence[int], shape: Tuple[int]):
batch_perm, _, dense_perm = _validate_permutation(data, indices, permutation, shape)
n_batch = len(batch_perm)
indices_shape = np.array(indices.shape)[[*batch_perm, n_batch, n_batch + 1]]
data_shape = np.array(data.shape)[[*batch_perm, n_batch, *(d + n_batch + 1 for d in dense_perm)]]
return core.ShapedArray(data_shape, data.dtype), core.ShapedArray(indices_shape, indices.dtype)

def _bcoo_transpose_jvp(primals, tangents, *, permutation, shape):
data, indices = primals
data_dot, _ = tangents
primals_out = bcoo_transpose(data, indices, permutation=permutation, shape=shape)
data_dot_out, _ = bcoo_transpose(data_dot, indices, permutation=permutation, shape=shape)
return primals_out, (data_dot_out, ad.Zero.from_value(indices))

def _bcoo_transpose_transpose(ct, data, indices, *, permutation, shape):
data_ct, indices_ct = ct
assert isinstance(indices_ct, ad.Zero)
if ad.is_undefined_primal(indices):
raise ValueError("Cannot transpose with respect to sparse indices")
assert data_ct.dtype == data.aval.dtype
ct_shape = tuple(shape[p] for p in permutation)
rev_permutation = np.argsort(permutation)
# TODO(jakevdp) avoid dummy indices?
dummy_indices = jnp.zeros([1 for i in range(indices.ndim - 2)] + list(indices.shape[-2:]), dtype=int)
data_trans, _ = bcoo_transpose(data_ct, dummy_indices, permutation=rev_permutation, shape=ct_shape)
return data_trans, indices_ct

def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation, shape):
data, indices = batched_args
batch_dims = list(batch_dims)
batch_size = max(0 if dim is None else arg.shape[dim]
for arg, dim in zip(batched_args, batch_dims))
if batch_dims[0] is None:
data = data[None]
else:
assert batch_dims[0] == 0
if batch_dims[1] is None:
indices = indices[None]
else:
assert batch_dims[1] == 0
batched_shape = (batch_size, *shape)
batched_permutation = (0, *(p + 1 for p in permutation))
data, indices = bcoo_transpose(data, indices, permutation=batched_permutation, shape=batched_shape)
if batch_dims[0] is None:
data = data[0]
if batch_dims[1] is None:
indices = indices[0]
return (data, indices), batch_dims

ad.primitive_jvps[bcoo_transpose_p] = _bcoo_transpose_jvp
ad.primitive_transposes[bcoo_transpose_p] = _bcoo_transpose_transpose
batching.primitive_batchers[bcoo_transpose_p] = _bcoo_transpose_batch_rule
xla.translations[bcoo_transpose_p] = xla.lower_fun(
_bcoo_transpose_impl, multiple_results=True)

#----------------------------------------------------------------------
# bcoo_dot_general
# (batched) general dot product of a BCOO sparse ND array and a dense ND array,
Expand Down Expand Up @@ -1080,7 +1179,7 @@ def matvec(self, v):
def matmat(self, B):
raise NotImplementedError("matmat")

def transpose(self):
def transpose(self, axes=None):
raise NotImplementedError()

@property
Expand Down Expand Up @@ -1130,7 +1229,8 @@ def matvec(self, v):
def matmat(self, B):
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape)

def transpose(self):
def transpose(self, axes=None):
assert axes is None
return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])

def tree_flatten(self):
Expand Down Expand Up @@ -1168,7 +1268,8 @@ def matvec(self, v):
def matmat(self, B):
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)

def transpose(self):
def transpose(self, axes=None):
assert axes is None
return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1])

def tree_flatten(self):
Expand Down Expand Up @@ -1206,7 +1307,8 @@ def matvec(self, v):
def matmat(self, B):
return coo_matmat(self.data, self.row, self.col, B, shape=self.shape)

def transpose(self):
def transpose(self, axes=None):
assert axes is None
return COO((self.data, self.col, self.row), shape=self.shape[::-1])

def tree_flatten(self):
Expand Down Expand Up @@ -1271,10 +1373,11 @@ def __rmatmul__(self, other):
rhs_shape=self.shape,
dimension_numbers=(([other.ndim - 1], [0]), ([], [])))

def transpose(self):
if self.n_batch or self.n_dense:
raise NotImplementedError("BCOO transpose with batch or dense dimensions")
return BCOO((self.data, self.indices[::-1]), shape=self.shape[::-1])
def transpose(self, axes=None):
axes = np.arange(self.ndim)[::-1] if axes is None else axes
data_T, indices_T = bcoo_transpose(self.data, self.indices, shape=self.shape, permutation=axes)
shape_T = [self.shape[i] for i in axes]
return BCOO((data_T, indices_T), shape=shape_T)

def tree_flatten(self):
children = (self.data, self.indices)
Expand All @@ -1291,7 +1394,9 @@ def tree_unflatten(cls, aux_data, children):
if _is_dummy(data, indices):
shape = sparse_shape
else:
assert len(sparse_shape) == indices.shape[-2]
if np.ndim(indices) < 2 or len(sparse_shape) != np.shape(indices)[-2]:
raise ValueError(f"Invalid sparse representation: got indices.shape={np.shape(indices)}, "
f"data.shape={np.shape(data)}, sparse_shape={sparse_shape}")
n_batch = indices.ndim - 2
shape = (
tuple(np.maximum(data.shape[:n_batch], indices.shape[:n_batch]))
Expand Down
67 changes: 67 additions & 0 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,73 @@ def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense):
self.assertEqual(j1.shape, data.shape + M.shape)
self.assertEqual(hess.shape, data.shape + 2 * M.shape)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)))
def test_bcoo_transpose(self, shape, dtype, n_batch, n_dense):
n_sparse = len(shape) - n_batch - n_dense
rng = self.rng()
sprng = rand_sparse(rng)
M = sprng(shape, dtype)
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

permutation = np.concatenate([
rng.permutation(range(n_batch)),
rng.permutation(range(n_batch, n_batch + n_sparse)),
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)

M_T = M.transpose(permutation)
trans = partial(sparse.bcoo_transpose, shape=shape, permutation=permutation)
self.assertArraysEqual(M_T, sparse.bcoo_todense(*trans(data, indices), shape=M_T.shape))
self.assertArraysEqual(M_T, sparse.bcoo_todense(*jit(trans)(data, indices), shape=M_T.shape))

# test batched
def trans(M):
return M.transpose([p - n_batch for p in permutation[n_batch:]])
for _ in range(n_batch):
trans = jax.vmap(trans)
Msp = sparse.BCOO.fromdense(M, n_batch=n_batch, n_dense=n_dense)
self.assertArraysEqual(trans(M), trans(Msp).todense())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
"shape": shape, "dtype": dtype, "n_batch": n_batch, "n_dense": n_dense}
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for dtype in jtu.dtypes.floating
for n_batch in range(len(shape) + 1)
for n_dense in range(len(shape) + 1 - n_batch)))
def test_bcoo_transpose_ad(self, shape, dtype, n_batch, n_dense):
n_sparse = len(shape) - n_batch - n_dense
rng = self.rng()
sprng = rand_sparse(self.rng())

M = sprng(shape, dtype)
data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

permutation = np.concatenate([
rng.permutation(range(n_batch)),
rng.permutation(range(n_batch, n_batch + n_sparse)),
rng.permutation(range(n_batch + n_sparse, len(shape)))]).astype(int)

def f_sparse(data):
return sparse.bcoo_transpose(data, indices, shape=shape, permutation=permutation)[0]

jf_sparse = jax.jacfwd(f_sparse)(data)
jr_sparse = jax.jacrev(f_sparse)(data)

tol = {}
if jtu.device_under_test() == "tpu":
tol = {np.float32: 5E-3}

# TODO(jakevdp) also test against dense version?
self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_nbatch={}_ndense={}".format(
jtu.format_shape_dtype_string(shape, dtype), n_batch, n_dense),
Expand Down

0 comments on commit c358da4

Please sign in to comment.