In [None]:
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp

import jeig

In [None]:
batch_size = [1, 16]
matrix_size = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
repeats = 3
backends = [jeig.EigBackend.JAX, jeig.EigBackend.MAGMA, jeig.EigBackend.NUMPY, jeig.EigBackend.SCIPY, jeig.EigBackend.TORCH]

fns = {}
for backend in backends:
    fns[backend] = jax.jit(lambda x: jeig.eig(x, backend=backend))

results = {}
for backend in backends:
    results[backend] = onp.zeros((len(batch_size), len(matrix_size), repeats))

for i, bs in enumerate(batch_size):
    for j, ms in enumerate(matrix_size):
        shape = (bs, ms, ms)
        for backend in backends:
            for repeat in range(repeats):
                key = jax.random.fold_in(jax.random.PRNGKey(0), repeat)
                matrix = jax.random.normal(key, shape)
                t0 = time.time()
                jax.block_until_ready(fns[backend](matrix))
                et = time.time() - t0
                results[backend][i, j, repeat] = et

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(9, 3))

for backend in backends:
    for i in range(len(batch_size)):
        axs[i].loglog(matrix_size, onp.amin(results[backend][i, :, :], axis=-1), "o-", label=backend)

for i in range(len(batch_size)):
    axs[i].set_title(f"batch_size={batch_size[i]}")

axs[1].legend(bbox_to_anchor=(1, 1))
axs[0].set_ylabel("Elapsed time (s)")
for ax in axs:
    ax.set_xlabel("Matrix size")
plt.tight_layout()