From: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html


In [16]:
import jax
import jax.numpy as jnp

import numpy as np

In [2]:
print(f"Using jax: {jax.__version__}")

Using jax: 0.6.1


In [7]:
a = jnp.zeros((2, 5), dtype=jnp.float32)
print(f"{a=}")

b = jnp.arange(6)
print(f"{b=}")
print(f"{b.__class__=} {b.dtype=} {b.device=}")


a=Array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)
b=Array([0, 1, 2, 3, 4, 5], dtype=int32)
b.__class__=<class 'jaxlib._jax.ArrayImpl'> b.dtype=dtype('int32') b.device=CudaDevice(id=0)


In [11]:
b_cpu = jax.device_get(b)
print(f"{b_cpu=}, {b_cpu.__class__=} {b_cpu.dtype=} {b_cpu.device=}")

b_gpu = jax.device_put(b_cpu)
print(f"{b_gpu=}, {b_gpu.__class__=} {b_gpu.dtype=} {b_gpu.device=}")

b_cpu + b_gpu

b_cpu=array([0, 1, 2, 3, 4, 5], dtype=int32), b_cpu.__class__=<class 'numpy.ndarray'> b_cpu.dtype=dtype('int32') b_cpu.device='cpu'
b_gpu=Array([0, 1, 2, 3, 4, 5], dtype=int32), b_gpu.__class__=<class 'jaxlib._jax.ArrayImpl'> b_gpu.dtype=dtype('int32') b_gpu.device=CudaDevice(id=0)


Array([ 0,  2,  4,  6,  8, 10], dtype=int32)

In [12]:
jax.devices()

[CudaDevice(id=0)]

In [14]:
b_new = b.at[0].set(1)
print(f"{b=}, {b_new=}")

b=Array([0, 1, 2, 3, 4, 5], dtype=int32), b_new=Array([1, 1, 2, 3, 4, 5], dtype=int32)


In [20]:
# pseudo random number generation
rng = jax.random.key(42) # equivalent to jax.random.PRNGKey(0)
jax_random_number_1, jax_random_number_2 = jax.random.normal(rng), jax.random.normal(rng)
print(f"{jax_random_number_1=}, {jax_random_number_2=}")

# random number in numpy
np_random_number_1, np_random_number_2 = np.random.normal(size=2)
print(f"{np_random_number_1=}, {np_random_number_2=}")

# for different random number every time we sample, split the key:
rng, subkey1, subkey2 = jax.random.split(rng, num=3)
jax_random_number_3 = jax.random.normal(subkey1, shape=(1,))
jax_random_number_4 = jax.random.normal(subkey2, shape=(1,))
print(f"{jax_random_number_3=}, {jax_random_number_4=}")

jax_random_number_1=Array(-0.02830462, dtype=float32), jax_random_number_2=Array(-0.02830462, dtype=float32)
np_random_number_1=np.float64(0.14305665507357707), np_random_number_2=np.float64(2.31487907533281)
jax_random_number_3=Array([0.60576403], dtype=float32), jax_random_number_4=Array([0.4323065], dtype=float32)


In [21]:
# function transformation with jaxpr
def simple_graph(x):
    x = x + 2
    x = x ** 2
    x = x + 3
    y = x.mean()
    return y

input_array = jnp.arange(10, dtype=jnp.float32)
print(f"{input_array=}, output: {simple_graph(input_array)=}")
jaxpr = jax.make_jaxpr(simple_graph)(input_array)
print(f"{jaxpr=}")

input_array=Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32), output: simple_graph(input_array)=Array(53.5, dtype=float32)
jaxpr={ lambda ; a:f32[10]. let
    b:f32[10] = add a 2.0:f32[]
    c:f32[10] = integer_pow[y=2] b
    d:f32[10] = add c 3.0:f32[]
    e:f32[] = reduce_sum[axes=(0,)] d
    f:f32[] = div e 10.0:f32[]
  in (f,) }


In [24]:
global_list = []

def norm(x):
    global global_list
    global_list.append(x)
    return jnp.linalg.norm(x)
# watch out for the global variable, it will not be captured in jax
jaxpr_norm = jax.make_jaxpr(norm)(input_array)
print(f"{jaxpr_norm=}")

jaxpr_norm={ lambda ; a:f32[10]. let
    b:f32[] = pjit[
      name=norm
      jaxpr={ lambda ; a:f32[10]. let
          c:f32[10] = mul a a
          d:f32[] = reduce_sum[axes=(0,)] c
          b:f32[] = sqrt d
        in (b,) }
    ] a
  in (b,) }


In [25]:
# TODO: automatic differentiation

grad_function = jax.grad(simple_graph)
gradients = grad_function(input_array)
print(f"Gradients of {simple_graph.__name__} at {input_array=}: {gradients=}")

Gradients of simple_graph at input_array=Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32): gradients=Array([0.4      , 0.6      , 0.8      , 1.       , 1.2      , 1.4      ,
       1.6      , 1.8000001, 2.       , 2.2      ], dtype=float32)


In [26]:
jax.make_jaxpr(grad_function)(input_array)

{ lambda ; a:f32[10]. let
    b:f32[10] = add a 2.0:f32[]
    c:f32[10] = integer_pow[y=2] b
    d:f32[10] = integer_pow[y=1] b
    e:f32[10] = mul 2.0:f32[] d
    f:f32[10] = add c 3.0:f32[]
    g:f32[] = reduce_sum[axes=(0,)] f
    _:f32[] = div g 10.0:f32[]
    h:f32[] = div 1.0:f32[] 10.0:f32[]
    i:f32[10] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10,)
      sharding=None
    ] h
    j:f32[10] = mul i e
  in (j,) }

In [None]:
# TODO: understand the differentiation