In [2]:
import jax
import flax
import optax
import os

from jax import numpy as jnp
from flax import linen as nn

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".20"

In [3]:
from typing import Callable, Sequence

@jax.jit
def relu(x):
    return jnp.maximum(x, 0)

class SimpleDense(nn.Module):
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros_init()

    @nn.compact
    def __call__(self, inputs, *args, **kwargs):
        initialized = self.has_variable("stats", "counter")
        kernel = self.param("kernel", self.kernel_init, (inputs.shape[-1], self.features))
        bias = self.param("bias", self.bias_init, (self.features,))
        counter = self.variable("stats", "counter", lambda: 0)
        outputs = jnp.dot(inputs, kernel) + bias

        if initialized:
            counter.value += 1

        return outputs

class SimpleMLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs, *args, **kwargs):
        x = inputs
        last = len(self.features) - 1
        for i, n in enumerate(self.features):
            x = SimpleDense(n)(x)
            if i != last:
                x = relu(x)
        return x


In [4]:
n_batch = 20
x_dim = 10
y_dim = 3

key = jax.random.PRNGKey(0)
key_w, key_b, key_x, key_noise = jax.random.split(key, 4)

# Setup true line
true_w = jax.random.normal(key_w, shape=(x_dim, y_dim))
true_b = jax.random.normal(key_b, shape=(y_dim,))

# Prepare training samples
x_samples = jax.random.normal(key_x, shape=(n_batch, x_dim))
y_samples = jnp.dot(x_samples, true_w) + true_b + 0.1 * jax.random.normal(key_noise, true_b.shape)

print(x_samples.shape, y_samples.shape) # (20, 10) (20, 3)

(20, 10) (20, 3)


In [5]:
key_init, key_infer = jax.random.split(key_noise)

model = SimpleMLP([16, 8, y_dim])
optim = optax.adam(0.01)

@jax.jit
def train_step(params, mstate, ostate):

    @jax.jit
    def mse(params, x_batch, y_batch):
        y_pred, updated_state = model.apply({ "params": params, **mstate}, x_batch, mutable=mstate.keys())
        y_diff = jnp.square(y_pred - y_batch)
        y_losses = jnp.sum(y_diff, axis=1) / 2  # example loss
        return jnp.mean(y_losses, axis=0), updated_state

    (loss, mstate), grad = jax.value_and_grad(mse, has_aux=True)(params, x_samples, y_samples)

    updates, ostate = optim.update(grad, ostate)
    params = optax.apply_updates(params, updates)
    return params, mstate, ostate, loss


In [6]:
variables = model.init(key_infer, x_samples)
mstate, params = variables.pop("params")
ostate = optim.init(params)

for epoch in range(1000):
    params, mstate, ostate, loss = train_step(params, mstate, ostate)
    if (epoch + 1) % 100 == 0:
        print(loss)

0.22833319
0.08406925
0.0697316
0.053109266
0.016352348
0.00046114367
5.5995006e-06
6.684547e-08
5.688082e-10
3.8080554e-12


In [7]:
jax.tree_map(jnp.shape, params)

FrozenDict({
    SimpleDense_0: {
        bias: (16,),
        kernel: (10, 16),
    },
    SimpleDense_1: {
        bias: (8,),
        kernel: (16, 8),
    },
    SimpleDense_2: {
        bias: (3,),
        kernel: (8, 3),
    },
})

In [8]:
mstate

FrozenDict({
    stats: {
        SimpleDense_0: {
            counter: Array(1000, dtype=int32, weak_type=True),
        },
        SimpleDense_1: {
            counter: Array(1000, dtype=int32, weak_type=True),
        },
        SimpleDense_2: {
            counter: Array(1000, dtype=int32, weak_type=True),
        },
    },
})