In [None]:
import numpy as np
import grain.python as grain
import jax
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import multivariate_t
from scipy.optimize import linear_sum_assignment
from tqdm import trange
from onlineEM.core import em_config
from onlineEM.models import HDgmm, HDstm, HDlm

# Set plot style for professional appearance
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_palette("husl")

# High-Dimensional Mixture Models Tutorial 🚀

## 1. Gaussian Mixture Models

In [None]:
# ==================== UTILITY FUNCTIONS ====================


def generate_mixture_of_gaussians(num_samples: int, means: list, covariances: list, weights: list) -> np.ndarray:
    """Generate synthetic data from a mixture of Gaussian distributions."""
    num_components = len(means)
    component_indices = np.random.choice(num_components, size=num_samples, p=weights)
    samples = np.zeros((num_samples, len(means[0])))

    for i in range(num_components):
        indices = component_indices == i
        num_samples_component = np.sum(indices)
        if num_samples_component > 0:
            component_samples = np.random.multivariate_normal(means[i], covariances[i], num_samples_component)
            samples[indices] = component_samples
    return samples


def create_dataloader(X_array: np.ndarray, config) -> grain.DataLoader:
    """Create a standardized dataloader for training."""

    class SimpleDataSource(grain.RandomAccessDataSource):
        def __init__(self, data: np.ndarray):
            self.data = data

        def __len__(self) -> int:
            return len(self.data)

        def __getitem__(self, index: int) -> np.ndarray:
            return self.data[index]

    data_source = SimpleDataSource(X_array)
    sampler = grain.IndexSampler(num_records=len(data_source), num_epochs=config.num_epochs, shuffle=True, seed=42)
    transforms = [grain.Batch(batch_size=config.batch_size, drop_remainder=True)]
    return grain.DataLoader(data_source=data_source, sampler=sampler, operations=transforms, worker_count=0)

In [None]:
# Configuration and synthetic data setup
config = em_config(
    n_components=3,
    num_features=5,
    num_epochs=30,
    batch_size=256,
    n_first=20000,
    # Using automatic reduction detection (empty array) to avoid model issues
)

# Component parameters: structured variance pattern for clear eigenvalue structure
mean1 = np.array([2.0, 4.0, 1.0, 0.5, 0.2])  # Scaled up for better separation
covariance1 = np.diag([10.0, 8.0, 0.1, 0.1, 0.1])  # Clear eigenvalue hierarchy
weight1 = 0.4

mean2 = np.array([6.0, 2.0, 5.0, 1.0, 0.8])
covariance2 = np.diag([8.0, 6.0, 4.0, 0.25, 0.25])  # Clear eigenvalue hierarchy
weight2 = 0.3

mean3 = np.array([0.0, 0.0, 0.0, 0.0, 0.0])
covariance3 = np.diag([3.0, 2.0, 0.08, 0.08, 0.08])  # Clear eigenvalue hierarchy
weight3 = 0.3

means = [mean1, mean2, mean3]
covariances = [covariance1, covariance2, covariance3]
weights = [weight1, weight2, weight3]

print(f"📊 Mixture components configured with total weight: {sum(weights):.1f}")
print("🔧 Using automatic reduction detection with clear eigenvalue structure")

In [None]:
# Generate synthetic dataset
num_samples = 100000
X_array = generate_mixture_of_gaussians(num_samples, means, covariances, weights)
dataloader = create_dataloader(X_array, config)

print(f"🎲 Generated {num_samples:,} samples from Gaussian mixture")

In [None]:
# ==================== VISUALIZATION FUNCTIONS ====================


def plot_data_overview(X_array: np.ndarray, title: str = "Dataset Visualization"):
    """Create standardized data visualization with scatter plots and distributions."""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle(title, fontsize=16, fontweight="bold")

    # Pairwise scatter plots
    combinations = [(0, 1), (0, 2), (1, 2)]
    for i, (dim1, dim2) in enumerate(combinations):
        ax = axes[0, i]
        ax.scatter(X_array[:5000, dim1], X_array[:5000, dim2], alpha=0.6, s=1, c=np.arange(5000), cmap="viridis")
        ax.set_xlabel(f"Dimension {dim1}")
        ax.set_ylabel(f"Dimension {dim2}")
        ax.set_title(f"Dims {dim1} vs {dim2}")
        ax.grid(True, alpha=0.3)

    # Distribution plots
    for i in range(3):
        ax = axes[1, i]
        ax.hist(X_array[:, i], bins=50, alpha=0.7, density=True, color=sns.color_palette()[i])
        ax.set_xlabel(f"Dimension {i}")
        ax.set_ylabel("Density")
        ax.set_title(f"Distribution of Dim {i}")
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Summary statistics
    print("📊 Dataset Summary:")
    print(f"Mean: {np.mean(X_array, axis=0)}")
    print(f"Variance: {np.var(X_array, axis=0)}")


# Visualize Gaussian mixture data
plot_data_overview(X_array, "Gaussian Mixture Dataset")

In [None]:
# ==================== TRAINING FUNCTIONS ====================


def train_hd_model(model, dataloader, config):
    """Unified training function for HD models."""

    def schedule(k):
        return (1 - 10e-10) * (k + 1) ** (-6 / 10)

    # Initialize model first to get updated config
    tmp_it = iter(dataloader)
    X = np.concatenate([next(tmp_it) for _ in range(config.n_first // config.batch_size)])
    updated_config, em_params, em_stats = model.init(X, config)

    # Create compiled training steps with UPDATED config
    @jax.jit
    def burnin_step(batch, step, params, stats):
        return model.burnin(batch, step, params, stats, updated_config, schedule)

    @jax.jit
    def train_step(batch, step, params, stats):
        return model.update(batch, step, params, stats, updated_config, schedule)

    # Burn-in phase
    burning_iter = iter(dataloader)
    for step in trange(2 * config.num_features, desc="Burn-in", colour="green"):
        batch = next(burning_iter)
        em_stats = burnin_step(batch, step, em_params, em_stats)

    # Training phase
    train_iter = iter(dataloader)
    for step in trange(config.num_epochs, desc="Training", colour="blue"):
        for _ in range(config.num_features):
            batch = next(train_iter)
            em_params, em_stats = train_step(batch, step, em_params, em_stats)

    return em_params, em_stats, updated_config


# Train HD-GMM model
print("🏗️ Training HD-GMM model...")
model_hdgmm = HDgmm()
params_hdgmm, stats_hdgmm, config_hdgmm = train_hd_model(model_hdgmm, dataloader, config)
print(f"✅ HD-GMM training completed with reductions: {config_hdgmm.reduction}")

In [None]:
# ==================== EVALUATION FUNCTIONS ====================


def find_best_permutation(true_means, pred_means):
    """Find optimal component permutation using Hungarian algorithm."""
    n_components = len(true_means)
    cost_matrix = np.zeros((n_components, n_components))

    for i in range(n_components):
        for j in range(n_components):
            cost_matrix[i, j] = np.linalg.norm(true_means[i] - pred_means[j])

    _, col_indices = linear_sum_assignment(cost_matrix)
    return col_indices


def permute_parameters(params, permutation):
    """Apply permutation to HD model parameters."""
    return type(params)(
        pi=params.pi[permutation],
        mu=params.mu[permutation],
        A=[params.A[i] for i in permutation],
        b=params.b[permutation],
        D_tilde=[params.D_tilde[i] for i in permutation],
        **({"nu": params.nu[permutation]} if hasattr(params, "nu") else {}),
    )


def evaluate_model_performance(true_means, true_weights, true_covariances, pred_params, model_name="HD Model"):
    """Comprehensive model evaluation with optimal alignment."""
    # Find optimal permutation and align parameters
    perm = find_best_permutation(true_means, pred_params.mu)
    aligned_params = permute_parameters(pred_params, perm)

    # Calculate errors
    mean_errors = [np.linalg.norm(aligned_params.mu[i] - true_means[i]) for i in range(len(true_means))]
    weight_error = np.linalg.norm(np.array(true_weights) - aligned_params.pi)

    print(f"\n📊 {model_name} Performance:")
    print(f"Mean reconstruction error: {np.mean(mean_errors):.4f} ± {np.std(mean_errors):.4f}")
    print(f"Weight reconstruction error: {weight_error:.4f}")
    print(f"Optimal permutation: {perm}")

    return aligned_params, mean_errors


# Evaluate HD-GMM
aligned_params_hdgmm, errors_hdgmm = evaluate_model_performance(means, weights, covariances, params_hdgmm, "HD-GMM")

In [None]:
def plot_model_comparison(
    true_means, true_weights, true_covariances, aligned_params, mean_errors, model_name, model_color
):
    """Create comprehensive model comparison visualization."""
    fig, axes = plt.subplots(3, 3, figsize=(18, 15))
    fig.suptitle(f"{model_name}: Ground Truth vs Predictions", fontsize=16, fontweight="bold")

    # Row 1: Mixing weights comparison
    ax = axes[0, 0]
    x_pos = np.arange(3)
    width = 0.35
    ax.bar(x_pos - width / 2, true_weights, width, label="True", color="gold", alpha=0.8)
    ax.bar(x_pos + width / 2, aligned_params.pi, width, label=model_name, color=model_color, alpha=0.8)
    ax.set_xlabel("Components")
    ax.set_ylabel("Mixing Weights")
    ax.set_title("Mixing Weights Comparison")
    ax.set_xticks(x_pos)
    ax.set_xticklabels(["Comp 1", "Comp 2", "Comp 3"])
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Row 1: Mean reconstruction errors
    ax = axes[0, 1]
    bars = ax.bar(range(3), mean_errors, color=model_color, alpha=0.8)
    ax.set_xlabel("Components")
    ax.set_ylabel("L2 Error")
    ax.set_title("Mean Vector Reconstruction Error")
    ax.set_xticks(range(3))
    ax.set_xticklabels(["Comp 1", "Comp 2", "Comp 3"])
    ax.grid(True, alpha=0.3)
    for i, bar in enumerate(bars):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + max(mean_errors) * 0.01,
            f"{height:.3f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    # Row 1: Overall model summary metrics
    ax = axes[0, 2]
    metrics = ['Mean Error', 'Weight Error', 'Avg Reduction']
    # Calculate average reduction from config if available
    avg_reduction = np.mean([len(aligned_params.A[i]) for i in range(3)])
    values = [np.mean(mean_errors),
              np.linalg.norm(np.array(true_weights) - aligned_params.pi),
              avg_reduction]
    bars = ax.bar(metrics, values, color=[model_color, model_color, 'lightblue'], alpha=0.8)
    ax.set_ylabel("Value")
    ax.set_title("Overall Model Performance")
    ax.grid(True, alpha=0.3)
    for i, bar in enumerate(bars):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + max(values) * 0.01,
            f"{height:.3f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    # Row 2: Mean vectors comparison for each component
    component_names = ["Component 1", "Component 2", "Component 3"]
    for comp in range(3):
        ax = axes[1, comp]
        dims = np.arange(len(true_means[0]))
        ax.plot(dims, true_means[comp], "o-", color="gold", linewidth=3, markersize=8, label="True")
        ax.plot(dims, aligned_params.mu[comp], "s-", color=model_color, linewidth=2, markersize=6, label=model_name)
        ax.set_xlabel("Dimension")
        ax.set_ylabel("Mean Value")
        ax.set_title(f"{component_names[comp]} Mean Vector")
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xticks(dims)

    # Row 3: Eigenvalue spectra for all components
    for comp in range(3):
        ax = axes[2, comp]
        true_eigenvals = np.sort(np.linalg.eigvals(true_covariances[comp]))[::-1]
        # Reconstruct eigenvalues from HD parameters
        hd_A = np.array(aligned_params.A[comp])
        hd_b = aligned_params.b[comp]
        n_leading = len(hd_A)
        n_tail = len(true_eigenvals) - n_leading
        hd_eigenvals = np.concatenate([hd_A, np.full(n_tail, hd_b)])
        hd_eigenvals = np.sort(hd_eigenvals)[::-1]

        x_pos = np.arange(len(true_eigenvals))
        ax.semilogy(x_pos, true_eigenvals, "o-", color="gold", linewidth=3, markersize=8, label="True")
        ax.semilogy(x_pos, hd_eigenvals, "s-", color=model_color, linewidth=2, markersize=6, label=model_name)
        ax.set_xlabel("Eigenvalue Index")
        ax.set_ylabel("Eigenvalue (log scale)")
        ax.set_title(f"{component_names[comp]} Eigenvalue Spectrum")
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xticks(x_pos)

    plt.tight_layout()
    plt.show()


# Visualize HD-GMM results
plot_model_comparison(means, weights, covariances, aligned_params_hdgmm, errors_hdgmm, "HD-GMM", "lightcoral")

## 2. Student-t Mixture Models 📊

In [None]:
def generate_mixture_of_student(
    num_samples: int, means: list, covariances: list, nus: list, weights: list
) -> np.ndarray:
    """Generate synthetic data from multivariate Student-t mixture."""
    num_components = len(means)
    component_indices = np.random.choice(num_components, size=num_samples, p=weights)
    samples = np.zeros((num_samples, len(means[0])))

    for i in range(num_components):
        indices = component_indices == i
        num_samples_component = np.sum(indices)
        if num_samples_component > 0:
            component_samples = multivariate_t.rvs(
                loc=means[i], shape=covariances[i], df=nus[i], size=num_samples_component
            )
            samples[indices] = component_samples
    return samples


# Student-t mixture parameters (same spatial structure, added degrees of freedom)
nus = [5.0, 5.0, 5.0]  # Moderate heavy tails
X_array_student = generate_mixture_of_student(num_samples, means, covariances, nus, weights)
dataloader_student = create_dataloader(X_array_student, config)

print(f"🎲 Generated {num_samples:,} samples from Student-t mixture with ν={nus}")

In [None]:
plot_data_overview(X_array_student, "Student-t Mixture Dataset")

In [None]:
# Train HD-STM model
print("🏗️ Training HD-STM model...")
model_hdstm = HDstm()
params_hdstm, stats_hdstm, config_hdstm = train_hd_model(model_hdstm, dataloader_student, config)
print(f"✅ HD-STM training completed with reductions: {config_hdstm.reduction}")

In [None]:
# Evaluate HD-STM with degrees of freedom analysis
aligned_params_hdstm, errors_hdstm = evaluate_model_performance(means, weights, covariances, params_hdstm, "HD-STM")

# Additional analysis for degrees of freedom
print("\n🎓 Degrees of Freedom Analysis:")
nu_errors = [abs(aligned_params_hdstm.nu[i] - nus[i]) for i in range(3)]
print(f"True ν: {nus}")
print(f"Predicted ν: {aligned_params_hdstm.nu}")
print(f"ν reconstruction error: {np.mean(nu_errors):.4f} ± {np.std(nu_errors):.4f}")

In [None]:
def plot_stm_comparison(
    true_means, true_weights, true_covariances, true_nus, aligned_params, mean_errors, model_name, model_color
):
    """Create comprehensive HD-STM comparison visualization with degrees of freedom."""
    fig, axes = plt.subplots(4, 3, figsize=(18, 20))
    fig.suptitle(f"{model_name}: Ground Truth vs Predictions", fontsize=16, fontweight="bold")

    # Row 1: Mixing weights, mean errors, and degrees of freedom
    ax = axes[0, 0]
    x_pos = np.arange(3)
    width = 0.35
    ax.bar(x_pos - width / 2, true_weights, width, label="True", color="gold", alpha=0.8)
    ax.bar(x_pos + width / 2, aligned_params.pi, width, label=model_name, color=model_color, alpha=0.8)
    ax.set_xlabel("Components")
    ax.set_ylabel("Mixing Weights")
    ax.set_title("Mixing Weights Comparison")
    ax.set_xticks(x_pos)
    ax.set_xticklabels(["Comp 1", "Comp 2", "Comp 3"])
    ax.legend()
    ax.grid(True, alpha=0.3)

    ax = axes[0, 1]
    bars = ax.bar(range(3), mean_errors, color=model_color, alpha=0.8)
    ax.set_xlabel("Components")
    ax.set_ylabel("L2 Error")
    ax.set_title("Mean Vector Reconstruction Error")
    ax.set_xticks(range(3))
    ax.set_xticklabels(["Comp 1", "Comp 2", "Comp 3"])
    ax.grid(True, alpha=0.3)
    for i, bar in enumerate(bars):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + max(mean_errors) * 0.01,
            f"{height:.3f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    # Degrees of freedom comparison
    ax = axes[0, 2]
    ax.bar(x_pos - width / 2, true_nus, width, label="True ν", color="gold", alpha=0.8)
    ax.bar(x_pos + width / 2, aligned_params.nu, width, label="Predicted ν", color=model_color, alpha=0.8)
    ax.set_xlabel("Components")
    ax.set_ylabel("Degrees of Freedom (ν)")
    ax.set_title("Degrees of Freedom Comparison")
    ax.set_xticks(x_pos)
    ax.set_xticklabels(["Comp 1", "Comp 2", "Comp 3"])
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Row 2: ν analysis plots
    ax = axes[1, 0]
    # Scatter plot: True vs Predicted ν
    ax.scatter(true_nus, aligned_params.nu, s=100, alpha=0.7, color=model_color, edgecolors="black")
    nu_min = min(min(true_nus), min(aligned_params.nu)) - 0.5
    nu_max = max(max(true_nus), max(aligned_params.nu)) + 0.5
    ax.plot([nu_min, nu_max], [nu_min, nu_max], "r--", alpha=0.7, label="Perfect Prediction")
    ax.set_xlabel("True ν")
    ax.set_ylabel("Predicted ν")
    ax.set_title("True vs Predicted Degrees of Freedom")
    ax.legend()
    ax.grid(True, alpha=0.3)
    for i, (true_nu, pred_nu) in enumerate(zip(true_nus, aligned_params.nu)):
        ax.annotate(f"C{i + 1}", (true_nu, pred_nu), xytext=(5, 5), textcoords="offset points")

    ax = axes[1, 1]
    # ν error analysis
    nu_errors = [abs(aligned_params.nu[i] - true_nus[i]) for i in range(3)]
    bars = ax.bar(range(3), nu_errors, color=model_color, alpha=0.8)
    ax.set_xlabel("Components")
    ax.set_ylabel("Absolute Error in ν")
    ax.set_title("Degrees of Freedom Reconstruction Error")
    ax.set_xticks(range(3))
    ax.set_xticklabels(["Comp 1", "Comp 2", "Comp 3"])
    ax.grid(True, alpha=0.3)
    for bar in bars:
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + max(nu_errors) * 0.01,
            f"{height:.3f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    # Overall STM performance summary
    ax = axes[1, 2]
    metrics = ['Mean Error', 'Weight Error', 'ν Error']
    values = [np.mean(mean_errors),
              np.linalg.norm(np.array(true_weights) - aligned_params.pi),
              np.mean(nu_errors)]
    bars = ax.bar(metrics, values, color=[model_color, model_color, 'lightgreen'], alpha=0.8)
    ax.set_ylabel("Error Value")
    ax.set_title("STM Performance Summary")
    ax.grid(True, alpha=0.3)
    for i, bar in enumerate(bars):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + max(values) * 0.01,
            f"{height:.3f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )

    # Row 3: Mean vectors comparison
    component_names = ["Component 1", "Component 2", "Component 3"]
    for comp in range(3):
        ax = axes[2, comp]
        dims = np.arange(len(true_means[0]))
        ax.plot(dims, true_means[comp], "o-", color="gold", linewidth=3, markersize=8, label="True")
        ax.plot(dims, aligned_params.mu[comp], "^-", color=model_color, linewidth=2, markersize=6, label=model_name)
        ax.set_xlabel("Dimension")
        ax.set_ylabel("Mean Value")
        ax.set_title(f"{component_names[comp]} Mean Vector")
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xticks(dims)

    # Row 4: Eigenvalue spectra for all components
    for comp in range(3):
        ax = axes[3, comp]
        true_eigenvals = np.sort(np.linalg.eigvals(true_covariances[comp]))[::-1]
        hd_A = np.array(aligned_params.A[comp])
        hd_b = aligned_params.b[comp]
        n_leading = len(hd_A)
        n_tail = len(true_eigenvals) - n_leading
        hd_eigenvals = np.concatenate([hd_A, np.full(n_tail, hd_b)])
        hd_eigenvals = np.sort(hd_eigenvals)[::-1]

        x_pos = np.arange(len(true_eigenvals))
        ax.semilogy(x_pos, true_eigenvals, "o-", color="gold", linewidth=3, markersize=8, label="True")
        ax.semilogy(x_pos, hd_eigenvals, "^-", color=model_color, linewidth=2, markersize=6, label=model_name)
        ax.set_xlabel("Eigenvalue Index")
        ax.set_ylabel("Eigenvalue (log scale)")
        ax.set_title(f"{component_names[comp]} Eigenvalue Spectrum")
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xticks(x_pos)

    plt.tight_layout()
    plt.show()


# Visualize HD-STM results with degrees of freedom analysis
plot_stm_comparison(means, weights, covariances, nus, aligned_params_hdstm, errors_hdstm, "HD-STM", "darkorange")

## 3. Laplace Mixture Models 🔺

Multivariate Laplace distributions have sharp peaks and exponential tails, useful for sparse data modeling.

In [None]:
def generate_mixture_of_multivariate_laplace(
    num_samples: int, means: list, covariances: list, weights: list
) -> np.ndarray:
    """Generate synthetic data from multivariate Laplace mixture using hierarchical representation."""
    num_components = len(means)
    d = len(means[0])
    component_indices = np.random.choice(num_components, size=num_samples, p=weights)
    samples = np.zeros((num_samples, d))

    for i in range(num_components):
        indices = component_indices == i
        num_samples_component = np.sum(indices)
        if num_samples_component > 0:
            # Hierarchical representation: X|W ~ N(μ, W*Σ), W ~ Exp(1)
            W = np.random.exponential(1.0, size=num_samples_component)
            component_samples = np.zeros((num_samples_component, d))
            for j in range(num_samples_component):
                scaled_cov = W[j] * covariances[i]
                component_samples[j] = np.random.multivariate_normal(means[i], scaled_cov)
            samples[indices] = component_samples
    return samples


# Laplace mixture parameters (amplified means for visibility)
means_laplace = [mean * 5 for mean in means]  # Scale up for better separation
X_array_laplace = generate_mixture_of_multivariate_laplace(num_samples, means_laplace, covariances, weights)
dataloader_laplace = create_dataloader(X_array_laplace, config)

print(f"🎲 Generated {num_samples:,} samples from Laplace mixture")

In [None]:
plot_data_overview(X_array_laplace, "Laplace Mixture Dataset")

In [None]:
# Train HD-LM model
print("🏗️ Training HD-LM model...")
model_hdlm = HDlm()
params_hdlm, stats_hdlm, config_hdlm = train_hd_model(model_hdlm, dataloader_laplace, config)
print(f"✅ HD-LM training completed with reductions: {config_hdlm.reduction}")

In [None]:
# Evaluate HD-LM
aligned_params_hdlm, errors_hdlm = evaluate_model_performance(means_laplace, weights, covariances, params_hdlm, "HD-LM")

In [None]:
plot_model_comparison(means_laplace, weights, covariances, aligned_params_hdlm, errors_hdlm, "HD-LM", "mediumseagreen")

## 4. Comparative Analysis 📈

In [None]:
# ==================== FINAL COMPARISON ====================

# Performance summary table
models = ["HD-GMM", "HD-STM", "HD-LM"]
mean_errors = [np.mean(errors_hdgmm), np.mean(errors_hdstm), np.mean(errors_hdlm)]
weight_errors = [
    np.linalg.norm(np.array(weights) - aligned_params_hdgmm.pi),
    np.linalg.norm(np.array(weights) - aligned_params_hdstm.pi),
    np.linalg.norm(np.array(weights) - aligned_params_hdlm.pi),
]

print("🏆 FINAL PERFORMANCE COMPARISON")
print("=" * 50)
for i, model in enumerate(models):
    print(f"{model:6} | Mean Error: {mean_errors[i]:.4f} | Weight Error: {weight_errors[i]:.4f}")

# Add degrees of freedom analysis for HD-STM
nu_errors = [abs(aligned_params_hdstm.nu[i] - nus[i]) for i in range(3)]
print(f"\nHD-STM degrees of freedom error: {np.mean(nu_errors):.4f}")

# Comprehensive visualization comparison
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle("High-Dimensional Mixture Models: Comparative Performance", fontsize=16, fontweight="bold")

# Mean reconstruction errors
ax = axes[0, 0]
bars = ax.bar(models, mean_errors, color=["lightcoral", "darkorange", "mediumseagreen"], alpha=0.8)
ax.set_ylabel("Mean L2 Error")
ax.set_title("Mean Vector Reconstruction Error")
ax.grid(True, alpha=0.3)
for bar in bars:
    height = bar.get_height()
    ax.text(
        bar.get_x() + bar.get_width() / 2.0,
        height + max(mean_errors) * 0.01,
        f"{height:.4f}",
        ha="center",
        va="bottom",
        fontsize=9,
    )

# Weight reconstruction errors
ax = axes[0, 1]
bars = ax.bar(models, weight_errors, color=["lightcoral", "darkorange", "mediumseagreen"], alpha=0.8)
ax.set_ylabel("Weight L2 Error")
ax.set_title("Mixing Weight Reconstruction Error")
ax.grid(True, alpha=0.3)
for bar in bars:
    height = bar.get_height()
    ax.text(
        bar.get_x() + bar.get_width() / 2.0,
        height + max(weight_errors) * 0.01,
        f"{height:.4f}",
        ha="center",
        va="bottom",
        fontsize=9,
    )

# Eigenvalue reconstruction quality (Component 1 example)
ax = axes[0, 2]
params_list = [aligned_params_hdgmm, aligned_params_hdstm, aligned_params_hdlm]
colors = ["lightcoral", "darkorange", "mediumseagreen"]
markers = ["s", "^", "D"]

# Use Gaussian covariances for comparison (all models)
true_eigenvals = np.sort(np.linalg.eigvals(covariances[0]))[::-1]
x_pos = np.arange(len(true_eigenvals))
ax.semilogy(x_pos, true_eigenvals, "o-", color="gold", linewidth=3, markersize=8, label="True")

for i, (params, model, color, marker) in enumerate(zip(params_list, models, colors, markers)):
    hd_A = np.array(params.A[0])
    hd_b = params.b[0]
    n_leading = len(hd_A)
    n_tail = len(true_eigenvals) - n_leading
    hd_eigenvals = np.concatenate([hd_A, np.full(n_tail, hd_b)])
    hd_eigenvals = np.sort(hd_eigenvals)[::-1]
    ax.semilogy(x_pos, hd_eigenvals, f"{marker}-", color=color, linewidth=2, markersize=6, label=model)

ax.set_xlabel("Eigenvalue Index")
ax.set_ylabel("Eigenvalue (log scale)")
ax.set_title("Component 1 Eigenvalue Comparison")
ax.legend()
ax.grid(True, alpha=0.3)

# Model-specific parameter analysis
ax = axes[1, 0]
# Reduction analysis
reductions = [np.mean(config_hdgmm.reduction), np.mean(config_hdstm.reduction), np.mean(config_hdlm.reduction)]
bars = ax.bar(models, reductions, color=["lightcoral", "darkorange", "mediumseagreen"], alpha=0.8)
ax.set_ylabel("Average Reduction")
ax.set_title("Average Dimensionality Reduction")
ax.grid(True, alpha=0.3)
for bar in bars:
    height = bar.get_height()
    ax.text(
        bar.get_x() + bar.get_width() / 2.0,
        height + max(reductions) * 0.01,
        f"{height:.1f}",
        ha="center",
        va="bottom",
        fontsize=9,
    )

# Degrees of freedom for HD-STM
ax = axes[1, 1]
ax.bar(
    ["True ν", "HD-STM ν"], [np.mean(nus), np.mean(aligned_params_hdstm.nu)], color=["gold", "darkorange"], alpha=0.8
)
ax.set_ylabel("Degrees of Freedom (ν)")
ax.set_title("HD-STM: Degrees of Freedom")
ax.grid(True, alpha=0.3)
for i, height in enumerate([np.mean(nus), np.mean(aligned_params_hdstm.nu)]):
    ax.text(i, height + 0.1, f"{height:.2f}", ha="center", va="bottom", fontsize=9)

# Compression efficiency
ax = axes[1, 2]
# Calculate compression ratios
total_params_full = [3 * 5 * (5 + 1) // 2 for _ in models]  # Full covariance matrices
total_params_hd = [
    sum(len(aligned_params_hdgmm.A[i]) for i in range(3)) + 3,  # A + b parameters
    sum(len(aligned_params_hdstm.A[i]) for i in range(3)) + 3,  # A + b parameters
    sum(len(aligned_params_hdlm.A[i]) for i in range(3)) + 3,  # A + b parameters
]
compression_ratios = [hd / full for hd, full in zip(total_params_hd, total_params_full)]

bars = ax.bar(models, compression_ratios, color=["lightcoral", "darkorange", "mediumseagreen"], alpha=0.8)
ax.set_ylabel("Parameter Compression Ratio")
ax.set_title("Model Compression Efficiency")
ax.grid(True, alpha=0.3)
for bar in bars:
    height = bar.get_height()
    ax.text(
        bar.get_x() + bar.get_width() / 2.0,
        height + max(compression_ratios) * 0.01,
        f"{height:.2f}",
        ha="center",
        va="bottom",
        fontsize=9,
    )

plt.tight_layout()
plt.show()