In [1]:
import math
from functools import partial
from typing import Literal

import pyscenarios
from pyscenarios._sobol._kernel_numba import sobol_kernel as sobol_numba
from pyscenarios._sobol._kernel_numpy import sobol_kernel as sobol_numpy, _calc_c as _c_numpy
from pyscenarios._sobol._vmatrix import V

import numpy as np
import jax
import jax.lax
import jax.numpy as jnp
import numba
import scipy

jax.config.update("jax_enable_x64", True)

In [2]:
SAMPLES = 2**17 - 1
DIMENSIONS = 500
CHUNKS = (2**13, -1)

In [3]:
def get_jax_backend() -> Literal["cpu", "gpu", "tpu"]:
    """Determine the backend on which new arrays are created with device=None.

    Note that device=not None does not work as of JAX 0.6.0:
    https://github.com/jax-ml/jax/issues/26000

    .. warning::

       This function relies on the fact that `@jax.jit` re-traces the Python
       when the default device changes, which is an implementation detail and
       may break without warning in a future version of JAX.
    """
    device = jax.config.jax_default_device
    return jax.default_backend() if device is None else device.platform


@jax.jit
def f():
    backend = get_jax_backend()
    return jnp.asarray(1 if backend == "cpu" else 2)


@jax.jit
def g(x):
    backend = get_jax_backend()
    return x + (1 if backend == "cpu" else 2)

f.clear_cache()
g.clear_cache()

cpu = jax.devices("cpu")[0]
gpu = jax.devices("gpu")[0]

x_cpu = jnp.asarray(0, device=cpu)
x_gpu = jnp.asarray(0, device=gpu)

jax.config.update("jax_default_device", cpu)
assert f() == 1
assert f().device.platform == "cpu"
assert g(x_cpu) == 1
assert g(x_cpu).device.platform == "cpu"
assert g(x_gpu) == 1
assert g(x_gpu).device.platform == "gpu"

jax.config.update("jax_default_device", gpu)
assert f() == 2
assert f().device.platform == "gpu"
assert g(x_cpu) == 2
assert g(x_cpu).device.platform == "cpu"
assert g(x_gpu) == 2
assert g(x_gpu).device.platform == "gpu"

jax.config.update("jax_default_device", None)
assert f() == 2
assert f().device.platform == "gpu"
assert g(x_cpu) == 2
assert g(x_cpu).device.platform == "cpu"
assert g(x_gpu) == 2
assert g(x_gpu).device.platform == "gpu"

In [4]:
def _c_aapi(samples: int, *, device=None, xp):
    """c[i] = index from the right of the first zero bit of sample index i"""
    # c[i] = index from the right of the first zero bit of sample index i
    samples_rng = xp.arange(samples, device=device)
    c_max = int(math.log(samples, 2))
    out = xp.full(samples, c_max, device=device)
    for c in range(c_max + 1, -1, -1):
        mask = samples_rng & (1 << c) == 0
        out = xp.where(mask, c, out)
    return out


_c_aapi_jax = jax.jit(partial(_c_aapi, xp=jnp), static_argnames=("samples", "device"))

expect = _c_numpy(SAMPLES)
np.testing.assert_array_equal(_c_aapi(SAMPLES, xp=np), expect, strict=True)
for device in (cpu, gpu):
    jax.config.update("jax_default_device", device)
    assert _c_aapi_jax(SAMPLES).device == device
    np.testing.assert_array_equal(_c_aapi_jax(SAMPLES), expect, strict=True)

In [5]:
# BENCHMARK

print("NumPy")
%timeit _c_aapi(SAMPLES, xp=np)

_c_aapi_jax.clear_cache()
for device in (cpu, gpu):
    print(f"\nJAX {device.platform}")
    jax.config.update("jax_default_device", device)
    %timeit -n 1 -r 1 _c_aapi_jax(SAMPLES).block_until_ready()
    %timeit _c_aapi_jax(SAMPLES).block_until_ready()

NumPy
2.91 ms ± 11.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

JAX cpu
50.4 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
75.3 μs ± 4.63 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

JAX gpu
92.8 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
41.6 μs ± 306 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
@partial(jax.jit, static_argnames=("samples", "dimensions", "d0"))
def sobol_jax(samples: int, dimensions: int, d0: int = 0):
    c = _c_aapi(samples, xp=jnp)
    VT = jnp.asarray(V)[d0:d0 + dimensions, :].T
    states = jnp.take(VT, c, axis=0)

    # FIXME fairly slow on CPU and very slow on GPU
    # https://github.com/jax-ml/jax/issues/28097
    # states = jnp.bitwise_xor.accumulate(states, axis=0)
    if get_jax_backend() == "cpu":
        # 1.2x faster on CPU and 1.3x faster on GPU
        states = jax.lax.fori_loop(
            1,
            samples,
            lambda i, x: x.at[i, :].set(x[i - 1, :] ^ x[i, :]),
            states,
        )
    else:  # gpu, tpu
        # 220x faster on GPU, but 2.5x slower on CPU
        states = jax.lax.associative_scan(jnp.bitwise_xor, states)

    return jnp.astype(states, jnp.float64) / 2**32


expect = sobol_numba(1023, 4, 0, 0)

for device in (cpu, gpu):
    jax.config.update("jax_default_device", device)
    actual = sobol_jax(1023, 4)
    assert actual.device == device
    np.testing.assert_allclose(actual, expect, strict=True)

In [7]:
# BENCHMARK

print("Numba")
%timeit pyscenarios.sobol((SAMPLES, DIMENSIONS))

print("\nDask+Numba")
%timeit pyscenarios.sobol((SAMPLES, DIMENSIONS), chunks=CHUNKS).compute()

print("\nNumPy")
%timeit sobol_numpy(SAMPLES, DIMENSIONS, 0, 0)

sobol_jax.clear_cache()
for device in (cpu, gpu):
    print(f"\nJAX {device.platform}")
    jax.config.update("jax_default_device", device)
    %timeit -n 1 -r 1 sobol_jax(SAMPLES, DIMENSIONS).block_until_ready()
    %timeit sobol_jax(SAMPLES, DIMENSIONS).block_until_ready()

Numba
95.7 ms ± 832 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Dask+Numba
5.65 s ± 90.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

NumPy
645 ms ± 26.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

JAX cpu
256 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
194 ms ± 2.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

JAX gpu
5.83 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
5.52 ms ± 122 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
def generate_positive_definite_cov_matrix(n: int) -> np.ndarray:
    """Generate a positive definite covariance matrix of the specified width."""
    # Generate a random matrix
    A = np.random.uniform(-1, 1, size=(n, n))
    # Symmetrize the matrix
    A = (A + A.T) / 2
    # Add n * I to ensure positive definiteness
    A += n * np.eye(n)
    # Normalize to set the diagonal to 1
    D = np.diag(1 / np.sqrt(np.diag(A)))
    cov = D @ A @ D
    return cov

cov = generate_positive_definite_cov_matrix(3)
cov

array([[ 1.        , -0.09502161, -0.107505  ],
       [-0.09502161,  1.        , -0.09512254],
       [-0.107505  , -0.09512254,  1.        ]])

In [9]:
import os
import sys
from concurrent.futures import ThreadPoolExecutor

import jax
import jax.numpy as jnp
import scipy.special

if sys.version_info[:2] >= (3, 13):
    CPU_COUNT = os.process_cpu_count()
else:
    CPU_COUNT = os.cpu_count()


def apply_parallel_numpy_ufunc(ufunc, *args):
    shape = jnp.broadcast_shapes(*(jnp.shape(x) for x in args))
    dtype = jnp.result_type(*args)
    args = tuple(jnp.astype(x, dtype, copy=False) for x in args)
    out_type = jax.ShapeDtypeStruct(shape, dtype)

    def callback(*args):
        args = jax.device_get(args)
        args = tuple(np.asarray(x) for x in args)
        args = np.broadcast_arrays(*args)
        args = tuple(np.reshape(x, -1) for x in args)
        out = np.empty_like(args[0])

        n = max(1, min(CPU_COUNT, out.size // 10_000))
        with ThreadPoolExecutor(n) as ex:
            for i in range(n):
                ex.submit(ufunc, *(arg[i::n] for arg in args), out=out[i::n])

        return np.reshape(out, shape)

    return jax.pure_callback(callback, out_type, *args)


# jax.scipy.special.betainc exists, but it's very slow
# https://github.com/jax-ml/jax/issues/28547
betainc = partial(apply_parallel_numpy_ufunc, scipy.special.betainc)

# jax.scipy.special.gammaincinv does not exist
# https://github.com/jax-ml/jax/issues/5350
gammaincinv = partial(apply_parallel_numpy_ufunc, scipy.special.gammaincinv)


# jax.scipy.special.stdtr does not exist
def stdtr(df, t):
    x = df / (t ** 2 + df)
    tail = betainc(df / 2, 0.5, x) / 2
    return jnp.where(t < 0, tail, 1 - tail)


for func, scipy_func, caller in (
    (betainc, scipy.special.betainc, lambda f, x: f(4.5, 0.5, x)),
    (gammaincinv, scipy.special.gammaincinv, lambda f, x: f(4.5, x)),
    (stdtr, scipy.special.stdtr, lambda f, x: f(9., x)),
):
    x = np.linspace(0, 1, 101)
    expect = caller(scipy_func, x)

    for device in (cpu, gpu):
        jax.config.update("jax_default_device", device)
        actual = caller(jax.jit(func), x)
        np.testing.assert_allclose(actual, expect, strict=True)

In [10]:
@partial(jax.jit, static_argnames=("samples", "seed"))
def copula_jax(cov, df, samples, seed):
    dimensions = cov.shape[0]
    L = jnp.linalg.cholesky(cov)

    if df is None:
        y = sobol_jax(samples, dimensions, d0=seed)
    else:
        yr = sobol_jax(samples, dimensions + 1, d0=seed)
        y = yr[:, :-1]
        r = yr[:, -1:]

    y = jax.scipy.stats.norm.ppf(y)
    p = (L @ y.T).T
    if df is None:  # Gaussian Copula
        return p

    s = 2 * gammaincinv(df/2, r)  # Same as stats.chi2.ppf(r, df)
    z = jnp.sqrt(df / s) * p
    # Convert t distribution to normal (0, 1)
    u = stdtr(df, z)  # Same as stats.t.cdf(z, df)
    return jax.scipy.stats.norm.ppf(u)


expect_g = pyscenarios.gaussian_copula(cov, samples=10, rng="Sobol")
expect_t = pyscenarios.t_copula(cov, df=9, samples=10, rng="Sobol")

for device in (cpu, gpu):
    jax.config.update("jax_default_device", device)
    actual = copula_jax(cov, df=None, samples=10, seed=0)
    assert actual.device == device
    np.testing.assert_allclose(actual, expect_g, strict=True)

    actual = copula_jax(cov, df=9, samples=10, seed=0)
    assert actual.device == device
    np.testing.assert_allclose(actual, expect_t, strict=True)

In [12]:
# BENCHMARK

cov = generate_positive_definite_cov_matrix(DIMENSIONS)
copula_jax.clear_cache()

for label, df, np_func in (
    ("Gaussian Copula", None, pyscenarios.gaussian_copula),
    ("T Copula", 9, partial(pyscenarios.t_copula, df=9)),
):
    print(f"Numba {label}")
    %timeit -n 1 -r 1 np_func(cov, samples=SAMPLES, rng="Sobol")

    print(f"\nDask+Numba {label}")
    %timeit -n 1 -r 1 np_func(cov, samples=SAMPLES, rng="Sobol", chunks=CHUNKS).compute()

    for device in (cpu, gpu):
        print(f"\nJAX {label} {device.platform}")
        jax.config.update("jax_default_device", device)
        %timeit -n 1 -r 1 copula_jax(cov, df, SAMPLES, DIMENSIONS).block_until_ready()
        %timeit -n 1 -r 1 copula_jax(cov, df, SAMPLES, DIMENSIONS).block_until_ready()

    print()

Numba Gaussian Copula
5.57 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

Dask+Numba Gaussian Copula
8.39 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

JAX Gaussian Copula cpu
533 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
487 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

JAX Gaussian Copula gpu
4.58 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
217 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

Numba T Copula
29.1 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

Dask+Numba T Copula
5.28 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

JAX T Copula cpu
2.05 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
1.9 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

JAX T Copula gpu
4.06 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
1.66 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

