In [None]:
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import numpy.random as random
import pylab as plt
jax.config.update("jax_enable_x64", True)

In [None]:
# set globals and defaults

N = 16
OMEGA = 1.
plt.rcParams.update({
    "text.usetex": True,
    "figure.figsize": (4, 4),
})

In [None]:
# make fake data

def design_matrix(xs):
    return jnp.vstack([jnp.cos(OMEGA * xs), jnp.sin(OMEGA * xs)]).T

def make_data(seed=17):
    """
    ## bugs:
    - Lots of stuff hard coded.
    """
    rng = random.default_rng(seed)
    true_amps = rng.normal(size=2)
    xs = jnp.sort(4. * jnp.pi * rng.uniform(size=N))
    ws = jnp.zeros_like(xs) + 1. / (1.2 ** 2)
    ys = design_matrix(xs) @ true_amps + rng.normal(size=N) / jnp.sqrt(ws)
    return true_amps, xs, ys, ws

truth, xs, ys, ivars = make_data()
print(truth, xs.shape, ys.shape)

In [None]:
# two parameter forms; conversions

def amps2pars(amps):
    a, b = amps
    return jnp.array([jnp.sqrt(a ** 2 + b ** 2), jnp.arctan2(b, a)])

def pars2amps(pars):
    A, phi = pars
    return jnp.array([A * jnp.cos(phi), A * jnp.sin(phi)])

true_pars = amps2pars(truth)
print(truth, true_pars, pars2amps(true_pars))

In [None]:
_ = plt.errorbar(xs, ys, yerr=1. / jnp.sqrt(ivars), linestyle="none", marker="o", color="k")
plotxs = jnp.linspace(0., 4. * jnp.pi, 1000)
plt.plot(plotxs, design_matrix(plotxs) @ truth, "b-", lw=0.5, alpha=0.5)
plt.plot(plotxs, true_pars[0] * jnp.cos(OMEGA * plotxs - true_pars[1]), "b--", alpha=0.5)

In [None]:
# make probability functions

def log_gaussian(resids, ivars):
    """
    ## inputs:
    - resids:  residuals (x - mu)
    - ivars:   diagonal elements of a presumed-diagonal *inverse* covariance matrix
    """
    return 0.5 * jnp.sum(jnp.log(ivars) - ivars * resids ** 2)

def log_likelihood(pars, xs, ys, ivars):
    return log_gaussian(ys - design_matrix(xs) @ pars2amps(pars), ivars)

In [None]:
# make likelihood image

def function_image(xlim, nx, ylim, ny, funky):
    dx = (xlim[1] - xlim[0]) / nx
    xvec = jnp.arange(xlim[0] + 0.5 * dx, xlim[1], dx)
    dy = (ylim[1] - ylim[0]) / ny
    yvec = jnp.arange(ylim[0] + 0.5 * dy, ylim[1], dy)
    xs, ys = jnp.meshgrid(xvec, yvec)
    return xvec, yvec, jax.vmap(funky)(xs.flatten(), ys.flatten()).reshape(xs.shape)

amplim = (0., 2.)
philim = (0., 2. * jnp.pi)
def foo(a, p):
    return log_likelihood((a, p), xs, ys, ivars)
ampvec, phivec, test = function_image(amplim, 16, philim, 8, foo)
print(jnp.min(test), jnp.max(test))

In [None]:
# make 2-d plot

def plot_llf_image(xlim, nx, ylim, ny, llf_function, point=None, truepoint=None):
    xvec, yvec, lls = function_image(xlim, nx, ylim, ny, llf_function)
    mlls = jnp.max(lls)
    plt.imshow(jnp.exp(lls - mlls), interpolation="nearest", origin="lower",
               extent=xlim + ylim,
               vmin=0, vmax=1, cmap="gray_r", aspect="auto")
    plt.contour(xvec, yvec, lls - mlls, origin="lower",
                levels=[-0.5,], colors="r", linestyles="solid", linewidths=0.5, alpha=0.9)
    if point is not None:
        plt.scatter([point[0], ], [point[1], ], marker="x", c="r",
                    s=20., alpha=0.9)
    if truepoint is not None:
        plt.scatter([truepoint[0], ], [truepoint[1], ], marker="x", c="b",
                    s=20., linewidths=0.5, alpha=0.5)
    return plt.gca(), xvec, yvec, lls

ax, ampvec, phivec, lls = plot_llf_image(amplim, 300, philim, 256, foo)
ax.set_xlabel("amplitude $A$")
ax.set_ylabel("phase $\phi$")
ax.set_title("likelihood")
print(lls.shape)

In [None]:
# make 1-d plot

lls_marginal = logsumexp(lls, axis=0) - jnp.log(len(phivec))
lls_marginal -= jnp.max(lls_marginal)
plt.step(ampvec, jnp.exp(lls_marginal), where="mid", color="k")
ax = plt.gca()
ax.set_ylim(-0.1, 1.1)
ax.set_xlabel("amplitude $A$")
ax.set_ylabel("likelihood (relative to maximum)")
ax.set_title("marginal likelihood")
print(lls_marginal.shape)