# Advanced Riemannian Optimization with RiemannAX

This notebook provides an in-depth exploration of advanced Riemannian optimization concepts and their implementation in RiemannAX. We'll cover:

1. **Mathematical Foundations**: Riemannian geometry basics
2. **Manifold Operations**: Deep dive into exponential maps, parallel transport
3. **Optimization Algorithms**: Comparative analysis of SGD, Adam, and Momentum
4. **Numerical Considerations**: Stability, conditioning, and performance
5. **Advanced Applications**: Multi-manifold optimization and custom problems

**Prerequisites**: Linear algebra, basic optimization theory, Python/JAX familiarity

**Learning Objectives**:
- Understand when and why to use Riemannian optimization
- Master the RiemannAX API for complex optimization problems
- Develop intuition for manifold geometry and optimization dynamics
- Learn best practices for numerical stability and performance

In [None]:
# Essential imports
import time
from pathlib import Path

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# RiemannAX imports
import riemannax as rx

# Set up plotting
plt.style.use("seaborn-v0_8")
plt.rcParams["figure.figsize"] = (12, 8)
plt.rcParams["font.size"] = 12

# JAX configuration
jax.config.update("jax_enable_x64", True)  # Use double precision

print("RiemannAX Advanced Tutorial - Setup Complete")
print(f"JAX backend: {jax.default_backend()}")
print(f"JAX devices: {jax.devices()}")

## 1. Mathematical Foundations

### Riemannian Manifolds: Key Concepts

A **Riemannian manifold** $(M, g)$ is a smooth manifold $M$ equipped with a Riemannian metric $g$ that defines:

- **Tangent Spaces**: $T_x M$ at each point $x \in M$
- **Inner Products**: $g_x : T_x M \times T_x M \to \mathbb{R}$
- **Geodesics**: Shortest paths on the manifold
- **Exponential Map**: $\exp_x : T_x M \to M$
- **Parallel Transport**: Moving tangent vectors along curves

### Why Riemannian Optimization?

Many optimization problems have **natural constraints** that define manifold structure:
- Orthogonality constraints → Stiefel/Grassmann manifolds
- Positive definiteness → SPD manifolds  
- Unit norm constraints → Sphere manifolds

Riemannian optimization respects these constraints **exactly** at every iteration.

In [None]:
# Let's explore manifold operations with the sphere
sphere = rx.Sphere()

# Generate random points on the sphere
key = jax.random.key(42)
keys = jax.random.split(key, 5)

x = sphere.random_point(keys[0])
y = sphere.random_point(keys[1])

print("Sphere Manifold Operations Demo")
print("=" * 40)
print(f"Point x: {x}")
print(f"Point y: {y}")
print(f"||x|| = {jnp.linalg.norm(x):.6f} (should be 1.0)")
print(f"||y|| = {jnp.linalg.norm(y):.6f} (should be 1.0)")

# Logarithmic map: find tangent vector from x to y
log_xy = sphere.log(x, y)
print(f"\nLogarithmic map log_x(y): {log_xy}")
print(f"||log_x(y)|| = {jnp.linalg.norm(log_xy):.6f} (geodesic distance)")

# Verify it's in tangent space (orthogonal to x)
print(f"<x, log_x(y)> = {jnp.dot(x, log_xy):.10f} (should be ~0)")

# Exponential map: recover y from x and tangent vector
exp_result = sphere.exp(x, log_xy)
print(f"\nExponential map exp_x(log_x(y)): {exp_result}")
print(f"||y - exp_x(log_x(y))|| = {jnp.linalg.norm(y - exp_result):.10f}")

# Geodesic distance
distance = sphere.dist(x, y)
print(f"\nGeodesic distance d(x,y) = {distance:.6f}")
print(f"Euclidean distance ||x-y|| = {jnp.linalg.norm(x - y):.6f}")

## 2. Deep Dive: Manifold Operations

### Parallel Transport

Parallel transport is crucial for Riemannian optimization algorithms like Adam and Momentum. It allows us to "move" tangent vectors from one point to another while preserving their geometric properties.

In [None]:
def visualize_parallel_transport():
    """Visualize parallel transport on the sphere."""
    sphere = rx.Sphere()

    # Create a geodesic path
    key = jax.random.key(123)
    keys = jax.random.split(key, 3)

    x = sphere.random_point(keys[0])
    y = sphere.random_point(keys[1])
    v = sphere.random_tangent(keys[2], x)

    # Normalize tangent vector for visualization
    v = 0.3 * v / jnp.linalg.norm(v)

    print("Parallel Transport Demonstration")
    print("=" * 40)
    print(f"Starting point x: {x}")
    print(f"Ending point y: {y}")
    print(f"Tangent vector v at x: {v}")
    print(f"||v|| = {jnp.linalg.norm(v):.6f}")

    # Parallel transport v from x to y
    v_transported = sphere.transp(x, y, v)

    print(f"\nTransported vector at y: {v_transported}")
    print(f"||v_transported|| = {jnp.linalg.norm(v_transported):.6f}")

    # Verify transported vector is in tangent space at y
    print(f"<y, v_transported> = {jnp.dot(y, v_transported):.10f} (should be ~0)")

    # Verify norm preservation (key property of parallel transport)
    norm_preserved = jnp.allclose(jnp.linalg.norm(v), jnp.linalg.norm(v_transported))
    print(f"Norm preserved: {norm_preserved}")

    return x, y, v, v_transported


x, y, v, v_transported = visualize_parallel_transport()

## 3. Optimization Algorithm Deep Dive

### Riemannian SGD
$$x_{k+1} = \text{Exp}_{x_k}(-\alpha \text{grad} f(x_k))$$

### Riemannian Adam
$$m_{k+1} = \beta_1 \mathcal{T}_{x_k \to x_{k+1}} m_k + (1-\beta_1) \text{grad} f(x_k)$$
$$v_{k+1} = \beta_2 \mathcal{T}_{x_k \to x_{k+1}} v_k + (1-\beta_2) (\text{grad} f(x_k))^2$$

where $\mathcal{T}_{x \to y}$ denotes parallel transport.

### Riemannian Momentum
$$m_{k+1} = \mu \mathcal{T}_{x_k \to x_{k+1}} m_k + \text{grad} f(x_k)$$
$$x_{k+1} = \text{Exp}_{x_k}(-\alpha m_{k+1})$$

In [None]:
def detailed_optimizer_analysis():
    """Detailed analysis of optimizer behavior."""
    # Create a challenging optimization problem on SO(3)
    so3 = rx.SpecialOrthogonal(n=3)

    # Target: specific rotation matrix
    key = jax.random.key(456)
    keys = jax.random.split(key, 3)

    target = so3.random_point(keys[0])
    x0 = so3.random_point(keys[1])

    def cost_fn(R):
        # Geodesic distance squared (more challenging than Frobenius norm)
        return 0.5 * so3.dist(R, target) ** 2

    problem = rx.RiemannianProblem(so3, cost_fn)

    print("Detailed Optimizer Analysis")
    print("=" * 40)
    print("Problem: SO(3) rotation alignment")
    print(f"Initial cost: {cost_fn(x0):.6f}")
    print("Target cost: 0.0")

    # Test different optimizers with manual stepping
    max_iterations = 100

    optimizers = {
        "RSGD": rx.riemannian_gradient_descent(learning_rate=0.1),
        "RAdaM": rx.riemannian_adam(learning_rate=0.01, beta1=0.9, beta2=0.999),
        "RMomentum": rx.riemannian_momentum(learning_rate=0.05, momentum=0.9),
    }

    results = {}

    for name, (init_fn, update_fn) in optimizers.items():
        print(f"\nAnalyzing {name}...")

        state = init_fn(x0)
        costs = [float(cost_fn(state.x))]
        grad_norms = []
        step_sizes = []

        for _i in range(max_iterations):
            # Compute gradient
            gradient = problem.grad(state.x)
            grad_norm = jnp.linalg.norm(gradient)
            grad_norms.append(float(grad_norm))

            # Store previous state for step size analysis
            prev_x = state.x

            # Update
            state = update_fn(gradient, state, so3)

            # Compute step size
            step_size = so3.dist(prev_x, state.x)
            step_sizes.append(float(step_size))

            # Record cost
            current_cost = float(cost_fn(state.x))
            costs.append(current_cost)

            # Early stopping
            if current_cost < 1e-10:
                break

        results[name] = {
            "costs": costs,
            "grad_norms": grad_norms,
            "step_sizes": step_sizes,
            "final_cost": costs[-1],
            "iterations": len(costs) - 1,
        }

        print(f"  Final cost: {costs[-1]:.2e}")
        print(f"  Iterations: {len(costs) - 1}")
        print(f"  Final gradient norm: {grad_norms[-1]:.2e}")

    return results


# Run detailed analysis
optimizer_results = detailed_optimizer_analysis()

In [None]:
def plot_optimizer_analysis(results):
    """Create detailed plots of optimizer behavior."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    colors = {"RSGD": "red", "RAdaM": "blue", "RMomentum": "green"}

    # 1. Convergence plot
    ax = axes[0, 0]
    for name, data in results.items():
        iterations = range(len(data["costs"]))
        ax.semilogy(iterations, data["costs"], color=colors[name], label=name, linewidth=2, alpha=0.8)

    ax.set_title("Convergence Comparison")
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Cost (log scale)")
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 2. Gradient norm evolution
    ax = axes[0, 1]
    for name, data in results.items():
        iterations = range(len(data["grad_norms"]))
        ax.semilogy(iterations, data["grad_norms"], color=colors[name], label=name, linewidth=2, alpha=0.8)

    ax.set_title("Gradient Norm Evolution")
    ax.set_xlabel("Iteration")
    ax.set_ylabel("||grad f|| (log scale)")
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 3. Step size evolution
    ax = axes[1, 0]
    for name, data in results.items():
        iterations = range(len(data["step_sizes"]))
        ax.plot(iterations, data["step_sizes"], color=colors[name], label=name, linewidth=2, alpha=0.8)

    ax.set_title("Step Size Evolution")
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Step Size (geodesic distance)")
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 4. Performance summary
    ax = axes[1, 1]

    names = list(results.keys())
    final_costs = [results[name]["final_cost"] for name in names]
    iterations = [results[name]["iterations"] for name in names]

    x_pos = np.arange(len(names))

    # Dual y-axis plot
    bars1 = ax.bar(x_pos - 0.2, final_costs, 0.4, color=[colors[name] for name in names], alpha=0.7, label="Final Cost")

    ax2 = ax.twinx()
    bars2 = ax2.bar(x_pos + 0.2, iterations, 0.4, color="gray", alpha=0.5, label="Iterations")

    ax.set_xlabel("Optimizer")
    ax.set_ylabel("Final Cost", color="black")
    ax2.set_ylabel("Iterations", color="gray")
    ax.set_title("Final Performance Summary")
    ax.set_xticks(x_pos)
    ax.set_xticklabels(names)
    ax.set_yscale("log")

    # Add value labels
    for bar, cost in zip(bars1, final_costs, strict=False):
        ax.text(
            bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{cost:.1e}", ha="center", va="bottom", fontsize=9
        )

    for bar, iters in zip(bars2, iterations, strict=False):
        ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{iters}", ha="center", va="bottom", fontsize=9)

    plt.tight_layout()

    # Save plot
    output_dir = Path("./output")
    output_dir.mkdir(exist_ok=True)
    plt.savefig(output_dir / "advanced_optimizer_analysis.png", dpi=300, bbox_inches="tight")

    plt.show()


# Create detailed plots
plot_optimizer_analysis(optimizer_results)

## 4. Numerical Considerations

### Conditioning and Stability

Riemannian optimization can face unique numerical challenges:

1. **Exponential Map Computation**: Can be expensive for some manifolds
2. **Parallel Transport**: Numerical errors can accumulate
3. **Manifold Constraints**: Must be preserved exactly
4. **Step Size Selection**: Too large steps can leave the manifold

### Best Practices

- Use **retraction** when exponential map is expensive
- **Clip step sizes** to prevent overshooting
- **Validate manifold constraints** regularly
- Consider **double precision** for ill-conditioned problems

In [None]:
def numerical_stability_analysis():
    """Analyze numerical stability across different manifolds."""
    print("Numerical Stability Analysis")
    print("=" * 40)

    # Test different manifolds with varying conditioning
    manifolds = {
        "Sphere": rx.Sphere(),
        "SO(3)": rx.SpecialOrthogonal(n=3),
        "Grassmann(4,2)": rx.Grassmann(n=4, p=2),
        "SPD(3)": rx.SymmetricPositiveDefinite(n=3),
    }

    key = jax.random.key(789)

    for name, manifold in manifolds.items():
        print(f"\nTesting {name}:")

        # Generate random point
        key, subkey = jax.random.split(key)
        x = manifold.random_point(subkey)

        # Test manifold constraint satisfaction
        if hasattr(manifold, "_is_in_manifold"):
            on_manifold = manifold._is_in_manifold(x)
            print(f"  Point on manifold: {on_manifold}")

        # Test exp/log consistency
        key, subkey = jax.random.split(key)
        v = manifold.random_tangent(subkey, x)

        # Scale tangent vector to test different step sizes
        for scale in [0.1, 0.5, 1.0]:
            v_scaled = scale * v

            try:
                # Test exp/log round trip
                y = manifold.exp(x, v_scaled)
                v_recovered = manifold.log(x, y)

                error = jnp.linalg.norm(v_scaled - v_recovered)
                print(f"  Scale {scale}: exp/log error = {error:.2e}")

                # Test if y is on manifold
                if hasattr(manifold, "_is_in_manifold"):
                    y_on_manifold = manifold._is_in_manifold(y)
                    if not y_on_manifold:
                        print("    WARNING: exp result not on manifold!")

            except Exception as e:
                print(f"  Scale {scale}: ERROR - {e!s}")

        # Test parallel transport properties
        key, subkey = jax.random.split(key)
        y = manifold.random_point(subkey)

        v_transported = manifold.transp(x, y, v)

        # Check norm preservation
        norm_error = abs(jnp.linalg.norm(v) - jnp.linalg.norm(v_transported))
        print(f"  Parallel transport norm error: {norm_error:.2e}")

        # Check tangent space property
        if hasattr(manifold, "proj"):
            v_projected = manifold.proj(y, v_transported)
            tangent_error = jnp.linalg.norm(v_transported - v_projected)
            print(f"  Tangent space error: {tangent_error:.2e}")


numerical_stability_analysis()

## 5. Advanced Applications

### Multi-Manifold Optimization

Some problems involve optimization over **multiple manifolds simultaneously**. For example:
- Joint diagonalization: $\text{SO}(n) \times \text{SPD}(n)$
- Subspace clustering: Multiple Grassmann manifolds
- Multi-view learning: Product of Stiefel manifolds

In [None]:
def joint_diagonalization_demo():
    """Demonstrate joint diagonalization using SO(n) x SPD(n)."""
    print("Joint Diagonalization Demo")
    print("=" * 40)

    # Problem: Given matrices A1, A2, ..., find rotation R such that
    # R^T A_i R are as diagonal as possible

    key = jax.random.key(999)
    keys = jax.random.split(key, 10)

    n = 4
    n_matrices = 3

    # Generate test matrices with known joint diagonalizer
    true_rotation = rx.SpecialOrthogonal(n).random_point(keys[0])

    # Create matrices that are diagonalized by true_rotation
    matrices = []
    for i in range(n_matrices):
        # Random diagonal matrix
        D = jnp.diag(jax.random.uniform(keys[i + 1], (n,), minval=0.1, maxval=2.0))
        # Rotate to create non-diagonal matrix
        A = true_rotation @ D @ true_rotation.T
        matrices.append(A)

    print(f"Generated {n_matrices} matrices of size {n}x{n}")
    print("True rotation matrix known")

    # Define joint diagonalization cost
    def joint_diag_cost(R):
        total_off_diag = 0.0
        for A in matrices:
            # Transform matrix
            transformed = R.T @ A @ R
            # Penalize off-diagonal elements
            off_diag_mask = 1 - jnp.eye(n)
            off_diag_sum = jnp.sum((transformed * off_diag_mask) ** 2)
            total_off_diag += off_diag_sum
        return total_off_diag

    # Initial cost
    initial_cost = joint_diag_cost(jnp.eye(n))
    true_cost = joint_diag_cost(true_rotation)

    print(f"\nCost with identity: {initial_cost:.6f}")
    print(f"Cost with true rotation: {true_cost:.6f}")

    # Optimize using different methods
    so3 = rx.SpecialOrthogonal(n)
    problem = rx.RiemannianProblem(so3, joint_diag_cost)

    # Random initialization
    x0 = so3.random_point(keys[-1])
    initial_opt_cost = joint_diag_cost(x0)

    print(f"Initial optimization cost: {initial_opt_cost:.6f}")

    # Test different optimizers
    methods = ["rsgd", "radam", "rmom"]
    results = {}

    for method in methods:
        print(f"\nOptimizing with {method.upper()}...")

        if method == "radam":
            options = {"learning_rate": 0.01, "max_iterations": 200}
        elif method == "rmom":
            options = {"learning_rate": 0.05, "momentum": 0.9, "max_iterations": 200}
        else:
            options = {"learning_rate": 0.1, "max_iterations": 200}

        result = rx.minimize(problem, x0, method=method, options=options)

        final_cost = result.fun
        optimal_R = result.x

        # Measure distance to true solution
        rotation_error = so3.dist(optimal_R, true_rotation)

        results[method] = {"final_cost": final_cost, "rotation_error": rotation_error, "optimal_R": optimal_R}

        print(f"  Final cost: {final_cost:.6f}")
        print(f"  Distance to true rotation: {rotation_error:.6f}")
        print(f"  Improvement: {((initial_opt_cost - final_cost) / initial_opt_cost * 100):.2f}%")

    return results, matrices, true_rotation


joint_diag_results, matrices, true_rotation = joint_diagonalization_demo()

## 6. Performance Optimization and Best Practices

### JAX Optimization Tips

1. **JIT Compilation**: Use `@jax.jit` for expensive operations
2. **Vectorization**: Use `jax.vmap` for batch operations
3. **Memory Management**: Be aware of JAX's functional programming model
4. **Device Placement**: Utilize GPU/TPU when available

### RiemannAX Specific Tips

1. **Choose the Right Manifold**: Match problem structure to manifold geometry
2. **Optimizer Selection**: Adam for ill-conditioned, Momentum for speed, SGD for simplicity
3. **Learning Rate Tuning**: Start conservative, especially for Adam
4. **Retraction vs Exponential**: Use retraction for computational efficiency

In [None]:
def performance_comparison_demo():
    """Compare performance of different implementation choices."""
    print("Performance Comparison Demo")
    print("=" * 40)

    # Large-scale Grassmann optimization
    n, p = 100, 10  # Large problem size
    grassmann = rx.Grassmann(n=n, p=p)

    # Generate synthetic data
    key = jax.random.key(12345)
    keys = jax.random.split(key, 5)

    # Large data matrix
    data = jax.random.normal(keys[0], (n, 1000))

    def subspace_cost(subspace):
        projector = subspace @ subspace.T
        reconstruction = projector @ data
        return jnp.sum((data - reconstruction) ** 2)

    problem = rx.RiemannianProblem(grassmann, subspace_cost)
    x0 = grassmann.random_point(keys[1])

    print(f"Problem size: Grassmann({n}, {p})")
    print(f"Data size: {data.shape}")
    print(f"Initial cost: {subspace_cost(x0):.2e}")

    # Compare retraction vs exponential map
    configurations = [("Exponential Map", False), ("Retraction", True)]

    for config_name, use_retraction in configurations:
        print(f"\nTesting {config_name}:")

        start_time = time.time()

        result = rx.minimize(
            problem,
            x0,
            method="rsgd",
            options={"learning_rate": 0.01, "max_iterations": 50, "use_retraction": use_retraction},
        )

        elapsed_time = time.time() - start_time

        print(f"  Time: {elapsed_time:.3f} seconds")
        print(f"  Final cost: {result.fun:.2e}")
        print(f"  Iterations: {getattr(result, 'nit', 'N/A')}")

    # Test batch optimization
    print("\nBatch Optimization Test:")

    batch_size = 5
    batch_x0 = jax.vmap(lambda k: grassmann.random_point(k))(jax.random.split(keys[2], batch_size))

    def batch_cost(batch_subspaces):
        return jax.vmap(subspace_cost)(batch_subspaces)

    # Time batch operation
    start_time = time.time()
    batch_costs = batch_cost(batch_x0)
    elapsed_time = time.time() - start_time

    print(f"  Batch size: {batch_size}")
    print(f"  Batch evaluation time: {elapsed_time:.4f} seconds")
    print(f"  Average cost: {jnp.mean(batch_costs):.2e}")


performance_comparison_demo()

## 7. Summary and Next Steps

### Key Takeaways

1. **Geometric Insight**: Riemannian optimization provides natural solutions for constrained problems
2. **Algorithm Choice**: Different optimizers excel in different scenarios
3. **Numerical Care**: Stability and efficiency require thoughtful implementation
4. **Practical Impact**: Real performance gains in machine learning applications

### Advanced Topics to Explore

- **Riemannian Natural Gradients**: Second-order methods
- **Constrained Optimization**: Equality and inequality constraints on manifolds
- **Stochastic Methods**: Riemannian SGD with mini-batches
- **Multi-scale Optimization**: Hierarchical manifold optimization

### Further Reading

- Absil, P.-A., Mahony, R., & Sepulchre, R. (2008). *Optimization Algorithms on Matrix Manifolds*
- Boumal, N. (2023). *An Introduction to Optimization on Smooth Manifolds*
- RiemannAX Documentation: [github.com/lv416e/riemannax](https://github.com/lv416e/riemannax)

In [None]:
# Final summary and cleanup
print("=" * 60)
print("ADVANCED RIEMANNIAN OPTIMIZATION TUTORIAL COMPLETE")
print("=" * 60)
print("\n✓ Mathematical foundations covered")
print("✓ Manifold operations understood")
print("✓ Optimization algorithms analyzed")
print("✓ Numerical considerations addressed")
print("✓ Advanced applications demonstrated")
print("✓ Performance optimization explored")
print("\nNext: Apply these concepts to your own optimization problems!")
print("=" * 60)