diff --git a/docs/conf.py b/docs/conf.py
index e1209002..cb2e0ae0 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -36,34 +36,35 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.coverage',
- 'sphinx.ext.mathjax',
- 'sphinx.ext.viewcode',
- 'sphinx.ext.intersphinx',
- 'sphinx.ext.napoleon',
+ "matplotlib.sphinxext.plot_directive",
+ "sphinx.ext.autodoc",
+ "sphinx.ext.coverage",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.viewcode",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.napoleon",
]
# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
-source_suffix = '.rst'
+source_suffix = ".rst"
# The encoding of source files.
#
# source_encoding = 'utf-8-sig'
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# General information about the project.
-project = u'geoopt'
-copyright = u'2018, Max Kochurov'
-author = u'Max Kochurov'
+project = u"geoopt"
+copyright = u"2018, Max Kochurov"
+author = u"Max Kochurov"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
@@ -93,7 +94,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
-exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# The reST default role (used for this markup: `text`) to use for all
# documents.
@@ -115,7 +116,7 @@
# show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
+pygments_style = "sphinx"
# A list of ignored prefixes for module index sorting.
# modindex_common_prefix = []
@@ -132,7 +133,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
-html_theme = 'alabaster'
+html_theme = "alabaster"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
@@ -166,7 +167,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
@@ -246,34 +247,30 @@
# html_search_scorer = 'scorer.js'
# Output file base name for HTML help builder.
-htmlhelp_basename = 'geooptdoc'
+htmlhelp_basename = "geooptdoc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
- # The paper size ('letterpaper' or 'a4paper').
- #
- # 'papersize': 'letterpaper',
-
- # The font size ('10pt', '11pt' or '12pt').
- #
- # 'pointsize': '10pt',
-
- # Additional stuff for the LaTeX preamble.
- #
- # 'preamble': '',
-
- # Latex figure (float) alignment
- #
- # 'figure_align': 'htbp',
+ # The paper size ('letterpaper' or 'a4paper').
+ #
+ # 'papersize': 'letterpaper',
+ # The font size ('10pt', '11pt' or '12pt').
+ #
+ # 'pointsize': '10pt',
+ # Additional stuff for the LaTeX preamble.
+ #
+ # 'preamble': '',
+ # Latex figure (float) alignment
+ #
+ # 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
- (master_doc, 'geoopt.tex', u'geoopt Documentation',
- u'Max Kochurov', 'manual'),
+ (master_doc, "geoopt.tex", u"geoopt Documentation", u"Max Kochurov", "manual")
]
# The name of an image file (relative to this directory) to place at the top of
@@ -313,10 +310,7 @@
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
-man_pages = [
- (master_doc, 'geoopt', u'geoopt Documentation',
- [author], 1)
-]
+man_pages = [(master_doc, "geoopt", u"geoopt Documentation", [author], 1)]
# If true, show URL addresses after external links.
#
@@ -329,9 +323,15 @@
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
- (master_doc, 'geoopt', u'geoopt Documentation',
- author, 'geoopt', 'One line description of project.',
- 'Miscellaneous'),
+ (
+ master_doc,
+ "geoopt",
+ u"geoopt Documentation",
+ author,
+ "geoopt",
+ "One line description of project.",
+ "Miscellaneous",
+ )
]
# Documents to append as an appendix to all manuals.
@@ -352,7 +352,7 @@
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
- 'numpy': ('http://docs.scipy.org/doc/numpy/', None),
- 'python': ('https://docs.python.org/', None),
- 'torch': ('https://pytorch.org/docs/master/', None),
+ "numpy": ("http://docs.scipy.org/doc/numpy/", None),
+ "python": ("https://docs.python.org/", None),
+ "torch": ("https://pytorch.org/docs/master/", None),
}
diff --git a/docs/devguide.rst b/docs/devguide.rst
index 040469bb..b3674acd 100644
--- a/docs/devguide.rst
+++ b/docs/devguide.rst
@@ -1,5 +1,5 @@
-Extending ``geoopt``
-====================
+Developer Guide
+===============
Base Manifold
-------------
@@ -9,7 +9,7 @@ The common base class for all manifolds is :class:`geoopt.manifolds.base.Manifol
.. autoclass:: geoopt.manifolds.base.Manifold
:private-members:
:members:
-
+ :noindex:
Metaclass
---------
diff --git a/docs/extended.rst b/docs/extended.rst
new file mode 100644
index 00000000..e2d1b35d
--- /dev/null
+++ b/docs/extended.rst
@@ -0,0 +1,7 @@
+Extended Guide
+==============
+
+.. toctree::
+ :maxdepth: 1
+
+ extended/poincare
diff --git a/docs/extended/poincare.rst b/docs/extended/poincare.rst
new file mode 100644
index 00000000..c411803b
--- /dev/null
+++ b/docs/extended/poincare.rst
@@ -0,0 +1,117 @@
+Poincare Ball model
+===================
+
+Poincare ball model is a compact representation of hyperbolic space.
+To have a nice introduction into this model we should start from
+simple concepts, putting them all together to build a more complete picture.
+
+Hyperbolic spaces
+-----------------
+
+Hyperbolic space is a constant negative curvature Riemannian manifold.
+A very simple example of Riemannian manifold with constant, but positive curvature is sphere.
+
+An (N+1)-dimensional hyperboloid spans the manifold that can be embedded into N-dimensional space via projections.
+
+.. figure:: ../plots/extended/poincare/hyperboloid_projection.png
+ :width: 300
+
+ img source `Wikipedia, Hyperboloid Model `_
+
+Originally, the distance between points on the hyperboloid is defined as
+
+.. math::
+
+ d(x, y) = \operatorname{arccosh}(x, y)
+
+It is difficult to work in (N+1)-dimensional space and there is a range of useful embeddings
+exist in literature
+
+Klein Model
+~~~~~~~~~~~
+
+.. figure:: ../plots/extended/poincare/klein_tiling.png
+ :width: 300
+
+ img source `Wikipedia, Klein Model `_
+
+
+Poincare Model
+~~~~~~~~~~~~~~
+
+.. figure:: ../plots/extended/poincare/poincare_lines.gif
+ :width: 300
+
+ img source `Bulatov, Poincare Model `_
+
+Here we go.
+
+First of all we note, that Poincare ball is embedded in a Sphere of radius :math:`r=1/\sqrt{c}`,
+where c is negative curvature. We also note, as :math:`c` goes to :math:`0`, we recover infinite radius ball.
+We should expect this limiting behaviour recovers Euclidean geometry.
+
+To connect Euclidean space with its embedded manifold we need to get :math:`g_x`.
+It is done via `conformal factor` :math:`\lambda^c_x`.
+
+
+.. autofunction:: geoopt.manifolds.poincare.math.lambda_x
+
+
+:math:`\lambda^c_x` connects Euclidean inner product with Riemannian one
+
+.. autofunction:: geoopt.manifolds.poincare.math.inner
+.. autofunction:: geoopt.manifolds.poincare.math.norm
+.. autofunction:: geoopt.manifolds.poincare.math.egrad2rgrad
+
+Math
+----
+The good thing about Poincare ball is that it forms a Gyrogroup. Minimal definition of a Gyrogroup
+assumes a binary operation :math:`*` defined that satisfies a set of properties.
+
+Left identity
+ For every element :math:`a\in G` there exist :math:`e\in G` such that :math:`e * a = a`.
+Left Inverse
+ For every element :math:`a\in G` there exist :math:`b\in G` such that :math:`b * a = e`
+Gyroassociativity
+ For any :math:`a,b,c\in G` there exist :math:`gyr[a, b]c\in G` such that :math:`a * (b * c)=(a * b) * gyr[a, b]c`
+Gyroautomorphism
+ :math:`gyr[a, b]` is a magma automorphism in G
+Left loop
+ :math:`gyr[a, b] = gyr[a * b, b]`
+
+As mentioned above, hyperbolic space forms a Gyrogroup equipped with
+
+.. autofunction:: geoopt.manifolds.poincare.math.mobius_add
+.. autofunction:: geoopt.manifolds.poincare.math.gyration
+
+Using this math, it is possible to define another useful operations
+
+.. autofunction:: geoopt.manifolds.poincare.math.mobius_sub
+.. autofunction:: geoopt.manifolds.poincare.math.mobius_scalar_mul
+.. autofunction:: geoopt.manifolds.poincare.math.mobius_pointwise_mul
+.. autofunction:: geoopt.manifolds.poincare.math.mobius_matvec
+.. autofunction:: geoopt.manifolds.poincare.math.mobius_fn_apply
+.. autofunction:: geoopt.manifolds.poincare.math.mobius_fn_apply_chain
+
+Manifold
+--------
+Now we are ready to proceed with studying distances, geodesics, exponential maps and more
+
+.. autofunction:: geoopt.manifolds.poincare.math.dist
+.. autofunction:: geoopt.manifolds.poincare.math.dist2plane
+.. autofunction:: geoopt.manifolds.poincare.math.parallel_transport
+.. autofunction:: geoopt.manifolds.poincare.math.geodesic
+.. autofunction:: geoopt.manifolds.poincare.math.geodesic_unit
+.. autofunction:: geoopt.manifolds.poincare.math.expmap
+.. autofunction:: geoopt.manifolds.poincare.math.expmap0
+.. autofunction:: geoopt.manifolds.poincare.math.logmap
+.. autofunction:: geoopt.manifolds.poincare.math.logmap0
+
+
+Stability
+---------
+Numerical stability is a pain in this model. It is strongly recommended to work in ``float64``,
+so expect adventures in ``float32`` (but this is not certain).
+
+.. autofunction:: geoopt.manifolds.poincare.math.project
+.. autofunction:: geoopt.manifolds.poincare.math.clip_tangent
diff --git a/docs/index.rst b/docs/index.rst
index 9f62ee12..9ee6098a 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -17,6 +17,7 @@ API
optimizers
tensors
samplers
+ extended
devguide
Indices and tables
diff --git a/docs/manifolds.rst b/docs/manifolds.rst
index 97b65a0e..6dd3f579 100644
--- a/docs/manifolds.rst
+++ b/docs/manifolds.rst
@@ -4,13 +4,11 @@ Manifolds
.. currentmodule:: geoopt.manifolds
-All manifolds share same API. In order not to duplicate the same information, the complete public API is provided only for :class:`geoopt.manifolds.Euclidean` in the end of this file.
+All manifolds share same API. In order not to duplicate the same information, the complete public API is provided only for :class:`geoopt.manifolds.Manifold` in the end of this file.
.. automodule:: geoopt.manifolds
- :members: Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection
+ :members: Euclidean, Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection, PoincareBall
-
-.. autoclass:: geoopt.manifolds.Euclidean
+.. autoclass:: geoopt.manifolds.base.Manifold
:members:
- :inherited-members:
diff --git a/docs/plots/extended/poincare/distance.py b/docs/plots/extended/poincare/distance.py
new file mode 100644
index 00000000..33408101
--- /dev/null
+++ b/docs/plots/extended/poincare/distance.py
@@ -0,0 +1,27 @@
+import geoopt.manifolds.poincare.math as pmath
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+sns.set_style("white")
+radius = 1
+coords = np.linspace(-radius, radius, 100)
+x = torch.tensor([-0.75, 0])
+xx, yy = np.meshgrid(coords, coords)
+dist2 = xx ** 2 + yy ** 2
+mask = dist2 <= radius ** 2
+grid = np.stack([xx, yy], axis=-1)
+dists = pmath.dist(torch.from_numpy(grid).float(), x)
+dists[(~mask).nonzero()] = np.nan
+circle = plt.Circle((0, 0), 1, fill=False, color="b")
+plt.gca().add_artist(circle)
+plt.xlim(-1.1, 1.1)
+plt.ylim(-1.1, 1.1)
+plt.gca().set_aspect("equal")
+plt.contourf(
+ grid[..., 0], grid[..., 1], dists.log().numpy(), levels=100, cmap="inferno"
+)
+plt.colorbar()
+plt.title("log distance to ($-$0.75, 0)")
+plt.show()
diff --git a/docs/plots/extended/poincare/distance2plane.py b/docs/plots/extended/poincare/distance2plane.py
new file mode 100644
index 00000000..66cbc8fd
--- /dev/null
+++ b/docs/plots/extended/poincare/distance2plane.py
@@ -0,0 +1,35 @@
+import geoopt.manifolds.poincare.math as pmath
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+from matplotlib import rcParams
+
+rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
+rcParams["text.usetex"] = True
+
+sns.set_style("white")
+radius = 1
+coords = np.linspace(-radius, radius, 100)
+x = torch.tensor([-0.75, 0])
+v = torch.tensor([0.1 / 3, -1 / 3])
+xx, yy = np.meshgrid(coords, coords)
+dist2 = xx ** 2 + yy ** 2
+mask = dist2 <= radius ** 2
+grid = np.stack([xx, yy], axis=-1)
+dists = pmath.dist2plane(torch.from_numpy(grid).float(), x, v)
+dists[(~mask).nonzero()] = np.nan
+circle = plt.Circle((0, 0), 1, fill=False, color="b")
+plt.gca().add_artist(circle)
+plt.xlim(-1.1, 1.1)
+plt.ylim(-1.1, 1.1)
+
+plt.gca().set_aspect("equal")
+plt.contourf(
+ grid[..., 0], grid[..., 1], dists.log().numpy(), levels=100, cmap="inferno"
+)
+plt.colorbar()
+plt.scatter(*x, color="g")
+plt.arrow(*x, *v, color="g", width=0.01)
+plt.title(r"log distance to $\tilde{H}_{a, p}$")
+plt.show()
diff --git a/docs/plots/extended/poincare/gyrovector_parallel_transport.py b/docs/plots/extended/poincare/gyrovector_parallel_transport.py
new file mode 100644
index 00000000..894f6c94
--- /dev/null
+++ b/docs/plots/extended/poincare/gyrovector_parallel_transport.py
@@ -0,0 +1,52 @@
+import geoopt.manifolds.poincare.math as pmath
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+from matplotlib import rcParams
+
+rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
+rcParams["text.usetex"] = True
+
+sns.set_style("white")
+
+x = torch.tensor((-0.25, -0.75))
+xv1 = torch.tensor((np.sin(np.pi / 3), np.cos(np.pi / 3))) / 5
+xv2 = torch.tensor((np.sin(-np.pi / 3), np.cos(np.pi / 3))) / 5
+t = torch.linspace(0, 1, 10)[:, None]
+
+y = torch.tensor((0.65, -0.55))
+xy = pmath.logmap(x, y)
+path = pmath.geodesic(t, x, y)
+yv1 = pmath.parallel_transport(x, y, xv1)
+yv2 = pmath.parallel_transport(x, y, xv2)
+
+xgv1 = pmath.geodesic_unit(t, x, xv1)
+xgv2 = pmath.geodesic_unit(t, x, xv2)
+
+ygv1 = pmath.geodesic_unit(t, y, yv1)
+ygv2 = pmath.geodesic_unit(t, y, yv2)
+
+
+def plot_gv(gv, **kwargs):
+ plt.plot(*gv.t().numpy(), **kwargs)
+ plt.arrow(*gv[-2], *(gv[-1] - gv[-2]), width=0.01, **kwargs)
+
+
+circle = plt.Circle((0, 0), 1, fill=False, color="b")
+plt.gca().add_artist(circle)
+plt.xlim(-1.1, 1.1)
+plt.ylim(-1.1, 1.1)
+plt.gca().set_aspect("equal")
+plt.annotate("x", x - 0.09, fontsize=15)
+plt.annotate("y", y - 0.09, fontsize=15)
+plt.annotate(r"$\vec{v}$", x + torch.tensor([0.3, 0.5]), fontsize=15)
+plot_gv(xgv1, color="r")
+plot_gv(xgv2, color="b")
+plt.arrow(*x, *xy, width=0.01, color="g")
+plot_gv(ygv1, color="r")
+plot_gv(ygv2, color="b")
+
+plt.plot(*path.t().numpy(), color="g")
+plt.title(r"gyrovector parallel transport $P_{x\to y}$")
+plt.show()
diff --git a/docs/plots/extended/poincare/hyperboloid_projection.png b/docs/plots/extended/poincare/hyperboloid_projection.png
new file mode 100644
index 00000000..df4843c5
Binary files /dev/null and b/docs/plots/extended/poincare/hyperboloid_projection.png differ
diff --git a/docs/plots/extended/poincare/klein_tiling.png b/docs/plots/extended/poincare/klein_tiling.png
new file mode 100644
index 00000000..c2184592
Binary files /dev/null and b/docs/plots/extended/poincare/klein_tiling.png differ
diff --git a/docs/plots/extended/poincare/mobius_add.py b/docs/plots/extended/poincare/mobius_add.py
new file mode 100644
index 00000000..159656a8
--- /dev/null
+++ b/docs/plots/extended/poincare/mobius_add.py
@@ -0,0 +1,29 @@
+import geoopt.manifolds.poincare.math as pmath
+import torch
+import matplotlib.pyplot as plt
+import seaborn as sns
+from matplotlib import rcParams
+
+rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
+rcParams["text.usetex"] = True
+
+sns.set_style("white")
+
+x = torch.tensor((-0.25, -0.75)) / 2
+y = torch.tensor((0.65, -0.55)) / 2
+x_plus_y = pmath.mobius_add(x, y)
+
+
+circle = plt.Circle((0, 0), 1, fill=False, color="b")
+plt.gca().add_artist(circle)
+plt.xlim(-1.1, 1.1)
+plt.ylim(-1.1, 1.1)
+plt.gca().set_aspect("equal")
+plt.annotate("x", x - 0.09, fontsize=15)
+plt.annotate("y", y - 0.09, fontsize=15)
+plt.annotate(r"$x\oplus y$", x_plus_y - torch.tensor([0.1, 0.15]), fontsize=15)
+plt.arrow(0, 0, *x, width=0.01, color="r")
+plt.arrow(0, 0, *y, width=0.01, color="g")
+plt.arrow(0, 0, *x_plus_y, width=0.01, color="b")
+plt.title(r"Addition $x\oplus y$")
+plt.show()
diff --git a/docs/plots/extended/poincare/mobius_matvec.py b/docs/plots/extended/poincare/mobius_matvec.py
new file mode 100644
index 00000000..df02968d
--- /dev/null
+++ b/docs/plots/extended/poincare/mobius_matvec.py
@@ -0,0 +1,30 @@
+import geoopt.manifolds.poincare.math as pmath
+import torch
+import matplotlib.pyplot as plt
+import seaborn as sns
+from matplotlib import rcParams
+
+rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
+rcParams["text.usetex"] = True
+sns.set_style("white")
+x = torch.tensor((-0.25, -0.75)) / 3
+M = torch.tensor([[-1, -1.5], [0.2, 0.5]])
+M_x = pmath.mobius_matvec(M, x)
+
+
+circle = plt.Circle((0, 0), 1, fill=False, color="b")
+plt.gca().add_artist(circle)
+plt.xlim(-1.1, 1.1)
+plt.ylim(-1.1, 1.1)
+plt.gca().set_aspect("equal")
+plt.annotate("x", x - 0.09, fontsize=15)
+plt.annotate(
+ r"$M=\begin{bmatrix}-1 &-1.5\\.2 &.5\end{bmatrix}$",
+ x + torch.tensor([-0.5, 0.5]),
+ fontsize=15,
+)
+plt.annotate(r"$M^\otimes x$", M_x - torch.tensor([0.1, 0.15]), fontsize=15)
+plt.arrow(0, 0, *x, width=0.01, color="r")
+plt.arrow(0, 0, *M_x, width=0.01, color="b")
+plt.title(r"Matrix multiplication $M\otimes x$")
+plt.show()
diff --git a/docs/plots/extended/poincare/mobius_sigmoid_apply.py b/docs/plots/extended/poincare/mobius_sigmoid_apply.py
new file mode 100644
index 00000000..deeb7f7e
--- /dev/null
+++ b/docs/plots/extended/poincare/mobius_sigmoid_apply.py
@@ -0,0 +1,27 @@
+import geoopt.manifolds.poincare.math as pmath
+from matplotlib import rcParams
+import torch
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+sns.set_style("white")
+rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
+rcParams["text.usetex"] = True
+x = torch.tensor((-0.25, -0.75)) / 3
+f_x = pmath.mobius_fn_apply(torch.sigmoid, x)
+
+
+circle = plt.Circle((0, 0), 1, fill=False, color="b")
+plt.gca().add_artist(circle)
+plt.xlim(-1.1, 1.1)
+plt.ylim(-1.1, 1.1)
+plt.gca().set_aspect("equal")
+plt.annotate("x", x - 0.09, fontsize=15)
+plt.annotate(
+ r"$\sigma(x)=\frac{1}{1+e^{-x}}$", x + torch.tensor([-0.7, 0.5]), fontsize=15
+)
+plt.annotate(r"$\sigma^\otimes(x)$", f_x - torch.tensor([0.1, 0.15]), fontsize=15)
+plt.arrow(0, 0, *x, width=0.01, color="r")
+plt.arrow(0, 0, *f_x, width=0.01, color="b")
+plt.title(r"Mobius function (sigmoid) apply $\sigma^\otimes(x)$")
+plt.show()
diff --git a/docs/plots/extended/poincare/parallel_transport.py b/docs/plots/extended/poincare/parallel_transport.py
new file mode 100644
index 00000000..be624c52
--- /dev/null
+++ b/docs/plots/extended/poincare/parallel_transport.py
@@ -0,0 +1,40 @@
+import geoopt.manifolds.poincare.math as pmath
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+from matplotlib import rcParams
+
+rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
+rcParams["text.usetex"] = True
+
+
+sns.set_style("white")
+
+x = torch.tensor((-0.25, -0.75))
+v1 = torch.tensor((np.sin(np.pi / 3), np.cos(np.pi / 3))) / 5
+v2 = torch.tensor((np.sin(-np.pi / 3), np.cos(np.pi / 3))) / 5
+y = torch.tensor((0.65, -0.55))
+t = torch.linspace(0, 1)
+xy = pmath.logmap(x, y)
+path = pmath.geodesic(t[:, None], x, y)
+yv1 = pmath.parallel_transport(x, y, v1)
+yv2 = pmath.parallel_transport(x, y, v2)
+
+
+circle = plt.Circle((0, 0), 1, fill=False, color="b")
+plt.gca().add_artist(circle)
+plt.xlim(-1.1, 1.1)
+plt.ylim(-1.1, 1.1)
+plt.gca().set_aspect("equal")
+plt.annotate("x", x - 0.07, fontsize=15)
+plt.annotate("y", y - 0.07, fontsize=15)
+plt.annotate(r"$\vec{v}$", x + torch.tensor([0.3, 0.5]), fontsize=15)
+plt.arrow(*x, *v1, width=0.01, color="r")
+plt.arrow(*x, *xy, width=0.01, color="g")
+plt.arrow(*x, *v2, width=0.01, color="b")
+plt.arrow(*y, *yv1, width=0.01, color="r")
+plt.arrow(*y, *yv2, width=0.01, color="b")
+plt.plot(*path.t().numpy(), color="g")
+plt.title(r"parallel transport $P^c_{x\to y}$")
+plt.show()
diff --git a/docs/plots/extended/poincare/poincare_lines.gif b/docs/plots/extended/poincare/poincare_lines.gif
new file mode 100644
index 00000000..da427047
Binary files /dev/null and b/docs/plots/extended/poincare/poincare_lines.gif differ
diff --git a/geoopt/__init__.py b/geoopt/__init__.py
index 43f95f74..fe0609f3 100644
--- a/geoopt/__init__.py
+++ b/geoopt/__init__.py
@@ -11,6 +11,7 @@
Sphere,
SphereSubspaceIntersection,
SphereSubspaceComplementIntersection,
+ PoincareBall,
)
__version__ = "0.0.1"
diff --git a/geoopt/linalg/_expm.py b/geoopt/linalg/_expm.py
index 1727a33d..1ffb88a7 100644
--- a/geoopt/linalg/_expm.py
+++ b/geoopt/linalg/_expm.py
@@ -3,7 +3,7 @@
maintainer: ferrine
"""
-import torch
+import torch.jit
@torch.jit.script
@@ -73,6 +73,7 @@ def expm_one(A): # pragma: no cover
U, V = torch_pade13(Ascaled)
P = U + V
Q = -U + V
+ # TODO: torch.gesv -> torch.solve after pytorch release
R, _ = torch.gesv(P, Q) # solve P = Q*R
expmA = matrix_2_power(R, n_squarings)
return expmA
diff --git a/geoopt/linalg/batch_linalg.py b/geoopt/linalg/batch_linalg.py
index f4d8842e..2a7aabe7 100644
--- a/geoopt/linalg/batch_linalg.py
+++ b/geoopt/linalg/batch_linalg.py
@@ -1,4 +1,4 @@
-import torch
+import torch.jit
from . import _expm
__all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank", "expm", "block_matrix"]
diff --git a/geoopt/manifolds/__init__.py b/geoopt/manifolds/__init__.py
index a54c88b6..99fe4840 100644
--- a/geoopt/manifolds/__init__.py
+++ b/geoopt/manifolds/__init__.py
@@ -6,3 +6,5 @@
SphereSubspaceComplementIntersection,
SphereSubspaceIntersection,
)
+from .poincare import PoincareBall
+from . import poincare
diff --git a/geoopt/manifolds/base.py b/geoopt/manifolds/base.py
index 3440cb86..ee1949e7 100644
--- a/geoopt/manifolds/base.py
+++ b/geoopt/manifolds/base.py
@@ -210,7 +210,7 @@ def set_default_order(self, order):
Returns
-------
- self
+ Manifold
returns same instance
"""
if order is None:
@@ -245,7 +245,7 @@ def reset_default_order(self):
Returns
-------
- self
+ Manifold
returns same instance
"""
return self.set_default_order(None)
@@ -471,7 +471,7 @@ def assert_check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5):
)
)
- def dist(self, x, y):
+ def dist(self, x, y, keepdim=False):
"""
Compute distance between 2 points on the manifold that is the shortest path along geodesics
@@ -481,13 +481,15 @@ def dist(self, x, y):
point on the manifold
y : tensor
point on the manifold
+ keepdim : bool
+ keep the last dim?
Returns
-------
scalar
distance between two points
"""
- return self._dist(x, y)
+ return self._dist(x, y, keepdim=keepdim)
def retr(self, x, u, t=1.0, order=None):
"""
@@ -638,7 +640,7 @@ def transp(self, x, v, *more, u=None, t=1.0, y=None, order=None):
else:
raise TypeError("transp() requires either y or u")
- def inner(self, x, u, v=None):
+ def inner(self, x, u, v=None, keepdim=False):
"""
Inner product for tangent vectors at point :math:`x`
@@ -650,6 +652,8 @@ def inner(self, x, u, v=None):
tangent vector at point :math:`x`
v : tensor (optional)
tangent vector at point :math:`x`
+ keepdim : bool
+ keep the last dim?
Returns
-------
@@ -658,7 +662,7 @@ def inner(self, x, u, v=None):
"""
if v is None and self._inner_autofill:
v = u
- return self._inner(x, u, v)
+ return self._inner(x, u, v, keepdim=keepdim)
# dev: autofill None parameter or propagate None?
_inner_autofill = True
@@ -907,7 +911,7 @@ def _retr(self, x, u, t):
_dist = not_implemented
@abc.abstractmethod
- def _inner(self, x, u, v):
+ def _inner(self, x, u, v, keepdim):
"""
Developer Guide
diff --git a/geoopt/manifolds/euclidean.py b/geoopt/manifolds/euclidean.py
index e4f5576d..60bf4f3c 100644
--- a/geoopt/manifolds/euclidean.py
+++ b/geoopt/manifolds/euclidean.py
@@ -24,7 +24,7 @@ def _check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5):
def _retr(self, x, u, t):
return x + t * u
- def _inner(self, x, u, v):
+ def _inner(self, x, u, v, keepdim):
return u * v
def _proju(self, x, u):
@@ -50,5 +50,5 @@ def _transp2y(self, x, v, *more, y):
def _logmap(self, x, y):
return y - x
- def _dist(self, x, y):
+ def _dist(self, x, y, keepdim):
return (x - y).abs()
diff --git a/geoopt/manifolds/poincare/__init__.py b/geoopt/manifolds/poincare/__init__.py
new file mode 100644
index 00000000..1d8d723f
--- /dev/null
+++ b/geoopt/manifolds/poincare/__init__.py
@@ -0,0 +1,102 @@
+import torch
+from . import math
+from ..base import Manifold
+
+__all__ = ["PoincareBall"]
+
+
+class PoincareBall(Manifold):
+ """
+ Poincare ball model, see more in :doc:`/extended/poincare`
+
+ Parameters
+ ----------
+ c : float|tensor
+ ball negative curvature
+
+ Notes
+ -----
+ It is extremely recommended to work with this manifold in double precision
+ """
+
+ ndim = 1
+ reversible = False
+ _default_order = 1
+ name = "Poincare ball"
+
+ def __init__(self, c=1.0):
+ super().__init__()
+ self.register_buffer("c", torch.as_tensor(c))
+
+ def _check_shape(self, x, name):
+ ok = x.dim() > 0
+ if not ok:
+ reason = "'{}' on poincare ball requires more that zero dim".format(name)
+ else:
+ reason = None
+ return ok, reason
+
+ def _check_point_on_manifold(self, x, atol=1e-5, rtol=1e-5):
+ px = math.project(x, c=self.c)
+ ok = torch.allclose(x, px, atol=atol, rtol=rtol)
+ if not ok:
+ reason = "'x' norm lies out of the bounds [-1/sqrt(c)+eps, 1/sqrt(c)-eps]"
+ else:
+ reason = None
+ return ok, reason
+
+ def _check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5):
+ return True, None
+
+ def _dist(self, x, y, keepdim):
+ return math.dist(x, y, c=self.c, keepdim=keepdim)
+
+ def _egrad2rgrad(self, x, u):
+ return math.egrad2rgrad(x, u, c=self.c)
+
+ def _retr(self, x, u, t):
+ # always assume u is scaled properly
+ approx = x + u * t
+ return math.project(approx, c=self.c)
+
+ _retr_transp_default_preference = "2y"
+
+ def _projx(self, x):
+ return math.project(x, c=self.c)
+
+ def _proju(self, x, u):
+ return math.clip_tangent(x, u, c=self.c)
+
+ def _inner(self, x, u, v, keepdim):
+ return math.inner(x, u, v, c=self.c, keepdim=keepdim)
+
+ def _expmap(self, x, u, t):
+ return math.project(math.expmap(x, u * t, c=self.c), c=self.c)
+
+ def _logmap(self, x, y):
+ return math.logmap(x, y, c=self.c)
+
+ def _transp2y(self, x, v, *more, y):
+ if not more:
+ return math.parallel_transport(x, y, v, c=self.c)
+ else:
+ n = len(more) + 1
+ vecs = torch.stack((v,) + more, dim=0)
+ transp = math.parallel_transport(x, y, vecs, c=self.c)
+ return tuple(transp[i] for i in range(n))
+
+ def _transp_follow(self, x, v, *more, u, t):
+ y = self._retr(x, u, t)
+ return self._transp2y(x, v, *more, y=y)
+
+ def _expmap_transp(self, x, v, *more, u, t):
+ y = self._expmap(x, u, t)
+ vs = self._transp2y(x, v, *more, y=y)
+ if more:
+ return (y,) + vs
+ else:
+ return y, vs
+
+ def _transp_follow_expmap(self, x, v, *more, u, t):
+ y = self._expmap(x, u, t)
+ return self._transp2y(x, v, *more, y=y)
diff --git a/geoopt/manifolds/poincare/math.py b/geoopt/manifolds/poincare/math.py
new file mode 100644
index 00000000..bd8f9469
--- /dev/null
+++ b/geoopt/manifolds/poincare/math.py
@@ -0,0 +1,1347 @@
+"""
+Functions for math on Poincare ball model. Most of this is taken from
+a well written paper by Octavian-Eugen Ganea (2018) [1]_
+
+
+.. [1] Octavian-Eugen Ganea et al., Hyperbolic Neural Networks, NIPS 2018
+"""
+
+import functools
+import torch.jit
+
+
+def tanh(x):
+ return x.clamp(-15, 15).tanh()
+
+
+class Artanh(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x):
+ x = x.clamp(-1 + 1e-15, 1 - 1e-15)
+ ctx.save_for_backward(x)
+ dtype = x.dtype
+ x = x.double()
+ res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5)
+ return res.to(dtype)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ return grad_output / (1 - input ** 2)
+
+
+class Arsinh(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ z = x.double()
+ return (z + torch.sqrt_(1 + z.pow(2))).clamp_min_(1e-15).log_().to(x.dtype)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ return grad_output / (1 + input ** 2) ** 0.5
+
+
+def artanh(x):
+ return Artanh.apply(x)
+
+
+def arsinh(x):
+ return Arsinh.apply(x)
+
+
+def project(x, *, c=1.0, dim=-1):
+ r"""
+ Safe projection on the manifold for numerical stability.
+
+ Parameters
+ ----------
+ x : tensor
+ point on the Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension to compute norm
+
+ Returns
+ -------
+ tensor
+ projected vector on the manifold
+ """
+ return _project(x, c, dim)
+
+
+@torch.jit.script
+def _max_norm(x):
+ if x.dtype == torch.float32:
+ maxnorm = torch.full((), 1 - 3e-3, dtype=x.dtype, device=x.device)
+ else:
+ maxnorm = torch.full((), 1 - 1e-5, dtype=x.dtype, device=x.device)
+ return maxnorm
+
+
+def _project(x, c, dim: int = -1):
+ norm = x.norm(dim=dim, keepdim=True, p=2)
+ maxnorm = _max_norm(x) / (c ** 0.5)
+ cond = norm > maxnorm
+ projected = x / norm * maxnorm
+ return torch.where(cond, projected, x)
+
+
+def lambda_x(x, *, c=1.0, keepdim=False, dim=-1):
+ r"""
+ Compute the conformal factor :math:`\lambda^c_x` for a point on the ball
+
+ .. math::
+
+ \lambda^c_x = \frac{1}{1 - c \|x\|_2^2}
+
+ Parameters
+ ----------
+ x : tensor
+ point on the Poincare ball
+ c : float|tensor
+ ball negative curvature
+ keepdim : bool
+ retain the last dim? (default: false)
+ dim : int
+ reduction dimension
+
+ Returns
+ -------
+ tensor
+ conformal factor
+ """
+ return _lambda_x(x, c, keepdim=keepdim, dim=dim)
+
+
+def _lambda_x(x, c, keepdim: bool = False, dim: int = -1):
+ return 2 / (1 - c * x.pow(2).sum(dim=dim, keepdim=keepdim))
+
+
+def inner(x, u, v, *, c=1.0, keepdim=False, dim=-1):
+ r"""
+ Compute inner product for two vectors on the tangent space w.r.t Riemannian metric on the Poincare ball
+
+ .. math::
+
+ \langle u, v\rangle_x = (\lambda^c_x)^2 \langle u, v \rangle
+
+ Parameters
+ ----------
+ x : tensor
+ point on the Poincare ball
+ u : tensor
+ tangent vector to :math:`x` on Poincare ball
+ v : tensor
+ tangent vector to :math:`x` on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ keepdim : bool
+ retain the last dim? (default: false)
+ dim : int
+ reduction dimension
+
+ Returns
+ -------
+ tensor
+ inner product
+ """
+ return _inner(x, u, v, c, keepdim=keepdim, dim=dim)
+
+
+def _inner(x, u, v, c, keepdim: bool = False, dim: int = -1):
+ return _lambda_x(x, c, keepdim=True, dim=dim) ** 2 * (u * v).sum(
+ dim=dim, keepdim=keepdim
+ )
+
+
+def norm(x, u, *, c=1.0, keepdim=False, dim=-1):
+ r"""
+ Compute vector norm on the tangent space w.r.t Riemannian metric on the Poincare ball
+
+ .. math::
+
+ \|u\|_x = \lambda^c_x \|u\|_2
+
+ Parameters
+ ----------
+ x : tensor
+ point on the Poincare ball
+ u : tensor
+ tangent vector to :math:`x` on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ keepdim : bool
+ retain the last dim? (default: false)
+ dim : int
+ reduction dimension
+
+ Returns
+ -------
+ tensor
+ norm of vector
+ """
+ return _norm(x, u, c, keepdim=keepdim, dim=dim)
+
+
+def _norm(x, u, c, keepdim: bool = False, dim: int = -1):
+ return _lambda_x(x, c, keepdim=keepdim, dim=dim) * u.norm(
+ dim=dim, keepdim=keepdim, p=2
+ )
+
+
+def mobius_add(x, y, *, c=1.0, dim=-1):
+ r"""
+ Mobius addition is a special operation in a hyperbolic space.
+
+ .. math::
+
+ x \oplus_c y = \frac{
+ (1 + 2 c \langle x, y\rangle + c \|y\|^2_2) x + (1 - c \|x\|_2^2) y
+ }{
+ 1 + 2 c \langle x, y\rangle + c^2 \|x\|^2_2 \|y\|^2_2
+ }
+
+ .. plot:: plots/extended/poincare/mobius_add.py
+
+ In general this operation is not commutative:
+
+ .. math::
+
+ x \oplus_c y \ne y \oplus_c x
+
+ But in some cases this property holds:
+
+ * zero vector case
+
+ .. math::
+
+ \mathbf{0} \oplus_c x = x \oplus_c \mathbf{0}
+
+ * zero negative curvature case that is same as Euclidean addition
+
+ .. math::
+
+ x \oplus_0 y = y \oplus_0 x
+
+ Another useful property is so called left-cancellation law:
+
+ .. math::
+
+ (-x) \oplus_c (x \oplus_c y) = y
+
+ Parameters
+ ----------
+ x : tensor
+ point on the Poincare ball
+ y : tensor
+ point on the Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ the result of mobius addition
+ """
+ return _mobius_add(x, y, c, dim=dim)
+
+
+def _mobius_add(x, y, c, dim=-1):
+ y = y + 1e-15
+ x2 = x.pow(2).sum(dim=dim, keepdim=True)
+ y2 = y.pow(2).sum(dim=dim, keepdim=True)
+ xy = (x * y).sum(dim=dim, keepdim=True)
+ num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
+ denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
+ # avoid division by zero in this way
+ return num / (denom + 1e-15)
+
+
+def mobius_sub(x, y, *, c=1.0, dim=-1):
+ r"""
+ Mobius substraction that can be represented via Mobius addition as follows:
+
+ .. math::
+
+ x \ominus_c y = x \oplus_c (-y)
+
+ Parameters
+ ----------
+ x : tensor
+ point on Poincare ball
+ y : tensor
+ point on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ the result of mobius substraction
+ """
+ return _mobius_sub(x, y, c, dim=dim)
+
+
+def _mobius_sub(x, y, c, dim: int = -1):
+ return _mobius_add(x, -y, c, dim=dim)
+
+
+def mobius_coadd(x, y, *, c=1.0, dim=-1):
+ r"""
+ Mobius coaddition operation
+
+ Addition operation :math:`\oplus_c` is neither associative, nor commutative. Coaddition, or cooperation in
+ Gyrogroup is an associative operation that is defined as follows.
+
+ .. math::
+
+ a \boxplus_c b = b \boxplus_c a = a\operatorname{gyr}[a, -b]b\\
+ = \frac{
+ (1 + c \|y\|^2_2) x + (1 - c \|x\|_2^2) y
+ }{
+ 1 + c^2 \|x\|^2_2 \|y\|^2_2
+ },
+
+ where :math:`\operatorname{gyr}[a, b]c = \ominus_c (a \oplus b) \oplus_c (a \oplus_c (b \oplus_c c))`
+
+ The following right cancellation property holds
+
+ .. math::
+
+ (a \boxplus_c b) \ominus_c b = a\\
+ (a \oplus_c b) \boxminus_c b = a
+
+ Parameters
+ ----------
+ x : tensor
+ point on Poincare ball
+ y : tensor
+ point on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ the result of mobius coaddition
+
+ """
+ return _mobius_coadd(x, y, c, dim=dim)
+
+
+def _mobius_coadd(x, y, c, dim: int = -1):
+ y = y + 1e-15
+ x2 = x.pow(2).sum(dim=dim, keepdim=True)
+ y2 = y.pow(2).sum(dim=dim, keepdim=True)
+ num = (1 - c * y2) * x + (1 - c * x2) * y
+ denom = 1 - c ** 2 * x2 * y2
+ # avoid division by zero in this way
+ return num / (denom + 1e-15)
+
+
+def mobius_cosub(x, y, *, c=1.0, dim=-1):
+ """
+ Mobius cosubstraction operation
+
+ .. math::
+
+ a \boxminus_c b = a \boxplus_c -b
+
+ Parameters
+ ----------
+ x : tensor
+ point on Poincare ball
+ y : tensor
+ point on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ the result of mobius coaddition
+
+ """
+ return _mobius_cosub(x, y, c, dim=dim)
+
+
+def _mobius_cosub(x, y, c, dim: int = -1):
+ return _mobius_coadd(x, -y, c, dim=dim)
+
+
+def mobius_scalar_mul(r, x, *, c=1.0, dim=-1):
+ r"""
+ Left scalar multiplication on the Poincare ball
+
+ .. math::
+
+ r \otimes_c x = (1/\sqrt{c}) \tanh(r\tanh^{-1}(\sqrt{c}\|x\|_2))\frac{x}{\|x\|_2}
+
+ This operation has properties similar to Euclidean
+
+ * `n-addition` property
+
+ .. math::
+
+ r \otimes_c x = x \oplus_c \dots \oplus_c x
+
+ * Distributive property
+
+ .. math::
+
+ (r_1 + r_2) \otimes_c x = r_1 \otimes_c x \oplus r_2 \otimes_c x
+
+ * Scalar associativity
+
+ .. math::
+
+ (r_1 r_2) \otimes_c x = r_1 \otimes_c (r_2 \otimes_c x)
+
+ * Scaling property
+
+ .. math::
+
+ |r| \otimes_c x / \|r \otimes_c x\|_2 = x/\|x\|_2
+
+ Parameters
+ ----------
+ r : float|tensor
+ scalar for multiplication
+ x : tensor
+ point on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ the result of mobius scalar multiplication
+ """
+ return _mobius_scalar_mul(r, x, c, dim=dim)
+
+
+def _mobius_scalar_mul(r, x, c, dim: int = -1):
+ x = x + 1e-15
+ x_norm = x.norm(dim=dim, keepdim=True, p=2)
+ sqrt_c = c ** 0.5
+ res_c = tanh(r * artanh(sqrt_c * x_norm)) * x / (x_norm * sqrt_c)
+ return res_c
+
+
+def dist(x, y, *, c=1.0, keepdim=False, dim=-1):
+ r"""
+ Distance on the Poincare ball
+
+ .. math::
+
+ d_c(x, y) = \frac{2}{\sqrt{c}}\tanh^{-1}(\sqrt{c}\|(-x)\oplus_c y\|_2)
+
+ .. plot:: plots/extended/poincare/distance.py
+
+ Parameters
+ ----------
+ x : tensor
+ point on Poincare ball
+ y : tensor
+ point on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ keepdim : bool
+ retain the last dim? (default: false)
+ dim : int
+ reduction dimension
+
+ Returns
+ -------
+ tensor
+ geodesic distance between :math:`x` and :math:`y`
+ """
+ return _dist(x, y, c, keepdim=keepdim, dim=dim)
+
+
+def _dist(x, y, c, keepdim: bool = False, dim: int = -1):
+ sqrt_c = c ** 0.5
+ dist_c = artanh(
+ sqrt_c * _mobius_add(-x, y, c, dim=dim).norm(dim=dim, p=2, keepdim=keepdim)
+ )
+ return dist_c * 2 / sqrt_c
+
+
+def dist0(x, *, c=1.0, keepdim=False, dim=-1):
+ r"""
+ Distance on the Poincare ball to zero
+
+ Parameters
+ ----------
+ x : tensor
+ point on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ keepdim : bool
+ retain the last dim? (default: false)
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ geodesic distance between :math:`x` and :math:`0`
+ """
+ return _dist0(x, c, keepdim=keepdim, dim=dim)
+
+
+def _dist0(x, c, keepdim: bool = False, dim: int = -1):
+ sqrt_c = c ** 0.5
+ dist_c = artanh(sqrt_c * x.norm(dim=dim, p=2, keepdim=keepdim))
+ return dist_c * 2 / sqrt_c
+
+
+def clip_tangent(x, u, *, c=1.0, dim=-1):
+ r"""
+ Project tangent vector to reasonable values that do not exceed
+ maximum allowed (vector norm allowing to travel to the opposite pole)
+
+ .. math::
+
+ \operatorname{maxnorm}_x = d_{c}(\operatorname{proj}(-\infty), \operatorname{proj}(\infty)) / \lambda_x^c
+
+ Parameters
+ ----------
+ x : tensor
+ point on Poincare ball
+ u : tensor
+ tangent vector
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension to compute norm
+
+ Returns
+ -------
+ tensor
+ same tangent vector with reasonable values
+ """
+ return _clip_tangent(x, u, c, dim=dim)
+
+
+def _clip_tangent(x, u, c, dim: int = -1):
+ # get the almost infinite vecotor estimate
+ # this is the norm of travel vector to the opposite pole
+ s = x.size(dim)
+ p = torch.ones((s,), dtype=x.dtype, device=x.device)
+ p = p / s ** 0.5 / (c ** 0.5)
+ p = _project(p, c, dim=dim)
+ # normalize its length based on x
+ maxnorm = _dist(p, -p, c, keepdim=True, dim=dim) / _lambda_x(
+ x, c, keepdim=True, dim=dim
+ )
+ norm = u.norm(dim=dim, keepdim=True, p=2)
+ cond = norm > maxnorm
+ projected = u / norm * maxnorm
+ return torch.where(cond, projected, u)
+
+
+def geodesic(t, x, y, *, c=1.0, dim=-1):
+ r"""
+ Geodesic (the shortest) path connecting :math:`x` and :math:`y`.
+ The path can be treated as and extension of a line segment between
+ points but in a Riemannian manifold. In Poincare ball model, the path
+ is expressed using Mobius addition and scalar multiplication:
+
+ .. math::
+
+ \gamma_{x\to y}(t) = x \oplus_c r \otimes_c ((-x) \oplus_c y)
+
+ The required properties of this path are the following:
+
+ .. math::
+
+ \gamma_{x\to y}(0) = x\\
+ \gamma_{x\to y}(1) = y\\
+ \dot\gamma_{x\to y}(t) = v
+
+ Moreover, as geodesic path is not only the shortest path connecting points and Poincare ball.
+ This definition also requires local distance minimization and thus another property appears:
+
+ .. math::
+
+ d_c(\gamma_{x\to y}(t_1), \gamma_{x\to y}(t_2)) = v|t_1-t_2|
+
+ "Natural parametrization" of the curve ensures unit speed geodesics which yields the above formula with :math:`v=1`.
+ However, for Poincare ball we can always compute the constant speed :math:`v` from the points
+ that particular path connects:
+
+ .. math::
+
+ v = d_c(\gamma_{x\to y}(0), \gamma_{x\to y}(1)) = d_c(x, y)
+
+
+ Parameters
+ ----------
+ t : float|tensor
+ travelling time
+ x : tensor
+ starting point on Poincare ball
+ y : tensor
+ target point on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ point on the Poincare ball
+ """
+ return _geodesic(t, x, y, c, dim=dim)
+
+
+def _geodesic(t, x, y, c, dim: int = -1):
+ # this is not very numerically unstable
+ v = _mobius_add(-x, y, c, dim=dim)
+ tv = _mobius_scalar_mul(t, v, c, dim=dim)
+ gamma_t = _mobius_add(x, tv, c, dim=dim)
+ return gamma_t
+
+
+def expmap(x, u, *, c=1.0, dim=-1):
+ r"""
+ Exponential map for Poincare ball model. This is tightly related with :func:`geodesic`.
+ Intuitively Exponential map is a smooth constant travelling from starting point :math:`x` with speed :math:`u`.
+
+ A bit more formally this is travelling along curve :math:`\gamma_{x, u}(t)` such that
+
+ .. math::
+
+ \gamma_{x, u}(0) = x\\
+ \dot\gamma_{x, u}(0) = u\\
+ \|\dot\gamma_{x, u}(t)\|_{\gamma_{x, u}(t)} = \|u\|_x
+
+ The existence of this curve relies on uniqueness of differential equation solution, that is local.
+ For the Poincare ball model the solution is well defined globally and we have.
+
+ .. math::
+
+ \operatorname{Exp}^c_x(u) = \gamma_{x, u}(1) = \\
+ x\oplus_c \tanh(\sqrt{c}/2 \|u\|_x) \frac{u}{\sqrt{c}\|u\|_2}
+
+ Parameters
+ ----------
+ x : tensor
+ starting point on Poincare ball
+ u : tensor
+ speed vector on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ :math:`\gamma_{x, u}(1)` end point
+ """
+ return _expmap(x, u, c, dim=dim)
+
+
+def _expmap(x, u, c, dim: int = -1):
+ u += 1e-15
+ sqrt_c = c ** 0.5
+ u_norm = u.norm(dim=dim, p=2, keepdim=True)
+ second_term = (
+ tanh(sqrt_c / 2 * _lambda_x(x, c, keepdim=True, dim=dim) * u_norm)
+ * u
+ / (sqrt_c * u_norm)
+ )
+ gamma_1 = _mobius_add(x, second_term, c, dim=dim)
+ return gamma_1
+
+
+def expmap0(u, *, c=1.0, dim=-1):
+ r"""
+ Exponential map for Poincare ball model from :math:`0`.
+
+ .. math::
+
+ \operatorname{Exp}^c_0(u) = \tanh(\sqrt{c}/2 \|u\|_2) \frac{u}{\sqrt{c}\|u\|_2}
+
+ Parameters
+ ----------
+ u : tensor
+ speed vector on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ :math:`\gamma_{0, u}(1)` end point
+ """
+ return _expmap0(u, c, dim=dim)
+
+
+def _expmap0(u, c, dim: int = -1):
+ u = u + 1e-15
+ sqrt_c = c ** 0.5
+ u_norm = u.norm(dim=dim, p=2, keepdim=True)
+ gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
+ return gamma_1
+
+
+def geodesic_unit(t, x, u, *, c=1.0, dim=-1):
+ r"""
+ Unit speed geodesic starting from :math:`x` with direction :math:`u/\|u\|_x`
+
+ .. math::
+
+ \gamma_{x,u}(t) = x\oplus_c \tanh(t\sqrt{c}/2) \frac{u}{\sqrt{c}\|u\|_2}
+
+ Parameters
+ ----------
+ t : tensor
+ travelling time
+ x : tensor
+ initial point
+ u : tensor
+ direction
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ the point on geodesic line
+ """
+ return _geodesic_unit(t, x, u, c, dim=dim)
+
+
+def _geodesic_unit(t, x, u, c, dim: int = -1):
+ sqrt_c = c ** 0.5
+ u_norm = u.norm(dim=dim, p=2, keepdim=True)
+ second_term = tanh(sqrt_c / 2 * t) * u / (sqrt_c * u_norm)
+ gamma_1 = _mobius_add(x, second_term, c, dim=dim)
+ return gamma_1
+
+
+def logmap(x, y, *, c=1.0, dim=-1):
+ r"""
+ Logarithmic map for two points :math:`x` and :math:`y` on the manifold.
+
+ .. math::
+
+ \operatorname{Log}^c_x(y) = \frac{2}{\sqrt{c}\lambda_x^c} \tanh^{-1}(
+ \sqrt{c} \|(-x)\oplus_c y\|_2
+ ) * \frac{(-x)\oplus_c y}{\|(-x)\oplus_c y\|_2}
+
+ The result of Logarithmic map is a vector such that
+
+ .. math::
+
+ y = \operatorname{Exp}^c_x(\operatorname{Log}^c_x(y))
+
+
+ Parameters
+ ----------
+ x : tensor
+ starting point on Poincare ball
+ y : tensor
+ target point on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ tangent vector that transports :math:`x` to :math:`y`
+ """
+ return _logmap(x, y, c, dim=dim)
+
+
+def _logmap(x, y, c, dim: int = -1):
+ sub = _mobius_add(-x, y, c, dim=dim)
+ sub_norm = sub.norm(dim=dim, p=2, keepdim=True)
+ lam = _lambda_x(x, c, keepdim=True, dim=dim)
+ sqrt_c = c ** 0.5
+ return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm
+
+
+def logmap0(y, *, c=1.0, dim=-1):
+ r"""
+ Logarithmic map for :math:`y` from :math:`0` on the manifold.
+
+
+ .. math::
+
+ \operatorname{Log}^c_0(y) = \tanh^{-1}(\sqrt{c}\|y\|_2) \frac{y}{\|y\|_2}
+
+ The result is such that
+
+ .. math::
+
+ y = \operatorname{Exp}^c_0(\operatorname{Log}^c_0(y))
+
+ Parameters
+ ----------
+ y : tensor
+ target point on Poincare ball
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ tangent vector that transports :math:`0` to :math:`y`
+ """
+ return _logmap0(y, c, dim=dim)
+
+
+def _logmap0(y, c, dim: int = -1):
+ sqrt_c = c ** 0.5
+ y = y + 1e-15
+ y_norm = y.norm(dim=dim, p=2, keepdim=True)
+ return y / y_norm / sqrt_c * artanh(sqrt_c * y_norm)
+
+
+def mobius_matvec(m, x, *, c=1.0, dim=-1):
+ r"""
+ Generalization for matrix-vector multiplication to hyperbolic space defined as
+
+ .. math::
+
+ M \otimes_c x = (1/\sqrt{c}) \tanh\left(
+ \frac{\|Mx\|_2}{\|x\|_2}\tanh^{-1}(\sqrt{c}\|x\|_2)
+ \right)\frac{Mx}{\|Mx\|_2}
+
+ .. plot:: plots/extended/poincare/mobius_matvec.py
+
+ Parameters
+ ----------
+ m : tensor
+ matrix for multiplication.
+ Batched matmul is performed if ``m.dim() > 2``, but only last dim reduction is supported
+ x : tensor
+ point on Poincare ball
+ c : float|tensor
+ negative ball curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ Mobius matvec result
+ """
+ return _mobius_matvec(m, x, c, dim=dim)
+
+
+def _mobius_matvec(m, x, c, dim: int = -1):
+ if m.dim() > 2 and dim != -1:
+ raise RuntimeError(
+ "broadcasted Mobius matvec is supported for the last dim only"
+ )
+ x = x + 1e-15
+ x_norm = x.norm(dim=dim, keepdim=True, p=2)
+ sqrt_c = c ** 0.5
+ if dim != -1 or m.dim() == 2:
+ mx = torch.tensordot(x, m, dims=([dim], [1]))
+ else:
+ mx = torch.matmul(m, x.unsqueeze(-1)).squeeze(-1)
+ mx_norm = mx.norm(dim=dim, keepdim=True, p=2)
+ res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c)
+ cond = (mx == 0).prod(dim=dim, keepdim=True, dtype=torch.uint8)
+ res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
+ res = torch.where(cond, res_0, res_c)
+ return res
+
+
+def mobius_pointwise_mul(w, x, *, c=1.0, dim=-1):
+ r"""
+ Generalization for point-wise multiplication to hyperbolic space defined as
+
+ .. math::
+
+ \operatorname{diag}(w) \otimes_c x = (1/\sqrt{c}) \tanh\left(
+ \frac{\|\operatorname{diag}(w)x\|_2}{x}\tanh^{-1}(\sqrt{c}\|x\|_2)
+ \right)\frac{\|\operatorname{diag}(w)x\|_2}{\|x\|_2}
+
+
+ Parameters
+ ----------
+ w : tensor
+ weights for multiplication (should be broadcastable to x)
+ x : tensor
+ point on Poincare ball
+ c : float|tensor
+ negative ball curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ Mobius point-wise mul result
+ """
+ return _mobius_pointwise_mul(w, x, c, dim=dim)
+
+
+def _mobius_pointwise_mul(w, x, c, dim: int = -1):
+ x = x + 1e-15
+ x_norm = x.norm(dim=dim, keepdim=True, p=2)
+ sqrt_c = c ** 0.5
+ wx = w * x
+ wx_norm = wx.norm(dim=dim, keepdim=True, p=2)
+ res_c = tanh(wx_norm / x_norm * artanh(sqrt_c * x_norm)) * wx / (wx_norm * sqrt_c)
+ cond = (wx == 0).prod(dim=dim, keepdim=True, dtype=torch.uint8)
+ res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
+ res = torch.where(cond, res_0, res_c)
+ return res
+
+
+def mobius_fn_apply_chain(x, *fns, c=1.0, dim=-1):
+ r"""
+ Generalization for functions in hyperbolic space.
+ First, hyperbolic vector is mapped to a Euclidean space via
+ :math:`\operatorname{Log}^c_0` and nonlinear function is applied in this tangent space.
+ The resulting vector is then mapped back with :math:`\operatorname{Exp}^c_0`
+
+ .. math::
+
+ f^{\otimes_c}(x) = \operatorname{Exp}^c_0(f(\operatorname{Log}^c_0(y)))
+
+ The definition of mobius function application allows chaining as
+
+ .. math::
+
+ y = \operatorname{Exp}^c_0(\operatorname{Log}^c_0(y))
+
+ Resulting in
+
+ .. math::
+
+ (f \circ g)^{\otimes_c}(x) = \operatorname{Exp}^c_0((f \circ g) (\operatorname{Log}^c_0(y)))
+
+ Parameters
+ ----------
+ x : tensor
+ point on Poincare ball
+ fns : callable[]
+ functions to apply
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ Apply chain result
+ """
+ if not fns:
+ return x
+ else:
+ ex = _logmap0(x, c, dim=dim)
+ for fn in fns:
+ ex = fn(ex)
+ y = _expmap0(ex, c, dim=dim)
+ return y
+
+
+def mobius_fn_apply(fn, x, *args, c=1.0, dim=-1, **kwargs):
+ r"""
+ Generalization for functions in hyperbolic space.
+ First, hyperbolic vector is mapped to a Euclidean space via
+ :math:`\operatorname{Log}^c_0` and nonlinear function is applied in this tangent space.
+ The resulting vector is then mapped back with :math:`\operatorname{Exp}^c_0`
+
+ .. math::
+
+ f^{\otimes_c}(x) = \operatorname{Exp}^c_0(f(\operatorname{Log}^c_0(y)))
+
+ .. plot:: plots/extended/poincare/mobius_sigmoid_apply.py
+
+ Parameters
+ ----------
+ x : tensor
+ point on Poincare ball
+ fn : callable
+ function to apply
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ Result of function in hyperbolic space
+ """
+ ex = _logmap0(x, c, dim=dim)
+ ex = fn(ex, *args, **kwargs)
+ y = _expmap0(ex, c, dim=dim)
+ return y
+
+
+def mobiusify(fn):
+ r"""
+ Wraps a function so that is works in hyperbolic space. New function will accept additional argument ``c``
+
+ Parameters
+ ----------
+ fn : callable
+ function in Euclidean space, only its first argument is treated as hyperbolic
+
+ Returns
+ -------
+ callable
+ function working in hyperbolic space
+ """
+
+ @functools.wraps(fn)
+ def mobius_fn(x, *args, c=1.0, dim=-1, **kwargs):
+ ex = _logmap0(x, c, dim=dim)
+ ex = fn(ex, *args, **kwargs)
+ y = _expmap0(ex, c, dim=dim)
+ return y
+
+ return mobius_fn
+
+
+def dist2plane(x, p, a, *, c=1.0, keepdim=False, signed=False, dim=-1):
+ r"""
+ Distance from :math:`x` to a hyperbolic hyperplane in Poincare ball
+ that is orthogonal to :math:`a` and contains :math:`p`.
+
+ .. plot:: plots/extended/poincare/distance2plane.py
+
+ To form an intuition what is a hyperbolic hyperplane, let's first consider Euclidean hyperplane
+
+ .. math::
+
+ H_{a, b} = \left\{
+ x \in \mathbb{R}^n\;:\;\langle x, a\rangle - b = 0
+ \right\},
+
+ where :math:`a\in \mathbb{R}^n\backslash \{\mathbf{0}\}` and :math:`b\in \mathbb{R}^n`.
+
+ This formulation of a hyperplane is hard to generalize,
+ therefore we can rewrite :math:`\langle x, a\rangle - b`
+ utilizing orthogonal completion.
+ Setting any :math:`p` s.t. :math:`b=\langle a, p\rangle` we have
+
+ .. math::
+
+ H_{a, b} = \left\{
+ x \in \mathbb{R}^n\;:\;\langle x, a\rangle - b = 0
+ \right\}\\
+ =H_{a, \langle a, p\rangle} = \tilde{H}_{a, p}\\
+ = \left\{
+ x \in \mathbb{R}^n\;:\;\langle x, a\rangle - \langle a, p\rangle = 0
+ \right\}\\
+ =\left\{
+ x \in \mathbb{R}^n\;:\;\langle -p + x, a\rangle = 0
+ \right\}\\
+ = p + \{a\}^\perp
+
+ Naturally we have a set :math:`\{a\}^\perp` with applied :math:`+` operator to each element.
+ Generalizing a notion of summation to the hyperbolic space we replace :math:`+` with :math:`\oplus_c`.
+
+ Next, we should figure out what is :math:`\{a\}^\perp` in the Poincare ball.
+
+ First thing that we should acknowledge is that notion of orthogonality is defined for vectors in tangent spaces.
+ Let's consider now :math:`p\in \mathbb{D}_c^n` and :math:`a\in T_p\mathbb{D}_c^n\backslash \{\mathbf{0}\}`.
+
+ Slightly deviating from traditional notation let's write :math:`\{a\}_p^\perp`
+ highlighting the tight relationship of :math:`a\in T_p\mathbb{D}_c^n\backslash \{\mathbf{0}\}`
+ with :math:`p \in \mathbb{D}_c^n`. We then define
+
+ .. math::
+
+ \{a\}_p^\perp := \left\{
+ z\in T_p\mathbb{D}_c^n \;:\; \langle z, a\rangle_p = 0
+ \right\}
+
+ Recalling that a tangent vector :math:`z` for point :math:`p` yields :math:`x = \operatorname{Exp}^c_p(z)`
+ we rewrite the above equation as
+
+ .. math::
+ \{a\}_p^\perp := \left\{
+ x\in \mathbb{D}_c^n \;:\; \langle \operatorname{Log}_p^c(x), a\rangle_p = 0
+ \right\}
+
+ This formulation is something more pleasant to work with.
+ Putting all together
+
+ .. math::
+
+ \tilde{H}_{a, p}^c = p + \{a\}^\perp_p\\
+ = \left\{
+ x \in \mathbb{D}_c^n\;:\;\langle\operatorname{Log}^c_p(x), a\rangle_p = 0
+ \right\} \\
+ = \left\{
+ x \in \mathbb{D}_c^n\;:\;\langle -p \oplus_c x, a\rangle = 0
+ \right\}
+
+ To compute the distance :math:`d_c(x, \tilde{H}_{a, p}^c)` we find
+
+ .. math::
+
+ d_c(x, \tilde{H}_{a, p}^c) = \inf_{w\in \tilde{H}_{a, p}^c} d_c(x, w)\\
+ = \frac{1}{\sqrt{c}} \sinh^{-1}\left\{
+ \frac{
+ 2\sqrt{c} |\langle(-p)\oplus_c x, a\rangle|
+ }{
+ (1-c\|(-p)\oplus_c x\|^2_2)\|a\|_2
+ }
+ \right\}
+
+ Parameters
+ ----------
+ x : tensor
+ point on Poincare ball
+ a : tensor
+ vector on tangent space of :math:`p`
+ p : tensor
+ point on Poincare ball lying on the hyperplane
+ c : float|tensor
+ ball negative curvature
+ keepdim : bool
+ retain the last dim? (default: false)
+ signed : bool
+ return signed distance
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ distance to the hyperplane
+ """
+ return _dist2plane(x, a, p, c, keepdim=keepdim, signed=signed, dim=dim)
+
+
+def _dist2plane(x, a, p, c, keepdim: bool = False, signed: bool = False, dim: int = -1):
+ sqrt_c = c ** 0.5
+ diff = _mobius_add(-p, x, c, dim=dim)
+ diff_norm2 = diff.pow(2).sum(dim=dim, keepdim=keepdim)
+ sc_diff_a = (diff * a).sum(dim=dim, keepdim=keepdim)
+ if not signed:
+ sc_diff_a = sc_diff_a.abs()
+ a_norm = a.norm(dim=dim, keepdim=keepdim, p=2)
+ num = 2 * sqrt_c * sc_diff_a
+ denom = (1 - c * diff_norm2) * a_norm
+ return arsinh(num / (denom + 1e-15)) / sqrt_c
+
+
+def gyration(a, b, u, *, c=1.0, dim=-1):
+ r"""
+ Gyration is a special operation in hyperbolic geometry.
+ Addition operation :math:`\oplus_c` is not associative (as mentioned in :func:`mobius_add`),
+ but gyroassociative which means
+
+ .. math::
+
+ u \oplus_c (v \oplus_c w) = (u\oplus_c v) \oplus_c \operatorname{gyr}[u, v]w,
+
+ where
+
+ .. math::
+
+ \operatorname{gyr}[u, v]w = \ominus (u \oplus_c v) \oplus (u \oplus_c (v \oplus_c w))
+
+ We can simplify this equation using explicit formula for Mobius addition [1]. Recall
+
+ .. math::
+
+ A = - c^2 \langle u, w\rangle \langle v, v\rangle + c \langle v, w\rangle +
+ 2 c^2 \langle u, v\rangle \langle v, w\rangle\\
+ B = - c^2 \langle v, w\rangle \langle u, u\rangle - c \langle u, w\rangle\\
+ D = 1 + 2 c \langle u, v\rangle + c^2 \langle u, u\rangle \langle v, v\rangle\\
+
+ \operatorname{gyr}[u, v]w = w + 2 \frac{A u + B v}{D}
+
+ Parameters
+ ----------
+ a : tensor
+ first point on Poincare ball
+ b : tensor
+ second point on Poincare ball
+ u : tensor
+ vector field for operation
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ the result of automorphism
+
+ References
+ ----------
+ [1] A. A. Ungar (2009), A Gyrovector Space Approach to Hyperbolic Geometry
+ """
+ return _gyration(a, b, u, c, dim=dim)
+
+
+def _gyration(u, v, w, c, dim: int = -1):
+ # non-simplified
+ # mupv = -_mobius_add(u, v, c)
+ # vpw = _mobius_add(u, w, c)
+ # upvpw = _mobius_add(u, vpw, c)
+ # return _mobius_add(mupv, upvpw, c)
+ # simplified
+ u2 = u.pow(2).sum(dim=dim, keepdim=True)
+ v2 = v.pow(2).sum(dim=dim, keepdim=True)
+ uv = (u * v).sum(dim=dim, keepdim=True)
+ uw = (u * w).sum(dim=dim, keepdim=True)
+ vw = (v * w).sum(dim=dim, keepdim=True)
+ c2 = c ** 2
+ a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw
+ b = -c2 * vw * u2 - c * uw
+ d = 1 + 2 * c * uv + c2 * u2 * v2
+ return w + 2 * (a * u + b * v) / (d + 1e-15)
+
+
+def parallel_transport(x, y, v, *, c=1.0, dim=-1):
+ r"""
+ Parallel transport is essential for adaptive algorithms in Riemannian manifolds.
+ For Hyperbolic spaces parallel transport is expressed via gyration.
+
+ .. plot:: plots/extended/poincare/gyrovector_parallel_transport.py
+
+ To recover parallel transport we first need to study isomorphism between gyrovectors and vectors.
+ The reason is that originally, parallel transport is well defined for gyrovectors as
+
+ .. math::
+
+ P_{x\to y}(z) = \operatorname{gyr}[y, -x]z,
+
+ where :math:`x,\:y,\:z \in \mathbb{D}_c^n` and
+ :math:`\operatorname{gyr}[a, b]c = \ominus (a \oplus_c b) \oplus_c (a \oplus_c (b \oplus_c c))`
+
+ But we want to obtain parallel transport for vectors, not for gyrovectors.
+ The blessing is isomorphism mentioned above. This mapping is given by
+
+ .. math::
+
+ U^c_p \: : \: T_p\mathbb{D}_c^n \to \mathbb{G} = v \mapsto \lambda^c_p v
+
+
+ Finally, having points :math:`x,\:y \in \mathbb{D}_c^n` and a tangent vector :math:`u\in T_x\mathbb{D}_c^n` we obtain
+
+ .. math::
+
+ P^c_{x\to y}(v) = (U^c_y)^{-1}\left(\operatorname{gyr}[y, -x] U^c_x(v)\right)\\
+ = \operatorname{gyr}[y, -x] v \lambda^c_x / \lambda^c_y
+
+ .. plot:: plots/extended/poincare/parallel_transport.py
+
+
+ Parameters
+ ----------
+ x : tensor
+ starting point
+ y : tensor
+ end point
+ v : tensor
+ tangent vector to be transported
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ transported vector
+ """
+ return _parallel_transport(x, y, v, c, dim=dim)
+
+
+def _parallel_transport(x, y, u, c, dim: int = -1):
+ return (
+ _gyration(y, -x, u, c, dim=dim)
+ * _lambda_x(x, c, keepdim=True, dim=dim)
+ / _lambda_x(y, c, keepdim=True, dim=dim)
+ )
+
+
+def parallel_transport0(y, v, *, c=1.0, dim=-1):
+ r"""
+ Special case parallel transport with starting point at zero that
+ can be computed more efficiently and numerically stable
+
+ Parameters
+ ----------
+ y : tensor
+ target point
+ v : tensor
+ vector to be transported
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ """
+ return _parallel_transport0(y, v, c, dim=dim)
+
+
+def _parallel_transport0(y, v, c, dim: int = -1):
+ return v * (1 - c * y.pow(2).sum(dim=dim, keepdim=True))
+
+
+def egrad2rgrad(x, grad, *, c=1.0, dim=-1):
+ r"""
+ Translate Euclidean gradient to Riemannian gradient on tangent space of :math:`x`
+
+ .. math::
+
+ \nabla_x = \nabla^E_x / (\lambda_x^c)^2
+
+ Parameters
+ ----------
+ x : tensor
+ point on the Poincare ball
+ grad : tensor
+ Euclidean gradient for :math:`x`
+ c : float|tensor
+ ball negative curvature
+ dim : int
+ reduction dimension for operations
+
+ Returns
+ -------
+ tensor
+ Riemannian gradient :math:`u\in T_x\mathbb{D}_c^n`
+ """
+ return _egrad2rgrad(x, grad, c, dim=dim)
+
+
+def _egrad2rgrad(x, grad, c, dim: int = -1):
+ return grad / _lambda_x(x, c, keepdim=True, dim=dim) ** 2
diff --git a/geoopt/manifolds/sphere.py b/geoopt/manifolds/sphere.py
index 5c3741d1..2795d37e 100644
--- a/geoopt/manifolds/sphere.py
+++ b/geoopt/manifolds/sphere.py
@@ -38,14 +38,14 @@ def _check_point_on_manifold(self, x, atol=1e-5, rtol=1e-5):
return True, None
def _check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5):
- inner = self._inner(None, x, u)
+ inner = self._inner(None, x, u, keepdim=True)
ok = torch.allclose(inner, inner.new((1,)).fill_(0), atol=atol, rtol=rtol)
if not ok:
return False, "` != 0` with atol={}, rtol={}".format(atol, rtol)
return True, None
- def _inner(self, x, u, v):
- return (u * v).sum(-1)
+ def _inner(self, x, u, v, keepdim):
+ return (u * v).sum(-1, keepdim=keepdim)
def _projx(self, x):
return x / x.norm(dim=-1, keepdim=True)
@@ -88,13 +88,13 @@ def _expmap_transp(self, x, v, *more, u, t):
def _logmap(self, x, y):
u = self._proju(x, y - x)
- dist = self._dist(x, y).unsqueeze(-1)
+ dist = self._dist(x, y, keepdim=True)
# If the two points are "far apart", correct the norm.
cond = dist.gt(1e-6)
return torch.where(cond, u * dist / u.norm(dim=-1, keepdim=True), u)
- def _dist(self, x, y):
- inner = self._inner(None, x, y).clamp(-1, 1)
+ def _dist(self, x, y, keepdim):
+ inner = self._inner(None, x, y, keepdim=keepdim).clamp(-1, 1)
return torch.acos(inner)
diff --git a/geoopt/manifolds/stiefel.py b/geoopt/manifolds/stiefel.py
index 6bfa5c64..a6100ba6 100644
--- a/geoopt/manifolds/stiefel.py
+++ b/geoopt/manifolds/stiefel.py
@@ -89,7 +89,7 @@ class CanonicalStiefel(Stiefel):
name = "Stiefel(canonical)"
reversible = True
- def _inner(self, x, u, v):
+ def _inner(self, x, u, v, keepdim):
# _x = tr(u^T(I-1/2xx^T)v)
# = tr(u^T(v-1/2xx^Tv))
# = tr(u^Tv-1/2u^Txx^Tv)
@@ -102,7 +102,9 @@ def _inner(self, x, u, v):
v = u
else:
xtv = x.transpose(-1, -2) @ v
- return (u * v).sum([-1, -2]) - 0.5 * (xtv * xtu).sum([-1, -2])
+ return (u * v).sum([-1, -2], keepdim=keepdim) - 0.5 * (xtv * xtu).sum(
+ [-1, -2], keepdim=keepdim
+ )
# we do faster on inner without autofill
_inner_autofill = False
@@ -112,6 +114,7 @@ def _transp_follow_one(self, x, v, *, u, t):
rhs = v + t / 2 * a @ v
lhs = -t / 2 * a
lhs[..., torch.arange(a.shape[-2]), torch.arange(x.shape[-2])] += 1
+ # TODO: torch.gesv -> torch.solve after pytorch release
qv, _ = torch.gesv(rhs, lhs)
return qv
@@ -182,8 +185,8 @@ def _retr_transp(self, x, v, *more, u, t):
else:
return y, vs
- def _inner(self, x, u, v):
- return (u * v).sum([-1, -2])
+ def _inner(self, x, u, v, keepdim):
+ return (u * v).sum([-1, -2], keepdim=keepdim)
def _retr(self, x, u, t):
q, r = linalg.batch_linalg.qr(x + u * t)
diff --git a/geoopt/optim/radam.py b/geoopt/optim/radam.py
index b669d6ad..b6a6d98a 100644
--- a/geoopt/optim/radam.py
+++ b/geoopt/optim/radam.py
@@ -83,17 +83,10 @@ def step(self, closure=None):
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
- inner_prod_shape = p.shape
- if manifold.ndim > 0:
- inner_prod_shape = inner_prod_shape[: -manifold.ndim]
- state["exp_avg_sq"] = torch.zeros(
- inner_prod_shape, dtype=p.dtype, device=p.device
- )
+ state["exp_avg_sq"] = torch.zeros_like(p)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
- state["max_exp_avg_sq"] = torch.zeros(
- inner_prod_shape, dtype=p.dtype, device=p.device
- )
+ state["max_exp_avg_sq"] = torch.zeros_like(p)
# this is assumed to be already transported
if "traced_step" not in state:
@@ -174,7 +167,9 @@ def perform_step(
grad.add_(weight_decay, point)
grad = manifold.egrad2rgrad(point, grad)
exp_avg.mul_(betas[0]).add_(1 - betas[0], grad)
- exp_avg_sq.mul_(betas[1]).add_(1 - betas[1], manifold.inner(point, grad))
+ exp_avg_sq.mul_(betas[1]).add_(
+ 1 - betas[1], manifold.inner(point, grad, keepdim=True)
+ )
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
@@ -182,7 +177,6 @@ def perform_step(
denom = max_exp_avg_sq.sqrt().add_(eps)
else:
denom = exp_avg_sq.sqrt().add_(eps)
- denom = manifold.broadcast_scalar(denom)
step.add_(1)
bias_correction1 = 1 - betas[0] ** step.type_as(betas)
bias_correction2 = 1 - betas[1] ** step.type_as(betas)
diff --git a/geoopt/optim/tracing.py b/geoopt/optim/tracing.py
index 842520b1..7b014b23 100644
--- a/geoopt/optim/tracing.py
+++ b/geoopt/optim/tracing.py
@@ -33,5 +33,6 @@ def create_traced_update(step, manifold, point, *buffers, **kwargs):
def partial(*args):
step(manifold, *args, **kwargs)
return args
+
return partial
# return torch.jit.trace(partial, (point, grad, lr) + buffers)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 5d33904c..ee440468 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -7,3 +7,4 @@ pymanopt
twine
wheel
sphinx
+seaborn
diff --git a/setup.py b/setup.py
index fe8f2fc0..f1a55f78 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,7 @@
DESCRIPTION = """Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more"""
PROJECT_ROOT = os.path.dirname(os.path.realpath(__file__))
-with open(os.path.join(PROJECT_ROOT, 'README.rst'), encoding='utf-8') as buff:
+with open(os.path.join(PROJECT_ROOT, "README.rst"), encoding="utf-8") as buff:
LONG_DESCRIPTION = buff.read()
@@ -44,5 +44,5 @@ def get_version(*path):
url="https://github.com/geoopt/geoopt",
python_requires=">=3.6.0",
license="Apache License, Version 2.0",
- classifiers=classifiers
+ classifiers=classifiers,
)
diff --git a/tests/test_adam.py b/tests/test_adam.py
index 318b2879..96a3159e 100644
--- a/tests/test_adam.py
+++ b/tests/test_adam.py
@@ -30,3 +30,24 @@ def closure():
np.testing.assert_allclose(X.data, Xstar, atol=1e-5, rtol=1e-5)
optim.load_state_dict(optim.state_dict())
optim.step(closure)
+
+
+def test_adam_poincare():
+ torch.manual_seed(44)
+ ideal = torch.tensor([0.5, 0.5])
+ start = torch.randn(2) / 2
+ start = geoopt.manifolds.poincare.math.expmap0(start, c=1.0)
+ start = geoopt.ManifoldParameter(start, manifold=geoopt.PoincareBall())
+
+ def closure():
+ optim.zero_grad()
+ loss = geoopt.manifolds.poincare.math.dist(start, ideal) ** 2
+ loss.backward()
+ return loss.item()
+
+ optim = geoopt.optim.RiemannianAdam([start], lr=1e-2)
+
+ for _ in range(2000):
+ optim.step(closure)
+
+ np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5)
diff --git a/tests/test_manifold.py b/tests/test_manifold.py
index a10c8d06..4d3b6ec8 100644
--- a/tests/test_manifold.py
+++ b/tests/test_manifold.py
@@ -26,6 +26,7 @@ def t(request):
# match implementation of pymanopt for stiefel
functools.partial(geoopt.manifolds.Stiefel, canonical=False),
functools.partial(geoopt.manifolds.Stiefel, canonical=True),
+ geoopt.manifolds.PoincareBall,
geoopt.manifolds.Euclidean,
geoopt.manifolds.Sphere,
functools.partial(
@@ -41,7 +42,7 @@ def t(request):
def manifold(request, retraction_order):
man = request.param()
try:
- return man.set_default_order(retraction_order)
+ return man.set_default_order(retraction_order).double()
except ValueError:
pytest.skip("not supported retraction order for {}".format(man))
@@ -63,6 +64,7 @@ def manifold(request, retraction_order):
# shapes to verify unary element implementation
shapes = {
+ geoopt.manifolds.PoincareBall: (3,),
geoopt.manifolds.EuclideanStiefel: (10, 5),
geoopt.manifolds.CanonicalStiefel: (10, 5),
geoopt.manifolds.Euclidean: (1,),
@@ -79,11 +81,17 @@ def manifold(request, retraction_order):
@pytest.fixture()
def unary_case(manifold):
shape = shapes[type(manifold)]
- manopt_manifold = mannopt[type(manifold)](*shape)
np.random.seed(42)
- rand = manopt_manifold.rand().astype("float64")
- x = geoopt.ManifoldTensor(torch.from_numpy(rand), manifold=manifold)
torch.manual_seed(43)
+ if type(manifold) in mannopt:
+ manopt_manifold = mannopt[type(manifold)](*shape)
+ rand = manopt_manifold.rand().astype("float64")
+ x = geoopt.ManifoldTensor(torch.from_numpy(rand), manifold=manifold)
+ else:
+ manopt_manifold = None
+ x = geoopt.ManifoldTensor(
+ torch.randn(shape, dtype=torch.float64) * 0.1, manifold=manifold
+ )
ex = geoopt.ManifoldTensor(torch.randn_like(x), manifold=manifold)
v = x.proju(torch.randn_like(x))
ev = torch.randn_like(x)
@@ -105,6 +113,8 @@ def test_projection_via_assert(unary_case):
def test_vector_projection(unary_case):
if isinstance(unary_case.manifold, geoopt.manifolds.CanonicalStiefel):
pytest.skip("pymanopt uses euclidean Stiefel")
+ elif unary_case.manopt_manifold is None:
+ pytest.skip("pymanopt does not have {}".format(unary_case.manifold))
x = unary_case.x
ev = unary_case.ev
@@ -124,6 +134,8 @@ def test_vector_projection_via_assert(unary_case):
def test_retraction(unary_case, retraction_order, t):
+ if unary_case.manopt_manifold is None:
+ pytest.skip("pymanopt does not have {}".format(unary_case.manifold))
if isinstance(unary_case.manifold, geoopt.manifolds.CanonicalStiefel):
pytest.skip("pymanopt uses euclidean Stiefel")
x = unary_case.x
@@ -139,6 +151,8 @@ def test_retraction(unary_case, retraction_order, t):
def test_transport(unary_case, t):
+ if unary_case.manopt_manifold is None:
+ pytest.skip("pymanopt does not have {}".format(unary_case.manifold))
if isinstance(unary_case.manifold, geoopt.manifolds.CanonicalStiefel):
pytest.skip("pymanopt uses euclidean Stiefel")
x = unary_case.x
@@ -261,24 +275,28 @@ def test_broadcast_retr_transp_many(unary_case, t):
def test_reversibility(unary_case, t):
- torch.manual_seed(43)
- X = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype)
- U = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype)
- X = unary_case.manifold.projx(X)
- U = unary_case.manifold.proju(X, U)
- Z, Q = unary_case.manifold.retr_transp(X, U, u=U, t=t)
- X1, U1 = unary_case.manifold.retr_transp(Z, Q, u=Q, t=-t)
if unary_case.manifold.reversible:
+ torch.manual_seed(43)
+ X = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype)
+ U = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype)
+ X = unary_case.manifold.projx(X)
+ U = unary_case.manifold.proju(X, U)
+ Z, Q = unary_case.manifold.retr_transp(X, U, u=U, t=t)
+ X1, U1 = unary_case.manifold.retr_transp(Z, Q, u=Q, t=-t)
+
np.testing.assert_allclose(X1, X, atol=1e-5)
np.testing.assert_allclose(U1, U, atol=1e-5)
else:
- assert not np.allclose(X1, X, atol=1e-5)
- assert not np.allclose(U1, U, atol=1e-5)
+ pytest.skip("The manifold {} is not supposed to be checked")
def test_dist(unary_case):
if type(unary_case.manifold)._dist is geoopt.manifolds.base.not_implemented:
- pytest.skip("logmap is not implemented for {}".format(unary_case.manifold))
+ pytest.skip("dist is not implemented for {}".format(unary_case.manifold))
+ if unary_case.manopt_manifold is None:
+ pytest.skip(
+ "dist is not implemented for pymanopt {}".format(unary_case.manifold)
+ )
torch.manual_seed(43)
x = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype)
y = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype)
@@ -295,12 +313,16 @@ def test_logmap(unary_case, t):
x = unary_case.x
v = unary_case.v
- y = unary_case.manopt_manifold.exp(x.numpy(), v.numpy() * t)
- vman = unary_case.manopt_manifold.log(x.numpy(), y)
- vhat = unary_case.manifold.logmap(x, torch.as_tensor(y))
- np.testing.assert_allclose(vhat, vman)
+ if unary_case.manopt_manifold is not None:
+ y = unary_case.manopt_manifold.exp(x.numpy(), v.numpy() * t)
+ vman = unary_case.manopt_manifold.log(x.numpy(), y)
+ vhat = unary_case.manifold.logmap(x, torch.as_tensor(y))
+ np.testing.assert_allclose(vhat, vman, atol=1e-7)
+ else:
+ y = unary_case.manifold.expmap(x, v)
+ vhat = unary_case.manifold.logmap(x, torch.as_tensor(y))
ey = unary_case.manifold.expmap(x, vhat)
- np.testing.assert_allclose(y, ey)
+ np.testing.assert_allclose(y, ey, atol=1e-7)
def test_logmap_many(unary_case, t):
@@ -317,4 +339,4 @@ def test_logmap_many(unary_case, t):
Uh = unary_case.manifold.logmap(X, Y)
Yh = unary_case.manifold.expmap(X, Uh)
- np.testing.assert_allclose(Yh, Y)
+ np.testing.assert_allclose(Yh, Y, atol=1e-7)
diff --git a/tests/test_poincare_math.py b/tests/test_poincare_math.py
new file mode 100644
index 00000000..31780120
--- /dev/null
+++ b/tests/test_poincare_math.py
@@ -0,0 +1,371 @@
+"""
+Tests ideas are taken mostly from https://github.com/dalab/hyperbolic_nn/blob/master/util.py with some changes
+"""
+import torch
+import random
+import numpy as np
+import pytest
+from geoopt.manifolds import poincare
+
+
+@pytest.fixture("function", autouse=True, params=range(30, 40))
+def seed(request):
+ seed = request.param
+ torch.manual_seed(seed)
+ random.seed(seed)
+ return seed
+
+
+@pytest.fixture("function", params=[torch.float64, torch.float32])
+def dtype(request):
+ return request.param
+
+
+@pytest.fixture
+def c(seed, dtype):
+ # test broadcasted and non broadcasted versions
+ if seed == 30:
+ c = torch.tensor(0.0).to(dtype)
+ elif seed == 35:
+ c = torch.zeros(100, 1, dtype=dtype)
+ elif seed > 35:
+ c = torch.rand(100, 1, dtype=dtype)
+ else:
+ c = torch.tensor(random.random()).to(dtype)
+ return c + 1e-10
+
+
+@pytest.fixture
+def a(seed, c):
+ if seed in {30, 35}:
+ a = torch.randn(100, 10, dtype=c.dtype)
+ elif seed > 35:
+ # do not check numerically unstable regions
+ # I've manually observed small differences there
+ a = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1)
+ a /= a.norm(dim=-1, keepdim=True) * 1.3
+ a *= (torch.rand_like(c) * c) ** 0.5
+ else:
+ a = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1)
+ a /= a.norm(dim=-1, keepdim=True) * 1.3
+ a *= random.uniform(0, c) ** 0.5
+ return poincare.math.project(a, c=c)
+
+
+@pytest.fixture
+def b(seed, c):
+ if seed in {30, 35}:
+ b = torch.randn(100, 10, dtype=c.dtype)
+ elif seed > 35:
+ b = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1)
+ b /= b.norm(dim=-1, keepdim=True) * 1.3
+ b *= (torch.rand_like(c) * c) ** 0.5
+ else:
+ b = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1)
+ b /= b.norm(dim=-1, keepdim=True) * 1.3
+ b *= random.uniform(0, c) ** 0.5
+ return poincare.math.project(b, c=c)
+
+
+def test_mobius_addition_left_cancelation(a, b, c):
+ res = poincare.math.mobius_add(-a, poincare.math.mobius_add(a, b, c=c), c=c)
+ tolerance = {torch.float32: dict(atol=1e-6, rtol=1e-6), torch.float64: dict()}
+ np.testing.assert_allclose(res, b, **tolerance[c.dtype])
+
+
+def test_mobius_addition_zero_a(b, c):
+ a = torch.zeros(100, 10, dtype=c.dtype)
+ res = poincare.math.mobius_add(a, b, c=c)
+ np.testing.assert_allclose(res, b)
+
+
+def test_mobius_addition_zero_b(a, c):
+ b = torch.zeros(100, 10, dtype=c.dtype)
+ res = poincare.math.mobius_add(a, b, c=c)
+ np.testing.assert_allclose(res, a)
+
+
+def test_mobius_addition_negative_cancellation(a, c):
+ res = poincare.math.mobius_add(a, -a, c=c)
+ tolerance = {
+ torch.float32: dict(atol=1e-7, rtol=1e-6),
+ torch.float64: dict(atol=1e-10),
+ }
+ np.testing.assert_allclose(res, torch.zeros_like(res), **tolerance[c.dtype])
+
+
+def test_mobius_negative_addition(a, b, c):
+ res = poincare.math.mobius_add(-b, -a, c=c)
+ res1 = -poincare.math.mobius_add(b, a, c=c)
+ tolerance = {
+ torch.float32: dict(atol=1e-7, rtol=1e-6),
+ torch.float64: dict(atol=1e-10),
+ }
+ np.testing.assert_allclose(res, res1, **tolerance[c.dtype])
+
+
+@pytest.mark.parametrize("n", list(range(5)))
+def test_n_additions_via_scalar_multiplication(n, a, c):
+ y = torch.zeros_like(a)
+ for _ in range(n):
+ y = poincare.math.mobius_add(a, y, c=c)
+ ny = poincare.math.mobius_scalar_mul(n, a, c=c)
+ tolerance = {
+ torch.float32: dict(atol=1e-7, rtol=1e-6),
+ torch.float64: dict(atol=1e-10),
+ }
+ np.testing.assert_allclose(y, ny, **tolerance[c.dtype])
+
+
+@pytest.fixture
+def r1(seed, dtype):
+ if seed % 3 == 0:
+ return random.uniform(-1, 1)
+ else:
+ return torch.rand(100, 1, dtype=dtype) * 2 - 1
+
+
+@pytest.fixture
+def r2(seed, dtype):
+ if seed % 3 == 1:
+ return random.uniform(-1, 1)
+ else:
+ return torch.rand(100, 1, dtype=dtype) * 2 - 1
+
+
+def test_scalar_multiplication_distributive(a, c, r1, r2):
+ res = poincare.math.mobius_scalar_mul(r1 + r2, a, c=c)
+ res1 = poincare.math.mobius_add(
+ poincare.math.mobius_scalar_mul(r1, a, c=c),
+ poincare.math.mobius_scalar_mul(r2, a, c=c),
+ c=c,
+ )
+ res2 = poincare.math.mobius_add(
+ poincare.math.mobius_scalar_mul(r1, a, c=c),
+ poincare.math.mobius_scalar_mul(r2, a, c=c),
+ c=c,
+ )
+ tolerance = {
+ torch.float32: dict(atol=1e-6, rtol=1e-7),
+ torch.float64: dict(atol=1e-7, rtol=1e-10),
+ }
+ np.testing.assert_allclose(res1, res, **tolerance[c.dtype])
+ np.testing.assert_allclose(res2, res, **tolerance[c.dtype])
+
+
+def test_scalar_multiplication_associative(a, c, r1, r2):
+ res = poincare.math.mobius_scalar_mul(r1 * r2, a, c=c)
+ res1 = poincare.math.mobius_scalar_mul(
+ r1, poincare.math.mobius_scalar_mul(r2, a, c=c), c=c
+ )
+ res2 = poincare.math.mobius_scalar_mul(
+ r2, poincare.math.mobius_scalar_mul(r1, a, c=c), c=c
+ )
+ tolerance = {
+ torch.float32: dict(atol=1e-7, rtol=1e-6), # worked with rtol=1e-7 locally
+ torch.float64: dict(atol=1e-7, rtol=1e-10),
+ }
+ np.testing.assert_allclose(res1, res, **tolerance[c.dtype])
+ np.testing.assert_allclose(res2, res, **tolerance[c.dtype])
+
+
+def test_scaling_property(a, c, r1):
+ x1 = a / a.norm(dim=-1, keepdim=True)
+ ra = poincare.math.mobius_scalar_mul(r1, a, c=c)
+ x2 = poincare.math.mobius_scalar_mul(abs(r1), a, c=c) / ra.norm(
+ dim=-1, keepdim=True
+ )
+ tolerance = {
+ torch.float32: dict(rtol=1e-5, atol=1e-6),
+ torch.float64: dict(atol=1e-10),
+ }
+ np.testing.assert_allclose(x1, x2, **tolerance[c.dtype])
+
+
+def test_geodesic_borders(a, b, c):
+ geo0 = poincare.math.geodesic(0.0, a, b, c=c)
+ geo1 = poincare.math.geodesic(1.0, a, b, c=c)
+ tolerance = {
+ torch.float32: dict(rtol=1e-5, atol=1e-6),
+ torch.float64: dict(atol=1e-10),
+ }
+ np.testing.assert_allclose(geo0, a, **tolerance[c.dtype])
+ np.testing.assert_allclose(geo1, b, **tolerance[c.dtype])
+
+
+def test_geodesic_segment_length_property(a, b, c):
+ extra_dims = len(a.shape)
+ segments = 12
+ t = torch.linspace(0, 1, segments + 1, dtype=c.dtype).view(
+ (segments + 1,) + (1,) * extra_dims
+ )
+ gamma_ab_t = poincare.math.geodesic(t, a, b, c=c)
+ gamma_ab_t0 = gamma_ab_t[:-1]
+ gamma_ab_t1 = gamma_ab_t[1:]
+ dist_ab_t0mt1 = poincare.math.dist(gamma_ab_t0, gamma_ab_t1, c=c, keepdim=True)
+ speed = (
+ poincare.math.dist(a, b, c=c, keepdim=True)
+ .unsqueeze(0)
+ .expand_as(dist_ab_t0mt1)
+ )
+ # we have exactly 12 line segments
+ tolerance = {torch.float32: dict(rtol=1e-5), torch.float64: dict(atol=1e-10)}
+ np.testing.assert_allclose(dist_ab_t0mt1, speed / segments, **tolerance[c.dtype])
+
+
+def test_geodesic_segement_unit_property(a, b, c):
+ extra_dims = len(a.shape)
+ segments = 12
+ t = torch.linspace(0, 1, segments + 1, dtype=c.dtype).view(
+ (segments + 1,) + (1,) * extra_dims
+ )
+ gamma_ab_t = poincare.math.geodesic_unit(t, a, b, c=c)
+ gamma_ab_t0 = gamma_ab_t[:1]
+ gamma_ab_t1 = gamma_ab_t
+ dist_ab_t0mt1 = poincare.math.dist(gamma_ab_t0, gamma_ab_t1, c=c, keepdim=True)
+ true_distance_travelled = t.expand_as(dist_ab_t0mt1)
+ # we have exactly 12 line segments
+ tolerance = {
+ torch.float32: dict(atol=1e-6, rtol=1e-5),
+ torch.float64: dict(atol=1e-10),
+ }
+ np.testing.assert_allclose(
+ dist_ab_t0mt1, true_distance_travelled, **tolerance[c.dtype]
+ )
+
+
+def test_expmap_logmap(a, b, c):
+ # this test appears to be numerical unstable once a and b may appear on the opposite sides
+ bh = poincare.math.expmap(x=a, u=poincare.math.logmap(a, b, c=c), c=c)
+ tolerance = {torch.float32: dict(rtol=1e-5, atol=1e-6), torch.float64: dict()}
+ np.testing.assert_allclose(bh, b, **tolerance[c.dtype])
+
+
+def test_expmap0_logmap0(a, c):
+ # this test appears to be numerical unstable once a and b may appear on the opposite sides
+ v = poincare.math.logmap0(a, c=c)
+ norm = poincare.math.norm(torch.zeros_like(v), v, c=c, keepdim=True)
+ dist = poincare.math.dist0(a, c=c, keepdim=True)
+ bh = poincare.math.expmap0(v, c=c)
+ tolerance = {torch.float32: dict(rtol=1e-6), torch.float64: dict()}
+ np.testing.assert_allclose(bh, a, **tolerance[c.dtype])
+ np.testing.assert_allclose(norm, dist, **tolerance[c.dtype])
+
+
+def test_matvec_zeros(a, c):
+ mat = a.new_zeros(3, a.shape[-1])
+ z = poincare.math.mobius_matvec(mat, a, c=c)
+ np.testing.assert_allclose(z, 0.0)
+
+
+def test_matvec_via_equiv_fn_apply(a, c):
+ mat = a.new(3, a.shape[-1]).normal_()
+ y = poincare.math.mobius_fn_apply(lambda x: x @ mat.transpose(-1, -2), a, c=c)
+ y1 = poincare.math.mobius_matvec(mat, a, c=c)
+ tolerance = {torch.float32: dict(atol=1e-5), torch.float64: dict()}
+ np.testing.assert_allclose(y, y1, **tolerance[c.dtype])
+
+
+def test_mobiusify(a, c):
+ mat = a.new(3, a.shape[-1]).normal_()
+
+ @poincare.math.mobiusify
+ def matvec(x):
+ return x @ mat.transpose(-1, -2)
+
+ y = matvec(a, c=c)
+ y1 = poincare.math.mobius_matvec(mat, a, c=c)
+ tolerance = {torch.float32: dict(atol=1e-5), torch.float64: dict()}
+ np.testing.assert_allclose(y, y1, **tolerance[c.dtype])
+
+
+def test_matvec_chain_via_equiv_fn_apply(a, c):
+ mat1 = a.new(a.shape[-1], a.shape[-1]).normal_()
+ mat2 = a.new(a.shape[-1], a.shape[-1]).normal_()
+ y = poincare.math.mobius_fn_apply_chain(
+ a,
+ lambda x: x @ mat1.transpose(-1, -2),
+ lambda x: x @ mat2.transpose(-1, -2),
+ c=c,
+ )
+ y1 = poincare.math.mobius_matvec(mat1, a, c=c)
+ y1 = poincare.math.mobius_matvec(mat2, y1, c=c)
+ np.testing.assert_allclose(y, y1, atol=1e-5)
+
+
+def test_parallel_transport0_preserves_inner_products(a, c):
+ # pointing to the center
+ v_0 = torch.rand_like(a) + 1e-5
+ u_0 = torch.rand_like(a) + 1e-5
+ zero = torch.zeros_like(a)
+ v_a = poincare.math.parallel_transport0(a, v_0, c=c)
+ u_a = poincare.math.parallel_transport0(a, u_0, c=c)
+ # compute norms
+ vu_0 = poincare.math.inner(zero, v_0, u_0, c=c, keepdim=True)
+ vu_a = poincare.math.inner(a, v_a, u_a, c=c, keepdim=True)
+ np.testing.assert_allclose(vu_a, vu_0, atol=1e-6, rtol=1e-6)
+
+
+def test_parallel_transport0_is_same_as_usual(a, c):
+ # pointing to the center
+ v_0 = torch.rand_like(a) + 1e-5
+ zero = torch.zeros_like(a)
+ v_a = poincare.math.parallel_transport0(a, v_0, c=c)
+ v_a1 = poincare.math.parallel_transport(zero, a, v_0, c=c)
+ # compute norms
+ np.testing.assert_allclose(v_a, v_a1, atol=1e-6, rtol=1e-6)
+
+
+def test_parallel_transport_a_b(a, b, c):
+ # pointing to the center
+ v_0 = torch.rand_like(a)
+ u_0 = torch.rand_like(a)
+ v_1 = poincare.math.parallel_transport(a, b, v_0, c=c)
+ u_1 = poincare.math.parallel_transport(a, b, u_0, c=c)
+ # compute norms
+ vu_1 = poincare.math.inner(b, v_1, u_1, c=c, keepdim=True)
+ vu_0 = poincare.math.inner(a, v_0, u_0, c=c, keepdim=True)
+ np.testing.assert_allclose(vu_0, vu_1, atol=1e-6, rtol=1e-6)
+
+
+def test_add_infinity_and_beyond(a, b, c):
+ infty = b * 10000000
+ infty = poincare.math.clip_tangent(a, infty, c=c)
+ for i in range(100):
+ z = poincare.math.expmap(a, infty, c=c)
+ z = poincare.math.project(z, c=c)
+ z = poincare.math.mobius_scalar_mul(1000.0, z, c=c)
+ z = poincare.math.project(z, c=c)
+ infty = poincare.math.parallel_transport(a, z, infty, c=c)
+ assert np.isfinite(z).all(), (i, z)
+ assert np.isfinite(infty).all(), (i, infty)
+ a = z
+ z = poincare.math.expmap(a, -infty, c=c)
+ # they just need to be very far, exact answer is not supposed
+ tolerance = {
+ torch.float32: dict(rtol=3e-1, atol=2e-1),
+ torch.float64: dict(rtol=1e-1, atol=1e-3),
+ }
+ np.testing.assert_allclose(z, -a, **tolerance[c.dtype])
+
+
+def test_mobius_coadd(a, b, c):
+ # (a \boxplus_c b) \ominus_c b = a
+ ah = poincare.math.mobius_sub(poincare.math.mobius_coadd(a, b, c=c), b, c=c)
+ np.testing.assert_allclose(ah, a, atol=1e-5)
+
+
+def test_mobius_cosub(a, b, c):
+ # (a \oplus_c b) \boxminus b = a
+ ah = poincare.math.mobius_cosub(poincare.math.mobius_add(a, b, c=c), b, c=c)
+ np.testing.assert_allclose(ah, a, atol=1e-5)
+
+
+def test_distance2plane(a, c):
+ v = torch.rand_like(a)
+ vr = v / poincare.math.norm(a, v, c=c, keepdim=True)
+ z = poincare.math.expmap(a, vr, c=c)
+ dist1 = poincare.math.dist(a, z, c=c)
+ dist = poincare.math.dist2plane(z, a, vr, c=c)
+
+ np.testing.assert_allclose(dist, dist1, atol=1e-5, rtol=1e-5)
diff --git a/tests/test_rhmc.py b/tests/test_rhmc.py
index 3fb8bdd1..2ce01c8f 100644
--- a/tests/test_rhmc.py
+++ b/tests/test_rhmc.py
@@ -6,6 +6,15 @@
import geoopt.samplers.rhmc
+@pytest.fixture(autouse=True)
+def withdtype():
+ torch.set_default_dtype(torch.float64)
+ try:
+ yield
+ finally:
+ torch.set_default_dtype(torch.float32)
+
+
@pytest.mark.parametrize(
"params",
[
@@ -20,8 +29,6 @@
],
)
def test_leapfrog_reversibility(params):
- torch.set_default_dtype(torch.float64)
-
class NormalDist(torch.nn.Module):
def __init__(self, mu, sigma):
super().__init__()
diff --git a/tests/test_utils.py b/tests/test_utils.py
index e335ad82..da936667 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -106,3 +106,14 @@ def test_manifold_is_submodule():
container = torch.nn.ModuleDict({"sphere": sub_sphere})
container.to(torch.float64)
assert sub_sphere._projector.dtype == torch.float64
+
+
+def test_manifold_is_submodule_poincare():
+ print(torch.get_default_dtype())
+ c = torch.tensor(1.0)
+ ball = geoopt.manifolds.PoincareBall(c)
+ assert ball.c.dtype == torch.float32
+ ball.to(torch.float64)
+ container = torch.nn.ModuleDict({"ball": ball})
+ container.to(torch.float64)
+ assert ball.c.dtype == torch.float64