In [2]:
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node

In [5]:
jnp.arange(5.0)

Array([0., 1., 2., 3., 4.], dtype=float32)

In [6]:
jnp.asarray(jnp.arange(5.0))

Array([0., 1., 2., 3., 4.], dtype=float32)

In [11]:
class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to `MyTree`.

TypeError: Value '<object object at 0x70435428bb20>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

In [12]:
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to `MyTree`.

  return array(a, dtype=dtype, copy=bool(copy), order=order)  # type: ignore


TypeError: Value '<object object at 0x70435428beb0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

### Potential solution 1:

* The `__init__` and `__new__` methods of custom pytree classes should generally avoid doing any array conversion or other input validation, or else anticipate and handle these special cases. For example:

In [16]:
class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a

Array([0., 1., 2., 3., 4.], dtype=float32)

### Potential solution 2:

Structure your custom `tree_unflatten` function so that it avoids calling `__init__`. If you choose this route, make sure that your `tree_unflatten` function stays in sync with `__init__` if and when the code is updated. Example:

In [None]:
def tree_unflatten(aux_data, children):
  del aux_data  # Unused in this class.
  obj = object.__new__(MyTree)
  obj.a = a
  return obj