# Differential Geometry for Machine Learning
## Manifold Learning and Natural Gradients

Welcome to the **geometry of curved spaces**! Differential geometry provides the mathematical framework for understanding data that lives on curved manifolds rather than flat Euclidean spaces.

### What You'll Master
By the end of this notebook, you'll understand:
1. **Manifolds** - Curved spaces that locally look Euclidean
2. **Tangent spaces** - Linear approximations to curved spaces
3. **Riemannian metrics** - How to measure distances and angles on manifolds
4. **Natural gradients** - Optimization that respects the geometry
5. **Manifold learning** - Discovering hidden structure in high-dimensional data
6. **Information geometry** - The geometry of probability distributions

### Why This is Revolutionary
- **Data manifolds** - Real data often lies on low-dimensional manifolds
- **Natural gradients** - Faster convergence by using geometric structure
- **Geometric deep learning** - Neural networks on manifolds and graphs
- **Information geometry** - Optimal learning algorithms

### Real-World Applications
- **Computer vision**: Face recognition on face manifolds
- **Robotics**: Motion planning on configuration manifolds
- **Neuroscience**: Neural population dynamics on neural manifolds
- **Economics**: Market dynamics on economic manifolds

Let's explore the beautiful geometry of curved spaces! 📐

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
from scipy import linalg
from scipy.optimize import minimize
from sklearn.datasets import make_swiss_roll, make_s_curve, load_digits
from sklearn.manifold import LocallyLinearEmbedding, Isomap, TSNE, SpectralEmbedding
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("viridis")
np.random.seed(42)

print("📐 Differential Geometry toolkit loaded!")
print("Ready to explore curved spaces and manifolds!")

## 1. Manifolds and Tangent Spaces

### What is a Manifold?
A **manifold** is a space that locally looks like Euclidean space but can be globally curved.

**Formal Definition**: An n-dimensional manifold M is a set where:
1. Every point has a neighborhood homeomorphic to an open subset of ℝⁿ
2. Transition maps between charts are smooth (differentiable)

### Examples of Manifolds
- **Circle S¹**: 1D manifold embedded in ℝ²
- **Sphere S²**: 2D manifold embedded in ℝ³
- **Torus**: 2D manifold (donut shape)
- **Swiss roll**: 2D manifold embedded in ℝ³

### Tangent Spaces
At each point p on a manifold M, the **tangent space** TₚM is the vector space of all possible "directions" you can move from p.

**Properties**:
- **Dimension**: dim(TₚM) = dim(M)
- **Linear structure**: Can add vectors and multiply by scalars
- **Local approximation**: Best linear approximation to the manifold at p

### Charts and Coordinates
A **chart** (φ, U) is a local coordinate system:
```
φ: U ⊆ M → V ⊆ ℝⁿ
```
where U is an open set in M and φ is a homeomorphism.

### Why Manifolds Matter in ML
- **Data manifold hypothesis**: High-dimensional data often lies near a low-dimensional manifold
- **Dimensionality reduction**: Find the underlying manifold structure
- **Natural gradients**: Optimization respecting the manifold geometry
- **Geometric deep learning**: Neural networks that respect manifold structure

In [None]:
def demonstrate_manifolds_and_tangent_spaces():
    """Explore manifolds, tangent spaces, and local geometry"""
    
    print("🌐 Manifolds and Tangent Spaces: Curved Geometry")
    print("=" * 48)
    
    fig = plt.figure(figsize=(20, 15))
    
    # 1. Classic manifold examples
    print("\n1. Classic Manifold Examples")
    
    # Circle manifold
    theta = np.linspace(0, 2*np.pi, 100)
    circle_x = np.cos(theta)
    circle_y = np.sin(theta)
    
    ax1 = fig.add_subplot(3, 4, 1)
    ax1.plot(circle_x, circle_y, 'b-', linewidth=3)
    
    # Show tangent vectors at a few points
    tangent_points = [0, np.pi/4, np.pi/2, np.pi]
    for t in tangent_points:
        x, y = np.cos(t), np.sin(t)
        # Tangent vector (perpendicular to radius)
        tx, ty = -np.sin(t), np.cos(t)
        ax1.arrow(x, y, 0.3*tx, 0.3*ty, head_width=0.05, head_length=0.05, fc='red', ec='red')
        ax1.plot(x, y, 'ro', markersize=6)
    
    ax1.set_xlim(-1.5, 1.5)
    ax1.set_ylim(-1.5, 1.5)
    ax1.set_aspect('equal')
    ax1.set_title('Circle S¹ with Tangent Vectors')
    ax1.grid(True, alpha=0.3)
    
    # Sphere manifold
    ax2 = fig.add_subplot(3, 4, 2, projection='3d')
    
    # Create sphere
    u = np.linspace(0, 2 * np.pi, 50)
    v = np.linspace(0, np.pi, 50)
    sphere_x = np.outer(np.cos(u), np.sin(v))
    sphere_y = np.outer(np.sin(u), np.sin(v))
    sphere_z = np.outer(np.ones(np.size(u)), np.cos(v))
    
    ax2.plot_surface(sphere_x, sphere_y, sphere_z, alpha=0.3, color='blue')
    
    # Show tangent plane at north pole
    xx, yy = np.meshgrid(np.linspace(-0.5, 0.5, 10), np.linspace(-0.5, 0.5, 10))
    zz = np.ones_like(xx) * 1
    ax2.plot_surface(xx, yy, zz, alpha=0.7, color='red')
    
    ax2.set_title('Sphere S² with Tangent Plane')
    ax2.set_box_aspect([1,1,1])
    
    print(f"   Circle S¹: 1D manifold embedded in ℝ²")
    print(f"   Sphere S²: 2D manifold embedded in ℝ³")
    print(f"   Tangent vectors span the tangent space at each point")
    
    # 2. Swiss roll manifold
    print("\n2. Swiss Roll: A Classic ML Manifold")
    
    # Generate Swiss roll data
    n_samples = 1000
    X_swiss, color_swiss = make_swiss_roll(n_samples=n_samples, noise=0.1, random_state=42)
    
    # 3D plot
    ax3 = fig.add_subplot(3, 4, 3, projection='3d')
    ax3.scatter(X_swiss[:, 0], X_swiss[:, 1], X_swiss[:, 2], 
               c=color_swiss, cmap='viridis', s=20, alpha=0.8)
    ax3.set_title('Swiss Roll in ℝ³')
    ax3.set_xlabel('X₁')
    ax3.set_ylabel('X₂')
    ax3.set_zlabel('X₃')
    
    # Intrinsic 2D coordinates
    ax4 = fig.add_subplot(3, 4, 4)
    ax4.scatter(color_swiss, X_swiss[:, 1], c=color_swiss, cmap='viridis', s=20, alpha=0.8)
    ax4.set_title('Intrinsic 2D Coordinates')
    ax4.set_xlabel('Angle (intrinsic coordinate)')
    ax4.set_ylabel('Height')
    ax4.grid(True, alpha=0.3)
    
    print(f"   Swiss roll: 2D manifold embedded in ℝ³")
    print(f"   Intrinsic dimension: 2, Ambient dimension: 3")
    print(f"   Goal: Recover intrinsic coordinates from 3D embedding")
    
    # 3. Tangent space estimation
    print("\n3. Tangent Space Estimation")
    
    def estimate_tangent_space(X, point_idx, k=10):
        """Estimate tangent space using local PCA"""
        # Find k nearest neighbors
        nbrs = NearestNeighbors(n_neighbors=k+1).fit(X)
        _, indices = nbrs.kneighbors([X[point_idx]])
        neighbors = X[indices[0][1:]]  # Exclude the point itself
        
        # Center the neighbors
        centered_neighbors = neighbors - X[point_idx]
        
        # Compute SVD to find tangent directions
        U, s, Vt = np.linalg.svd(centered_neighbors.T, full_matrices=False)
        
        # Tangent vectors are the first d principal directions
        # For Swiss roll, d=2
        tangent_vectors = U[:, :2]
        
        return tangent_vectors, s
    
    # Select a point on the Swiss roll
    point_idx = 100
    point = X_swiss[point_idx]
    
    # Estimate tangent space
    tangent_vecs, singular_vals = estimate_tangent_space(X_swiss, point_idx, k=20)
    
    # Plot the point and its tangent space
    ax5 = fig.add_subplot(3, 4, 5, projection='3d')
    
    # Plot Swiss roll (subset for clarity)
    subset_idx = np.random.choice(n_samples, 200, replace=False)
    ax5.scatter(X_swiss[subset_idx, 0], X_swiss[subset_idx, 1], X_swiss[subset_idx, 2],
               c=color_swiss[subset_idx], cmap='viridis', s=10, alpha=0.3)
    
    # Highlight the selected point
    ax5.scatter([point[0]], [point[1]], [point[2]], c='red', s=100, marker='o')
    
    # Plot tangent vectors
    scale = 2.0
    for i in range(2):
        vec = tangent_vecs[:, i] * scale
        ax5.quiver(point[0], point[1], point[2], vec[0], vec[1], vec[2],
                  color='red', arrow_length_ratio=0.1, linewidth=3)
    
    ax5.set_title('Tangent Space Estimation')
    ax5.set_xlabel('X₁')
    ax5.set_ylabel('X₂')
    ax5.set_zlabel('X₃')
    
    print(f"   Selected point: ({point[0]:.2f}, {point[1]:.2f}, {point[2]:.2f})")
    print(f"   Singular values: {singular_vals[:3]}")
    print(f"   First 2 values >> third (confirms 2D manifold)")
    
    # 4. Local coordinates and charts
    print("\n4. Local Coordinate Charts")
    
    # Create a parametric surface (torus)
    def torus_parametrization(u, v, R=2, r=1):
        """Parametrization of a torus"""
        x = (R + r * np.cos(v)) * np.cos(u)
        y = (R + r * np.cos(v)) * np.sin(u)
        z = r * np.sin(v)
        return np.array([x, y, z])
    
    # Generate torus
    u_vals = np.linspace(0, 2*np.pi, 30)
    v_vals = np.linspace(0, 2*np.pi, 20)
    U_grid, V_grid = np.meshgrid(u_vals, v_vals)
    
    torus_points = np.array([torus_parametrization(u, v) for u, v in zip(U_grid.flatten(), V_grid.flatten())])
    
    # Plot torus with coordinate lines
    ax6 = fig.add_subplot(3, 4, 6, projection='3d')
    
    # Plot surface
    X_torus = torus_points[:, 0].reshape(U_grid.shape)
    Y_torus = torus_points[:, 1].reshape(U_grid.shape)
    Z_torus = torus_points[:, 2].reshape(U_grid.shape)
    
    ax6.plot_surface(X_torus, Y_torus, Z_torus, alpha=0.3, color='blue')
    
    # Plot coordinate lines
    for i in range(0, len(u_vals), 5):
        u_line = [torus_parametrization(u_vals[i], v) for v in v_vals]
        u_line = np.array(u_line)
        ax6.plot(u_line[:, 0], u_line[:, 1], u_line[:, 2], 'r-', linewidth=2)
    
    for j in range(0, len(v_vals), 4):
        v_line = [torus_parametrization(u, v_vals[j]) for u in u_vals]
        v_line = np.array(v_line)
        ax6.plot(v_line[:, 0], v_line[:, 1], v_line[:, 2], 'g-', linewidth=2)
    
    ax6.set_title('Torus with Coordinate Charts')
    ax6.set_box_aspect([1,1,0.5])
    
    # Parameter space
    ax7 = fig.add_subplot(3, 4, 7)
    ax7.imshow(np.zeros_like(U_grid), extent=[0, 2*np.pi, 0, 2*np.pi], cmap='gray', alpha=0.3)
    
    # Draw coordinate grid
    for u in u_vals[::5]:
        ax7.axvline(x=u, color='red', linewidth=2, alpha=0.7)
    for v in v_vals[::4]:
        ax7.axhline(y=v, color='green', linewidth=2, alpha=0.7)
    
    ax7.set_xlabel('u parameter')
    ax7.set_ylabel('v parameter')
    ax7.set_title('Parameter Space (Chart)')
    ax7.grid(True, alpha=0.3)
    
    print(f"   Torus: 2D manifold with global parametrization")
    print(f"   Chart: (u, v) ∈ [0, 2π] × [0, 2π] → ℝ³")
    print(f"   Red lines: constant u, Green lines: constant v")
    
    # 5. Manifold learning algorithms comparison
    print("\n5. Manifold Learning Algorithms")
    
    # Use S-curve for clearer visualization
    X_scurve, color_scurve = make_s_curve(n_samples=300, noise=0.1, random_state=42)
    
    # Apply different manifold learning algorithms
    algorithms = {
        'PCA': PCA(n_components=2),
        'Isomap': Isomap(n_components=2, n_neighbors=10),
        'LLE': LocallyLinearEmbedding(n_components=2, n_neighbors=10),
        'Spectral': SpectralEmbedding(n_components=2, n_neighbors=10)
    }
    
    embeddings = {}
    for name, algorithm in algorithms.items():
        embedding = algorithm.fit_transform(X_scurve)
        embeddings[name] = embedding
    
    # Plot original S-curve
    ax8 = fig.add_subplot(3, 4, 8, projection='3d')
    ax8.scatter(X_scurve[:, 0], X_scurve[:, 1], X_scurve[:, 2],
               c=color_scurve, cmap='viridis', s=20)
    ax8.set_title('Original S-Curve')
    ax8.set_xlabel('X₁')
    ax8.set_ylabel('X₂')
    ax8.set_zlabel('X₃')
    
    # Plot embeddings
    for i, (name, embedding) in enumerate(embeddings.items()):
        ax = fig.add_subplot(3, 4, 9 + i)
        ax.scatter(embedding[:, 0], embedding[:, 1], c=color_scurve, cmap='viridis', s=20)
        ax.set_title(f'{name} Embedding')
        ax.set_xlabel('Component 1')
        ax.set_ylabel('Component 2')
        ax.grid(True, alpha=0.3)
    
    print(f"   S-curve: 1D manifold embedded in ℝ³")
    print(f"   PCA: Linear projection (misses curvature)")
    print(f"   Isomap: Preserves geodesic distances")
    print(f"   LLE: Preserves local linear relationships")
    print(f"   Spectral: Uses graph Laplacian eigenvectors")
    
    plt.tight_layout()
    plt.show()
    
    return X_swiss, color_swiss, tangent_vecs, embeddings

X_swiss_demo, color_demo, tangent_demo, embeddings_demo = demonstrate_manifolds_and_tangent_spaces()