From 44e3fceec8aa52c46ceb2fa35b9135c0c8c3138c Mon Sep 17 00:00:00 2001 From: Max Kochurov Date: Thu, 14 Nov 2019 19:05:14 +0300 Subject: [PATCH] fix poincare dist2 --- geoopt/manifolds/poincare/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/geoopt/manifolds/poincare/__init__.py b/geoopt/manifolds/poincare/__init__.py index 27305233..3224bcc5 100644 --- a/geoopt/manifolds/poincare/__init__.py +++ b/geoopt/manifolds/poincare/__init__.py @@ -62,6 +62,11 @@ def dist( ) -> torch.Tensor: return math.dist(x, y, c=self.c, keepdim=keepdim, dim=dim) + def dist2( + self, x: torch.Tensor, y: torch.Tensor, *, keepdim=False, dim=-1 + ) -> torch.Tensor: + return math.dist(x, y, c=self.c, keepdim=keepdim, dim=dim) ** 2 + def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor, *, dim=-1) -> torch.Tensor: return math.egrad2rgrad(x, u, c=self.c, dim=dim)