# Multi-GPU programming with JAX

In [1]:
import jax
import numpy as np
from timers import cpu_timer

In [2]:
gpus = jax.devices('gpu')

In [3]:
for i, d in enumerate(gpus):
    print(f"Device {i}: {d}")

Device 0: cuda:0
Device 1: cuda:1
Device 2: cuda:2
Device 3: cuda:3


In [4]:
jax.local_device_count()

4

In [5]:
mesh = jax.sharding.Mesh(jax.devices(), axis_names=['x'])

In [53]:
x = jax.random.normal(jax.random.key(0), (16000, 16000))

In [54]:
y = jax.device_put(x, jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', None)))

In [55]:
y.sum(axis=0)

Array([  52.915867,   46.157303,  -26.998585, ...,  190.85838 ,
         85.26398 , -303.00476 ], dtype=float32)

In [56]:
jax.debug.visualize_array_sharding(y)

In [57]:
@jax.jit
def my_complex_sin(z):
    y = jax.numpy.sin(z) + jax.numpy.cos(z) + z ** 2

    return y.sum(axis=0)

In [58]:
with cpu_timer(True):
    z = my_complex_sin(x).block_until_ready()

Elapsed time: 109.63940399960848 ms


In [59]:
with cpu_timer(True):
    z2 = my_complex_sin(y).block_until_ready()

Elapsed time: 72.65475099848118 ms


In [60]:
z2.sharding

NamedSharding(mesh=Mesh('x': 4, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=device)

In [82]:
with cpu_timer(True):
    z = y ** 2
    for i in range(20):
        z += y @ y @ y @ y @ y @ y @ y
    z.sum().block_until_ready()
print(z.sum())

Elapsed time: 1181.931570994493 ms
5.414268e+18


In [83]:
with cpu_timer(True):
    z = x ** 2
    for i in range(20):
        z += x @ x @ x @ x @ x @ x @ x
print(z.sum())

Elapsed time: 2564.825961002498 ms
5.4146616e+18


(Array(80721552., dtype=float32), Array(81084800., dtype=float32))