Stateful Computations in JAX
https://jax.readthedocs.io/en/latest/jax-101/07-state.html

In [None]:
import jax
import jax.numpy as jnp

In [None]:
class Counter:
    def __init__(self):
        self.n : int = 0

    def count(self) -> int:
        self.n += 1
        return self.n
    
    def reset(self) -> None:
        self.n = 0

counter = Counter()
for _ in range(4):
    print(f'{counter.count()}')


In [None]:
counter = Counter()
compiled_count = jax.jit(counter.count)

for _ in range(4):
    print(f'{compiled_count()}')

In [None]:
from typing import Tuple

# Create Counter object with explicit state
CounterState = int

class CounterWithState:

    def count(self, n: CounterState) -> Tuple[int, CounterState]:
        return n+1, n+1
    
    def reset(self) -> CounterState:
        return 0

counter_with_state = CounterWithState()
compiled_count = jax.jit(counter_with_state.count)

state = counter_with_state.reset()

for _ in range(4):
    value, state = compiled_count(state)
    print(f'{value}')

In [None]:
from typing import NamedTuple

class Params(NamedTuple):
    W: jnp.ndarray
    b: jnp.ndarray
    
def init(rng) -> Params:
    W_key, b_key = jax.random.split(rng)
    W = jax.random.normal(W_key, ())
    b = jax.random.normal(b_key, ())
    return Params(W, b)

def forward(params: Params, x: jnp.array) -> jnp.array:
    return jnp.dot(x, params.W) + params.b

def loss(params: Params, x: jnp.array, y:jnp.array) -> jnp.array:
    return jnp.mean((forward(params, x) - y)**2)

@jax.jit
def update(params: Params, x: jnp.array, y: jnp.array, lr:float=5e-3) -> Params:
    return jax.tree_multimap(
        lambda p, g: p - lr*g,
        params, jax.grad(loss)(params, x, y)
    )

In [None]:
import matplotlib.pyplot as plt

rng = jax.random.PRNGKey(42)
x_key, noise_key = jax.random.split(rng)

Ws = 2
bs = -1
xs = jax.random.normal(x_key, (128, 1))
noise = 0.1*jax.random.normal(noise_key, (128, 1))
ys = Ws*xs + bs + noise

params = init(rng)
for _ in range(1000):
    params = update(params, xs, ys)

plt.scatter(xs, ys, label='dataset')
plt.scatter(xs, forward(params, xs), label='model')
plt.legend()