# Runtime comparison for various implementations of `approximation_error()`


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import cupy as cp
import jax
import jax.numpy as jnp
import numpy as np

from mat_dnf.jax import losses as jlosses
from mat_dnf.numpy import losses

jax.config.update("jax_enable_x64", True)

In [3]:
sample_dir = "sample_0"
array_dir = Path("../tests/resources/approximation_error") / sample_dir
assert array_dir.exists()

### Numpy


In [4]:
input_dir, _, input_filenames = next((array_dir / "input").walk())
output_dir, _, output_filenames = next(array_dir.walk())

input_arrays = {(array_dir / f).stem: np.load(input_dir / f) for f in input_filenames}
output_arrays = {
    (array_dir / f).stem: np.load(output_dir / f) for f in output_filenames
}

In [None]:
%timeit er_k_th, c_th, d_k_th = losses.approximation_error(**input_arrays)

### Cupy


In [6]:
input_dir, _, input_filenames = next((array_dir / "input").walk())
output_dir, _, output_filenames = next(array_dir.walk())

cupy_input_arrays = {
    (array_dir / f).stem: cp.asarray(np.load(input_dir / f)) for f in input_filenames
}
cupy_input_arrays = {
    k: v if k not in ("split_c", "split_d_k") else v.item()
    for k, v in cupy_input_arrays.items()
}
cupy_output_arrays = {
    (array_dir / f).stem: cp.asarray(np.load(output_dir / f)) for f in output_filenames
}

In [None]:
%timeit er_k_th, c_th, d_k_th = losses.approximation_error(**cupy_input_arrays)

### JAX (CPU)


(Includes compile as it is the realistic use case)


In [None]:
cpu_device = jax.devices("cpu")[0]  # ALL cpu
cpu_device

In [9]:
input_jax_arrays_cpu = {
    k: jax.device_put(jnp.array(v), device=cpu_device)
    if k not in ("split_c", "split_d_k")
    else jnp.array(v).item()
    for k, v in input_arrays.items()
}

In [None]:
%%timeit
er_k_th, c_th, d_k_th = jlosses.approximation_error(**input_jax_arrays_cpu)
er_k_th.block_until_ready()
c_th.block_until_ready()
d_k_th.block_until_ready()

### JAX (GPU)


p.s. I only have one GPU; use `jax.sharding` for multiple GPUs


In [None]:
gpu_device = jax.devices("gpu")[0]
gpu_device

In [12]:
input_jax_arrays_gpu = {
    k: jax.device_put(jnp.array(v), device=gpu_device)
    if k not in ("split_c", "split_d_k")
    else int(v)
    for k, v in input_arrays.items()
}

In [None]:
%%timeit
er_k_th, c_th, d_k_th = jlosses.approximation_error(**input_jax_arrays_gpu)
er_k_th.block_until_ready()
c_th.block_until_ready()
d_k_th.block_until_ready()