In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gamma
import seaborn as sns

def plot_poisson_gamma_update(prior_alpha, prior_beta, data_sum, n_obs, true_theta=5):
    """
    Visualize how the Gamma distribution updates with Poisson data
    
    Parameters:
    -----------
    prior_alpha : float
        Shape parameter of prior Gamma distribution
    prior_beta : float
        Rate parameter of prior Gamma distribution
    data_sum : float
        Sum of observed counts
    n_obs : int
        Number of observations
    true_theta : float
        True rate parameter (for reference)
    """
    
    # Create theta values for plotting
    theta = np.linspace(0, true_theta * 2, 1000)
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot prior
    prior = gamma.pdf(theta, prior_alpha, scale=1/prior_beta)
    ax1.plot(theta, prior, 'b-', label=f'Prior: Gamma({prior_alpha},{prior_beta})')
    ax1.axvline(true_theta, color='r', linestyle='--', label='True θ')
    ax1.set_title('Prior Distribution')
    ax1.set_xlabel('θ')
    ax1.set_ylabel('Density')
    ax1.legend()
    
    # Calculate posterior parameters
    post_alpha = prior_alpha + data_sum
    post_beta = prior_beta + n_obs
    
    # Plot both prior and posterior
    posterior = gamma.pdf(theta, post_alpha, scale=1/post_beta)
    ax2.plot(theta, prior, 'b-', label=f'Prior: Gamma({prior_alpha},{prior_beta})')
    ax2.plot(theta, posterior, 'g-', label=f'Posterior: Gamma({post_alpha},{post_beta})')
    ax2.axvline(true_theta, color='r', linestyle='--', label='True θ')
    ax2.set_title(f'Prior and Posterior after {n_obs} observations\nTotal count = {data_sum}')
    ax2.set_xlabel('θ')
    ax2.set_ylabel('Density')
    ax2.legend()
    
    plt.tight_layout()
    return fig

# Example with informative prior
fig1 = plot_poisson_gamma_update(prior_alpha=10, prior_beta=2, 
                                data_sum=25, n_obs=5, true_theta=5)

# Example with weak prior
fig2 = plot_poisson_gamma_update(prior_alpha=2, prior_beta=1, 
                                data_sum=25, n_obs=5, true_theta=5)

plt.show()