In [1]:
from sklearn.datasets import make_regression
import equinox as eqx
from jaxtyping import Float, Array
import jax

from jax.extend.backend import get_backend
from rich.pretty import pprint

In [2]:
a = get_backend().platform

In [3]:
pprint(a)

In [4]:
data = make_regression(random_state=10, n_samples=int(1e5), n_features=int(1e3))
x, y = data

In [5]:
x.shape, y.shape

((100000, 1000), (100000,))

In [6]:
class Linear(eqx.Module):
    stack: list

    def __init__(self, rng):
        key1, key2, key3 = jax.random.split(rng, 3)
        self.stack = [
            eqx.nn.Linear(in_features=1000, out_features=64, key=key1),
            eqx.nn.Linear(in_features=64, out_features=64, key=key2),
            eqx.nn.Linear(in_features=64, out_features=1, key=key3),
        ]

    def __call__(self, x: Float[Array, "1 100"]) -> Float[Array, "1"]:  # noqa: F722
        for layer in self.stack:
            x = jax.nn.relu(layer(x))
        return x

In [7]:
rng = jax.random.PRNGKey(10)
model = Linear(rng)
compiled_model = jax.jit(model)

In [8]:
print(model)

Linear(
  stack=[
    Linear(
      weight=f32[64,1000],
      bias=f32[64],
      in_features=1000,
      out_features=64,
      use_bias=True
    ),
    Linear(
      weight=f32[64,64],
      bias=f32[64],
      in_features=64,
      out_features=64,
      use_bias=True
    ),
    Linear(
      weight=f32[1,64],
      bias=f32[1],
      in_features=64,
      out_features=1,
      use_bias=True
    )
  ]
)


In [9]:
def loss(model, x, y):
    pred = jax.vmap(model)(x)
    return pred - jax.numpy.expand_dims(y, 1) ** 2


compiled_loss = jax.jit(loss)

In [10]:
# import timeit

# start = timeit.default_timer()

# loss(model, x, y).mean()

# end = timeit.default_timer()

# end - start

In [11]:
# import timeit

# start = timeit.default_timer()
# grads = jax.grad(loss, argnums=[0])(model, x, y)

# end = timeit.default_timer()

# end - start

In [12]:
compiled_loss(model, x, y).block_until_ready()

%timeit compiled_loss(model, x, y).block_until_ready()

149 ms ± 1.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
%timeit loss(model, x, y).block_until_ready()

154 ms ± 6.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
import jax
import jax.numpy as jnp


def selu(x, alpha=1.67, lambda_=1.05):
    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

198 μs ± 876 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [15]:
selu_jit = jax.jit(selu)

# Pre-compile the function before timing...
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

62.9 μs ± 3.27 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [16]:
def func(x, a, b):
    return x - jnp.log(x + (1 / a)) - b**2 + jnp.log(b - a)

In [17]:
(
    func(10.0, -2.0, 8.0),
    jax.grad(func)(10.0, -2.0, 8.0),
    jax.grad(jax.grad(func))(10.0, -2.0, 8.0),
)

(Array(-53.948708, dtype=float32, weak_type=True),
 Array(0.8947368, dtype=float32, weak_type=True),
 Array(0.01108033, dtype=float32, weak_type=True))

In [18]:
(
    func(10.0, -2.0, 8.0),
    jax.grad(func, argnums=1)(10.0, -2.0, 8.0),
    jax.grad(jax.grad(func, argnums=1), argnums=0)(10.0, -2.0, 8.0),
)

(Array(-53.948708, dtype=float32, weak_type=True),
 Array(-0.07368422, dtype=float32, weak_type=True),
 Array(-0.00277008, dtype=float32, weak_type=True))

In [19]:
value, grad = jax.value_and_grad(jax.grad(func))(10.0, -2.0, 8.0)
value, grad

(Array(0.8947368, dtype=float32, weak_type=True),
 Array(0.01108033, dtype=float32, weak_type=True))

In [20]:
def func(params):
    a = params["a"]
    b = params["b"]
    x = params["x"]
    return x - jnp.log(x + (1 / a)) - b**2 + jnp.log(b - a)

In [21]:
params = {"x": 10.0, "a": -2.0, "b": 8.0}
func(params), jax.grad(func)(params)

(Array(-53.948708, dtype=float32, weak_type=True),
 {'a': Array(-0.07368422, dtype=float32, weak_type=True),
  'b': Array(-15.9, dtype=float32, weak_type=True),
  'x': Array(0.8947368, dtype=float32, weak_type=True)})