# Solution: Gaussian Bayesian Update (Problem 1)

## Problem Setup

In class, we derived the posterior and predictive distributions for a Gaussian-Gaussian model:

**Generative Process:**
$$\mu \sim \mathcal{N}(\mu_0, \sigma_0^2)$$
$$x_1, \ldots, x_N | \mu, \sigma_x^2 \overset{iid}{\sim} \mathcal{N}(\mu, \sigma_x^2)$$

**Posterior Distribution:**
$$\mu | x_1, \ldots, x_N \sim \mathcal{N}\left( \frac{\mu_0 \sigma_0^{-2} + \sigma_x^{-2} \sum_{n=1}^N x_n}{\sigma_0^{-2} + N \sigma_x^{-2}}, \left[ \sigma_0^{-2} + N \sigma_x^{-2} \right]^{-1} \right)$$

**Predictive Distribution:**
$$x_{N+1} | x_1, \ldots, x_N \sim \mathcal{N}\left( \frac{\mu_0 \sigma_0^{-2} + \sigma_x^{-2} \sum_{n=1}^N x_n}{\sigma_0^{-2} + N \sigma_x^{-2}}, \left[ \sigma_0^{-2} + N \sigma_x^{-2} \right]^{-1} + \sigma_x^2 \right)$$

For this problem, use $\mu_0 = 0$ and $\sigma_0^2 = 1$.

We will explore how the number of data points and variance of the likelihood affect the posterior and predictive distributions.

---

## üìö Reviewing Key Concepts

This problem builds on concepts from earlier chapters:

**From [Tutorial 1, Chapter 5 - Bayes' Theorem](../../content/intro/05_bayes.md)**:
- Remember Bayes' rule: **Posterior ‚àù Likelihood √ó Prior**
- We're applying it to continuous distributions here!
- $p(\mu|data) = \frac{p(data|\mu) \cdot p(\mu)}{p(data)}$

**From [Tutorial 2, Chapter 3 - Gaussian Distribution](../../content/intro2/03_gaussian.md)**:
- The Gaussian (Normal) distribution N(Œº, œÉ¬≤) with bell curve shape
- The 68-95-99.7 rule for standard deviations
- Why Gaussians appear everywhere (Central Limit Theorem)

**From [Tutorial 2, Chapter 4 - Bayesian Learning](../../content/intro2/04_bayesian_learning.md)**:
- **Conjugate priors**: Gaussian prior + Gaussian likelihood = Gaussian posterior
- **Precision-weighted averaging**: Posterior mean balances prior and data
- **Sequential learning**: Update one observation at a time
- **Predictive distribution**: Combines posterior uncertainty + data variance

**What's new in this assignment:**
- **Systematic exploration**: How do œÉ¬≤_x and N affect learning?
- **Visual intuition**: See the precision-weighting in action
- **Verification**: Compare analytical formulas with GenJAX simulations

In [None]:
# Import packages
import jax
import jax.numpy as jnp
import jax.random as random
from genjax import gen, normal
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

# Configure matplotlib
plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

# Set random seed
np.random.seed(42)
key = random.PRNGKey(42)

## Helper Functions

In [None]:
def update_posterior(mu_0, sigma_0_squared, x_s, sigma_x_squared, x_min, x_max):
    """
    Analytical Bayesian update for Gaussian-Gaussian conjugate prior.
    
    Args:
        mu_0: Prior mean
        sigma_0_squared: Prior variance
        x_s: List or array of observations
        sigma_x_squared: Likelihood variance (known)
        x_min: Minimum x value for plotting
        x_max: Maximum x value for plotting
    
    Returns:
        posterior_mu: Posterior mean
        posterior_pdf: Posterior PDF values
        predictive_mu: Predictive mean
        predictive_pdf: Predictive PDF values
    """
    n = len(x_s)
    
    # Posterior parameters
    posterior_mu = (mu_0 / sigma_0_squared + sum(x_s) / sigma_x_squared) / \
                   (1 / sigma_0_squared + n / sigma_x_squared)
    posterior_sigma_squared = 1 / (1 / sigma_0_squared + n / sigma_x_squared)
    posterior_sigma = np.sqrt(posterior_sigma_squared)
    
    # Predictive parameters
    predictive_mu = posterior_mu
    predictive_sigma_squared = posterior_sigma_squared + sigma_x_squared
    predictive_sigma = np.sqrt(predictive_sigma_squared)
    
    # Compute PDFs for plotting
    x_range = np.linspace(x_min, x_max, 1000)
    posterior_pdf = norm.pdf(x_range, posterior_mu, posterior_sigma)
    predictive_pdf = norm.pdf(x_range, predictive_mu, predictive_sigma)
    
    return posterior_mu, posterior_pdf, predictive_mu, predictive_pdf

## GenJAX Implementation

In [None]:
@gen
def gaussian_learning_model(observations, mu_0, sigma_0, sigma_x):
    """
    GenJAX generative model for Gaussian learning.
    
    Args:
        observations: Observed data points
        mu_0: Prior mean
        sigma_0: Prior standard deviation
        sigma_x: Likelihood standard deviation (known)
    """
    # Prior on unknown mean
    mu = normal(mu_0, sigma_0) @ "mu"
    
    # Generate observations
    for i in range(len(observations)):
        x = normal(mu, sigma_x) @ f"obs_{i}"
    
    return mu

@gen
def posterior_predictive(posterior_mu, posterior_sigma, sigma_x):
    """
    Sample from posterior predictive distribution.
    
    Args:
        posterior_mu: Posterior mean for mu
        posterior_sigma: Posterior std dev for mu
        sigma_x: Likelihood standard deviation
    """
    # Sample mu from posterior
    mu = normal(posterior_mu, posterior_sigma) @ "mu"
    
    # Sample new observation
    x_new = normal(mu, sigma_x) @ "x_new"
    
    return x_new

---

## Problem 1(a): Prior Distribution

Plot the prior distribution to provide a baseline.

In [None]:
# Prior parameters
mu_0 = 0
sigma_0 = 1

# Axis range
x = np.linspace(-8, 8, 1000)

# Density
y = norm.pdf(x, mu_0, sigma_0)

# Plot prior
plt.figure(figsize=(10, 6))
plt.plot(x, y, label=r'$\mathcal{N}(\mu_0=0, \sigma_0^2=1)$', color='red', linewidth=2)
plt.axvline(mu_0, color='red', linestyle='--', linewidth=1, alpha=0.5, label=f'Prior mean: {mu_0}')

# Mark 68-95-99.7 regions
plt.axvline(mu_0 - sigma_0, color='gray', linestyle=':', linewidth=1, alpha=0.5)
plt.axvline(mu_0 + sigma_0, color='gray', linestyle=':', linewidth=1, alpha=0.5, label='¬±1œÉ (68%)')
plt.axvline(mu_0 - 2*sigma_0, color='gray', linestyle=':', linewidth=1, alpha=0.3)
plt.axvline(mu_0 + 2*sigma_0, color='gray', linestyle=':', linewidth=1, alpha=0.3, label='¬±2œÉ (95%)')

plt.xlabel(r'$\mu$', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.title('Prior Distribution', fontsize=16)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nüìä Interpretation:")
print(f"  The prior distribution is centered at Œº = {mu_0}, with standard deviation œÉ = {sigma_0}.")
print(f"  68% of the prior mass is between {mu_0 - sigma_0} and {mu_0 + sigma_0}.")
print(f"  95% of the prior mass is between {mu_0 - 2*sigma_0} and {mu_0 + 2*sigma_0}.")
print(f"  The distribution drops quickly beyond ¬±3œÉ.")
print(f"\n  üìñ Recall the 68-95-99.7 rule from Chapter 3:")
print(f"     This is a standard Gaussian N(0,1), so it follows the empirical rule perfectly!")
print(f"     [Review: Tutorial 2, Chapter 3 - Gaussian Distribution]")

---

## Problem 1(b): One Datum Update

Calculate and plot the posterior and predictive distributions after observing $x_1 = 2$ for:
- $\sigma_x^2 = 0.25$ (small variance, precise measurements)
- $\sigma_x^2 = 4$ (large variance, noisy measurements)

**Question**: How does changing the variance of the likelihood affect the distributions?

In [None]:
# Prior parameters
mu_0 = 0
sigma_0_squared = 1

# Observation
x_1 = 2

# Likelihood variances to compare
sigma_x_squared_values = [0.25, 4]

# Plot range
x_min = -8
x_max = 8
x_range = np.linspace(x_min, x_max, 1000)

# Create figure
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for i, sigma_x_squared in enumerate(sigma_x_squared_values):
    # Update
    posterior_mu, posterior_pdf, predictive_mu, predictive_pdf = update_posterior(
        mu_0, sigma_0_squared, [x_1], sigma_x_squared, x_min, x_max
    )
    
    # Posterior distribution
    axes[i, 0].plot(x_range, posterior_pdf, label=f'Posterior (œÉ¬≤_x={sigma_x_squared})', 
                   color='blue', linewidth=2)
    axes[i, 0].axvline(posterior_mu, color='blue', linestyle='--', linewidth=1.5, 
                      label=f'Posterior mean = {posterior_mu:.2f}')
    axes[i, 0].axvline(x_1, color='red', linestyle=':', linewidth=1.5, 
                      label=f'Observation: {x_1}')
    
    # Add prior for comparison
    prior_pdf = norm.pdf(x_range, mu_0, np.sqrt(sigma_0_squared))
    axes[i, 0].plot(x_range, prior_pdf, 'k--', linewidth=1.5, alpha=0.5, label='Prior')
    
    axes[i, 0].set_title(f'Posterior Distribution (œÉ¬≤_x = {sigma_x_squared})', fontsize=13)
    axes[i, 0].set_xlabel('Œº', fontsize=12)
    axes[i, 0].set_ylabel('Density', fontsize=12)
    axes[i, 0].legend(fontsize=10)
    axes[i, 0].grid(True, alpha=0.3)
    
    # Predictive distribution
    axes[i, 1].plot(x_range, predictive_pdf, label=f'Predictive (œÉ¬≤_x={sigma_x_squared})', 
                   color='orange', linewidth=2)
    axes[i, 1].axvline(predictive_mu, color='orange', linestyle='--', linewidth=1.5, 
                      label=f'Predictive mean = {predictive_mu:.2f}')
    axes[i, 1].axvline(x_1, color='red', linestyle=':', linewidth=1.5, 
                      label=f'Observation: {x_1}')
    
    axes[i, 1].set_title(f'Predictive Distribution (œÉ¬≤_x = {sigma_x_squared})', fontsize=13)
    axes[i, 1].set_xlabel('x', fontsize=12)
    axes[i, 1].set_ylabel('Density', fontsize=12)
    axes[i, 1].legend(fontsize=10)
    axes[i, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Interpretation

**Effect of Likelihood Variance:**

According to the update formula for posterior distributions:

$$\sigma_N^2 = \left[ \sigma_0^{-2} + N \sigma_x^{-2} \right]^{-1}$$

A **smaller likelihood variance** ($\sigma_x^2$) results in:
- **Greater shrinkage**: The posterior concentrates more sharply around its peak
- **Stronger data influence**: The posterior mean moves closer to the observed data
- This is because the posterior mean is a precision-weighted average:
  $$\mu_N = \frac{\text{precision}_{\text{prior}} \times \mu_0 + \text{precision}_{\text{data}} \times \bar{x}}{\text{precision}_{\text{prior}} + \text{precision}_{\text{data}}}$$
  
Where precision = $1/\text{variance}$. Smaller $\sigma_x^2$ means higher data precision, so the data gets more weight.

**Key Observations:**
- With $\sigma_x^2 = 0.25$ (precise data): Posterior is narrow and close to $x_1 = 2$
- With $\sigma_x^2 = 4$ (noisy data): Posterior is wider and stays closer to prior mean (0)
- The predictive distribution is always more dispersed than the posterior (adds $\sigma_x^2$)

In [None]:
# Numerical comparison
print("\nüìä Numerical Comparison:\n")
print(f"{'Quantity':<30} {'œÉ¬≤_x = 0.25':<20} {'œÉ¬≤_x = 4':<20}")
print("-" * 70)

for sigma_x_squared in sigma_x_squared_values:
    post_mu, _, pred_mu, _ = update_posterior(mu_0, sigma_0_squared, [x_1], sigma_x_squared, x_min, x_max)
    
    # Calculate variances
    post_var = 1 / (1/sigma_0_squared + 1/sigma_x_squared)
    pred_var = post_var + sigma_x_squared
    
    col_name = f"œÉ¬≤_x = {sigma_x_squared}"
    
    if sigma_x_squared == 0.25:
        print(f"{'Posterior mean:':<30} {post_mu:<20.2f}", end="")
    else:
        post_mu_prev, _, _, _ = update_posterior(mu_0, sigma_0_squared, [x_1], 0.25, x_min, x_max)
        print(f"{post_mu:<20.2f}")
        
        print(f"{'Posterior variance:':<30} {1/(1/sigma_0_squared + 1/0.25):<20.2f} {post_var:<20.2f}")
        print(f"{'Predictive variance:':<30} {1/(1/sigma_0_squared + 1/0.25) + 0.25:<20.2f} {pred_var:<20.2f}")

# Final summary
print("\n" + "="*70)
print("\n‚úÖ Conclusion:")
print("  ‚Ä¢ Smaller œÉ¬≤_x ‚Üí Posterior closer to data, more concentrated")
print("  ‚Ä¢ Larger œÉ¬≤_x ‚Üí Posterior closer to prior, more dispersed")
print("  ‚Ä¢ Predictive always has larger variance than posterior")

---

## Problem 1(c): Multiple Data Update

Calculate and plot the posterior and predictive distributions given:
$$(x_1, \ldots, x_5) = (2.1, 2.5, 1.4, 2.2, 1.8)$$

for $\sigma_x^2 = 0.25$ and $\sigma_x^2 = 4$.

**Question**: How does this compare to the single observation case? Note that the average is 2.0 in both cases.

In [None]:
# Prior parameters
mu_0 = 0
sigma_0_squared = 1

# Observations
x_values = np.array([2.1, 2.5, 1.4, 2.2, 1.8])
N = len(x_values)
sample_mean = np.mean(x_values)

print(f"Observations: {x_values}")
print(f"Sample size: N = {N}")
print(f"Sample mean: {sample_mean:.2f}")
print(f"\nNote: In part (b), we had 1 observation at x‚ÇÅ = 2.0")
print(f"      In part (c), we have 5 observations with mean = 2.0")
print(f"      Both have the same average value!\n")

# Likelihood variances
sigma_x_squared_values = [0.25, 4]

# Create figure
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for i, sigma_x_squared in enumerate(sigma_x_squared_values):
    # Update with 1 observation (from part b)
    posterior_mu_1, posterior_pdf_1, predictive_mu_1, predictive_pdf_1 = update_posterior(
        mu_0, sigma_0_squared, [x_1], sigma_x_squared, x_min, x_max
    )
    
    # Update with 5 observations (part c)
    posterior_mu, posterior_pdf, predictive_mu, predictive_pdf = update_posterior(
        mu_0, sigma_0_squared, x_values, sigma_x_squared, x_min, x_max
    )
    
    # Posterior distribution comparison
    axes[i, 0].plot(x_range, posterior_pdf, label=f'Posterior (N=5, œÉ¬≤_x={sigma_x_squared})', 
                   color='blue', linewidth=2)
    axes[i, 0].plot(x_range, posterior_pdf_1, label=f'Posterior (N=1, œÉ¬≤_x={sigma_x_squared})', 
                   color='blue', linewidth=2, alpha=0.3, linestyle='--')
    
    axes[i, 0].axvline(posterior_mu, color='blue', linestyle='--', linewidth=1.5, alpha=0.7)
    axes[i, 0].axvline(posterior_mu_1, color='blue', linestyle='--', linewidth=1, alpha=0.3)
    axes[i, 0].axvline(sample_mean, color='red', linestyle=':', linewidth=2, 
                      label=f'Sample mean: {sample_mean:.2f}')
    
    axes[i, 0].set_title(f'Posterior Distribution (œÉ¬≤_x = {sigma_x_squared})', fontsize=13)
    axes[i, 0].set_xlabel('Œº', fontsize=12)
    axes[i, 0].set_ylabel('Density', fontsize=12)
    axes[i, 0].legend(fontsize=9)
    axes[i, 0].grid(True, alpha=0.3)
    
    # Predictive distribution comparison
    axes[i, 1].plot(x_range, predictive_pdf, label=f'Predictive (N=5, œÉ¬≤_x={sigma_x_squared})', 
                   color='orange', linewidth=2)
    axes[i, 1].plot(x_range, predictive_pdf_1, label=f'Predictive (N=1, œÉ¬≤_x={sigma_x_squared})', 
                   color='orange', linewidth=2, alpha=0.3, linestyle='--')
    
    axes[i, 1].axvline(predictive_mu, color='orange', linestyle='--', linewidth=1.5, alpha=0.7)
    axes[i, 1].axvline(predictive_mu_1, color='orange', linestyle='--', linewidth=1, alpha=0.3)
    axes[i, 1].axvline(sample_mean, color='red', linestyle=':', linewidth=2, 
                      label=f'Sample mean: {sample_mean:.2f}')
    
    axes[i, 1].set_title(f'Predictive Distribution (œÉ¬≤_x = {sigma_x_squared})', fontsize=13)
    axes[i, 1].set_xlabel('x', fontsize=12)
    axes[i, 1].set_ylabel('Density', fontsize=12)
    axes[i, 1].legend(fontsize=9)
    axes[i, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Interpretation

**Comparison: 1 observation vs 5 observations (both with mean = 2.0)**

The key insight is that the posterior variance depends on **both** the likelihood variance AND the number of observations:

$$\sigma_N^2 = \left[ \sigma_0^{-2} + \frac{N}{\sigma_x^2} \right]^{-1}$$

**Effect of increasing N (given fixed $\sigma_x^2$):**

1. **Posterior becomes more concentrated**: More observations ‚Üí smaller $\sigma_N^2$
2. **Posterior mean moves closer to sample mean**: Higher data precision
3. **Predictive distribution also becomes more concentrated**: Smaller $\sigma_N^2$ component

**Why the difference between N=1 and N=5?**

Even though both have the same mean (2.0), the **effective precision** of the data is different:
- N=1: Data precision = $1/\sigma_x^2$
- N=5: Data precision = $5/\sigma_x^2$ (5√ó higher!)

**Specific Observations:**
- With $\sigma_x^2 = 0.25$: Both posteriors are narrow, but N=5 is much sharper
- With $\sigma_x^2 = 4$: N=1 posterior stays close to prior; N=5 shifts significantly toward data
- The posterior mean is the same for both (because sample mean is the same)
- But the **confidence** (inverse of variance) is much higher with N=5

In [None]:
# Numerical comparison
print("\nüìä Numerical Comparison: N=1 vs N=5\n")
print(f"{'Quantity':<35} {'N=1, œÉ¬≤_x=0.25':<18} {'N=5, œÉ¬≤_x=0.25':<18} {'N=1, œÉ¬≤_x=4':<18} {'N=5, œÉ¬≤_x=4':<18}")
print("-" * 110)

results = {}
for n_obs in [1, 5]:
    obs = [x_1] if n_obs == 1 else x_values
    for sigma_x_sq in sigma_x_squared_values:
        post_mu, _, pred_mu, _ = update_posterior(mu_0, sigma_0_squared, obs, sigma_x_sq, x_min, x_max)
        post_var = 1 / (1/sigma_0_squared + n_obs/sigma_x_sq)
        pred_var = post_var + sigma_x_sq
        results[(n_obs, sigma_x_sq)] = {
            'post_mu': post_mu,
            'post_var': post_var,
            'post_std': np.sqrt(post_var),
            'pred_var': pred_var,
            'pred_std': np.sqrt(pred_var)
        }

print(f"{'Posterior mean:':<35} {results[(1,0.25)]['post_mu']:<18.3f} {results[(5,0.25)]['post_mu']:<18.3f} "
      f"{results[(1,4)]['post_mu']:<18.3f} {results[(5,4)]['post_mu']:<18.3f}")
print(f"{'Posterior std dev:':<35} {results[(1,0.25)]['post_std']:<18.3f} {results[(5,0.25)]['post_std']:<18.3f} "
      f"{results[(1,4)]['post_std']:<18.3f} {results[(5,4)]['post_std']:<18.3f}")
print(f"{'Predictive std dev:':<35} {results[(1,0.25)]['pred_std']:<18.3f} {results[(5,0.25)]['pred_std']:<18.3f} "
      f"{results[(1,4)]['pred_std']:<18.3f} {results[(5,4)]['pred_std']:<18.3f}")

print("\n" + "="*110)
print("\n‚úÖ Key Findings:")
print(f"  ‚Ä¢ Posterior mean is similar for N=1 and N=5 (both ‚âà {sample_mean:.1f}) because sample mean is the same")
print(f"  ‚Ä¢ Posterior std dev DECREASES with more observations (N‚Üë ‚Üí uncertainty‚Üì)")
print(f"  ‚Ä¢ With œÉ¬≤_x=0.25 (precise data): N=5 gives much sharper posterior than N=1")
print(f"  ‚Ä¢ With œÉ¬≤_x=4 (noisy data): Effect is less dramatic but still significant")
print(f"  ‚Ä¢ More data = more confidence, even if the mean stays the same!")

---

## GenJAX Verification

Let's verify our analytical results using GenJAX simulations!

In [None]:
# Verify with GenJAX simulation
print("üî¨ GenJAX Verification: Posterior Predictive Sampling\n")

# Use the N=5, œÉ¬≤_x=0.25 case
sigma_x_squared = 0.25
sigma_x = np.sqrt(sigma_x_squared)

# Analytical results
post_mu_analytical, _, pred_mu_analytical, _ = update_posterior(
    mu_0, sigma_0_squared, x_values, sigma_x_squared, x_min, x_max
)
post_var_analytical = 1 / (1/sigma_0_squared + N/sigma_x_squared)
post_std_analytical = np.sqrt(post_var_analytical)
pred_var_analytical = post_var_analytical + sigma_x_squared
pred_std_analytical = np.sqrt(pred_var_analytical)

print(f"Analytical results (N={N}, œÉ¬≤_x={sigma_x_squared}):")
print(f"  Posterior: N({post_mu_analytical:.3f}, {post_var_analytical:.3f})")
print(f"  Predictive: N({pred_mu_analytical:.3f}, {pred_var_analytical:.3f})")
print()

# GenJAX simulation
key = random.PRNGKey(42)
n_samples = 5000

predictions = []
for _ in range(n_samples):
    key, subkey = random.split(key)
    trace = posterior_predictive.simulate(subkey, (post_mu_analytical, post_std_analytical, sigma_x))
    predictions.append(float(trace.get_retval()))

predictions = np.array(predictions)

print(f"GenJAX simulation results ({n_samples} samples):")
print(f"  Predictive mean: {np.mean(predictions):.3f} (analytical: {pred_mu_analytical:.3f})")
print(f"  Predictive std: {np.std(predictions):.3f} (analytical: {pred_std_analytical:.3f})")
print()

# Plot comparison
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

# Histogram of samples
ax.hist(predictions, bins=50, density=True, alpha=0.6, color='skyblue', 
        edgecolor='black', label='GenJAX samples')

# Analytical PDF
x_plot = np.linspace(-2, 6, 1000)
analytical_pdf = norm.pdf(x_plot, pred_mu_analytical, pred_std_analytical)
ax.plot(x_plot, analytical_pdf, 'r-', linewidth=2, 
        label=f'Analytical: N({pred_mu_analytical:.2f}, {pred_var_analytical:.2f})')

ax.axvline(np.mean(predictions), color='blue', linestyle='--', linewidth=1.5, 
          label=f'Sample mean: {np.mean(predictions):.2f}')
ax.axvline(pred_mu_analytical, color='red', linestyle='--', linewidth=1.5, alpha=0.5,
          label=f'Analytical mean: {pred_mu_analytical:.2f}')

ax.set_xlabel('x (predicted observation)', fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.set_title('GenJAX Posterior Predictive vs Analytical', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\n‚úÖ GenJAX simulation matches analytical results!")

---

## Summary

### Key Insights from Problem 1:

1. **Effect of Likelihood Variance ($\sigma_x^2$)**:
   - Smaller variance ‚Üí data more influential ‚Üí posterior closer to data
   - Larger variance ‚Üí prior more influential ‚Üí posterior closer to prior
   - This is captured by the **precision-weighted average** formula

2. **Effect of Number of Observations (N)**:
   - More observations ‚Üí higher effective data precision ($N/\sigma_x^2$)
   - Posterior becomes more concentrated (smaller variance)
   - Posterior mean converges to sample mean as N ‚Üí ‚àû

3. **Predictive Distribution**:
   - Always more dispersed than posterior (adds $\sigma_x^2$)
   - Accounts for both parameter uncertainty AND data variability
   - Mean is same as posterior mean

4. **Precision Interpretation**:
   - Prior precision: $1/\sigma_0^2$
   - Data precision: $N/\sigma_x^2$
   - Posterior precision: sum of the two
   - Higher precision = more certainty

### Mathematical Framework:

**Posterior**:
$$\mu_N = \frac{\frac{1}{\sigma_0^2} \mu_0 + \frac{N}{\sigma_x^2} \bar{x}}{\frac{1}{\sigma_0^2} + \frac{N}{\sigma_x^2}}, \quad \sigma_N^2 = \frac{1}{\frac{1}{\sigma_0^2} + \frac{N}{\sigma_x^2}}$$

**Predictive**:
$$x_{N+1} \sim \mathcal{N}(\mu_N, \sigma_N^2 + \sigma_x^2)$$

This elegant framework allows us to:
- Update beliefs sequentially as data arrives
- Balance prior knowledge with observed data
- Quantify uncertainty about future observations