In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# Replacing simple for-loops with `vmap`

The first JAX thing we will look at is the `vmap` function.
What does `vmap` do?
From the [JAX docs on `vmap`][docs]

[docs]: (https://jax.readthedocs.io/en/latest/jax.html#jax.vmap):

> Vectorizing map. Creates a function which maps fun over argument axes.

Basically the idea here is to take a function
and apply it to every "element" along a particular array axis.
Let's take a look at a few examples to see how to use it.

## Mapping a function over a 1D array

The trivial example we will use here is to map a function over a 1D array.
Firstly, let's get an array of numbers, say, from between 0 to 1:

In [2]:
import jax.numpy as np
from jax import vmap
from time import time

arr = np.linspace(0, 1, 10000)
arr



DeviceArray([0.0000000e+00, 1.0001000e-04, 2.0002000e-04, ...,
             9.9979997e-01, 9.9989998e-01, 1.0000000e+00], dtype=float32)

The shape of `arr` is `(10000,)`, which means it is one-dimensional.
`vmap`ping a function across the first dimension here
means that we will map it across every element in `arr`.
Let's apply the `sin` function, for example.

In [17]:
vmap(np.sin)(arr)

DeviceArray([0.0000000e+00, 1.0001000e-04, 2.0002000e-04, ...,
             8.4136289e-01, 8.4141695e-01, 8.4147096e-01], dtype=float32)

Now, at this point you may be wondering,
but that's equivalent to applying a NumPy universal function (`ufunc`)
elementwise across an array.
Your intuition is correct!
The example above is the trivial one.
Let's see a more complicated situation in which we might want to use `vmap`.

## Mapping a function across a 2D matrix.

In this example let's say we have a matrix of values
that we measured in an experiment.
There were `n_samples` measured, and `3` unique properties that we collected,
thereby giving us a matrix of shape `(n_samples, 3)`.
Just for the sake of pedagogy,
let's say that we wanted to take the row-wise sum.
While it's possible to use `numpy.sum()` while specifying an axis to collapse,
let's see how we can use JAX to accomplish the same.

In [19]:
data = np.array([
    [1, 3, 1,],
    [3, 5, 1,],
    [1, 2, 5,],
    [7, 1, 3,],
    [11, 2, 3,],
])

vmap(np.sum, in_axes=0)(data)

DeviceArray([ 5,  9,  8, 11, 16], dtype=int32)

Here, we specify the exact axes on which we wish to perform the sum,
which is axis 0 (for row-wise sum).
It's also possible to do a column-wise sum:

In [20]:
vmap(np.sum, in_axes=1)(data)

DeviceArray([23, 13, 13], dtype=int32)

By default, JAX's `vmap` maps over axis 0, which is kind of handy:
by whatever accidental convention from history,
we usually set the 0th axis as the axis along which
we order our samples/observations that constitute our dataset.
Expressed in JAX code:

In [21]:

vmap(np.sum)(data)

DeviceArray([ 5,  9,  8, 11, 16], dtype=int32)

## Mapping a function over a 3D array

Dealing with high dimensional arrays
is where `vmap` really begins to shine for us.
Let's say we had a second dataset of the same shape,
such that now we had a 3D array:

In [22]:
data2 = np.array([
    [1, 3, 7,],
    [3, 5, 11,],
    [3, 2, 5,],
    [7, 5, 3,],
    [11, 5, 3,],
])

combined_data = np.moveaxis(np.dstack([data, data2], ), 2, 0)
combined_data.shape

(2, 5, 3)

Our shapes tell us that we have 2 stacks of data, each with 5 rows and 3 columns.

Let's now say we wanted to do a row-wise sum.
With vanilla NumPy, we'd have to specify with a magic number
the exact axis we wanted to collapse in order to accomplish the sum:

In [31]:
np.sum(combined_data, axis=2)

DeviceArray([[ 5,  9,  8, 11, 16],
             [11, 19, 10, 15, 19]], dtype=int32)

With JAX + some informatively-named functions,
we can avoid the magic number trap instead.

Firstly, we will set up two functions,
one function that sums over rows in one dataset:

In [32]:
def row_sum(data: np.ndarray):
    """Perform a row-wise sum over one dataset.
    
    We assume that data is an array of 2 dimensions,
    such that the first dim is the sample dimension,
    and the second dim is <insert something informative here>.
    """
    return vmap(np.sum)(data)

Now, we can map row_sum over all datasets in the 3D cube
to yield the summary statistics that we desire:
row-wise summations (5 rows each) over 2 datasets
to give us a final dataset of 2 rows and 5 columns.

In [33]:
vmap(row_sum)(combined_data)

DeviceArray([[ 5,  9,  8, 11, 16],
             [11, 19, 10, 15, 19]], dtype=int32)

And voilà, just like that, magic numbers were removed from our program, 
and the hierarchical structure of our functions are a bit more explicit:

- The elementary function, `np.sum`, operates on a per-row basis.
- We map the elementary function across all rows of a single dataset, 
giving us a higher-order function that calculates 
row-wise summation for a single dataset, `row_sum`.
- We then map the `row_sum` function across all of the datasets 
that have been stacked together in a single 3D array.

## Mapping functions with two arguments

Let's see how we can handle functions that accept two arguments using `vmap`.
Say we have a function that takes in two vectors `a` and `b`,
and multiplies `a` by the transpose of `b`
to get a 2D array:

In [38]:
a = np.array([1, 2, 3, 4,])
b = np.array([2, 3, 4, 5, 6])

def grid_multiply(a, b):
    """Grid multiply vectors a and b.

    Both `a` and `b` are assumed to be vectors with only one dimension.
    Reshaping happens inside the function."""
    return a * np.reshape(b, (-1, 1))


grid_multiply(a, b)

DeviceArray([[ 2,  4,  6,  8],
             [ 3,  6,  9, 12],
             [ 4,  8, 12, 16],
             [ 5, 10, 15, 20],
             [ 6, 12, 18, 24]], dtype=int32)

What if for our each of our inputs `a` and `b` we had multiple vectors
rather than a single vector each?
`vmap` to the rescue again!

If your function has two array arguments,
then by default `vmap` will map the function
over the leading (0th) axis on both.

In [40]:
# Let's simulate a matrix of inputs by vstacking `a` and `b`.
a_mat = np.vstack([a, a*2, a*3])
b_mat = np.vstack([b, b*0.1, b*0.2])

vmap(grid_multiply)(a_mat, b_mat)

DeviceArray([[[ 2.       ,  4.       ,  6.       ,  8.       ],
              [ 3.       ,  6.       ,  9.       , 12.       ],
              [ 4.       ,  8.       , 12.       , 16.       ],
              [ 5.       , 10.       , 15.       , 20.       ],
              [ 6.       , 12.       , 18.       , 24.       ]],

             [[ 0.4      ,  0.8      ,  1.2      ,  1.6      ],
              [ 0.6      ,  1.2      ,  1.8000001,  2.4      ],
              [ 0.8      ,  1.6      ,  2.4      ,  3.2      ],
              [ 1.       ,  2.       ,  3.       ,  4.       ],
              [ 1.2      ,  2.4      ,  3.6000001,  4.8      ]],

             [[ 1.2      ,  2.4      ,  3.6000001,  4.8      ],
              [ 1.8000001,  3.6000001,  5.4      ,  7.2000003],
              [ 2.4      ,  4.8      ,  7.2000003,  9.6      ],
              [ 3.       ,  6.       ,  9.       , 12.       ],
              [ 3.6000001,  7.2000003, 10.8      , 14.400001 ]]],            dtype=float32)

Wonderful!
We have an array of the correct shape here!

What if we wanted to fix one of the arguments in place,
such as fixing `a` to be one array instead, while mapping over `b`?
To do this, we have to partially evaluate a function using `functools.partial`.

In [42]:
from functools import partial 

# We do not specify kwarg `a=a` because `a` is positionally 1st in line
grid_mul_a = partial(grid_multiply, a)

vmap(grid_mul_a)(b_mat)

DeviceArray([[[ 2.        ,  4.        ,  6.        ,  8.        ],
              [ 3.        ,  6.        ,  9.        , 12.        ],
              [ 4.        ,  8.        , 12.        , 16.        ],
              [ 5.        , 10.        , 15.        , 20.        ],
              [ 6.        , 12.        , 18.        , 24.        ]],

             [[ 0.2       ,  0.4       ,  0.6       ,  0.8       ],
              [ 0.3       ,  0.6       ,  0.90000004,  1.2       ],
              [ 0.4       ,  0.8       ,  1.2       ,  1.6       ],
              [ 0.5       ,  1.        ,  1.5       ,  2.        ],
              [ 0.6       ,  1.2       ,  1.8000001 ,  2.4       ]],

             [[ 0.4       ,  0.8       ,  1.2       ,  1.6       ],
              [ 0.6       ,  1.2       ,  1.8000001 ,  2.4       ],
              [ 0.8       ,  1.6       ,  2.4       ,  3.2       ],
              [ 1.        ,  2.        ,  3.        ,  4.        ],
              [ 1.2       ,  2.4       ,  3.

The same can be done in fixing `b`:

In [43]:
# We have to specify kwarg `b=b` because `b` is positionally 2nd in line
grid_mul_b = partial(grid_multiply, b=b)

vmap(grid_mul_b)(a_mat)

DeviceArray([[[ 2,  4,  6,  8],
              [ 3,  6,  9, 12],
              [ 4,  8, 12, 16],
              [ 5, 10, 15, 20],
              [ 6, 12, 18, 24]],

             [[ 4,  8, 12, 16],
              [ 6, 12, 18, 24],
              [ 8, 16, 24, 32],
              [10, 20, 30, 40],
              [12, 24, 36, 48]],

             [[ 6, 12, 18, 24],
              [ 9, 18, 27, 36],
              [12, 24, 36, 48],
              [15, 30, 45, 60],
              [18, 36, 54, 72]]], dtype=int32)

As usual, we are able to not care about the leading array axis for each array. 
Once again, we also broke down the problem into its elementary components, 
and then leveraged `vmap` to build _out_ the program 
to do what we wanted it to do. 
(This general pattern will show up!)

In general, `vmap`-ing over the _leading_ array axis 
is the idiomatic thing to do with JAX. 
It's possibleto `vmap` over other axes, but those are not the defaults. 
The implication is that we are nudged towards writing programs that 
at their core begin with an elementary function that operate
on every entry along one axis in an array.
We then progressively `vmap` them outwards on array data structures.

## Exercises

Let's go to some exercises to flex your newly-found `vmap` muscles!
Everything you need to know you have picked up above;
all that's left is getting practice creatively combining them together.
Use the puzzles below, which are ordered in increasing complexity,
to challenge your skillsets!

### Exercise 1: `vmap` a dot product over matrices

Dot products are ubiquitous in the deep learning world,
and most models are nothing more than fancy chains of dot products.
In this exercise, your task is to use `vmap` to express a dot product
between two stacks of matrices.
Your implementation should match the vanilla Python+NumPy equivalent provided:

In [53]:
from jax import random

key = random.PRNGKey(42)
mat1 = random.normal(key, shape=(11, 5, 3))
mat2 = random.normal(key, shape=(11, 3, 7))


def numpy_equivalent(mat1, mat2):
    result = []
    for m1, m2 in zip(mat1, mat2):
        result.append(np.dot(m1, m2))
    return np.stack(result)


result = numpy_equivalent(mat1, mat2)
result.shape 


## Your solution below!

(11, 5, 7)

### Exercise 2: Chained `vmap`s

We're going to try our hand at constructing a slightly more complex program.
This program takes in one dataset of three dimensions,
`(n_datasets, n_rows, n_columns)`.
The program first calculates
the cumulative product across each row in a dataset,
then sums them up column-wise across each dataset,
and finally applies this same operation across all datasets stacked together.
This one is a bit more challenging!

To help you along here, the shape of the data are such:

- There are 11 stacks of data.
- Each stack of data has 31 rows, and 7 columns.

The result of this program still should have 11 stacks and 31 rows,
but now each column is not the original data,
but the cumulative product of the previous columns.

To get this answer write,
no magic numbers are allows (e.g. for accessing particular axes).
At least two `vmap`s are necessary here.

In [59]:
data = random.normal(key, shape=(11, 31, 7))


def ex2_numpy_equivalent(data):
    result = []
    for d in data:
        cp = np.cumprod(d, axis=-1)
        s = np.sum(cp, axis=1)
        result.append(s)
    return np.stack(result)

ex2_numpy_equivalent(data).shape


# Your answer below:

# def row_wise_cumprod(row):
#     return np.cumprod(row)

# def dataset_wise_sum_cumprod(data):
#     row_cumprods = vmap(row_wise_cumprod)(data)
#     return vmap(np.sum)(row_cumprods)

# vmap(dataset_wise_sum_cumprod)(data).shape

(11, 31)

### Exercise 3: Double for-loops

This one is a favourite of mine,
and took me an afternoon of on-and-off thinking to reason about clearly.
Graphs, a.k.a. networks, are comprised of nodes and edges,
and nodes can be represented by a vector of information (i.e. node features).
Stacking all the nodes' vectors together
gives us a _node feature matrix_.
In graph attention networks, one step is needed where we pairwise concatenate
every node to every other node together.
For example, if every node had a length `n_features` feature vector,
then concatenating two nodes' vectors together
should give us a length `2 * n_features` vector.
Doing this pairwise across all nodes in a graph
would give us an `(n_nodes, n_nodes, 2 * n_features)` feature vector.

Your challenge below is to write the vmapped version of the following:

In [64]:
num_nodes = 13
num_feats = 17
node_feats = random.normal(key, shape=(13, 17))

def ex3_numpy_equivalent(node_feats):
    result = []
    for node1 in node_feats:
        node1_concats = []
        for node2 in node_feats:
            cc = np.concatenate([node1, node2])
            node1_concats.append(cc)
        result.append(np.stack(node1_concats))

    return np.stack(result)

ex3_numpy_equivalent(node_feats).shape

## Your solution below

(13, 13, 34)