In [1]:
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn

In [2]:
model = nn.Dense(features=5)

In [3]:
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,))
params = model.init(key2, x)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'params': {'bias': (5,), 'kernel': (10, 5)}}

In [4]:
n_samples = 20
x_dim = 10
y_dim = 5

key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})

key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [5]:
@jax.jit
def mse(params, x_batched, y_batched):
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

In [6]:
learning_rate = 0.3
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.023639796
Loss step 0:  35.343876
Loss step 10:  0.51505077
Loss step 20:  0.11404524
Loss step 30:  0.039395202
Loss step 40:  0.01994018
Loss step 50:  0.014217627
Loss step 60:  0.012428728
Loss step 70:  0.011851473
Loss step 80:  0.011662121
Loss step 90:  0.011599523
Loss step 100:  0.01157873


In [7]:
params = model.init(key2, x)

def rollout(carry, x):
    params, step = carry
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = update_params(params, learning_rate, grads)

    def print_loss_fn():
        jax.debug.print('Loss step {i}: {loss}', i=step, loss=loss_val)
        return loss_val
    
    loss_val = jax.lax.cond(step % 10 == 0, lambda _: print_loss_fn(), lambda _: loss_val, operand=None)
    return (params, step+1), loss_val

_, loss_vals = jax.lax.scan(rollout, (params, 0), None, 101)

Loss step 0: 35.343875885009766
Loss step 10: 0.5150507688522339
Loss step 20: 0.11404523998498917
Loss step 30: 0.039395201951265335
Loss step 40: 0.019940180703997612
Loss step 50: 0.014217627234756947
Loss step 60: 0.012428727932274342
Loss step 70: 0.011851472780108452
Loss step 80: 0.011662120930850506
Loss step 90: 0.011599523015320301
Loss step 100: 0.011578730307519436


In [8]:
params = model.init(key2, x)

def rollout(carry, x):
    (params, step, loss_history) = carry
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = update_params(params, learning_rate, grads)

    def store_loss_fn():
        return loss_history.at[step // 10].set(loss_val)
    
    loss_history = jax.lax.cond(step % 10 == 0, lambda _: store_loss_fn(), lambda _: loss_history, operand=None)
    return (params, step+1, loss_history), loss_val
#                                            f         init                                 xs    length
(_, _, final_loss_history), _ = jax.lax.scan(rollout, (params, 0, jnp.zeros((101 // 10,))), None, 101)

In [9]:
final_loss_history

Array([3.5343876e+01, 5.1505077e-01, 1.1404524e-01, 3.9395202e-02,
       1.9940181e-02, 1.4217627e-02, 1.2428728e-02, 1.1851473e-02,
       1.1662121e-02, 1.1599523e-02], dtype=float32)