In [1]:
import jax
from typing import Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn

In [2]:
model = nn.Dense(features=5)

In [3]:
model

Dense(
    # attributes
    features = 5
    use_bias = True
    dtype = None
    param_dtype = float32
    precision = None
    kernel_init = init
    bias_init = zeros
    dot_general = None
    dot_general_cls = None
)

In [4]:
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,)) # Dummy input data
params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes

{'params': {'bias': (5,), 'kernel': (10, 5)}}

In [5]:
model.apply(params, x)

Array([-1.3721197 ,  0.61131513,  0.6442838 ,  2.2192965 , -1.1271117 ],      dtype=float32)

In [6]:
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [7]:
# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
    # Define the squared loss for a single pair (x,y)
    def squared_error(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y-pred, y-pred) / 2.0
    # Vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

In [8]:
learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
    params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grads)
    return params

for i in range(101):
    # Perform one gradient update.
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = update_params(params, learning_rate, grads)
    if i % 10 == 0:
        print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.023639798
Loss step 0:  35.343876
Loss step 10:  0.51505065
Loss step 20:  0.114045195
Loss step 30:  0.039395172
Loss step 40:  0.01994014
Loss step 50:  0.01421761
Loss step 60:  0.012428714
Loss step 70:  0.011851465
Loss step 80:  0.011662135
Loss step 90:  0.011599515
Loss step 100:  0.011578727


In [9]:
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

In [10]:
for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print(f'Loss step {i}: ', loss_val)

Loss step 0:  0.01157765
Loss step 10:  0.2614
Loss step 20:  0.07683634
Loss step 30:  0.036484905
Loss step 40:  0.022029908
Loss step 50:  0.016185088
Loss step 60:  0.01299824
Loss step 70:  0.012026562
Loss step 80:  0.011765008
Loss step 90:  0.0116460025
Loss step 100:  0.0115856035


In [11]:
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)

Dict output
{'params': {'bias': Array([-1.4555818 , -2.0278208 ,  2.0790894 ,  1.2186159 , -0.99807364],      dtype=float32), 'kernel': Array([[ 1.0098742 ,  0.18935174,  0.04456387, -0.9280217 ,  0.3478275 ],
       [ 1.7298428 ,  0.9879267 ,  1.1640524 ,  1.1006035 , -0.10651099],
       [-1.2029499 ,  0.28633362,  1.4155992 ,  0.11869171, -1.3141348 ],
       [-1.1941439 , -0.18958223,  0.03415466,  1.3169502 ,  0.08061308],
       [ 0.13851956,  1.3712972 , -1.3187258 ,  0.53151757, -2.2405186 ],
       [ 0.56294185,  0.8122366 ,  0.317533  ,  0.5345454 ,  0.9049949 ],
       [-0.37925696,  1.7410626 ,  1.0790431 , -0.5039805 ,  0.92827296],
       [ 0.97064775, -1.3153123 ,  0.33682504,  0.8099361 , -1.2018671 ],
       [ 1.0194423 , -0.62024164,  1.0818756 , -1.8389823 , -0.45808458],
       [-0.6436587 ,  0.45668206, -1.1329095 , -0.6853872 ,  0.1683104 ]],      dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14\x81P\xba\xbf\xd1\xc

In [12]:
serialization.from_bytes(params, bytes_output)

{'params': {'bias': array([-1.4555818 , -2.0278208 ,  2.0790894 ,  1.2186159 , -0.99807364],
        dtype=float32),
  'kernel': array([[ 1.0098742 ,  0.18935174,  0.04456387, -0.9280217 ,  0.3478275 ],
         [ 1.7298428 ,  0.9879267 ,  1.1640524 ,  1.1006035 , -0.10651099],
         [-1.2029499 ,  0.28633362,  1.4155992 ,  0.11869171, -1.3141348 ],
         [-1.1941439 , -0.18958223,  0.03415466,  1.3169502 ,  0.08061308],
         [ 0.13851956,  1.3712972 , -1.3187258 ,  0.53151757, -2.2405186 ],
         [ 0.56294185,  0.8122366 ,  0.317533  ,  0.5345454 ,  0.9049949 ],
         [-0.37925696,  1.7410626 ,  1.0790431 , -0.5039805 ,  0.92827296],
         [ 0.97064775, -1.3153123 ,  0.33682504,  0.8099361 , -1.2018671 ],
         [ 1.0194423 , -0.62024164,  1.0818756 , -1.8389823 , -0.45808458],
         [-0.6436587 ,  0.45668206, -1.1329095 , -0.6853872 ,  0.1683104 ]],
        dtype=float32)}}

In [13]:
class ExplicitMLP(nn.Module):
    features: Sequence[int]

    def setup(self):
        # we automatically know what to do with lists, dicts of submodules
        self.layers = [nn.Dense(feat) for feat in self.features]
        # for single submodules, we would just write:
        # self.layer1 = nn.Dense(feat1)

    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = nn.relu(x)
        return x

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00723787 -0.00810345 -0.0255093   0.02151708 -0.01261237]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


In [14]:
try:
    y = model(x) # Returns an error
except AttributeError as e:
    print(e)

"ExplicitMLP" object has no attribute "layers". If "layers" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.


In [15]:
class SimpleMLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, feat in enumerate(self.features):
            x = nn.Dense(feat, name=f'layers_{i}')(x)
            if i != len(self.features) - 1:
                x = nn.relu(x)
        # providing a name is optional though!
        # the default autonames would be "Dense_0", "Dense_1", ...
        return x

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00723787 -0.00810345 -0.0255093   0.02151708 -0.01261237]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


In [16]:
class SimpleDense(nn.Module):
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros_init()

    @nn.compact
    def __call__(self, inputs):
        kernel = self.param('kernel',
                            self.kernel_init, # Initialization function
                            (inputs.shape[-1], self.features))  # shape info.
        y = jnp.dot(inputs, kernel)
        bias = self.param('bias', self.bias_init, (self.features,))
        y = y + bias
        return y

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)

initialized parameters:
 {'params': {'kernel': Array([[ 0.61506   , -0.22728713,  0.60547006],
       [-0.2961799 ,  1.1232013 , -0.879759  ],
       [-0.35162622,  0.38064912,  0.68932474],
       [-0.1151355 ,  0.04567899, -1.091212  ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}}
output:
 [[-0.029962    1.102088   -0.6660265 ]
 [-0.31092793  0.6323942  -0.5367881 ]
 [ 0.0142401   0.9424717  -0.6356147 ]
 [ 0.36818963  0.35865188 -0.00459227]]


In [17]:
class BiasAdderWithRunningMean(nn.Module):
    decay: float = 0.99

    @nn.compact
    def __call__(self, x):
        # easy pattern to detect if we're initializing via empty variable tree
        is_initialized = self.has_variable('batch_stats', 'mean')
        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s),
                                x.shape[1:])
        bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
        if is_initialized:
            ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

        return x - ra_mean.value + bias


key1, key2 = random.split(random.key(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)

initialized variables:
 {'batch_stats': {'mean': Array([0., 0., 0., 0., 0.], dtype=float32)}, 'params': {'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}


In [18]:
for val in [1.0, 2.0, 3.0]:
    x = val * jnp.ones((10,5))
    y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
    old_state, params = flax.core.pop(variables, 'params')
    variables = flax.core.freeze({'params': params, **updated_state})
    print('updated state:\n', updated_state) # Shows only the mutable part

updated state:
 {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32)}}


In [19]:
from functools import partial

@partial(jax.jit, static_argnums=(0, 1))
def update_step(tx, apply_fn, x, opt_state, params, state):

    def loss(params):
        y, updated_state = apply_fn({'params': params, **state},
                                    x, mutable=list(state.keys()))
        _l = ((x - y) ** 2).sum()
        return _l, updated_state

    (_l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return opt_state, params, state

x = jnp.ones((10,5))
variables = model.init(random.key(0), x)
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(3):
    opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
    print('Updated state: ', state)

Updated state:  {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}
Updated state:  {'batch_stats': {'mean': Array([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32)}}
Updated state:  {'batch_stats': {'mean': Array([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32)}}
