New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subclassing DeviceArray #4269
Comments
The short answer is that JAX doesn't support subclassing its array objects at present, and I'm not sure if it ever will. Subclassing in general is can be pretty error prone, unless a class was designed from the start for subclassing. I wouldn't recommend it even with NumPy arrays. Could you be a bit more specific about the use-cases you are trying to solve with subclassing? Would it suffice to just add some helper functions for working with your specific type of arrays? @jakevdp is working on adding support for sparse arrays in JAX, which sounds a little similar to this. |
@lukepfister thanks for the kind words! As usual most of my comment is agreeing with @shoyer , though hopefully adding some extra details. The main issue is that subclassing DeviceArray won't make the subclass compatible with JAX. To make a new type compatible with JAX, you basically have to register a bunch of handlers to teach JAX what abstract values are associated with the new type ( So the system is extensible, but not by subclassing. (And I agree with @shoyer in general that inheritance is often an anti-pattern.) Jake is indeed working on adding new types to JAX, and his code will probably serve as a nice example of how to do it (plus he fixed some bugs that had crept in because we haven't used this extensibility much). But in almost all cases you don't need to add new types to JAX, and can just represent what you need as a composite data type, perhaps registered as a pytree. I suspect that will work well here, and subclassing would prove unnecessary. In the numpy (not JAX) case, what do you get from subclassing np.ndarray? |
As a reference example that could be helpful, here's a kind of block matrix class we used in a now-deleted file called lapax.py when we used to roll our own linear algebra routines: class LapaxMatrix(object):
"""A matrix model using LAX functions and tweaked index rules from Numpy."""
__slots__ = ["ndarray", "bs", "shape"]
def __init__(self, ndarray, block_size=1):
self.ndarray = ndarray
self.bs = block_size
self.shape = tuple(onp.floor_divide(ndarray.shape, block_size)
+ (onp.mod(ndarray.shape, block_size) > 0))
def __getitem__(self, idx):
return LapaxMatrix(_matrix_take(self.ndarray, idx, self.bs), block_size=1)
def __setitem__(self, idx, val):
self.ndarray = _matrix_put(self.ndarray, idx, val.ndarray, self.bs)
def bview(self, block_size):
return LapaxMatrix(self.ndarray, block_size=block_size)
__add__ = _make_infix_op(lax.add)
__sub__ = _make_infix_op(lax.sub)
__mul__ = _make_infix_op(lax.batch_matmul)
__div__ = _make_infix_op(lax.div)
__truediv__ = _make_infix_op(lax.div)
T = property(_make_infix_op(_matrix_transpose))
# Utility functions for block access of ndarrays
def _canonical_idx(shape, idx_elt, axis, block_size=1):
"""Canonicalize the indexer `idx_elt` to a slice."""
k = block_size
block_dim = shape[axis] // k + bool(shape[axis] % k)
if isinstance(idx_elt, int):
idx_elt = idx_elt % block_dim
idx_elt = slice(idx_elt, idx_elt + 1, 1)
indices = tuple(onp.arange(block_dim)[idx_elt])
if not indices:
return slice(0, 0, 1), False # sliced to size zero
start, stop_inclusive = indices[0], indices[-1]
step = 1 if idx_elt.step is None else idx_elt.step
if k != 1 and step != 1:
raise TypeError("Non-unit step supported only with block_size=1")
if step > 0:
end = min(k * (stop_inclusive + step), shape[axis])
return slice(k * start, end, step), False
else:
end = min(k * (start - step), shape[axis])
return slice(k * stop_inclusive, end, -step), True
def _matrix_put(ndarray, idx, val, block_size=1):
"""Similar to numpy.put using LAX operations."""
idx_i, idx_j = idx
sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size)
slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size)
if not sli.step == slj.step == 1:
raise TypeError("Non-unit step not supported in assigment.")
if row_rev or col_rev:
val = lax.rev(val, *onp.where([row_rev, col_rev]))
start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start]
return lax.dynamic_update_slice(ndarray, val, start_indices)
def _matrix_take(ndarray, idx, block_size=1):
"""Similar to numpy.take using LAX operations."""
idx_i, idx_j = idx
sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size)
slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size)
start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start]
limit_indices = list(ndarray.shape[:-2]) + [sli.stop, slj.stop]
strides = [1] * (ndarray.ndim - 2) + [sli.step, slj.step]
out = lax.slice(ndarray, start_indices, limit_indices, strides)
if row_rev or col_rev:
out = lax.rev(out, *onp.where([row_rev, col_rev]))
return out I encourage you to take a look at the full module for context (it's quite short!). It implements blocked Cholesky and blocked triangular solves using that LapaxMatrix class. |
By the way, that LapaxMatrix class supports in-place updating, but that's dangerous in general because aliasing semantics won't be preserved across |
Thanks for this!
Sure. I'm working in a situation where we frequently have objects that are best thought of as block-arrays with heterogeneous dimensions. A common example is taking horizontal & vertical finite differences of an image; we take an image of size Subclassing Importantly, by subclassing
Now, is there a way to register this type such that Thanks again for your help. |
@lukepfister Were you able figure out a way to register your custom array type with |
@KeAWang Nope! we wound up wrapping jax.numpy in a new module that either called |
I'm in the process of converting a numpy "block matrix" class to work in Jax. The block matrix class is a subclass of
np.ndarray
and essentially contains a single contiguous array with some extra indexing capabilities to access various blocks.To make this work in Jax, I thought I could subclass
DeviceArray
and add the extra attributes / methods. I'm having a hard time doing so, though.Two questions:
DeviceArray
orlax.ndarray
?DeviceArray
? I saw issue add attribute to devicearray? #2464, but the issue remains open and I get the same error as the author.Thanks for your work on the library-- it is a joy to use!
The text was updated successfully, but these errors were encountered: