# Jax 101 Exercise Solutions
---
This notebook accompanies the `Jax 101` [blog post](link) and the `exercises.ipynb` notebook. Make sure to read that post and complete the exercises before looking at this notebook.

In [1]:
import jax
from jax import random, grad, vmap
from jax import jit
import jax.numpy as jnp
import imageio
from tqdm import tqdm
from io import BytesIO
import matplotlib.pyplot as plt

%matplotlib inline

## Beginner Level

### Ex 1: Array Creation

- **Objective:** Learn basic array creation techniques.
- **Tasks:**
  - Create an array of values that range from 0 to 100.
  - Create an matrix of zeros of size (100 x 1000) with dtype jnp.float64.
  - Use `jnp.polyval` to create an array of values. Visualize the result using `matplotlib`.

In [4]:
jax.config.update("jax_enable_x64", True)

In [None]:
arr = jnp.arange(100)
arr2 = jnp.zeros((100, 1000), dtype=jnp.float64)
print(arr[:10])
print(arr.dtype, arr2.dtype)

In [None]:
a = jnp.polyval(jnp.array([-3, 4, -2, 7]), jnp.linspace(-5, 5, 100))
plt.plot(jnp.arange(a.shape[0]), a);

### Exercise 2: Higher Order Grads

- **Objective:** Learn about `grad`.
- **Tasks:**
  - Write the function defined below.
  - Create an array of values that range from 0 to 100.
  - Using `grad` calculate the 1st, 2nd, 3rd, and 4th order gradients of the function
    - Hint: You can nest `grad` calls like `grad(grad(f))`
  - Plot the resulting arrays on a single figure

$
f(x) = x^3 - 3x^2 + 2x
$

In [9]:
def f(x):
    return x**3 - 3*x**2 + 2*x

In [None]:
x = jnp.arange(7.0)

print(f"original:           {x} --> {f(x)}")
print(f"grad:               {x} --> {vmap(grad(f))(x)}")
print(f"2nd order grad:     {x} --> {vmap(grad(grad(f)))(x)}")
print(f"3rd order grad:     {x} --> {vmap(grad(grad(grad(f))))(x)}")
print(f"4th order grad:     {x} --> {vmap(grad(grad(grad(grad(f)))))(x)}")

In [None]:
xs = jnp.linspace(-10, 10, 2_000)

plt.plot(xs, f(xs), label="f(x)")
plt.plot(xs, vmap(grad(f))(xs), label="1st order grad")
plt.plot(xs, vmap(grad(grad(f)))(xs), label="2nd order grad")
plt.plot(xs, vmap(grad(grad(grad(f))))(xs), label="3rd order grad")
plt.grid(True)
plt.legend();

### Exercise 3: Random Numbers

- **Objective:** Learn about Jax's random module.
- **Tasks:**
  - Use `random.key` to create a key and then create an array of normally distributed numbers.
  - What happens if you try to reuse the key?
  - Bonus points: Implement a Monte Carlo simulation to estimate the value of π.

In [None]:
key = random.key(0)
arr = random.normal(key, 10)
arr_reused = random.normal(key, 10)
print(arr)
print(arr_reused)

In [None]:
key = random.key(0)
arr = random.uniform(key, (10_000, 2), minval=0, maxval=1)

invals = (jnp.linalg.norm(arr, axis=1) < 1).sum()
pi = invals / arr.shape[0] * 4
print(f"Estimated value of pi: {pi:.7f}")

### Exercise 4: Jit

- **Objective:** Understand and apply Just-In-Time (JIT) compilation for performance optimization.
- **Tasks:**
  - Implement the function described below.
  - Use `jax.jit` to optimize the function.
  - Measure and compare the execution time before and after applying `jax.jit`.
    - Hint: For accurate timing make sure to use `block_until_ready`.
  - Experiment with `static_argnums` and `donate_argnums` to understand their impact on performance.

$
f(x) = x^2 + 2x + 1
$

**Best Practice**: Apply `vmap` and then `jit` so that you optimize the vectorized version of a function

In [18]:
def f(x):
    return x**2 + 2*x + 1

In [19]:
f_jitted = jit(f)

In [20]:
key = random.key(0)

arr = random.normal(key, (1_000_000, 100))

_ = f_jitted(arr) # compile function

In [None]:
%timeit f(arr).block_until_ready()
%timeit f_jitted(arr).block_until_ready()

In [26]:
# static argnums
def f(x, y):
    return x**2 + x**0.5 - y**3

Cached code gets used when the same 0-th argument gets used. otherwise recompilation occurs.

In [27]:
f_jitted = jit(f)
f_staticargnum_jitted = jit(f, static_argnums=(0,))

In [None]:
# Create data and compile function
key = random.key(0)

x = random.normal(key, (1_000_000, 100))
y = random.normal(key, (1_000_000, 100))
xnew = 2.0
ynew = jnp.array(3.0)


_ = f_jitted(x, y) # compile function
_ = f_staticargnum_jitted(xnew, ynew)

In [None]:
%timeit -n 10 f(x, y).block_until_ready()
%timeit -n 10 f_jitted(x, y).block_until_ready()

## **Intermediate Level**

### Exercise 4: Vectorization with `vmap`
- **Objective:** Learn how to efficiently vectorize operations using `jax.vmap`.
- **Tasks:**
  - Implement a function that computes the dot product of two vectors.
  - Use `jax.vmap` to apply this function across a batch of vectors.
  - Extend the function to compute the matrix-vector product for a batch of matrices and vectors.
  - Compare the performance of the vectorized version with a loop-based implementation.

In [30]:
def f(x, y):
    return jnp.dot(x, y)

In [None]:
x = jnp.arange(10)
y = jnp.arange(10) + 5
f(x, y)

In [32]:
f_vmapped = vmap(f)

In [33]:
# Create data for timing test
key = random.key(0)
xkey, ykey = random.split(key)
x = random.normal(xkey, (10_000, 100))
y = random.normal(ykey, (10_000, 100))

In [None]:
%timeit [f(i, j) for i, j in zip(x, y)]
%timeit f_vmapped(x, y)

### Exercise 5: Working with Custom Gradients
- **Objective:** Implement custom gradients for non-standard operations.
- **Tasks:**
  - Implement the relu function.
  - Define a custom gradient for this function using `jax.custom_jvp` or `jax.custom_vjp`.
  - Implement and test a function that uses this custom gradient in an optimization problem.
  - Plot the results and use Jax's built-in `grad` and `jax.nn.relu` to confirm the outputs match your implementation.

In [35]:
@jax.custom_vjp
def relu(x):
    return jnp.maximum(0, x)

def relu_fwd(x):
    y = relu(x)
    return y, x

def relu_bwd(res, g):
    y = res
    grad_x = jnp.where(y > 0., 1., 0.) * g
    return (grad_x,)

relu.defvjp(relu_fwd, relu_bwd)

In [None]:
xs = jnp.linspace(-10, 10, 2_000)

_, ax = plt.subplots()
ax.plot(xs, relu(xs), label="relu")
ax.plot(xs, vmap(grad(relu))(xs), label="grad(relu)")
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
ax.grid(True)
ax.legend()

## **Advanced Level**

### Exercise 7: Neural Networks with Jax
- **Objective:** Build and train a simple neural network from scratch using Jax.
- **Tasks:**
  - Implement a basic feedforward neural network for a classification task.
  - Use `jax.numpy` for all operations, and manually implement forward and backward passes.
  - Implement training using stochastic gradient descent.
  - Experiment with different activation functions and regularization techniques.

In [None]:
## Create data for binary classification task
# Parameters for the random data
n = 50
num_classes = 2

# Generate random data points
key = random.key(0)
xkey0, xkey1 = random.split(key, 2)

mean_class_0 = jnp.array([-2, -2])  # mean for class 0
mean_class_1 = jnp.array([2, 2])  # mean for class 1
std_dev = 1.0  # standard deviation for both classes

# Generate data for class 0
X_class_0 = random.normal(xkey0, (n, 2)) * std_dev + mean_class_0
y_class_0 = jnp.zeros(n)  # label 0 for class 0

# Generate data for class 1
X_class_1 = random.normal(xkey1, (n, 2)) * std_dev + mean_class_1
y_class_1 = jnp.ones(n)  # label 1 for class 1

# Assign random class labels (0 or 1)
# y = random.randint(ykey, n, 0, num_classes)
X = jnp.vstack((X_class_0, X_class_1))
y = jnp.expand_dims(jnp.hstack((y_class_0, y_class_1)), -1)
print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")

# Visualize the generated data
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral, edgecolors='k')
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("Randomly Generated Data for Binary Classification")
plt.show()

# Show the first few rows of the generated data
X[:5], y[:5]

In [40]:
# Build network weights
Wkey1, Wkey2 = random.split(key, 2)
W1 = random.normal(Wkey1, (n*2, 32)) * 1e-2
b1 = jnp.ones((32,1))

W2 = random.normal(Wkey2, (32, n*2)) * 1e-2
b2 = jnp.ones((n*2,1))

params = {"W1": W1, "b1": b1, "W2": W2, "b2": b2}

In [41]:
# Train via gradient descent
def predict(params, x):
    x_h = jax.nn.relu(params["W1"].T @ x + params["b1"])
    preds = jax.nn.sigmoid(params["W2"].T @ x_h + params["b2"])
    return preds

# @jit
def loss_fn(params, x, y):
    preds = predict(params, x)
    return jnp.square(preds - y).mean()

steps = 2000
lr = 0.01
losses = jnp.empty((steps,))
predictions = []
for i in range(steps):
    preds = predict(params, X)
    if (i % 10) == 0:
        predictions.append(preds)
    # calculate loss + grad
    loss_value, grad_value = jax.value_and_grad(loss_fn)(params, preds, y)
    losses = losses.at[i].set(loss_value)
    # update weights
    for param in params:
        params[param] += -grad_value[param]*lr

In [None]:
plt.plot(losses, label="train loss")
plt.grid()
plt.legend();

In [None]:
# See predictions made by trained model
preds = predict(params, X)
plt.figure(figsize=(8, 6))
plt.scatter(preds[:, 0], preds[:, 1], c=y, cmap=plt.cm.Spectral, edgecolors='k')
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("Final Predictions")
plt.show()

In [None]:
# Create Gif of predictions over time (optional)
frames = []

for i in tqdm(range(len(predictions))):
    # Create a plot for each frame
    plt.figure(figsize=(8, 6))
    plt.scatter(predictions[i][:, 0], predictions[i][:, 1], c=y, cmap=plt.cm.Spectral, edgecolors='k')
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    plt.title(f"Predictions at step {i+1}/{len(predictions)}")
    
    # Save the plot to a BytesIO object
    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    
    # Add this frame to the list
    frames.append(imageio.v3.imread(buf))

# Save the frames as a GIF
imageio.mimsave('ex7_sgd.gif', frames, fps=2)  # You can adjust fps for speed

In [None]:
# If you want to display the GIF in a Jupyter notebook (optional):
from IPython.display import Image
Image(filename='ex7_sgd.gif')

### Exercise 8: Parallelism with `pmap`
- **Objective:** Leverage data parallelism to scale computations across multiple devices.
- **Tasks:**
  - Implement a function to compute across a large array. It should be complicated enough to notice a difference between a jitted version and a non-jitted version
  - Use `jax.pmap` to parallelize this operation across multiple devices (e.g., GPUs).
  - Measure the speedup obtained with `pmap` compared to single-device execution.
  - Experiment with different batch sizes and data partitioning strategies.
  - Try doing this in [Google Colab](https://colab.research.google.com) for free and easy access to multiple accelerator devices.
  
Note: `jax.pmap` compiles the function so `jit()` is unnecessary.

In [None]:
jax.devices()

In [None]:
jax.local_device_count()

In [65]:
def f(x, y):
    return jnp.sum(5*x - y**2 + 2)

In [None]:
key = random.key(0)
key1, key2 = random.split(key, 2)

x1 = random.normal(key1, (8, 100_000))
x2 = random.normal(key2, (8, 100_000))
f(x1, x2).shape

In [71]:
## Run this on machine with hardware accelerator. It will not work on CPU
# jitted_f = jax.pmap(f)
# _ = jitted_f(x1, x2)

In [None]:
# %timeit f(x1, x2)
# %timeit jitted_f(x1, x2).block_until_ready()

### Exercise 9: JaxPR
- **Objective:** Dive into Jax's internal representation by manipulating JaxPR.
- **Tasks:**
  - Define a function and obtain it's jaxpr using `jax.make_jaxpr`.
  - Analyze the jaxpr to understand it's computation graph. Compare it with the one from the blog post.
  - Modify the function and observe how the jaxpr changes.

In [72]:
def f(x):
    x = x + 3
    x = x ** 2
    y = 21
    return jnp.sum(x + y)

In [73]:
x = jnp.arange(10)

In [None]:
jax.make_jaxpr(f)(x)