# SPD Manifold: Robust Covariance Matrix Estimation

This notebook demonstrates robust covariance matrix estimation on the Symmetric Positive Definite (SPD) manifold. We compare standard maximum likelihood estimation with manifold-based robust estimation that is resilient to outliers.

## Applications
- **Computer vision**: Robust covariance descriptors for image classification
- **Finance**: Portfolio optimization with heavy-tailed return distributions  
- **Signal processing**: Covariance matrix estimation in the presence of noise
- **Machine learning**: Robust Gaussian mixture model parameter estimation

## Mathematical Background
The SPD manifold P(n) = {X ∈ R^{n×n} : X = X^T, X ≻ 0} equipped with the affine-invariant Riemannian metric provides a natural framework for covariance matrix estimation that respects the geometric structure of positive definite matrices.

## Setup and Imports

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

import riemannax as rx

# Enable 64-bit precision for better numerical stability
jax.config.update('jax_enable_x64', True)

## Data Generation with Outliers

We generate multivariate data with a known covariance structure and add outliers to test robustness.

In [None]:
def generate_multivariate_data_with_outliers(key, n_samples=200, n_features=4, outlier_ratio=0.1):
    """Generate multivariate data with outliers for robust estimation testing."""
    keys = jax.random.split(key, 4)

    # True covariance structure with correlation
    true_cov = jnp.array([
        [1.0, 0.5, 0.2, 0.1],
        [0.5, 1.0, 0.3, 0.0], 
        [0.2, 0.3, 1.0, 0.4], 
        [0.1, 0.0, 0.4, 1.0]
    ])

    # Generate normal samples
    n_clean = int(n_samples * (1 - outlier_ratio))
    clean_data = jax.random.multivariate_normal(keys[0], jnp.zeros(n_features), true_cov, (n_clean,))

    # Generate outlier samples (heavy-tailed distribution)
    n_outliers = n_samples - n_clean
    outlier_scale = 3.0  # Scale factor for outliers
    outlier_data = outlier_scale * jax.random.multivariate_normal(
        keys[1], jnp.zeros(n_features), jnp.eye(n_features), (n_outliers,)
    )

    # Combine data
    data = jnp.vstack([clean_data, outlier_data])

    # Shuffle the data
    perm = jax.random.permutation(keys[2], n_samples)
    data = data[perm]

    return data, true_cov

# Generate test data
key = jax.random.key(42)
data, true_cov = generate_multivariate_data_with_outliers(key, n_samples=200, n_features=4, outlier_ratio=0.15)

print(f"Generated data shape: {data.shape}")
print(f"Outlier ratio: 15%")
print(f"\nTrue covariance matrix:")
print(true_cov)
print(f"\nData statistics:")
print(f"Mean: {jnp.mean(data, axis=0)}")
print(f"Std:  {jnp.std(data, axis=0)}")

## Standard vs Robust Estimation

Let's compare traditional maximum likelihood estimation with robust manifold-based estimation.

In [None]:
def mle_covariance(data):
    """Standard maximum likelihood estimation of covariance matrix."""
    n_samples = data.shape[0]
    centered_data = data - jnp.mean(data, axis=0)
    return (centered_data.T @ centered_data) / (n_samples - 1)

def robust_manifold_covariance_cost(cov_matrix, data, huber_delta=1.5):
    """
    Robust covariance estimation cost function using Huber loss.

    This cost function uses the Mahalanobis distance with Huber loss
    to reduce the influence of outliers on covariance estimation.
    """
    centered_data = data - jnp.mean(data, axis=0)

    # Compute Mahalanobis distances
    cov_inv = jnp.linalg.inv(cov_matrix)
    mahalanobis_sq = jnp.sum((centered_data @ cov_inv) * centered_data, axis=1)

    # Apply Huber loss to reduce outlier influence
    def huber_loss(x, delta):
        condition = jnp.abs(x) <= delta
        quadratic = 0.5 * x**2
        linear = delta * (jnp.abs(x) - 0.5 * delta)
        return jnp.where(condition, quadratic, linear)

    # Negative log-likelihood with Huber loss
    log_det_term = jnp.log(jnp.linalg.det(cov_matrix))
    huber_distances = jax.vmap(lambda x: huber_loss(jnp.sqrt(x), huber_delta))(mahalanobis_sq)

    return log_det_term + jnp.mean(huber_distances)

# Standard MLE estimation
mle_cov = mle_covariance(data)
print("Standard MLE covariance estimate:")
print(mle_cov)
print(f"\nMLE Frobenius error: {jnp.linalg.norm(mle_cov - true_cov, 'fro'):.4f}")

## Robust Estimation on SPD Manifold

Now we perform robust covariance estimation using Riemannian optimization on the SPD manifold.

In [None]:
def optimize_covariance_manifold(data, method="radam", max_iterations=100):
    """Optimize covariance matrix on SPD manifold using robust estimation."""
    n_features = data.shape[1]
    spd = rx.SymmetricPositiveDefinite(n=n_features)

    # Define robust cost function
    def cost_fn(C):
        return robust_manifold_covariance_cost(C, data)

    # Create optimization problem
    problem = rx.RiemannianProblem(spd, cost_fn)

    # Initialize with MLE as starting point
    # Ensure it's positive definite by adding small regularization
    initial_cov = mle_cov + 1e-6 * jnp.eye(n_features)

    # Solve the optimization problem
    result = rx.minimize(
        problem,
        initial_cov,
        method=method,
        options={"learning_rate": 0.01, "max_iterations": max_iterations}
    )

    return result

# Perform robust estimation
print("Optimizing robust covariance estimate on SPD manifold...")
robust_result = optimize_covariance_manifold(data)

print(f"\nOptimization completed in {robust_result.niter} iterations")
print(f"Final cost: {robust_result.fun:.6f}")

robust_cov = robust_result.x
print("\nRobust manifold covariance estimate:")
print(robust_cov)
print(f"\nRobust Frobenius error: {jnp.linalg.norm(robust_cov - true_cov, 'fro'):.4f}")

## Results Comparison and Analysis

In [None]:
# Compare all three covariance estimates
print("=" * 60)
print("COVARIANCE ESTIMATION COMPARISON")
print("=" * 60)

# Compute error metrics
mle_error = jnp.linalg.norm(mle_cov - true_cov, 'fro')
robust_error = jnp.linalg.norm(robust_cov - true_cov, 'fro')

print(f"\nFrobenius Norm Errors:")
print(f"MLE Error:       {mle_error:.4f}")
print(f"Robust Error:    {robust_error:.4f}")
print(f"Improvement:     {((mle_error - robust_error) / mle_error * 100):.1f}%")

# Compute condition numbers
true_cond = jnp.linalg.cond(true_cov)
mle_cond = jnp.linalg.cond(mle_cov)
robust_cond = jnp.linalg.cond(robust_cov)

print(f"\nCondition Numbers:")
print(f"True:            {true_cond:.2f}")
print(f"MLE:             {mle_cond:.2f}")
print(f"Robust:          {robust_cond:.2f}")

# Check positive definiteness
def check_spd(matrix, name):
    eigenvals = jnp.linalg.eigvals(matrix)
    min_eigval = jnp.min(eigenvals)
    print(f"{name:12} min eigenvalue: {min_eigval:.6f} {'✓' if min_eigval > 0 else '✗'}")

print(f"\nPositive Definiteness Check:")
check_spd(true_cov, "True")
check_spd(mle_cov, "MLE")
check_spd(robust_cov, "Robust")

## Visualization

In [None]:
def plot_covariance_comparison(true_cov, mle_cov, robust_cov):
    """Plot comparison of covariance matrices."""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Covariance matrices heatmaps
    matrices = [true_cov, mle_cov, robust_cov]
    titles = ["True Covariance", "MLE Estimate", "Robust Estimate"]
    
    vmin = min(mat.min() for mat in matrices)
    vmax = max(mat.max() for mat in matrices)
    
    for i, (mat, title) in enumerate(zip(matrices, titles)):
        im = axes[0, i].imshow(mat, cmap='RdBu_r', vmin=vmin, vmax=vmax)
        axes[0, i].set_title(title)
        axes[0, i].set_xlabel('Feature')
        axes[0, i].set_ylabel('Feature')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=axes[0, i], shrink=0.8)
        
        # Add text annotations
        for j in range(mat.shape[0]):
            for k in range(mat.shape[1]):
                text = axes[0, i].text(k, j, f'{mat[j, k]:.2f}', 
                                     ha="center", va="center", color="black", fontsize=10)
    
    # Error matrices
    mle_error_mat = jnp.abs(mle_cov - true_cov)
    robust_error_mat = jnp.abs(robust_cov - true_cov)
    
    error_matrices = [mle_error_mat, robust_error_mat]
    error_titles = ["MLE Error (|est - true|)", "Robust Error (|est - true|)"]
    
    error_vmax = max(mat.max() for mat in error_matrices)
    
    for i, (mat, title) in enumerate(zip(error_matrices, error_titles)):
        im = axes[1, i].imshow(mat, cmap='Reds', vmin=0, vmax=error_vmax)
        axes[1, i].set_title(title)
        axes[1, i].set_xlabel('Feature')
        axes[1, i].set_ylabel('Feature')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=axes[1, i], shrink=0.8)
        
        # Add text annotations
        for j in range(mat.shape[0]):
            for k in range(mat.shape[1]):
                text = axes[1, i].text(k, j, f'{mat[j, k]:.3f}', 
                                     ha="center", va="center", color="white", fontsize=10)
    
    # Bar plot comparison
    methods = ['MLE', 'Robust']
    errors = [mle_error, robust_error]
    colors = ['skyblue', 'lightcoral']
    
    bars = axes[1, 2].bar(methods, errors, color=colors, alpha=0.7, edgecolor='black')
    axes[1, 2].set_title('Frobenius Norm Error Comparison')
    axes[1, 2].set_ylabel('||Estimate - True||_F')
    axes[1, 2].grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar, error in zip(bars, errors):
        height = bar.get_height()
        axes[1, 2].text(bar.get_x() + bar.get_width()/2., height + error*0.01,
                        f'{error:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    return fig

# Create comparison visualization
fig = plot_covariance_comparison(true_cov, mle_cov, robust_cov)
plt.show()

In [None]:
def plot_data_with_ellipses(data, true_cov, mle_cov, robust_cov):
    """Plot data points with confidence ellipses for different covariance estimates."""
    # Plot first two dimensions for visualization
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    data_2d = data[:, :2]
    mean_2d = jnp.mean(data_2d, axis=0)
    
    covariances = [true_cov[:2, :2], mle_cov[:2, :2], robust_cov[:2, :2]]
    titles = ['True Covariance', 'MLE Estimate', 'Robust Estimate']
    colors = ['green', 'blue', 'red']
    
    for i, (cov_2d, title, color) in enumerate(zip(covariances, titles, colors)):
        # Scatter plot of data
        axes[i].scatter(data_2d[:, 0], data_2d[:, 1], alpha=0.6, s=30, color='gray')
        
        # Plot confidence ellipses
        from matplotlib.patches import Ellipse
        
        for confidence in [0.5, 0.95]:  # 50% and 95% confidence
            # Chi-square quantile for 2D
            chi2_val = -2 * np.log(1 - confidence)
            
            # Eigendecomposition for ellipse orientation
            eigenvals, eigenvecs = jnp.linalg.eigh(cov_2d)
            
            # Ellipse parameters
            angle = np.degrees(np.arctan2(eigenvecs[1, 0], eigenvecs[0, 0]))
            width = 2 * np.sqrt(chi2_val * eigenvals[0])
            height = 2 * np.sqrt(chi2_val * eigenvals[1])
            
            # Create ellipse
            alpha = 0.3 if confidence == 0.95 else 0.6
            ellipse = Ellipse(mean_2d, width, height, angle=angle, 
                            facecolor=color, alpha=alpha, edgecolor=color, linewidth=2)
            axes[i].add_patch(ellipse)
        
        axes[i].set_title(f'{title}\nConfidence Ellipses (50% & 95%)')
        axes[i].set_xlabel('Feature 1')
        axes[i].set_ylabel('Feature 2')
        axes[i].grid(True, alpha=0.3)
        axes[i].axis('equal')
    
    plt.tight_layout()
    return fig

# Create ellipse visualization
ellipse_fig = plot_data_with_ellipses(data, true_cov, mle_cov, robust_cov)
plt.show()

## Manifold Properties Analysis

In [None]:
# Analyze SPD manifold properties
n_features = data.shape[1]
spd = rx.SymmetricPositiveDefinite(n=n_features)

print("SPD Manifold Properties:")
print(f"Matrix dimension: {n_features}x{n_features}")
print(f"Manifold dimension: {spd.dimension}")
print(f"Ambient space dimension: {spd.ambient_dimension}")

# Test geodesic distance between estimates
mle_robust_dist = spd.dist(mle_cov, robust_cov)
true_mle_dist = spd.dist(true_cov, mle_cov)
true_robust_dist = spd.dist(true_cov, robust_cov)

print(f"\nGeodesic Distances on SPD Manifold:")
print(f"True ↔ MLE:       {true_mle_dist:.4f}")
print(f"True ↔ Robust:    {true_robust_dist:.4f}")
print(f"MLE ↔ Robust:     {mle_robust_dist:.4f}")

# Test tangent space properties
random_tangent = jax.random.normal(jax.random.key(999), (n_features, n_features))
# Make it symmetric (tangent vectors to SPD are symmetric matrices)
random_tangent = (random_tangent + random_tangent.T) / 2

projected_tangent = spd.proj(robust_cov, random_tangent)
print(f"\nTangent space projection test:")
print(f"Original tangent symmetry error: {jnp.linalg.norm(random_tangent - random_tangent.T):.2e}")
print(f"Projected tangent symmetry error: {jnp.linalg.norm(projected_tangent - projected_tangent.T):.2e}")

print("\nRobust SPD covariance estimation completed successfully!")

## Summary

This notebook demonstrated:

1. **SPD Manifold Structure**: Working with symmetric positive definite matrices as a Riemannian manifold
2. **Robust Estimation**: Using Huber loss to reduce outlier influence in covariance estimation
3. **Manifold Optimization**: Leveraging Riemannian optimization to respect the geometric constraints
4. **Comparison Analysis**: Quantitative and visual comparison of different estimation methods
5. **Geometric Properties**: Understanding geodesic distances and tangent spaces on the SPD manifold

The results typically show that robust manifold-based estimation provides better accuracy when outliers are present, while maintaining the positive definite constraint naturally through the manifold structure.

Key advantages of the manifold approach:
- **Geometric Consistency**: Respects the natural geometry of positive definite matrices
- **Constraint Satisfaction**: Automatically maintains positive definiteness
- **Outlier Robustness**: Huber loss reduces influence of extreme values
- **Theoretical Foundation**: Based on well-established Riemannian geometry principles