Pytrees from tutorial
https://docs.jax.dev/en/latest/working-with-pytrees.html

In [1]:
import jax
import jax.numpy as jnp


example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {"k1": 2, "k2": (3, 4)}, 5],
    {"a": 2, "b": (2, 3)},
    jnp.array([1, 2, 3])
]

In [2]:
for pytree in example_trees:
    leaves = jax.tree.leaves(pytree)
    print(f"{repr(pytree):<45} has {len(leaves)} {leaves=}")


[1, 'a', <object object at 0x7e08940f7ff0>]   has 3 leaves=[1, 'a', <object object at 0x7e08940f7ff0>]
(1, (2, 3), ())                               has 3 leaves=[1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves=[1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves=[2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves=[Array([1, 2, 3], dtype=int32)]


In [3]:
list_of_lists = [[1, 2, 3], [3, 4], [1, 2, 3, 4, 5]]

jax.tree.map(lambda x: x * 2, list_of_lists)

[[2, 4, 6], [6, 8], [2, 4, 6, 8, 10]]

In [4]:
import numpy as np


def init_mlp_params(layer_widths):
    params = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        params.append(
            dict(
                weights=np.random.randn(n_in, n_out) * np.sqrt(2 / n_in),
                biases=np.ones(shape=(n_out,)),
            )
        )
    return params


params = init_mlp_params([1, 128, 128, 1])

jax.tree.map(lambda x: x.shape, params)

[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]