# JAX API structure

<img src="https://chefsmandala.com/wp-content/uploads/2018/03/Onion-Red.jpg" width="640">

- NumPy <-> lax <-> XLA
- lax API is stricter and more powerful
- It's a Python wrapper around XLA

In [7]:
import jax.numpy as jnp
import numpy as np
from jax import jit, grad, vmap
from jax import random
import matplotlib.pyplot as plt

# JAX's low level API 
# (lax is just an anagram for XLA, not completely sure how they came up with name JAX)
from jax import lax

In [8]:
# Example 1: lax is stricter

print(jnp.add(1, 1.0))  # jax.numpy API implicitly promotes mixed types
print(lax.add(1, 1.0))  # jax.lax API requires explicit type promotion

2.0


TypeError: ignored

In [9]:
# Example 2: lax is more powerful (but as a tradeoff less user-friendly)

x = jnp.array([1, 2, 1])
y = jnp.ones(10)

# NumPy API
result1 = jnp.convolve(x, y)

# lax API
result2 = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy

print(result1)
print(result2[0][0])
assert np.allclose(result1, result2[0][0], atol=1E-6)

# XLA: https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution

[1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.]
[1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.]
