# Solution: Gaussian 2-Clusters (Problem 2)

## Problem Setup

We investigate how to make **categorization decisions** for two categories, where each is defined as a Gaussian distribution.

**Generative Process:**

Data are generated by:
1. First picking category $c \in \{1, 2\}$ according to prior probability $\theta$
2. Then generating datum $x$ from the corresponding category's likelihood

$$c_n | \theta \sim \text{Bernoulli}(\theta)$$
$$x_n | \mu_{c(n)}, \sigma_{c(n)}^2 \overset{iid}{\sim} \mathcal{N}(\mu_{c(n)}, \sigma_{c(n)}^2)$$

where:
- $P(c_n = 1) = \theta$ (prior probability of category 1)
- $P(c_n = 2) = 1 - \theta$ (prior probability of category 2)

For all problems, assume $\mu_1 = -1$ and $\mu_2 = 1$.

---

## üìö Reviewing Key Concepts

This problem connects several important concepts from earlier chapters:

**From [Tutorial 1, Chapter 3 - Probability as Counting](../../content/intro/03_prob_count.md)**:
- **Random variables**: Category c is a discrete random variable
- **Weighted counting**: When probabilities aren't equal (Œ∏ ‚â† 0.5)

**From [Tutorial 1, Chapter 4 - Conditional Probability](../../content/intro/04_conditional.md)**:
- **Conditional probability** P(c|x): What category given observed x?
- **Marginal vs Joint**: We'll compute marginal p(x) by summing over categories
- **Law of Total Probability**: p(x) = Œ£_c p(x|c)P(c)

**From [Tutorial 1, Chapter 5 - Bayes' Theorem](../../content/intro/05_bayes.md)**:
- **Bayes' rule**: P(H|E) = P(E|H)P(H) / P(E)
- Here: P(c|x) = p(x|c)P(c) / p(x)
- **Critical concept**: Prior belief updated by evidence

**From [Tutorial 2, Chapter 1 - Mystery Bentos](../../content/intro2/01_mystery_bentos.md)**:
- **Discrete mixtures**: Chibany's 70% tonkatsu, 30% hamburger
- **Expected value**: E[X] = Œ∏¬∑value‚ÇÅ + (1-Œ∏)¬∑value‚ÇÇ
- **Now extending with continuous distributions!**

**From [Tutorial 2, Chapter 3 - Gaussian Distribution](../../content/intro2/03_gaussian.md)**:
- **Gaussian PDF**: N(x; Œº, œÉ¬≤) describes bell curve
- **Properties**: Mean Œº centers distribution, variance œÉ¬≤ controls spread
- Each category has its own Gaussian distribution

**What's new in this assignment:**
- **Latent (hidden) variables**: Category c is unknown, must infer from x
- **Categorization**: Using Bayes' rule to classify observations
- **Mixture distributions**: Combining multiple Gaussians with weights
- **Decision boundaries**: Where does P(c=1|x) = 0.5?

In [None]:
# Import packages
import jax
import jax.numpy as jnp
import jax.random as random
import jax.lax as lax
from genjax import gen, bernoulli, normal
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

# Configure matplotlib
plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

# Set random seed
np.random.seed(42)
key = random.PRNGKey(42)

## Helper Functions

In [None]:
def update_datum_c1(mu_1, mu_2, sigma_1_squared, sigma_2_squared, theta, x_min, x_max):
    """
    Compute posterior probability of category 1 and marginal distribution.
    
    Args:
        mu_1, mu_2: Means of categories 1 and 2
        sigma_1_squared, sigma_2_squared: Variances of categories 1 and 2
        theta: Prior probability of category 1
        x_min, x_max: Range for plotting
    
    Returns:
        posterior_c1: P(c=1|x) over x_range
        marginal: p(x) over x_range
    """
    # Define x range
    x_range = np.linspace(x_min, x_max, 1000)
    
    # Compute likelihoods
    likelihood_1 = norm.pdf(x_range, mu_1, np.sqrt(sigma_1_squared))
    likelihood_2 = norm.pdf(x_range, mu_2, np.sqrt(sigma_2_squared))
    
    # Compute posterior probabilities using Bayes' rule
    posterior_c1 = (theta * likelihood_1) / (
        theta * likelihood_1 + (1 - theta) * likelihood_2
    )
    
    # Compute marginal distribution (mixture)
    marginal = theta * likelihood_1 + (1 - theta) * likelihood_2
    
    return posterior_c1, marginal

---

## Problem 2(a): Derivation - Categorization

Using **Bayes' rule**, derive the probability of a single datum being in category 1: $P(c_1=1|x_1)$.

### Solution

Using Bayes' rule:

$$P(c_1 = 1 | x_1) = \frac{P(x_1 | c_1 = 1) P(c_1 = 1)}{P(x_1)}$$

**Step 1: Prior probability**
$$P(c_1 = 1) = \theta$$

**Step 2: Likelihood**
$$P(x_1 | c_1 = 1) = \mathcal{N}(x_1; \mu_1, \sigma_1^2)$$

where:
$$\mathcal{N}(x; \mu, \sigma^2) = \frac{1}{\sqrt{2\pi \sigma^2}} \exp\left( -\frac{(x - \mu)^2}{2\sigma^2} \right)$$

**Step 3: Marginal likelihood** (Law of Total Probability)

$$P(x_1) = \sum_{c} P(x_1 | c) P(c)$$
$$= P(x_1 | c_1 = 1) P(c_1 = 1) + P(x_1 | c_1 = 2) P(c_1 = 2)$$
$$= \theta \, \mathcal{N}(x_1; \mu_1, \sigma_1^2) + (1-\theta) \, \mathcal{N}(x_1; \mu_2, \sigma_2^2)$$

**Step 4: Posterior probability**

Substituting into Bayes' rule:

$$\boxed{P(c_1=1|x_1) = \frac{\theta \, \mathcal{N}(x_1; \mu_1, \sigma_1^2)}{\theta \, \mathcal{N}(x_1; \mu_1, \sigma_1^2) + (1-\theta) \, \mathcal{N}(x_1; \mu_2, \sigma_2^2)}}$$

### Interpretation

The posterior probability $P(c=1|x)$ is the **proportion of the likelihood weighted by the prior probability**.

- Numerator: How likely is $x$ under category 1, weighted by prior belief in category 1
- Denominator: Total likelihood of $x$ across both categories (marginal)
- This normalizes the weighted likelihood to be a valid probability (sums to 1 over categories)

---

## Problem 2(b): Categorization

Calculate and plot the probability of being in category 1 for:
1. $\theta = 0.5$ and $\theta = 0.75$ with $\sigma_1^2 = \sigma_2^2 = 1$
2. $\theta = 0.5$ when $\sigma_1^2 = 0.5$ and $\sigma_2^2 = 2$
3. $\theta = 0.75$ when $\sigma_1^2 = 0.5$ and $\sigma_2^2 = 2$

**Question**: Describe the effect of changing the prior and the variance on categorization decisions.

In [None]:
# Fixed means
mu_1 = -1.0
mu_2 = 1.0

# Configuration sets
theta_values = [0.5, 0.75, 0.25]
configs = [
    (1, 1),      # œÉ‚ÇÅ¬≤ = œÉ‚ÇÇ¬≤ = 1 (equal variances)
    (0.5, 2),    # œÉ‚ÇÅ¬≤ = 0.5, œÉ‚ÇÇ¬≤ = 2 (c1 more precise)
    (2, 0.5)     # œÉ‚ÇÅ¬≤ = 2, œÉ‚ÇÇ¬≤ = 0.5 (c2 more precise)
]

# Colors and linestyles
colors = {0.5: "blue", 0.75: "orange", 0.25: "#980025"}
colors_sigma = {1: "steelblue", 0.5: "red", 2: "green"}
linestyles = {1: "-", 0.5: "--", 2: ":"}

# Plot range
x_min = -6
x_max = 6
x_range = np.linspace(x_min, x_max, 1000)

# Create figure
fig, axes = plt.subplots(2, 1, figsize=(12, 10))

# Plot 1: Posterior P(c=1|x)
for j, (sigma_1_sq, sigma_2_sq) in enumerate(configs):
    for theta in theta_values:
        posterior_c1, _ = update_datum_c1(mu_1, mu_2, sigma_1_sq, sigma_2_sq, theta, x_min, x_max)
        
        axes[0].plot(
            x_range,
            posterior_c1,
            label=f"Œ∏={theta}, œÉ¬≤‚ÇÅ={sigma_1_sq}, œÉ¬≤‚ÇÇ={sigma_2_sq}",
            color=colors[theta],
            linestyle=linestyles[sigma_1_sq],
            linewidth=2
        )

# Decision boundary
axes[0].axhline(0.5, color='black', linestyle=':', linewidth=1, alpha=0.5, label='Decision boundary (0.5)')
axes[0].axvline(0, color='gray', linestyle='--', linewidth=1, alpha=0.3)
axes[0].set_title('Posterior Probability $P(c=1|x)$', fontsize=14)
axes[0].set_xlabel('$x$', fontsize=12)
axes[0].set_ylabel('Probability', fontsize=12)
axes[0].set_ylim([-0.05, 1.05])
axes[0].legend(fontsize=9, ncol=2)
axes[0].grid(True, alpha=0.3)

# Plot 2: Likelihoods
for sigma_1_sq, sigma_2_sq in configs:
    lik_1 = norm.pdf(x_range, mu_1, np.sqrt(sigma_1_sq))
    lik_2 = norm.pdf(x_range, mu_2, np.sqrt(sigma_2_sq))
    
    axes[1].plot(
        x_range,
        lik_1,
        label=f"L(c=1|x), œÉ¬≤‚ÇÅ={sigma_1_sq}, œÉ¬≤‚ÇÇ={sigma_2_sq}",
        color=colors_sigma[sigma_1_sq],
        linestyle="-",
        linewidth=2
    )
    
    axes[1].plot(
        x_range,
        lik_2,
        label=f"L(c=2|x), œÉ¬≤‚ÇÅ={sigma_1_sq}, œÉ¬≤‚ÇÇ={sigma_2_sq}",
        color=colors_sigma[sigma_1_sq],
        linestyle="--",
        linewidth=2
    )

axes[1].axvline(mu_1, color='blue', linestyle=':', linewidth=1, alpha=0.5)
axes[1].axvline(mu_2, color='red', linestyle=':', linewidth=1, alpha=0.5)
axes[1].set_title('Likelihood Distributions', fontsize=14)
axes[1].set_xlabel('$x$', fontsize=12)
axes[1].set_ylabel('Density', fontsize=12)
axes[1].legend(fontsize=9, ncol=2)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Interpretation

**Effect of Prior ($\theta$) on Categorization:**

1. **Equal prior** ($\theta = 0.5$):
   - When variances are equal, decision boundary is at $x = 0$ (midpoint between $\mu_1 = -1$ and $\mu_2 = 1$)
   - The posterior probability depends purely on the relative likelihood values
   - Symmetric behavior around midpoint

2. **Unequal prior** ($\theta = 0.75$ favors c=1):
   - Decision boundary **shifts toward category 2** (the less favored category)
   - This makes sense: we need stronger evidence (higher likelihood) to overcome prior belief
   - The "75% line" crosses 0.5 at a positive $x$ value

3. **Prior acts as a "magnifier"**:
   - It doesn't change the shape of likelihood distributions
   - It amplifies or diminishes the contribution of each component to posterior
   - Higher $\theta$ ‚Üí category 1 gets more weight in Bayes' rule calculation

**Effect of Likelihood Variance on Categorization:**

1. **Equal variances** ($\sigma_1^2 = \sigma_2^2 = 1$):
   - Decision boundary determined by distance from means and prior
   - Posterior transitions smoothly between 0 and 1
   - Tails determined by which mean is closer

2. **Unequal variances** ($\sigma_1^2 \neq \sigma_2^2$):
   - **Component with larger variance dominates the tails!**
   - Why? Larger variance ‚Üí flatter distribution ‚Üí longer tails
   - Creates asymmetric categorization behavior
   
3. **"Needle effect"**:
   - Component with smaller variance has higher peak (taller, narrower)
   - Near its mean, it can "punch through" and dominate categorization
   - This creates sharp transitions in posterior probability

**Relative degree of dispersion** (variance ratio) determines:
- Shape of posterior probability curve
- Whether categorization is balanced or dominated by one component
- Location and sharpness of decision boundary

**Combined Effects:**

When **both** prior and variances are unequal:
- Prior shifts the baseline favoritism
- Variance ratio determines how sharply the posterior transitions
- The "needle" can be enhanced (same direction as prior) or diminished (opposite direction)

**Decision Boundary Analysis:**
- Equal prior + equal variance ‚Üí boundary at midpoint (0)
- Unequal prior + equal variance ‚Üí boundary shifts away from favored category
- Equal prior + unequal variance ‚Üí boundary shifts toward higher-variance category
- Unequal prior + unequal variance ‚Üí complex interaction!

In [None]:
# Find and print decision boundaries
print("\nüìä Decision Boundary Analysis:\n")
print(f"{'Configuration':<40} {'Decision Boundary (x where P(c=1|x)=0.5)':<45}")
print("-" * 85)

for sigma_1_sq, sigma_2_sq in configs:
    for theta in theta_values:
        posterior_c1, _ = update_datum_c1(mu_1, mu_2, sigma_1_sq, sigma_2_sq, theta, x_min, x_max)
        
        # Find where P(c=1|x) ‚âà 0.5
        decision_idx = np.argmin(np.abs(posterior_c1 - 0.5))
        decision_x = x_range[decision_idx]
        
        config_str = f"Œ∏={theta}, œÉ¬≤‚ÇÅ={sigma_1_sq}, œÉ¬≤‚ÇÇ={sigma_2_sq}"
        print(f"{config_str:<40} {decision_x:<45.3f}")

print("\n" + "="*85)
print("\n‚úÖ Key Observations:")
print("  ‚Ä¢ Œ∏=0.5, equal variances ‚Üí boundary at 0.0 (midpoint)")
print("  ‚Ä¢ Œ∏>0.5 ‚Üí boundary shifts RIGHT (toward less favored category)")
print("  ‚Ä¢ Œ∏<0.5 ‚Üí boundary shifts LEFT (toward less favored category)")
print("  ‚Ä¢ Larger œÉ¬≤‚ÇÇ ‚Üí boundary shifts toward c=2 (compensates for wider distribution)")
print("  ‚Ä¢ Combined effects can reinforce or counteract each other")

---

## Problem 2(c): Derivation - Prediction

Using Bayes' rule and the **Law of Total Probability**, derive the probability of a data point $p(x)$ according to this model (without any given data).

### Solution

Using the **Law of Total Probability**:

$$p(x_1) = \sum_{c} p(x_1 | c) \, P(c)$$

**üìñ Recall from [Tutorial 1, Chapter 4](../../content/intro/04_conditional.md):**

The **Law of Total Probability** (also called the **sum rule** or **marginalization**):
- To find P(A), sum over all possibilities of another variable B
- $P(A) = \sum_b P(A, B=b) = \sum_b P(A|B=b) \cdot P(B=b)$
- This "marginalizes out" the variable B

**Step 1: Expand for two categories**

$$p(x_1) = p(x_1 | c_1 = 1) P(c_1 = 1) + p(x_1 | c_1 = 2) P(c_1 = 2)$$

**Step 2: Substitute terms**

Prior probabilities:
$$P(c_1 = 1) = \theta, \quad P(c_1 = 2) = 1 - \theta$$

Likelihoods:
$$p(x_1 | c_1 = 1) = \mathcal{N}(x_1; \mu_1, \sigma_1^2)$$
$$p(x_1 | c_1 = 2) = \mathcal{N}(x_1; \mu_2, \sigma_2^2)$$

**Step 3: Final result**

$$\boxed{p(x_1) = \theta \, \mathcal{N}(x_1; \mu_1, \sigma_1^2) + (1-\theta) \, \mathcal{N}(x_1; \mu_2, \sigma_2^2)}$$

### Interpretation

The marginal distribution $p(x)$ is a **weighted mixture** of the two component Gaussians:
- Weights are the prior probabilities ($\theta$ and $1-\theta$)
- Each component contributes according to its weight
- This is called a **Gaussian Mixture Model** (GMM)

**Connection to [Tutorial 2, Chapter 1](../../content/intro2/01_mystery_bentos.md)**:
- Remember Chibany's discrete mixture: 70% √ó 500g + 30% √ó 350g = 455g
- Here we have the **continuous analog**: weighted sum of PDFs, not values!
- Instead of discrete outcomes, we have continuous distributions

The marginal can have:
- **Two peaks** (bimodal) if components are well-separated
- **One peak** (unimodal) if components overlap heavily
- **Plateau** if components have similar height and overlap

---

## Problem 2(d): Prediction

Plot $p(x_1)$ for:
1. $\theta = 0.5$ and $\theta = 0.75$ with $\sigma_1^2 = \sigma_2^2 = 1$
2. $\theta = 0.5$ when $\sigma_1^2 = 0.5$ and $\sigma_2^2 = 2$
3. $\theta = 0.75$ when $\sigma_1^2 = 0.5$ and $\sigma_2^2 = 2$

**Question**: How does the prior and variance affect $p(x_1)$?

In [None]:
# Configuration sets (matching part b)
theta_values = [0.5, 0.75]
configs = [(1, 1), (0.5, 2)]

# Colors
colors = {0.5: "blue", 0.75: "orange"}

# Create figure
fig, axes = plt.subplots(2, 1, figsize=(12, 10))

# Plot range
x_min = -6
x_max = 6
x_range = np.linspace(x_min, x_max, 1000)

for j, (sigma_1_sq, sigma_2_sq) in enumerate(configs):
    for theta in theta_values:
        posterior_c1, marginal = update_datum_c1(mu_1, mu_2, sigma_1_sq, sigma_2_sq, theta, x_min, x_max)
        
        # Plot posterior (for comparison)
        axes[0].plot(
            x_range,
            posterior_c1,
            label=f"P(c=1|x), Œ∏={theta}, œÉ¬≤‚ÇÅ={sigma_1_sq}, œÉ¬≤‚ÇÇ={sigma_2_sq}",
            color=colors[theta],
            linestyle="-" if sigma_1_sq == 1 else "--",
            linewidth=2
        )
        
        # Plot marginal (predictive)
        axes[1].plot(
            x_range,
            marginal,
            label=f"p(x), Œ∏={theta}, œÉ¬≤‚ÇÅ={sigma_1_sq}, œÉ¬≤‚ÇÇ={sigma_2_sq}",
            color=colors[theta],
            linestyle="-" if sigma_1_sq == 1 else "--",
            linewidth=2
        )

# Configure posterior plot
axes[0].axhline(0.5, color='black', linestyle=':', linewidth=1, alpha=0.3)
axes[0].set_title('Posterior Probability $P(c=1|x)$', fontsize=14)
axes[0].set_xlabel('$x$', fontsize=12)
axes[0].set_ylabel('Probability', fontsize=12)
axes[0].set_ylim([-0.05, 1.05])
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Configure marginal plot
axes[1].axvline(mu_1, color='blue', linestyle=':', linewidth=1, alpha=0.3, label='Œº‚ÇÅ=-1')
axes[1].axvline(mu_2, color='red', linestyle=':', linewidth=1, alpha=0.3, label='Œº‚ÇÇ=1')
axes[1].set_title('Marginal Distribution $p(x)$ (Mixture)', fontsize=14)
axes[1].set_xlabel('$x$', fontsize=12)
axes[1].set_ylabel('Density', fontsize=12)
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Interpretation

**Effect of Prior ($\theta$) on Marginal Distribution:**

The marginal distribution is:
$$p(x) = \theta \, \mathcal{N}(x; \mu_1, \sigma_1^2) + (1-\theta) \, \mathcal{N}(x; \mu_2, \sigma_2^2)$$

1. **Prior as a "magnifier"**:
   - $\theta$ controls how much each component contributes to the mixture
   - Higher $\theta$ ‚Üí component 1 gets larger weight ‚Üí higher peak near $\mu_1$
   - Lower $\theta$ ‚Üí component 2 gets larger weight ‚Üí higher peak near $\mu_2$

2. **Envelope constraint**:
   - The marginal is always bounded by the individual components
   - Extreme case: $\theta = 1$ ‚Üí $p(x) = \mathcal{N}(x; \mu_1, \sigma_1^2)$ (pure component 1)
   - Extreme case: $\theta = 0$ ‚Üí $p(x) = \mathcal{N}(x; \mu_2, \sigma_2^2)$ (pure component 2)

3. **Shape modulation**:
   - Prior doesn't change the **shape** of individual components
   - It only changes their **relative heights** in the mixture
   - The mixture shape depends on overlap and relative weights

**Effect of Variance on Marginal Distribution:**

1. **Equal variances** ($\sigma_1^2 = \sigma_2^2 = 1$):
   - With $\theta = 0.5$: Creates a **plateau** (flat top merged from two diminished peaks)
   - With $\theta \neq 0.5$: Asymmetric mixture, one peak higher than the other
   - Smooth transition between component regions

2. **Unequal variances** ($\sigma_1^2 = 0.5, \sigma_2^2 = 2$):
   - Component with **smaller variance** creates a sharp spike (high, narrow peak)
   - Component with **larger variance** creates a wide bump (low, wide peak)
   - Bimodal shape more pronounced (two distinct peaks visible)

3. **Combined effect**:
   - Equal prior + equal variance ‚Üí plateau or symmetric bimodal
   - Unequal prior + equal variance ‚Üí asymmetric mixture, one peak larger
   - Equal prior + unequal variance ‚Üí asymmetric mixture, narrow peak taller
   - Unequal prior + unequal variance ‚Üí complex asymmetry, depends on alignment

**Modality Analysis:**

The number of peaks (modes) depends on:
- **Separation** of means ($|\mu_1 - \mu_2|$)
- **Variance ratio** ($\sigma_1^2 / \sigma_2^2$)
- **Mixing weights** ($\theta$ vs $1-\theta$)

In our examples:
- Blue solid (Œ∏=0.5, equal var): **Unimodal plateau** (components merge)
- Blue dashed (Œ∏=0.5, unequal var): **Bimodal** (distinct peaks)
- Orange solid (Œ∏=0.75, equal var): **Asymmetric unimodal** (one side heavier)
- Orange dashed (Œ∏=0.75, unequal var): **Asymmetric bimodal** (different peak heights)

In [None]:
# Analyze modality
print("\nüìä Marginal Distribution Analysis:\n")
print(f"{'Configuration':<40} {'Modality':<20} {'Peak Location(s)':<30}")
print("-" * 90)

for sigma_1_sq, sigma_2_sq in configs:
    for theta in theta_values:
        _, marginal = update_datum_c1(mu_1, mu_2, sigma_1_sq, sigma_2_sq, theta, x_min, x_max)
        
        # Find peaks (local maxima)
        from scipy.signal import find_peaks
        peaks, _ = find_peaks(marginal, height=0.01, distance=50)
        peak_locations = x_range[peaks]
        
        # Determine modality
        if len(peaks) == 0:
            modality = "No clear peak"
        elif len(peaks) == 1:
            modality = "Unimodal"
        elif len(peaks) == 2:
            modality = "Bimodal"
        else:
            modality = f"{len(peaks)} peaks"
        
        peak_str = ", ".join([f"{p:.2f}" for p in peak_locations]) if len(peaks) > 0 else "N/A"
        
        config_str = f"Œ∏={theta}, œÉ¬≤‚ÇÅ={sigma_1_sq}, œÉ¬≤‚ÇÇ={sigma_2_sq}"
        print(f"{config_str:<40} {modality:<20} {peak_str:<30}")

print("\n" + "="*90)
print("\n‚úÖ Conclusions:")
print("  ‚Ä¢ Prior Œ∏ acts as weight/magnifier for each component")
print("  ‚Ä¢ Equal variances + equal prior ‚Üí plateau (merged peaks)")
print("  ‚Ä¢ Unequal variances ‚Üí distinct peaks (bimodal)")
print("  ‚Ä¢ Smaller variance component has taller, sharper peak")
print("  ‚Ä¢ Prior doesn't change component shapes, only their relative heights")

---

## GenJAX Implementation: Mixture Model

Let's implement and simulate from the Gaussian mixture model using GenJAX!

In [None]:
@gen
def gaussian_mixture_model(theta, mu_1, mu_2, sigma_1, sigma_2):
    """
    GenJAX generative model for Gaussian mixture.
    
    Args:
        theta: Prior probability of category 1
        mu_1, mu_2: Means of categories 1 and 2
        sigma_1, sigma_2: Standard deviations of categories 1 and 2
    """
    # Sample category (Bernoulli with parameter theta)
    c = bernoulli(theta) @ "category"
    
    # Sample from both distributions (GenJAX needs @ operator at top level)
    x1 = normal(mu_1, sigma_1) @ "observation_1"
    x2 = normal(mu_2, sigma_2) @ "observation_2"
    
    # Select which observation to use based on category
    x = lax.cond(c == 1, lambda _: x1, lambda _: x2, None)
    
    return x, c

# Simulate from mixture model
print("üî¨ GenJAX Simulation: Gaussian Mixture Model\n")

# Parameters
theta = 0.7
sigma_1 = 1.0
sigma_2 = 1.0
n_samples = 2000

print(f"Parameters: Œ∏={theta}, Œº‚ÇÅ={mu_1}, Œº‚ÇÇ={mu_2}, œÉ‚ÇÅ={sigma_1}, œÉ‚ÇÇ={sigma_2}")
print(f"Generating {n_samples} samples...\n")

# Generate samples
key = random.PRNGKey(42)
observations = []
categories = []

for _ in range(n_samples):
    key, subkey = random.split(key)
    trace = gaussian_mixture_model.simulate(subkey, (theta, mu_1, mu_2, sigma_1, sigma_2))
    x, c = trace.get_retval()
    observations.append(float(x))
    categories.append(int(c))

observations = np.array(observations)
categories = np.array(categories)

# Separate by category
obs_c1 = observations[categories == 1]
obs_c2 = observations[categories == 0]

print(f"Results:")
print(f"  Category 1: {len(obs_c1)} samples ({len(obs_c1)/n_samples*100:.1f}%) [expected: {theta*100:.1f}%]")
print(f"  Category 2: {len(obs_c2)} samples ({len(obs_c2)/n_samples*100:.1f}%) [expected: {(1-theta)*100:.1f}%]")
print(f"  Overall mean: {np.mean(observations):.2f} [expected: {theta*mu_1 + (1-theta)*mu_2:.2f}]")
print(f"  Category 1 mean: {np.mean(obs_c1):.2f} [expected: {mu_1:.2f}]")
print(f"  Category 2 mean: {np.mean(obs_c2):.2f} [expected: {mu_2:.2f}]")

In [None]:
# Visualize GenJAX simulation results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: All observations (marginal distribution)
axes[0].hist(observations, bins=60, density=True, alpha=0.7, color='purple', 
            edgecolor='black', label='Simulated data')

# Overlay theoretical marginal
x_plot = np.linspace(-5, 5, 1000)
theoretical_marginal = theta * norm.pdf(x_plot, mu_1, sigma_1) + \
                      (1-theta) * norm.pdf(x_plot, mu_2, sigma_2)
axes[0].plot(x_plot, theoretical_marginal, 'r-', linewidth=2, 
            label='Theoretical p(x)')

axes[0].set_xlabel('x', fontsize=12)
axes[0].set_ylabel('Density', fontsize=12)
axes[0].set_title('Marginal Distribution: Simulated vs Theoretical', fontsize=13)
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Plot 2: Observations by true category
axes[1].hist(obs_c1, bins=30, density=True, alpha=0.6, color='blue', 
            edgecolor='black', label=f'Category 1 (n={len(obs_c1)})')
axes[1].hist(obs_c2, bins=30, density=True, alpha=0.6, color='red', 
            edgecolor='black', label=f'Category 2 (n={len(obs_c2)})')

# Overlay theoretical likelihoods
axes[1].plot(x_plot, norm.pdf(x_plot, mu_1, sigma_1), 'b-', linewidth=2, 
            alpha=0.7, label='Theoretical L(c=1|x)')
axes[1].plot(x_plot, norm.pdf(x_plot, mu_2, sigma_2), 'r-', linewidth=2, 
            alpha=0.7, label='Theoretical L(c=2|x)')

axes[1].axvline(mu_1, color='blue', linestyle='--', linewidth=1, alpha=0.5)
axes[1].axvline(mu_2, color='red', linestyle='--', linewidth=1, alpha=0.5)

axes[1].set_xlabel('x', fontsize=12)
axes[1].set_ylabel('Density', fontsize=12)
axes[1].set_title('Observations by True Category', fontsize=13)
axes[1].legend(fontsize=9)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n‚úÖ GenJAX simulation matches theoretical predictions!")

---

## Summary

### Key Insights from Problem 2:

1. **Categorization (Posterior Probability)**:
   - Computed using Bayes' rule: $P(c=1|x) = \frac{\theta \mathcal{N}(x;\mu_1,\sigma_1^2)}{\theta \mathcal{N}(x;\mu_1,\sigma_1^2) + (1-\theta)\mathcal{N}(x;\mu_2,\sigma_2^2)}$
   - Prior $\theta$ acts as a weight/magnifier for each component
   - Decision boundary shifts based on prior and variance ratios
   - Component with larger variance dominates tails

2. **Effect of Prior ($\theta$)**:
   - Shifts decision boundary toward less favored category
   - Amplifies contribution of favored component in mixture
   - Doesn't change likelihood shapes, only their relative importance

3. **Effect of Variance Ratio**:
   - Equal variances ‚Üí symmetric behavior around midpoint
   - Unequal variances ‚Üí asymmetric categorization
   - Smaller variance ‚Üí taller, sharper peak ("needle effect")
   - Larger variance ‚Üí wider, flatter distribution (dominates tails)

4. **Marginal Distribution (Mixture)**:
   - Weighted sum: $p(x) = \theta \mathcal{N}(x;\mu_1,\sigma_1^2) + (1-\theta)\mathcal{N}(x;\mu_2,\sigma_2^2)$
   - Can be unimodal, bimodal, or plateau depending on parameters
   - Prior controls relative heights of component peaks
   - Variance ratio determines peak sharpness

5. **Mixture Model Behavior**:
   - Components bounded by envelope of individual Gaussians
   - Separation of means, variance ratio, and mixing weights determine modality
   - GenJAX simulations verify theoretical predictions

### Connection to Bayesian Learning:

This problem introduces **latent variable models**:
- Observed: $x$ (continuous)
- Latent (hidden): $c$ (discrete category)
- Inference: Given $x$, infer $c$ (categorization)
- Generation: Given $c$, generate $x$ (sampling)

This framework extends to:
- Gaussian Mixture Models (GMM) with unknown $\mu_1, \mu_2, \sigma_1^2, \sigma_2^2$
- Expectation-Maximization (EM) algorithm for learning parameters
- Clustering and unsupervised learning

**Next steps**: Extend to unknown parameters and learn from data!