In [1]:
# a matmul in jax using pmap across 8 GPUs
import jax.numpy as jnp
import jax
import time
import numpy as np

from jax.sharding import PositionalSharding
sharding = PositionalSharding(jax.devices())
DEVICES = 8
BS = 32
N = 4096
d = np.float16
nA = np.zeros((DEVICES, BS, N, N), d)
nB = np.zeros((DEVICES, 1, N, N), d)
print(jax.devices())
A = jax.device_put_sharded([nA[i] for i in range(DEVICES)], jax.devices())
B = jax.device_put_sharded([nB[i] for i in range(DEVICES)], jax.devices())
A.shape, B.shape

[cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4), cuda(id=5), cuda(id=6), cuda(id=7)]


((8, 32, 4096, 4096), (8, 1, 4096, 4096))

In [2]:
OPS = DEVICES*BS*N*N*N*2
print(f"TOPS {OPS*1e-12}")
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float16)
lowered = jax.pmap(matmul)

TOPS 35.184372088832


In [3]:
MAX_TFLOPS = 142*DEVICES  # Peak FP16 Tensor TFLOPS with FP16 Acc
for i in range(10):
    st = time.perf_counter()
    C = lowered(A,B).block_until_ready()
    et = time.perf_counter()-st
    tflops = (OPS*1e-12)/et
    print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")

time 4511.29 ms, TFLOPS   7.80, MFU 0.69% out shape (8, 32, 4096, 4096) dtype float16
time 33.10 ms, TFLOPS 1062.90, MFU 93.57% out shape (8, 32, 4096, 4096) dtype float16
time 32.89 ms, TFLOPS 1069.81, MFU 94.17% out shape (8, 32, 4096, 4096) dtype float16
time 32.81 ms, TFLOPS 1072.43, MFU 94.40% out shape (8, 32, 4096, 4096) dtype float16
time 32.95 ms, TFLOPS 1067.93, MFU 94.01% out shape (8, 32, 4096, 4096) dtype float16
time 32.89 ms, TFLOPS 1069.82, MFU 94.17% out shape (8, 32, 4096, 4096) dtype float16
time 32.84 ms, TFLOPS 1071.48, MFU 94.32% out shape (8, 32, 4096, 4096) dtype float16
time 32.76 ms, TFLOPS 1073.86, MFU 94.53% out shape (8, 32, 4096, 4096) dtype float16
time 32.72 ms, TFLOPS 1075.39, MFU 94.66% out shape (8, 32, 4096, 4096) dtype float16
time 32.93 ms, TFLOPS 1068.51, MFU 94.06% out shape (8, 32, 4096, 4096) dtype float16
