In [27]:
import jax.numpy as jnp
import jax.nn as nn
import jax
import itertools

input_dim = 2
hidden_dim = 32
output_dim = 1

samples = jnp.linspace(-jnp.pi, jnp.pi, num=10)

x_values = jnp.array(list(itertools.product(samples, samples)))

initializer = jax.nn.initializers.glorot_normal()

params = {
    "weights_0": initializer(jax.random.key(1), (input_dim, hidden_dim), jnp.float32),
    "weights_1": initializer(jax.random.key(3), (hidden_dim, output_dim), jnp.float32),
    "bias_0": initializer(jax.random.key(2), (1, hidden_dim), jnp.float32),
    "bias_1": initializer(jax.random.key(4), (1, output_dim), jnp.float32),
}
x_values.shape

(100, 2)

In [25]:
def forward(params, x):
    weights_0 = params["weights_0"]
    bias_0 = params["bias_0"]
    weights_1 = params["weights_1"]
    bias_1 = params["bias_1"]

    hidden = nn.sigmoid(x @ weights_0 + bias_0)
    output = hidden @ weights_1 + bias_1
    return output

%timeit forward(params, x_values).block_until_ready()

40.4 μs ± 298 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [26]:
jit_forward = jax.jit(forward)

jit_forward(params, x_values).block_until_ready()

%timeit jit_forward(params, x_values).block_until_ready()

21.5 μs ± 1.91 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
