Skip to content

Commit

Permalink
add more docs + black
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Dec 31, 2018
1 parent abf42f5 commit 469f2a5
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 19 deletions.
9 changes: 8 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
'sphinx.ext.githubpages',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
]

Expand Down Expand Up @@ -349,3 +349,10 @@
# If true, do not generate a @detailmenu in the "Top" node's menu.
#
# texinfo_no_detailmenu = False

# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'python': ('https://docs.python.org/', None),
'torch': ('https://pytorch.org/docs/master/', None),
}
3 changes: 1 addition & 2 deletions docs/manifolds.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ Manifolds
.. currentmodule:: geoopt.manifolds

.. automodule:: geoopt.manifolds
:members:
:imported-members: True
:members: Stiefel, Euclidean
4 changes: 1 addition & 3 deletions geoopt/manifolds/euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

class Euclidean(Manifold):
"""
Euclidean manifold
An unconstrained manifold
Simple Euclidean manifold
"""

name = "Euclidean"
Expand Down
35 changes: 26 additions & 9 deletions geoopt/manifolds/stiefel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,28 @@
__all__ = ["Stiefel", "EuclideanStiefel", "CanonicalStiefel"]


class Stiefel(Manifold):
r"""
_stiefel_doc = r"""
Manifold induced by the following matrix constraint:
.. math::
X^\top X = I
X \in \mathrm{R}^{n\times m}
X^\top X = I\\
X \in \mathrm{R}^{n\times m}\\
n \ge m
"""


class Stiefel(Manifold):
__doc__ = r"""
{}
Parameters
----------
canonical : bool
Use canonical inner product instead of euclidean one (defaults to canonical)
Notes
-----
works with batch sized tensors
"""
""".format(
_stiefel_doc
)
ndim = 2

def __new__(cls, canonical=True):
Expand Down Expand Up @@ -76,6 +79,13 @@ def _projx(self, x):


class CanonicalStiefel(Stiefel):
__doc__ = r"""Stiefel Manifold with Canonical inner product
{}
""".format(
_stiefel_doc
)

name = "Stiefel(canonical)"
reversible = True

Expand Down Expand Up @@ -131,6 +141,13 @@ def _retr(self, x, u, t):


class EuclideanStiefel(Stiefel):
__doc__ = r"""Stiefel Manifold with Euclidean inner product
{}
""".format(
_stiefel_doc
)

name = "Stiefel(euclidean)"
reversible = False

Expand Down
9 changes: 8 additions & 1 deletion geoopt/optim/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class RiemannianAdam(OptimMixin, torch.optim.Adam):
r"""Riemannian Adam
r"""Riemannian Adam with the same API as :class:`torch.optim.Adam`
Parameters
----------
Expand All @@ -29,9 +29,16 @@ class RiemannianAdam(OptimMixin, torch.optim.Adam):
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
Other Parameters
----------------
stabilize : int
Stabilize parameters if they are off-manifold due to numerical
reasons every ``stabilize`` steps (default: ``None`` -- no stabilize)
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

def step(self, closure=None):
Expand Down
8 changes: 7 additions & 1 deletion geoopt/optim/rsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class RiemannianSGD(OptimMixin, torch.optim.Optimizer):
r"""Riemannian Stochastic Gradient Descent
r"""Riemannian Stochastic Gradient Descent with the same API as :class:`torch.optim.SGD`
Parameters
----------
Expand All @@ -26,6 +26,12 @@ class RiemannianSGD(OptimMixin, torch.optim.Optimizer):
dampening for momentum (default: 0)
nesterov : bool (optional)
enables Nesterov momentum (default: False)
Other Parameters
----------------
stabilize : int
Stabilize parameters if they are off-manifold due to numerical
reasons every ``stabilize`` steps (default: ``None`` -- no stabilize)
"""

def __init__(
Expand Down
18 changes: 16 additions & 2 deletions geoopt/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@


class ManifoldTensor(torch.Tensor):
"""A regular tensor that has information about its manifold.
It is a very tiny wrapper over regular tensor so that all API is the same
"""Same as :class:`torch.Tensor` that has information about its manifold.
Other Parameters
----------------
manifold : :class:`geoopt.Manifold`
A manifold for the tensor, (default: :class:`geoopt.Euclidean`)
"""

def __new__(cls, *args, manifold=Euclidean(), requires_grad=False, **kwargs):
Expand Down Expand Up @@ -50,6 +54,16 @@ def __repr__(self):


class ManifoldParameter(ManifoldTensor, torch.nn.Parameter):
"""Same as :class:`torch.nn.Parameter` that has information about its manifold.
It should be used within :class:`torch.nn.Module` to be recognized
in parameter collection.
Other Parameters
----------------
manifold : :class:`geoopt.Manifold` (optional)
A manifold for the tensor if ``data`` is not a :class:`geoopt.ManifoldTensor`
"""

def __new__(cls, data=None, manifold=None, requires_grad=True):
if data is None:
data = ManifoldTensor(manifold=manifold)
Expand Down

0 comments on commit 469f2a5

Please sign in to comment.