In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
r = lambda: np.random.uniform(-5, 5)

def R(x):
    x1, x2 = x
    R_lims = {
        0: ((-4, 0), (-4, 0)),
        1: ((-1, 3), (-3, 1)),
        2: ((-2, 1), (2, 4)),
        3: ((-3, 4), (-1, 3)),
    }
    
    R_lims = {
        i: ((r(), r()), (r(), r()))
        for i in range(10)
    }
    return np.array([
        x1_min <= x1 <= x1_max and x2_min <= x2 <= x2_max
        for (x1_min, x1_max), (x2_min, x2_max) in R_lims.values()
    ]).astype(float)


xl = [
    ((r(), r()), 1)
    for _ in range(100)
]

In [None]:
x, y = np.meshgrid(np.linspace(-5, 5, 100), np.linspace(-5, 5, 100))
X = np.stack([x, y], axis=-1)
Z = np.array([R(x) for x in X.reshape(-1, 2)]).reshape(100, 100, -1)
n = Z.shape[-1]
Z = Z.sum(axis=-1)

plt.figure(figsize=(4, 4))
plt.contourf(x, y, Z, levels=20)
for x, y in xl:
    plt.plot(*x, 'ro')
plt.show()

In [None]:
%env JAX_PLATFORMS=cpu
import jax.numpy as jnp
import jax
from jax import config
config.update("jax_enable_x64", True)

In [None]:
mu0 = -1
sigma0 = 5
wl = jnp.array([y * R(x) for x, y in xl]).astype(float).T
x0 = jnp.zeros(n)


def f(x):
    return 1 / (1 + jnp.exp(-x))


def log_prior(x):
    return -0.5 * jnp.sum(((x - mu0) / sigma0) ** 2, axis=-1)


def log_likelihood(x):
    return jnp.sum(jnp.log(f(x@ wl)), axis=-1)


def log_posterior(x):
    return log_likelihood(x) + log_prior(x)


print(f(x0))
print(log_likelihood(x0))
print(log_prior(x0))
print(log_posterior(x0))

In [None]:
def plot_corner(
    n,
    log_pdf,
    xs=np.linspace(-10, 10, 10),
    start=None,
    map=None,
    transform_x=False,
    transform_y=False,
    title="",
):
    Theta = jnp.stack(jnp.meshgrid(*[xs] * n), axis=-1)
    Z = log_pdf(Theta)
    if transform_y:
        Z = jnp.exp(Z)

    if transform_x:
        xs = 1 / (1 + jnp.exp(-xs))
        map = 1 / (1 + jnp.exp(-map)) if map is not None else None
        start = 1 / (1 + jnp.exp(-start)) if start is not None else None

    plt.figure(figsize=(12, 10))
    plt.suptitle(title)
    for i in range(n):
        for j in range(i + 1, n):
            other = tuple(a for a in range(n) if a not in (i, j))
            plt.subplot(n-1, n-1, i * (n-1) + j)
            plt.contourf(xs, xs, Z.max(other), levels=50)
            #plt.colorbar()
            if map is not None:
                plt.plot(map.take(j), map.take(i), "rx")
            if start is not None:
                plt.plot(start.take(j), start.take(i), "kx")

In [None]:
from scipy import optimize
target = lambda x: -log_posterior(x)
jac = jax.jacobian(target)
hes = jax.hessian(target)

print(target(x0))
print(jac(x0))
print(hes(x0))

sol = optimize.minimize(target, x0, jac=jac, hess=hes, method='Newton-CG')
sol

In [None]:
plot_corner(
    n,
    log_prior,
    map=sol.x,
    start=x0,
    transform_y=True,
    transform_x=True,
    title="log Prior",
)
plot_corner(
    n,
    log_posterior,
    map=sol.x,
    start=x0,
    transform_y=True,
    transform_x=True,
    title="log Posterior",
)

In [None]:
import optimistix as optx

solver = optx.BFGS(rtol=1e-5, atol=1e-5)
sol = optx.minimise(lambda x, p: target(x), solver, x0)
sol.value

In [None]:
plot_corner(
    n,
    log_posterior,
    map=sol.value,
    start=x0,
    transform_y=True,
    transform_x=True,
    title="log Posterior",
)