In [1]:
# https://hsf-training.github.io/hsf-training-ml-gpu-webpage/02-whichgpu/index.html, refer to https://github.com/anderskm/gputil/blob/master/GPUtil/GPUtil.py
import GPUtil
device_count = len(GPUtil.getGPUs())
print("available gpu from GPUtil : {}".format(device_count))

import torch
device_count = torch.cuda.device_count()
print("available gpu from torch : {}".format(device_count))

available gpu from GPUtil : 2
available gpu from torch : 2


In [3]:
# refer to https://github.com/google/jax
# refer to quick start, https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

import jax.numpy as jnp
from jax import grad, jit, vmap

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)  # inputs to the next layer
  return outputs                # no activation on last layer

def loss(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_loss = jit(grad(loss))  # compiled gradient evaluation function
print("grad_loss: {}".format(grad_loss))
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0)))  # fast per-example grads
print("perex_grads: {}".format(perex_grads))

grad_loss: <function api_boundary.<locals>.reraise_with_filtered_traceback at 0x7fbf8e152f70>
perex_grads: <function api_boundary.<locals>.reraise_with_filtered_traceback at 0x7fbf8e1625e0>


In [5]:
# Transformations
# Automatic differentiation with grad
from jax import grad
import jax.numpy as jnp

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print("grad_tanh(1.0): {}".format(grad_tanh(1.0)))   # Evaluate it at x = 1.0
# prints 0.4199743
print("grad(grad(grad(tanh)))(1.0): {}".format(grad(grad(grad(tanh)))(1.0)))

grad_tanh(1.0): 0.41997429728507996
grad(grad(grad(tanh)))(1.0): 0.621626615524292


In [6]:
from jax import jit, jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0))   # prints 1.0
print(abs_val_grad(-1.0))  # prints -1.0 (abs_val is re-evaluated)

1.0
-1.0


In [7]:
# Compilation with jit
import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x)  # ~ 14.5 ms / loop (also on GPU via JAX)

The slowest run took 59.39 times longer than the fastest. This could mean that an intermediate result is being cached.
6.31 ms ± 7.96 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
The slowest run took 6.12 times longer than the fastest. This could mean that an intermediate result is being cached.
12.5 ms ± 11.1 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [19]:
# Auto-vectorization with vmap
def predict(params, input_vec):
  assert input_vec.ndim == 1
  activations = input_vec
  for W, b in params:
    outputs = jnp.dot(W, activations) + b  # `activations` on the right-hand side!
    activations = jnp.tanh(outputs)        # inputs to the next layer
  return outputs                           # no activation on last layer

In [20]:
import numpy as np

params = np.random.rand(100)
input_batch  = 16;
print("param.ndim: {}, input_batch: {}".format(params.ndim, input_batch))

param.ndim: 1, input_batch: 16


In [21]:
from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))

TypeError: 'int' object is not iterable

In [22]:
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)

ValueError: vmap got arg 0 of rank 0 but axis to be mapped 0. The tree of ranks is:
(0,)

In [24]:
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)

NameError: name 'inputs' is not defined

In [25]:
# SPMD programming with pmap
from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), device_count)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]

[1.1608742 1.2301513]


In [13]:
# In addition to expressing pure maps, you can use fast collective communication operations between devices:
from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')

print(normalize(jnp.arange(device_count)))
# prints [0.         0.16666667 0.33333334 0.5       ]

[0. 1.]


In [23]:
import jax
jax.devices()

[GpuDevice(id=0), GpuDevice(id=1)]

In [25]:
# It all composes, so you're free to differentiate through parallel computations:
from jax import grad

# https://github.com/google/jax/discussions/4198
@pmap
def f(x):
  y = jnp.sin(x)
  @pmap
  def g(z):
    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
  return grad(lambda w: jnp.sum(g(w)))(x)

print(f(x))
# [[ 0.        , -0.7170853 ],
#  [-3.1085174 , -0.4824318 ],
#  [10.366636  , 13.135289  ],
#  [ 0.22163185, -0.52112055]]

print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726,  -1.6356447],
#  [  4.7572474,  11.606951 ],
#  [-98.524414 ,  42.76499  ],
#  [ -1.6007166,  -1.2568436]]

ValueError: compiling computation that requires 25000000 logical devices, but only 2 XLA devices are available (num_replicas=25000000, num_partitions=1)