<a href="https://colab.research.google.com/github/jecampagne/JaxTutos/blob/main/JAX_JIT_in_class.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import scipy as sc


import jax
import jax.numpy as jnp
import jax.scipy as jsc

from jax import grad, jit, vmap
from jax import jacfwd, jacrev, hessian
#from jax.ops import index, index_update
jax.config.update("jax_enable_x64", True)

from functools import partial
from typing import Union, Dict, Callable, Optional, Tuple, Any

In [2]:
jax.__version__

'0.4.26'

In [3]:
def u_print(idx:int, *args)->int:
    print(f"{idx}):",*args)
    idx +=1
    return idx


# Topic: JIT and Class methods

This notebook is a bit technical but deals with a real problem when one comes from the Object Oriented world (ex. C++ & Python).

it's an element of a JAX thread that I initiated and that becomes a part of the doc [see FAQ](https://jax.readthedocs.io/en/latest/faq.html#strategy-1-jit-compiled-helper-function)

In [5]:
class A():
    def __init__(self, a: float):
        print("New A")
        self.a = a                   # a variable that will be "static" once the obj is created
        self.b = None                # a variable ("dynamic") updated by f

    def f(self, var: float) -> None:
        self.b = self.a * var

######

idp=0  # (just an index for printing)

objA = A(2.0)
idp = u_print(idp,objA.a, objA.b)
objA.f(10.)
idp = u_print(idp,objA.a, objA.b)
objA.f(11.)
idp = u_print(idp,objA.a, objA.b)


objA = A(3.0)
idp = u_print(idp,objA.a, objA.b)
objA.f(20.)
idp = u_print(idp,objA.a, objA.b)

New A
0): 2.0 None
1): 2.0 20.0
2): 2.0 22.0
New A
3): 3.0 None
4): 3.0 60.0


# And now use jit on f

In [6]:
class A():
    def __init__(self, a: float):
        self.a = a                   # a variable that will be "static" once the obj is created
        self.b = None                # a variable ("dynamic") updated by f

    @jit
    def f(self, var: float) -> None:
        self.b = self.a * var


idp=0

objA = A(2.0)
idp = u_print(idp,objA.a, objA.b)
objA.f(10.)


0): 2.0 None


TypeError: Cannot interpret value of type <class '__main__.A'> as an abstract array; it does not have a dtype attribute

The problem is that the first argument of the function is **self**, which has the type `A`and JAX does not know what it is.

There are different basic strategies that we can use in this case, and we will discuss them below.

# 1st strategy: external function ('helper')

In [19]:
jax.clear_caches()  # clear all compilations

@jit
def _f(a, var):
    print("compile...")
    res = a * var
    return res

class A():
    def __init__(self, a):
        self.a = a
        self.b = None

    def f(self, var):
        self.b = _f(self.a,var)



idp=0

print("New Obj A (1)")
objA = A(2.0)
idp = u_print(idp, objA.a, objA.b)
objA.f(10.0)
idp = u_print(idp, objA.a, objA.b)
objA.f(11.0)
idp = u_print(idp, objA.a, objA.b)


print("New Obj A (2)")
objA = A(3.0)
idp = u_print(idp, objA.a, objA.b)
objA.f(20.)
idp = u_print(idp, objA.a, objA.b)
objA.f(21.)
idp = u_print(idp, objA.a, objA.b)


print("New Obj A (3)")
objA = A(4.0)
idp = u_print(idp, objA.a, objA.b)
objA.f(20.)
idp = u_print(idp, objA.a, objA.b)

print("f is called with an array not a float => compile")
objA.f(jnp.array([20.]))
idp = u_print(idp, objA.a, objA.b)

objA.f(jnp.array([21.]))
idp = u_print(idp, objA.a, objA.b)

objA.f(21.)
idp = u_print(idp, objA.a, objA.b)

print("change internal objA.a")

objA.a = 400.  # change
objA.f(21.)
idp = u_print(idp, objA.a, objA.b)


New Obj A (1)
0): 2.0 None
compile...
1): 2.0 20.0
2): 2.0 22.0
New Obj A (2)
3): 3.0 None
4): 3.0 60.0
5): 3.0 63.0
New Obj A (3)
6): 4.0 None
7): 4.0 80.0
f is called with an array not a float => compile
compile...
8): 4.0 [80.]
9): 4.0 [84.]
10): 4.0 84.0
change internal objA.a
11): 400.0 8400.0


## Review of the helper method

- This is a simple method
implementation, and we don't have to instruct JAX how to use the A class.

- Now it becomes a matter of taste to code a helper per function to use jit. But we can do file encapsulation so that at least the code of "A" is in the code of the definition of "A".

- Pb: how to manage teh use-case with the function f needing for example another function (either member of A, or external to A)?

# 2nd strategy: self as static


it is a classic procedure that is often proposed in forum.
But...

In [20]:
jax.clear_caches()  # clear all compilations

class A():
    def __init__(self, a: float):
        print("new A")
        self.a = a
        self.b = None

    @partial(jit, static_argnums=(0,))   # "self" is static
    def f(self, var: float) -> None:
        print("compile...")
        self.b = self.a * var  #<- this is the culpit of the later crash on "g"

    def g(self):
        print("g...:",self.b)
        return self.b * self.b

idp=0

objA = A(2.0)
idp = u_print(idp, objA.a, objA.b)
objA.f(10.)
idp = u_print(idp, objA.a, objA.b)
res = objA.g()   # will case a crash but this is not the culpit
print(res)


new A
0): 2.0 None
compile...
1): 2.0 Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
g...: Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float64[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was f at <ipython-input-20-c8714ebe4d23>:9 traced for jit.
------------------------------
The leaked intermediate value was created on line <ipython-input-20-c8714ebe4d23>:12 (f). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3257 (run_cell_async)
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3473 (run_ast_nodes)
/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py:3553 (run_code)
<ipython-input-20-c8714ebe4d23>:22 (<cell line: 22>)
<ipython-input-20-c8714ebe4d23>:12 (f)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

The code crash and seems to point the `g` function BUT in fact the problem comes from the jitted `f` function with self as static: it uses Traced<ShapedArray> to analyze the code, and propagate...
Nb. the trace dump may change between JAX version...

So
```
Do not use `self.var =` in a jitted function with self as static.
```

In [21]:
class A():
    def __init__(self, a):
        print("New obj A")
        self.a = a
        self.b = None

    def set_a(self,x):
        self.a = x
        print("new a:",self.a)

    def set_b(self,x):                  # (*) call from extern
        self.b = x

    @partial(jit, static_argnums=(0,))   #  "self" is static
    def f(self, var):
        print("compile...")
        return self.a * var             # see (*)

    def g(self):
        print("g...:",self.b)
        return self.b * self.b


idp=0

objA = A(2.0)
idp = u_print(idp, objA.a, objA.b)
b  = objA.f(10.)
objA.set_b(b)      #<----------------------
idp = u_print(idp, objA.a, objA.b)
res = objA.g()
print(res)


New obj A
0): 2.0 None
compile...
1): 2.0 20.0
g...: 20.0
400.0


In [25]:
objA.set_a(100.)   # <------ change "a"
idp = u_print(idp, objA.a, objA.b)
objA.f(10.)   # 100. * 10. = 1000.  expected, if "a" unchanged for "f" => 2.0  * 10. = 20.

new a: 100.0
5): 100.0 20.0


Array(20., dtype=float64, weak_type=True)

## We have to be careful:
- the "static" character of "a" is a by-product of the analysis of the code by JIT, but this is done silently and without warning and therefore, this can mislead the user

What is behind? The problem is that `static_argnums` relies on the hash method of the object to determine if it has changed between two calls, and the
method `__hash__` for a user-defined class does not take into account the class attributes.

This means that on the second function call, JAX has no way of knowing that the class attributes have changed and uses the static value cached during the previous compilation.

For this reason, if you mark personal arguments as static, it is important that you define an appropriate `__hash__` method for your class. For example:

In [28]:
class A():
    def __init__(self, a):
        print("new A")
        self.a = a
        self.b = None

    def set_a(self,x):
        self.a = x
        print("new a:",self.a)


    def set_b(self,x):
        self.b = x

    @partial(jit, static_argnums=0)   #"self" is static
    def f(self, var):
        print("compile...")
        return self.a * var

    # specific hash metho
    def __hash__(self):
        return hash((self.a,self.b))

    def __eq__(self, other):   # you need it also
        return (isinstance(other, A) and
            (self.a, self.b) == (other.a, other.b))



But this is not the end....

In [29]:
idp=0

objA = A(2.0)
idp = u_print(idp, objA.a, objA.b)

b = objA.f(10.)         #  what is the type of 'b' ??? (*)
print("b:", b, type(b))
objA.set_b(b)
idp = u_print(idp, objA.a, objA.b)

objA.f(11.)

new A
0): 2.0 None
compile...
b: 20.0 <class 'jaxlib.xla_extension.ArrayImpl'>
1): 2.0 20.0


ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class '__main__.A'>, <__main__.A object at 0x79424d9cba00>. The error was:
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
  File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>
  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start
  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>
  File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
  File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
  File "<ipython-input-29-ac13534385f8>", line 11, in <cell line: 11>
  File "<ipython-input-28-7e9f3e1bd40e>", line 22, in __hash__
TypeError: unhashable type: 'jaxlib.xla_extension.DeviceList'


the problem comes from the fact that after compilation `b` becomes a `DeviceArray` which is unhastable so it's stuck because then self can't be hashtable too !!!

In [30]:
class A():
    def __init__(self, a):
        print("Nouveau A")
        self.a = a
        self.b = None

    def set_a(self,x):
        self.a = x
        print("new a:",self.a)

    def set_b(self,x):
        self.b = x

    @partial(jit, static_argnums=0)
    def f(self, var):
        print("compile...")
        return self.a * var

    def __hash__(self):
        return hash((self.a))                 # do not put "b"

    def __eq__(self, other):
        return (isinstance(other, A) and
            (self.a) == (other.a))             # do not put "b"


In [31]:
idp=0

objA = A(2.0)
idp = u_print(idp, objA.a, objA.b)

b = objA.f(10.)         #  quel est le type de 'b' ??? (*)
print("b:", b, type(b))
objA.set_b(b)
idp = u_print(idp, objA.a, objA.b)

objA.set_b(objA.f(11.))
idp = u_print(idp, objA.a, objA.b)


objA = A(3.0)
idp = u_print(idp, objA.a, objA.b)
objA.set_b(objA.f(20.))
idp = u_print(idp, objA.a, objA.b)

Nouveau A
0): 2.0 None
compile...
b: 20.0 <class 'jaxlib.xla_extension.ArrayImpl'>
1): 2.0 20.0
2): 2.0 22.0
Nouveau A
3): 3.0 None
compile...
4): 3.0 60.0


In [32]:
objA.set_a(40.)
objA.set_b(objA.f(11.))
idp = u_print(idp, objA.a, objA.b)   # 40.*11. = 440


new a: 40.0
compile...
5): 40.0 440.0


## Review self as static:

- We have to carefully define `__hass__`  and `__eq__` (see Python [doc](https://docs.python.org/3/reference/datamodel.html#object.__hash__)) but DeviceArray are non-hashtable

- `a` can be redefined voluntarily (with/without setter), the result of `b` is well updated

- So can be done but with care and there is a simpler solution


# 3d strategy: PyTree

In [33]:
class A():
    def __init__(self, a, b=None):   # new signature
        #print("Nouveau A")
        self.a = a
        self.b = b


    def set_a(self,x):
        self.a = x
        print("new a:",self.a)

    def set_b(self, b):
        self.b = b

    @jit                              # <------ self, no more static    !
    def f(self, var):
        print("compile...")
        return self.a * var

    def g(self):
      return self.b * self.b

    #### PyTree methods....
    def tree_flatten(self):
        children = (self.b,)         # dynamic values
        aux_data = {'a': self.a}     # static values
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        b = children
        a = aux_data['a']
        return cls(a=a, b=b ) # should be the signature of __init__

# register explicit Or decorator (as Cosmology class)
from jax import tree_util
tree_util.register_pytree_node(A,
                               A.tree_flatten,
                               A.tree_unflatten)

In [34]:
idp=0

objA = A(2.0)
idp = u_print(idp, objA.a, objA.b)

b = objA.f(10.)         #  quel est le type de 'b' ??? (*)
print("b:", b, type(b))
objA.set_b(b)
idp = u_print(idp, objA.a, objA.b)

objA.set_b(objA.f(11.))
idp = u_print(idp, objA.a, objA.b)


objA = A(3.0)
idp = u_print(idp, objA.a, objA.b)
objA.set_b(objA.f(20.))
idp = u_print(idp, objA.a, objA.b)
objA.set_b(objA.f(30.))
idp = u_print(idp, objA.a, objA.b)


objA.set_a(30.)
objA.set_b(objA.f(30.))
idp = u_print(idp, objA.a, objA.b)



objA = A(3.0)
idp = u_print(idp, objA.a, objA.b)

objA.set_b(objA.f(20.))
idp = u_print(idp, objA.a, objA.b)

objA.set_b(objA.f(30.))
idp = u_print(idp, objA.a, objA.b)

print(">>>> try an array for var")

objA.set_b(objA.f(jnp.array([30.])))
idp = u_print(idp, objA.a, objA.b)

objA.set_b(objA.f(jnp.array([40.])))
idp = u_print(idp, objA.a, objA.b)

print("g: ", objA.g())

0): 2.0 None
compile...
b: 20.0 <class 'jaxlib.xla_extension.ArrayImpl'>
1): 2.0 20.0
compile...
2): 2.0 22.0
3): 3.0 None
compile...
4): 3.0 60.0
compile...
5): 3.0 90.0
new a: 30.0
compile...
6): 30.0 900.0
7): 3.0 None
8): 3.0 60.0
9): 3.0 90.0
>>>> try an array for var
compile...
10): 3.0 [90.]
compile...
11): 3.0 [120.]
g:  [14400.]


## A possible variation

In [35]:
class A():
    def __init__(self, a, b=None):
        self.a = a
        self.b = b

    @jit
    def f(self, var):
        print("compile...")
        new_b = self.a * var
        return A(self.a, new_b)       # <--------- new objet.


# registration with  lambda functions
tree_util.register_pytree_node(A,
                               lambda x: ((x.a,x.b), None),
                               lambda _, x: A(a=x[0],b=x[1])
                              )

In [36]:
idp=0

objA = A(2.0)
idp = u_print(idp, objA.a, objA.b)    #0

objA = objA.f(10.)
idp = u_print(idp,objA.a, objA.b)     #1

objA = objA.f(11.)
idp = u_print(idp, objA.a, objA.b)    #2

####

objA = A(3.0)
idp = u_print(idp, objA.a, objA.b)     #3
objA = objA.f(20.)
idp = u_print(idp, objA.a, objA.b)     #4
objA= objA.f(30.)
idp = u_print(idp, objA.a, objA.b)     #5


objA.a = 30.  # change
objA= objA.f(30.)
idp = u_print(idp, objA.a, objA.b)     #6

####

objA = A(3.0)
idp = u_print(idp, objA.a, objA.b)     #7

objA= objA.f(20.)
idp = u_print(idp, objA.a, objA.b)     #8

objA= objA.f(30.)
idp = u_print(idp, objA.a, objA.b)     #9

print(">>>> try an array for var")

objA= objA.f(jnp.array([30.]))
idp = u_print(idp, objA.a, objA.b)     #10

objA= objA.f(jnp.array([40.]))
idp = u_print(idp, objA.a, objA.b)     #11

objA= objA.f(jnp.array([50.]))
idp = u_print(idp, objA.a, objA.b)     #12


objA.a = 40. ### new change
objA= objA.f(jnp.array([50.]))
idp = u_print(idp, objA.a, objA.b)     #13


0): 2.0 None
compile...
1): 2.0 20.0
compile...
2): 2.0 22.0
3): 3.0 None
4): 3.0 60.0
5): 3.0 90.0
6): 30.0 900.0
7): 3.0 None
8): 3.0 60.0
9): 3.0 90.0
>>>> try an array for var
compile...
10): 3.0 [90.]
compile...
11): 3.0 [120.]
12): 3.0 [150.]
13): 40.0 [2000.]


## Review PyTree

- one must somehow set up `tree_flatten` and `tree_unflatten` with care
- the change of ``a` is well transmitted
- choice of implementation with new object each time or setter of the internal variable to change

In [37]:
from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class B():
    def __init__(self, p, v):
        self.p = p
        self.v = v
        print("init:",p,v)

    @jit
    def step(self, dt):
        print("compile...")
        a = -9.8
        new_v = self.v + a * dt
        new_p = self.p + new_v * dt
        return B(new_p, new_v)

    #### PyTree methods....
    def tree_flatten(self):
        print("tree_flatten")
        children = (self.p,self.v)         # dynamic values
        aux_data = {}     # static values
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        print("tree_unflatten")
        b = children
        return cls(p=b[0], v=b[1] ) # should be the signature of __init__



In [38]:
b = B(jnp.array([0., 0.5]), jnp.array([1., 1.5]))

init: [0.  0.5] [1.  1.5]


In [39]:
for i in range(5):
    b = b.step(0.01)


tree_flatten
tree_flatten
tree_flatten
tree_unflatten
init: Traced<ShapedArray(float64[2])>with<DynamicJaxprTrace(level=1/0)> Traced<ShapedArray(float64[2])>with<DynamicJaxprTrace(level=1/0)>
compile...
init: Traced<ShapedArray(float64[2])>with<DynamicJaxprTrace(level=1/0)> Traced<ShapedArray(float64[2])>with<DynamicJaxprTrace(level=1/0)>
tree_flatten
tree_flatten
tree_unflatten
init: [0.00902 0.51402] [0.902 1.402]
tree_unflatten
init: [0.00902 0.51402] [0.902 1.402]
tree_flatten
tree_flatten
tree_unflatten
init: [0.01706 0.52706] [0.804 1.304]
tree_flatten
tree_unflatten
init: [0.02412 0.53912] [0.706 1.206]
tree_flatten
tree_unflatten
init: [0.0302 0.5502] [0.608 1.108]
tree_flatten
tree_unflatten
init: [0.0353 0.5603] [0.51 1.01]


In [40]:
! pip install -q equinox

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.2/175.2 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.2/41.2 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [41]:
import equinox as eqx

In [42]:
class D(eqx.Module):
  p: float
  v: float

  def __init__(self, p,v):  # the init
    self.p = p
    self.v = v

  def __call__(self, dt):    # the forward call
        a = -9.8
        new_v = self.v + a * dt
        new_p = self.p + new_v * dt
        return D(new_p, new_v)

In [43]:
d = D(np.array([0, 0.5]), np.array([1, 1.5]))
for i in range(5):
    d = d(0.01)


In [44]:
d.p, d.v

(array([0.0353, 0.5603]), array([0.51, 1.01]))

# Takeaway message (JIT in a class):

- several methods: helper, self-static, Pytree

- should not use `self.<var>` = in a jitted function (otherwise side effect). By the way, the code analysis by JIT is done via `Traced<ShapedArray>`

- use of `__hash__` and `__eq__`, and the DeviceArray is unhastable

- when the compilation is triggered (JIT)

- See also the doc [to-jit-or-not-to-jit](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html?highlight=Traced%3CShapedArray%3E#to-jit-or-not-to-jit)

- Ask yourself "Do I need OO encapsulation (i.e. make Class) ?"