In [7]:
import geoopt
import torch

manifold = geoopt.manifolds.Lorentz(k=1.0)

x = torch.randn(100, 1000)
x = x.double()  # Ensure features are in double precision
zeros = torch.zeros(100, 999) # [N, D-1]
ones = torch.ones(100, 1) # [N, 1]
origin = torch.cat([ones, zeros], dim=-1) # [N, D]


In [2]:
def proj_tan0(u, c):
    narrowed = u.narrow(-1, 0, 1)
    vals = torch.zeros_like(u)
    vals[:, 0:1] = narrowed
    return u - vals

In [8]:
x_T0 = proj_tan0(x, c=1.0)
print("Projected x onto T0:", x_T0)

Projected x onto T0: tensor([[ 0.0000e+00, -1.0741e+00,  1.3305e-01,  ..., -4.1644e-01,
         -2.1414e+00, -6.8367e-01],
        [ 0.0000e+00, -7.1185e-02, -1.2423e+00,  ...,  4.9543e-01,
         -6.9051e-02, -6.4188e-01],
        [ 0.0000e+00,  4.0684e-01,  8.8712e-01,  ...,  7.7168e-01,
         -3.5015e+00, -9.5821e-01],
        ...,
        [ 0.0000e+00,  9.6764e-01, -2.5009e+00,  ..., -7.9908e-01,
          4.5771e-01,  2.5597e-03],
        [ 0.0000e+00,  5.6271e-01,  8.7310e-01,  ...,  6.1095e-01,
         -5.2153e-01, -1.5108e+00],
        [ 0.0000e+00,  3.0983e+00, -5.7601e-01,  ...,  6.4639e-01,
         -1.6819e-01,  1.4049e-02]], dtype=torch.float64)


In [9]:
x_T0 = manifold.proju(origin, x)
print("Projected x onto T0 using manifold.proju:", x_T0)
x = manifold.projx(manifold.expmap(origin, x_T0))
print(manifold.check_point_on_manifold(x))

Projected x onto T0 using manifold.proju: tensor([[ 0.0000e+00, -1.0741e+00,  1.3305e-01,  ..., -4.1644e-01,
         -2.1414e+00, -6.8367e-01],
        [ 0.0000e+00, -7.1185e-02, -1.2423e+00,  ...,  4.9543e-01,
         -6.9051e-02, -6.4188e-01],
        [ 0.0000e+00,  4.0684e-01,  8.8712e-01,  ...,  7.7168e-01,
         -3.5015e+00, -9.5821e-01],
        ...,
        [ 0.0000e+00,  9.6764e-01, -2.5009e+00,  ..., -7.9908e-01,
          4.5771e-01,  2.5597e-03],
        [ 0.0000e+00,  5.6271e-01,  8.7310e-01,  ...,  6.1095e-01,
         -5.2153e-01, -1.5108e+00],
        [ 0.0000e+00,  3.0983e+00, -5.7601e-01,  ...,  6.4639e-01,
         -1.6819e-01,  1.4049e-02]], dtype=torch.float64)
False


In [5]:
x[0].shape


torch.Size([100])

In [6]:
dist = manifold.dist(x[0], x[0])

print(dist)

tensor(3.1623e-08, dtype=torch.float64)
