In [3]:
import jax
import jax.numpy as jnp # type: ignore
import numpy as np # type: ignore

from jax import grad, jit, vmap, pmap # type: ignore
import matplotlib.pyplot import plt


In [5]:
simple_pytree = [
    [9, 5, None],
    [3, (2, 8)],
    {"p": 78},
    {"q": {"r": 5, "s": {"t": 100, "u": 200}}},
    4
]

for elem in simple_pytree:
    leaves = jax.tree_util.tree_leaves(elem)
    print(f"{elem} contains {len(leaves)} leaves ... {leaves}")

[9, 5, None] contains 2 leaves ... [9, 5]
[3, (2, 8)] contains 3 leaves ... [3, 2, 8]
{'p': 78} contains 1 leaves ... [78]
{'q': {'r': 5, 's': {'t': 100, 'u': 200}}} contains 3 leaves ... [5, 100, 200]
4 contains 1 leaves ... [4]


In [7]:
print(jax.tree_util.tree_map(lambda x: x**2, simple_pytree))

[[81, 25, None], [9, (4, 64)], {'p': 6084}, {'q': {'r': 25, 's': {'t': 10000, 'u': 40000}}}, 16]


In [13]:
# For the above, we can use multiple different pytrees but they must have the same structure in order to be operated together.

In [18]:
def init_mlp_params(layer_widths):
    params = []

    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        params.append(
            dict(weight=np.random.normal(size=(n_in, n_out)) * np.sqrt(2 / n_in),
                bias=np.ones(shape=(n_out,))
            )
        )
    
    return params

params = init_mlp_params([1, 128, 128, 1])
jax.tree_util.tree_map(lambda x: x.shape, params)

[{'bias': (128,), 'weight': (1, 128)},
 {'bias': (128,), 'weight': (128, 128)},
 {'bias': (1,), 'weight': (128, 1)}]

In [19]:
def forward(params, x):
    *hidden, last = params

    for layer in hidden:
        x  = jax.nn.relu(jnp.dot(x, layer["weight"]) + layer["bias"])

    return jax.dot(x, last["weight"]) + last["bias"]

def loss_fn(params, x, y):
    return jnp.mean((forward(params, x) - y) ** 2)

lr = 0.0001

@jit
def update(params, x, y):

    grads = jax.grad(loss_fn)(params, x, y)

    return jax.tree.map(lambda p, g: p - lr * g, params, grads)

In [17]:
xs = np.random.normal(size=(128, 1))
ys = xs ** 2

num_epochs = 5000
for _ in range(num_epochs):
    params = update(params, xs, ys)

plt.scatter(xs, ys)
plt.scatter(xs, forward(params, xs), label="Model predictions")
plt.legend();

KeyError: 'bias'