In [5]:
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node_class
import jax.numpy as jnp

In [7]:
class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return "Special(x={}, y={})".format(self.x, self.y)

In [8]:
def show_example(structured):
  flat, tree = tree_flatten(structured)
  unflattened = tree_unflatten(tree, flat)
  print(f"{structured=}\n  {flat=}\n  {tree=}\n  {unflattened=}")

In [None]:
from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # tell JAX what are the children nodes
    special_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)

show_example(RegisteredSpecial(1., 2.))

In [9]:
@register_pytree_node_class
class RegisteredSpecial2(Special):
  def __repr__(self):
    return "RegisteredSpecial2(x={}, y={})".format(self.x, self.y)

  # 여기에 classmethod를 사용하면 flatten시에 class 변수가 flatten되므로 그러면 안된다.
  def tree_flatten(self):
    children = (self.x, self.y)
    aux_data = None
    return (children, aux_data)

  # instance에 종속된 값이 없고 앞으로도 그럴 일이 없기 때문에 classmethod사용
  @classmethod
  def tree_unflatten(cls, aux_data, children):
    return cls(*children)

show_example(RegisteredSpecial2(1., 2.))

structured=RegisteredSpecial2(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredSpecial2[None], [*, *]))
  unflattened=RegisteredSpecial2(x=1.0, y=2.0)


In [10]:
@register_pytree_node_class
class RegisteredSpecial3(Special):
  def __repr__(self):
    return "RegisteredSpecial3(x={}, y={})".format(self.x, self.y)

  def tree_flatten(self):
    children = (self.x, self.y)
    aux_data = None
    return (children, aux_data)

  # classmethod를 사용하지 않으면 어떻게 될까?
  def tree_unflatten(self, aux_data, children):
    return self(*children)

show_example(RegisteredSpecial2(1., 2.))

structured=RegisteredSpecial2(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(RegisteredSpecial2[None], [*, *]))
  unflattened=RegisteredSpecial2(x=1.0, y=2.0)
