From: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html


In [None]:
import jax
import jax.numpy as jnp

import numpy as np

In [None]:
print(f"Using jax: {jax.__version__}")

In [None]:
a = jnp.zeros((2, 5), dtype=jnp.float32)
print(f"{a=}")

b = jnp.arange(6)
print(f"{b=}")
print(f"{b.__class__=} {b.dtype=} {b.device=}")


In [None]:
b_cpu = jax.device_get(b)
print(f"{b_cpu=}, {b_cpu.__class__=} {b_cpu.dtype=} {b_cpu.device=}")

b_gpu = jax.device_put(b_cpu)
print(f"{b_gpu=}, {b_gpu.__class__=} {b_gpu.dtype=} {b_gpu.device=}")

b_cpu + b_gpu

In [None]:
jax.devices()

In [None]:
b_new = b.at[0].set(1)
print(f"{b=}, {b_new=}")

In [None]:
# pseudo random number generation
rng = jax.random.key(42) # equivalent to jax.random.PRNGKey(0)
jax_random_number_1, jax_random_number_2 = jax.random.normal(rng), jax.random.normal(rng)
print(f"{jax_random_number_1=}, {jax_random_number_2=}")

# random number in numpy
np_random_number_1, np_random_number_2 = np.random.normal(size=2)
print(f"{np_random_number_1=}, {np_random_number_2=}")

# for different random number every time we sample, split the key:
rng, subkey1, subkey2 = jax.random.split(rng, num=3)
jax_random_number_3 = jax.random.normal(subkey1, shape=(1,))
jax_random_number_4 = jax.random.normal(subkey2, shape=(1,))
print(f"{jax_random_number_3=}, {jax_random_number_4=}")

In [None]:
# function transformation with jaxpr
def simple_graph(x):
    x = x + 2
    x = x ** 2
    x = x + 3
    y = x.mean()
    return y

input_array = jnp.arange(10, dtype=jnp.float32)
print(f"{input_array=}, output: {simple_graph(input_array)=}")
jaxpr = jax.make_jaxpr(simple_graph)(input_array)
print(f"{jaxpr=}")

In [None]:
global_list = []

def norm(x):
    global global_list
    global_list.append(x)
    return jnp.linalg.norm(x)
# watch out for the global variable, it will not be captured in jax
jaxpr_norm = jax.make_jaxpr(norm)(input_array)
print(f"{jaxpr_norm=}")

In [None]:
# TODO: automatic differentiation

grad_function = jax.grad(simple_graph)
gradients = grad_function(input_array)
print(f"Gradients of {simple_graph.__name__} at {input_array=}: {gradients=}")

In [None]:
jax.make_jaxpr(grad_function)(input_array)

In [None]:
jitted_function = jax.jit(simple_graph)
rng, normal_rng = jax.random.split(rng)
large_input = jax.random.normal(normal_rng, shape=(1000, ))
_ = jitted_function(large_input)

In [None]:
%%timeit 
simple_graph(large_input).block_until_ready()

In [None]:
%%timeit
jitted_function(large_input).block_until_ready()