<a href="https://colab.research.google.com/github/jecampagne/JaxTutos/blob/main/JAX_static_traced_var_func.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import jax
from jax import lax, grad, jit, vmap, jacrev
from jax.tree_util import register_pytree_node_class

import jax.numpy as jnp
import numpy as np

In [3]:
import jaxlib.xla_extension as xla_ext

# You could use the presets
option = xla_ext.HloPrintOptions.short_parsable()
#option = xla_ext.HloPrintOptions.canonical()
#option = xla_ext.HloPrintOptions.fingerprint()


# Topic: pure numpy & JAX

The question arise sometimes after playing with JAX: does one can use pure numpy inside a JAX function? why sometimes one NEEDS to use pure numpy even in a Jitable JAX function... Lets us take some use-cases.

One things that is stated in the JAX doc [Thinking in Jax](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html) is related to JIT:   

    "Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time"

This is more in action with a JAX change done after version 0.2 (currently JAX version is >= 0.4.6).

Key Concept:
    
    Just as values can be either static or traced, operations can be static or traced. 





# A toy exemple

Let us consider this Numpy code

In [4]:
x = np.arange(0,12.,1).reshape((3, 4))
print("in= ",x)
x = x.reshape(np.prod(x.shape))
print("out= ",x, " sum:", x.sum())

in=  [[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]]
out=  [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11.]  sum: 66.0


It is then natural to consider this JAX translation

In [5]:
def ex1(x):
  size = jnp.prod(jnp.array(x.shape))
  return jnp.sum(x.reshape((size,)))

In [6]:
x = jnp.arange(0,12.,1).reshape((3, 4))
ex1(x)



Array(66., dtype=float32)

ok all good. Let jit this function....  :)

In [7]:
jax.jit(ex1)(x) # a crash my god!

TypeError: ignored

There is a crash because of "reshape" as  "***Shapes must be 1D sequences of concrete values of integer type***". But the solution to put "x" as static is not our solution as we want for instance compute gradiant, vmap...

In fact: x.shape is performed statically, but jnp.prod is traced as can be shown running the following code. 

In [8]:
def ex1_debug(x):
  size = jnp.prod(jnp.array(x.shape))
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")   # not traced
  print(f"jnp.array(x.shape).prod() = {size}")
  return jnp.sum(x.reshape((size,)))


In [9]:
jax.jit(ex1_debug)(x)  # still crash but see the prints

x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (3, 4)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>


TypeError: ignored

In [10]:
@jit
def ex1_sol(x):
  size = np.prod(np.array(x.shape))     #<= notice pure numpy used
  return jnp.sum(x.reshape((size,)))

In [11]:
x = jnp.arange(0,12.,1).reshape((3, 4))
ex1_sol(x)

Array(66., dtype=float32)

In [12]:
grad(ex1_sol)(x)

Array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32)

In [16]:
#keep that vmap should be applyied on same sized objects
vmap(ex1_sol)(jnp.array([x,jax.random.normal(jax.random.PRNGKey(10), shape=(12,)).reshape((3, 4))]))

Array([66.     , -6.51921], dtype=float32)

# An other toy example

In [17]:
def exo2(x):  
  n = x.shape[0]        # this is static Ok
  print("x: ",x, type(x))
  print("n: ",n)
  y = jnp.arange(0.,n)   # so ok
  return jnp.sum(x*y)

In [18]:
x = jax.random.normal(jax.random.PRNGKey(10), shape=(10,))
(exo2)(x)

x:  [ 0.4676388  -0.26684377  1.1685165  -0.48255268 -1.4658067   0.408608
 -0.12636082 -0.16064268  0.876103   -0.23290116] <class 'jaxlib.xla_extension.ArrayImpl'>
n:  10


Array(-0.16760588, dtype=float32)

In [19]:
jax.jit(exo2)(x)

x:  Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=1/0)> <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
n:  10


Array(-0.16760588, dtype=float32)

In [20]:
jax.jit(jax.grad(exo2))(x)

x:  Traced<ShapedArray(float32[10])>with<JVPTrace(level=3/0)> with
  primal = Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=1/0)>
  tangent = Traced<ShapedArray(float32[10])>with<JaxprTrace(level=2/0)> with
    pval = (ShapedArray(float32[10]), None)
    recipe = LambdaBinding() <class 'jax._src.interpreters.ad.JVPTracer'>
n:  10


Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)

In [21]:
def exo2_modif(x,n):
  print("x: ",x, type(x))
  print("n: ",n)
  y = jnp.arange(0.,n)   # No more ok as n is now traced
  return jnp.sum(x*y)

In [22]:
jax.jit(exo2_modif)(x, x.shape[0])  # Crash !!!!

x:  Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=1/0)> <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
n:  Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


ConcretizationTypeError: ignored

In [23]:
def exo2_modif(x,n):
  print("x: ",x, type(x))
  print("n: ",n)
  y = jnp.arange(0.,n)   
  return jnp.sum(x*y)

In [24]:
jax.jit(exo2_modif, static_argnums=(1,))(x, x.shape[0]) # Now n is no more traced by construction

x:  Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=1/0)> <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
n:  10


Array(-0.16760588, dtype=float32)

In [25]:
jax.jit(jax.grad(exo2_modif), static_argnums=(1,))(x, x.shape[0]) 

x:  Traced<ShapedArray(float32[10])>with<JVPTrace(level=3/0)> with
  primal = Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=1/0)>
  tangent = Traced<ShapedArray(float32[10])>with<JaxprTrace(level=2/0)> with
    pval = (ShapedArray(float32[10]), None)
    recipe = LambdaBinding() <class 'jax._src.interpreters.ad.JVPTracer'>
n:  10


Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)

# More advanced exo

The above exercises could be concidered as toy ex. but the kind of pb can be the cause of crash not so eveident to debug if not aware of.



In [26]:
def ex_convolve(in1,in2):

  axes = range(in1.ndim)

  s1 = in1.shape
  s2 = in2.shape
  shape = [s1[i] + s2[i] -1 for i in range(in1.ndim)]

  print("shape",shape)
  
  fshape = 2**jnp.ceil(jnp.log2(jnp.array(shape))).astype('uint16')  #<-------

  print("fshape",fshape)

  sp1 = jnp.fft.fftn(in1, fshape, axes=axes)                         #<-------
  sp2 = jnp.fft.fftn(in2, fshape, axes=axes)
  conv = jnp.fft.irfftn(sp1 * sp2, fshape, axes=axes)
  conv = conv[tuple(map(slice, shape))]

  def _centered(arr, newshape):
    # Return the center newshape portion of the array.
    newshape = np.asarray(newshape)
    currshape = np.array(arr.shape)
    startind = (currshape - newshape) // 2
    endind = startind + newshape
    myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
    return arr[tuple(myslice)]

  return _centered(conv, s1)

In [27]:
key, key1, key2= jax.random.split(jax.random.PRNGKey(1),3)
x1 = jax.random.normal(key1, shape=(10,10))
x2 = jax.random.normal(key2, shape=(10,10))
res = ex_convolve(x1,x2)


shape [19, 19]
fshape [32 32]


In [28]:
res1= jax.jit(ex_convolve)(x1,x2) # crash

shape [19, 19]
fshape Traced<ShapedArray(uint16[2])>with<DynamicJaxprTrace(level=1/0)>


TracerIntegerConversionError: ignored

The problem comes from the fshape used by fft is Traced while static is required as "Shape = Sequence[int]". The solution is to use pure Numpy
function to compute fshape. 

In [29]:
def ex_convolve_sol(in1,in2):

  axes = range(in1.ndim)

  s1 = in1.shape
  s2 = in2.shape
  shape = [s1[i] + s2[i] -1 for i in range(in1.ndim)]

  print("shape",shape)
  
  fshape = 2**np.ceil(np.log2(np.array(shape))).astype('uint16') # ok now

  print("fshape",fshape)

  sp1 = jnp.fft.fftn(in1, fshape, axes=axes)
  sp2 = jnp.fft.fftn(in2, fshape, axes=axes)
  conv = jnp.fft.irfftn(sp1 * sp2, fshape, axes=axes)
  conv = conv[tuple(map(slice, shape))]

  def _centered(arr, newshape):
    # Return the center newshape portion of the array.
    newshape = np.asarray(newshape)
    currshape = np.array(arr.shape)
    startind = (currshape - newshape) // 2
    endind = startind + newshape
    myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
    return arr[tuple(myslice)]

  return _centered(conv, s1)

In [30]:
res1= jax.jit(ex_convolve_sol)(x1,x2)

shape [19, 19]
fshape [32 32]


# Takeaway message

- Properties of arrays like shapes, sizes, and dtypes are always static within JAX, so they can be used to define other static quantities

- Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

- Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.

- Use numpy for operations that you want to be static; use jax.numpy for operations that you want to be traced.


- This post was in fact the result of an exchange I had with JAX developpers (end March 2023). The [JAX doc "how to think in JAX"] (https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#how-to-think-in-jax) was completed after this post. As the JAX API may change, this doc can also change too. 




# Epilog

In [35]:
# a curiosity due to np.sum beahvior
def ex0(x):
  print("x:",x)
  size = np.prod(np.array(x.shape))
  return np.sum(x.reshape((size,)))
print(jit(ex0)(np.arange(0,12.,1).reshape((3, 4))))
print(grad(ex0)(np.arange(0,12.,1).reshape((3, 4))))

x: Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
66.0
x: Traced<ConcreteArray([[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.],
       [ 8.,  9., 10., 11.]])
  tangent = Traced<ShapedArray(float32[3,4])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3,4]), None)
    recipe = LambdaBinding()
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]


A you probably noticed the `ex0` code rely only on pure Numpy function, so one can ask how is it possible to trace `x` and make possible the JIT and Autodiff???

The reason is very tricky: `np.sum` looks in fact if its argument has a `.sum()` method, if so it calls it. And JAX tracers do have a sum method, so it is why every thing works ! now if you try np.exp on top of np.sum then there will be a crash... Theris no miracle.  