<a href="https://colab.research.google.com/github/kbrezinski/JAX-Practice/blob/main/jax_low_level.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import jax.numpy as jnp
import numpy as np

# special transformations
from jax import grad, jit, vmap, pmap
# JAX low level APIs
from jax import lax, make_jaxpr, random, device_put

In [None]:
# numpy <- lax <= XLA
jnp.add(1, 1.) # handles mixed types
lax.add(1, 1.0) # error, mixed types

In [None]:
# test out two implementations, 1 with jnp and 1 with lax
x = jnp.array([1, 2, 1])
y = jnp.ones(10)

res1 = jnp.convolve(x, y)
res2 = lax.conv_general_dilated(
    x.reshape(1,1,3).astype(float),
    y.reshape(1,1,10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)]
)

res2[0][0] ## use [0][0] to return batched result

DeviceArray([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

In [None]:
# in Jax random numbers are stateful, need to be passed in as args
k = random.PRNGKey(2021)

In [None]:
# small differences in time taken for normal and jitted versions
def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

norm_compiled = jit(norm)
X = random.normal(k, (10000, 100), dtype=jnp.float32)

In [None]:
# test an example of a jitted and non-jitted fn to see what jit allows
def get_negative(x):
  return x[x < 0]

x = random.normal(k, (10,), dtype=jnp.float32)

jit(get_negative)(x) # returns error, different potential sizes confuses jit
# printing in a function sometimes is lost due to tracing, cachine mechanism causes side effects

DeviceArray([-1.906434  , -0.04492167, -0.5956922 , -0.3326311 ,
             -2.6711135 ], dtype=float32)

In [8]:
# another example of failure
@jit
def f(x, neg):
  return -x if neg else x

# f(1, True) # cannot depend on neg

from functools import partial
@partial(jit, static_argnums=(1,))
def f(x, neg):
  print(x)
  return -x if neg else x

f(1, True) # works because True is set to false
f(1, False) # works because Works, and caches again since static arg changes

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>


DeviceArray(1, dtype=int32, weak_type=True)

In [10]:
# use numpy prod instead of jnz
@jit
def f(x):
  return x.reshape((np.prod(x.shape),))
  #return x.reshape(jnp.array(x.shape).prod()) # doesn't work

x = jnp.ones((2,3))
f(x)

DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)

In [18]:
# pure function can still be stateful within a function
def pure_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 ==0 else 'odd'] += x

  return state

jit(pure_internal_state)(5.)

{'even': DeviceArray(25., dtype=float32, weak_type=True),
 'odd': DeviceArray(25., dtype=float32, weak_type=True)}

In [24]:
# no use of iterators as they are stateful
arr = jnp.arange(10)
lax.fori_loop(0, 10, lambda i,x: x + arr[i], 0) # smart version of for loop in XLA

iterator = iter(range(10))
lax.fori_loop(0, 10, lambda i,x: x + next(iterator),0) # doesn't work since iterators have internal state

DeviceArray(0, dtype=int32, weak_type=True)

In [25]:
jnp.arange(10).at[11].add(23) # doesnt return exception where index 11 does not exist
jnp.arange(10)[11] # index 11 does not exist, still reuturns 9th element

DeviceArray(9, dtype=int32)

In [30]:
## demonstrate how jax works for lists
print(np.sum([1, 2, 3]))

# doesnt work since jax needs numpy array
try:
  jnp.sum([1,2,3])
except TypeError as e:
  print(f"{e}")

# jax implementation
def permissive_sum(x):
  return jnp.sum(jnp.array(x))

# inefficient since each element of the list is passed element by element
x = list(range(10))
jit(permissive_sum)(x)

6
sum requires ndarray or scalar arguments, got <class 'list'> at position 0.


DeviceArray(45, dtype=int32)