###

### Key concepts

``JAX.Array`` is default array implementation in JAX, but we usually create arrays via JAX API functions. ``jax.numpy`` provides almost all familiar array construction

In [34]:
import jax 
import jax.numpy as jnp 

x = jnp.arange(5)
isinstance(x, jax.Array)

True

In [35]:
#to inspect device
x.devices()

{CpuDevice(id=0)}

An array may be __sharded__ across multiple device in parallel programming, which can be inspected by ``sharding``

In [36]:
x.sharding

SingleDeviceSharding(device=CpuDevice(id=0))

__Transformations__ accept a function as an argument, and return a new transformed function.

In [37]:
def selu(x, alpha=1.67, lambda_ = 1.05):
    return lambda_ * jnp.where(x >0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.5))

1.5749999


In [38]:
#Another way of JIT compilation

@jax.jit
def selu(x, alpha=1.67, lambda_ = 1.05):
    return lambda_ * jnp.where(x >0, x, alpha * jnp.exp(x) - alpha)

__Tracers__ are used as standin for JAX array to determine the sequence of operations performed by a python function

In [39]:
@jax.jit 
def f(x):
    print(x)
    return x+1

x = jnp.arange(5)
result = f(x)

Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)>


__JAXPR__ is an intermediate representation of a computation that is generated by JAX, and is forwarded to XLA for compilation and execution.

In [40]:
def selu(x, alpha=1.67, lambda_ = 1.05):
    return lambda_ * jnp.where(x >0, x, alpha * jnp.exp(x) - alpha)

In [41]:
x = jnp.arange(5)
jax.make_jaxpr(selu)(x) 

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[5][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:bool[5][39m = gt a 0
    c[35m:f32[5][39m = convert_element_type[new_dtype=float32 weak_type=False] a
    d[35m:f32[5][39m = exp c
    e[35m:f32[5][39m = mul 1.6699999570846558 d
    f[35m:f32[5][39m = sub e 1.6699999570846558
    g[35m:f32[5][39m = pjit[
      name=_where
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; h[35m:bool[5][39m i[35m:i32[5][39m j[35m:f32[5][39m. [34m[22m[1mlet
          [39m[22m[22mk[35m:f32[5][39m = convert_element_type[new_dtype=float32 weak_type=False] i
          l[35m:f32[5][39m = select_n h j k
        [34m[22m[1min [39m[22m[22m(l,) }
    ] b a f
    m[35m:f32[5][39m = mul 1.0499999523162842 g
  [34m[22m[1min [39m[22m[22m(m,) }

JAX is designed to work with __pure__ functions. Pure functions are those that always produces the same output for the same input or determinisitc, and has no side-effects.

Side-effect within a function occurs when a function:
- modifies a variable outside its local scope
- modifies a mutable object passed as an argument
- performs I/O operations (printing/over-writing, etc)

In [42]:
def log2_with_print(x):

    print("printed x:", x)
    ln_x = jnp.log(x)
    return ln_x / jnp.log(2.0)

print(jax.make_jaxpr(log2_with_print)(3.0))

printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[][39m = log a
    c[35m:f32[][39m = log 2.0
    d[35m:f32[][39m = div b c
  [34m[22m[1min [39m[22m[22m(d,) }


### JIT Compilation

In [13]:
import time 
def selu(x, alpha=1.67, lambda_ = 1.05):
    return lambda_ * jnp.where(x >0, x, alpha * jnp.exp(x) - alpha) 

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

4.19 ms ± 179 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
#JIT compilation
selu_jit = jax.jit(selu)
%timeit selu_jit(x).block_until_ready()

759 µs ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


We cannot JIT everywhere. Some cases are when functions are:
- having control flow dependent on runtime value
- not of static state, but dynamic
- having non-JAX operations within itself

In [30]:
# def f(x):
#   if x > 0:  # Depends on runtime value of x
#     return x
#   else:
#     return 2 * x
# jax.jit(f)(10)  # Raises TracerBoolConversionError


# def g(x, n):
#   i = 0
#   while i < n:
#     i += 1
#   return x + i

# jax.jit(g)(10, 20)  


# def f(x):
#   return jnp.arange(x)  # Shape of output depends on runtime value of x
# jax.jit(f)(10)  # Raises ConcretizationTypeError


import numpy as np
# def f(x):
#   return np.sin(x)  # Non-JAX operation
# jax.jit(f)(10)  # Raises TracerArrayConversionError

Special __control flow operators__ can be used, or we can JIT-compile only part of the function

In [23]:
#While look conditioned on x and y with a jitted body 

@jax.jit
def loop_body(prev_i):
    return prev_i + 1

def g_inner_jitted(x, n):
    i = 0
    while i < n:
        i = loop_body(i)
    return x + i
g_inner_jitted(10, 20)

Array(30, dtype=int32, weak_type=True)

Another way is to __mark argument as static__ with specifying ``static_argnums`` or ``static_argnames``

In [32]:
f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))

g_jitted_correct = jax.jit(g, static_argnames=['n'])
print(g_jitted_correct(10, 20))

10
30


To specify such argument when using __jit__ as decorator, a common pattern is to use ``functools.partial()``

In [33]:
from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i
print(g_jit_decorated(10, 20))

30


If we have jitted ``f = jax.jit(g)``, then in JAX subsequent calls of ``f`` will reuse the cached code, unless we specify ``static_argnums`` then cached code will only be used for specified static values. For any other value, a new compilation will occur

In [43]:
from functools import partial

def unjitted_loop_body(prev_i):
    return prev_i + 1

def g_inner_jitted_partial(x, n):
    i = 0
    while i < n:
        #every time partial returns a function with different hash
        i = jax.jit(partial(unjitted_loop_body))(i)
    return x + i

def g_inner_jitted_lambda(x, n):
    i = 0
    while i < n:
        #every time lambda returns a function with different hash
        i = jax.jit(lambda x: unjitted_loop_body(x))(i)
    return x + i

def g_inner_jitted_normal(x, n):
    i = 0
    while i < n:
        #JAX can find cached version of the function
        i = jax.jit(unjitted_loop_body)(i)
    return x + i

print("JIT called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("JIT called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()   

print("JIT called in a loop with normal:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()

JIT called in a loop with partials:
192 ms ± 8.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
JIT called in a loop with lambdas:
190 ms ± 2.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
JIT called in a loop with normal:
1.22 ms ± 23 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
