In [1]:
import jax
jax.config.update('jax_cuda_visible_devices', '0')

import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
from bllarse.losses.bayesian import IBProbit, IBPolyaGamma, MultinomialPolyaGamma


def generate_data(key, n_samples, n_features, n_classes):
    """Generate synthetic data for multinomial logistic regression."""
    key, subkey = jr.split(key)
    params = jr.normal(subkey, (n_features + 1, n_classes))
    W = params[:-1]
    b = params[-1]
    
    key, subkey = jr.split(key)
    X = jr.normal(subkey, (n_samples, n_features))
    
    logits = X @ W + b
    probs = jax.nn.softmax(logits, axis=-1)
    
    key, subkey = jr.split(key)
    y = jr.categorical(subkey, logits=logits, axis=-1)
    
    return X, y.astype(jnp.float32), probs, (W, b)


In [2]:
n_samples = 1024
n_features = 768
n_classes_range = jnp.arange(10, 31, 20)
key = jr.PRNGKey(0)

for n_classes in n_classes_range:
    X, y, *_ = generate_data(key, n_samples, n_features, n_classes)
    
    ib_probit = IBProbit(n_features, n_classes, key=key)
    ib_pg = IBPolyaGamma(n_features, n_classes, key=key)
    mpg = MultinomialPolyaGamma(n_features, n_classes, key=key)

    func = eqx.filter_jit(ib_probit.update)
    func(X, y, num_iters=128)
    print('num_classes =', n_classes, "; IBProbit")
    %timeit jax.block_until_ready(func(X, y, num_iters=128).eta)

    func = eqx.filter_jit(ib_pg.update)
    func(X, y)
    print('num_classes =', n_classes, "; IBPG")
    %timeit jax.block_until_ready(func(X, y).mu)

    func = eqx.filter_jit(mpg.update)
    func(X, y)
    print('num_classes =', n_classes, "; MPG")
    %timeit jax.block_until_ready( func(X, y).mu )

num_classes = 10 ; IBProbit
6.12 ms ± 15.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
num_classes = 10 ; IBPG
3.89 s ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
num_classes = 10 ; MPG
3.64 s ± 6.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
num_classes = 30 ; IBProbit
6.93 ms ± 38.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
num_classes = 30 ; IBPG
9.57 s ± 6.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
num_classes = 30 ; MPG
9.28 s ± 909 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [2]:
n_samples = 1024
n_features = 768
n_classes_range = jnp.arange(10, 101, 20)
key = jr.PRNGKey(0)

for n_classes in n_classes_range:
    X, y, *_ = generate_data(key, n_samples, n_features, n_classes)
    
    ib_probit = IBProbit(n_features, n_classes, key=key)

    func = eqx.filter_jit(ib_probit.update)
    func(X, y, num_iters=128)
    print('num_classes =', n_classes, "; IBProbit")
    %timeit jax.block_until_ready(func(X, y, num_iters=128).eta)

num_classes = 10 ; IBProbit
6.12 ms ± 17.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
num_classes = 30 ; IBProbit
6.52 ms ± 17.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
num_classes = 50 ; IBProbit
7.07 ms ± 12.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
num_classes = 70 ; IBProbit
9.85 ms ± 35.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
num_classes = 90 ; IBProbit
10.2 ms ± 41.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
