Skip to content

Commit

Permalink
Merge 49f1e59 into 23314cb
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Feb 11, 2019
2 parents 23314cb + 49f1e59 commit 359029b
Show file tree
Hide file tree
Showing 17 changed files with 476 additions and 125 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ This file tracks important changes in PRs
geoopt v0.1.0 (unreleased)
==========================

Breaking Changes
----------------
* better public api, refactored developer api a lot (#40). See the corresponding PR for more details

New Features
------------
* Added ``Sphere`` manifold (#25)
Expand All @@ -11,7 +15,7 @@ New Features
Maintenance
-----------
* Add gitter chat (#31)
* Maintain torch>=1.0.0 only
* Maintain torch>=1.0.0 only (#39)

Deprecations
------------
Expand Down
10 changes: 5 additions & 5 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ There are two ways to install geoopt:

1. GitHub (preferred so far) due to active development

.. code-block::
.. code-block:: bash
pip install git+https://github.com/ferrine/geoopt.git
2. pypi (this might be significantly behind master branch)

.. code-block::
.. code-block:: bash
pip install geoopt
Expand Down Expand Up @@ -54,11 +54,11 @@ points on a certain manifold
- ``.inner(u, v=None)`` – inner product at this point for two
**tangent** vectors at this point. The passed vectors are not
projected, they are assumed to be already projected.
- ``.retr(u, t)`` – retraction map following vector ``u`` for time
- ``.retr(u, t=1.)`` – retraction map following vector ``u`` for time
``t``
- ``.transp(u, t, v, *more)`` – transport vector ``v`` (and possibly
- ``.transp(v, *more, u, t=1.)`` – transport vector ``v`` (and possibly
more vectors) with direction ``u`` for time ``t``
- ``.retr_transp(u, t, v, *more)`` – transport ``self``, vector ``v``
- ``.retr_transp(v, *more, u, t=1.)`` – transport ``self``, vector ``v``
(and possibly more vectors) with direction ``u`` for time ``t``
(returns are plain tensors)

Expand Down
18 changes: 18 additions & 0 deletions docs/devguide.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Extending ``geoopt``
====================

Base Manifold
-------------

The common base class for all manifolds is :class:`geoopt.manifolds.base.Manifold`.

.. autoclass:: geoopt.manifolds.base.Manifold
:private-members:
:members:


Metaclass
---------

.. autoclass:: geoopt.manifolds.base.ManifoldMeta

2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ API
optimizers
tensors
samplers
devguide

Indices and tables
==================

* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`

14 changes: 8 additions & 6 deletions docs/manifolds.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ Manifolds

.. currentmodule:: geoopt.manifolds

.. automodule:: geoopt.manifolds
:members: Euclidean, Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection

Extending ``geoopt``
--------------------
The common base class for all manifolds is :class:`geoopt.manifolds.base.Manifold`.
All manifolds share same API. In order not to duplicate the same information, the complete public API is provided only for :class:`geoopt.manifolds.Euclidean` in the end of this file.

.. automodule:: geoopt.manifolds
:members: Manifold
:members: Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection


.. autoclass:: geoopt.manifolds.Euclidean
:members:
:inherited-members:

6 changes: 3 additions & 3 deletions geoopt/linalg/_expm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@torch.jit.script
def torch_pade13(A):
def torch_pade13(A): # pragma: no cover
# avoid torch select operation and unpack coefs
(b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13) = (
64764752532480000.0,
Expand Down Expand Up @@ -49,14 +49,14 @@ def torch_pade13(A):


@torch.jit.script
def matrix_2_power(x, p):
def matrix_2_power(x, p): # pragma: no cover
for _ in range(int(p)):
x = x @ x
return x


@torch.jit.script
def expm_one(A):
def expm_one(A): # pragma: no cover
# no checks, this is private implementation
# but A should be a matrix
A_fro = torch.norm(A)
Expand Down
12 changes: 6 additions & 6 deletions geoopt/linalg/batch_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@torch.jit.script
def svd(x):
def svd(x): # pragma: no cover
# inspired by
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
# prolonged here:
Expand All @@ -32,7 +32,7 @@ def svd(x):


@torch.jit.script
def qr(x):
def qr(x): # pragma: no cover
# inspired by
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
# prolonged here:
Expand All @@ -56,12 +56,12 @@ def qr(x):


@torch.jit.script
def sym(x):
def sym(x): # pragma: no cover
return 0.5 * (x.transpose(-1, -2) + x)


@torch.jit.script
def extract_diag(x):
def extract_diag(x): # pragma: no cover
n, m = x.shape[-2:]
batch = x.shape[:-2]
k = n if n < m else m
Expand All @@ -72,7 +72,7 @@ def extract_diag(x):


@torch.jit.script
def matrix_rank(x):
def matrix_rank(x): # pragma: no cover
# inspired by
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
# prolonged here:
Expand All @@ -97,7 +97,7 @@ def matrix_rank(x):


@torch.jit.script
def expm(x):
def expm(x): # pragma: no cover
# inspired by
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
# prolonged here:
Expand Down
Loading

0 comments on commit 359029b

Please sign in to comment.