In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.tree_util import register_pytree_node

### Using structured parameters

In [26]:
w = jnp.array([0.5, 0.4, 0.6, 0.8, 0.9], dtype=jnp.float32)
b = jnp.array([-1,-3,-6,-19,-2], dtype=jnp.float32)
x = jnp.array([1,2,3,4,5], dtype=jnp.float32)

model = {
    "w": w,
    "b": b
}

In [14]:
def _linear(x, model):
    return jnp.sum(jnp.dot(x, model["w"]) + model["b"])

linear = jit(vmap(_linear, in_axes=(0, {"w": None, "b": None}))) # in axis matches the structure definition

In [27]:
X = jnp.array([
                [1,2,3,4,5],
                [6,7,8,9,10],
                [11,12,13,14,15],
              ], dtype=jnp.float32)

In [28]:
linear(X, model)

DeviceArray([ 23.     , 103.00001, 183.     ], dtype=float32)

### Using structured parameters (batch based on internal structure)

In [32]:
linear_vectorize_on_w = jit(vmap(_linear, in_axes=(None, {"w": None, "b": 0})))

In [33]:
x = jnp.array([1,2,3,4,5], dtype=jnp.float32)

B = jnp.array([
                [-1,-3,-6,-19,-2],
                [0, 0, 0, 0, 0],
              ], dtype=jnp.float32)

In [34]:
linear_vectorize_on_w(x, {"w": w, "b": B})

DeviceArray([23., 54.], dtype=float32)

### Using custom classes (structures)

In [36]:
class LinearModel:
    
    def __init__(self, w, b):
        self.w = w
        self.b = b
        
    def __repr__(self):
        return "LinearModel(w={}, b={})".format(self.w, self.b)
        
        
def linear_flatten(model):
      children = (model.w, model.b)
      metadata = None
      return (children, metadata)


def linear_unflatten(metadata, children):
    (w, b) = children
    return LinearModel(w, b)

register_pytree_node(
    LinearModel,
    linear_flatten,    # tell JAX what are the children nodes
    linear_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)


In [37]:
model = LinearModel(
    jnp.array([0.5, 0.4, 0.6, 0.8, 0.9], dtype=jnp.float32),
    jnp.array([-1,-3,-6,-19,-2], dtype=jnp.float32)
)

In [38]:
def _linear(x, model):
    return jnp.sum(jnp.dot(x,model.w) + model.b)

linear = jit(vmap(_linear, in_axes=(None, LinearModel(None, 0))))

In [39]:
linear(x, LinearModel(w, B))

DeviceArray([23., 54.], dtype=float32)