# JAX basics

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit # program transformations

from utils import (
    sigmoid,
    sigmoid_sum,
    sigmoid_deriv,
)

## Autograd

In [None]:
num_samples = int(1e03)

key = jax.random.PRNGKey(0)

x = jax.random.normal(key, (num_samples,))

sigmoid_sum_grad = grad(sigmoid_sum) # create gradient function

In [None]:
autodiff_grad = sigmoid_sum_grad(x)
exact_grad = sigmoid_deriv(x)

print(jnp.allclose(autodiff_grad, exact_grad))

max_abs_diff = max(abs(autodiff_grad - exact_grad))
print(max_abs_diff)

## Just-in-time compilation

In [None]:
num_samples = int(1e03)
key = jax.random.PRNGKey(2024)

x = jax.random.normal(key, (num_samples,))

sigmoid_jit = jit(sigmoid) # create JIT-compiled function

In [None]:
%timeit sigmoid(x).block_until_ready() # avoid asynchronous execution when timing

In [None]:
%timeit sigmoid_jit(x).block_until_ready()

## Good to know

### Index clamping

In [None]:
num_elem = 10

arange = jnp.arange(num_elem)

print(arange[num_elem - 1]) # last element
print(arange[num_elem + 10]) # index is clamped

### In-place operations

In [None]:
matrix = jnp.eye(2)

# matrix[0, 0] = 2 # would throw an error
new_matrix = matrix.at[0, 0].set(2) # proper way

new_matrix = new_matrix.at[:, 1].add(5)

print(new_matrix)

### Random numbers

In [None]:
key = jax.random.PRNGKey(0) # create an explicit PRNG state
keys = jax.random.split(key, 10) # create multiple states

In [None]:
num_elem = 10

x1 = jax.random.normal(key, (num_elem,)) # pseudo-random numbers
x2 = jax.random.normal(key, (num_elem,)) # exactly the same values

print(jnp.all(x1 == x2))

In [None]:
for key in keys:
    x = jax.random.normal(key) # different values
    print(x)