In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1,2,3'


# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
import jax
# jax.distributed.initialize()  # On GPU, see above for the necessary arguments.
jax.device_count()  # total number of accelerator devices in the cluster

3

In [3]:
jax.local_device_count()  # number of accelerator devices attached to this host

3

In [4]:
# The psum is performed over all mapped devices across the pod slice
xs = jax.numpy.ones(jax.local_device_count())
jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)

Array([3., 3., 3.], dtype=float32)

In [1]:
import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)

Array([11., 20., 29.], dtype=float32)

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1,2,3'

import jax

import numpy as np
import jax.numpy as jnp

import time

In [2]:

size = 5
x = jnp.arange(size)
w = jnp.array([2., 3., 4.])


def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

iter = 2

In [3]:
for _ in range(iter):
  t0 = time.time()
  a=jax.jit(convolve)(x, w)
  print(time.time() - t0)


1.5886805057525635
[  11.   20.   29.   38.   47.   56.   65.   74.   83.   92.  101.  110.
  119.  128.  137.  146.  155.  164.  173.  182.  191.  200.  209.  218.
  227.  236.  245.  254.  263.  272.  281.  290.  299.  308.  317.  326.
  335.  344.  353.  362.  371.  380.  389.  398.  407.  416.  425.  434.
  443.  452.  461.  470.  479.  488.  497.  506.  515.  524.  533.  542.
  551.  560.  569.  578.  587.  596.  605.  614.  623.  632.  641.  650.
  659.  668.  677.  686.  695.  704.  713.  722.  731.  740.  749.  758.
  767.  776.  785.  794.  803.  812.  821.  830.  839.  848.  857.  866.
  875.  884.  893.  902.  911.  920.  929.  938.  947.  956.  965.  974.
  983.  992. 1001. 1010. 1019. 1028. 1037. 1046. 1055. 1064. 1073. 1082.
 1091. 1100. 1109. 1118. 1127. 1136. 1145. 1154. 1163. 1172. 1181. 1190.
 1199. 1208. 1217. 1226. 1235. 1244. 1253. 1262. 1271. 1280. 1289. 1298.
 1307. 1316. 1325. 1334. 1343. 1352. 1361. 1370. 1379. 1388. 1397. 1406.
 1415. 1424. 1433. 1442. 1451. 1

In [None]:
n_devices = jax.local_device_count()
xs = np.arange(size * n_devices).reshape(-1, size)
ws = np.stack([w] * n_devices)


In [None]:
for _ in range(iter):
  t0 = time.time()
  jax.vmap(convolve)(xs, ws)
  print(time.time() - t0)


In [None]:
for _ in range(iter):
  t0 = time.time()
  a = jax.pmap(convolve)(xs, ws).block_until_ready()
  print(time.time() - t0)
