#Some Drawbacks

##In-Place Updates

In [6]:
import jax.numpy as jnp
import numpy as np
from jax import jit, grad, vmap
from jax import random
from jax import make_jaxpr
import matplotlib.pyplot as plt

In [7]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
updated_array = jax_array.at[1, :].set(1.0)

print("original array unchanged:\n", jax_array)
print("updated array:\n", updated_array)

original array unchanged:
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
 [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]


In [8]:
# The expresiveness of NumPy is still there!

print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

print("new array post-addition:")
new_jax_array = jax_array.at[::2, 3:].add(7.)
print(new_jax_array)

original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]


##Out-of-Bounds Indexing

Due to JAX's accelerator agnostic approach JAX had to make a non-error behaviour for out of bounds indexing (similarly to how invalid fp arithmetic results in NaNs and not an exception).

In [10]:
# NumPy behavior

try:
  np.arange(10)[11]
except Exception as e:
    print("Exception {}".format(e))

Exception index 11 is out of bounds for axis 0 with size 10


In [11]:
# JAX behavior

# 1) updates at out-of-bounds indices are skipped
# 2) retrievals result in index being clamped
# in general there are currently some bugs so just consider the behavior undefined!

print(jnp.arange(10).at[11].add(23))  # example of 1)
print(jnp.arange(10)[11])  # example of 2)

[0 1 2 3 4 5 6 7 8 9]
9


##Non-array inputs

- This is added by design (performance reasons)

In [12]:
# NumPy

print(np.sum([1, 2, 3]))

6


In [13]:
# JAX

try:
    jnp.sum([1, 2, 3])
except TypeError as e:
    print(f"TypeError: {e}")

TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.


In [15]:
# Why? jaxpr to the rescue!

def permissive_sum(x):
    return jnp.sum(jnp.array(x))

x = list(range(10))
print(make_jaxpr(permissive_sum)(x))
# print(permissive_sum(x))

{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
    j:i32[]. let
    k:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] a
    l:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] b
    m:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] c
    n:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] d
    o:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
    p:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
    q:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] g
    r:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] h
    s:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] i
    t:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] j
    u:i32[10] = concatenate[dimension=0] k l m n o p q r s t
    v:i32[10] = convert_element_type[new_dtype=int32 weak_type=False] u
    w:i32[] = reduce_sum[axes=(0,)] v
  in (w,) }
