In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# Replacing simple for-loops with `vmap`

The first JAX thing we will look at is the `vmap` function. What does `vmap` do? From the [JAX docs on `vmap`](https://jax.readthedocs.io/en/latest/jax.html#jax.vmap):

> Vectorizing map. Creates a function which maps fun over argument axes.

What does that mean? Well, let's take a look at a few classic examples.

## Mapping an elementwise function over an array's leading axis

The first example is mapping a function over an array axis. The simplest example, which is a bit trivial, is doing elementwise application of a function. Say we have uniformly spaced numbers from 0 to 1 in an array:

In [None]:
import jax.numpy as np
from jax import vmap
from time import time

arr = np.linspace(0, 1, 10000)
arr

If we wanted to apply an exponential transform on every element, the "dumb", pure Python way to do so is to write a for-loop:

In [None]:
start = time()
new_arr = []
for element in arr:
    new_arr.append(np.exp(element))
new_arr = np.array(new_arr)
end = time()
print(f"{end - start:.2f} seconds")
new_arr

Because `np.exp` is a NumPy `ufunc` that operates on individual elements, we can call `np.exp` on `arr` directly:

In [None]:
start = time()
new_arr = np.exp(arr)
end = time()
print(f"{end - start:.4f} seconds")
new_arr

As you can see, this is much faster.

This, incidentally, is equivalent to using `vmap` to map the function across all elements in the array:

In [None]:
start = time()
new_arr = vmap(np.exp)(arr)
end = time()
print(f"{end - start:.4f} seconds")
new_arr

It's a bit slower, but one thing we gain from using `vmap` is the ability to ignore the leading (first) array axis of every element that is passed into the `vmap`-ed function. To see that, we're going to look at another example.

## Mapping a row-wise function across an array's leading axis

In this example let's say we have a matrix of values that we measured in an experiment. There were `n_samples` measured, and `3` unique properties that we collected, thereby giving us a matrix of shape `(n_samples, 3)`. If we needed to find their sum, we could do the following in pure NumPy:

In [None]:
def row_sum(data):
    """Given one dataset, calculate row-wise sum of data."""
    return np.sum(data, axis=1)

data = np.array([
    [1, 3, 1,],
    [3, 5, 1,],
    [1, 2, 5,],
    [7, 1, 3,],
    [11, 2, 3,],
])

start = time()
result = row_sum(data)
end = time()
print(f"{end - start:.4f} seconds")
result

This would give us the correct answer... but we had to worry about the "axis" argument, which is a bit irritating. Instead, we could use first transform `np.sum` into a vmapped function that is mapped across the leading axis of `data`:

In [None]:
def row_sum_one_data(data):
    """Given one dataset, calculate row-wise sum of data."""
    return vmap(np.sum)(data)

start = time()
result = row_sum_one_data(data)
end = time()
print(f"{end - start:.4f} seconds")
result

Thereby giving us the exact same result. While the syntax does take some time to get used to, it does more explicitly and clearly expresses the idea that _we don't really care about summing over the leading axis_.

Now, let's say we had multiple datasets for which we wanted to calculate the row-wise sum. How would we do this in pure NumPy?

Well, let's first create this dataset.

In [None]:
data2 = np.array([
    [1, 3, 7,],
    [3, 5, 11,],
    [3, 2, 5,],
    [7, 5, 3,],
    [11, 5, 3,],
])

combined_data = np.moveaxis(np.dstack([data, data2], ), 2, 0)
combined_data.shape

Our shapes tell us that we have 2 stacks of data, each with 5 rows and 3 columns.

Since we want row-wise summation, but want to preserve the 2 stacks of data, we have to now worry about which axes to collapse:

In [None]:
np.sum(combined_data, axis=2)

This is all cool, but we now have a "magic number" in our program. We can eliminate this magic number by instead doing vmapping `row_sum_over_data` across the `combined_data` array:

In [None]:
def row_sum_all_data(data):
    return vmap(row_sum_one_data)(data)
    
row_sum_all_data(combined_data)

And voilà, just like that, magic numbers were removed from our program, and the hierarchical structure of our functions are a bit more explicit:

- The elementary function, `np.sum`, operates on a per-row basis.
- We map the elementary function across all rows of a single dataset, giving us a higher-order function that calculates row-wise summation for a single dataset, `row_sum_one_data`.
- We then map the `row_sum_one_data` across all of the datasets that have been stacked together in a single 3D array.

## Mapping a function over two arrays simultaneously

Let's look at another example. Say we are given two arrays, and we wanted to elementwise multiply them together. For example:

In [None]:
a1 = np.array([1, 2, 3, 4,])
a2 = np.array([2, 3, 4, 5,])

As the NumPy-idiomatic option, we could do:

In [None]:
a1 * a2

Another option is that we can define a function called `multiply`, which multiplies two scalars together and gives us back another scalar, which we then apply over each element in a `zip` of the two arrays. This is the _extremely_ naive way of handling the problem:

In [None]:
result = []

def multiply(a, b):
    return a * b

for val1, val2 in zip(a1, a2):
    result.append(multiply(val1, val2))
np.array(result)

On the other hand, if we consider this to be the elementary operation of our function, we could instead multiply them pairwise:

In [None]:
vmap(multiply)(a1, a2)

As usual, we are able to not care about the leading array axis for each array. Once again, we also broke down the problem into its elementary components, and then leveraged `vmap` to build _out_ the program to do what we wanted it to do. (This general pattern will show up!)

In general, `vmap`-ing over the _leading_ array axis is the idiomatic thing to do with JAX. It's possibleto `vmap` over other axes, but those are not the defaults. The implication is that we are nudged towards writing programs that at their core begin with an "elementary" function that operate "elementwise", where the definition of an "element" is not necessarily an array element, but problem-dependent. We then progressively `vmap` them outwards on array data structures.

### Example 1: `vmap`-ing a dot product over square matrices

Let's try getting some practice with the following exercises.

The first one is to `vmap` a dot product of a square matrix against itself across a stack of square matrices.

An example square matrix called `sq_matrix` is provided for you to jog your memory on how dot products work if you need to.

In [None]:
from jax import random

key = random.PRNGKey(42)
data = random.normal(key, shape=(11, 5, 5))
sq_matrix = random.normal(key, shape=(5, 5))

vmap(np.dot)(data, data).shape

### Example 2: Constructing a more complex program

We're going to try our hand at constructing a program that first calculates a cumulative product vector for each row in each dataset, sums them up column-wise across each dataset, and applies this same operation across all datasets stacked together. This one is a bit more challenging!

To help you along here, the shape of the data are such:

- There are 11 stacks of data.
- Each stack of data has 31 rows, and 7 columns.

The result of this program still should have 11 stacks and 31 rows, but now each column is not the original data, but the cumulative product of the previous columns.

To get this answer write, no magic numbers are allows (e.g. for accessing particular axes). At least two `vmap`s are necessary here.

In [None]:
data = random.normal(key, shape=(11, 31, 7))

def row_wise_cumprod(row):
    return np.cumprod(row)

def dataset_wise_sum_cumprod(data):
    row_cumprods = vmap(row_wise_cumprod)(data)
    return vmap(np.sum)(row_cumprods)

vmap(dataset_wise_sum_cumprod)(data).shape