Skip to content

Commit

Permalink
implementation of SPD manifolds (#153)
Browse files Browse the repository at this point in the history
* implementation of SPD matrices manifold

* implement `SymmetricPositiveDefinite.retr(x,u)`

* Documented `SymmetricPositiveDefinite`

* Modified to fit python 3.6

* Modified to fit python 3.6

* Apply suggestions from code review ( change equation to math mode )

Co-authored-by: Maxim Kochurov <maxim.v.kochurov@gmail.com>

* fix `No module named 'geoopt'` problem in Sphinx

* move symmetric matrix operations to `batch_linalg`

* make `inner` and `_stein_metric` support `keepdim`

* test_manifold_basic of `SymmetricPositiveDefinite`

* fix lint error

* add `random()` method

* Add `rsgd` test for `spd`

* add `origin()` method

Co-authored-by: Maxim Kochurov <maxim.v.kochurov@gmail.com>
  • Loading branch information
tao-harald and ferrine committed Feb 7, 2021
1 parent bddba4f commit bb2eede
Show file tree
Hide file tree
Showing 8 changed files with 506 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import sys
import os

import geoopt

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath(".."))

import geoopt

# -- General configuration ------------------------------------------------

# If your documentation needs a minimal Sphinx version, state it here.
Expand Down
2 changes: 1 addition & 1 deletion docs/manifolds.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ Manifolds
All manifolds share same API. Some manifols may have several implementations of retraction operation, every implementation has a corresponding class.

.. automodule:: geoopt.manifolds
:members: Euclidean, Stiefel, CanonicalStiefel, EuclideanStiefel, EuclideanStiefelExact, Sphere, SphereExact, Stereographic, StereographicExact, PoincareBall, PoincareBallExact, SphereProjection, SphereProjectionExact, Scaled, ProductManifold, Lorentz
:members: Euclidean, Stiefel, CanonicalStiefel, EuclideanStiefel, EuclideanStiefelExact, Sphere, SphereExact, Stereographic, StereographicExact, PoincareBall, PoincareBallExact, SphereProjection, SphereProjectionExact, Scaled, ProductManifold, Lorentz, SymmetricPositiveDefinite


1 change: 1 addition & 0 deletions geoopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Scaled,
Lorentz,
BirkhoffPolytope,
SymmetricPositiveDefinite,
)

__version__ = "0.3.1"
165 changes: 163 additions & 2 deletions geoopt/linalg/batch_linalg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
from typing import List
from typing import List, Callable, Tuple
import torch
import torch.jit
from . import _expm

__all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank", "expm", "block_matrix"]
__all__ = [
"svd",
"qr",
"sym",
"extract_diag",
"matrix_rank",
"expm",
"block_matrix",
"sym_funcm",
"sym_expm",
"sym_logm",
"sym_sqrtm",
"sym_invm",
"sym_inv_sqrtm1",
"sym_inv_sqrtm2",
]


@torch.jit.script
Expand Down Expand Up @@ -77,6 +93,151 @@ def block_matrix(blocks: List[List[torch.Tensor]], dim0: int = -2, dim1: int = -
return torch.cat(hblocks, dim=dim0)


@torch.jit.script
def trace(x: torch.Tensor) -> torch.Tensor:
r"""self-implemented matrix trace, since `torch.trace` only support 2-d input.
Parameters
----------
x : torch.Tensor
input matrix
Returns
-------
torch.Tensor
:math:`\operatorname{Tr}(x)`
"""
return torch.diagonal(x, dim1=-2, dim2=-1).sum(-1)


def sym_funcm(
x: torch.Tensor, func: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
"""Apply function to symmetric matrix.
Parameters
----------
x : torch.Tensor
symmetric matrix
func : Callable[[torch.Tensor], torch.Tensor]
function to apply
Returns
-------
torch.Tensor
symmetric matrix with function applied to
"""
e, v = torch.symeig(x, eigenvectors=True)
return v @ torch.diag_embed(func(e)) @ v.transpose(-1, -2)


def sym_expm(x: torch.Tensor, using_native=False) -> torch.Tensor:
r"""Symmetric matrix exponent.
Parameters
----------
x : torch.Tensor
symmetric matrix
using_native : bool, optional
if using native matrix exponent `torch.matrix_exp`, by default False
Returns
-------
torch.Tensor
:math:`\exp(x)`
"""
if using_native:
return torch.matrix_exp(x)
else:
return sym_funcm(x, torch.exp)


def sym_logm(x: torch.Tensor) -> torch.Tensor:
r"""Symmetric matrix logarithm.
Parameters
----------
x : torch.Tensor
symmetric matrix
Returns
-------
torch.Tensor
:math:`\log(x)`
"""
return sym_funcm(x, torch.log)


def sym_sqrtm(x: torch.Tensor) -> torch.Tensor:
"""Symmetric matrix square root .
Parameters
----------
x : torch.Tensor
symmetric matrix
Returns
-------
torch.Tensor
:math:`x^{1/2}`
"""
return sym_funcm(x, torch.sqrt)


def sym_invm(x: torch.Tensor) -> torch.Tensor:
"""Symmetric matrix inverse.
Parameters
----------
x : torch.Tensor
symmetric matrix
Returns
-------
torch.Tensor
:math:`x^{-1}`
"""
return sym_funcm(x, torch.reciprocal)


def sym_inv_sqrtm1(x: torch.Tensor) -> torch.Tensor:
"""Symmetric matrix inverse square root.
Parameters
----------
x : torch.Tensor
symmetric matrix
Returns
-------
torch.Tensor
:math:`x^{-1/2}`
"""
return sym_funcm(x, lambda tensor: torch.reciprocal(torch.sqrt(tensor)))


def sym_inv_sqrtm2(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Symmetric matrix inverse square root, with square root return also.
Parameters
----------
x : torch.Tensor
symmetric matrix
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
:math:`x^{-1/2}`, :math:`x^{1/2}`
"""
e, v = torch.symeig(x, eigenvectors=True)
sqrt_e = torch.sqrt(e)
inv_sqrt_e = torch.reciprocal(sqrt_e)
return (
v @ torch.diag_embed(inv_sqrt_e) @ v.transpose(-1, -2),
v @ torch.diag_embed(sqrt_e) @ v.transpose(-1, -2),
)


# left here for convenience
qr = torch.qr

Expand Down
1 change: 1 addition & 0 deletions geoopt/manifolds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .stiefel import Stiefel, EuclideanStiefel, CanonicalStiefel, EuclideanStiefelExact
from .sphere import Sphere, SphereExact
from .birkhoff_polytope import BirkhoffPolytope
from .symmetric_positive_definite import SymmetricPositiveDefinite
from .stereographic import (
PoincareBall,
PoincareBallExact,
Expand Down

0 comments on commit bb2eede

Please sign in to comment.