In [19]:
import jax

print("JAX version:", jax.__version__)
print("Backend:", jax.default_backend())
print("Devices:", jax.devices())

JAX version: 0.8.0
Backend: gpu
Devices: [CudaDevice(id=0)]


In [2]:
!nvidia-smi

Wed Oct 29 15:59:42 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   43C    P0             27W /   70W |     110MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [9]:
import jax.numpy as jnp
import numpyro
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler

tfd = numpyro.distributions

dataset = load_breast_cancer()
scaler = StandardScaler()
X = scaler.fit_transform(dataset.data).astype("float32")
y = dataset.target.astype("float32")

X = jnp.asarray(X)
y = jnp.asarray(y)
n_features = X.shape[1]
print("Dataset shape:", X.shape, y.shape)

Dataset shape: (569, 30) (569,)


In [10]:
def joint_log_prob(x, y, tau, lamb, beta):
    lp = tfd.Gamma(0.5, 0.5).log_prob(tau)
    lp += tfd.Gamma(0.5, 0.5).log_prob(lamb).sum()
    lp += tfd.Normal(0.0, 1.0).log_prob(beta).sum()
    logits = x @ (tau * lamb * beta)
    lp += tfd.Bernoulli(logits=logits).log_prob(y).sum()
    return lp

In [15]:
from jax import random

key = random.key(0)           # initialize PRNG key
beta = random.uniform(key, (30,), minval=0.0, maxval=1.0)

print(beta)

[0.947667   0.9785799  0.33229148 0.46866846 0.5698887  0.16550303
 0.3101946  0.68948054 0.74676657 0.17101455 0.9853538  0.02528262
 0.6400418  0.56269085 0.8992138  0.93453753 0.8341402  0.7256162
 0.5098531  0.02765214 0.03148878 0.9580188  0.5188192  0.79221416
 0.5522419  0.6113529  0.8931755  0.75499094 0.21164179 0.22934973]


In [16]:
joint_log_prob(X, y, 1.0, 1.0, beta)

Array(-4869.8623, dtype=float32)

In [20]:
def unconstrained_joint_log_prob(x, y, z):
    ndims = x.shape[-1]
    unc_tau, unc_lamb, beta = jnp.split(z, [1, 1 + ndims])
    unc_tau = unc_tau.reshape([])
    tau = jnp.exp(unc_tau)
    lamb = jnp.exp(unc_lamb)
    ldj = unc_tau + unc_lamb.sum()
    return joint_log_prob(x, y, tau, lamb, beta) + ldj

target_log_prob = lambda z: unconstrained_joint_log_prob(X, y, z)

In [21]:
target_log_prob_and_grad = jax.value_and_grad(target_log_prob)

dim = 1 + n_features + n_features  # tau + lamb + beta
z_init = jnp.zeros((dim,))

logp, grad = target_log_prob_and_grad(z_init)
print("Initial log-density:", float(logp))
print("Gradient L2 norm:", float(jnp.linalg.norm(grad)))

Initial log-density: -465.95599365234375
Gradient L2 norm: 803.6372680664062
