# Advanced Vectorization in `Jax`

In the previous guide on `Jax`, we introduced the basic building blocks---multi-dimensional arrays---as well as a new paradigm for writing code without loops or if/else-expressions. Here, we continue building on this paradigm with more advanced (yet common) uses of `Jax` in ML. 

Let's get ourselves started by importing some libraries.

In [2]:
import jax
import jax.numpy as jnp
import chex

**Acknowledgement.** Parts of this tutorial have been adapted from [this NumPy tutorial](https://numpy.org/doc/stable/user/quickstart.html).

## Indexing with Boolean Arrays

Whereas *slicing* is typically used to extract contiguous chunks from an array (i.e. a subsection of elements that belonged together in the original array), *indexing* will allow us to extract non-contiguous parts. How? It's easier shown than explained.

You may remember from the introductory guide to `Jax` that performing boolean operations on an array returns a boolean array. Here's an example:

In [66]:
a = jnp.arange(10)
print('a       =', a)
print('is even =', a % 2 == 0)

a       = [0 1 2 3 4 5 6 7 8 9]
is even = [ True False  True False  True False  True False  True False]


By using these boolean arrays, we can perform subselections:

In [67]:
b = a[a % 2 == 0]
print('Even elements from a =', b)

Even elements from a = [0 2 4 6 8]


Using this, you can easily write functions that have some very complicated behaviors. As an example, the function,
\begin{align*}
f(x) &= \begin{cases}
x^2 & \text{if $x$ is even} \\
x^3 & \text{if $x$ is odd} \\
\end{cases}
\end{align*}
can be implemented as follows:

In [73]:
def f(x):
    is_even = (x % 2 == 0)
    return is_even * (x ** 2.0) + ~is_even * (x ** 3.0)

print(f(jnp.arange(10)))

[  0.   1.   4.  27.  16. 125.  36. 343.  64. 729.]


In the above, we first created a boolean variable `is_even`, which has values `True` only at indices where $x$ is even. By multiplying `is_even` times $x$, we automatically cast it to an integer. This means that wherever `is_even` is `True`, it becomes 1, and when it's false, it becomes 0. As such, `is_even * (x ** 2.0)` returns an array in which every element is $x^2$ when $x$ is even and is $0$ otherwise. In contrast, `~is_even * (x ** 3.0)` returns an array in which every element is 0 when $x$ is even and is $x^3$ otherwise (the `~` notation flips the boolean values). When we add the two resultant arrays, we get the answer we were looking for.

While the above example has pedagogical value, the logic does become a little obfuscated with all of the casting and such. Here's another, cleaner way to implement the above function using `jnp.where`:

In [74]:
def f(x):
    return jnp.where(x % 2 == 0, x ** 2.0, x ** 3.0)
    
print(f(jnp.arange(10)))

[  0.   1.   4.  27.  16. 125.  36. 343.  64. 729.]


Here, `jnp.where` uses the condition `x % 2 == 0` (is even) to select elements from one of two arrays, $x^2$ or $x^3$.

````{admonition} Exercise
As before, please solve the following using `Jax` library calls only (no loops, no if/else!):

**Part 1:** Write a function that takes in an integer $N$, a coordinate $x, y$, and a radius $r$. The function should return an integer array of shape $(N, N)$ in which every element is 0 except elements that are within radius $r$ of $(x, y)$ (i.e. every element $i, j$ should be 1 if $(x - i)^2 + (y - j)^2 \leq r^2$) and 0 otherwise. For example, for $N = 10$, $x = 5$, $y = 5$, and $r = 2$, the function should return:
```
Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
       [0, 0, 0, 1, 1, 1, 1, 1, 0, 0],
       [0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)
```
Similarly, for $N = 10$, $x = 5$, $y = 0$, and $r = 2$, the function should return:
```
Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)
```

**Part 2:** Write a function that, given an integer $N \geq 2$, returns an integer array of shape $(N, N)$ with a checkerboard pattern of 0s and 1s. For example, for $N = 5$, the function should return:
```
Array([[1, 0, 1, 0, 1],
       [0, 1, 0, 1, 0],
       [1, 0, 1, 0, 1],
       [0, 1, 0, 1, 0],
       [1, 0, 1, 0, 1]], dtype=int32)
```
````

In [127]:
def indexing_q1(N, x, y, r):
    rows = jnp.tile(jnp.arange(N)[..., None], N)
    cols = rows.T
    return ((rows - x) ** 2.0 + (cols - y) ** 2.0 <= r ** 2.0).astype('int')

def indexing_q2(N):
    a = jnp.arange(N * N).reshape(N, N)
    b = (a % 2 == 0).astype('int')
    return (N % 2 == 0) * jnp.abs(b - b.T) + (N % 2 == 1) * b

## Indexing with Integer Arrays

Whereas boolean indexing selected elements from an array based on a boolean statement, we can also select elements from an array by using

## Broadcasting

## Fast Mapping with `vmap`

## Catching Bugs Early with `chex`