# Create an animation of histogram pruning

(messy code warning)

In [None]:
from functools import partial
from tqdm.notebook import tqdm
from matplotlib.animation import FuncAnimation
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
from IPython.display import display

from bayeshist.bayeshist import _prune_histogram, _bayes_factor_test


plt.rcParams["animation.html"] = "html5"

bin_edges = np.linspace(-4, 4, 40)
neg_samples, _ = np.histogram(test_x[test_y == 0], bins=bin_edges)
pos_samples, _ = np.histogram(test_x[test_y == 1], bins=bin_edges)

pruning_threshold = 2
prior_params = (1, 1000)
test = partial(_bayes_factor_test, threshold=pruning_threshold)
pruner = _prune_histogram(bin_edges, pos_samples, neg_samples, test, prior_params, yield_steps=True)

states = [state for state in pruner if not isinstance(state, tuple)]

fig = plt.figure(figsize=(9, 6))
ylim = 1e2 * max(pos_samples.max(), neg_samples.max())

pbar = tqdm()
speedup_after = 3
num_steps = 5
frame_cutoff = 10 * speedup_after * num_steps


def animate(frameno):
    pbar.update(1)
    
    if frameno < frame_cutoff:
        frameno = frameno // 10
    else:
        frameno = frameno - (frame_cutoff - frame_cutoff // 10)
        
    state_idx, step = frameno // num_steps, frameno % num_steps
    
    fig.clear()
    ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
    
    axprob = fig.add_axes([0.1, 0.1, 0.8, 0.8], frameon=False)
    axprob.grid(False)
    axprob.yaxis.tick_right()
    axprob.yaxis.set_label_position("right")
    axprob.set(
        xlim=(bin_edges[0], bin_edges[-1]),
        xticks=[],
        ylim=(1e-4, 1e1),
        yscale="log"
    )
    axprob.set_ylabel("Event rate", y=0.25)
    axprob.set_yticks([1e-4, 1e-3, 1e-2])
    axprob.tick_params(which="both", right=False)
    
    is_final_state = state_idx >= len(states)
    
    if is_final_state:
        state = states[-1]
        step = 0
    else:
        state = states[state_idx]
        
    i = state["i"]
    bins = state["bins"]
    bin_centers = 0.5 * (bins[1:] + bins[:-1])

    if is_final_state:
        ax.hist(test_x[test_y == 0], log=True, alpha=0.6, bins=bins, facecolor="C0", label="y = 0")
        ax.hist(test_x[test_y == 1], log=True, alpha=0.6, bins=bins, facecolor="C1", label="y = 1")
        fig.text(0.5, 0.85, "Final Bayesian histogram", ha="center", weight="bold")
    else:
        ax.hist(test_x[test_y == 0], log=True, alpha=0.4, bins=bins, facecolor="C0")
        ax.hist(test_x[test_y == 1], log=True, alpha=0.4, bins=bins, facecolor="C1")

        ax.hist(test_x[test_y == 0], log=True, alpha=0.8, bins=bins[i:i+3], facecolor="C0", label="y = 0")
        ax.hist(test_x[test_y == 1], log=True, alpha=0.8, bins=bins[i:i+3], facecolor="C1", label="y = 1")
    
    event_dist = scipy.stats.beta(state["pos_samples"] + prior_params[0], state["neg_samples"] + prior_params[1])
    ci_low, ci_high = event_dist.ppf(0.01), event_dist.ppf(0.99)

    # background boxes
    errorboxes = [
        Rectangle((x1, y1), x2 - x1, y2 - y1)
        for x1, x2, y1, y2
        in zip(bins[:-1], bins[1:], ci_low, ci_high)
    ]

    pc = PatchCollection(errorboxes, facecolor="0.2", alpha=0.2)
    axprob.add_collection(pc)

    # median indicator
    axprob.hlines(event_dist.median(), bins[:-1], bins[1:], colors="0.2", label="p(y = 1)")

    # box edges
    ax.hlines(ci_low, bins[:-1], bins[1:], colors="0.2", alpha=0.8, linewidth=1)
    ax.hlines(ci_high, bins[:-1], bins[1:], colors="0.2", alpha=0.8, linewidth=1)
    
    fig.legend(loc="upper center", ncol=3, frameon=False)

    if step > 0:
        axdist1 = fig.add_axes([0.16, 0.72, 0.2, 0.1])
        axdist1.axis("off")        
        dist_x = np.logspace(-5, 0, 100)

        with np.errstate(divide='ignore'):
            alpha_1, beta_1 = state["samples_1"]
            axdist1.plot(dist_x, scipy.stats.beta(alpha_1 + prior_params[0], beta_1 + prior_params[1]).pdf(dist_x), c="0.2", label="original")

            alpha_2, beta_2 = state["samples_2"]
            axdist1.plot(dist_x, scipy.stats.beta(alpha_2 + prior_params[0], beta_2 + prior_params[1]).pdf(dist_x), c="0.2")

            alpha_comb, beta_comb = alpha_1 + alpha_2, beta_1 + beta_2
            axdist1.plot(dist_x, scipy.stats.beta(alpha_comb + prior_params[0], beta_comb + prior_params[1]).pdf(dist_x), c="coral", label="merged")

        axdist1.text(0.5, -0.05, "p(y = 1)", transform=axdist1.transAxes, va="top", ha="center", color="0.2")
        axdist1.set_xscale("log")
        axdist1.set_title("Event rate distributions", weight="bold")
        axdist1.legend(loc="upper right", frameon=False, labelcolor="linecolor", handlelength=0)

    if step > 1:
        p_1 = scipy.stats.betabinom(alpha_1 + beta_1, alpha_1 + prior_params[0], beta_1 + prior_params[1]).logpmf(alpha_1)
        ax.text(bin_centers[i], beta_1, f"{p_1:.1f}", ha="center", va="bottom", fontsize=9)

        p_2 = scipy.stats.betabinom(alpha_2 + beta_2, alpha_2 + prior_params[0], beta_2 + prior_params[1]).logpmf(alpha_2)
        ax.text(bin_centers[i+1], beta_2, f"{p_2:.1f}", ha="center", fontsize=9)

        p_c1 = scipy.stats.betabinom(alpha_1 + beta_1, alpha_comb + prior_params[0], beta_comb + prior_params[1]).logpmf(alpha_1)
        ax.text(bin_centers[i], beta_1 * 2, f"{p_c1:.1f}", color="coral", ha="center", fontsize=9)

        p_c2 = scipy.stats.betabinom(alpha_2 + beta_2, alpha_comb + prior_params[0], beta_comb + prior_params[1]).logpmf(alpha_2)
        ax.text(bin_centers[i+1], beta_2 * 2, f"{p_c2:.1f}", color="coral", ha="center", fontsize=9)

        ax.text(bins[i+1], max(beta_1, beta_2) * 4, "Data log likelihood", ha="center", weight="bold")

    if step > 2:
        compsign = "$>$" if state["test_value"] > pruning_threshold else "$\\ngtr$" 
        fig.text(0.8, 0.85, "Log likelihood $\\Delta$", ha="center", va="top", weight="bold")
        fig.text(0.8, 0.82, f"{np.log(state['test_value']):.2f} {compsign} log({pruning_threshold})", ha="center", va="top")

    if step > 3:
        merge_text = "merge" if state["reverse_split"] else "don't merge"
        ax.annotate("", xy=(0.8, 0.78), xytext=(0.8, 0.7), arrowprops=dict(arrowstyle="<-", color="black"), xycoords="figure fraction", textcoords="figure fraction")
        fig.text(0.8, 0.7, f"{merge_text}", ha="center", va="top", weight="bold")

    ax.set(
        xlabel="x",
        ylabel="Count",
        xlim=(bin_edges[0], bin_edges[-1]),
        ylim=(0.5, ylim),
    )
    
animate(100)

In [None]:
num_frames = frame_cutoff - frame_cutoff // 10 + len(states) * num_steps + 20

with tqdm(total=num_frames) as pbar:
    anim = FuncAnimation(fig, animate, frames=num_frames, interval=100)
    display(anim)

In [None]:
anim.save("bayes-pruning.mp4")