In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1,2,3'

import sys
sys.path.append('/home/yongha/meent')


# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
import jax
# jax.distributed.initialize()  # On GPU, see above for the necessary arguments.
jax.device_count()  # total number of accelerator devices in the cluster

3

In [2]:
from jax_smi import initialise_tracking
initialise_tracking()
# some computation...

In [3]:
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0)]

In [18]:
import numpy as np
import jax.numpy as jnp

size1 = 30000
size2 = 2+101

x = np.arange(size1)
w = np.arange(2, size2, 1)
# w = np.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(size2//2, len(x)-size2//2):
    output.append(jnp.dot(x[i-size2//2:i+size2//2-1], w))
  return jnp.array(output)

# convolve(x, w)

In [19]:
n_devices = jax.local_device_count()
xs = np.arange(size1 * n_devices).reshape(-1, size1)
ws = np.stack([w] * n_devices)

In [8]:
%time jax.vmap(jax.jit(convolve))(xs, ws)

CPU times: user 7min 3s, sys: 5.38 s, total: 7min 8s
Wall time: 5min 11s


Array([[   348450,    353702,    358954, ...,  52316990,  52322242,
         52327494],
       [ 52868450,  52873702,  52878954, ..., 104836990, 104842242,
        104847494],
       [105388450, 105393702, 105398954, ..., 157356990, 157362242,
        157367494]], dtype=int32)

In [9]:
%timeit jax.vmap(jax.jit(convolve))(xs, ws)

5.6 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
%time jax.vmap(convolve)(xs, ws)

CPU times: user 3.32 s, sys: 415 ms, total: 3.73 s
Wall time: 2.76 s


Array([[  348450,   353702,   358954, ...,  5048990,  5054242,  5059494],
       [ 5600450,  5605702,  5610954, ..., 10300990, 10306242, 10311494],
       [10852450, 10857702, 10862954, ..., 15552990, 15558242, 15563494]],      dtype=int32)

In [19]:
%timeit jax.vmap(convolve)(xs, ws)

2.58 s ± 20.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
%time jax.pmap(convolve)(xs, ws)

CPU times: user 6min 2s, sys: 8.53 s, total: 6min 10s
Wall time: 4min 17s


Array([[   348450,    353702,    358954, ...,  52316990,  52322242,
         52327494],
       [ 52868450,  52873702,  52878954, ..., 104836990, 104842242,
        104847494],
       [105388450, 105393702, 105398954, ..., 157356990, 157362242,
        157367494]], dtype=int32)

In [7]:
%timeit jax.pmap(convolve)(xs, ws)


6.32 ms ± 364 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
%time jax.vmap(jax.jit(convolve))(xs, ws)
%time jax.pmap(convolve)(xs, ws)

%timeit jax.vmap(jax.jit(convolve))(xs, ws)
%timeit jax.pmap(convolve)(xs, ws)


CPU times: user 18.3 ms, sys: 3.89 ms, total: 22.2 ms
Wall time: 19.4 ms
CPU times: user 34.1 ms, sys: 43 µs, total: 34.2 ms
Wall time: 13 ms
5.27 ms ± 364 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.14 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
%time jax.vmap(jax.jit(convolve))(xs, ws)
%time jax.pmap(convolve)(xs, ws)

%timeit jax.vmap(jax.jit(convolve))(xs, ws)
%timeit jax.pmap(convolve)(xs, ws)


2023-01-28 21:01:34.568619: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_convolve] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-01-28 21:01:46.953251: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 2m12.380533367s

********************************
[Compiling module jit_convolve] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


CPU times: user 1h 17min 5s, sys: 56.1 s, total: 1h 18min 1s
Wall time: 1h 12min 1s


2023-01-28 22:11:19.419497: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 2m12.614540417s

********************************
[Compiling module pmap_convolve] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


CPU times: user 1h 14min 24s, sys: 41.7 s, total: 1h 15min 6s
Wall time: 1h 9min 33s
14.5 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
22.3 ms ± 983 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [22]:
%time jax.vmap(jax.jit(convolve))(xs, ws)
%time jax.pmap(convolve)(xs, ws)

%timeit jax.vmap(jax.jit(convolve))(xs, ws)
%timeit jax.pmap(convolve)(xs, ws)


CPU times: user 37.1 ms, sys: 4.01 ms, total: 41.1 ms
Wall time: 38.1 ms
CPU times: user 85.4 ms, sys: 7.97 ms, total: 93.3 ms
Wall time: 41.3 ms
16.1 ms ± 1.21 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
24.3 ms ± 1.81 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
