In [15]:
#https://github.com/deepmind/optax/blob/master/examples/flax_example.py

from flax import linen as nn
import jax
import jax.numpy as jnp
import optax

In [22]:
learning_rate = 1e-2
n_training_steps = 100


[ 0.06141227  0.00766806 -0.05698649 ...  0.01272237  0.01838974
   0.02176958]

## =====================================================================
# Random number generator sequence.
rng = jax.random.PRNGKey(0)
rng1, rng2 = jax.random.split(rng)

# Create a one linear layer instance.
model = nn.Dense(features=5)

# Initialise the parameters.
params = model.init(rng2, jax.random.normal(rng1, (10,)))


print(params)

## =====================================================================

# Set problem dimensions.
nsamples = 20
xdim = 10
ydim = 5

# Generate random ground truth w and b.
w = jax.random.normal(rng1, (xdim, ydim))
b = jax.random.normal(rng2, (ydim,))

# Generate samples with additional noise.
ksample, knoise = jax.random.split(rng1)
x_samples = jax.random.normal(ksample, (nsamples, xdim))
y_samples = jnp.dot(x_samples, w) + b
y_samples += 0.1 * jax.random.normal(knoise, (nsamples, ydim))


print(x_samples.shape)
print(y_samples.shape)


# Define an MSE loss function.
def make_mse_func(x_batched, y_batched):
  def mse(params):
    # Define the squared loss for a single (x, y) pair.
    def squared_error(x, y):
      pred = model.apply(params, x)
      return jnp.inner(y-pred, y-pred) / 2.0
    # Vectorise the squared error and compute the average of the loss.
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
  return jax.jit(mse)  # `jit` the result.


FrozenDict({
    params: {
        kernel: DeviceArray([[ 2.35571519e-01, -1.71652585e-01, -4.45728786e-02,
                      -4.68043566e-01,  4.54595268e-01],
                     [-6.87736452e-01,  3.67835373e-01, -1.79262087e-01,
                       1.29276231e-01, -2.42580160e-01],
                     [ 2.02303097e-01, -2.49465615e-01,  2.74955630e-01,
                       4.73488361e-01, -1.98002517e-01],
                     [ 2.74478316e-01, -1.21369645e-01, -2.25361675e-01,
                      -4.78193641e-01, -9.63979885e-02],
                     [-6.19886033e-02, -1.72743499e-01,  2.96945305e-04,
                      -7.17593372e-01,  2.00894207e-01],
                     [-5.60321152e-01,  3.27208370e-01,  1.06281497e-01,
                       1.28758654e-01,  1.16973236e-01],
                     [ 1.82218999e-01,  1.11444063e-01, -1.62924141e-01,
                       3.24953087e-02, -1.67053342e-01],
                     [ 4.31294113e-01,  2.08004564e-01,

In [18]:
# Instantiate the sampled loss.
loss = make_mse_func(x_samples, y_samples)

optimizer = optax.adam(learning_rate=1e-2)

# Create optimiser state.
opt_state = optimizer.init(params)
# Compute the gradient of the loss function.
loss_grad_fn = jax.value_and_grad(loss)


print(params)
# Minimise the loss.
for step in range(n_training_steps):
    # Compute gradient of the loss.
    loss_val, grads = loss_grad_fn(params)
    # Update the optimiser state, create an update to the params.
    updates, opt_state = optimizer.update(grads, opt_state)
    # Update the parameters.
    params = optax.apply_updates(params, updates)
    print(f'Loss[{step}] = {loss_val}')

FrozenDict({
    params: {
        bias: DeviceArray([-0.87172574, -0.86985385,  0.8467341 ,  0.84649915,
                     -0.39169806], dtype=float32),
        kernel: DeviceArray([[ 0.94558173, -0.16723742,  0.3746021 , -1.0775216 ,
                       0.25383565],
                     [ 0.22913468,  0.47971997,  0.62425894,  0.72347826,
                       0.20757319],
                     [-0.6437151 , -0.1304319 ,  1.0882291 ,  0.37703192,
                      -0.9357848 ],
                     [-0.57490134, -0.71605146,  0.22252326,  0.37048048,
                      -0.5051023 ],
                     [-0.09404071,  0.6873026 , -0.84188306,  0.10151887,
                      -0.69316685],
                     [ 0.17136317, -0.22676271, -0.26385623,  0.10000193,
                       0.6823034 ],
                     [-0.48451614,  0.8493618 ,  0.6192771 , -0.74398625,
                       0.70105404],
                     [ 0.9236116 , -0.6056143 ,  0.874353  ,  0.7