In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
from ipywidgets import widgets, HBox, VBox
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import numpy as np
from scipy.stats import beta

# Boost resolution for high-quality plots
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

# Create an output widget to contain our plot
output = widgets.Output()

# Define the plotting function
def plot_marble_posterior(blue_draws, red_draws, prior_strength, show_prior, show_credible_interval, show_midline):
    with output:
        # Clear the output area first
        clear_output(wait=True)
        
        # Create a new figure
        plt.figure(figsize=(10, 6))
        
        # Create x values (possible proportions of blue marbles)
        # Avoid exact 0 and 1 to prevent numerical issues with low alpha/beta
        x = np.linspace(0.001, 0.999, 1000)
        
        # Plot the prior if checkbox is selected
        prior_alpha = prior_strength
        prior_beta = prior_strength
        prior = beta.pdf(x, prior_alpha, prior_beta)
        
        if show_prior:
            # Cap the prior for visualization purposes
            normalized_prior = prior / max(prior.max(), 1e-10)  # Avoid division by zero
            plt.plot(x, np.minimum(normalized_prior, 1.0), 'b--', label='Prior belief', alpha=0.6)
        
        # Plot the posterior (Beta with updated parameters)
        posterior_alpha = blue_draws + prior_alpha
        posterior_beta = red_draws + prior_beta
        posterior = beta.pdf(x, posterior_alpha, posterior_beta)
        
        # Cap the posterior for visualization purposes
        normalized_posterior = posterior / max(posterior.max(), 1e-10)  # Avoid division by zero
        plt.plot(x, np.minimum(normalized_posterior, 1.0), 'r-', label='Posterior belief', linewidth=2)
        
        # Add credible interval (95%) if checkbox is selected
        if blue_draws + red_draws > 0 and show_credible_interval:
            ci_low, ci_high = beta.interval(0.95, posterior_alpha, posterior_beta)
            plt.axvspan(ci_low, ci_high, alpha=0.2, color='gray', 
                        label=f'95% Credible interval: ({ci_low:.2f}, {ci_high:.2f})')
            
        # Calculate probability that proportion is < 0.5
        if blue_draws + red_draws > 0:
            prob_less_than_half = beta.cdf(0.5, posterior_alpha, posterior_beta)
            
            # Add an invisible line with a label for the legend
            plt.plot([], [], ' ', label=f'P(θ < 0.5) = {prob_less_than_half:.3f}')
        
        # Add vertical line at 0.5 if checkbox is selected
        if show_midline:
            plt.axvline(x=0.5, color='gray', linestyle='--')

        # Add labels and title
        plt.xlabel("Proportion of Blue Marbles (θ)")
        plt.ylabel("Relative Probability Density")
        plt.title(f"Posterior Distribution After Drawing {blue_draws} Blue and {red_draws} Red Marbles")
        plt.legend()
        plt.ylim(0, 1.1)
        plt.grid(alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
# Create widgets
blue_slider = widgets.IntSlider(min=0, max=100, step=1, value=3, description='Blue draws:', layout=widgets.Layout(width='800px'))
red_slider = widgets.IntSlider(min=0, max=100, step=1, value=3, description='Red draws:', layout=widgets.Layout(width='800px'))
prior_slider = widgets.FloatSlider(
    min=0.5, max=3.0, step=0.01, value=1.0, 
    description='Prior strength:', 
    layout=widgets.Layout(width='800px')
)
prior_checkbox = widgets.Checkbox(value=False, description='Show prior')
credible_checkbox = widgets.Checkbox(value=False, description='Show credible interval')
midline_checkbox = widgets.Checkbox(value=False, description='Show line at θ=0.5')
button = widgets.Button(description="Update Plot")

# Function to handle button click
def on_button_clicked(b):
    plot_marble_posterior(
        blue_slider.value, 
        red_slider.value, 
        prior_slider.value, 
        prior_checkbox.value,
        credible_checkbox.value,
        midline_checkbox.value
    )

# Connect the button to the function
button.on_click(on_button_clicked)

# Layout for widgets
controls = VBox([
    blue_slider, 
    red_slider, 
    prior_slider, 
    HBox([prior_checkbox, credible_checkbox, midline_checkbox]), 
    button
])

# Display the interface
display(controls, output)

# Show initial plot
plot_marble_posterior(3, 3, 1.0, False, False, False)

VBox(children=(IntSlider(value=3, description='Blue draws:', layout=Layout(width='800px')), IntSlider(value=3,…

Output()