Skip to content

Commit

Permalink
Refactoring (#149)
Browse files Browse the repository at this point in the history
* fix API warnings

* fix #148

* per group stabilize, following remark in #140

* version

* update changelog

* add sphinx checks and fix issues there

* add sphinx checks and fix some issues

* do not support torch < 1.4
  • Loading branch information
ferrine committed Oct 7, 2020
1 parent 088e30b commit 6111f68
Show file tree
Hide file tree
Showing 15 changed files with 101 additions and 41 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ env:
- PYTHON_VERSION=3.6 PYTORCH_CHANNEL='pytorch' PYTORCH_VERSION='=1.6.0' COVERAGE='--cov geoopt'
- PYTHON_VERSION=3.6 PYTORCH_CHANNEL='pytorch' PYTORCH_VERSION='=1.5.1' COVERAGE=''
- PYTHON_VERSION=3.6 PYTORCH_CHANNEL='pytorch' PYTORCH_VERSION='=1.4.0' COVERAGE=''
- PYTHON_VERSION=3.6 PYTORCH_CHANNEL='pytorch' PYTORCH_VERSION='=1.3.1' COVERAGE=''
- PYTHON_VERSION=3.6 PYTORCH_CHANNEL='pytorch-nightly' PYTORCH_VERSION='' COVERAGE=''

matrix:
Expand Down
12 changes: 12 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
This file tracks important changes in PRs

geoopt v0.3.0
=============

New Features
------------
* Riemannian Line Search (#140)
* Per group stabilization (#149)

Maintenance
-----------
* Fix API warnings (raised in #148)

geoopt v0.2.0
=============

Expand Down
9 changes: 7 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: help dockstyle-check codestyle-check linter-check black test lint check
.PHONY: help dockstyle-check codestyle-check linter-check black test lint check sphinx-check
.DEFAULT_GOAL = help

PYTHON = python
Expand All @@ -15,6 +15,11 @@ docstyle-check: # Check geoopt with pydocstyle
pydocstyle geoopt
@printf "\033[1;34mPydocstyle passes!\033[0m\n\n"

sphinx-check:
@printf "Checking sphinx build...\n"
SPHINXOPTS=-W make -C docs -f Makefile clean html
@printf "\033[1;34mSphinx passes!\033[0m\n\n"

codestyle-check: # Check geoopt with black
@printf "Checking code style with black...\n"
black --check --diff geoopt tests
Expand All @@ -31,7 +36,7 @@ black: # Format code in-place using black.
test: # Test code using pytest.
pytest -v geoopt tests --doctest-modules --html=testing-report.html --self-contained-html

lint: linter-check codestyle-check docstyle-check # Lint code using black and pylint (no pydocstyle yet).
lint: linter-check codestyle-check docstyle-check sphinx-check # Lint code using black, pylint, pydocstyle and sphinx.

check: lint test # Both lint and test code. Runs `make lint` followed by `make test`.

Expand Down
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ Manifolds
- ``geoopt.BirkhoffPolytope`` - manifold of Doubly Stochastic matrices
- ``geoopt.Stereographic`` - Constant curvature stereographic projection model
- ``geoopt.SphereProjection`` - Sphere stereographic projection model
- ``geoopt.PoincareBall`` - Poincare ball model (`wiki <https://en.wikipedia.org/wiki/Poincar%C3%A9_disk_model>`_)
- ``geoopt.Lorentz`` - Hyperboloid model (`wiki <https://en.wikipedia.org/wiki/Hyperboloid_model>`_)
- ``geoopt.PoincareBall`` - `Poincare ball model <https://en.wikipedia.org/wiki/Poincar%C3%A9_disk_model>`_
- ``geoopt.Lorentz`` - `Hyperboloid model <https://en.wikipedia.org/wiki/Hyperboloid_model>`_
- ``geoopt.ProductManifold`` - Product manifold constructor
- ``geoopt.Scaled`` - Scaled version of the manifold. Similar to `Learning Mixed-Curvature Representations in Product Spaces <https://openreview.net/forum?id=HJxeWnCcF7>`_ if combined with ``ProductManifold``

Expand Down
2 changes: 1 addition & 1 deletion geoopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
BirkhoffPolytope,
)

__version__ = "0.2.0"
__version__ = "0.3.0"
4 changes: 4 additions & 0 deletions geoopt/optim/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ def __init__(self, *args, stabilize=None, **kwargs):
self._stabilize = stabilize
super().__init__(*args, **kwargs)

def add_param_group(self, param_group: dict):
param_group.setdefault("stabilize", self._stabilize)
return super().add_param_group(param_group)

def stabilize_group(self, group):
pass

Expand Down
11 changes: 7 additions & 4 deletions geoopt/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def step(self, closure=None):
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
# actual step
grad.add_(weight_decay, point)
grad.add_(point, alpha=weight_decay)
grad = manifold.egrad2rgrad(point, grad)
exp_avg.mul_(betas[0]).add_(1 - betas[0], grad)
exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0])
exp_avg_sq.mul_(betas[1]).add_(
1 - betas[1], manifold.component_inner(point, grad)
manifold.component_inner(point, grad), alpha=1 - betas[1]
)
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
Expand Down Expand Up @@ -120,7 +120,10 @@ def step(self, closure=None):
exp_avg.set_(exp_avg_new)

group["step"] += 1
if self._stabilize is not None and group["step"] % self._stabilize == 0:
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
):
self.stabilize_group(group)
return loss

Expand Down
54 changes: 36 additions & 18 deletions geoopt/optim/rlinesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
This module uses the same syntax as a Torch optimizer
"""

from .mixin import OptimMixin
from ..tensor import ManifoldParameter, ManifoldTensor
from ..manifolds import Euclidean
from scipy.optimize.linesearch import scalar_search_wolfe2, scalar_search_armijo
import warnings
import torch
from .mixin import OptimMixin
from ..tensor import ManifoldParameter, ManifoldTensor
from ..manifolds import Euclidean
from ..utils import copy_or_set_


__all__ = ["RiemannianLineSearch"]
Expand All @@ -27,13 +28,15 @@ class RiemannianLineSearch(OptimMixin, torch.optim.Optimizer):
This is done by minimizing the line search objective
.. math::
\phi(\alpha) = f(R_x(\alpha\eta)),
where :math:`R_x` is the retraction at :math:`x`.
Its derivative is given by
.. math::
\phi'(alpha) = \langle\mathrm{grad} f(R_x(\alpha\eta)),\,
\phi'(\alpha) = \langle\mathrm{grad} f(R_x(\alpha\eta)),\,
\mathcal T_{\alpha\eta}(\eta) \rangle_{R_x(\alpha\eta)},
where :math:`\mathcal T_\xi(\eta)` denotes the vector transport of :math:`\eta`
Expand All @@ -42,8 +45,9 @@ class RiemannianLineSearch(OptimMixin, torch.optim.Optimizer):
The search direction :math:`\eta` is defined recursively by
.. math::
\eta_{k+1} = -\mathrm{grad} f(R_{x_k}(\alpha_k\eta_k))+
\beta \mathcal T_{\alpha_k\eta_k}(\eta_k)
\eta_{k+1} = -\mathrm{grad} f(R_{x_k}(\alpha_k\eta_k))
+ \beta \mathcal T_{\alpha_k\eta_k}(\eta_k)
Here :math:`\beta` is the scale parameter. If :math:`\beta=0` this is steepest
descent, other choices are Riemannian version of Fletcher-Reeves and
Expand All @@ -53,11 +57,13 @@ class RiemannianLineSearch(OptimMixin, torch.optim.Optimizer):
sufficient decrease condition:
.. math::
\phi(\alpha)\leq \phi(0)+c_1\alpha\phi'(0)
And additionally the curvature / (strong) Wolfe condition
.. math::
\phi'(\alpha)\geq c_2\phi'(0)
The Wolfe conditions are more restrictive, but guarantee that search direction
Expand All @@ -77,11 +83,11 @@ class RiemannianLineSearch(OptimMixin, torch.optim.Optimizer):
where phi is scalar line search objective, and derphi is its derivative.
If no suitable step size can be found, the method should return `None`.
The following arguments are always passed in `**kwargs`:
* **phi0:** float, Value of phi at 0
* **old_phi0:** float, Value of phi at previous point
* **derphi0:** float, Value derphi at 0
* **old_derphi0:** float, Value of derphi at previous point
* **old_step_size:** float, Stepsize at previous point
* **phi0:** float, Value of phi at 0
* **old_phi0:** float, Value of phi at previous point
* **derphi0:** float, Value derphi at 0
* **old_derphi0:** float, Value of derphi at previous point
* **old_step_size:** float, Stepsize at previous point
If any of these arguments are undefined, they default to `None`.
Additional arguments can be supplied through the `line_search_params` parameter
line_search_params : dict
Expand Down Expand Up @@ -120,7 +126,7 @@ class RiemannianLineSearch(OptimMixin, torch.optim.Optimizer):
line search conditions. See also :meth:`step` (default: 1)
stabilize : int
Stabilize parameters if they are off-manifold due to numerical
reasons every ``stabilize`` steps (default: ``None`` -- no stabilize)
reasons every `stabilize` steps (default: `None` -- no stabilize)
cg_kwargs : dict
Additional parameters to pass to the method used to compute the
conjugate gradient scale parameter.
Expand Down Expand Up @@ -168,6 +174,7 @@ def __init__(
)
self._params = []
for group in self.param_groups:
group.setdefault("step", 0)
self._params.extend(group["params"])
if len(self.param_groups) > 1:
warning_string = """Multiple parameter groups detected.
Expand Down Expand Up @@ -337,7 +344,6 @@ def _init_loss(self, recompute_gradients=False):
loss = self.prev_loss
reuse_grads = True

derphi0 = 0
self._step_size_dic = dict()

for point in self._params:
Expand Down Expand Up @@ -513,11 +519,13 @@ def step(self, closure, force_step=False, recompute_gradients=False, no_step=Fal
with torch.no_grad(): # Take suggested step
point.copy_(new_point)

if (
self._stabilize is not None
and len(self.step_size_history) % self._stabilize == 0
):
point.copy_(manifold.projx(point))
for group in self.param_groups:
group["step"] += 1
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
):
self.stabilize_group(group)

# Update loss value
if step_size is not None:
Expand All @@ -527,6 +535,16 @@ def step(self, closure, force_step=False, recompute_gradients=False, no_step=Fal
new_loss = self.prev_loss
return new_loss

def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
state = self.state[p]
if not state: # due to None grads
continue
manifold = p.manifold
copy_or_set_(p, manifold.projx(p))


#################################################################################
# Conjugate gradient scale factor
Expand Down
11 changes: 7 additions & 4 deletions geoopt/optim/rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def step(self, closure=None):
else:
manifold = self._default_manifold

grad.add_(weight_decay, point)
grad.add_(point, alpha=weight_decay)
grad = manifold.egrad2rgrad(point, grad)
if momentum > 0:
momentum_buffer = state["momentum_buffer"]
momentum_buffer.mul_(momentum).add_(1 - dampening, grad)
momentum_buffer.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
grad = grad.add_(momentum, momentum_buffer)
grad = grad.add_(momentum_buffer, alpha=momentum)
else:
grad = momentum_buffer
# we have all the things projected
Expand All @@ -114,7 +114,10 @@ def step(self, closure=None):
copy_or_set_(point, new_point)

group["step"] += 1
if self._stabilize is not None and group["step"] % self._stabilize == 0:
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
):
self.stabilize_group(group)
return loss

Expand Down
9 changes: 6 additions & 3 deletions geoopt/optim/sparse_radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def step(self, closure=None):
exp_avg_sq = state["exp_avg_sq"][rows]
# actual step
grad = manifold.egrad2rgrad(point, grad)
exp_avg.mul_(betas[0]).add_(1 - betas[0], grad)
exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0])
exp_avg_sq.mul_(betas[1]).add_(
1 - betas[1], manifold.component_inner(point, grad)
manifold.component_inner(point, grad), alpha=1 - betas[1]
)
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"][rows]
Expand Down Expand Up @@ -144,7 +144,10 @@ def step(self, closure=None):
state["exp_avg_sq"][rows] = exp_avg_sq

group["step"] += 1
if self._stabilize is not None and group["step"] % self._stabilize == 0:
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
):
self.stabilize_group(group)
return loss

Expand Down
9 changes: 6 additions & 3 deletions geoopt/optim/sparse_rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def step(self, closure=None):
grad = manifold.egrad2rgrad(point, grad)
if momentum > 0:
momentum_buffer = state["momentum_buffer"][rows]
momentum_buffer.mul_(momentum).add_(1 - dampening, grad)
momentum_buffer.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
grad = grad.add_(momentum, momentum_buffer)
grad = grad.add_(momentum_buffer, alpha=momentum)
else:
grad = momentum_buffer
# we have all the things projected
Expand All @@ -116,7 +116,10 @@ def step(self, closure=None):
full_point[rows] = new_point

group["step"] += 1
if self._stabilize is not None and group["step"] % self._stabilize == 0:
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
):
self.stabilize_group(group)
return loss

Expand Down
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
# due to bug in black: https://github.com/ambv/black/issues/355
add-ignore = D100,D101,D102,D103,D104,D105,D106,D107,D202
convention = numpy

[tool:pytest]
filterwarnings =
error
default::RuntimeWarning
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_version(*path):
maintainer_email="maxim.v.kochurov@gmail.com",
long_description=LONG_DESCRIPTION,
packages=find_packages(),
install_requires=["torch", "numpy"],
install_requires=["torch>=1.4.0", "numpy"],
version=get_version(PROJECT_ROOT, "geoopt", "__init__.py"),
url="https://github.com/geoopt/geoopt",
python_requires=">=3.6.0",
Expand Down
4 changes: 3 additions & 1 deletion tests/test_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def closure():
return loss.item()

optim = geoopt.optim.RiemannianAdam([X], stabilize=4500, **params)
assert optim.param_groups[0]["stabilize"] == 4500
for _ in range(10000):
if (Xstar - X).norm() < 1e-5:
break
Expand Down Expand Up @@ -47,7 +48,8 @@ def closure():
loss.backward()
return loss.item()

optim = geoopt.optim.RiemannianAdam([X], stabilize=4500, **params)
optim = geoopt.optim.RiemannianAdam([dict(params=[X], stabilize=4500)], **params)
assert optim.param_groups[0]["stabilize"] == 4500
assert (X - Xstar).norm() > 1e-5
for _ in range(10000):
if (X - Xstar).norm() < 1e-5:
Expand Down
5 changes: 4 additions & 1 deletion tests/test_gyrovector_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def tolerant_allclose_check(a, b, strict=True, **tolerance):
except AssertionError as e:
assert not torch.isnan(a).any(), "Found nans"
assert not torch.isnan(b).any(), "Found nans"
warnings.warn("Unstable numerics: " + " | ".join(str(e).splitlines()[3:6]))
warnings.warn(
"Unstable numerics: " + " | ".join(str(e).splitlines()[3:6]),
RuntimeWarning,
)


@pytest.fixture(params=[True, False], ids=["negative", "positive"])
Expand Down

0 comments on commit 6111f68

Please sign in to comment.