# Comparing Workflow Performance on Different Backends 

As of now, JAX does not support the Apple GPU for hardware acceleration and computations will by default fall back to the CPU. TensorFlow and PyTorch both have Metal backend support for Apple GPUs, allowing better utilization of the M2 chip’s GPU. Tests below compare TF & Jax performance on both CPU & GPU.

In [1]:
# Local Setup on an M2 Mac

# pip install tensorflow-macos
# pip install tensorflow-metal 

In [2]:
import tensorflow as tf

# Check TensorFlow version
print(f"TensorFlow version: {tf.__version__}")

# Verify if TensorFlow detects the GPU (Metal backend)
print("Is GPU available:", tf.config.list_physical_devices('GPU'))

TensorFlow version: 2.16.2
Is GPU available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


The memory growth setting allows TensorFlow to:

- Allocate only the memory it needs, instead of reserving all GPU memory upfront.
- Dynamically increase memory usage as the workload grows.

In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("Set memory growth for GPU")
    except RuntimeError as e:
        print(e)

Set memory growth for GPU


# TensorFlow CPU vs. GPU Test

**A Note on Matrix Size Impact** 

For testing, larger matrices benefit GPU usage (i.e. the [10000, 10000] one listed below, while smaller ones [1000, 1000] may favor the CPU due to overhead.

In [4]:
import tensorflow as tf
import time

# Generate random matrices
a = tf.random.normal([10000, 10000])
b = tf.random.normal([10000, 10000])

# Perform on CPU
with tf.device('/CPU:0'):
    start = time.time()
    c = tf.matmul(a, b)
    print("CPU time:", time.time() - start)

# Perform on GPU
with tf.device('/GPU:0'):
    start = time.time()
    c = tf.matmul(a, b)
    print("GPU time:", time.time() - start)

CPU time: 4.150007963180542
GPU time: 0.025270938873291016


# JAX Metal Test (Experimental)

JAX uses lazy execution, meaning computations are deferred until explicitly needed (via .block_until_ready()), adding slight overhead compared to TensorFlow, which eagerly executes operations. 

[Apple Dev Docs](https://developer.apple.com/metal/jax/)

Vastly slower execution on this one as of Nov 2024.

In [5]:
import jax
import jax.numpy as jnp
import jax.random as random
import time

print(f"JAX version: {jax.__version__}")

# Check the device JAX is using
print(f"JAX is using device: {jax.devices()[0]}")

# Generate random matrices using jax.random
key = random.PRNGKey(0)  # Create a random key
a = random.normal(key, shape=(10000, 10000))
b = random.normal(key, shape=(10000, 10000))

# Perform matrix multiplication on JAX (using Metal GPU if available)
start = time.time()
c = jnp.dot(a, b).block_until_ready()  # Ensure computation finishes
print("JAX computation time:", time.time() - start)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


JAX version: 0.4.35
Metal device set to: Apple M2
JAX is using device: METAL:0


I0000 00:00:1731468564.655418  545268 service.cc:145] XLA service 0x600000524400 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731468564.655426  545268 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1731468564.656324  545268 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1731468564.656332  545268 mps_client.cc:384] XLA backend will use up to 15579430912 bytes on device 0 for SimpleAllocator.


JAX computation time: 2.1954009532928467


# Conclusion

- TensorFlow’s Metal Backend delivers excellent GPU performance on M2 Macs.
- JAX currently lacks full support for Metal, making it less competitive on Apple hardware.
- TensorFlow outperforms JAX on the CPU due to its mature optimization.