## PyTorch


In [1]:
import torch


def create_torch_tensors(device):
    x = torch.rand((10000, 10000), dtype=torch.float32)
    y = torch.rand((10000, 10000), dtype=torch.float32)
    x = x.to(device)
    y = y.to(device)

    return x, y

In [2]:
device = torch.device("cpu")
x, y = create_torch_tensors(device)

In [3]:
%%timeit
x * y

20.5 ms ± 499 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
device = torch.device("mps")
x, y = create_torch_tensors(device)

In [5]:
%%timeit
x * y

128 µs ± 53.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## TensorFlow


In [6]:
import tensorflow as tf


def create_tf_tensors():
    x = tf.random.uniform((10000, 10000), dtype=tf.float32)
    y = tf.random.uniform((10000, 10000), dtype=tf.float32)

    return x, y


x, y = create_tf_tensors()

In [7]:
%%timeit

with tf.device("/CPU:0"):
    x * y

19.6 ms ± 450 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
%%timeit

with tf.device("/GPU:0"):
    x * y

3.34 ms ± 54.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## JAX


In [9]:
import os

os.environ["JAX_PLATFORMS"] = "cpu"

In [12]:
import os

# os.environ["JAX_PLATFORMS"] = "cpu"

In [13]:
import jax
import jax.numpy as jnp


def create_jax_tensors():
    x = jax.random.uniform(jax.random.PRNGKey(0), (10000, 10000), dtype=jnp.float32)
    y = jax.random.uniform(jax.random.PRNGKey(1), (10000, 10000), dtype=jnp.float32)

    return x, y


x, y = create_jax_tensors()

In [14]:
%%timeit
x * y

3.4 ms ± 54.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
