New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Delayed initialisation of static fields #3
Comments
Thanks for the details + example! In general, having a static field depend on the value of a dynamic array is unlikely to be supported. For function transforms to work, it needs to be possible to instantiate each pytree type using only abstract tracer arrays; the error you're seeing is because the behavior of While for your particular case the expected behavior is pretty clear, a more concrete example for why this should be avoided might be:
Note that For workarounds: the best off the top of my head is just to pass the value of |
Surely I could pass it as a separate input, but I would like to avoid it to ensure the consistency of the data. Moving the computation out of Could you please specify how would you move the computation out of the |
It doesn't seem like there are any perfect options here, but two possible suggestions:
import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
@jdc.pytree_dataclass()
class PyTreeDataclass:
a: jnp.ndarray
_sum: int = jdc.static_field()
@staticmethod
def setup(a: jnp.ndarray) -> "PyTreeDataclass":
# Note that `a` must be a concrete array. This method cannot be called
# inside a JIT-compiled function.
return PyTreeDataclass(a, _sum=a.sum().item())
def print_pytree(obj):
print(obj._sum)
obj = PyTreeDataclass.setup(jnp.arange(4))
print_pytree(obj)
jax.jit(print_pytree)(obj)
@jdc.pytree_dataclass
class PyTreeDataclass:
a: Tuple[int, ...] = jdc.static_field()
@property
def _sum(self) -> int:
return sum(self.a) or you could do something hacky to wrap the array and make it hashable: @dataclasses.dataclass
class HashableArrayWrapper:
inner: onp.ndarray # Needs to initialized with a regular numpy array.
def __hash__(self) -> int:
return hash(str(self.inner.tolist()))
@jdc.pytree_dataclass
class PyTreeDataclass:
a: HashableArrayWrapper = jdc.static_field()
@property
def _sum(self) -> int:
return self.a.inner.sum().item() |
Thank you very much! They are very good options! |
First of all, thank you for the amazing library!
I have recently discovered
jax_dataclasses
and I have decided to port my messy JAX functional code to a more organised object-oriented code based onjax_dataclasses
.In my application, I have some derived quantities of the attributes of the dataclass that are static values used to determine the shape of tensors during JIT compilation. I would like to include them as attribute of the dataclass, but I'm getting an error and I would like to know if there is workaround.
Here is a simple example, where the attribute
_sum
is a derived static field that depends on the constant value of the arraya
.The non-jitted version works, but when
print_pytree
is jitted I get the following error.Is there a way to compute in the
__post_init__
the value of static fields not initialized in__init__
that depend onjnp.ndarray
attributes of the dataclass?The text was updated successfully, but these errors were encountered: