In [1]:
# https://roberttlange.com/posts/2020/03/blog-post-10/

import numpy as onp
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random

# Generate key which is used to generate random numbers
key = random.PRNGKey(1)

2024-02-17 09:14:19.640799: W external/xla/xla/service/gpu/nvptx_compiler.cc:744] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


# Matrix multiplication

In [2]:
# Generate a random matrix
x = random.uniform(key, (1000, 1000))
# Compare running times of 3 different matrix multiplications
%time y = onp.dot(x, x)
%time y = np.dot(x, x)
%time y = np.dot(x, x).block_until_ready()

CPU times: user 62.9 ms, sys: 1.78 ms, total: 64.7 ms
Wall time: 7.31 ms
CPU times: user 172 ms, sys: 476 ms, total: 648 ms
Wall time: 41.2 ms
CPU times: user 790 µs, sys: 1.33 ms, total: 2.12 ms
Wall time: 139 µs


# jit

In [3]:
def ReLU(x):
    """ Rectified Linear Unit (ReLU) activation function """
    return np.maximum(0, x)

jit_ReLU = jit(ReLU)

In [4]:
%time out = ReLU(x).block_until_ready()
# Call jitted version to compile for evaluation time!
%time jit_ReLU(x).block_until_ready()
%time out = jit_ReLU(x).block_until_ready()

CPU times: user 19.6 ms, sys: 6.27 ms, total: 25.8 ms
Wall time: 31.6 ms
CPU times: user 9.58 ms, sys: 2.46 ms, total: 12 ms
Wall time: 18.9 ms
CPU times: user 111 µs, sys: 0 ns, total: 111 µs
Wall time: 175 µs


# grad

In [6]:
def FiniteDiffGrad(x):
    """ Compute the finite difference derivative approx for the ReLU"""
    return np.array((ReLU(x + 1e-3) - ReLU(x - 1e-3)) / (2 * 1e-3))

# Compare the Jax gradient with a finite difference approximation
print("Jax Grad: ", jit(grad(jit(ReLU)))(2.))
print("FD Gradient:", FiniteDiffGrad(2.))

Jax Grad:  1.0
FD Gradient: 0.99998707


# vmap

In [35]:
batch_dim = 32
feature_dim = 100
hidden_dim = 512

# Generate a batch of vectors to process
X = random.normal(key, (batch_dim, feature_dim))

# Generate Gaussian weights and biases
params = [random.normal(key, (hidden_dim, feature_dim)),
          random.normal(key, (hidden_dim, ))]

def relu_layer(params, x):
    """ Simple ReLu layer for single sample """
    return ReLU(np.dot(params[0], x) + params[1])

def batch_version_relu_layer(params, x):
    """ Error prone batch version """
    return ReLU(np.dot(X, params[0].T) + params[1])

def vmap_relu_layer(params, x):
    """ vmap version of the ReLU layer """
    return jit(vmap(relu_layer, in_axes=(None, 0), out_axes=0))

%time ut = np.stack([relu_layer(params, X[i, :]) for i in range(X.shape[0])])
%time out = batch_version_relu_layer(params, X)
out

CPU times: user 13.6 ms, sys: 3.01 ms, total: 16.6 ms
Wall time: 11.5 ms
CPU times: user 230 µs, sys: 174 µs, total: 404 µs
Wall time: 297 µs


Array([[ 0.49068165,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.7020273 ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.8156922 ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  7.3970523 ],
       ...,
       [ 6.127812  ,  0.        ,  0.        , ...,  6.7456055 ,
         0.        , 11.019837  ],
       [ 0.        ,  0.        , 21.238281  , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  6.715443  ,  0.        , ...,  0.        ,
         0.        , 10.9453125 ]], dtype=float32)

In [40]:
%time out = vmap_relu_layer(params, X)
print(out(params, X))

CPU times: user 78 µs, sys: 58 µs, total: 136 µs
Wall time: 141 µs
[[ 0.49068165  0.          0.         ...  0.          0.
   0.7020273 ]
 [ 0.          0.          0.         ...  0.          0.
   0.8156922 ]
 [ 0.          0.          0.         ...  0.          0.
   7.3970523 ]
 ...
 [ 6.127812    0.          0.         ...  6.7456055   0.
  11.019837  ]
 [ 0.          0.         21.238281   ...  0.          0.
   0.        ]
 [ 0.          6.715443    0.         ...  0.          0.
  10.9453125 ]]


In [23]:
import jax.numpy as jnp
A, B, C, D = 2, 3, 4, 5
x = jnp.ones((A, B))
y = jnp.ones((B, C))
z = jnp.ones((C, D))
def foo(tree_arg):
  x, (y, z) = tree_arg
  return jnp.dot(x, jnp.dot(y, z))
tree = (x, (y, z))
print(foo(tree))

[[12. 12. 12. 12. 12.]
 [12. 12. 12. 12. 12.]]


In [24]:
from jax import vmap
K = 6  # batch size
x = jnp.ones((K, A, B))  # batch axis in different locations
y = jnp.ones((B, K, C))
z = jnp.ones((C, D, K))
tree = (x, (y, z))
vfoo = vmap(foo, in_axes=((0, (1, 2)),))
print(vfoo(tree).shape)

(6, 2, 5)
