Skip to content

JAX Metal: Random Number Generation Performance Issue #31286

@AOS55

Description

@AOS55

Description

JAX Metal shows significant performance degradation for random number generation compared to CUDA. On identical array sizes (50,000 x 252 elements), Metal takes 26+ seconds while CUDA takes 0.5 seconds - a 55x performance gap.

Reproduction Code

import time
import jax
import jax.numpy as jnp
from jax import random

def benchmark():
    n_simulations, n_steps = 50000, 252
    
    key = random.PRNGKey(42)
    start_time = time.time()
    random_array = random.normal(key, (n_simulations, n_steps))
    result_sum = float(jnp.sum(random_array))
    duration = time.time() - start_time
    
    elements_per_sec = (n_simulations * n_steps) / duration
    print(f"Backend: {jax.default_backend()}")
    print(f"Duration: {duration:.3f}s")
    print(f"Performance: {elements_per_sec:,.0f} elements/sec")

benchmark()

Results

NVIDIA

TESTING GPU
JAX Backend: gpu
JAX Devices: [CudaDevice(id=0)]
Array size: 50,000 x 252 = 12,600,000 elements
Duration: 0.503 seconds
Performance: 25,035,568 elements/second
Performance: 99,347 simulations/second
Array sum: -795.07

SUMMARY
Backend: gpu
Total time: 0.503s
Elements/sec: 25,035,568

Mac

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1755949008.631433  342310 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
I0000 00:00:1755949008.640727  342310 service.cc:145] XLA service 0x6000036d8900 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1755949008.640735  342310 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1755949008.642089  342310 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1755949008.642098  342310 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
TESTING METAL
JAX Backend: METAL
JAX Devices: [METAL(id=0)]
Array size: 50,000 x 252 = 12,600,000 elements
Duration: 26.301 seconds
Performance: 479,075 elements/second
Performance: 1,901 simulations/second
Array sum: -1500.49
SUMMARY
Backend: METAL
Total time: 26.301s
Elements/sec: 479,075
I0000 00:00:1755949035.153338  342310 mps_client.h:209] MetalClient destroyed.

System info (python version, jaxlib version, accelerator, etc.)

NVIDIA Environment

jax:    0.7.1
jaxlib: 0.7.1
numpy:  2.1.2
python: 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0]
device info: NVIDIA GeForce RTX 5090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='c8538dfe6140', release='6.11.0-26-generic', version='#26~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Apr 17 19:20:47 UTC 2', machine='x86_64')

$ nvidia-smi
Sat Aug 23 11:43:39 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.153.02             Driver Version: 570.153.02     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 5090        On  |   00000000:0C:00.0 Off |                  N/A |
|  0%   28C    P8             21W /  575W |     535MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|

|    0   N/A  N/A           10769      C   python                                  526MiB |
+-----------------------------------------------------------------------------------------+

Mac M1Max

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1755949524.899129  350993 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

I0000 00:00:1755949524.909090  350993 service.cc:145] XLA service 0x600002ca4200 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1755949524.909103  350993 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1755949524.910408  350993 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1755949524.910416  350993 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
jax:    0.4.34
jaxlib: 0.4.34
numpy:  1.26.4
python: 3.12.3 | packaged by Anaconda, Inc. | (main, May  6 2024, 14:46:42) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='QuessyMBP', release='24.6.0', version='Darwin Kernel Version 24.6.0: Mon Jul 14 11:30:29 PDT 2025; root:xnu-11417.140.69~1/RELEASE_ARM64_T6000', machine='arm64')

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions