In [17]:
import numpy as np
import jax
import jax.numpy as jnp
import time
from jax import vmap, jit

In [2]:
ar1 = np.arange(150)
ar2 = np.arange(100,250)
result = np.dot(ar1,ar2)
ar1.shape, ar2.shape, result

((150,), (150,), 2231275)

In [3]:
array1 = jnp.stack([jnp.arange(150) for i in range(100)]) # 100 is the batch dimension!!
array2 = jnp.stack([jnp.arange(100, 250) for i in range(100)])
array1.shape, array2.shape

((100, 150), (100, 150))

In [4]:
# Naive way
start = time.time()

output = []
for i in range(100):
    output.append(jnp.dot(array1[i], array2[i]))

output = jnp.stack(output)
print(output)
print('Output shape: ', output.shape)
time_taken = time.time() - start
print('Time take in secs: ', time_taken)

[2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275]
Output shape:  (100,)
Time take in secs:  0.8910415172576904


In [5]:
vmap(jnp.dot)

In [6]:
start = time.time()

output = vmap(jnp.dot)(array1, array2)
print(output)
print('Output shape: ', output.shape)
time_taken = time.time() - start
print('Time take in secs: ', time_taken)

[2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275]
Output shape:  (100,)
Time take in secs:  0.10920429229736328


In [7]:
ouput = vmap(jnp.dot, in_axes = (0, 0))(array1, array2) # 'in_axes' explicitly specifies the batch dimensions
print(output)
print('Output shape: ', output.shape)

[2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275]
Output shape:  (100,)


In [8]:
array1 = jnp.arange(150)
array1.shape, array2.shape # array1 is a single array and array2 has a batch of 100 arrays

((150,), (100, 150))

In [9]:
output = vmap(jnp.dot, in_axes = (None, 0))(array1, array2) # The first array is a vector, it doesn't have a batch, therefore 'None'
print(output)
print('Output shape: ', output.shape)

[2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275 2231275
 2231275]
Output shape:  (100,)


In [10]:
key = jax.random.PRNGKey(0)

W = jax.random.normal(key, (64, 100), dtype=jnp.float32) # input dimension=100, 64 neurons (output size of the linear layer)
batch_x = jax.random.normal(key, (16, 100), dtype=jnp.float32) # batch dimension=16, input dimension=100

W.shape, batch_x.shape

((64, 100), (16, 100))

In [11]:
# Standard linear layer without bias and activation function for a single data point (not batch)
def layer(x):
    # (64, 100) . (100, ) -> (64, )
    return jnp.dot(W, x)

In [12]:
layer(batch_x) # gives error, because the layer() accepts a single data point, not batch of data points

TypeError: dot_general requires contracting dimensions to have the same shape, got (100,) and (16,).

In [13]:
print(layer(batch_x[0]))
layer(batch_x[0]).shape

[ -1.3889995  -20.139408   -15.25461     12.268584   -11.33385
  22.630579     0.6938026  -13.827614    11.879179    -3.962661
  18.831707   -14.518444   -10.260715   -12.685417     2.5124693
  -4.255941    -1.3663092    6.9495125   -7.8258133   -8.293367
  -6.7460346  -29.767746    -4.768341    14.712051    -1.9340603
   6.222947    13.89996    -11.409643    -3.27421     -2.1721942
  10.826935    -2.5647302   -0.46695042 -11.210756    -7.741742
 -22.293253     5.421151     1.3914757    3.3206859   -8.409931
   2.869808     7.1217394    3.5472736   -4.937554    -1.475796
  -4.0422435   -8.101667     0.17466402  -3.5307515   -8.768582
  14.79269      0.30482996  20.986172    -0.58729076   6.27522
 -20.083494     5.8386555  -13.792967   -10.024259     3.3196595
  15.8581       5.4580092   -6.9915285   27.747955  ]


(64,)

In [14]:
# Note that this cannot be jitted, because we rely on the content of the input
def naive_batched_layer(batch_x):
    outputs = []
    for row in batch_x:
        outputs.append(layer(row))
    return jnp.stack(outputs)

In [15]:
print('Naive batching')

%timeit naive_batched_layer(batch_x)

Naive batching
3.36 ms ± 66.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
@jit
def manual_batched_layer(batch_x):
    # (16, 100) . (100, 64) -> (16, 64)
    return jnp.dot(batch_x, W.T)

In [19]:
print('Manual batching')

%timeit manual_batched_layer(batch_x).block_until_ready()

Manual batching
145 µs ± 42.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [20]:
@jit
def vmap_batched_layer(batch_x):
    return vmap(layer)(batch_x)

In [22]:
print('Auto-vectorized batching')

%timeit vmap_batched_layer(batch_x).block_until_ready()

Auto-vectorized batching
106 µs ± 14.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [23]:
def layer_with_weights(W, x):
    # (64, 100) . (100, ) -> (64, )
    return jnp.dot(W, x)

In [24]:
@jit
def vmap_batched_layer_with_weights(W, batch_x):
    return vmap(layer_with_weights, in_axes=(None, 0))(W, batch_x)

In [25]:
print('Auto-vectorized batching')

%timeit vmap_batched_layer_with_weights(W, batch_x).block_until_ready()

Auto-vectorized batching
116 µs ± 21.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
