Skip to content

Commit

Permalink
[sparse] fix batched serialization of BCOO
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 22, 2021
1 parent f885366 commit e18183e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
30 changes: 29 additions & 1 deletion jax/experimental/sparse/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,10 @@ def tree_flatten(self):
return (self.data, self.row, self.col), {"shape": self.shape}


def _is_dummy(*args):
return all(type(arg) is object for arg in args) or all(arg is None for arg in args)


@tree_util.register_pytree_node_class
class BCOO(JAXSparse):
"""Experimental BCOO matrix implemented in JAX; API subject to change."""
Expand All @@ -1227,6 +1231,10 @@ class BCOO(JAXSparse):
n_dense = property(lambda self: self.data.ndim - 1 - self.n_batch)
shape = Tuple[int, ...]

@property
def _sparse_shape(self):
return tuple(self.shape[self.indices.ndim - 2:][:self.indices.shape[-2]])

def __init__(self, args, *, shape):
self.data, self.indices = args
super().__init__(args, shape=shape)
Expand Down Expand Up @@ -1271,4 +1279,24 @@ def transpose(self):
return BCOO((self.data, self.indices[::-1]), shape=self.shape[::-1])

def tree_flatten(self):
return (self.data, self.indices), {"shape": self.shape}
children = (self.data, self.indices)
# pytree sometimes creates dummy objects & we need to handle that.
sparse_shape = self.shape if _is_dummy(*children) else self._sparse_shape
# We serialize the sparse shape only to support batching.
return children, {"sparse_shape": sparse_shape}

@classmethod
def tree_unflatten(cls, aux_data, children):
data, indices = children
sparse_shape = aux_data["sparse_shape"]
# pytree sometimes creates dummy objects & we need to handle that.
if _is_dummy(data, indices):
shape = sparse_shape
else:
assert len(sparse_shape) == indices.shape[-2]
n_batch = indices.ndim - 2
shape = (
tuple(np.maximum(data.shape[:n_batch], indices.shape[:n_batch]))
+ tuple(sparse_shape)
+ tuple(data.shape[n_batch + 1:]))
return cls(children, shape=shape)
14 changes: 14 additions & 0 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from functools import partial
import itertools
from jax._src.api import vmap
import unittest

from absl.testing import absltest
Expand Down Expand Up @@ -838,6 +839,19 @@ def test_bcoo_matmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype):
self.assertAllClose(out1, out2, rtol=tol)
self.assertAllClose(out1, out3, rtol=tol)

def test_bcoo_vmap_shape(self, shape=(2, 3, 4, 5), dtype=np.float32):
# This test checks that BCOO shape metadata interacts correctly with vmap.
rng = rand_sparse(self.rng())
M = rng(shape, dtype)

def make_bcoo(M):
return sparse.BCOO.fromdense(M, nnz=np.prod(M.shape[:-1], dtype=int), n_dense=1)

for _ in range(3):
make_bcoo = vmap(make_bcoo)
Msp = make_bcoo(M)
self.assertEqual(Msp.shape, M.shape)
self.assertArraysEqual(Msp.todense(), M)

class SparseGradTest(jtu.JaxTestCase):
def test_sparse_grad(self):
Expand Down

0 comments on commit e18183e

Please sign in to comment.