# Torch Test

In [1]:
import torch

print(torch.__version__)

2.2.2


In [2]:
import torch

def create_torch_tensors(device):
    x = torch.rand((10000, 10000), dtype=torch.float32)
    y = torch.rand((10000, 10000), dtype=torch.float32)
    x = x.to(device)
    y = y.to(device)

    return x, y

In [3]:
device = torch.device("cpu")
x, y = create_torch_tensors(device)

In [4]:
%%timeit
x * y

25.4 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
device = torch.device("mps")
x, y = create_torch_tensors(device)

In [7]:
%%timeit
x * y

6.86 ms ± 108 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Tensorflow

In [8]:
import tensorflow

print(tensorflow.__version__)



2.16.1


In [9]:
import tensorflow as tf

def create_tf_tensors():
    x = tf.random.uniform((10000, 10000), dtype=tf.float32)
    y = tf.random.uniform((10000, 10000), dtype=tf.float32)

    return x, y

x, y = create_tf_tensors()

In [10]:
%%timeit

with tf.device("/CPU:0"):
    x * y

29.4 ms ± 3.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%%timeit

with tf.device("/GPU:0"):
    x * y

6.79 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# JAX

In [13]:
import jax

print(jax.__version__)

0.4.26


In [15]:
import jax
import jax.numpy as jnp


def create_jax_tensors():
    x = jax.random.uniform(jax.random.PRNGKey(0), (10000, 10000), dtype=jnp.float32)
    y = jax.random.uniform(jax.random.PRNGKey(1), (10000, 10000), dtype=jnp.float32)

    return x, y


x, y = create_jax_tensors()

objc[74949]: Class MetalStreamWrapper is implemented in both /Users/jcolamendy/python/tutorials/ml-tutorials/venv/lib/python3.9/site-packages/tensorflow-plugins/libmetal_plugin.dylib (0x3ae43d638) and /Users/jcolamendy/python/tutorials/ml-tutorials/venv/lib/python3.9/site-packages/jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib (0x46aee51f8). One of the two will be used. Which one is undefined.
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
I0000 00:00:1716489511.188432 3855892 service.cc:145] XLA service 0x600001750300 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716489511.188455 3855892 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716489511.190299 3855892 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716489511.190312 3855892 mps_client.cc:384] XLA backend will use up to 8105328640 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M1 Pro


In [17]:
%%timeit
x * y

6.78 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
