In [1]:
import jax
import numpy as np
import jax.numpy as jnp
from jax import random

In [2]:
jax.__version__

'0.4.26'

# Similarities Between Numpy and Jax Arrays

In [3]:
array_np = np.array([1, 2, 3, 4, 5])
print(array_np.dtype)
array_np

int64


array([1, 2, 3, 4, 5])

In [4]:
array_jax = jnp.array([1, 2, 3, 4, 5])
print(array_jax.dtype)
array_jax

int32


Array([1, 2, 3, 4, 5], dtype=int32)

In [5]:
array_np = np.array([1, 2, 3, 4, 5], dtype=np.int32)
print(array_np.dtype)
array_np

int32


array([1, 2, 3, 4, 5], dtype=int32)

In [6]:
array_jax = jnp.array([1, 2, 3, 4, 5], dtype=jnp.int32)
array_jax

Array([1, 2, 3, 4, 5], dtype=int32)

In [7]:
print('Type of Numpy Array :', type(array_np))
print('Type of JAX Array :', type(array_jax))

Type of Numpy Array : <class 'numpy.ndarray'>
Type of JAX Array : <class 'jaxlib.xla_extension.ArrayImpl'>


In [8]:
# Side Note:
# Unlike DeviceArray, ArrayImpl cannot be directly imported from jaxlib.xla_extension.
# This is because it's considered part of the private API.
from jaxlib.xla_extension import ArrayImpl

x = jnp.array([1, 2, 3, 4, 5])
print(type(x) is ArrayImpl)  # This will print True

True


In [9]:
array_np = np.arange(10)
array_jax = jnp.arange(10)

array_np, array_np.dtype, array_jax

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 dtype('int64'),
 Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

In [10]:
array_np = np.linspace(1, 10, 10)
array_jax = jnp.linspace(1, 10, 10)

array_np, array_np.dtype, array_jax

(array([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]),
 dtype('float64'),
 Array([ 1.       ,  2.       ,  3.       ,  4.       ,  5.       ,
         6.       ,  7.0000005,  8.       ,  9.       , 10.       ],      dtype=float32))

In [11]:
print('Sum of numpy array elements :', array_np.sum())
print('Sum of jax array elements :', array_jax.sum())
print('Mean of numpy array elements :', array_np.mean())
print('Mean of jax array elements :', array_jax.mean())

Sum of numpy array elements : 55.0
Sum of jax array elements : 55.0
Mean of numpy array elements : 5.5
Mean of jax array elements : 5.5


In [12]:
array_np = np.array([[1, 2, 3], [4, 5, 6]])
print('Numpy array:\n', array_np)

array_np_transposed = array_np.T
print('Transpose of numpy array:\n', array_np_transposed)

Numpy array:
 [[1 2 3]
 [4 5 6]]
Transpose of numpy array:
 [[1 4]
 [2 5]
 [3 6]]


In [13]:
array_jax = jnp.array([[1, 2, 3], [4, 5, 6]])
print('JAX array:\n', array_jax)

array_jax_transposed = array_jax.T
print('Transpose of JAX array:\n', array_jax_transposed)

JAX array:
 [[1 2 3]
 [4 5 6]]
Transpose of JAX array:
 [[1 4]
 [2 5]
 [3 6]]


In [14]:
print('Original shape of numpy array:', array_np.shape)
print('Original shape of JAX array:', array_jax.shape)

array_np_reshaped = array_np.reshape(1, -1)
array_jax_reshaped = array_jax.reshape(1, -1)

print('Reshaped shape of numpy array:', array_np_reshaped.shape)
print('Reshaped numpy array:', array_np_reshaped.shape)

Original shape of numpy array: (2, 3)
Original shape of JAX array: (2, 3)
Reshaped shape of numpy array: (1, 6)
Reshaped numpy array: (1, 6)


# Differences Between Numpy and Jax Arrays

In [15]:
np.sum([2, 3, 4, 6]) # np.sum() works on a 'list' argument

15

In [16]:
jnp.sum([2, 3, 4, 6]) # jnp.sum() doesn't work on a 'list' argument
# jnp.sum(jnp.array([2, 3, 4, 6]))

TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

### JAX arrays are immutable

In [17]:
array_np = np.arange(10, dtype= np.int32)
array_jax = jnp.arange(10, dtype= jnp.int32)

array_np, array_jax

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32),
 Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

In [18]:
array_np[4] = 22222
array_jax[4] = 22222
array_np

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

## How to modify JAX arrays

In [21]:
array_jax_modified = array_jax.at[4].set(22222)
array_jax_modified

Array([    0,     1,     2,     3, 22222,     5,     6,     7,     8,
           9], dtype=int32)

In [24]:
# At each case, a new array is generated. Original array remains same.
print(array_jax.at[5].add(10))
print(array_jax.at[5].mul(2))

[ 0  1  2  3  4 15  6  7  8  9]
[ 0  1  2  3  4 10  6  7  8  9]


## Asynchronous dispatch and JAX array speed up

In [25]:
array_np = np.random.normal(size=(10_000, 10_000)).astype(np.float32)
array_jax = jax.random.normal(jax.random.PRNGKey(0), shape=(10_000, 10_000), dtype=jnp.float32)

print('Shape of numpy array: ', array_np.shape)
print('Shape of JAX array: ', array_jax.shape)

Shape of numpy array:  (10000, 10000)
Shape of JAX array:  (10000, 10000)


In [26]:
%time np.matmul(array_np, array_np)

print('Completed Numpy Operation')

CPU times: user 27.9 s, sys: 612 ms, total: 28.5 s
Wall time: 16.3 s
Completed Numpy Operation


In [32]:
%time np.asarray(jnp.matmul(array_jax, array_jax))

print('Completed JAX Operation')
# We have an additional overhead of converting jax result into numpy array here!

CPU times: user 74.2 ms, sys: 68.1 ms, total: 142 ms
Wall time: 678 ms
Completed JAX Operation


In [33]:
# Another way to time JAX operations (rather than converting into Numpy arrays):
%time jnp.matmul(array_jax, array_jax).block_until_ready()

CPU times: user 3.29 ms, sys: 2.95 ms, total: 6.24 ms
Wall time: 489 ms


Array([[ -90.649605 ,   13.972294 ,  -95.644356 , ...,   23.926107 ,
         133.53967  ,   53.144093 ],
       [  44.89995  ,  -33.371754 ,   94.96752  , ..., -100.38466  ,
         -56.9396   , -217.22105  ],
       [ -45.115227 , -185.57512  , -189.55267  , ..., -213.17244  ,
          11.189074 ,   18.810091 ],
       ...,
       [  57.29252  ,   89.313675 ,   96.97262  , ...,  -26.702463 ,
          32.321266 ,  159.42375  ],
       [  92.361015 ,   29.874603 ,  -63.730263 , ...,   41.16846  ,
         154.73875  ,   85.76176  ],
       [   7.0634017,   81.06211  ,  145.52782  , ...,   73.70803  ,
         -43.59744  ,  -10.067663 ]], dtype=float32)