Skip to content

Commit

Permalink
add scalar mul, test props
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Feb 14, 2019
1 parent a09bcc0 commit ffc4f6d
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 28 deletions.
139 changes: 127 additions & 12 deletions geoopt/manifolds/poincare/math.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
import torch
import torch.jit


@torch.jit.script
def tanh(x):
return x.clamp(-15, 15).tanh()


# noinspection PyTypeChecker,PyUnresolvedReferences
@torch.jit.script
def artanh(x):
res = (0.5 * (torch.log(1 + x) - torch.log(1 - x))).clamp(-1 + 1e-5, 1 - 1e-5)
return res


def project(x, *, c):
Expand All @@ -22,9 +35,16 @@ def project(x, *, c):
.. [1] Hyperbolic Neural Networks, NIPS2018
https://arxiv.org/abs/1805.09112
"""
norm = x.norm(-1, keepdim=True)
maxnorm = (1 - 1e-5) / (c ** 0.5 + 1e-15)
cond = norm > maxnorm
if not isinstance(c, torch.Tensor):
c = torch.as_tensor(c).type_as(x)
return _project(x, c)


@torch.jit.script
def _project(x, c):
norm = x.norm(dim=-1, keepdim=True, p=2)
maxnorm = (1 - 1e-5) / (c ** 0.5)
cond = (norm > maxnorm) & (c > 1e-10)
projected = x / norm * maxnorm
return torch.where(cond, projected, x)

Expand All @@ -33,7 +53,7 @@ def lambda_x(x, *, c):
r"""
Compute the conformal factor :math:`\lambda_x` for a point on the ball
..math::
.. math::
\lambda_x = \frac{1}{1 - c \|x\|_2^2}
Expand All @@ -49,14 +69,21 @@ def lambda_x(x, *, c):
scalar
conformal factor
"""
if not isinstance(c, torch.Tensor):
c = torch.as_tensor(c).type_as(x)
return _lambda_x(x, c)


@torch.jit.script
def _lambda_x(x, c):
return 2 / (1 - c * x.pow(2).sum(-1))


def inner(x, u, v, *, c):
r"""
Compute inner product for two vectors on the tangent space w.r.t Riemannian metric on the Poincare ball
..math::
.. math::
\langle u, v\rangle_x = \lambda_x^2 \langle u, v \rangle
Expand All @@ -76,7 +103,14 @@ def inner(x, u, v, *, c):
scalar
inner product
"""
return lambda_x(x, c=c) ** 2 * (u * v).sum(-1)
if not isinstance(c, torch.Tensor):
c = torch.as_tensor(c).type_as(x)
return _inner(x, u, v, c)


@torch.jit.script
def _inner(x, u, v, c):
return _lambda_x(x, c) ** 2 * (u * v).sum(-1)


def mobius_add(x, y, *, c):
Expand All @@ -99,7 +133,7 @@ def mobius_add(x, y, *, c):
But in some cases this property holds:
* zero vector vase
* zero vector case
.. math::
Expand Down Expand Up @@ -131,10 +165,17 @@ def mobius_add(x, y, *, c):
tensor
the result of mobius addition
"""
y = y + 1e-15 # add small epsilon for stability
x2 = x.pow(2).sum(-1, keepdim=True)
y2 = y.pow(2).sum(-1, keepdim=True)
xy = (x * y).sum(-1, keepdim=True)
if not isinstance(c, torch.Tensor):
c = torch.as_tensor(c).type_as(x)
return _mobius_add(x, y, c)


@torch.jit.script
def _mobius_add(x, y, c):
y = y + 1e-15
x2 = x.pow(2).sum(dim=-1, keepdim=True)
y2 = y.pow(2).sum(dim=-1, keepdim=True)
xy = (x * y).sum(dim=-1, keepdim=True)
num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
return num / denom
Expand Down Expand Up @@ -162,4 +203,78 @@ def mobius_sub(x, y, *, c):
tensor
the result of mobius substraction
"""
return mobius_add(x, -y, c=c)
if not isinstance(c, torch.Tensor):
c = torch.as_tensor(c).type_as(x)
return _mobius_sub(x, y, c)


@torch.jit.script
def _mobius_sub(x, y, c):
return _mobius_add(x, -y, c)


def mobius_scalar_mul(r, x, *, c):
r"""
Left scalar multiplication on the Poincare ball
.. math::
r \otimes_c x = (1/\sqrt{c}) \tanh(r\tanh^{-1}(\sqrt{c}\|x\|_2))\frac{x}{\|x\|_2}
This operation has properties similar to euclidean
* `n-addition` property
.. math::
r \otimes_c x = x \oplus_c \dots \oplus_c x
* Distributive property
.. math::
(r_1 + r_2) \otimes_c x = r_1 \otimes_c x \oplus r_2 \otimes_c x
* Scalar associativity
.. math::
(r_1 r_2) \otimes_c x = r_1 \otimes_c (r_2 \otimes_c x)
* Scaling property
.. math::
|r| \otimes_c x / \|r \otimes_c x\|_2 = x/\|x\|_2
Parameters
----------
r : float|tensor
scalar for multiplication
x : tensor
point on poincare ball
c : float|tensor
ball negative curvature
Returns
-------
tensor
the result of mobius scalar multiplication
"""
if not isinstance(c, torch.Tensor):
c = torch.as_tensor(c).type_as(x)
if not isinstance(r, torch.Tensor):
r = torch.as_tensor(r).type_as(x)
return _mobius_scalar_mul(r, x, c)


@torch.jit.script
def _mobius_scalar_mul(r, x, c):
x = x + 1e-15
x_norm = x.norm(dim=-1, keepdim=True, p=2)
cond = c < 1e-10
sqrt_c = c ** 0.5
res_0 = x * r
res_c = tanh(r * artanh(sqrt_c * x_norm)) * x / (x_norm * sqrt_c)
res = torch.where(cond, res_0, res_c)
return _project(res, c)
146 changes: 130 additions & 16 deletions tests/test_poincare_math.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Tests ideas are taken mostly from https://github.com/dalab/hyperbolic_nn/blob/master/util.py with some changes
"""
import torch
import random
import numpy as np
Expand All @@ -13,31 +16,142 @@ def seed(request):
return seed


def test_mobius_addition_left_cancelation_test(seed):
a = torch.randn(100, 10, dtype=torch.float64)
b = torch.randn(100, 10, dtype=torch.float64)

@pytest.fixture
def c(seed):
# test broadcasted and non broadcasted versions
if seed == 30:
c = 0
elif seed == 35:
c = torch.zeros(100, 1, dtype=torch.float64)
elif seed > 35:
c = torch.rand(100, 1, dtype=torch.float64)
else:
c = random.random()
a = poincare.math.project(a, c=c)
b = poincare.math.project(b, c=c)
return c

res = poincare.math.mobius_add(-a, poincare.math.mobius_add(a, b, c=c), c=c)
np.testing.assert_allclose(res, b)

@pytest.fixture
def a(seed, c):
if seed in {30, 35}:
a = torch.randn(100, 10, dtype=torch.float64)
elif seed > 35:
# do not check numerically unstable regions
# I've manually observed small differences there
a = torch.empty(100, 10, dtype=torch.float64).normal_(-1, 1)
a /= a.norm(dim=-1, keepdim=True) * 1.3
a *= (torch.rand_like(c) * c) ** 0.5
else:
a = torch.empty(100, 10, dtype=torch.float64).normal_(-1, 1)
a /= a.norm(dim=-1, keepdim=True) * 1.3
a *= random.uniform(0, c) ** 0.5
return a

def test_mobius_addition_left_cancelation_test_broadcasted(seed):
a = torch.randn(100, 10, dtype=torch.float64)
b = torch.randn(100, 10, dtype=torch.float64)

if seed == 30:
c = torch.zeros(*a.shape[:-1], 1, dtype=torch.float64)
@pytest.fixture
def b(seed, c):
if seed in {30, 35}:
b = torch.randn(100, 10, dtype=torch.float64)
elif seed > 35:
b = torch.empty(100, 10, dtype=torch.float64).normal_(-1, 1)
b /= b.norm(dim=-1, keepdim=True) * 1.3
b *= (torch.rand_like(c) * c) ** 0.5
else:
c = torch.rand(*a.shape[:-1], 1, dtype=torch.float64)
a = poincare.math.project(a, c=c)
b = poincare.math.project(b, c=c)
b = torch.empty(100, 10, dtype=torch.float64).normal_(-1, 1)
b /= b.norm(dim=-1, keepdim=True) * 1.3
b *= random.uniform(0, c) ** 0.5
return b


def test_mobius_addition_left_cancelation(a, b, c):
res = poincare.math.mobius_add(-a, poincare.math.mobius_add(a, b, c=c), c=c)
np.testing.assert_allclose(res, b)


def test_mobius_addition_left_cancelation_broadcasted(a, b, c):
res = poincare.math.mobius_add(-a, poincare.math.mobius_add(a, b, c=c), c=c)
np.testing.assert_allclose(res, b)


def test_mobius_addition_zero_a(b, c):
a = torch.zeros(100, 10, dtype=torch.float64)
res = poincare.math.mobius_add(a, b, c=c)
np.testing.assert_allclose(res, b)


def test_mobius_addition_zero_b(a, c):
b = torch.zeros(100, 10, dtype=torch.float64)
res = poincare.math.mobius_add(a, b, c=c)
np.testing.assert_allclose(res, a)


def test_mobius_addition_negative_cancellation(a, c):
res = poincare.math.mobius_add(a, -a, c=c)
np.testing.assert_allclose(res, torch.zeros_like(res), atol=1e-10)


def test_mobius_negative_addition(a, b, c):
res = poincare.math.mobius_add(-b, -a, c=c)
res1 = -poincare.math.mobius_add(b, a, c=c)
np.testing.assert_allclose(res, res1, atol=1e-10)


@pytest.mark.parametrize("n", list(range(5)))
def test_n_additions_via_scalar_multiplication(n, a, c):
y = torch.zeros_like(a)
for _ in range(n):
y = poincare.math.mobius_add(a, y, c=c)
ny = poincare.math.mobius_scalar_mul(n, a, c=c)
np.testing.assert_allclose(y, ny, atol=1e-7, rtol=1e-10)


@pytest.fixture
def r1(seed):
if seed % 3 == 0:
return random.uniform(-1, 1)
else:
return torch.rand(100, 1, dtype=torch.float64) * 2 - 1


@pytest.fixture
def r2(seed):
if seed % 3 == 1:
return random.uniform(-1, 1)
else:
return torch.rand(100, 1, dtype=torch.float64) * 2 - 1


def test_scalar_multiplication_distributive(a, c, r1, r2):
res = poincare.math.mobius_scalar_mul(r1 + r2, a, c=c)
res1 = poincare.math.mobius_add(
poincare.math.mobius_scalar_mul(r1, a, c=c),
poincare.math.mobius_scalar_mul(r2, a, c=c),
c=c,
)
res2 = poincare.math.mobius_add(
poincare.math.mobius_scalar_mul(r1, a, c=c),
poincare.math.mobius_scalar_mul(r2, a, c=c),
c=c,
)
np.testing.assert_allclose(res1, res, atol=1e-7, rtol=1e-10)
np.testing.assert_allclose(res2, res, atol=1e-7, rtol=1e-10)


def test_scalar_multiplication_associative(a, c, r1, r2):
res = poincare.math.mobius_scalar_mul(r1 * r2, a, c=c)
res1 = poincare.math.mobius_scalar_mul(
r1, poincare.math.mobius_scalar_mul(r2, a, c=c), c=c
)
res2 = poincare.math.mobius_scalar_mul(
r2, poincare.math.mobius_scalar_mul(r1, a, c=c), c=c
)
np.testing.assert_allclose(res1, res, atol=1e-7, rtol=1e-10)
np.testing.assert_allclose(res2, res, atol=1e-7, rtol=1e-10)


def test_scaling_property(a, c, r1):
x1 = a / a.norm(dim=-1, keepdim=True)
ra = poincare.math.mobius_scalar_mul(r1, a, c=c)
x2 = poincare.math.mobius_scalar_mul(abs(r1), a, c=c) / ra.norm(
dim=-1, keepdim=True
)
np.testing.assert_allclose(x1, x2, atol=1e-7, rtol=1e-10)

0 comments on commit ffc4f6d

Please sign in to comment.