Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subclassing DeviceArray #4269

Open
lukepfister opened this issue Sep 11, 2020 · 7 comments
Open

Subclassing DeviceArray #4269

lukepfister opened this issue Sep 11, 2020 · 7 comments
Assignees
Labels
question Questions for the JAX team

Comments

@lukepfister
Copy link
Contributor

I'm in the process of converting a numpy "block matrix" class to work in Jax. The block matrix class is a subclass of np.ndarray and essentially contains a single contiguous array with some extra indexing capabilities to access various blocks.

To make this work in Jax, I thought I could subclass DeviceArray and add the extra attributes / methods. I'm having a hard time doing so, though.

Two questions:

  1. Should I be subclassing DeviceArray or lax.ndarray?
  2. Are there any working examples of subclassing DeviceArray? I saw issue add attribute to devicearray? #2464, but the issue remains open and I get the same error as the author.

Thanks for your work on the library-- it is a joy to use!

@shoyer
Copy link
Member

shoyer commented Sep 12, 2020

The short answer is that JAX doesn't support subclassing its array objects at present, and I'm not sure if it ever will. Subclassing in general is can be pretty error prone, unless a class was designed from the start for subclassing. I wouldn't recommend it even with NumPy arrays.

Could you be a bit more specific about the use-cases you are trying to solve with subclassing? Would it suffice to just add some helper functions for working with your specific type of arrays?

@jakevdp is working on adding support for sparse arrays in JAX, which sounds a little similar to this.

@mattjj
Copy link
Member

mattjj commented Sep 12, 2020

@lukepfister thanks for the kind words!

As usual most of my comment is agreeing with @shoyer , though hopefully adding some extra details.

The main issue is that subclassing DeviceArray won't make the subclass compatible with JAX. To make a new type compatible with JAX, you basically have to register a bunch of handlers to teach JAX what abstract values are associated with the new type (core.pytype_aval_mappings and xla.pytype_aval_mappings), what the XLA representation is (xla.device_put_handlers, xla.result_handlers, and xla.xla_shape_handlers the first two being about data-level mapping while the latter is about type-level mapping, as well as xla.canonicalize_dtype_handlers). You can see how these tables are indexed on types, which is closer to multiple dispatch than OOP.

So the system is extensible, but not by subclassing. (And I agree with @shoyer in general that inheritance is often an anti-pattern.) Jake is indeed working on adding new types to JAX, and his code will probably serve as a nice example of how to do it (plus he fixed some bugs that had crept in because we haven't used this extensibility much).

But in almost all cases you don't need to add new types to JAX, and can just represent what you need as a composite data type, perhaps registered as a pytree. I suspect that will work well here, and subclassing would prove unnecessary. In the numpy (not JAX) case, what do you get from subclassing np.ndarray?

@mattjj
Copy link
Member

mattjj commented Sep 12, 2020

As a reference example that could be helpful, here's a kind of block matrix class we used in a now-deleted file called lapax.py when we used to roll our own linear algebra routines:

class LapaxMatrix(object):
  """A matrix model using LAX functions and tweaked index rules from Numpy."""
  __slots__ = ["ndarray", "bs", "shape"]

  def __init__(self, ndarray, block_size=1):
    self.ndarray = ndarray
    self.bs = block_size
    self.shape = tuple(onp.floor_divide(ndarray.shape, block_size)
                       + (onp.mod(ndarray.shape, block_size) > 0))

  def __getitem__(self, idx):
    return LapaxMatrix(_matrix_take(self.ndarray, idx, self.bs), block_size=1)

  def __setitem__(self, idx, val):
    self.ndarray = _matrix_put(self.ndarray, idx, val.ndarray, self.bs)

  def bview(self, block_size):
    return LapaxMatrix(self.ndarray, block_size=block_size)

  __add__ = _make_infix_op(lax.add)
  __sub__ = _make_infix_op(lax.sub)
  __mul__ = _make_infix_op(lax.batch_matmul)
  __div__ = _make_infix_op(lax.div)
  __truediv__ = _make_infix_op(lax.div)
  T = property(_make_infix_op(_matrix_transpose))


# Utility functions for block access of ndarrays


def _canonical_idx(shape, idx_elt, axis, block_size=1):
  """Canonicalize the indexer `idx_elt` to a slice."""
  k = block_size
  block_dim = shape[axis] // k + bool(shape[axis] % k)
  if isinstance(idx_elt, int):
    idx_elt = idx_elt % block_dim
    idx_elt = slice(idx_elt, idx_elt + 1, 1)
  indices = tuple(onp.arange(block_dim)[idx_elt])
  if not indices:
    return slice(0, 0, 1), False  # sliced to size zero
  start, stop_inclusive = indices[0], indices[-1]
  step = 1 if idx_elt.step is None else idx_elt.step
  if k != 1 and step != 1:
    raise TypeError("Non-unit step supported only with block_size=1")
  if step > 0:
    end = min(k * (stop_inclusive + step), shape[axis])
    return slice(k * start, end, step), False
  else:
    end = min(k * (start - step), shape[axis])
    return slice(k * stop_inclusive, end, -step), True


def _matrix_put(ndarray, idx, val, block_size=1):
  """Similar to numpy.put using LAX operations."""
  idx_i, idx_j = idx
  sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size)
  slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size)
  if not sli.step == slj.step == 1:
    raise TypeError("Non-unit step not supported in assigment.")

  if row_rev or col_rev:
    val = lax.rev(val, *onp.where([row_rev, col_rev]))

  start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start]
  return lax.dynamic_update_slice(ndarray, val, start_indices)


def _matrix_take(ndarray, idx, block_size=1):
  """Similar to numpy.take using LAX operations."""
  idx_i, idx_j = idx
  sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size)
  slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size)

  start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start]
  limit_indices = list(ndarray.shape[:-2]) + [sli.stop, slj.stop]
  strides = [1] * (ndarray.ndim - 2) + [sli.step, slj.step]
  out = lax.slice(ndarray, start_indices, limit_indices, strides)

  if row_rev or col_rev:
    out = lax.rev(out, *onp.where([row_rev, col_rev]))
  return out

I encourage you to take a look at the full module for context (it's quite short!). It implements blocked Cholesky and blocked triangular solves using that LapaxMatrix class.

@mattjj mattjj added the question Questions for the JAX team label Sep 12, 2020
@mattjj mattjj self-assigned this Sep 12, 2020
@mattjj
Copy link
Member

mattjj commented Sep 12, 2020

By the way, that LapaxMatrix class supports in-place updating, but that's dangerous in general because aliasing semantics won't be preserved across jit boundaries (that's cryptic but I can unpack it if really needed). We used it ourselves internally without exposing it to users.

@lukepfister
Copy link
Contributor Author

Thanks for this!

Could you be a bit more specific about the use-cases you are trying to solve with subclassing?

In the numpy (not JAX) case, what do you get from subclassing np.ndarray?

Sure. I'm working in a situation where we frequently have objects that are best thought of as block-arrays with heterogeneous dimensions. A common example is taking horizontal & vertical finite differences of an image; we take an image of size (n, n) and return images of size (n, n-1) and (n-1, n). We'd like to keep these together. One option is to unravel everything and treat the output as a single vector of length 2 n(n-1), but it is handy to think of them as something like a block-vector with block sizes (n, n-1) and (n-1, n).

Subclassing ndarray gives us the best of both worlds: we have a little bit of extra indexing logic (similar to the LapaxMatrix example) that treats a vector as a block vector, but we can transparently think of it as a single flattened vector.

Importantly, by subclassing ndarray and dealing with __array_wrap__, we can use our Block Arrays anywhere that an ndarray goes. For example, np.linalg.norm() and np.sin() work as expected.

But in almost all cases you don't need to add new types to JAX, and can just represent what you need as a composite data type, perhaps registered as a pytree.
Thanks! I got the hang of this now. I registered my type as a pytree and can convert between my BlockArray and DeviceBlockArray classes.

Now, is there a way to register this type such that jax.numpy functions can accept it as an input? I see that this is an issue shared with the LapaxMatrix class. For example, if A = LaxpaxMatrix(...), can I make jax.numpy.sin(A) = jax.numpy.sin(A.ndarray)?

Thanks again for your help.

@KeAWang
Copy link

KeAWang commented Dec 15, 2022

@lukepfister Were you able figure out a way to register your custom array type with jax.numpy operations?

@lukepfister
Copy link
Contributor Author

@KeAWang Nope! we wound up wrapping jax.numpy in a new module that either called jnp functions or did some custom handling. It is messy. See https://github.com/lanl/scico/tree/main/scico/numpy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

4 participants