# Tutorial 3: Handling Complexity with Mixture Models

What happens when your data doesn't come from a single, simple process? What if it could be from one of two (or more) different models? A single `LinearModel` will fail to capture this complexity.

This is where `lsbi.MixtureModel` comes in. It allows you to model data as a mixture of several linear components, and infer which component is most likely to have generated the data.

We will cover:
1. The problem: generating bimodal data that can't be fit by a single line
2. Failure of a single `LinearModel` 
3. The solution: `MixtureModel` with multiple components
4. Interpreting mixture model results
5. Advanced mixture model techniques

## 1. The Problem: A Bimodal Dataset

Let's generate data from two different lines and mix them together to create a challenging dataset that no single linear model can fit well.

In [None]:
# Generate bimodal data
import numpy as np
import matplotlib.pyplot as plt
from lsbi import LinearModel, MixtureModel

np.random.seed(123)

# Model 1: Positive slope line
theta_1 = np.array([2.0, 1.0])   # slope, intercept
x_1 = np.linspace(-2, 2, 50)
y_1 = x_1 * theta_1[0] + theta_1[1] + np.random.normal(0, 0.4, size=x_1.shape)

# Model 2: Negative slope line  
theta_2 = np.array([-2.0, -1.0]) # slope, intercept
x_2 = np.linspace(-2, 2, 50)
y_2 = x_2 * theta_2[0] + theta_2[1] + np.random.normal(0, 0.4, size=x_2.shape)

# Combine into a single dataset - this creates an "X" pattern
combined_x = np.concatenate([x_1, x_2])
combined_y = np.concatenate([y_1, y_2])

print(f"Dataset size: {len(combined_x)} points")
print(f"True parameters - Component 1: {theta_1}")
print(f"True parameters - Component 2: {theta_2}")

plt.figure(figsize=(10, 6))
plt.scatter(x_1, y_1, alpha=0.6, label='Component 1 Data', color='blue')
plt.scatter(x_2, y_2, alpha=0.6, label='Component 2 Data', color='red')
plt.plot(x_1, x_1 * theta_1[0] + theta_1[1], 'b-', lw=2, label='True Line 1')
plt.plot(x_2, x_2 * theta_2[0] + theta_2[1], 'r-', lw=2, label='True Line 2')
plt.title("A Bimodal Dataset: Two Linear Processes")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

## 2. Failure of a Single `LinearModel`

Let's see what happens when we try to fit this "X"-shaped data with a single straight line:

In [None]:
# Attempting to fit with a single model
M_single = np.vstack([combined_x, np.ones_like(combined_x)]).T
model_single = LinearModel(M=M_single, n=2, d=len(combined_x))  # Using default priors

posterior_single = model_single.posterior(combined_y)

print("Single model fit:")
print(f"Estimated parameters: {np.round(posterior_single.mean, 2)}")
print(f"Parameter uncertainty (std): {np.round(np.sqrt(np.diag(posterior_single.cov)), 2)}")

plt.figure(figsize=(10, 6))
plt.scatter(combined_x, combined_y, alpha=0.6, label='Combined Data', color='gray')
plt.plot(combined_x, M_single @ posterior_single.mean, 'orange', lw=3, label='Single Model Fit')

# Show the true generating lines for reference
plt.plot(x_1, x_1 * theta_1[0] + theta_1[1], 'b--', lw=2, alpha=0.7, label='True Line 1')
plt.plot(x_2, x_2 * theta_2[0] + theta_2[1], 'r--', lw=2, alpha=0.7, label='True Line 2')

plt.title("Failure of a Single Linear Model")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

# Calculate the evidence (log probability of the data)
evidence_single = model_single.evidence()
log_evidence_single = evidence_single.logpdf(combined_y)
print(f"\nSingle model log-evidence: {log_evidence_single:.2f}")

As expected, the single model produces a poor compromise fit that doesn't represent either underlying process well. The orange line tries to average between the two trends, resulting in a nearly flat line that fits neither component.

## 3. The Solution: `MixtureModel`

We will now model this as a mixture of two `LinearModel`s. This requires "stacking" the parameters for each component into arrays with a leading dimension equal to the number of components (in our case, 2).

**Key Concept**: A `MixtureModel` with `k` components needs:
- `mu`: shape `(k, n)` 
- `Sigma`: shape `(k, n, n)`
- `M`: shape `(k, d, n)`
- `C`: shape `(k, d, d)`
- `logw`: shape `(k,)` - log prior weights

In [None]:
# Building the MixtureModel
n_components = 2
n_params = 2
n_data = len(combined_x)

print(f"Setting up MixtureModel with {n_components} components")
print(f"Problem size: {n_params} parameters, {n_data} data points")

# Priors for the two components, stacked along a new axis
# We'll use broad priors centered around zero for both components
mu_mix = np.zeros((n_components, n_params))
print(f"mu_mix shape: {mu_mix.shape}")

# Same prior covariance for both components
Sigma_mix = np.tile(np.eye(n_params) * 10, (n_components, 1, 1))
print(f"Sigma_mix shape: {Sigma_mix.shape}")

# The M matrix is the same for both components in this case
M_mix = np.tile(M_single, (n_components, 1, 1))
print(f"M_mix shape: {M_mix.shape}")

# The data covariance is also the same for both components
C_mix = np.tile(np.eye(n_data) * 0.4**2, (n_components, 1, 1))
print(f"C_mix shape: {C_mix.shape}")

# Prior weights: we assume each component is equally likely a priori
logw = np.log([0.5, 0.5])
print(f"logw shape: {logw.shape}")
print(f"Prior probabilities: {np.exp(logw)}")

# Instantiate the MixtureModel
mixture_model = MixtureModel(M=M_mix, mu=mu_mix, Sigma=Sigma_mix, C=C_mix, logw=logw)
print(f"\n✓ MixtureModel created with {mixture_model.k} components")

**Common Pitfall Alert**: Getting the shapes right is crucial. The most common errors are:
- Forgetting the leading component dimension
- Using the wrong axis for stacking
- Inconsistent shapes between different arrays

## 4. Inference and Analysis with `MixtureModel`

Now we perform inference and inspect the results. The posterior from a `MixtureModel` contains the individual posteriors for each component, along with the updated posterior weights `logw`.

In [None]:
# Inference and Visualization
print("Computing mixture model posterior...")
mixture_posterior = mixture_model.posterior(combined_y)

# The posterior mean is now an array of shape (n_components, n_params)
posterior_means = mixture_posterior.mean
posterior_covs = mixture_posterior.cov

print("\n=== Mixture Model Results ===")
print(f"Posterior means for each component:")
for i in range(n_components):
    print(f"  Component {i+1}: {np.round(posterior_means[i], 2)}")
    
print(f"\nTrue parameters were:")
print(f"  Component 1: {theta_1}")
print(f"  Component 2: {theta_2}")

# The posterior weights tell us the probability that each component generated the data
posterior_weights = np.exp(mixture_posterior.logw)
print(f"\nComponent weights:")
print(f"  Prior weights:     {np.round(np.exp(logw), 3)}")
print(f"  Posterior weights: {np.round(posterior_weights, 3)}")

# Calculate mixture model evidence
evidence_mixture = mixture_model.evidence()
log_evidence_mixture = evidence_mixture.logpdf(combined_y)
print(f"\nModel comparison:")
print(f"  Single model log-evidence:  {log_evidence_single:.2f}")
print(f"  Mixture model log-evidence: {log_evidence_mixture:.2f}")
print(f"  Evidence ratio (mixture/single): {np.exp(log_evidence_mixture - log_evidence_single):.2e}")

In [None]:
# Visualize the mixture model fits
plt.figure(figsize=(12, 8))

# Plot the data
plt.scatter(combined_x, combined_y, alpha=0.6, label='Combined Data', color='gray', s=30)

# Plot the mixture model component fits
colors = ['blue', 'red']
for i in range(n_components):
    line_y = M_single @ posterior_means[i]
    plt.plot(combined_x, line_y, color=colors[i], lw=3, 
             label=f'Component {i+1} Fit (w={posterior_weights[i]:.2f})')
    
    # Plot uncertainty bands using posterior covariance
    # Sample from posterior to show uncertainty
    component_posterior = mixture_posterior[i]
    samples = component_posterior.rvs(size=100)
    for j in range(min(20, len(samples))):
        sample_line = M_single @ samples[j]
        plt.plot(combined_x, sample_line, color=colors[i], alpha=0.1, lw=1)

# Compare with the failed single model
plt.plot(combined_x, M_single @ posterior_single.mean, 
         'orange', lw=2, linestyle='--', alpha=0.8, label='Single Model (Failed)')

# Show true generating lines
plt.plot(x_1, x_1 * theta_1[0] + theta_1[1], 'b:', lw=2, alpha=0.8, label='True Line 1')
plt.plot(x_2, x_2 * theta_2[0] + theta_2[1], 'r:', lw=2, alpha=0.8, label='True Line 2')

plt.title("Successful Fit with MixtureModel")
plt.xlabel("x")
plt.ylabel("y")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

print("\n✓ Success! The MixtureModel correctly identified both underlying trends.")
print(f"✓ Component weights are balanced ({posterior_weights[0]:.2f}, {posterior_weights[1]:.2f})")
print(f"  as expected since we have equal amounts of data from each component.")

## 5. Understanding Mixture Model Components

Let's dive deeper into what the mixture model learned by examining each component individually:

In [None]:
# Analyze individual components
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Component parameter distributions
theta_range = np.linspace(-4, 4, 200)
colors = ['blue', 'red']

for i in range(n_components):
    component_post = mixture_posterior[i]
    
    # Slope distribution
    slope_dist = component_post.marginalise([1])  # Remove intercept, keep slope
    axes[0, i].plot(theta_range, slope_dist.pdf(theta_range[:, None]).flatten(), 
                   color=colors[i], lw=2, label=f'Component {i+1} Slope')
    axes[0, i].axvline([theta_1, theta_2][i][0], color='black', linestyle='--', 
                      label=f'True Slope ({[theta_1, theta_2][i][0]})')
    axes[0, i].set_title(f'Component {i+1}: Slope Distribution')
    axes[0, i].set_xlabel('Slope')
    axes[0, i].set_ylabel('Density')
    axes[0, i].legend()
    axes[0, i].grid(True, alpha=0.3)
    
    # Intercept distribution
    intercept_dist = component_post.marginalise([0])  # Remove slope, keep intercept
    axes[1, i].plot(theta_range, intercept_dist.pdf(theta_range[:, None]).flatten(), 
                   color=colors[i], lw=2, label=f'Component {i+1} Intercept')
    axes[1, i].axvline([theta_1, theta_2][i][1], color='black', linestyle='--', 
                      label=f'True Intercept ({[theta_1, theta_2][i][1]})')
    axes[1, i].set_title(f'Component {i+1}: Intercept Distribution')
    axes[1, i].set_xlabel('Intercept')
    axes[1, i].set_ylabel('Density')
    axes[1, i].legend()
    axes[1, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print quantitative comparison
print("=== Quantitative Assessment ===")
for i in range(n_components):
    true_params = [theta_1, theta_2][i]
    estimated_params = posterior_means[i]
    param_std = np.sqrt(np.diag(posterior_covs[i]))
    
    print(f"\nComponent {i+1}:")
    print(f"  True parameters:      [{true_params[0]:5.2f}, {true_params[1]:5.2f}]")
    print(f"  Estimated parameters: [{estimated_params[0]:5.2f}, {estimated_params[1]:5.2f}]")
    print(f"  Parameter std errors: [{param_std[0]:5.2f}, {param_std[1]:5.2f}]")
    
    # Calculate how many standard deviations away the estimate is
    z_scores = np.abs(estimated_params - true_params) / param_std
    print(f"  |Z-scores|:           [{z_scores[0]:5.2f}, {z_scores[1]:5.2f}]")
    if all(z_scores < 2):
        print(f"  ✓ Excellent fit (< 2σ error)")
    elif all(z_scores < 3):
        print(f"  ✓ Good fit (< 3σ error)")
    else:
        print(f"  ⚠ Poor fit (> 3σ error)")

## 6. Advanced Topics: Monte Carlo Estimates and Information Gain

Unlike `LinearModel`, mixture models require Monte Carlo sampling for some calculations like the KL divergence. Let's explore this:

In [None]:
# Monte Carlo estimates for mixture models
print("=== Information Gain Analysis ===")

# For a single LinearModel, KL divergence is computed analytically
try:
    kl_single = model_single.dkl(combined_y)
    print(f"Single model KL divergence: {kl_single:.3f} nats (analytical)")
except Exception as e:
    print(f"Single model KL error: {e}")

# For MixtureModel, we need Monte Carlo sampling
try:
    # This will fail - mixture models require n > 0
    kl_mixture = mixture_model.dkl(combined_y)
    print(f"Mixture model KL divergence: {kl_mixture:.3f} nats")
except Exception as e:
    print(f"Mixture model KL error (expected): {str(e)}")

# Use Monte Carlo estimation
n_samples = 5000
print(f"\nUsing Monte Carlo with {n_samples} samples...")
kl_mixture_mc = mixture_model.dkl(combined_y, n=n_samples)
print(f"Mixture model KL divergence: {kl_mixture_mc:.3f} nats (Monte Carlo)")

# Compare information gains
print(f"\nInformation Gain Comparison:")
print(f"  Single model:  {kl_single:.3f} nats ({kl_single/np.log(2):.2f} bits)")
print(f"  Mixture model: {kl_mixture_mc:.3f} nats ({kl_mixture_mc/np.log(2):.2f} bits)")
print(f"  Additional information from mixture: {(kl_mixture_mc-kl_single)/np.log(2):.2f} bits")

# Show convergence of Monte Carlo estimate
sample_sizes = [100, 500, 1000, 2000, 5000, 10000]
kl_estimates = []

print(f"\nMonte Carlo convergence:")
for n in sample_sizes:
    kl_est = mixture_model.dkl(combined_y, n=n)
    kl_estimates.append(kl_est)
    print(f"  n={n:5d}: KL = {kl_est:.4f} nats")

# Plot convergence
plt.figure(figsize=(8, 5))
plt.semilogx(sample_sizes, kl_estimates, 'o-', lw=2, markersize=6)
plt.axhline(kl_estimates[-1], color='red', linestyle='--', alpha=0.7, 
           label=f'Converged value: {kl_estimates[-1]:.3f}')
plt.xlabel('Number of Monte Carlo Samples')
plt.ylabel('KL Divergence (nats)')
plt.title('Monte Carlo Convergence for Mixture Model KL Divergence')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

## 7. Model Selection: When to Use Mixture Models

Let's examine when mixture models are justified using model evidence:

In [None]:
# Model selection using evidence
print("=== Model Selection Analysis ===")

# Compare different numbers of components
component_counts = [1, 2, 3]
log_evidences = []
model_names = []

for k in component_counts:
    if k == 1:
        # Single component (our original LinearModel)
        evidence = model_single.evidence().logpdf(combined_y)
        model_names.append("Single LinearModel")
    else:
        # Multiple components
        # Create a k-component mixture model
        mu_k = np.zeros((k, n_params))
        Sigma_k = np.tile(np.eye(n_params) * 10, (k, 1, 1))
        M_k = np.tile(M_single, (k, 1, 1))
        C_k = np.tile(np.eye(n_data) * 0.4**2, (k, 1, 1))
        logw_k = np.log(np.ones(k) / k)  # Equal weights
        
        model_k = MixtureModel(M=M_k, mu=mu_k, Sigma=Sigma_k, C=C_k, logw=logw_k)
        evidence = model_k.evidence().logpdf(combined_y)
        model_names.append(f"{k}-Component Mixture")
    
    log_evidences.append(evidence)

# Display results
print("Model Evidence Comparison:")
best_idx = np.argmax(log_evidences)
for i, (name, log_ev) in enumerate(zip(model_names, log_evidences)):
    relative_evidence = np.exp(log_ev - log_evidences[best_idx])
    marker = "👑" if i == best_idx else "  "
    print(f"{marker} {name:20s}: log-evidence = {log_ev:8.2f}, relative = {relative_evidence:.2e}")

print(f"\n🏆 Best model: {model_names[best_idx]}")

# Bayes factor interpretation
bayes_factor = np.exp(log_evidences[1] - log_evidences[0])  # 2-component vs 1-component
print(f"\nBayes Factor (2-component vs 1-component): {bayes_factor:.2e}")
if bayes_factor > 100:
    strength = "decisive"
elif bayes_factor > 10:
    strength = "strong"
elif bayes_factor > 3:
    strength = "moderate"
else:
    strength = "weak"
print(f"Evidence strength: {strength} support for mixture model")

# Visualize model comparison
plt.figure(figsize=(8, 5))
bars = plt.bar(model_names, log_evidences, 
               color=['orange', 'green', 'purple'], alpha=0.7)
bars[best_idx].set_color('gold')
bars[best_idx].set_edgecolor('black')
bars[best_idx].set_linewidth(2)

plt.ylabel('Log Evidence')
plt.title('Model Evidence Comparison')
plt.xticks(rotation=45)
plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

## Summary

In this tutorial, we explored `lsbi`'s `MixtureModel` for handling complex, multi-modal data:

### Key Takeaways:

1. **When to Use Mixture Models**: When your data might come from multiple different linear processes, a single `LinearModel` will produce poor, compromise fits.

2. **Shape Requirements**: `MixtureModel` requires stacking parameters along the component dimension:
   - `mu`: `(k, n)` for k components and n parameters
   - `Sigma`: `(k, n, n)` for parameter covariances
   - `M`, `C`: Similarly stacked for each component

3. **Interpretation**: The posterior provides:
   - Individual parameter estimates for each component
   - Updated component weights showing which components best explain the data
   - Uncertainty quantification for all estimates

4. **Monte Carlo Requirements**: Unlike `LinearModel`, mixture models need Monte Carlo sampling for some calculations (like KL divergence).

5. **Model Selection**: Use evidence comparison to determine the optimal number of components. More components aren't always better due to the complexity penalty built into Bayesian model comparison.

### Best Practices:

- **Start Simple**: Try a single `LinearModel` first
- **Use Evidence**: Let the data tell you how many components are needed
- **Check Convergence**: Use sufficient Monte Carlo samples for stable estimates
- **Visualize Results**: Always plot your mixture components to understand what the model learned

### When Mixture Models Excel:

- **Multi-modal data**: Data that clearly comes from different processes
- **Regime changes**: Time series or experiments with distinct phases
- **Population studies**: Different subgroups with different linear relationships
- **Robust regression**: Handling outliers by treating them as a separate component

The mixture model framework in `lsbi` provides a powerful tool for capturing complexity while maintaining the analytical tractability and performance benefits of linear Bayesian inference!