Importing required libraries and modules

In [2]:
import time
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jax import random
from jax import jit
from jax import vmap

Taking two one dimensional arrays for performing dot product

In [3]:
array1 =  jnp.arange(150)

array1

DeviceArray([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,
              12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,
              24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,
              36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,
              48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
              60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
              72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
              84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,
              96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107,
             108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
             120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
             132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
             144, 145, 146, 147, 148, 149], dtype=int32)

In [4]:
array2 =  jnp.arange(100, 250)

array2

DeviceArray([100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
             112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
             124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135,
             136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
             148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
             160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171,
             172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
             184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195,
             196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207,
             208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
             220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
             232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243,
             244, 245, 246, 247, 248, 249], dtype=int32)

Performing dot product

In [5]:
output = jnp.dot(array1, array2)

output 

DeviceArray(2231275, dtype=int32)

In [6]:
array1 = jnp.stack([jnp.arange(150) for i in range(100)])
array2 = jnp.stack([jnp.arange(100, 250) for i in range(100)])

array1.shape, array2.shape

((100, 150), (100, 150))

Implementing dot product on a batch of vectors using loops

In [7]:
start = time.time()

output = []
for i in range(100):
    output.append(jnp.dot(array1[i], array2[i]))

output = jnp.stack(output)
print(output)
print('Output shape: ', output.shape)

time_taken = time.time() - start

print('Time taken in secs', time_taken)

[2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275]
Output shape:  (100,)
Time taken in secs 0.9632058143615723


# Introduction to `vmap`

In the above example, you can use the last two operations in JAX as well but we will take a look into a transformation that is literally the best of all. As I said earlier also, this is one of my favorite transformations in JAX - `vmap`

## What is `vmap`?
`vmap` is just another transformation like jit. It takes a function as an input along with the dimensions for the inputs and the outputs where the functions is to be mapped over to create a vectorized function. The syntax of `vmap` is like this: `vmap(function, in_axes, out_axes, ...)`


When you transform a function using `vmap`, it returns a function that is a vectorized version of the original function. Let's see it in action

In [8]:
vmap(jnp.dot)

<function jax._src.numpy.lax_numpy.dot>

vmap execution seems to be lot faster than looping

In [9]:
start = time.time()

output = vmap(jnp.dot)(array1, array2)

print(output)
print(output.shape)

time_taken = time.time() - start

print('Time taken in secs', time_taken)

[2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275]
(100,)
Time taken in secs 0.09523797035217285


Here `function` is the function that you want to vectorize. `in_axes` is the axis indices that represent the batch dimension in the inputs to the original function. Similarly, `out_axes` are the axis indices that represent the batch dimension in the output.

In [10]:
output = vmap(jnp.dot, in_axes = (0, 0))(array1, array2)

print(output)
print(output.shape)

[2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275]
(100,)


**Note:** Both the arguments necessarily do not need to have a batch dimension. For example, we can take one vector and perform the dot product with a batch of some vectors. For the input that doesn't have a batch dimension, you can just pass `None` in the `in_axes(..)` argument. Let's take an example to make it clear.

In [11]:
array1 = jnp.arange(150)

array1.shape, array2.shape

((150,), (100, 150))

In [12]:
output = vmap(jnp.dot, in_axes = (None, 0))(array1, array2)

print(output)
print(output.shape)

[2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275]
(100,)


# Batched input to a linear layer

- W: weights of a linear layer
- batch: Batched input to a linear layer

In [13]:
key = jax.random.PRNGKey(0)

W = jax.random.normal(key, (64, 100), dtype = jnp.float32)
batch_x = jax.random.normal(key, (16, 100), dtype = jnp.float32)

W.shape, batch_x.shape

((64, 100), (16, 100))

In [14]:
def layer(x):
  # (64, 100) . (100, ) -> (64, )
  return jnp.dot(W, x)

In [15]:
layer(batch_x)

TypeError: ignored

In [16]:
layer(batch_x[0])

DeviceArray([ -1.3889999 , -20.13941   , -15.254608  ,  12.268584  ,
             -11.33385   ,  22.630579  ,   0.6938019 , -13.827613  ,
              11.879179  ,  -3.9626598 ,  18.831705  , -14.518444  ,
             -10.2607155 , -12.685415  ,   2.5124695 ,  -4.255941  ,
              -1.3663094 ,   6.949514  ,  -7.8258133 ,  -8.293367  ,
              -6.7460346 , -29.767748  ,  -4.768342  ,  14.712051  ,
              -1.9340608 ,   6.222945  ,  13.89996   , -11.409643  ,
              -3.2742107 ,  -2.172195  ,  10.826933  ,  -2.5647306 ,
              -0.46695018, -11.210756  ,  -7.7417426 , -22.293255  ,
               5.421152  ,   1.3914765 ,   3.3206863 ,  -8.409932  ,
               2.8698087 ,   7.1217403 ,   3.547274  ,  -4.9375544 ,
              -1.4757957 ,  -4.042242  ,  -8.101669  ,   0.17466497,
              -3.5307512 ,  -8.768582  ,  14.792691  ,   0.30482912,
              20.986172  ,  -0.58729076,   6.2752194 , -20.083494  ,
               5.8386536 , -13.792

In [17]:
# Note that this cannot be jitted, because we rely on the content of the input

def naive_batched_layer(batch_x):
  outputs = []
  for row in batch_x:
    outputs.append(layer(row))
  
  return jnp.stack(outputs)

In [18]:
print('Naive batching')

%timeit naive_batched_layer(batch_x)

Naive batching
The slowest run took 48.71 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 4.14 ms per loop


In [21]:
@jit
def manual_batched_layer(batch_x):
  # (16, 100) . (100, 64) -> (16, 64)
  return jnp.dot(batch_x, W.T)

# TODO Recording

Re-run the cell below 2-3 times to show that the first run takes very long but the subsequent runs are faster

In [None]:
print('Manual batching')

%timeit manual_batched_layer(batch_x).block_until_ready()

Manual batching
The slowest run took 3154.06 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 8.86 µs per loop


In [None]:
@jit
def vmap_batched_layer(batch_x):
  return vmap(layer)(batch_x)

In [None]:
print('Auto-vectorized batching')

%timeit vmap_batched_layer(batch_x).block_until_ready()

Auto-vectorized batching
The slowest run took 1390.45 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 18.9 µs per loop


In [None]:
def layer_with_weights(W, x):
  # (64, 100) . (100, ) -> (64, )
  return jnp.dot(W, x)

# TODO Recording:

- first please record with the code which does not have in_axes
- Then run the next cell and show the error
- Come back to this cell and change the code to have the in_axes
- Run the next cell again and show that this works

In [None]:
@jit
def vmap_batched_layer_with_weights(W, batch_x):
  return vmap(layer_with_weights)(W, batch_x)

# def vmap_batched_layer_with_weights(W, batch_x):
#   return vmap(layer_with_weights, in_axes = (None, 0))(W, batch_x)

In [None]:
print('Auto-vectorized batching')

%timeit vmap_batched_layer_with_weights(W, batch_x).block_until_ready()

Auto-vectorized batching
The slowest run took 1476.31 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 19.1 µs per loop
