# Likelihood and uncertainty
A notebook to illustrate the relationships between the likelihood function and different uncertainty estimates.

## Authors:
**David W. Hogg** (NYU)

## License:
- Copyright 2025 the author. All code is licensed for re-use under the open-source *MIT License*.

## Notes:
- Some overlap with `basic_inference_example.ipynb`.

## To-do:
- Make all plots consistent across all noteboooks, so they are publication-ready.

## Bugs:
- Various things hard-coded.
- Plots of uncertainty contours and LLF images have repeated code.

In [None]:
import jax
import jax.numpy as jnp
import numpy.random as random
import pylab as plt
from matplotlib import rcParams
import scipy.optimize as op

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
rcParams['figure.figsize'] = [4.0, 4.0]

In [None]:
# set default global stuff (apologies)

N = 7
p = 2
prior_bounds = jnp.array([[1., 1.5 * jnp.pi], [0., 2. * jnp.pi]])
assert prior_bounds.shape == (p, 2)
true_omega = 2.13 # hard-coded global magic variable

In [None]:
# make fake data

def expectation(ts, pars):
    amp, phi = pars
    return amp * jnp.cos(true_omega * ts - phi)

def make_fake_data(seed=17):
    rng = random.default_rng(seed)
    ts = jnp.sort(7. * rng.uniform(size=N))
    ivars = 0.25 * jnp.array(1. + 1. * rng.uniform(size=N)) # magic
    truepars = jnp.array([2.59, 2.0344]) # magic
    return ts, expectation(ts, truepars) + rng.normal(size=N) / jnp.sqrt(ivars), ivars, truepars

ts, ys, ivars, true_pars = make_fake_data()
print(ts.shape, ys.shape, true_pars)

In [None]:
# plot data

def plot(ts, ys, ivars, true_pars, ml_pars, samples, title):
    plt.errorbar(ts, ys, yerr=1./jnp.sqrt(ivars), fmt="ko")
    plot_ts = jnp.linspace(0., 7., 1000)
    if samples is not None:
        for sample in samples:
            plt.plot(plot_ts, expectation(plot_ts, sample), "r-", lw=1, alpha=0.45)
    if true_pars is not None:
        plt.plot(plot_ts, expectation(plot_ts, true_pars), "b-", lw=1, alpha=0.45)
    if ml_pars is not None:
        plt.plot(plot_ts, expectation(plot_ts, ml_pars), "r-", lw=2, alpha=0.9)
    plt.xlabel("time")
    plt.ylabel("data value")
    plt.title(title)

plot(ts, ys, ivars, true_pars, None, None, "data and true expectation")

In [None]:
# define likelihood in terms of phase

def negative_log_likelihood(pars, ts, ys, ivars):
    return 0.5 * jnp.sum(ivars * (ys - expectation(ts, pars)) ** 2)

In [None]:
res = op.minimize(negative_log_likelihood, true_pars, args=(ts, ys, ivars))
print(res)
ml_pars = jnp.zeros(p) + jnp.nan
ml_pars_covar = jnp.zeros((p, p)) + jnp.nan
if res.success:
    ml_pars = res.x
    ml_pars_covar = res.hess_inv
ml_pars[1] = jnp.arctan2(jnp.sin(ml_pars[1]), jnp.cos(ml_pars[1])) # angle zero issues
print(ml_pars)

In [None]:
plot(ts, ys, ivars, true_pars, ml_pars, None, "maximum-likelihood estimate")

In [None]:
# define functions in terms of amplitudes

def design_matrix(ts):
    return jnp.vstack([jnp.cos(true_omega * ts), jnp.sin(true_omega * ts)]).T

def ml_amplitudes(ts, ys, ivars):
    X = design_matrix(ts)
    return jnp.linalg.solve(X.T @ (ivars[:, None] * X), X.T @ (ivars * ys))

def amps_to_pars(amps):
    a, b = amps
    return jnp.array([jnp.sqrt(a ** 2 + b ** 2), jnp.arctan2(b, a)])
    
def pars_to_amps(pars):
    A, phi = pars
    return jnp.array([A * jnp.cos(phi), A * jnp.sin(phi)])

In [None]:
# check that everyone is cool

ml_amps = ml_amplitudes(ts, ys, ivars)
schml_pars = amps_to_pars(ml_amps)
print(schml_pars, jnp.allclose(ml_pars, schml_pars))

In [None]:
# take second derivatives of the two formulations of the likelihood

def negative_log_likelihood_pars(pars):
    return negative_log_likelihood(pars, ts, ys, ivars)

def negative_log_likelihood_amps(amps):
    return negative_log_likelihood(amps_to_pars(amps), ts, ys, ivars)

Cinv_pars = jax.hessian(negative_log_likelihood_pars)(ml_pars)
Cinv_amps = jax.hessian(negative_log_likelihood_amps)(ml_amps)
X = design_matrix(ts)
print(Cinv_pars)
print(Cinv_amps)
print(jnp.allclose(X.T @ (ivars[:, None] * X), Cinv_amps))

In [None]:
# invert hessians

C_pars, C_amps = jnp.linalg.inv(Cinv_pars), jnp.linalg.inv(Cinv_amps)
print(C_pars)
print(C_amps)

In [None]:
# make a likelihood image

amplim = (0., 8.)
ampvec = jnp.linspace(amplim[0], amplim[1], 301)
philim = (0., 2. * jnp.pi)
phivec = jnp.linspace(philim[0], philim[1], 301)
amps, phis = jnp.meshgrid(ampvec, phivec)
def foo(a, p):
    return -1. * negative_log_likelihood((a, p), ts, ys, ivars)
lls = jax.vmap(foo)(amps.flatten(), phis.flatten()).reshape(amps.shape)
print(jnp.min(lls), jnp.max(lls))

In [None]:
# code to draw ellipses

def matrix_sqrt(C):
    assert jnp.allclose(C, C.T)
    w, v = jnp.linalg.eigh(C)
    return jnp.sqrt(w)[:, None] * v

def draw_ellipse(ax, center, C, **kwargs):
    assert C.shape == (2, 2)
    thetas = jnp.linspace(0., 2. * jnp.pi, 101)
    xy = center[:, None] + matrix_sqrt(C).T @ jnp.vstack([jnp.cos(thetas), jnp.sin(thetas)])
    ax.plot(xy[0], xy[1], **kwargs)
    return ax

In [None]:
def plot_llf_image(xvec, yvec, lls, point=None, truepoint=None):
    mlls = jnp.max(lls)
    plt.imshow(jnp.exp(lls - mlls), interpolation="nearest", origin="lower",
               extent=[min(xvec), max(xvec), min(yvec), max(yvec)],
               vmin=0, vmax=1, cmap="gray_r", aspect="auto")
    plt.contour(xvec, yvec, lls - mlls, origin="lower",
                levels=[-0.5,], colors="r", linestyles="solid", linewidths=0.5, alpha=0.9)
    if point is not None:
        plt.scatter([point[0], ], [point[1], ], marker="x", c="r",
                    s=20., alpha=0.9)
    if truepoint is not None:
        plt.scatter([truepoint[0], ], [truepoint[1], ], marker="x", c="b",
                    s=20., linewidths=0.5, alpha=0.5)
    return plt.gca()

ax = plot_llf_image(ampvec, phivec, lls, point=ml_pars, truepoint=true_pars)
ax.set_xlabel("amplitude $A$")
ax.set_ylabel("phase $\phi$")
ax.set_title("likelihood")

In [None]:
# make another likelihood image

ayvec = jnp.linspace(-amplim[1], amplim[1], 301)
ays, bees = jnp.meshgrid(ayvec, ayvec)
def foo(a, b):
    return -1. * negative_log_likelihood(amps_to_pars((a, b)), ts, ys, ivars)
lls2 = jax.vmap(foo)(ays.flatten(), bees.flatten()).reshape(ays.shape)
print(jnp.min(lls2), jnp.max(lls2))

In [None]:
ax = plot_llf_image(ayvec, ayvec, lls2, point=ml_amps, truepoint=pars_to_amps(true_pars))
ax.set_xlabel("cosine amplitude $a$")
ax.set_ylabel("sine amplitude $b$")
ax.set_title("likelihood")

In [None]:
mlls = jnp.max(lls)
plt.contour(ampvec, phivec, lls - mlls, origin="lower",
            levels=[-0.5,], colors="k", linestyles="solid", linewidths=0.5, alpha=0.9)
plt.plot(jnp.zeros(2), jnp.zeros(2), "k-", linewidth=0.5, alpha=0.9, label="1-sigma region")
draw_ellipse(plt.gca(), ml_pars, C_pars, color="k", linestyle="dashed", linewidth=1.0, alpha=0.9,
             label="2nd-derivative approximation $C$")
plt.scatter([ml_pars[0], ], [ml_pars[1], ], marker="x", c="k",
            s=20., alpha=0.9)
tiny = 0.03 * jnp.sqrt(C_pars[1, 1])
badsigma = 1. / jnp.sqrt(Cinv_pars[0, 0])
plt.plot([ml_pars[0] - badsigma, ml_pars[0] + badsigma], [ml_pars[1], ml_pars[1]], "k", lw=4.0, alpha=0.23,
         label=r"$\pm([C^{-1}]_{0,\!0})^{-1/2}$ (wrong)")
goodsigma = jnp.sqrt(C_pars[0, 0])
plt.plot([ml_pars[0] - goodsigma, ml_pars[0] + goodsigma], [ml_pars[1], ml_pars[1]], "k", lw=1.0, alpha=0.9,
         label=r"$\pm([C]_{0,\!0})^{1/2}$")
plt.xlim(2, 6)
plt.ylim(1, 4)
plt.xlabel("amplitude $A$")
plt.ylabel("phase $\phi$")
plt.title("zoom in on uncertainty contour")
plt.legend()

In [None]:
mlls2 = jnp.max(lls2)
plt.contour(ayvec, ayvec, lls2 - mlls2, origin="lower",
            levels=[-0.5,], colors="r", linestyles="solid", linewidths=0.5, alpha=0.9)
draw_ellipse(plt.gca(), ml_amps, C_amps, color="r", linestyle="dashed", linewidth=1.5, alpha=0.45)
plt.scatter([ml_amps[0], ], [ml_amps[1], ], marker="x", c="r",
            s=20., alpha=0.9)
plt.xlim(-4, 1)
plt.ylim(1, 6)
plt.xlabel("cosine amplitude $a$")
plt.ylabel("sine amplitude $b$")
plt.title("zoom in on uncertainty contour")