Working with Pytrees

https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html

In [None]:
import jax
import numpy as np
import jax.numpy as jnp

In [None]:
'''
Pytrees are found in 
- Model parameters and descriptions
- Datasets
- RL states/observations
'''

example_pytree = [
    jnp.array(-1),
    [1, 2, 3],
    {'foo' : [1, 2, 3],'bar':{'bar':'foo'}},
]

leaves = jax.tree_leaves(example_pytree)
leaves

In [None]:
from typing import List

def init_mlp(layers: int, key=None) -> List[jnp.DeviceArray]:
    if key is None:
        key = jax.random.PRNGKey(0)
        _key = jax.random.split(jax.random.PRNGKey(0), num=2)
    params = []
    for n_in, n_out in zip(layers[:-1], layers[1:]):
        print(f'n_in {n_in}, n_out {n_out}')
        params.append({
            'W': \
                jnp.sqrt(2/n_in)*\
                jax.random.normal(key=_key[0],shape=(n_in, n_out)),
            'b': jnp.ones(shape=(n_out,)),
        })
    return params

params = init_mlp([1, 128, 128, 1])
params

In [None]:
# For everything in params, give me the shape
jax.tree_map(lambda x: x.shape, params)

In [None]:
a = jnp.array([[1, 2, 3],[4, 5, 6]]).astype(jnp.float32)
b = jnp.array([1, 2, 3]).astype(jnp.float32)
a @ b == jnp.dot(a, b)

In [None]:
def forward(params: jax.tree_util.PyTreeDef, x: jnp.DeviceArray) -> jnp.DeviceArray:
    *hidden, last = params
    for layer in hidden:
        x = jnp.dot(x, layer['W']) + layer['b']
        x = jax.nn.relu(x)
    # No activation function in last layer
    return jnp.dot(x, last['W']) + last['b']

def loss(params: jax.tree_util.PyTreeDef, x: jnp.DeviceArray, y:jnp.DeviceArray) -> jnp.DeviceArray:
    return jnp.mean((forward(params, x) - y)**2)

@jax.jit
def update(params, x, y, lr=1e-4):
    grads = jax.grad(loss)(params, x, y)
    # grads and params will have same structure, thus you can map
    return jax.tree_multimap(
        lambda params, grads: params - lr*grads, params, grads
    )    

In [None]:
import matplotlib.pyplot as plt

# Fake quadratic dataset
key = jax.random.PRNGKey(0)
xs = jax.random.normal(key=key, shape=(256, 1))
ys = xs ** 2

# initialize an mlp
params = init_mlp([1, 128, 128, 1])

# 1000 epochs of a 256 batch of dataset of size 256
for _ in range(1000):
    params = update(params, xs, ys)

plt.scatter(xs, ys, label='ground truth')
plt.scatter(xs, forward(params, xs), label='model prediction')
plt.legend()


In [None]:
from typing import Tuple, Iterable, Any

class Person:
    def __init__(self, name: str, age: int, left_handed: bool):
        self.name = name
        self.age = age
        self.left_handed = left_handed

# define flatten and un-flatten ops
def flatten_Person(person: Person) -> Tuple[Iterable[Any], str]:
    # choose what properties are considered leaves
    # data, auxiliary_data
    return [person.age, person.left_handed], person.name

def unflatten_Person(aux_data: str, flat_contents: Iterable[Any]) -> Person:
    return Person(aux_data, *flat_contents)

jax.tree_util.register_pytree_node(Person, flatten_Person, unflatten_Person)

pytree_with_people = [
    'foo',
    Person('bob', 20, False),
    Person('joe', 30, True),
]

jax.tree_leaves(pytree_with_people)