In [1]:
import jax

# Functional Programming Foundations in JAX

In the *core_jax*  notebook we introduced pure functions, array immutability in jax. Now we're going to dive a little bit more into this concepts and we're going to introduce Functional State Management in JAX.

**Immutability**: means that once an object is created, it cannot be modified. In JAX, arrays are immutable. Instead we create new arrays with the desired changes

**Managing State Functionally**:State is managed explicitly rather than implicitly. In JAX this is done by: Passing state as args to functions, and returning new state as part of the function's output.


In [2]:
def normalizer(S, norm_factor):
  normalized_S = S / norm_factor
  return normalized_S

#Initial state S
S = 1337
#Apply updates
S = normalizer(S, 10)
S

133.7

# Combining Immutability and State Management
In JAX, we often need to manage complex state. It is something **crucial** to understand.For example managing a simple ML model Parameters

In [3]:
import jax.numpy as jnp

def model(params, x):
  w, b = params
  return jnp.dot(x, w) + b # y = X.W + b

# Let's initialize parameters (same as S in the prev. example, stat init)
params = (jnp.array([1.0, 2.0]), jnp.array([3.0])) #(w, b)

def update_params(params, grads, lr):
  w, b = params
  dw, db = grads
  new_w = w - lr * dw
  new_b = b - lr * db
  return (new_w, new_b)

# Example of gradients (arbitrary)
grads = (jnp.array([0.1, 0.2]), jnp.array(0.4))

#now we update the parameters (Updating the state )
new_params = update_params(params, grads, lr=0.1)
new_params

(Array([0.99, 1.98], dtype=float32), Array([2.96], dtype=float32))

# Function Transformations
Function transformations are the magic spell of jax. function -> Transformation -> new_function with enhanced capabilities.
In this section we're going to handle 4 key transformations:
- **grad**
- **jit**
- **vmap**
- **pmap**


In [4]:
#----- grad: Automatic differentiation

def f(x):
  return 3.0 * x**2+ 2.0*x + 15.0
'''
To compute the gradient (derivative) of this manually with respect to x
df/dx  = (3*2)x + 2 = 6x + 2
At x = 1.0 --> df/dx = 6.0 + 2.0 = 8.0
'''
#Let's use AD with jax.grad and verify

df_dx = jax.grad(f)

print(df_dx(1.0))

8.0


In [5]:
#What if we want 2nd derivative? we simply do grad(grad(f))
#manually d2f_dx2 = 6.0 a constant -> the 3rd derivative should be equal to 0

d2f_dx2 = jax.grad(jax.grad(f))
print(d2f_dx2(1.0))

#3rd derivative [ df/dx(d2f/dx2)] or grad(grad(grad(f)))
d3f_dx3 = jax.grad(d2f_dx2)
print(d3f_dx3(1.0)) #--> 0.0

6.0
0.0


## JIT
In the introductory Notebook we use JIT. It's important to understand how JIT handles control flow, and dynamic control flow.

In [6]:
@jax.jit
def static_fn(x):
  return x**2

static_fn(1)

Array(1, dtype=int32, weak_type=True)

In [7]:
@jax.jit
def dynamic_fn(x, cond):
  if cond:
    return x**2
  else:
    return x**3
dynamic_fn(1, True) #--> This throws an error because cond affects control flow

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function dynamic_fn at <ipython-input-7-47260f41b174>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument cond.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [8]:
 #In order to solve the problem of dynamic control flow in jax
 #We should Mark 'cond' as a static argument with @partial

from functools import partial
@partial(jax.jit, static_argnames=['cond'])
def dynamic_fn(x, cond):
  if cond:
    return x**2
  else:
    return x**3
dynamic_fn(1, True) # Now it works just fine

Array(1, dtype=int32, weak_type=True)

## vmap: Automatic Vecotrization
**vmap** (vectorizing map) automatically vectorizes a function, allowing it to operate on batches of inputs without explicit loops.


In [9]:
# Let's say we want to vectorize this function
def predict(params, x):
    w, b = params
    return jnp.dot(x, w) + b

#Vectorize it over a batch of inputs
batch_predict = jax.vmap(predict, in_axes=(None, 0))

# Apply to a batch
params = (jnp.array([1.0, 2.0]), jnp.array(0.0))
batch_x = jnp.array([[1.0, 2.0], [3.0, 4.0]])  # Batch of 2 examples




'''
Instead of iterating over the elements of batch_x and append the result
at each iterations, we just make the vectorized function swallow batch_x and
iterate over it without needing a loop
'''

print(batch_predict(params, batch_x))


[ 5. 11.]


## pmap : Parallel Device Computation:
pmap distributes computation across multiple devices. It splits inputs across devices and performs computations in parallel

In [10]:
def f(x):
  return jnp.exp(x**2 + 1)

parallel_f = jax.pmap(f)
devices = jax.devices()
x = jnp.arange(8).reshape(len(devices), -1) #Split across devices

print(x)

[[0 1 2 3 4 5 6 7]]
