# Jax 101 Exercises
---
This notebook accompanies the `Jax 101` [blog post](link). Make sure to check that out before diving into these exercises. The post will not cover everything you need to complete these problems so make sure to refer back to the [Jax documentation](https://jax.readthedocs.io/en/latest/index.html) if you get stuck. 

Try to make use of Jax's debugging tools like `jax.debug.print`, you will not be able to print array values inside a compiled function. Also for timing use Python's built-in `timeit` library. You can do this easily inside a Jupyter notebook like so:
```
%timeit myfunction(x, y)
```

Good luck!

In [23]:
import jax
from jax import random
import jax.numpy as jnp
import imageio
from tqdm import tqdm
from io import BytesIO
import matplotlib.pyplot as plt

%matplotlib inline

## Beginner Level

### Ex 1: Array Creation

- **Objective:** Learn basic array creation techniques.
- **Tasks:**
  - Create an array of values that range from 0 to 100.
  - Create an matrix of zeros of size (100 x 1000) with dtype jnp.float64.
  - Use `jnp.polyval` to create an array of values. Visualize the result using `matplotlib`.

### Exercise 2: Higher Order Grads

- **Objective:** Learn about `grad`.
- **Tasks:**
  - Write the function defined below.
  - Create an array of values that range from -10 to +10.
  - Using `grad` calculate the 1st, 2nd, 3rd, and 4th order gradients of the function
    - Hint: You can nest `grad` calls like `grad(grad(f))`
  - Plot the resulting arrays on a single figure

$
f(x) = x^3 - 3x^2 + 2x
$

In [30]:
def f(x):
    ### TODO ###

### Exercise 3: Random Numbers

- **Objective:** Learn about Jax's random module.
- **Tasks:**
  - Use `random.key` to create a key and then create an array of normally distributed numbers.
  - What happens if you try to reuse the key?
  - Bonus points: Implement a Monte Carlo simulation to estimate the value of π.

### Exercise 4: Jit

- **Objective:** Understand and apply Just-In-Time (JIT) compilation for performance optimization.
- **Tasks:**
  - Implement the function described below.
  - Use `jax.jit` to optimize the function.
  - Measure and compare the execution time before and after applying `jax.jit`.
    - Hint: For accurate timing make sure to use `block_until_ready`.
  - Experiment with `static_argnums` and `donate_argnums` to understand their impact on performance.

$
f(x) = x^2 + 2x + 1
$

**Best Practice**: Apply `vmap` and then `jit` so that you optimize the vectorized version of a function

In [None]:
def f(x):
    ### TODO ###

## **Intermediate Level**

### Exercise 5: Vectorization with `vmap`
- **Objective:** Learn how to efficiently vectorize operations using `jax.vmap`.
- **Tasks:**
  - Implement a function that computes the dot product of two vectors.
  - Use `jax.vmap` to apply this function across a batch of vectors.
  - Extend the function to compute the matrix-vector product for a batch of matrices and vectors.
  - Compare the performance of the vectorized version with a loop-based implementation.

In [165]:
def f(x, y):
    ### TODO ###

In [None]:
# f_vmapped = ...

In [None]:
%timeit [f(i, j) for i, j in zip(x, y)]
%timeit f_vmapped(x, y)

### Exercise 6: Working with Custom Gradients
- **Objective:** Implement custom gradients for non-standard operations.
- **Tasks:**
  - Implement the relu function.
  - Define a custom gradient for this function using `jax.custom_jvp` or `jax.custom_vjp`.
  - Plot the results and use Jax's built-in `grad` and `jax.nn.relu` to confirm the outputs match your implementation.

In [251]:
def relu(x):
    ### TODO ###

In [None]:
xs = jnp.linspace(-10, 10, 2_000)
ys_relu = relu(xs)
# ys_grad = ...

plt.plot(xs, ys_relu, label="relu")
plt.plot(xs, ys_grad, label="grad(relu)")
plt.grid(True)
plt.legend();

## **Advanced Level**

### Exercise 7: Neural Networks with Jax
- **Objective:** Build and train a simple neural network from scratch using Jax.
- **Tasks:**
  - Implement a basic feedforward neural network for a classification task.
  - Use `jax.numpy` for all operations, and manually implement forward and backward passes.
  - Implement training using stochastic gradient descent.
  - Experiment with different activation functions and regularization techniques.
  - Try jitting the train loop and see if training time goes down.
  - Think about why we have to pass in `params` to each function. Why not access it in the global scope?

In [None]:
## Create data for binary classification task
# Parameters for the random data
n = 50
num_classes = 2

# Generate random data points
key = random.key(0)
xkey0, xkey1 = random.split(key, 2)

mean_class_0 = jnp.array([-2, -2])  # mean for class 0
mean_class_1 = jnp.array([2, 2])  # mean for class 1
std_dev = 1.0  # standard deviation for both classes

# Generate data for class 0
X_class_0 = random.normal(xkey0, (n, 2)) * std_dev + mean_class_0
y_class_0 = jnp.zeros(n)  # label 0 for class 0

# Generate data for class 1
X_class_1 = random.normal(xkey1, (n, 2)) * std_dev + mean_class_1
y_class_1 = jnp.ones(n)  # label 1 for class 1

# Assign random class labels (0 or 1)
# y = random.randint(ykey, n, 0, num_classes)
X = jnp.vstack((X_class_0, X_class_1))
y = jnp.expand_dims(jnp.hstack((y_class_0, y_class_1)), -1)
print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")

# Visualize the generated data
plt.figure(figsize=(8, 6))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral, edgecolors='k')
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("Randomly Generated Data for Binary Classification")
plt.show()

# Show the first few rows of the generated data
X[:5], y[:5]

In [14]:
# Build network weights
### TODO ###
# params = {"W1": ...}

In [15]:
# Train via gradient descent
def predict(params, x):
    ### TODO ###

def loss_fn(params, x, y):
    ### TODO ###

steps = 2000
lr = 0.01
losses = jnp.empty((steps,)) # Save loss values here
predictions = []
for i in range(steps):
    preds = predict(params, X)
    ### TODO ###
    # Save predictions for visualization later (optional)
    ### TODO ###
    # loss_value, grad_value = ...
    ### TODO ###
    # update weights...

In [None]:
plt.plot(losses, label="train loss")
plt.grid()
plt.legend();

In [None]:
# See predictions made by trained model
preds = predict(params, X)
plt.figure(figsize=(8, 6))
plt.scatter(preds[:, 0], preds[:, 1], c=y, cmap=plt.cm.Spectral, edgecolors='k')
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("Final Predictions")
plt.show()

In [None]:
# Create Gif of predictions over time (optional)
frames = []

for i in tqdm(range(len(predictions))):
    # Create a plot for each frame
    plt.figure(figsize=(8, 6))
    plt.scatter(predictions[i][:, 0], predictions[i][:, 1], c=y, cmap=plt.cm.Spectral, edgecolors='k')
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    plt.title(f"Predictions at step {i+1}/{len(predictions)}")
    
    # Save the plot to a BytesIO object
    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    
    # Add this frame to the list
    frames.append(imageio.v3.imread(buf))

# Save the frames as a GIF
imageio.mimsave('ex7_sgd.gif', frames, fps=2)  # You can adjust fps for speed

In [None]:
# If you want to display the GIF in a Jupyter notebook (optional):
from IPython.display import Image
Image(filename='ex7_sgd.gif')

### Exercise 8: Parallelism with `pmap`
- **Objective:** Leverage data parallelism to scale computations across multiple devices.
- **Tasks:**
  - Implement a function to compute across a large array. It should be complicated enough to notice a difference between a jitted version and a non-jitted version  - Use `jax.pmap` to parallelize this operation across multiple devices (e.g., GPUs).
  - Measure the speedup obtained with `pmap` compared to single-device execution.
  - Experiment with different batch sizes and data partitioning strategies.
  - Try doing this in [Google Colab](https://colab.research.google.com) for free and easy access to multiple accelerator devices.
  
Note: `jax.pmap` compiles the function so `jit()` is unnecessary.

### Exercise 9: JaxPR
- **Objective:** Dive into Jax's internal representation by manipulating JaxPR.
- **Tasks:**
  - Define a function and obtain it's jaxpr using `jax.make_jaxpr`.
  - Analyze the jaxpr to understand it's computation graph. Compare it with the one from the blog post.
  - Modify the function and observe how the jaxpr changes.