## PyTorch


In [1]:
import torch


def create_torch_tensors(device):
    x = torch.rand((10000, 10000), dtype=torch.float32)
    y = torch.rand((10000, 10000), dtype=torch.float32)
    x = x.to(device)
    y = y.to(device)

    return x, y

In [2]:
device = torch.device("cpu")
x, y = create_torch_tensors(device)

In [3]:
%%timeit
x * y

18.9 ms ± 245 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
device = torch.device("mps")
# device = torch.device("cuda") # specify CUDA to try this on an NVIDIA GPU
x, y = create_torch_tensors(device)

In [5]:
%%timeit
x * y

5.29 ms ± 9.32 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## TensorFlow


In [6]:
import tensorflow as tf


def create_tf_tensors():
    x = tf.random.uniform((10000, 10000), dtype=tf.float32)
    y = tf.random.uniform((10000, 10000), dtype=tf.float32)

    return x, y


x, y = create_tf_tensors()

In [7]:
%%timeit

with tf.device("/CPU:0"):
    x * y

21.5 ms ± 561 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
%%timeit

with tf.device("/GPU:0"):
    x * y

5.29 ms ± 5.97 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## JAX


In [9]:
import jax
import jax.numpy as jnp

In [10]:
# using default GPU
def create_jax_tensors():
    x = jax.random.uniform(jax.random.PRNGKey(0), (10000, 10000), dtype=jnp.float32)
    y = jax.random.uniform(jax.random.PRNGKey(1), (10000, 10000), dtype=jnp.float32)

    return x, y


x, y = create_jax_tensors()

objc[15775]: Class MetalStreamWrapper is implemented in both /Users/jkozik/projects/apple-silicon/env/lib/python3.8/site-packages/tensorflow-plugins/libmetal_plugin.dylib (0x3761dd638) and /Users/jkozik/projects/apple-silicon/env/lib/python3.8/site-packages/jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib (0x490679228). One of the two will be used. Which one is undefined.


Metal device set to: Apple M4 Pro

systemMemory: 24.00 GB
maxCacheSize: 8.00 GB



In [11]:
%%timeit
x * y

5.35 ms ± 35.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
# Forcing JAX to use CPU for comparison
cpu_device = jax.devices("cpu")[0]

with jax.default_device(cpu_device):
    def create_jax_tensors():
        x = jax.random.uniform(jax.random.PRNGKey(0), (10000, 10000), dtype=jnp.float32)
        y = jax.random.uniform(jax.random.PRNGKey(1), (10000, 10000), dtype=jnp.float32)

        return x, y


    x, y = create_jax_tensors()

In [13]:
%%timeit
x * y

61.1 ms ± 570 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
