<a href="https://colab.research.google.com/github/kbrezinski/JAX-Practice/blob/main/jax_low_level.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as jnp
import numpy as np

# special transformations
from jax import grad, jit, vmap, pmap
# JAX low level APIs
from jax import lax, make_jaxpr, random, device_put

In [None]:
# numpy <- lax <= XLA
jnp.add(1, 1.) # handles mixed types
lax.add(1, 1.0) # error, mixed types

In [9]:
# test out two implementations, 1 with jnp and 1 with lax
x = jnp.array([1, 2, 1])
y = jnp.ones(10)

res1 = jnp.convolve(x, y)
res2 = lax.conv_general_dilated(
    x.reshape(1,1,3).astype(float),
    y.reshape(1,1,10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)]
)

res2[0][0] ## use [0][0] to return batched result

DeviceArray([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

In [12]:
# in Jax random numbers are stateful, need to be passed in as args
k = random.PRNGKey(2021)

In [16]:
# small differences in time taken for normal and jitted versions
def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

norm_compiled = jit(norm)
X = random.normal(k, (10000, 100), dtype=jnp.float32)

In [18]:
# test an example of a jitted and non-jitted fn to see what jit allows
def get_negative(x):
  return x[x < 0]

x = random.normal(k, (10,), dtype=jnp.float32)

jit(get_negative)(x) # returns error, differnent potential sizes confuses jit

DeviceArray([-1.906434  , -0.04492167, -0.5956922 , -0.3326311 ,
             -2.6711135 ], dtype=float32)