In [2]:
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html


In [3]:
import jax
jax.__version__

'0.4.31'

In [4]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

RuntimeError: Unable to initialize backend 'tpu': ABORTED: The TPU is already in use by process with pid 352554. Not attempting to load libtpu.so in this process. (set JAX_PLATFORMS='' to automatically choose an available backend)

## DeviceArray properties

In [None]:
import numpy as np
import jax.numpy as jnp

In [None]:
np.array([1, 42, 31337])

array([    1,    42, 31337])

In [None]:
jnp.array([1, 42, 31337])

Array([    1,    42, 31337], dtype=int32)

In [None]:
np.sum([1, 42, 31337])

31380

In [None]:
try:
  jnp.sum([1, 42, 31337])
except TypeError as e:
  print(e)

sum requires ndarray or scalar arguments, got <class 'list'> at position 0.


In [None]:
jnp.sum(jnp.array([1, 42, 31337]))

Array(31380, dtype=int32)

In [None]:
arr = jnp.array([1, 42, 31337])

In [None]:
arr.ndim

1

In [None]:
arr.shape

(3,)

In [None]:
arr.dtype

dtype('int32')

In [None]:
arr.size

3

In [None]:
arr.nbytes

12

## Devices

In [None]:
import jax

In [None]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

In [None]:
jax.devices('cpu')

[CpuDevice(id=0)]

In [None]:
jax.device_count('tpu')

4

In [None]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

In [None]:
arr = jnp.array([1, 42, 31337])

In [None]:
arr.devices()

{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}

In [None]:
arr_cpu = jax.device_put(arr, jax.devices('cpu')[0])

In [None]:
arr_cpu.devices()

{CpuDevice(id=0)}

In [None]:
arr.devices()

{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}

In [None]:
arr_host = jax.device_get(arr)

In [None]:
type(arr_host)

numpy.ndarray

In [None]:
arr_host

array([    1,    42, 31337], dtype=int32)

In [None]:
arr + arr_cpu

Array([    2,    84, 62674], dtype=int32)

In [None]:
arr_tpu = jax.device_put(arr, jax.devices('tpu')[0])

In [None]:
arr_tpu.devices()

{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}

In [None]:
try:
  arr_tpu + arr_cpu
except ValueError as e:
  print(e)

Received incompatible devices for jitted computation. Got argument x of add with shape int32[3] and device ids [0] on platform TPU and argument y of add with shape int32[3] and device ids [0] on platform CPU


## Asyncronous dispatch

In [None]:
import jax

In [None]:
a = jnp.array(range(1000000)).reshape((1000,1000))

In [None]:
a.shape

(1000, 1000)

In [None]:
a.devices()

{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}

In [None]:
%time x = jnp.dot(a,a)

CPU times: user 671 ms, sys: 0 ns, total: 671 ms
Wall time: 661 ms


In [None]:
%time x = jnp.dot(a,a).block_until_ready()

CPU times: user 810 μs, sys: 0 ns, total: 810 μs
Wall time: 455 μs


In [None]:
%time x = np.asarray(jnp.dot(a,a))

CPU times: user 3.94 ms, sys: 0 ns, total: 3.94 ms
Wall time: 3.23 ms


In [None]:
a_cpu = jax.device_put(a, jax.devices('cpu')[0])

In [None]:
a_cpu.devices()

{CpuDevice(id=0)}

In [None]:
%time x = jnp.dot(a_cpu,a_cpu).block_until_ready()

CPU times: user 447 ms, sys: 18.9 ms, total: 466 ms
Wall time: 18.5 ms


## Immutability

In [None]:
import numpy as np
import jax.numpy as jnp

In [None]:
a_jnp = jnp.array(range(10))
a_np  = np.array(range(10))

In [None]:
a_jnp

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [None]:
a_np[5], a_jnp[5]

(5, Array(5, dtype=int32))

In [None]:
a_np[5] = 100

In [None]:
a_np[5]

100

In [None]:
try:
  a_jnp[5] = 100
except TypeError as e:
  print(e)

'<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html


In [None]:
a_jnp = a_jnp.at[5].set(100)

In [None]:
a_jnp[5]

Array(100, dtype=int32)

In [None]:
a_jnp[42]

Array(9, dtype=int32)

In [None]:
a_jnp.at[42].get()

Array(9, dtype=int32)

In [None]:
a_jnp.at[42].get(mode='clip')

Array(9, dtype=int32)

In [None]:
a_jnp.at[42].get(mode='drop')

Array(-2147483648, dtype=int32)

In [None]:
a_jnp.at[42].get(mode='fill', fill_value=-1)

Array(-1, dtype=int32)

In [None]:
a_jnp = a_jnp.at[42].set(100)
a_jnp

Array([  0,   1,   2,   3,   4, 100,   6,   7,   8,   9], dtype=int32)

In [None]:
a_jnp = a_jnp.at[42].set(100, mode='clip')
a_jnp

Array([  0,   1,   2,   3,   4, 100,   6,   7,   8, 100], dtype=int32)

## Working with float64

In [None]:
# this only works on startup!
import jax
jax.config.update("jax_enable_x64", True)

In [None]:
import jax
jax.config.update("jax_enable_x64", True)

In [None]:
import jax
import jax.numpy as jnp

In [None]:
# this will not work on TPU backend. Try using CPU or GPU.
x = jnp.array(range(10), dtype=jnp.float64)
x.dtype

dtype('float64')

In [None]:
x.devices()

{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)}

In [None]:
xc = jax.device_put(x, jax.devices('cpu')[0])

In [None]:
xc.devices()

{CpuDevice(id=0)}

In [None]:
xc.dtype

dtype('float64')

In [None]:
xb16 = jnp.array(range(10), dtype=jnp.bfloat16)
xb16.dtype

dtype(bfloat16)

In [None]:
xb16.nbytes

20

In [None]:
x16 = jnp.array(range(10), dtype=jnp.float16)
x16.dtype

dtype('float16')

In [None]:
x16.nbytes

20

In [None]:
xb16+x16

Array([ 0.,  2.,  4.,  6.,  8., 10., 12., 14., 16., 18.], dtype=float32)

In [None]:
xb16+xb16

Array([0, 2, 4, 6, 8, 10, 12, 14, 16, 18], dtype=bfloat16)

## jax.numpy & jax.lax

In [None]:
jax.config.update("jax_enable_x64", False)

In [None]:
import jax.numpy as jnp
from jax import lax
from jax import random

In [None]:
jnp.add(42, 42.0)

Array(84., dtype=float32, weak_type=True)

In [None]:
jnp.add(42.0, 42.0)

Array(84., dtype=float32, weak_type=True)

In [None]:
try:
   lax.add(42, 42.0)
except TypeError as e:
  print(e)

ValueError: Cannot lower jaxpr with verifier errors:
	op requires the same element type for all operands and results
		at loc("jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_352554/1506868229.py":2:0) at callsite("run_code"("/home/jetjiang/JAX-in-Action/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/jetjiang/JAX-in-Action/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/jetjiang/JAX-in-Action/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/jetjiang/JAX-in-Action/.venv/lib/python3.10/site-packages/IPython/core/async_helpers.py":128:0) at callsite("_run_cell"("/home/jetjiang/JAX-in-Action/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/jetjiang/JAX-in-Action/.venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/jetjiang/JAX-in-Action/.venv/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/jetjiang/JAX-in-Action/.venv/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/jetjiang/JAX-in-Action/.venv/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))))
Define JAX_DUMP_IR_TO to dump the module.

In [1]:
lax.add(jnp.float32(42), 42.0)

NameError: name 'lax' is not defined

In [None]:
def random_augmentation(image, augmentations, rng_key):
   '''A function that applies a random transformation to an image'''
   augmentation_index = random.randint(key=rng_key, minval=0, maxval=len(augmentations), shape=())
   augmented_image = lax.switch(augmentation_index, augmentations, image)
   return augmented_image

In [None]:
add_noise_func = lambda x: x+10
horizontal_flip_func = lambda x: x+1
rotate_func = lambda x: x+2
adjust_colors_func = lambda x: x+3

augmentations = [
   add_noise_func,
   horizontal_flip_func,
   rotate_func,
   adjust_colors_func
]


In [None]:
image = jnp.array(range(100))

In [None]:
random_augmentation(image, augmentations, random.PRNGKey(211))