In [None]:
import numpy as np
from scipy.special import digamma
from scipy.stats import beta
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


def bald(a, b):
    mu = a / (a + b)
    H_y = -mu * np.log(mu) - (1 - mu) * np.log(1 - mu)
    H_y_given_w = digamma(a + b + 1) - mu * digamma(a + 1) - (1 - mu) * digamma(b + 1)
    return np.log10(H_y - H_y_given_w)


def margin(a, b, r=0.4):
    mode = (a - 1) / (a + b - 2)
    logmargin = beta.logpdf(mode, a, b) - beta.logpdf(r, a, b)
    return np.exp(-logmargin)


def anom(a, b):
    mu = a / (a + b)
    return mu


a = 1 + 10 ** np.linspace(-1, 2, 100)
b = 1 + 10 ** np.linspace(-1, 2, 100)
A, B = np.meshgrid(a, b)
a0, b0 = 15, 10
delta_nu = [30, 60, 90]

plt.figure(figsize=(15, 4))
for i, fn in enumerate([bald, margin, anom]):
    plt.subplot(1, 3, i + 1)
    plt.contourf(A, B, fn(A, B), levels=20, cmap="viridis")
    plt.colorbar()

    plt.plot(a0, b0, "r*", markersize=10)
    for d in delta_nu:
        mu = np.linspace(0, 1, 100)
        ss = a0 + b0 + d
        a, b = ss * mu, ss * (1 - mu)
        plt.plot(a, b, "r:")

        a = a0 + np.arange(d + 1)
        b = b0 + np.arange(d + 1)[::-1]
        plt.plot(a, b, "r-", linewidth=3)

    plt.title(fn.__name__)
    plt.xlabel("$\\alpha$")
    plt.ylabel("$\\beta$")
    # plt.xscale("log")
    # plt.yscale("log")
    plt.xlim(1, A.max())
    plt.ylim(1, B.max())
    plt.gca().set_aspect("equal", adjustable="box")

    # inner plot
    axins = inset_axes(plt.gca(), width="50%", height="20%", loc="upper right")
    for d in delta_nu:
        mu = np.linspace(0, 1, 100)
        ss = a0 + b0 + d
        a, b = ss * mu, ss * (1 - mu)
        plt.plot(mu, fn(a, b), "r:")

        a = a0 + np.arange(d + 1)
        b = b0 + np.arange(d + 1)[::-1]
        mu = a / (a + b)
        axins.plot(mu, fn(a, b), "r-", linewidth=3)
        break

    axins.set_xticks([])
    axins.set_yticks([])
    axins.set_xlabel("$\\mu$")

plt.tight_layout()
plt.savefig("figures/bald_margin_anom.pdf")
plt.show()