In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import seaborn as sns

from thesis_utils.plotting import set_plotting, get_default_figsize, save_figure

set_plotting()

In [None]:
bounds = 5 * np.array([-1, 1])
prior_u = stats.uniform(bounds[0], np.ptp(bounds))

In [None]:
# Likelihood distribution
l_dist = stats.norm()
# Likelihood distribution as a function of r
l_r_dist = stats.chi(df=1)

In [None]:
def likelihood(r: np.ndarray) -> np.ndarray:
    """Likelihood as a function of radius"""
    return l_r_dist.pdf(r)

Define a function to map between prior volume and radius. This uses the `interval` function from scipy distributions and will only work for symmetric distributions.

In [None]:
def prior_vol_to_radius(X: np.ndarray, dist: stats.rv_continuous):
    """Convert prior volume to radius for a given distribution"""
    x = np.abs(dist.interval(X))
    assert all(np.abs(x[0] - x[1]) < 1e-6)
    return x[0]

In [None]:
x_vec = np.linspace(*bounds, 500)
grid = np.array(np.meshgrid(x_vec, x_vec))
grid_flat = np.concatenate([
    grid[0].reshape(-1, 1),
    grid[1].reshape(-1, 1),
], axis=1)

In [None]:
r_grid = np.linalg.norm(grid_flat, axis=1)
likelihood_surface = likelihood(r_grid)

In [None]:
X_shade = [0.9, 0.60, 0.30, 0.1]
r_shade = prior_vol_to_radius(X_shade, prior_u)
l_shade = likelihood(r_shade)

contour_colours = np.zeros_like(likelihood_surface)

c_min = 0.25
c_max = 0.75
c_values = np.linspace(c_min, c_max, len(X_shade), endpoint=True)
for i, l in enumerate(l_shade):
    contour_colours[likelihood_surface >= l] = c_values[i]


In [None]:
X_vec = np.linspace(1.0, 0.0, 1_000, endpoint=True)
r_u = prior_vol_to_radius(X_vec, prior_u)
l_r = likelihood(r_u)

In [None]:
cmap = sns.cubehelix_palette(start=.5, rot=-.5, as_cmap=True)
cmap = sns.color_palette("Blues", as_cmap=True)

In [None]:
cmap

In [None]:
figsize = get_default_figsize() * 0.8
figsize[1] = figsize[0]

fig, axs = plt.subplots(1, 2, figsize=figsize)

axs[0].contourf(
    grid[0],
    grid[1],
    contour_colours.reshape(len(grid[0]), -1),
    cmap=cmap,
)

angle = np.pi / 4
for i, r in enumerate(r_shade):
    axs[0].text(
        r * np.cos(angle),
        r * np.sin(angle),
        r"$\mathcal{L}_{" + str(i) + r"}$",
    )

axs[0].set_box_aspect(1)
axs[0].set_xticks([])
axs[0].set_yticks([])

axs[1].plot(X_vec, l_r, color=cmap(1.0))
axs[1].set_yscale("log")
axs[1].set_xlim([0, 1])
axs[1].set_box_aspect(1)

colours = cmap(c_values)

axs[1].fill_between(X_vec, l_r, color=cmap(0.0))

for (i, X_i), L_i in zip(enumerate(X_shade), l_shade):
    cutoff = np.argmax(X_vec < X_i)
    axs[1].fill_between(
        X_vec[cutoff:],
        l_r[cutoff:],
        color=colours[i]
    )
    axs[1].text(
        X_i,
        L_i,
        r"$\mathcal{L}_{" + str(i) + r"}$",
    )
    # axs[1].text(
    #     X_i,
    #     0.7 * axs[1].get_ylim()[0],
    #     r"$X_{" + str(i) + r"}$",
    #     verticalalignment="top",
    #     horizontalalignment="center",
    # )

axs[1].set_xlabel(r"$X$")
axs[1].set_ylabel(r"$\bar{\mathcal{L}}(X)$")

axs[1].set_xticks([0, *X_shade, 1])
axs[1].set_xticklabels(["0"] + [f"$X_{i}$" for i in range(4)] + [1])
axs[1].set_yticks([])
axs[1].tick_params(axis=u'both', which=u'both',length=0)
# plt.minorticks_off()


# axs[1].show()

save_figure(fig, "nest_plot_exact")

# plt.tight_layout()
plt.show()