# Flat Likelihood, Flat Prior
*by* **David W. Hogg** (NYU) (MPIA) (Flatiron)

## Goals:
Demonstrate that you can have a peak in your posterior, or your marginalized likelihood, even when your likelihood has no peak, and your priors are flat.

## To-do:
- Construct marginalized likelihood.
- Make publication-worthy plots.

In [None]:
import numpy as np
import pylab as plt
import emcee
import corner
RNG = np.random.default_rng(17)

In [None]:
N, P, Q = 12, 8, 2

In [None]:
def design_matrix(n, p, q, rng=RNG):
    assert q <= p < n
    return rng.normal(size=(n, q)) @ rng.normal(size=(q, p))

def log_like(a, X, ys, ivars):
    resids = ys - X @ a
    return -0.5 * resids @ (ivars * resids)

def wls(X, ys, ivars):
    return np.linalg.solve(X.T @ (ivars[:, None] * X), X.T @ (ivars * ys))

def log_profile_like(ai, i, X, ys, ivars):
    Xi = np.delete(X, (i), axis=1)
    a_hat = wls(Xi, ys - X[:, i] * ai, ivars)
    a = np.insert(a_hat, i, ai)
    return log_like(a, X, ys, ivars)

PRIOR_LIMITS = np.zeros((P, 2))
PRIOR_LIMITS[:, 0] = -5.
PRIOR_LIMITS[:, 1] = 5.
PRIOR_LIMITS[0] = -10., 10.
def log_prior(a):
    if np.any(a < PRIOR_LIMITS[:, 0]):
        return -np.Inf
    if np.any(a > PRIOR_LIMITS[:, 1]):
        return -np.Inf
    return 0.

def log_post(a, X, ys, ivars):
    lnpi = log_prior(a)
    if np.isfinite(lnpi):
        return lnpi + log_like(a, X, ys, ivars)
    else:
        return -np.Inf

def mh_mcmc_step(log_post, a, proposal, X, ys, ivars, rng=RNG):
    a_new = a + proposal * rng.normal(size=a.shape)
    lr = np.log(rng.uniform())
    if log_post(a_new, X, ys, ivars) - log_post(a, X, ys, ivars) > lr:
        return a_new
    else:
        return a

In [None]:
X = design_matrix(N, P, Q)
print(X.shape)

In [None]:
a_true = RNG.normal(size=P)
ys_true = X @ a_true
ivars = 100. * np.ones_like(ys_true)
ys = ys_true + RNG.normal(size=N) / np.sqrt(ivars)

In [None]:
a_hat = wls(X, ys, ivars)
np.linalg.cond(X.T @ (ivars[:, None] * X))

In [None]:
plt.errorbar(np.arange(N), ys, yerr=1/np.sqrt(ivars), color="k", fmt="o")
plt.errorbar(np.arange(N), X @ a_hat, color="r", fmt="o", mfc="none")

In [None]:
ays = np.arange(-5.0, 5.001, 0.25)
lpls = np.zeros_like(ays)
for i in range(P):
    for k, ai in enumerate(ays):
        lpls[k] = log_profile_like(ai, i, X, ys, ivars)
    profile_like_ratios = np.exp(lpls - np.max(lpls))
    plt.plot(ays, profile_like_ratios)
    plt.xlim(-5., 5.)
    plt.ylim(-0.1, 1.1)
    plt.axhline(0., color="k", lw=0.5, alpha=0.5)

In [None]:
nwalkers, nsample = 100, 10000
p0 = 0.01 * RNG.normal(size=(nwalkers, P))
sampler = emcee.EnsembleSampler(nwalkers, P, log_post, args=[X, ys, ivars])
state = sampler.run_mcmc(p0, 1000) # burn in
sampler.reset()
state = sampler.run_mcmc(state, nsample)

In [None]:
a_samples = sampler.get_chain(flat=True)
figure = corner.corner(a_samples,
                       range=PRIOR_LIMITS,
                       labels=("theta", ) + (P-1) * ("alpha", ))