Skip to content

Commit

Permalink
Merge a716f52 into d7c115c
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Apr 23, 2020
2 parents d7c115c + a716f52 commit 981b203
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
2 changes: 2 additions & 0 deletions geoopt/manifolds/stereographic/manifold.py
Expand Up @@ -452,6 +452,7 @@ def weighted_midpoint(
dim: int = -1,
keepdim: bool = False,
lincomb: bool = False,
posweight=False,
project=True,
):
mid = math.weighted_midpoint(
Expand All @@ -462,6 +463,7 @@ def weighted_midpoint(
dim=dim,
keepdim=keepdim,
lincomb=lincomb,
posweight=posweight,
)
if project:
return math.project(mid, k=self.k, dim=dim)
Expand Down
27 changes: 12 additions & 15 deletions geoopt/manifolds/stereographic/math.py
Expand Up @@ -1878,6 +1878,7 @@ def weighted_midpoint(
dim: int = -1,
keepdim: bool = False,
lincomb: bool = False,
posweight: bool = False,
):
r"""
Compute weighted Möbius gyromidpoint.
Expand Down Expand Up @@ -1925,6 +1926,10 @@ def weighted_midpoint(
retain the last dim? (default: false)
lincomb : bool
linear combination implementation
posweight : bool
make all weights positive. Negative weight will weight antipode of entry with positive weight instead.
This will give experimentally better numerics and nice interpolation
properties for linear combination and averaging
Returns
-------
Expand All @@ -1939,6 +1944,7 @@ def weighted_midpoint(
dim=dim,
keepdim=keepdim,
lincomb=lincomb,
posweight=posweight,
)


Expand All @@ -1951,6 +1957,7 @@ def _weighted_midpoint(
dim: int = -1,
keepdim: bool = False,
lincomb: bool = False,
posweight: bool = False,
):
if reducedim is None:
reducedim = list_range(xs.dim())
Expand All @@ -1960,21 +1967,12 @@ def _weighted_midpoint(
weights = torch.tensor(1.0, dtype=xs.dtype, device=xs.device)
else:
weights = weights.unsqueeze(dim)
if posweight and weights.lt(0).any():
xs = torch.where(weights.lt(0), _antipode(xs, k=k, dim=dim), xs)
weights = weights.abs()
denominator = ((gamma - 1) * weights).sum(reducedim, keepdim=True)
zero = torch.tensor(0.0, dtype=xs.dtype, device=xs.device)
one = torch.tensor(1.0, dtype=xs.dtype, device=xs.device)
ill_conditioned = torch.isclose(denominator, zero, atol=1e-7)
if lincomb:
nominator = (gamma * weights * xs).sum(reducedim, keepdim=True)
two_mean = nominator / torch.where(ill_conditioned, one, denominator)
elif ill_conditioned.any():
weights_denom = torch.where(ill_conditioned, weights + 1e-7, weights)
nominator = (gamma * weights * xs).sum(reducedim, keepdim=True)
denominator = ((gamma - 1) * weights_denom).sum(reducedim, keepdim=True)
two_mean = nominator / denominator
else:
nominator = (gamma * weights * xs).sum(reducedim, keepdim=True)
two_mean = nominator / clamp_abs(denominator)
nominator = (gamma * weights * xs).sum(reducedim, keepdim=True)
two_mean = nominator / clamp_abs(denominator, 1e-10)
a_mean = _mobius_scalar_mul(
torch.tensor(0.5, dtype=xs.dtype, device=xs.device), two_mean, k=k, dim=dim
)
Expand All @@ -1997,7 +1995,6 @@ def _weighted_midpoint(
else:
weights, _ = torch.broadcast_tensors(weights, gamma)
alpha = weights.sum(reducedim, keepdim=True)
alpha = torch.where(ill_conditioned, one, alpha)
a_mean = _mobius_scalar_mul(alpha, a_mean, k=k, dim=dim)
if not keepdim:
a_mean = drop_dims(a_mean, reducedim)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_gyrovector_math.py
Expand Up @@ -672,7 +672,9 @@ def test_weighted_midpoint_weighted_zero_sum(_k, lincomb):
a = manifold.expmap0(torch.eye(3, 10)).detach().requires_grad_(True)
weights = torch.rand_like(a[..., 0])
weights = weights - weights.sum() / weights.numel()
mid = manifold.weighted_midpoint(a, lincomb=lincomb, weights=weights)
mid = manifold.weighted_midpoint(
a, lincomb=lincomb, weights=weights, posweight=True
)
if _k == 0 and lincomb:
np.testing.assert_allclose(
mid.detach(),
Expand Down

0 comments on commit 981b203

Please sign in to comment.