In [1]:
import pymanopt
import numpy as np
from scipy.integrate import solve_ivp
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from pymanopt.manifolds import Sphere
import plotly.graph_objects as go

In [45]:
dim = 3
man = Sphere(dim)

In [46]:
# Compute local metric tensor
def local_metric_tensor(x, data, rho=1e-5):
    """
    Computes the local diagonal metric tensor at point x.

    Parameters:
    - x: (D,) point at which the metric is evaluated
    - data: (N, D) array of data points
    - rho: regularization constant to prevent singularities

    Returns:
    - metric_tensor: (D,) diagonal elements of the metric tensor at point x
    """
    # Gaussian kernel bandwidth
    sigma = 0.5

    # The metric tensor is a diagonal matrix with diagonal elements
    diff_sq = (data - x)**2
    distances_sq = np.sum(diff_sq, axis=1)
    weights = np.exp(-distances_sq / (2 * sigma**2))
    weighted_cov_diag = np.sum(weights[:, np.newaxis] * diff_sq, axis=0) + rho
    metric_tensor = np.diag(1.0 / weighted_cov_diag)

    return metric_tensor

# Determinant of local metric (assumed diagonal)
def det_local_metric(x, data, rho=1e-5):
    metric_tensor = local_metric_tensor(x, data, rho)
    metric_det = np.prod(metric_tensor)
    return metric_det

In [47]:
tol = 1e-5

def gram_schmidt(vectors):
    Q = np.zeros_like(vectors)
    for i in range(len(vectors)):
        v = vectors[i]
        for j in range(i):
            proj = np.dot(v, Q[j]) / np.dot(Q[j], Q[j]) * Q[j]
            v = v - proj
        Q[i] = v / np.linalg.norm(v)
    return Q
    
def get_axis(normal):
    # Assumes it is norm 1
    Q = np.eye(dim)
    arg = np.argmax(normal)
    if normal[arg]>(1-tol) and np.sum(normal)>(1-tol):
        return np.concatenate([Q[:arg], Q[arg+1:]])
    return gram_schmidt(np.vstack([normal,Q[:-1]]))[1:]

def get_e(i,n):
    return np.array(i*[0]+[1]+(n-i-1)*[0])

In [48]:
# Compute derivative of the metric tensor numerically
def metric_tensor_jacobian(x, data, eps=1e-5):
    D = len(x)
    jac = np.zeros((D, D, D))
    for i in range(D):
        dx = np.zeros(D)
        dx[i] = eps
        jac[:, :, i] = (local_metric_tensor(x + dx, data) - local_metric_tensor(x - dx, data)) / (2 * eps)
    return jac

# Geodesic ODE function
def geodesic_ode(t, y, data):
    D = len(y) // 2
    pos, vel = y[:D], y[D:]

    M = local_metric_tensor(pos, data)
    M_inv = np.linalg.inv(M)
    M_jac = metric_tensor_jacobian(pos, data)

    # Compute Christoffel term
    christoffel = np.zeros(D)
    for i in range(D):
        christoffel[i] = vel @ M_jac[:, :, i] @ vel

    acc = -0.5 * M_inv @ christoffel
    return np.concatenate([vel, acc])

# Exponential map (Initial value problem)
def exp_map(x, v, data, t_final=1.0):
    y0 = np.concatenate([x, v])
    sol = solve_ivp(geodesic_ode, [0, t_final], y0, args=(data,), method='RK45', atol=1e-8)
    return sol.y[:len(x), -1]

# Logarithmic map (Boundary value problem solved via optimization)
def log_map(x, y_target, data, t_final=1.0):
    D = len(x)

    def objective(v_guess):
        y_final = exp_map(x, v_guess, data, t_final)
        return np.linalg.norm(y_final - y_target)**2

    res = minimize(objective, np.zeros(D), method='BFGS', options={'gtol':1e-8, 'disp':False})
    return res.x

In [55]:
def estimate_normalization_constant(data, mu, Sigma, exp, num_samples=1000):
    D = len(mu)

    # Compute normalization of Euclidean normal distribution
    Z = np.sqrt((2 * np.pi) ** D * np.linalg.det(Sigma))

    # Generate tangent space samples
    axis = get_axis(mu)
    vectors = np.random.multivariate_normal(np.zeros(D), Sigma, num_samples)
    tangent_vectors = vectors@axis

    # Perform Monte Carlo integration
    metric_sum = 0
    for v in tangent_vectors:
        x_s = exp(mu, v, data)
        metric_det = np.abs(np.linalg.det(local_metric_tensor(x_s, data)))
        metric_sum += np.sqrt(metric_det)

    C_hat = Z * metric_sum / num_samples

    return C_hat

In [50]:
sigma = 0.7
x = man.random_point()
ax = get_axis(x)
data = np.random.multivariate_normal(np.zeros(dim - 1), sigma**2 * np.eye(dim - 1), size=1000)
tangent_vectors = data@ax
tangent_vectors[:10]

array([[-0.0332278 , -0.1162302 ,  0.27888432],
       [-1.41482827, -1.02367815, -0.37403805],
       [ 0.00483307,  0.06926743, -0.20395488],
       [-0.31167053,  0.22296873, -1.48182758],
       [-0.28979637, -0.27018546,  0.1121964 ],
       [ 0.37309062,  0.31906908, -0.05465689],
       [-0.47267079, -0.38056677, -0.00459616],
       [ 0.23148886, -0.02912111,  0.67471172],
       [-0.37641178, -0.47829956,  0.54314869],
       [-0.33855164,  0.03061379, -0.94939398]])

In [51]:
projected_points = np.array([man.exp(x, t) for t in tangent_vectors])
print(f'projected_points: {projected_points[:10]}')
recon_vectors = np.array([man.log(x, y) for y in projected_points])
print(f'reconstructed vectors: {recon_vectors[:10]}')

projected_points: [[ 0.54926535 -0.83449632  0.04385733]
 [-0.90415565 -0.3988733  -0.15297927]
 [ 0.60063792 -0.66846073 -0.43862779]
 [-0.17893994  0.11521563 -0.97709051]
 [ 0.27728106 -0.95417149 -0.11256992]
 [ 0.89513777 -0.35818378 -0.26539169]
 [ 0.05684808 -0.97753198 -0.20297666]
 [ 0.67331924 -0.59707149  0.43605829]
 [ 0.08197234 -0.94415407  0.31914515]
 [ 0.04120247 -0.37669257 -0.92542156]]
reconstructed vectors: [[-0.0332278  -0.1162302   0.27888432]
 [-1.41482827 -1.02367815 -0.37403805]
 [ 0.00483307  0.06926743 -0.20395488]
 [-0.31167053  0.22296873 -1.48182758]
 [-0.28979637 -0.27018546  0.1121964 ]
 [ 0.37309062  0.31906908 -0.05465689]
 [-0.47267079 -0.38056677 -0.00459616]
 [ 0.23148886 -0.02912111  0.67471172]
 [-0.37641178 -0.47829956  0.54314869]
 [-0.33855164  0.03061379 -0.94939398]]


In [52]:
tangent_vector = tangent_vectors[np.random.randint(0, len(tangent_vectors))]
print(f'tangent vector: {tangent_vector}')

y = exp_map(x, tangent_vector, projected_points)
y_auto = man.exp(x, tangent_vector)
print(f'y (our implementation): {y}')
print(f'y (pymanopt implementation): {y_auto}')

recon_vector = log_map(x, y, projected_points)
print(f'reconstructed vector: {recon_vector}')

tangent vector: [-0.63675844 -0.72356075  0.65184608]
y (our implementation): [ 0.01451016 -1.38283575  0.36971809]
y (pymanopt implementation): [-0.26091793 -0.8698749   0.41861616]
reconstructed vector: [-0.62505274 -0.68970806  0.6678184 ]


In [53]:
# Sphere mesh
theta, phi = np.mgrid[0:2*np.pi:50j, 0:np.pi:25j]
xs = np.cos(theta)*np.sin(phi)
ys = np.sin(theta)*np.sin(phi)
zs = np.cos(phi)

tangent_points = tangent_vectors + x

fig = go.Figure()

# Sphere surface
fig.add_trace(go.Surface(x=xs, y=ys, z=zs, opacity=0.3, colorscale='Viridis', showscale=False))

# Projected points
fig.add_trace(go.Scatter3d(x=tangent_points[:,0], y=tangent_points[:,1], z=tangent_points[:,2],
                           mode='markers', marker=dict(size=3, color='red'), name='Tangent Points'))

# Base point
fig.add_trace(go.Scatter3d(x=[x[0]], y=[x[1]], z=[x[2]],
                           mode='markers', marker=dict(size=8, color='black'), name='Base Point'))

fig.update_layout(
    scene=dict(
        aspectmode='data',
    ),
    width=700, 
    height=700,
    title='Tangent Points and Sphere Mesh'
)

fig.show()

In [54]:
fig = go.Figure()

# Sphere surface
fig.add_trace(go.Surface(x=xs, y=ys, z=zs, opacity=0.3, colorscale='Viridis', showscale=False))

# Projected points
fig.add_trace(go.Scatter3d(x=projected_points[:,0], y=projected_points[:,1], z=projected_points[:,2],
                           mode='markers', marker=dict(size=3, color='red'), name='Projected Points'))

# Base point
fig.add_trace(go.Scatter3d(x=[x[0]], y=[x[1]], z=[x[2]],
                           mode='markers', marker=dict(size=8, color='black'), name='Base Point'))

fig.update_layout(
    scene=dict(
        aspectmode="cube"
    ), 
    width=700, 
    height=700,
    title='Interactive Projection on Sphere'
)
fig.show()