In [1]:
import math
from functools import partial

import jax.numpy as jnp
import jax
import numpy as np
import quadax
from scipy.special import betainc, beta

jax.config.update("jax_enable_x64", True)
cpu = jax.devices("cpu")[0]
gpu = jax.devices("gpu")[0]
jax.config.update("jax_default_device", gpu)

In [None]:
a = 4.5
b = 0.5
x_np = np.random.random((65535, 2))
# x_np = np.random.random((1000, 2))
x_jnp = jnp.asarray(x_np)
x_np

In [None]:
baseline = betainc(a, b, x_np)
baseline

In [None]:
beta(a, b)

In [None]:
jax.scipy.special.beta(a, b)

In [None]:
deltas = jax.scipy.special.betainc(a, b, x_jnp) - jnp.asarray(baseline)
deltas = jnp.take(deltas.reshape((-1, )), jnp.argsort(x_jnp.reshape((-1, ))))
jnp.abs(deltas).max()

In [None]:
import pandas as pd
pd.Series(np.asarray(deltas)).plot()

In [None]:
@jax.jit
def quadax_betainc(a, b, x):
    shape = x.shape
    x = jnp.reshape(x, (-1, ))
    beta_ = jax.scipy.special.beta(a, b)
    res = jax.vmap(lambda x: 
        quadax.quadgk(
            lambda t, am1, bm1: t**am1 * (1 - t)**bm1,
            interval=(0, x),
            args=(a - 1, b - 1),
        )[0])(x) / beta_
    res = jnp.where((x == 0) | (x == 1), x, res)
    res = jnp.where((x < 0) | (x > 1), jnp.nan, res)
    return jnp.reshape(res, shape)

np.testing.assert_allclose(
    jax.scipy.special.betainc(4.5, 0.5, x_jnp),
    betainc(a, b, x_np), 
    rtol=1e-10,
)
np.testing.assert_allclose(
    quadax_betainc(4.5, 0.5, x_jnp),
    baseline,
    rtol=1e-11,
)

In [None]:
%timeit -r 1 betainc(4.5, 0.5, x_np)
%timeit -r 1 jax.scipy.special.betainc(4.5, 0.5, x_jnp).block_until_ready()
%timeit -r 1 quadax_betainc(4.5, 0.5, x_jnp).block_until_ready()

In [None]:
@partial(jax.jit, static_argnames="ncols")
def to_sorted_rect(x, ncols):
    x = jnp.reshape(x, (-1, ))
    nrows = int(math.ceil(x.size / ncols))
    if nrows * ncols > x.size:
        y = jnp.empty(nrows * ncols, dtype=x.dtype)
        y = y.at[:x.size].set(x)
        y = y.at[x.size:].set(x[0])
    else:
        y = x
    y = jnp.reshape(y, (nrows, ncols))
    idx = jnp.argsort(y, axis=1)
    y = jnp.take_along_axis(y, idx, axis=1)
    return y, idx

@partial(jax.jit, static_argnames="shape")
def from_sorted_rect(x, idx, shape):
    revidx = jnp.argsort(idx, axis=1)
    x = jnp.take_along_axis(x, revidx, axis=1)
    x = jnp.reshape(x, (-1, ))
    x = x[:math.prod(shape)]
    return jnp.reshape(x, shape)

y, idx = to_sorted_rect(x_jnp, 50)
assert jnp.all(y[:, 1:] >= y[:, :-1])
x2 = from_sorted_rect(y, idx, x_jnp.shape)
np.testing.assert_array_equal(x_jnp, x2, strict=True)

In [None]:
@jax.jit
def quadax_betainc_breakpoints(a, b, x):
    N_BREAKPOINTS = 50
    scratch, idx = to_sorted_rect(x, N_BREAKPOINTS)
    scratch = jnp.clip(scratch, 0, 1)
    scratch = jnp.concat(
        [jnp.zeros((scratch.shape[0], 1), dtype=x.dtype), scratch],
        axis=1,
    )

    y = jax.vmap(lambda interval: 
        quadax.quadgk(
            lambda t, am1, bm1: t**am1 * (1 - t)**bm1,
            interval=interval,
            args=(a - 1, b - 1),
            full_output=True,
        )[1].info["s_arr"]
    )(scratch)

    y = from_sorted_rect(y, idx, x.shape)
    y = jnp.where((x == 0) | (x == 1), x, y)
    y = jnp.where((x < 0) | (x > 1), jnp.nan, y)
    beta_ = jax.scipy.special.beta(a, b)
    y = y / beta_
    # y = jnp.where((x <= 0) | (x >= .99), jax.scipy.special.betainc(a, b, x), y)
    return y

actual = quadax_betainc_breakpoints(a, b, x_jnp)
np.testing.assert_allclose(actual, baseline, rtol=1e-10)

In [None]:
mask = np.abs(actual - baseline) > 1e-6 * np.abs(baseline)
x_np[mask].size, x_np[mask].min()

In [None]:
x_np[mask]

In [None]:
actual[mask]

In [None]:
baseline[mask]

In [None]:
%timeit -r 1 quadax_betainc_breakpoints(a, b, x_jnp).block_until_ready()

In [None]:
def to_sorted(x):
    x = jnp.reshape(x, (-1, ))
    idx = jnp.argsort(x)
    x = jnp.take(x, idx)
    return x, idx

def from_sorted(x, idx, shape):
    revidx = jnp.argsort(idx)
    x = jnp.take(x, revidx)
    return jnp.reshape(x, shape)

y, idx = to_sorted(x_jnp)
assert jnp.all(y[1:] >= y[:-1])
x2 = from_sorted(y, idx, x_jnp.shape)
np.testing.assert_array_equal(x_jnp, x2, strict=True)

In [None]:
@jax.jit
def quadax_betainc_cumulative(a, b, x):
    shape = x.shape
    stop, idx = to_sorted(x)
    start = jnp.roll(stop, 1).at[0].set(0)
    am1 = a - 1
    bm1 = b - 1

    @jax.vmap
    def partial_integral(start, stop):
        return quadax.quadgk(
            lambda t: t**am1 * (1 - t)**bm1,
            interval=(start, stop),
        )[0]

    partials = partial_integral(start, stop)
    cums = jnp.cumsum(partials)
    beta_ = jax.scipy.special.beta(a, b)
    res = cums / beta_

    return from_sorted(res, idx, x.shape)


np.testing.assert_allclose(
    quadax_betainc_cumulative(a, b, x_jnp),
    baseline,
    rtol=1e-10,
)

In [None]:
%timeit -r 1 quadax_betainc_cumulative(a, b, x_jnp).block_until_ready()

In [None]:
@jax.jit
def quadax_betainc_vec(a, b, x):
    xc = jnp.clip(x, 0, 1)
    y, info = quadax.quadgk(
        lambda t, x, am1, bm1: (t * x)**am1 * (1 - t * x)**bm1 * x,
        interval=(0, 1),
        args=(x, a - 1, b - 1),
        #max_ninter=20,
        order=15,
    )
    beta_ = jax.scipy.special.beta(a, b)
    y /= beta_
    return jnp.where((x < 0) | (x > 1), jnp.nan, y), info

actual, info = quadax_betainc_vec(a, b, x_jnp)
np.testing.assert_allclose(actual, baseline, rtol=1e-10)
print(info)

In [None]:
%timeit -r 1 quadax_betainc_vec(a, b, x_jnp)[0].block_until_ready()

In [None]:
def to_rect(x, ncols):
    x = jnp.reshape(x, (-1, ))
    nrows = int(math.ceil(x.size / ncols))
    if nrows * ncols > x.size:
        y = jnp.empty(nrows * ncols, dtype=x.dtype)
        y = y.at[:x.size].set(x)
        y = y.at[x.size:].set(x[0])
    else:
        y = x
    return jnp.reshape(y, (nrows, ncols))

def from_rect(x, shape):
    x = jnp.reshape(x, (-1, ))
    x = x[:math.prod(shape)]
    return jnp.reshape(x, shape)

y = to_rect(x_jnp, 50)
x2 = from_rect(y, x_jnp.shape)
np.testing.assert_array_equal(x_jnp, x2, strict=True)

In [None]:
@jax.jit
def quadax_betainc_vec_chunked(a, b, x):
    CHUNK_SIZE = 32
    scratch, idx = to_sorted_rect(x, CHUNK_SIZE)

    am1 = a - 1
    bm1 = b - 1
    y = jax.vmap(lambda x: quadax.quadgk(
        lambda t, x, am1, bm1: (t * x)**am1 * (1 - t * x)**bm1 * x,
        interval=(0, 1),
        args=(x, am1, bm1),
    )[0])(scratch)
    y = from_sorted_rect(y, idx, x.shape)
    y = jnp.where((x == 0) | (x == 1), x, y)
    y = jnp.where((x < 0) | (x > 1), jnp.nan, y)
    beta_ = jax.scipy.special.beta(a, b)
    return y / beta_

actual = quadax_betainc_vec_chunked(a, b, x_jnp)
np.testing.assert_allclose(actual, baseline, rtol=1e-10)

In [None]:
%timeit -r 1 quadax_betainc_vec_chunked(a, b, x_jnp).block_until_ready()

In [None]:
@partial(jax.jit, static_argnames="n")
def simpson_betainc(a, b, x, n):
    beta_ = jax.scipy.special.beta(a, b)
    ndim = len(jnp.broadcast_shapes(jnp.shape(a), jnp.shape(b), jnp.shape(x)))
    t = jnp.linspace(0, 1, 2**n - 1)
    t = jnp.reshape(t, (-1, *(1, ) * ndim))
    y = (t * x)**(a - 1) * (1 - t * x)**(b - 1) * x
    y = quadax.simpson(y, x=t, axis=0)
    return y / beta_

idx = jnp.argsort(jnp.reshape(x_jnp, (-1)))
yi = {}
for n in range(10, 13):
    y = simpson_betainc(a, b, x_jnp, n=n)
    deltas = jnp.take(jnp.reshape(y - baseline, (-1, )), idx)
    yi[2**n] = deltas

In [None]:
%timeit -r 1 simpson_betainc(a, b, x_jnp, n=12).block_until_ready()

In [None]:
import pandas
df = pandas.DataFrame([np.asarray(v) for v in yi.values()], index=yi).T
df.columns.name = "n"
df.index = np.sort(x_np.reshape(-1))
df
df.plot(logy=True)

In [None]:
df.max(axis=0)

In [None]:
@partial(jax.jit, static_argnames="n")
def flatsimpson_betainc(a, b, x, n):
    beta_ = jax.scipy.special.beta(a, b)
    xflat = jnp.reshape(x, (-1, ))
    xflat = jnp.clip(xflat, 0, 1)
    t = jnp.linspace(0, jnp.max(xflat), 2**n)
    tmix = jnp.concat([t, xflat])
    tmix, revidx = jnp.unique_inverse(tmix, size=tmix.size, fill_value=jnp.nan)
    ymix = tmix**(a - 1) * (1 - tmix)**(b - 1)
    yimix = quadax.cumulative_simpson(ymix, x=tmix)
    yi = jnp.take(yimix, revidx[2**n:])
    yi = jnp.reshape(yi, x.shape)
    return yi / beta_

yi = {}
for n in range(10, 18):
    y = flatsimpson_betainc(a, b, x_jnp, n=n)
    deltas = jnp.take(jnp.reshape(y - baseline, (-1, )), idx)
    yi[2**n] = deltas

In [None]:
import pandas
df = pandas.DataFrame([np.asarray(v) for v in yi.values()], index=yi).T
df.columns.name = "n"
df.index = np.sort(x_np.reshape(-1))
df
df.plot(logy=True)

In [None]:
df.max(axis=0)