Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delayed initialisation of static fields #3

Closed
lucagrementieri opened this issue May 30, 2022 · 4 comments
Closed

Delayed initialisation of static fields #3

lucagrementieri opened this issue May 30, 2022 · 4 comments

Comments

@lucagrementieri
Copy link

lucagrementieri commented May 30, 2022

First of all, thank you for the amazing library!
I have recently discovered jax_dataclasses and I have decided to port my messy JAX functional code to a more organised object-oriented code based on jax_dataclasses.

In my application, I have some derived quantities of the attributes of the dataclass that are static values used to determine the shape of tensors during JIT compilation. I would like to include them as attribute of the dataclass, but I'm getting an error and I would like to know if there is workaround.

Here is a simple example, where the attribute _sum is a derived static field that depends on the constant value of the array a.

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc

@jdc.pytree_dataclass()
class PyTreeDataclass:
    a: jnp.ndarray
    _sum: int = jdc.static_field(init=False, repr=False)

    def __post_init__(self):
        object.__setattr__(self, "_sum", self.a.sum().item())

def print_pytree(obj):
    print(obj._sum)

obj = PyTreeDataclass(jnp.arange(4))
print_pytree(obj)
jax.jit(print_pytree)(obj)

The non-jitted version works, but when print_pytree is jitted I get the following error.

File "jax_dataclasses_issue.py", line 14, in __post_init__
    object.__setattr__(self, "_sum", self.a.sum().item())
AttributeError: 'bool' object has no attribute 'sum'

Is there a way to compute in the __post_init__ the value of static fields not initialized in __init__ that depend on jnp.ndarray attributes of the dataclass?

@brentyi
Copy link
Owner

brentyi commented May 30, 2022

Thanks for the details + example!

In general, having a static field depend on the value of a dynamic array is unlikely to be supported. For function transforms to work, it needs to be possible to instantiate each pytree type using only abstract tracer arrays; the error you're seeing is because the behavior of .sum().item() is not well-defined on tracers.

While for your particular case the expected behavior is pretty clear, a more concrete example for why this should be avoided might be:

@jax.jit
def zeros(a: jnp.ndarray) -> jnp.ndarray:
    container = PyTreeDataclass(a)
    return jnp.zeros(container._sum)   # `container._sum` should be an int

Note that a will be an abstract tracer during JIT compilation, and it won't make sense to compute _sum. The ultimate implication is that it becomes impossible to instantiate a PyTreeDataclass in any JIT-compiled function, which seems overly limiting. The whole zeros() function is of course also problematic because the shape of the output array depends on values in the input array, which isn't a pattern supported in XLA's static graph-based design.

For workarounds: the best off the top of my head is just to pass the value of _sum as a separate input, or to move its computation out of the __post_init__, both of which I imagine you've thought of. Maybe could have more thoughts if you have any more specifics on the application.

@lucagrementieri
Copy link
Author

lucagrementieri commented May 30, 2022

Surely I could pass it as a separate input, but I would like to avoid it to ensure the consistency of the data.

Moving the computation out of __post_init__ could be a good workaround for my use case, but I would like to keep it inside the dataclass, for example as a method. My problem is that if a jitted function needs the method, then its result is traced and I incur in the same problem.

Could you please specify how would you move the computation out of the __post_init__ while keeping it inside the dataclass?

@brentyi
Copy link
Owner

brentyi commented May 30, 2022

It doesn't seem like there are any perfect options here, but two possible suggestions:

  1. If _sum needs to be computed automatically, one option is to make a helper that does just that. This should run:
import jax
import jax.numpy as jnp
import jax_dataclasses as jdc

@jdc.pytree_dataclass()
class PyTreeDataclass:
    a: jnp.ndarray
    _sum: int = jdc.static_field()

    @staticmethod
    def setup(a: jnp.ndarray) -> "PyTreeDataclass":
        # Note that `a` must be a concrete array. This method cannot be called
        # inside a JIT-compiled function.
        return PyTreeDataclass(a, _sum=a.sum().item())

def print_pytree(obj):
    print(obj._sum)

obj = PyTreeDataclass.setup(jnp.arange(4))
print_pytree(obj)
jax.jit(print_pytree)(obj)
  1. Or, perhaps a itself could be static? Arrays aren't hashable, so a tuple might work:
@jdc.pytree_dataclass
class PyTreeDataclass:
    a: Tuple[int, ...] = jdc.static_field()

    @property
    def _sum(self) -> int:
        return sum(self.a)

or you could do something hacky to wrap the array and make it hashable:

@dataclasses.dataclass
class HashableArrayWrapper:
    inner: onp.ndarray   # Needs to initialized with a regular numpy array.
    def __hash__(self) -> int:
        return hash(str(self.inner.tolist()))

@jdc.pytree_dataclass
class PyTreeDataclass:
    a: HashableArrayWrapper = jdc.static_field()

    @property
    def _sum(self) -> int:
        return self.a.inner.sum().item()

@lucagrementieri
Copy link
Author

Thank you very much! They are very good options!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants