Skip to content

Commit

Permalink
Merge a5ab672 into b2bdd93
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed May 7, 2019
2 parents b2bdd93 + a5ab672 commit c0b8283
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 89 deletions.
2 changes: 1 addition & 1 deletion geoopt/manifolds/poincare/__init__.py
Expand Up @@ -65,7 +65,7 @@ def _projx(self, x):
return math.project(x, c=self.c)

def _proju(self, x, u):
return math.clip_tangent(x, u, c=self.c)
return u

def _inner(self, x, u, v, keepdim):
return math.inner(x, u, v, c=self.c, keepdim=keepdim)
Expand Down
131 changes: 44 additions & 87 deletions geoopt/manifolds/poincare/math.py
Expand Up @@ -10,6 +10,10 @@
import torch.jit


MIN_NORM = 1e-15
BALL_EPS = {torch.float32: 4e-3, torch.float64: 1e-5}


def tanh(x):
return x.clamp(-15, 15).tanh()

Expand All @@ -35,7 +39,7 @@ class Arsinh(torch.autograd.Function):
def forward(ctx, x):
ctx.save_for_backward(x)
z = x.double()
return (z + torch.sqrt_(1 + z.pow(2))).clamp_min_(1e-15).log_().to(x.dtype)
return (z + torch.sqrt_(1 + z.pow(2))).clamp_min_(MIN_NORM).log_().to(x.dtype)

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -51,7 +55,7 @@ def arsinh(x):
return Arsinh.apply(x)


def project(x, *, c=1.0, dim=-1):
def project(x, *, c=1.0, dim=-1, eps=None):
r"""
Safe projection on the manifold for numerical stability.
Expand All @@ -63,27 +67,22 @@ def project(x, *, c=1.0, dim=-1):
ball negative curvature
dim : int
reduction dimension to compute norm
eps : float
stability parameter, uses default for dtype if not provided
Returns
-------
tensor
projected vector on the manifold
"""
return _project(x, c, dim)


@torch.jit.script
def _max_norm(x):
if x.dtype == torch.float32:
maxnorm = torch.full((), 1 - 4e-3, dtype=x.dtype, device=x.device)
else:
maxnorm = torch.full((), 1 - 1e-5, dtype=x.dtype, device=x.device)
return maxnorm
return _project(x, c, dim, eps)


def _project(x, c, dim: int = -1):
norm = x.norm(dim=dim, keepdim=True, p=2)
maxnorm = _max_norm(x) / (c ** 0.5)
def _project(x, c, dim: int = -1, eps: float = None):
norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(MIN_NORM)
if eps is None:
eps = BALL_EPS[x.dtype]
maxnorm = (1 - eps) / (c ** 0.5)
cond = norm > maxnorm
projected = x / norm * maxnorm
return torch.where(cond, projected, x)
Expand Down Expand Up @@ -117,7 +116,7 @@ def lambda_x(x, *, c=1.0, keepdim=False, dim=-1):


def _lambda_x(x, c, keepdim: bool = False, dim: int = -1):
return 2 / (1 - c * x.pow(2).sum(dim=dim, keepdim=keepdim))
return 2 / (1 - c * x.pow(2).sum(dim=dim, keepdim=keepdim)).clamp_min(MIN_NORM)


def inner(x, u, v, *, c=1.0, keepdim=False, dim=-1):
Expand Down Expand Up @@ -252,14 +251,24 @@ def mobius_add(x, y, *, c=1.0, dim=-1):


def _mobius_add(x, y, c, dim=-1):
y = y + 1e-15
x2 = x.pow(2).sum(dim=dim, keepdim=True)
y2 = y.pow(2).sum(dim=dim, keepdim=True)
xy = (x * y).sum(dim=dim, keepdim=True)
num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
# avoid division by zero in this way
return num / (denom + 1e-15)
# minimize denom (omit c to simplify th notation)
# 1)
# {d(denom)/d(x) = 2 y + 2x * <y, y> = 0
# {d(denom)/d(y) = 2 x + 2y * <x, x> = 0
# 2)
# {y + x * <y, y> = 0
# {x + y * <x, x> = 0
# 3)
# {- y/<y, y> = x
# {- x/<x, x> = y
# 4)
# minimum = 1 - 2 <y, y>/<y, y> + <y, y>/<y, y> = 0
return num / denom.clamp_min(MIN_NORM)


def mobius_sub(x, y, *, c=1.0, dim=-1):
Expand Down Expand Up @@ -339,13 +348,12 @@ def mobius_coadd(x, y, *, c=1.0, dim=-1):


def _mobius_coadd(x, y, c, dim: int = -1):
y = y + 1e-15
x2 = x.pow(2).sum(dim=dim, keepdim=True)
y2 = y.pow(2).sum(dim=dim, keepdim=True)
num = (1 - c * y2) * x + (1 - c * x2) * y
denom = 1 - c ** 2 * x2 * y2
# avoid division by zero in this way
return num / (denom + 1e-15)
return num / denom.clamp_min(MIN_NORM)


def mobius_cosub(x, y, *, c=1.0, dim=-1):
Expand Down Expand Up @@ -434,8 +442,7 @@ def mobius_scalar_mul(r, x, *, c=1.0, dim=-1):


def _mobius_scalar_mul(r, x, c, dim: int = -1):
x = x + 1e-15
x_norm = x.norm(dim=dim, keepdim=True, p=2)
x_norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(MIN_NORM)
sqrt_c = c ** 0.5
res_c = tanh(r * artanh(sqrt_c * x_norm)) * x / (x_norm * sqrt_c)
return res_c
Expand Down Expand Up @@ -509,51 +516,6 @@ def _dist0(x, c, keepdim: bool = False, dim: int = -1):
return dist_c * 2 / sqrt_c


def clip_tangent(x, u, *, c=1.0, dim=-1):
r"""
Project tangent vector to reasonable values that do not exceed
maximum allowed (vector norm allowing to travel to the opposite pole)
.. math::
\operatorname{maxnorm}_x = d_{c}(\operatorname{proj}(-\infty), \operatorname{proj}(\infty)) / \lambda_x^c
Parameters
----------
x : tensor
point on Poincare ball
u : tensor
tangent vector
c : float|tensor
ball negative curvature
dim : int
reduction dimension to compute norm
Returns
-------
tensor
same tangent vector with reasonable values
"""
return _clip_tangent(x, u, c, dim=dim)


def _clip_tangent(x, u, c, dim: int = -1):
# get the almost infinite vecotor estimate
# this is the norm of travel vector to the opposite pole
s = x.size(dim)
p = torch.ones((s,), dtype=x.dtype, device=x.device)
p = p / s ** 0.5 / (c ** 0.5)
p = _project(p, c, dim=dim)
# normalize its length based on x
maxnorm = _dist(p, -p, c, keepdim=True, dim=dim) / _lambda_x(
x, c, keepdim=True, dim=dim
)
norm = u.norm(dim=dim, keepdim=True, p=2)
cond = norm > maxnorm
projected = u / norm * maxnorm
return torch.where(cond, projected, u)


def geodesic(t, x, y, *, c=1.0, dim=-1):
r"""
Geodesic (the shortest) path connecting :math:`x` and :math:`y`.
Expand Down Expand Up @@ -659,9 +621,8 @@ def expmap(x, u, *, c=1.0, dim=-1):


def _expmap(x, u, c, dim: int = -1):
u = u + 1e-15
sqrt_c = c ** 0.5
u_norm = u.norm(dim=dim, p=2, keepdim=True)
u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM)
second_term = (
tanh(sqrt_c / 2 * _lambda_x(x, c, keepdim=True, dim=dim) * u_norm)
* u
Expand Down Expand Up @@ -697,9 +658,8 @@ def expmap0(u, *, c=1.0, dim=-1):


def _expmap0(u, c, dim: int = -1):
u = u + 1e-15
sqrt_c = c ** 0.5
u_norm = u.norm(dim=dim, p=2, keepdim=True)
u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM)
gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
return gamma_1

Expand Down Expand Up @@ -735,7 +695,7 @@ def geodesic_unit(t, x, u, *, c=1.0, dim=-1):

def _geodesic_unit(t, x, u, c, dim: int = -1):
sqrt_c = c ** 0.5
u_norm = u.norm(dim=dim, p=2, keepdim=True)
u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM)
second_term = tanh(sqrt_c / 2 * t) * u / (sqrt_c * u_norm)
gamma_1 = _mobius_add(x, second_term, c, dim=dim)
return gamma_1
Expand Down Expand Up @@ -779,7 +739,7 @@ def logmap(x, y, *, c=1.0, dim=-1):

def _logmap(x, y, c, dim: int = -1):
sub = _mobius_add(-x, y, c, dim=dim)
sub_norm = sub.norm(dim=dim, p=2, keepdim=True)
sub_norm = sub.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM)
lam = _lambda_x(x, c, keepdim=True, dim=dim)
sqrt_c = c ** 0.5
return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm
Expand Down Expand Up @@ -819,8 +779,7 @@ def logmap0(y, *, c=1.0, dim=-1):

def _logmap0(y, c, dim: int = -1):
sqrt_c = c ** 0.5
y = y + 1e-15
y_norm = y.norm(dim=dim, p=2, keepdim=True)
y_norm = y.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM)
return y / y_norm / sqrt_c * artanh(sqrt_c * y_norm)


Expand Down Expand Up @@ -861,14 +820,13 @@ def _mobius_matvec(m, x, c, dim: int = -1):
raise RuntimeError(
"broadcasted Mobius matvec is supported for the last dim only"
)
x = x + 1e-15
x_norm = x.norm(dim=dim, keepdim=True, p=2)
x_norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(MIN_NORM)
sqrt_c = c ** 0.5
if dim != -1 or m.dim() == 2:
mx = torch.tensordot(x, m, dims=([dim], [1]))
else:
mx = torch.matmul(m, x.unsqueeze(-1)).squeeze(-1)
mx_norm = mx.norm(dim=dim, keepdim=True, p=2)
mx_norm = mx.norm(dim=dim, keepdim=True, p=2).clamp_min(MIN_NORM)
res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c)
cond = (mx == 0).prod(dim=dim, keepdim=True, dtype=torch.uint8)
res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
Expand Down Expand Up @@ -907,11 +865,10 @@ def mobius_pointwise_mul(w, x, *, c=1.0, dim=-1):


def _mobius_pointwise_mul(w, x, c, dim: int = -1):
x = x + 1e-15
x_norm = x.norm(dim=dim, keepdim=True, p=2)
x_norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(MIN_NORM)
sqrt_c = c ** 0.5
wx = w * x
wx_norm = wx.norm(dim=dim, keepdim=True, p=2)
wx_norm = wx.norm(dim=dim, keepdim=True, p=2).clamp_min(MIN_NORM)
res_c = tanh(wx_norm / x_norm * artanh(sqrt_c * x_norm)) * wx / (wx_norm * sqrt_c)
cond = (wx == 0).prod(dim=dim, keepdim=True, dtype=torch.uint8)
res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
Expand Down Expand Up @@ -1144,14 +1101,14 @@ def dist2plane(x, p, a, *, c=1.0, keepdim=False, signed=False, dim=-1):
def _dist2plane(x, a, p, c, keepdim: bool = False, signed: bool = False, dim: int = -1):
sqrt_c = c ** 0.5
diff = _mobius_add(-p, x, c, dim=dim)
diff_norm2 = diff.pow(2).sum(dim=dim, keepdim=keepdim)
diff_norm2 = diff.pow(2).sum(dim=dim, keepdim=keepdim).clamp_min(MIN_NORM)
sc_diff_a = (diff * a).sum(dim=dim, keepdim=keepdim)
if not signed:
sc_diff_a = sc_diff_a.abs()
a_norm = a.norm(dim=dim, keepdim=keepdim, p=2)
a_norm = a.norm(dim=dim, keepdim=keepdim, p=2).clamp_min(MIN_NORM)
num = 2 * sqrt_c * sc_diff_a
denom = (1 - c * diff_norm2) * a_norm
return arsinh(num / (denom + 1e-15)) / sqrt_c
return arsinh(num / denom.clamp_min(MIN_NORM)) / sqrt_c


def gyration(a, b, u, *, c=1.0, dim=-1):
Expand Down Expand Up @@ -1222,7 +1179,7 @@ def _gyration(u, v, w, c, dim: int = -1):
a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw
b = -c2 * vw * u2 - c * uw
d = 1 + 2 * c * uv + c2 * u2 * v2
return w + 2 * (a * u + b * v) / (d + 1e-15)
return w + 2 * (a * u + b * v) / d.clamp_min(MIN_NORM)


def parallel_transport(x, y, v, *, c=1.0, dim=-1):
Expand Down Expand Up @@ -1313,7 +1270,7 @@ def parallel_transport0(y, v, *, c=1.0, dim=-1):


def _parallel_transport0(y, v, c, dim: int = -1):
return v * (1 - c * y.pow(2).sum(dim=dim, keepdim=True))
return v * (1 - c * y.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM)


def egrad2rgrad(x, grad, *, c=1.0, dim=-1):
Expand Down
2 changes: 2 additions & 0 deletions geoopt/optim/radam.py
Expand Up @@ -198,6 +198,8 @@ def stabilize_group(self, group):
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
state = self.state[p]
if not state: # due to None grads
continue
manifold = p.manifold
exp_avg = state["exp_avg"]
p.set_(manifold.projx(p))
Expand Down
2 changes: 2 additions & 0 deletions geoopt/optim/rsgd.py
Expand Up @@ -183,6 +183,8 @@ def stabilize_group(self, group):
p.set_(manifold.projx(p))
if momentum > 0:
param_state = self.state[p]
if not param_state: # due to None grads
continue
if "momentum_buffer" in param_state:
buf = param_state["momentum_buffer"]
buf.set_(manifold.proju(p, buf))
Expand Down
1 change: 0 additions & 1 deletion tests/test_poincare_math.py
Expand Up @@ -331,7 +331,6 @@ def test_parallel_transport_a_b(a, b, c):

def test_add_infinity_and_beyond(a, b, c):
infty = b * 10000000
infty = poincare.math.clip_tangent(a, infty, c=c)
for i in range(100):
z = poincare.math.expmap(a, infty, c=c)
z = poincare.math.project(z, c=c)
Expand Down

0 comments on commit c0b8283

Please sign in to comment.