# Setting up environments

In [1]:
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
import flax
from flax import linen as nn

# Linear regression with Flax

In [2]:
model = nn.Dense(features=5)

## Model parameters & initialization

In [24]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10, ))
# model.init(PRNGKey, dummy data)
params = model.init(key2, x)
jax.tree_util.tree_map(lambda x: x.shape, params)

{'params': {'bias': (5,), 'kernel': (10, 5)}}

In [25]:
output, params = model.init_with_output(key2, x)
print(output.shape)
print(jax.tree_util.tree_map(lambda x: x.shape, params))

(5,)
{'params': {'bias': (5,), 'kernel': (10, 5)}}


In [26]:
model.apply(params, x)

Array([-1.4309565 ,  0.68059814, -0.47063   , -0.00443757,  0.61720204],      dtype=float32)

## Gradient descent

In [27]:
n_samples = 20
x_dim = 10
y_dim = 5

key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim, ))
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})
print(type(true_params))

key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = (jnp.dot(x_samples, W) + b 
             + 0.1 * random.normal(key_noise, (n_samples, y_dim)))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

<class 'flax.core.frozen_dict.FrozenDict'>
x shape: (20, 10) ; y shape: (20, 5)


In [28]:
@jax.jit
def mse(params, x_batched, y_batched):
    def squared_error(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y-pred, y-pred) / 2.0
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched))

In [29]:
learning_rate = 0.3
print('Loss for "true W, b:', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
    params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grads)
    return params

for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = update_params(params, learning_rate, grads)
    if i % 10 == 0:
        print(f'Loss step {i}:', loss_val)

Loss for "true W, b: 0.023639798
Loss step 0: 35.717175
Loss step 10: 0.52654123
Loss step 20: 0.11472414
Loss step 30: 0.037997153
Loss step 40: 0.019133301
Loss step 50: 0.013885747
Loss step 60: 0.012306851
Loss step 70: 0.011808872
Loss step 80: 0.011647595
Loss step 90: 0.011594624
Loss step 100: 0.011577098


## Optimizing with Optax

In [3]:
import optax
learning_rate = 0.3
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
params = model.init(random.PRNGKey(0), x_samples)

for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print(f'Loss step {i}: {loss_val}')

NameError: name 'params' is not defined

## Serializing the result

In [34]:
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Bytes output:', bytes_output)
print('Dict output:', dict_output)

Bytes output: b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14\xd44\xb9\xbf\x82\xa5\x01\xc0\x1f{\x05@=[\x9c?\xb6E\x80\xbf\xa6kernel\xc7\xd6\x01\x93\x92\n\x05\xa7float32\xc4\xc8I\x8f\x80?\x88\x82=>\n\xc71=p\xdam\xbf4j\xb3>\x86\xcc\xdc?\xa9Y~?\xdd\xd0\x95?W\xc2\x8d?^~\xd9\xbd\x90\xad\x98\xbfT8\x92>\x1f\xa9\xb5?n\x9e\xfd=\x16\x1d\xa8\xbf\x86{\x98\xbf\xf5\x88?\xbe\x8e\\$=\x8cg\xa9?l\x1f\xad=X\xa5\x0c>\x1f\xfd\xae?\xaf\xd6\xa8\xbfe\xd4\t?ax\x0f\xc0\x96\x8c\x0f?\xc8\xd0P?\x96:\xa6>\x0f\xc7\t?\xc8Qg?F\xcf\xc3\xbeT\t\xdf?)\x08\x8a?\xfd\x10\x01\xbf\xf6\x14m?\xbe\xa1w?\xe8)\xa8\xbfr)\xab>&NP?\xef\xbe\x99\xbfI=\x82?\xf8-\x1f\xbf\xad\\\x89??n\xeb\xbf\xe4\x1c\xeb\xbe\x92T&\xbf\xa40\xec>\x99j\x91\xbf\xda?/\xbfR\xc2)>'
Dict output: {'params': {'bias': Array([-1.4469247, -2.0257268,  2.0856397,  1.2215344, -1.0021274],      dtype=float32), 'kernel': Array([[ 1.0043727 ,  0.18506825,  0.04340271, -0.92911434,  0.35041964],
       [ 1.7249916 ,  0.9935556 ,  1.1704365 ,  1.1074933 , 

In [38]:
# params here is a template
serialization.from_bytes(params, bytes_output)

{'params': {'bias': array([-1.4469247, -2.0257268,  2.0856397,  1.2215344, -1.0021274],
        dtype=float32),
  'kernel': array([[ 1.0043727 ,  0.18506825,  0.04340271, -0.92911434,  0.35041964],
         [ 1.7249916 ,  0.9935556 ,  1.1704365 ,  1.1074933 , -0.10619806],
         [-1.1927967 ,  0.285586  ,  1.4192237 ,  0.12383734, -1.3133876 ],
         [-1.1912696 , -0.18704589,  0.04012733,  1.3234725 ,  0.08453259],
         [ 0.13734949,  1.3670996 , -1.3190516 ,  0.53839713, -2.2417223 ],
         [ 0.5607389 ,  0.81568575,  0.32466573,  0.53819364,  0.90359163],
         [-0.38244075,  1.7424722 ,  1.078374  , -0.5041655 ,  0.9261011 ],
         [ 0.96731174, -1.3137789 ,  0.33430058,  0.81369245, -1.2011393 ],
         [ 1.0174953 , -0.6217952 ,  1.0731407 , -1.839302  , -0.4592048 ],
         [-0.64972794,  0.4613086 , -1.1360656 , -0.68456805,  0.16578034]],
        dtype=float32)}}

# Defining your own model

## Module basics