In [1]:
import pymanopt
import numpy as np
from scipy.integrate import solve_ivp
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from pymanopt.manifolds import Sphere
import plotly.graph_objects as go
from abc import ABC, abstractmethod

In [2]:
dim = 3

In [3]:
class Manifold(ABC):
    """
    Abstract base class for manifolds with exp and log maps.
    """

    @abstractmethod
    def exp(self, x, v):
        """Exponential map: tangent space -> manifold."""
        pass

    @abstractmethod
    def log(self, x, y):
        """Logarithmic map: manifold -> tangent space."""
        pass

In [4]:
# Wrapper around pymanopt Sphere manifold
class PymanoptSphereManifold(Manifold):
    def __init__(self, dim):
        self.manifold = Sphere(dim)

    def exp(self, x, v):
        return self.manifold.exp(x, v)

    def log(self, x, y):
        return self.manifold.log(x, y)

    def __getattr__(self, attr):
        return getattr(self.manifold, attr)

In [5]:
# Compute local metric tensor
def local_metric_tensor(x, data, rho=1e-5):
    """
    Computes the local diagonal metric tensor at point x.

    Parameters:
    - x: (D,) point at which the metric is evaluated
    - data: (N, D) array of data points
    - rho: regularization constant to prevent singularities

    Returns:
    - metric_tensor: (D,) diagonal elements of the metric tensor at point x
    """
    # Gaussian kernel bandwidth
    sigma = 0.5

    # The metric tensor is a diagonal matrix with diagonal elements
    diff_sq = (data - x)**2
    distances_sq = np.sum(diff_sq, axis=1)
    weights = np.exp(-distances_sq / (2 * sigma**2))
    weighted_cov_diag = np.sum(weights[:, np.newaxis] * diff_sq, axis=0) + rho
    metric_tensor = np.diag(1.0 / weighted_cov_diag)

    return metric_tensor

# Determinant of local metric (assumed diagonal)
def det_local_metric(x, data, rho=1e-5):
    metric_tensor = local_metric_tensor(x, data, rho)
    metric_det = np.prod(metric_tensor)
    return metric_det

In [6]:
tol = 1e-5

def gram_schmidt(vectors):
    Q = np.zeros_like(vectors)
    for i in range(len(vectors)):
        v = vectors[i]
        for j in range(i):
            proj = np.dot(v, Q[j]) / np.dot(Q[j], Q[j]) * Q[j]
            v = v - proj
        Q[i] = v / np.linalg.norm(v)
    return Q
    
def get_axis(normal):
    # Assumes it is norm 1
    Q = np.eye(dim)
    arg = np.argmax(normal)
    if normal[arg]>(1-tol) and np.sum(normal)>(1-tol):
        return np.concatenate([Q[:arg], Q[arg+1:]])
    return gram_schmidt(np.vstack([normal,Q[:-1]]))[1:]

def get_e(i,n):
    return np.array(i*[0]+[1]+(n-i-1)*[0])

In [7]:
# Compute derivative of the metric tensor numerically
def metric_tensor_jacobian(x, data, eps=1e-5):
    D = len(x)
    jac = np.zeros((D, D, D))
    for i in range(D):
        dx = np.zeros(D)
        dx[i] = eps
        jac[:, :, i] = (local_metric_tensor(x + dx, data) - local_metric_tensor(x - dx, data)) / (2 * eps)
    return jac

# Geodesic ODE function
def geodesic_ode(t, y, data):
    D = len(y) // 2
    pos, vel = y[:D], y[D:]

    M = local_metric_tensor(pos, data)
    M_inv = np.linalg.inv(M)
    M_jac = metric_tensor_jacobian(pos, data)

    # Compute Christoffel term
    christoffel = np.zeros(D)
    for i in range(D):
        christoffel[i] = vel @ M_jac[:, :, i] @ vel

    acc = -0.5 * M_inv @ christoffel
    return np.concatenate([vel, acc])

# Exponential map (Initial value problem)
def exp_map(x, v, data, t_final=1.0):
    y0 = np.concatenate([x, v])
    sol = solve_ivp(geodesic_ode, [0, t_final], y0, args=(data,), method='RK45', atol=1e-8)
    return sol.y[:len(x), -1]

# Logarithmic map (Boundary value problem solved via optimization)
def log_map(x, y_target, data, t_final=1.0):
    D = len(x)

    def objective(v_guess):
        y_final = exp_map(x, v_guess, data, t_final)
        return np.linalg.norm(y_final - y_target)**2

    res = minimize(objective, np.zeros(D), method='BFGS', options={'gtol':1e-8, 'disp':False})
    return res.x

In [22]:
class LearningManifold(Manifold):
    def __init__(self, data):
        self.data = data
        self.t_final = 1.0

    def exp(self, x, v):
        y0 = np.concatenate([x, v])
        sol = solve_ivp(geodesic_ode, [0, self.t_final], y0, args=(self.data,), method='RK45', atol=1e-8)
        return sol.y[:len(x), -1]

    def log(self, x, y):
        D = len(x)

        def objective(v_guess):
            y_final = exp_map(x, v_guess, self.data, self.t_final)
            return np.linalg.norm(y_final - y)**2

        res = minimize(objective, np.zeros(D), method='BFGS', options={'gtol':1e-8, 'disp':False})
        return res.x

In [44]:
def plot_tangent_points(x, tangent_vectors):    
    # Sphere mesh
    theta, phi = np.mgrid[0:2*np.pi:50j, 0:np.pi:25j]
    xs = np.cos(theta)*np.sin(phi)
    ys = np.sin(theta)*np.sin(phi)
    zs = np.cos(phi)

    tangent_points = tangent_vectors + x

    fig = go.Figure()

    # Sphere surface
    fig.add_trace(go.Surface(x=xs, y=ys, z=zs, opacity=0.3, colorscale='Viridis', showscale=False))

    # Projected points
    fig.add_trace(go.Scatter3d(x=tangent_points[:,0], y=tangent_points[:,1], z=tangent_points[:,2],
                            mode='markers', marker=dict(size=3, color='red'), name='Tangent Points'))

    # Base point
    fig.add_trace(go.Scatter3d(x=[x[0]], y=[x[1]], z=[x[2]],
                            mode='markers', marker=dict(size=8, color='black'), name='Base Point'))

    fig.update_layout(
        scene=dict(
            aspectmode='data',
        ),
        width=700, 
        height=700,
        title='Tangent Points and Sphere Mesh'
    )

    fig.show()

In [45]:
def plot_projected_sphere(x, projected_points):
    theta, phi = np.mgrid[0:2*np.pi:50j, 0:np.pi:25j]
    xs = np.cos(theta)*np.sin(phi)
    ys = np.sin(theta)*np.sin(phi)
    zs = np.cos(phi)

    fig = go.Figure()

    # Sphere surface
    fig.add_trace(go.Surface(x=xs, y=ys, z=zs, opacity=0.3, colorscale='Viridis', showscale=False))

    # Projected points
    fig.add_trace(go.Scatter3d(x=projected_points[:,0], y=projected_points[:,1], z=projected_points[:,2],
                            mode='markers', marker=dict(size=3, color='red'), name='Projected Points'))

    # Base point
    fig.add_trace(go.Scatter3d(x=[x[0]], y=[x[1]], z=[x[2]],
                            mode='markers', marker=dict(size=8, color='black'), name='Base Point'))

    fig.update_layout(
        scene=dict(
            aspectmode="cube"
        ), 
        width=700, 
        height=700,
        title='Interactive Projection on Sphere'
    )
    fig.show()


In [57]:
def plot_mean(target, xs):
    theta, phi = np.mgrid[0:2*np.pi:50j, 0:np.pi:25j]
    xs = np.cos(theta)*np.sin(phi)
    ys = np.sin(theta)*np.sin(phi)
    zs = np.cos(phi)

    fig = go.Figure()

    # Sphere surface
    fig.add_trace(go.Surface(x=xs, y=ys, z=zs, opacity=0.3, colorscale='Viridis', showscale=False))

    # Target point
    fig.add_trace(go.Scatter3d(x=[target[0]], y=[target[1]], z=[target[2]],
                            mode='markers', marker=dict(size=8, color='red'), name='Target Mean'))
    
    # Trail of previous means 
    for i, x in enumerate(xs[:-1]):
        opacity = i / len(xs + 1)
        fig.add_trace(go.Scatter3d(x=[x[0]], y=[x[1]], z=[x[2]],
                            mode='markers', marker=dict(size=2, color='black', opacity=opacity), showlegend=False))

    # Current mean
    fig.add_trace(go.Scatter3d(x=[x[0]], y=[x[1]], z=[x[2]],
                            mode='markers', marker=dict(size=8, color='black', opacity=opacity), name='Current Mean'))
    fig.update_layout(
        scene=dict(
            aspectmode="cube"
        ), 
        width=700, 
        height=700,
        title=f'Target vs Achieved Mean'
    )
    fig.show()


In [46]:
def estimate_normalization_constant(data, mu, Sigma, manifold, num_samples=1000):
    D = len(mu) - 1

    # Compute normalization of Euclidean normal distribution
    Z = np.sqrt((2 * np.pi) ** D * np.linalg.det(Sigma))

    # Generate tangent space samples
    axis = get_axis(mu)
    vectors = np.random.multivariate_normal(np.zeros(D), Sigma, num_samples)
    tangent_vectors = vectors@axis

    # Perform Monte Carlo integration
    metric_sum = np.sum(compute_vol(mu, tangent_vectors, manifold, data))

    C_hat = Z * metric_sum / num_samples

    return C_hat

def compute_vol(mu, vs, manifold, data):
    metric_tensors = np.array([local_metric_tensor(manifold.exp(mu, v), data) for v in vs])
    return np.sqrt(np.abs(np.linalg.det(metric_tensors)))

In [49]:
def random_cov():
    A = np.random.rand(dim-1, dim-1)
    return np.dot(A, A.transpose())

def extrinsic_to_log(manifold, mu, x, ax):
    point = manifold.log(mu, x)
    return np.dot(point, ax.T)

def objective_grad_mu(points, mu, Sigma, ax, manifold, S=100):
    d = manifold.dim
    samples = np.random.multivariate_normal(np.zeros(d), Sigma, S)
    vs = samples@ax
    ms = compute_vol(mu, vs, manifold, points)
    z = np.sqrt((2*np.pi)**d*np.linalg.det(Sigma))
    grad = (np.array([extrinsic_to_log(manifold,mu,p,ax) for p in points])
             .mean(0)-z*(ms.reshape(1,-1)@samples)/
             (S * estimate_normalization_constant(points, mu, Sigma, manifold)))
    return grad

def objective_grad_A(points, mu, Sigma, axis, manifold, S=100):
    d = manifold.dim
    vals, vecs = np.linalg.eig(Sigma)
    A = (vecs@np.diag(1/np.sqrt(vals))).T
    samples = np.random.multivariate_normal(np.zeros(d), Sigma, S)
    vs = samples@axis
    ms = compute_vol(mu, vs, manifold, points)
    term2 = np.zeros((d, d)).astype(dtype='float64')
    for m,s in zip(ms,samples):
        term2 += m*((s.reshape(-1,1))@(s.reshape(1,-1)))
    term2 *= np.sqrt((2*np.pi)**d*np.linalg.det(Sigma))
    term2 /= (S * estimate_normalization_constant(points, mu, Sigma, manifold))
    term1 = np.zeros((d, d)).astype(dtype='float64')
    for p,v in zip(points, vs):
        log = extrinsic_to_log(manifold,mu,p,axis)
        term1 += (log.reshape(-1,1))@(log.reshape(1,-1))
    term1 /= len(points)
    return A@(term1-term2)

def objective(points, mu, Sigma, manifold, axis):
    d = manifold.dim
    result = 0
    inv = np.linalg.inv(Sigma)
    for p in points:
        log = extrinsic_to_log(manifold, mu, p, axis)
        result += np.dot(log, inv@log)
    result /= 2*len(points)
    return result + np.log(estimate_normalization_constant(points, mu, Sigma, manifold))

def convergence_criteria(points, manifold, e=1e-4):
    x = lambda mu0, Sigma0, mu, Sigma, axis: (objective(points, mu, Sigma, manifold, axis)-
                                         objective(points, mu0, Sigma0, manifold, axis))
    return (lambda mu0, Sigma0, mu, Sigma, axis: np.abs(x(mu, Sigma, mu0, Sigma0, axis))>e)

def mle_manifold(points, manifold, step_size_mu=1e-2, step_size_A=1e-2):
    d = manifold.dim
    Sigma0 = random_cov()
    mu0 = manifold.random_point()
    Sigma = random_cov()
    # Sigma = 0.1**2 * np.eye(d)
    mu = manifold.random_point()
    axis = get_axis(mu)
    criterion = convergence_criteria(points, manifold)
    count, max_loops = 0, 100
    mus = []
    mus.append(mu)
    while criterion(mu0, Sigma0, mu, Sigma, axis) and count < max_loops:
        # norm_const = estimate_normalization_constant(mu, Sigma, manifold)
        grad_mu = objective_grad_mu(points, mu, Sigma, axis, manifold)@axis
        mu0, Sigma0 = mu, Sigma
        mu = manifold.exp(mu0, step_size_mu*grad_mu)[0]
        axis = get_axis(mu)
        # norm_const = estimate_normalization_constant(mu, Sigma0, manifold)
        vals, vecs = np.linalg.eig(Sigma)
        A = (vecs@np.diag(1/np.sqrt(vals))).T
        grad_A = objective_grad_A(points, mu, Sigma0, axis, manifold)
        A -= step_size_A*grad_A
        Sigma = np.linalg.inv(A.T@A)
        
        mus.append(mu)
        count += 1

    return mus, Sigma

In [11]:
sigma = 0.7
sphere = PymanoptSphereManifold(dim)
x = sphere.random_point()
ax = get_axis(x)
data = np.random.multivariate_normal(np.zeros(dim - 1), sigma**2 * np.eye(dim - 1), size=1000)
tangent_vectors = data@ax
tangent_vectors[:10]

array([[ 0.19603225, -0.89002908, -0.51496046],
       [ 0.2948223 ,  0.0883736 ,  0.01150023],
       [-0.11995227,  0.23705154,  0.1456974 ],
       [-0.43741723, -0.70185013, -0.33143058],
       [-0.95701767, -0.81495472, -0.32820837],
       [-0.54954556, -0.59915402, -0.26072467],
       [ 0.78628798, -0.10842027, -0.15887067],
       [ 1.23566861,  0.1687388 , -0.06287457],
       [ 0.19593861, -0.1951525 , -0.13220072],
       [-0.18132178,  0.05382602,  0.05251289]])

In [13]:
projected_points = np.array([sphere.exp(x, t) for t in tangent_vectors])
print(f'projected_points: {projected_points[:10]}')
recon_vectors = np.array([sphere.log(x, y) for y in projected_points])
print(f'reconstructed vectors: {recon_vectors[:10]}')

projected_points: [[ 0.10721073 -0.49621689 -0.86155363]
 [ 0.18556381  0.5439665  -0.81833155]
 [-0.22290949  0.691145   -0.68748087]
 [-0.45082277 -0.31113059 -0.83663409]
 [-0.73910021 -0.47562163 -0.47698527]
 [-0.55731404 -0.21384221 -0.80229207]
 [ 0.62745037  0.2338706  -0.7427049 ]
 [ 0.90391452  0.27997602 -0.32334498]
 [ 0.08821639  0.26508707 -0.96018056]
 [-0.28783671  0.5238214  -0.80172387]]
reconstructed vectors: [[ 0.19603225 -0.89002908 -0.51496046]
 [ 0.2948223   0.0883736   0.01150023]
 [-0.11995227  0.23705154  0.1456974 ]
 [-0.43741723 -0.70185013 -0.33143058]
 [-0.95701767 -0.81495472 -0.32820837]
 [-0.54954556 -0.59915402 -0.26072467]
 [ 0.78628798 -0.10842027 -0.15887067]
 [ 1.23566861  0.1687388  -0.06287457]
 [ 0.19593861 -0.1951525  -0.13220072]
 [-0.18132178  0.05382602  0.05251289]]


In [15]:
tangent_vector = tangent_vectors[np.random.randint(0, len(tangent_vectors))]
print(f'tangent vector: {tangent_vector}')

y = exp_map(x, tangent_vector, projected_points)
y_auto = sphere.exp(x, tangent_vector)
print(f'y (our implementation): {y}')
print(f'y (pymanopt implementation): {y_auto}')

recon_vector = log_map(x, y, projected_points)
print(f'reconstructed vector: {recon_vector}')

tangent vector: [0.29276603 1.16093737 0.60254296]
y (our implementation): [ 0.13210325  1.55833971 -0.32857802]
y (pymanopt implementation): [0.18757455 0.9527805  0.23879929]
reconstructed vector: [0.23303758 1.04977982 0.54979234]


In [30]:
plot_tangent_points(x, tangent_vectors)

In [40]:
plot_projected_sphere(x, projected_points)

In [56]:
sigma = 0.1
sphere = PymanoptSphereManifold(dim)
mean = sphere.random_point()
axis = get_axis(x)
data = np.random.multivariate_normal(np.zeros(dim - 1), sigma**2 * np.eye(dim - 1), size=1000)
tangent_vectors = data@axis
projected_points = np.array([sphere.exp(x, t) for t in tangent_vectors])

print(f'x (mean): {x}')
print(f'Sigma: {sigma**2 * np.eye(dim - 1)}')

mus, opt_cov = mle_manifold(projected_points, sphere, 0.01, 0.01)
opt_mu = mus[-1]
print(f'optimal mean: {opt_mu}')
print(f'optimal covariance: {opt_cov}')

x (mean): [-0.10978552  0.47954988 -0.87061992]
Sigma: [[0.01 0.  ]
 [0.   0.01]]
optimal mean: [-0.57144377  0.46490498 -0.67625097]
optimal covariance: [[0.0758472  0.04244441]
 [0.04244441 0.17312804]]


In [58]:
plot_mean(x, mus)