## Generating random initializations for a neural network

In [None]:
LAYER_SIZES = [200*200*3, 2048, 1024, 2]
PARAM_SCALE = 0.01

In [None]:
import jax
import jax.numpy as jnp
from jax import random

In [None]:
def random_layer_params(m, n, key, scale=1e-2):
  """A helper function to randomly initialize weights and biases of a dense layer""" 
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

def init_network_params(sizes, key=random.PRNGKey(0), scale=0.01):
  """Initialize all layers for a fully-connected neural network with given sizes"""
  keys = random.split(key, len(sizes)-1)
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]



In [None]:
key = random.PRNGKey(42)

In [None]:
params = init_network_params(LAYER_SIZES, key, scale=PARAM_SCALE)

In [None]:
for i,layer in enumerate(params):
  w,b = layer
  print(i, w.shape, b.shape)

0 (2048, 120000) (2048,)
1 (1024, 2048) (1024,)
2 (2, 1024) (2,)


In [None]:
shapes = jax.tree_util.tree_map(lambda p: p.shape, params)

In [None]:
for i,shape in enumerate(shapes):
  print(i, shape)

0 ((2048, 120000), (2048,))
1 ((1024, 2048), (1024,))
2 ((2, 1024), (2,))


In [None]:
jax.tree_util.tree_leaves(params)

[Array([[-0.01233545,  0.00821559, -0.00978333, ..., -0.00762261,
          0.01153923, -0.00699568],
        [ 0.00595356, -0.00548696,  0.00862382, ..., -0.00660049,
         -0.01387328,  0.00377337],
        [-0.00239754, -0.01310708,  0.01655764, ...,  0.00812733,
         -0.0122619 ,  0.0073874 ],
        ...,
        [ 0.01216634, -0.01617356, -0.0034067 , ...,  0.00477375,
         -0.00057253,  0.00784415],
        [-0.01213108, -0.00440847, -0.02979285, ...,  0.00520762,
         -0.0157708 ,  0.00563094],
        [-0.00177029, -0.00257568,  0.01720736, ...,  0.00065184,
         -0.00535367, -0.00308625]], dtype=float32),
 Array([ 0.00885022,  0.0077482 , -0.00685802, ..., -0.0108596 ,
        -0.02526451,  0.00504387], dtype=float32),
 Array([[ 0.00833544, -0.00892102,  0.01756026, ...,  0.00956038,
          0.02225046, -0.00698299],
        [-0.00055303, -0.00299709,  0.0233378 , ..., -0.00282513,
          0.00495245, -0.00278503],
        [-0.0128897 , -0.01443312,  0.

## Leaves and nodes

In [None]:
import numpy as np
import jax.numpy as jnp
import collections 

In [None]:
Point = collections.namedtuple('Point', ['x', 'y'])

In [None]:
example_pytree = [
    {
        'a': [1, 2, 3], 
        'b': jnp.array([1, 2, 3]),
        'c': np.array([1, 2, 3])
    },
    [42, [44, 46], None],
    31337,
    (50, (60, 70)),
    Point(640, 480),
    collections.OrderedDict([('a', 100), ('b', 200)]),
    'some string'
]

In [None]:
jax.tree_util.tree_leaves(example_pytree)

[1,
 2,
 3,
 Array([1, 2, 3], dtype=int32),
 array([1, 2, 3]),
 42,
 44,
 46,
 31337,
 50,
 60,
 70,
 640,
 480,
 100,
 200,
 'some string']

## Back to the MLP example from Chapter 2

In [None]:
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
#tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# as_supervised=True gives us the (image, label) as a tuple instead of a dict
data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True, 
                       with_info=True)

data_train = data['train']
data_test  = data['test']

Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /tmp/tfds/mnist/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to /tmp/tfds/mnist/3.0.1. Subsequent calls will reuse this data.


In [None]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS 
NUM_LABELS = info.features['label'].num_classes

In [None]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(data_train.map(preprocess).batch(32).prefetch(1))
test_data  = tfds.as_numpy(data_test.map(preprocess).batch(32).prefetch(1))

In [None]:
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [None]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer""" 
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [None]:
init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [None]:
def predict(params, image):
  """Function for per-example predictions."""
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

In [None]:
batched_predict = vmap(predict, in_axes=(None, 0))

In [None]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5

In [None]:
def loss(params, images, targets):
  """Categorical cross entropy loss function."""
  logits = batched_predict(params, images)
  log_preds = logits - logsumexp(logits)
  return -jnp.mean(targets*log_preds)

@jax.jit
def update(params, x, y, epoch_number):
  print(f"Params shapes: {jax.tree_util.tree_map(lambda p: p.shape, params)}")
  loss_value, grads = value_and_grad(loss)(params, x, y)
  print(f"Grads shapes: {jax.tree_util.tree_map(lambda p: p.shape, grads)}")  
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

In [None]:
x, y = next(iter(train_data))

In [None]:
x = jnp.reshape(x, (len(x), NUM_PIXELS))
y = one_hot(y, NUM_LABELS)

In [None]:
params, loss_value = update(init_params, x, y, 0)

Params shapes: [((512, 784), (512,)), ((10, 512), (10,))]
Grads shapes: [((512, 784), (512,)), ((10, 512), (10,))]


## Working with pytrees

Flatten/unflatten

In [None]:
params = init_network_params(LAYER_SIZES, key, scale=PARAM_SCALE)

In [None]:
scaled_params = jax.tree_util.tree_map(lambda p: 10*p, params)

In [None]:
some_pytree = [
    [1,1,1],
    [
        [10,10,10], [20, 20]
    ]
]

In [None]:
jax.tree_util.tree_map(lambda p: p+1, some_pytree)

[[2, 2, 2], [[11, 11, 11], [21, 21]]]

In [None]:
leaves, struct = jax.tree_util.tree_flatten(some_pytree)

In [None]:
leaves

[1, 1, 1, 10, 10, 10, 20, 20]

In [None]:
struct

PyTreeDef([[*, *, *], [[*, *, *], [*, *]]])

In [None]:
updated_leaves = map(lambda x: x+1, leaves)

In [None]:
jax.tree_util.tree_unflatten(struct, updated_leaves)

[[2, 2, 2], [[11, 11, 11], [21, 21]]]

Flatten/unflatten using a 1D array

In [None]:
from jax.flatten_util import ravel_pytree

In [None]:
leaves, unflatten_func = ravel_pytree(some_pytree)

In [None]:
leaves

Array([ 1,  1,  1, 10, 10, 10, 20, 20], dtype=int32)

In [None]:
unflatten_func

<function jax._src.flatten_util.ravel_pytree.<locals>.<lambda>(flat)>

In [None]:
unflatten_func(leaves)

[[Array(1, dtype=int32), Array(1, dtype=int32), Array(1, dtype=int32)],
 [[Array(10, dtype=int32), Array(10, dtype=int32), Array(10, dtype=int32)],
  [Array(20, dtype=int32), Array(20, dtype=int32)]]]

Reducing a tree

In [None]:
jax.tree_util.tree_reduce(lambda acc,value: acc+value, some_pytree, initializer=0)

73

Transposing a pytree

In [None]:
import math
from collections import namedtuple
Point = collections.namedtuple('Point', ['x', 'y'])

In [None]:
points = [
    Point(0.0, 0.0),
    Point(3.0, 0.0),
    Point(0.0, 4.0)
]

In [None]:
def rotate_point(p, theta):
  x = p.x * math.cos(theta) - p.y * math.sin(theta)
  y = p.x * math.sin(theta) + p.y * math.cos(theta)
  return Point(x,y)

In [None]:
rotate_point(points[1], math.pi)

Point(x=-3.0, y=3.6739403974420594e-16)

In [None]:
jax.vmap(rotate_point, in_axes=(0, None))(points, math.pi)

ValueError: ignored

In [None]:
jax.vmap(rotate_point, in_axes=(0, None))(jnp.array(points), math.pi)

AttributeError: ignored

In [None]:
points

[Point(x=0.0, y=0.0), Point(x=3.0, y=0.0), Point(x=0.0, y=4.0)]

In [None]:
jax.tree_util.tree_structure(points)

PyTreeDef([CustomNode(namedtuple[Point], [*, *]), CustomNode(namedtuple[Point], [*, *]), CustomNode(namedtuple[Point], [*, *])])

In [None]:
jax.tree_util.tree_structure(points[0])

PyTreeDef(CustomNode(namedtuple[Point], [*, *]))

In [None]:
jax.tree_util.tree_transpose(
  outer_treedef = jax.tree_util.tree_structure(points),
  inner_treedef = jax.tree_util.tree_structure(points[0]),
  pytree_to_transpose=points
)

TypeError: ignored

In [None]:
jax.tree_util.tree_structure([0 for p in points])

PyTreeDef([*, *, *])

In [None]:
points_t = jax.tree_util.tree_transpose(
  outer_treedef = jax.tree_util.tree_structure([0 for p in points]),
  inner_treedef = jax.tree_util.tree_structure(points[0]),
  pytree_to_transpose=points
)

In [None]:
points_t

Point(x=[0.0, 3.0, 0.0], y=[0.0, 0.0, 4.0])

In [None]:
jax.vmap(rotate_point, in_axes=(0, None))(points_t, math.pi)

ValueError: ignored

In [None]:
jax.tree_util.tree_leaves(points)

[0.0, 0.0, 3.0, 0.0, 0.0, 4.0]

In [None]:
jax.tree_util.tree_leaves(points_t)

[0.0, 3.0, 0.0, 0.0, 0.0, 4.0]

In [None]:
points_t_a = jax.tree_util.tree_map(lambda p: Point(jnp.array(p.x),jnp.array(p.y)) , points_t)

AttributeError: ignored

In [None]:
points_t_array = Point(jnp.array(points_t.x),jnp.array(points_t.y))

In [None]:
points_t_array

Point(x=Array([0., 3., 0.], dtype=float32), y=Array([0., 0., 4.], dtype=float32))

In [None]:
jax.vmap(rotate_point, in_axes=(0, None))(points_t_array, math.pi)

Point(x=Array([-0.0000000e+00, -3.0000000e+00, -4.8985874e-16], dtype=float32), y=Array([ 0.0000000e+00,  3.6739406e-16, -4.0000000e+00], dtype=float32))

## Custom nodes

In [None]:
class Layer:
  def __init__(self, name, w, b):
    self.w = w
    self.b = b
    self.name = 'name'


In [None]:
h1 = Layer('hidden1', jnp.zeros((100,20)), jnp.zeros((20,)))



In [None]:
pt = [
    jnp.ones(50),
    h1
]

In [None]:
jax.tree_util.tree_leaves(pt)

[Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32),
 <__main__.Layer at 0x7fc87db04490>]

In [None]:
jax.tree_map(lambda x: x*10, pt)

TypeError: ignored

In [None]:
def flatten_layer(container):
  flat_contents = [container.w, container.b]
  aux_data = container.name
  return flat_contents, aux_data

def unflatten_layer(aux_data, flat_contents):
  return Layer(aux_data, *flat_contents)

In [None]:
jax.tree_util.register_pytree_node(
    Layer, flatten_layer, unflatten_layer)

In [None]:
jax.tree_util.tree_leaves(pt)

[Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32),
 Array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.], dtype=float32)]

In [None]:
pt2 = jax.tree_map(lambda x: x+1, pt)

In [None]:
jax.tree_util.tree_leaves(pt2)

[Array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],      dtype=float32),
 Array([[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        ...,
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]], dtype=float32),
 Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1.], dtype=float32)]