In [1]:
import jax

# Functional Programming Foundations in JAX

In the *core_jax*  notebook we introduced pure functions, array immutability in jax. Now we're going to dive a little bit more into this concepts and we're going to introduce Functional State Management in JAX.

**Immutability**: means that once an object is created, it cannot be modified. In JAX, arrays are immutable. Instead we create new arrays with the desired changes

**Managing State Functionally**:State is managed explicitly rather than implicitly. In JAX this is done by: Passing state as args to functions, and returning new state as part of the function's output.


In [3]:
def normalizer(S, norm_factor):
  normalized_S = S / norm_factor
  return normalized_S

#Initial state S
S = 1337
#Apply updates
S = normalizer(S, 10)
S

133.7

# Combining Immutability and State Management
In JAX, we often need to manage complex state. It is something **crucial** to understand.For example managing a simple ML model Parameters

In [5]:
import jax.numpy as jnp

def model(params, x):
  w, b = params
  return jnp.dot(x, w) + b # y = X.W + b

# Let's initialize parameters (same as S in the prev. example, stat init)
params = (jnp.array([1.0, 2.0]), jnp.array([3.0])) #(w, b)

def update_params(params, grads, lr):
  w, b = params
  dw, db = grads
  new_w = w - lr * dw
  new_b = b - lr * db
  return (new_w, new_b)

# Example of gradients (arbitrary)
grads = (jnp.array([0.1, 0.2]), jnp.array(0.4))

#now we update the parameters (Updating the state )
new_params = update_params(params, grads, lr=0.1)
new_params

(Array([0.99, 1.98], dtype=float32), Array([2.96], dtype=float32))