Skip to content

Commit

Permalink
fix nans in corner case of logmap in sphere
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Nov 14, 2019
1 parent 55bb7b8 commit 2b00320
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
14 changes: 9 additions & 5 deletions geoopt/manifolds/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

__all__ = ["Sphere", "SphereExact"]

EPS = {torch.float32: 1e-4, torch.float64: 1e-8}
EPS = {torch.float32: 1e-4, torch.float64: 1e-7}

_sphere_doc = r"""
Sphere manifold induced by the following constraint
Expand Down Expand Up @@ -147,12 +147,16 @@ def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Ten
def logmap(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
u = self.proju(x, y - x)
dist = self.dist(x, y, keepdim=True)
# If the two points are "far apart", correct the norm.
cond = dist.gt(EPS[dist.dtype])
return torch.where(cond, u * dist / u.norm(dim=-1, keepdim=True), u)
cond = dist.gt(EPS[x.dtype])
result = torch.where(
cond, u * dist / u.norm(dim=-1, keepdim=True).clamp_min(EPS[x.dtype]), u
)
return result

def dist(self, x: torch.Tensor, y: torch.Tensor, *, keepdim=False) -> torch.Tensor:
inner = self.inner(x, x, y, keepdim=keepdim).clamp(-0.9999, 0.9999)
inner = self.inner(x, x, y, keepdim=keepdim).clamp(
-1 + EPS[x.dtype], 1 - EPS[x.dtype]
)
return torch.acos(inner)

egrad2rgrad = proju
Expand Down
2 changes: 1 addition & 1 deletion geoopt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def ismanifold(instance, cls):
return isinstance(instance, cls)


def canonical_manifold(manifold: geoopt.Manifold):
def canonical_manifold(manifold: "geoopt.Manifold"):
"""
Get a canonical manifold.
Expand Down
3 changes: 2 additions & 1 deletion tests/test_manifold_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,9 @@ def test_logmap(unary_case):
Y = unary_case.manifold.expmap(pX, U)
Uh = unary_case.manifold.logmap(pX, Y)
Yh = unary_case.manifold.expmap(pX, Uh)

np.testing.assert_allclose(Yh, Y, atol=1e-6, rtol=1e-6)
Zero = unary_case.manifold.logmap(pX, pX)
np.testing.assert_allclose(Zero, 0.0, atol=1e-6, rtol=1e-6)
except NotImplementedError:
pytest.skip("logmap was not implemented for {}".format(unary_case.manifold))

Expand Down

0 comments on commit 2b00320

Please sign in to comment.