<a href="https://colab.research.google.com/github/jecampagne/JaxTutos/blob/main/JAX_PyTree_initialisation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
from jax import lax, grad, jit, vmap
from jax.tree_util import register_pytree_node_class

import jax.numpy as jnp

# Topic: more on PyTrees

 - When tree_flatten/tree_unflatten are triggered?
 - Initialisation that crashes: Why? a solution
 - More advenced with different kinds of variable usage

## When tree_flatten/tree_unflatten are triggered <=> statefull initialisation

In [5]:
@register_pytree_node_class
class State:
  def __init__(self, n: int):
    self.n = n
    self.stateful()
    jax.debug.print('   n = {} # init', self.n)  # this print the value not the kind of Array Traced by JAX

  def stateful(self):
    self.n += 100

  def tree_flatten(self):
    jax.debug.print('   n = {} # flatten', self.n + 10)
    return (self.n,), {}

  @classmethod
  def tree_unflatten(cls, aux_data, children):
    n, = children
    jax.debug.print('   n = {} # unflatten', n + 1)
    return cls(n)  # runs `__init__` and in turn `stateful`

As a reminder jax.lax.scan is equivalent of
```python
def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)
```

In [6]:
def body_fun(carry, i):
  # unflatten
  # - __init__
  jax.debug.print('---')
  jax.debug.print('{}: n = {} # body', i, carry.n)
  # flatten
  return carry, carry.n   # carry.n is the "y"

s = State(0)   # this is the carry
xs = jnp.arange(3, dtype=jnp.int16)
print("xs:", xs)
print('xxx')
_, y = lax.scan(body_fun, s, xs)   # scan will first compile body_fun
print("y: ",y)

   n = 100 # init
xs: [0 1 2]
xxx
   n = 110 # flatten
   n = 110 # flatten
   n = 101 # unflatten
   n = 200 # init
---
0: n = 200 # body
   n = 210 # flatten
   n = 201 # unflatten
   n = 300 # init
---
1: n = 300 # body
   n = 310 # flatten
   n = 301 # unflatten
   n = 400 # init
---
2: n = 400 # body
   n = 410 # flatten
   n = 401 # unflatten
   n = 500 # init
y:  [200 300 400]


You see that : this is the body output that is stored , but the initialisation is called before and so the state of the object is silently changed before the scan is performed.

## Initialisation of a PyTree

### Let us start with a crash

In [20]:
class BaseObj():
    def __init__(self, *, gsparams=None, **params):
      self._params = params      # Dictionary containing all traced parameters
      self._gsparams = gsparams  # Non-traced static parameters
      jax.debug.print('Init BaseObj') 


    @property
    def gsparams(self):
        """A `GSParams` object that sets various parameters relevant for speed/accuracy trade-offs."""
        return self._gsparams

    @property
    def params(self):
        """A Dictionary object containing all parameters of the internal represention of this object."""
        return self._params


    def tree_flatten(self):
        """This function flattens the BaseObj into a list of children
        nodes that will be traced by JAX and auxiliary static data."""
        jax.debug.print('tree_flatten') 
        # Define the children nodes of the PyTree that need tracing
        children = (self.params,)
        # Define auxiliary static data that doesn’t need to be traced
        aux_data = {"gsparams": self.gsparams}
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Recreates an instance of the class from flatten representation"""
        jax.debug.print('tree_unflatten') 
        return cls(**(children[0]), **aux_data)

@register_pytree_node_class
class T(BaseObj):
  def __init__(self, a, gsparams=None):
    super().__init__(a=a, gsparams=gsparams)

    jax.debug.print("Init T: a={}, type {}",a, type(a)) 
    self._m = a*2     #<<<<<<<<<<<<<<  pb if decommented


  @property
  def a(self):
    return self.params["a"]
  
  def f(self,x):
    return self.a*x + self._m

If I instantiate a collection of T objs and apply a vmap

In [18]:
avals = jnp.pi * jnp.array([1.,2.])
ts = T(avals)

Init T: a=[3.1415927 6.2831855], type <class 'jaxlib.xla_extension.Array'>
Init BaseObj


In [19]:
jax.vmap(lambda t: t.f(1.))(ts)

tree_flatten
tree_unflatten
Init T: a=<object object at 0x7fe510245e20>, type <class 'object'>


TypeError: ignored

As you see `a` is Traced object, so it does not make sense to perform any operation with it. Notice that the crash doesn't depends of the "super" init position.

### One simple solution in that case

In [21]:
@register_pytree_node_class
class T(BaseObj):
  def __init__(self, a, gsparams=None):
    super().__init__(a=a, gsparams=gsparams)

    jax.debug.print("Init T: a={}",a) 


  @property
  def _m(self):                #<-----
    return self.a * 2

  @property
  def a(self):
    return self.params["a"]
  
  @jit
  def f(self,x):
    return self.a*x + self._m

In [22]:
avals = jnp.pi * jnp.array([1.,2.])
ts = T(avals)

Init BaseObj
Init T: a=[3.1415927 6.2831855]


In [23]:
jax.vmap(lambda t: t.f(1.))(ts)

tree_flatten
tree_unflatten
Init BaseObj
Init T: a=<object object at 0x7fe510245f60>
tree_flatten
tree_unflatten
Init BaseObj
Init T: a=3.1415927410125732
Init T: a=6.2831854820251465
tree_flatten
tree_flatten
tree_unflatten
Init BaseObj
Init T: a=3.1415927410125732
Init T: a=6.2831854820251465


Array([ 9.424778, 18.849556], dtype=float32)

In [24]:
jax.vmap(lambda t: t.f(2.))(ts)

tree_flatten
tree_unflatten
Init BaseObj
Init T: a=<object object at 0x7fe4ffa20940>
tree_flatten
tree_unflatten
Init BaseObj
Init T: a=3.1415927410125732
Init T: a=6.2831854820251465
tree_flatten
tree_flatten
tree_unflatten
Init BaseObj
Init T: a=3.1415927410125732
Init T: a=6.2831854820251465


Array([12.566371, 25.132742], dtype=float32)

It works as expected but may be not convenient it the computation of `_m` in the above example needs a more sophisticated plan. Ie. we do not want to repeate it if for instance an other function need it.

## A more complete ex with different kinds of variables

The classes below are just for academic purposes to show the init & use of some variables that you can encontered in your use-case.  

In [32]:
from functools import partial

from typing import Optional
import jax.tree_util as tu
from jax import Array, jit, lax, vmap
from jax.typing import ArrayLike

import numpy as np
import numpy.typing as npt
rng = np.random.default_rng(0)


### A Numpy class

In [47]:
class ExprNp:
  def __init__(self, name: str, steps: int, const: npt.NDArray,
               *, start: int = 0):
    self.name = name  # metadata
    self.steps = steps  # constant, static
    self.const = const  # constant, array
    # state, initialized in __init__, modified throughout
    self._start = start
    self._state = np.zeros_like(const, dtype=np.float32)
    # value, depends on constant, initialized in __init__, never modified again
    self._y = self._init_y(const)
    # state, depends on state, initialized in __init__, modified throughout
    self._sum = np.zeros(steps)

  @property
  def state(self):
    return self._state, self._sum

  @staticmethod
  def _init_y(c: npt.NDArray):
    return np.sin(c)

  def _step(self, i):
    self._state += self._y
    self._sum[i] += 1

  def main(self, steps: Optional[int] = None, *, start: Optional[int] = None):
    steps = steps or self.steps
    start = start or self._start
    end = start + steps
    for i in range(start, end):
      self._step(i)
    self._start = end
    return self

### The JAX class

In [169]:
a = None or jnp.array([0.]) 
print(a)

[0.]


In [155]:
@register_pytree_node_class
class ExprJax:
  def __init__(self, name: str, steps: int, const: Array, start: Array,
               _state: Array, _sum: Array):
    #
    # the user should use the explicit init
    #

    # init in __init__, not changed after
    self.name = name    # metadata here a string
    self.steps = steps  # constant, static
    self.const = const  # constant, array
    # state, initialized in __init__, modified throughout
    self._start = start 
    self._state = _state
    # state, depends on state, initialized in __init__, modified throughout
    self._sum = _sum

    # the self._y is the same as "self._m" is the T Class at the beginning (see @ property)

  # user initialisation
  @classmethod
  def init(cls, name: str, steps: int, const: ArrayLike, *, start: int = 0):
    _state = jnp.zeros_like(const, dtype=jnp.float32)
    _start = jnp.full(_state.shape[:-1], start)
    _sum = jnp.zeros(_state.shape[:-1] + (steps,))
    return cls(name, steps, jnp.array(const), _start, _state=_state, _sum=_sum)

  @property
  def state(self):
    return self._state, self._sum

  def tree_flatten(self):
    traced = (self.const, self._start, self._state, self._sum) #traced Arrays
    aux_data = (self.name, self.steps)
    # print("tree_flatten: ",aux_data, " | ", children) # decomment if you want what are traced obj...
    return traced, aux_data

  @classmethod
  def tree_unflatten(cls, aux_data, traced):
    # print("tree_unflatten: ",aux_data, " | ", children)# decomment if you want what are traced obj...
    return cls(*aux_data, *traced)

  @property
  def _y(self):
    return jnp.sin(self.const)

  def _step(self, i: Array):
    self._state += self._y
    self._sum = self._sum.at[i].add(1)

  @partial(jit, static_argnums=(1,))   # in case of use printing you need to desable jit
  def main(self, steps: Optional[int] = None, *, start: Array = jnp.array(0)):
    assert self._state.ndim == 1
    steps = steps or self.steps
    start = start or self._start

    def body(i: int, self: ExprJax):
      jax.debug.print("body i:{}",i) 
      self._step(start + i)
      return self

    # here we can do better  using scan but let keep simple  
    # The JIT will be triggered implicitly in cascade: body, _step
    self = lax.fori_loop(0, steps, body, self)    
    self._start = start + steps

    return self

In [156]:
N=10
nsteps= 3
starters = np.arange(N).reshape(2, -1)
starters

array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]])

In [157]:
res_np = [ExprNp('sim_np', N, s).main(nsteps).state for s in starters]
res_np

[(array([ 0.        ,  2.5244129 ,  2.7278922 ,  0.42336002, -2.2704074 ],
        dtype=float32), array([1., 1., 1., 0., 0., 0., 0., 0., 0., 0.])),
 (array([-2.8767729 , -0.83824646,  1.9709598 ,  2.9680748 ,  1.2363554 ],
        dtype=float32), array([1., 1., 1., 0., 0., 0., 0., 0., 0., 0.]))]

In [158]:
#just to see what is computed
y = np.sin(starters)*nsteps
y

array([[ 0.        ,  2.52441295,  2.72789228,  0.42336002, -2.27040749],
       [-2.87677282, -0.83824649,  1.9709598 ,  2.96807474,  1.23635546]])

In [159]:
[s for s in starters]

[array([0, 1, 2, 3, 4]), array([5, 6, 7, 8, 9])]

In [160]:
# as a list comprehension
res_jax = [ExprJax.init('sim_jax',N, s).main(nsteps).state for s in starters] 
res_jax

body i:0
body i:1
body i:2
body i:0
body i:1
body i:2


[(Array([ 0.       ,  2.5244129,  2.7278922,  0.42336  , -2.2704074],      dtype=float32),
  Array([1., 1., 1., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 (Array([-2.8767729 , -0.83824646,  1.9709598 ,  2.9680748 ,  1.2363555 ],      dtype=float32),
  Array([1., 1., 1., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))]

In [161]:
# as a vmap on ExprJax objects
sim_jax = ExprJax.init('sim_jax', N, starters)
print("sim_jax: ", sim_jax)
res_mapped = jit(vmap(partial(ExprJax.main, steps=nsteps)))(sim_jax)
res_mapped.state

sim_jax:  <__main__.ExprJax object at 0x7fe4fef20fa0>
body i:0
body i:1
body i:2


(Array([[ 0.        ,  2.5244129 ,  2.7278922 ,  0.42336   , -2.2704074 ],
        [-2.8767729 , -0.83824646,  1.9709598 ,  2.9680748 ,  1.2363555 ]],      dtype=float32),
 Array([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32))

In [162]:
#double the number of steps to shaow the results & compare with later code
sim_jax = ExprJax.init('sim_jax', N, starters)
res_mapped = jit(vmap(partial(ExprJax.main, steps=2*nsteps)))(sim_jax)
res_mapped.state

body i:0
body i:1
body i:2
body i:3
body i:4
body i:5


(Array([[ 0.        ,  5.0488253 ,  5.4557843 ,  0.84672004, -4.540815  ],
        [-5.7535458 , -1.6764929 ,  3.9419198 ,  5.9361496 ,  2.4727108 ]],      dtype=float32),
 Array([[1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.]], dtype=float32))

In [163]:
run3 = jit(vmap(partial(ExprJax.main, steps=3)))    #3 steps
run6 = jit(vmap(partial(ExprJax.main, steps=6)))    #6 steps

sim_jax33 = run3(run3(ExprJax.init('sim_jax', N, starters)))  # apply twice the 3steps function
sim_jax6 = run6(ExprJax.init('sim_jax', N, starters))         # apply once  the 6steps function


body i:0
body i:1
body i:2
body i:0
body i:1
body i:2
body i:0
body i:1
body i:2
body i:3
body i:4
body i:5


In [164]:
sim_jax33.state

(Array([[ 0.        ,  5.0488253 ,  5.4557843 ,  0.84672004, -4.540815  ],
        [-5.7535458 , -1.6764929 ,  3.9419198 ,  5.9361496 ,  2.4727108 ]],      dtype=float32),
 Array([[1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.]], dtype=float32))

In [165]:
sim_jax6.state

(Array([[ 0.        ,  5.0488253 ,  5.4557843 ,  0.84672004, -4.540815  ],
        [-5.7535458 , -1.6764929 ,  3.9419198 ,  5.9361496 ,  2.4727108 ]],      dtype=float32),
 Array([[1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.]], dtype=float32))

# Takeaway

- simple initialisation as in the case of Numpy cannot be used in JAX PyTree. Reminder: see JAX_JIT_in_class notebook to see why PyTrees should be considered.
- initialisation should be adapted to each kind of variables later usage
