In [24]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

class SphereManifold:
    def __init__(self, dim=2):
        self.dim = dim
    
    def project_to_tangent_space(self, x, v):
        """
        Project a vector v onto the tangent space at point x on the sphere.
        This is done by subtracting the component of v that is parallel to x.
        """
        return v - np.dot(x, v) * x
    
    def retraction(self, x, eta):
        """
        Retract a tangent vector eta back to the sphere.
        This is done by moving in the direction of eta and then normalizing.
        """
        y = x + eta
        return y / np.linalg.norm(y)
    
    def gradient(self, f, x, eps=1e-8):
        """
        Compute the Riemannian gradient of a function f at point x on the sphere.
        The gradient is computed using central finite differences.
        """
        grad = np.zeros_like(x)
        for i in range(len(x)):
            e_i = np.zeros_like(x)
            e_i[i] = 1
            grad[i] = (f(x + eps * e_i) - f(x - eps * e_i)) / (2 * eps)
        return self.project_to_tangent_space(x, grad)
    
    def random_point(self):
        """
        Generate a random point on the sphere.
        """
        x = np.random.randn(3)
        return x / np.linalg.norm(x)

sphere = SphereManifold()

x = sphere.random_point()

def plot_sphere_with_point(point):
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)
    
    x_sphere = np.outer(np.cos(u), np.sin(v))
    y_sphere = np.outer(np.sin(u), np.sin(v))
    z_sphere = np.outer(np.ones(np.size(u)), np.cos(v))
    
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(x_sphere, y_sphere, z_sphere, color='b', alpha=0.6, rstride=5, cstride=5, linewidth=0)

    ax.scatter(point[0], point[1], point[2], color='r', s=100)

    ax.set_title("2D Sphere $S^2$ with a Random Point")
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_box_aspect([1, 1, 1])  
    
    plt.show()

#plot_sphere_with_point(x)


In [20]:
from scipy.linalg import qr
def random_point(n,p):
  A = np.random.randn(n,p)
  Q,_ = qr(A)
  return Q

A = random_point(2,2)

def check_is_stiefel(X):
  size = X.shape[0]
  A = X.T @ X
  return np.allclose(A,np.eye(size),10e-3)

def project_to_tangent_space(X,Z):
  I = np.eye(X.shape[0])
  M = X.T @ Z
  skew = 1/2 * (M - M.T)
  proj = (I - X@X.T)@Z - X@skew
  return proj

B = np.eye(2)

C = project_to_tangent_space(A,B.T)

C

array([[ 1.11022302e-16, -1.42102144e-17],
       [-1.42102144e-17,  0.00000000e+00]])