Skip to content

Commit

Permalink
Positive weights for weighted midpoint (#131)
Browse files Browse the repository at this point in the history
* positive weights for einstein midpoint

* fix test

* black
  • Loading branch information
ferrine committed Apr 24, 2020
1 parent f73b940 commit 7e68755
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
2 changes: 2 additions & 0 deletions geoopt/manifolds/stereographic/manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def weighted_midpoint(
dim: int = -1,
keepdim: bool = False,
lincomb: bool = False,
posweight=False,
project=True,
):
mid = math.weighted_midpoint(
Expand All @@ -462,6 +463,7 @@ def weighted_midpoint(
dim=dim,
keepdim=keepdim,
lincomb=lincomb,
posweight=posweight,
)
if project:
return math.project(mid, k=self.k, dim=dim)
Expand Down
27 changes: 12 additions & 15 deletions geoopt/manifolds/stereographic/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,7 @@ def weighted_midpoint(
dim: int = -1,
keepdim: bool = False,
lincomb: bool = False,
posweight: bool = False,
):
r"""
Compute weighted Möbius gyromidpoint.
Expand Down Expand Up @@ -1925,6 +1926,10 @@ def weighted_midpoint(
retain the last dim? (default: false)
lincomb : bool
linear combination implementation
posweight : bool
make all weights positive. Negative weight will weight antipode of entry with positive weight instead.
This will give experimentally better numerics and nice interpolation
properties for linear combination and averaging
Returns
-------
Expand All @@ -1939,6 +1944,7 @@ def weighted_midpoint(
dim=dim,
keepdim=keepdim,
lincomb=lincomb,
posweight=posweight,
)


Expand All @@ -1951,6 +1957,7 @@ def _weighted_midpoint(
dim: int = -1,
keepdim: bool = False,
lincomb: bool = False,
posweight: bool = False,
):
if reducedim is None:
reducedim = list_range(xs.dim())
Expand All @@ -1960,21 +1967,12 @@ def _weighted_midpoint(
weights = torch.tensor(1.0, dtype=xs.dtype, device=xs.device)
else:
weights = weights.unsqueeze(dim)
if posweight and weights.lt(0).any():
xs = torch.where(weights.lt(0), _antipode(xs, k=k, dim=dim), xs)
weights = weights.abs()
denominator = ((gamma - 1) * weights).sum(reducedim, keepdim=True)
zero = torch.tensor(0.0, dtype=xs.dtype, device=xs.device)
one = torch.tensor(1.0, dtype=xs.dtype, device=xs.device)
ill_conditioned = torch.isclose(denominator, zero, atol=1e-7)
if lincomb:
nominator = (gamma * weights * xs).sum(reducedim, keepdim=True)
two_mean = nominator / torch.where(ill_conditioned, one, denominator)
elif ill_conditioned.any():
weights_denom = torch.where(ill_conditioned, weights + 1e-7, weights)
nominator = (gamma * weights * xs).sum(reducedim, keepdim=True)
denominator = ((gamma - 1) * weights_denom).sum(reducedim, keepdim=True)
two_mean = nominator / denominator
else:
nominator = (gamma * weights * xs).sum(reducedim, keepdim=True)
two_mean = nominator / clamp_abs(denominator)
nominator = (gamma * weights * xs).sum(reducedim, keepdim=True)
two_mean = nominator / clamp_abs(denominator, 1e-10)
a_mean = _mobius_scalar_mul(
torch.tensor(0.5, dtype=xs.dtype, device=xs.device), two_mean, k=k, dim=dim
)
Expand All @@ -1997,7 +1995,6 @@ def _weighted_midpoint(
else:
weights, _ = torch.broadcast_tensors(weights, gamma)
alpha = weights.sum(reducedim, keepdim=True)
alpha = torch.where(ill_conditioned, one, alpha)
a_mean = _mobius_scalar_mul(alpha, a_mean, k=k, dim=dim)
if not keepdim:
a_mean = drop_dims(a_mean, reducedim)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_gyrovector_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,9 @@ def test_weighted_midpoint_weighted_zero_sum(_k, lincomb):
a = manifold.expmap0(torch.eye(3, 10)).detach().requires_grad_(True)
weights = torch.rand_like(a[..., 0])
weights = weights - weights.sum() / weights.numel()
mid = manifold.weighted_midpoint(a, lincomb=lincomb, weights=weights)
mid = manifold.weighted_midpoint(
a, lincomb=lincomb, weights=weights, posweight=True
)
if _k == 0 and lincomb:
np.testing.assert_allclose(
mid.detach(),
Expand Down

0 comments on commit 7e68755

Please sign in to comment.