# Tutorial: Getting Started with JAX for machine learning

Having seen (`01_jax.ipynb`) JAX's ability to differentiate and vectorise, we might reasonably consider taking gradients of functions like "the loss function of a model" and vectorising operations like "take the parameter gradients across all samples in the data set".

### Optional: Model training in pure JAX

In [6]:
# NotImplementedError

### Model training with JAX, equinox and optax

Model training with JAX is an obvious enough idea that several libraries emerged in 2019-20 to support abstractions like network layers and utility functions for parameter updates. The most prominent of these are [Flax](https://flax.readthedocs.io/en/latest/) (Google Research) and [Haiku](https://dm-haiku.readthedocs.io/en/latest/index.html) (Google DeepMind), neither of which I find very satisfactory. 

For starting out, I recommend the newer [equinox](https://docs.kidger.site/equinox/). Here is a linear layer implemented for training:

In [5]:
import jax
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx

batch_size, in_size, out_size = 32, 2, 3
x = jnp.zeros((batch_size, in_size))
y = jnp.zeros((batch_size, out_size))

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, seed):
        wkey, bkey = jr.split(jr.PRNGKey(seed))
        self.weight = jr.normal(wkey, (out_size, in_size))
        self.bias = jr.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

model = Linear(in_size, out_size, seed=404)


In [None]:
@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y) ** 2)

grads = loss_fn(model, x, y)