In [1]:
from flax import linen as nn

model = nn.Dense(features=5)

In [2]:
from jax import random
import jax.numpy as jnp
import jax

key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,)) # Dummy input data
params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes

2024-07-15 00:17:03.018309: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


{'params': {'bias': (5,), 'kernel': (10, 5)}}

In [3]:
model.apply(params, x)

Array([-1.3721199 ,  0.611315  ,  0.64428365,  2.2192967 , -1.1271119 ],      dtype=float32)

In [4]:
import flax

# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [5]:
# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

In [6]:
learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.023639796
Loss step 0:  35.343876
Loss step 10:  0.51434684
Loss step 20:  0.11384165
Loss step 30:  0.039326724
Loss step 40:  0.019916201
Loss step 50:  0.014209116
Loss step 60:  0.012425651
Loss step 70:  0.011850391
Loss step 80:  0.011661771
Loss step 90:  0.011599408
Loss step 100:  0.011578708


In [7]:
import optax


tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

In [8]:
for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.011577626
Loss step 10:  0.26143175
Loss step 20:  0.07674778
Loss step 30:  0.0364394
Loss step 40:  0.022012014
Loss step 50:  0.016178384
Loss step 60:  0.013002939
Loss step 70:  0.012026127
Loss step 80:  0.011764488
Loss step 90:  0.011646035
Loss step 100:  0.011585518


In [9]:
from flax import serialization

bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)

Dict output
{'params': {'bias': Array([-1.4555764, -2.027799 ,  2.0790977,  1.2186146, -0.9980985],      dtype=float32), 'kernel': Array([[ 1.0098803 ,  0.18934338,  0.04455002, -0.9280221 ,  0.34784055],
       [ 1.7298455 ,  0.987937  ,  1.1640465 ,  1.1006078 , -0.10653927],
       [-1.2029461 ,  0.28635207,  1.4155982 ,  0.11870942, -1.3141488 ],
       [-1.1941484 , -0.1895852 ,  0.0341387 ,  1.3169428 ,  0.08060375],
       [ 0.13852431,  1.3713043 , -1.3187188 ,  0.53152657, -2.2404993 ],
       [ 0.5629401 ,  0.8122313 ,  0.3175202 ,  0.5345511 ,  0.9050041 ],
       [-0.37926012,  1.7410393 ,  1.0790291 , -0.5039834 ,  0.92830706],
       [ 0.9706488 , -1.3153405 ,  0.33681518,  0.8099343 , -1.2018454 ],
       [ 1.0194312 , -0.6202478 ,  1.0818834 , -1.838974  , -0.45804858],
       [-0.6436537 ,  0.45666704, -1.1329136 , -0.6853865 ,  0.16828986]],      dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14TP\xba\xbfu\xc7\x01\xc0\x

In [10]:
serialization.from_bytes(params, bytes_output)

{'params': {'bias': array([-1.4555764, -2.027799 ,  2.0790977,  1.2186146, -0.9980985],
        dtype=float32),
  'kernel': array([[ 1.0098803 ,  0.18934338,  0.04455002, -0.9280221 ,  0.34784055],
         [ 1.7298455 ,  0.987937  ,  1.1640465 ,  1.1006078 , -0.10653927],
         [-1.2029461 ,  0.28635207,  1.4155982 ,  0.11870942, -1.3141488 ],
         [-1.1941484 , -0.1895852 ,  0.0341387 ,  1.3169428 ,  0.08060375],
         [ 0.13852431,  1.3713043 , -1.3187188 ,  0.53152657, -2.2404993 ],
         [ 0.5629401 ,  0.8122313 ,  0.3175202 ,  0.5345511 ,  0.9050041 ],
         [-0.37926012,  1.7410393 ,  1.0790291 , -0.5039834 ,  0.92830706],
         [ 0.9706488 , -1.3153405 ,  0.33681518,  0.8099343 , -1.2018454 ],
         [ 1.0194312 , -0.6202478 ,  1.0818834 , -1.838974  , -0.45804858],
         [-0.6436537 ,  0.45666704, -1.1329136 , -0.6853865 ,  0.16828986]],
        dtype=float32)}}

In [11]:
from typing import Sequence

class ExplicitMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00723789 -0.00810346 -0.02550935  0.02151712 -0.01261239]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


In [12]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00723789 -0.00810346 -0.02550935  0.02151712 -0.01261239]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


In [13]:
from typing import Callable

class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros_init()

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
    y = jnp.dot(inputs, kernel)
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)

initialized parameters:
 {'params': {'kernel': Array([[ 0.61506   , -0.22728713,  0.6054702 ],
       [-0.29617992,  1.1232015 , -0.879759  ],
       [-0.35162625,  0.38064915,  0.68932486],
       [-0.1151355 ,  0.04567895, -1.0912124 ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}}
output:
 [[-0.02996206  1.1020882  -0.66602665]
 [-0.31092796  0.63239425 -0.5367882 ]
 [ 0.01424006  0.9424719  -0.63561475]
 [ 0.36818963  0.358652   -0.0045922 ]]


In [14]:
class BiasAdderWithRunningMean(nn.Module):
  decay: float = 0.99

  @nn.compact
  def __call__(self, x):
    # easy pattern to detect if we're initializing via empty variable tree
    is_initialized = self.has_variable('batch_stats', 'mean')
    ra_mean = self.variable('batch_stats', 'mean',
                            lambda s: jnp.zeros(s),
                            x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

    return x - ra_mean.value + bias


key1, key2 = random.split(random.key(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)

initialized variables:
 {'batch_stats': {'mean': Array([0., 0., 0., 0., 0.], dtype=float32)}, 'params': {'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}


In [15]:
for val in [1.0, 2.0, 3.0]:
  x = val * jnp.ones((10,5))
  y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
  old_state, params = flax.core.pop(variables, 'params')
  variables = flax.core.freeze({'params': params, **updated_state})
  print('updated state:\n', updated_state) # Shows only the mutable part

updated state:
 {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32)}}


In [16]:
from functools import partial

@partial(jax.jit, static_argnums=(0, 1))
def update_step(tx, apply_fn, x, opt_state, params, state):

  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum()
    return l, updated_state

  (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return opt_state, params, state

x = jnp.ones((10,5))
variables = model.init(random.key(0), x)
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(3):
  opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
  print('Updated state: ', state)

Updated state:  {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}
Updated state:  {'batch_stats': {'mean': Array([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32)}}
Updated state:  {'batch_stats': {'mean': Array([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32)}}
