In [1]:
# a matmul in jax using pmap across GPUs
import jax.numpy as jnp
import jax
import time
import numpy as np

print(jax.devices())
DEVICES = len(jax.devices())
BS = 32
N = 4096
d = np.float16
nA = np.zeros((DEVICES, BS, N, N), d)
nB = np.zeros((1, 1, N, N), d)
print(jax.devices())
A = jax.device_put_sharded([nA[i] for i in range(DEVICES)], jax.devices())
B = jax.device_put_sharded([nB for i in range(DEVICES)], jax.devices())
A.shape, B.shape

2023-11-28 13:40:14.428794: E external/xla/xla/stream_executor/plugin_registry.cc:90] Invalid plugin kind specified: DNN


[rocm(id=0), rocm(id=1), rocm(id=2), rocm(id=3), rocm(id=4), rocm(id=5)]
[rocm(id=0), rocm(id=1), rocm(id=2), rocm(id=3), rocm(id=4), rocm(id=5)]


((6, 32, 4096, 4096), (6, 1, 1, 4096, 4096))

In [2]:
OPS = DEVICES*BS*N*N*N*2
print(f"TOPS {OPS*1e-12}")
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32)
#def matmul(A,B): return jax.nn.relu(jnp.matmul(A,B,preferred_element_type=jnp.float32))
lowered = jax.pmap(matmul)

TOPS 26.388279066624


In [3]:
MAX_TFLOPS = 123*DEVICES  # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
for i in range(10):
    st = time.perf_counter()
    C = lowered(A,B).block_until_ready()
    et = time.perf_counter()-st
    tflops = (OPS*1e-12)/et
    print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")

time 1654.82 ms, TFLOPS  15.95, MFU 2.16% out shape (6, 1, 32, 4096, 4096) dtype float32
time 37.22 ms, TFLOPS 709.07, MFU 96.08% out shape (6, 1, 32, 4096, 4096) dtype float32
time 36.30 ms, TFLOPS 727.03, MFU 98.51% out shape (6, 1, 32, 4096, 4096) dtype float32
time 35.58 ms, TFLOPS 741.60, MFU 100.49% out shape (6, 1, 32, 4096, 4096) dtype float32
time 34.87 ms, TFLOPS 756.70, MFU 102.53% out shape (6, 1, 32, 4096, 4096) dtype float32
time 34.37 ms, TFLOPS 767.70, MFU 104.02% out shape (6, 1, 32, 4096, 4096) dtype float32
time 34.05 ms, TFLOPS 775.05, MFU 105.02% out shape (6, 1, 32, 4096, 4096) dtype float32
time 34.97 ms, TFLOPS 754.66, MFU 102.26% out shape (6, 1, 32, 4096, 4096) dtype float32
time 35.85 ms, TFLOPS 735.97, MFU 99.73% out shape (6, 1, 32, 4096, 4096) dtype float32
time 36.08 ms, TFLOPS 731.35, MFU 99.10% out shape (6, 1, 32, 4096, 4096) dtype float32
