Skip to content

Commit

Permalink
add more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Dec 31, 2018
1 parent 0d3d14c commit 8caeb99
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 17 deletions.
57 changes: 41 additions & 16 deletions geoopt/manifolds/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,12 @@ def broadcast_scalar(self, t):
Parameters
----------
t : scalar
Potentially batched (individual for every point in a batch) scalar for points on the manifold.
Returns
-------
scalar
Notes
-----
scalar can be batch sized
broadcasted representation for ``t``
"""
if isinstance(t, torch.Tensor):
extra = (1,) * self.ndim
Expand All @@ -70,6 +68,7 @@ def check_point(self, x, explain=False):
Parameters
----------
x : tensor
point on the manifold
explain: bool
return an additional information on check
Expand All @@ -92,6 +91,7 @@ def assert_check_point(self, x):
Parameters
----------
x : tensor
point on the manifold
"""

ok, reason = self._check_shape(x, "x")
Expand All @@ -108,6 +108,7 @@ def check_vector(self, u, explain=False):
Parameters
----------
u : tensor
vector on the tangent plane
explain: bool
return an additional information on check
Expand All @@ -130,6 +131,7 @@ def assert_check_vector(self, u):
Parameters
----------
u : tensor
vector on the tangent plane
"""

ok, reason = self._check_shape(u, "u")
Expand All @@ -146,8 +148,11 @@ def check_point_on_manifold(self, x, explain=False, atol=1e-5, rtol=1e-5):
Parameters
----------
x : tensor
point on the manifold
atol: float
absolute tolerance as in :func:`numpy.allclose`
rtol: float
relative tolerance as in :func:`numpy.allclose`
explain: bool
return an additional information on check
Expand All @@ -172,8 +177,11 @@ def assert_check_point_on_manifold(self, x, atol=1e-5, rtol=1e-5):
Parameters
----------
x : tensor
point on the manifold
atol: float
absolute tolerance as in :func:`numpy.allclose`
rtol: float
relative tolerance as in :func:`numpy.allclose`
"""
self.assert_check_point(x)
ok, reason = self._check_point_on_manifold(x, atol=atol, rtol=rtol)
Expand All @@ -190,9 +198,13 @@ def check_vector_on_tangent(self, x, u, explain=False, atol=1e-5, rtol=1e-5):
Parameters
----------
x : tensor
point on the manifold
u : tensor
vector on the tangent space to ``x``
atol: float
absolute tolerance as in :func:`numpy.allclose`
rtol: float
relative tolerance as in :func:`numpy.allclose`
explain: bool
return an additional information on check
Expand Down Expand Up @@ -220,9 +232,13 @@ def assert_check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5):
Parameters
----------
x : tensor
point on the manifold
u : tensor
vector on the tangent space to ``x``
atol: float
absolute tolerance as in :func:`numpy.allclose`
rtol: float
relative tolerance as in :func:`numpy.allclose`
"""
ok, reason = self._check_shape(x, "x")
if ok:
Expand Down Expand Up @@ -256,7 +272,7 @@ def retr(self, x, u, t):
Returns
-------
tensor
new_x
transported point
"""
t = self.broadcast_scalar(t)
return self._retr(x, u, t)
Expand All @@ -281,7 +297,8 @@ def transp(self, x, u, t, v, *more):
Returns
-------
transported tensors
tensor or tuple of tensors
transported tensor(s)
"""
t = self.broadcast_scalar(t)
if more:
Expand All @@ -304,8 +321,8 @@ def inner(self, x, u, v=None):
Returns
-------
inner product (broadcasted)
scalar
inner product (broadcasted)
"""
if v is None and self._inner_autofill:
v = u
Expand All @@ -327,7 +344,8 @@ def proju(self, x, u):
Returns
-------
projected vector
tensor
projected vector
"""
return self._proju(x, u)

Expand All @@ -342,7 +360,8 @@ def projx(self, x):
Returns
-------
projected point
tensor
projected point
"""
return self._projx(x)

Expand Down Expand Up @@ -370,7 +389,7 @@ def retr_transp(self, x, u, t, v, *more):
Returns
-------
tuple of tensors
(new_x, new_vs, ...)
transported point and vectors
"""
return self._retr_transp(x, u, t, v, *more)

Expand All @@ -388,12 +407,14 @@ def _check_shape(self, x, name):
Parameters
----------
x : tensor
point on the manifold
name : str
name to be present in errors
Returns
-------
bool, str
bool, str or None
check result and the reason of fail if any
"""
# return True, None
raise NotImplementedError
Expand All @@ -412,13 +433,16 @@ def _check_point_on_manifold(self, x, atol=1e-5, rtol=1e-5):
Parameters
----------
x : tensor
atol : float
absolute tolerance
rtol :
relative tolerance
point on the manifold
atol: float
absolute tolerance as in :func:`numpy.allclose`
rtol: float
relative tolerance as in :func:`numpy.allclose`
Returns
-------
bool, str or None
check result and the reason of fail if any
"""
# return True, None
raise NotImplementedError
Expand All @@ -445,6 +469,7 @@ def _check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5):
Returns
-------
bool, str or None
check result and the reason of fail if any
"""
# return True, None
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion geoopt/samplers/sgrhmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SGRHMC(Sampler):
n_steps : int
number of leapfrog steps
alpha : float
:math:`1 - alpha` - momentum term
:math:`(1 - alpha)` -- momentum term
"""

def __init__(self, params, epsilon=1e-3, n_steps=1, alpha=0.1):
Expand Down

0 comments on commit 8caeb99

Please sign in to comment.