diff --git a/docs/conf.py b/docs/conf.py index e1209002..cb2e0ae0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,34 +36,35 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinx.ext.intersphinx', - 'sphinx.ext.napoleon', + "matplotlib.sphinxext.plot_directive", + "sphinx.ext.autodoc", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. # # source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'geoopt' -copyright = u'2018, Max Kochurov' -author = u'Max Kochurov' +project = u"geoopt" +copyright = u"2018, Max Kochurov" +author = u"Max Kochurov" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -93,7 +94,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -115,7 +116,7 @@ # show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] @@ -132,7 +133,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -166,7 +167,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied @@ -246,34 +247,30 @@ # html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. -htmlhelp_basename = 'geooptdoc' +htmlhelp_basename = "geooptdoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'geoopt.tex', u'geoopt Documentation', - u'Max Kochurov', 'manual'), + (master_doc, "geoopt.tex", u"geoopt Documentation", u"Max Kochurov", "manual") ] # The name of an image file (relative to this directory) to place at the top of @@ -313,10 +310,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'geoopt', u'geoopt Documentation', - [author], 1) -] +man_pages = [(master_doc, "geoopt", u"geoopt Documentation", [author], 1)] # If true, show URL addresses after external links. # @@ -329,9 +323,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'geoopt', u'geoopt Documentation', - author, 'geoopt', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "geoopt", + u"geoopt Documentation", + author, + "geoopt", + "One line description of project.", + "Miscellaneous", + ) ] # Documents to append as an appendix to all manuals. @@ -352,7 +352,7 @@ # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'numpy': ('http://docs.scipy.org/doc/numpy/', None), - 'python': ('https://docs.python.org/', None), - 'torch': ('https://pytorch.org/docs/master/', None), + "numpy": ("http://docs.scipy.org/doc/numpy/", None), + "python": ("https://docs.python.org/", None), + "torch": ("https://pytorch.org/docs/master/", None), } diff --git a/docs/devguide.rst b/docs/devguide.rst index 040469bb..b3674acd 100644 --- a/docs/devguide.rst +++ b/docs/devguide.rst @@ -1,5 +1,5 @@ -Extending ``geoopt`` -==================== +Developer Guide +=============== Base Manifold ------------- @@ -9,7 +9,7 @@ The common base class for all manifolds is :class:`geoopt.manifolds.base.Manifol .. autoclass:: geoopt.manifolds.base.Manifold :private-members: :members: - + :noindex: Metaclass --------- diff --git a/docs/extended.rst b/docs/extended.rst new file mode 100644 index 00000000..e2d1b35d --- /dev/null +++ b/docs/extended.rst @@ -0,0 +1,7 @@ +Extended Guide +============== + +.. toctree:: + :maxdepth: 1 + + extended/poincare diff --git a/docs/extended/poincare.rst b/docs/extended/poincare.rst new file mode 100644 index 00000000..c411803b --- /dev/null +++ b/docs/extended/poincare.rst @@ -0,0 +1,117 @@ +Poincare Ball model +=================== + +Poincare ball model is a compact representation of hyperbolic space. +To have a nice introduction into this model we should start from +simple concepts, putting them all together to build a more complete picture. + +Hyperbolic spaces +----------------- + +Hyperbolic space is a constant negative curvature Riemannian manifold. +A very simple example of Riemannian manifold with constant, but positive curvature is sphere. + +An (N+1)-dimensional hyperboloid spans the manifold that can be embedded into N-dimensional space via projections. + +.. figure:: ../plots/extended/poincare/hyperboloid_projection.png + :width: 300 + + img source `Wikipedia, Hyperboloid Model `_ + +Originally, the distance between points on the hyperboloid is defined as + +.. math:: + + d(x, y) = \operatorname{arccosh}(x, y) + +It is difficult to work in (N+1)-dimensional space and there is a range of useful embeddings +exist in literature + +Klein Model +~~~~~~~~~~~ + +.. figure:: ../plots/extended/poincare/klein_tiling.png + :width: 300 + + img source `Wikipedia, Klein Model `_ + + +Poincare Model +~~~~~~~~~~~~~~ + +.. figure:: ../plots/extended/poincare/poincare_lines.gif + :width: 300 + + img source `Bulatov, Poincare Model `_ + +Here we go. + +First of all we note, that Poincare ball is embedded in a Sphere of radius :math:`r=1/\sqrt{c}`, +where c is negative curvature. We also note, as :math:`c` goes to :math:`0`, we recover infinite radius ball. +We should expect this limiting behaviour recovers Euclidean geometry. + +To connect Euclidean space with its embedded manifold we need to get :math:`g_x`. +It is done via `conformal factor` :math:`\lambda^c_x`. + + +.. autofunction:: geoopt.manifolds.poincare.math.lambda_x + + +:math:`\lambda^c_x` connects Euclidean inner product with Riemannian one + +.. autofunction:: geoopt.manifolds.poincare.math.inner +.. autofunction:: geoopt.manifolds.poincare.math.norm +.. autofunction:: geoopt.manifolds.poincare.math.egrad2rgrad + +Math +---- +The good thing about Poincare ball is that it forms a Gyrogroup. Minimal definition of a Gyrogroup +assumes a binary operation :math:`*` defined that satisfies a set of properties. + +Left identity + For every element :math:`a\in G` there exist :math:`e\in G` such that :math:`e * a = a`. +Left Inverse + For every element :math:`a\in G` there exist :math:`b\in G` such that :math:`b * a = e` +Gyroassociativity + For any :math:`a,b,c\in G` there exist :math:`gyr[a, b]c\in G` such that :math:`a * (b * c)=(a * b) * gyr[a, b]c` +Gyroautomorphism + :math:`gyr[a, b]` is a magma automorphism in G +Left loop + :math:`gyr[a, b] = gyr[a * b, b]` + +As mentioned above, hyperbolic space forms a Gyrogroup equipped with + +.. autofunction:: geoopt.manifolds.poincare.math.mobius_add +.. autofunction:: geoopt.manifolds.poincare.math.gyration + +Using this math, it is possible to define another useful operations + +.. autofunction:: geoopt.manifolds.poincare.math.mobius_sub +.. autofunction:: geoopt.manifolds.poincare.math.mobius_scalar_mul +.. autofunction:: geoopt.manifolds.poincare.math.mobius_pointwise_mul +.. autofunction:: geoopt.manifolds.poincare.math.mobius_matvec +.. autofunction:: geoopt.manifolds.poincare.math.mobius_fn_apply +.. autofunction:: geoopt.manifolds.poincare.math.mobius_fn_apply_chain + +Manifold +-------- +Now we are ready to proceed with studying distances, geodesics, exponential maps and more + +.. autofunction:: geoopt.manifolds.poincare.math.dist +.. autofunction:: geoopt.manifolds.poincare.math.dist2plane +.. autofunction:: geoopt.manifolds.poincare.math.parallel_transport +.. autofunction:: geoopt.manifolds.poincare.math.geodesic +.. autofunction:: geoopt.manifolds.poincare.math.geodesic_unit +.. autofunction:: geoopt.manifolds.poincare.math.expmap +.. autofunction:: geoopt.manifolds.poincare.math.expmap0 +.. autofunction:: geoopt.manifolds.poincare.math.logmap +.. autofunction:: geoopt.manifolds.poincare.math.logmap0 + + +Stability +--------- +Numerical stability is a pain in this model. It is strongly recommended to work in ``float64``, +so expect adventures in ``float32`` (but this is not certain). + +.. autofunction:: geoopt.manifolds.poincare.math.project +.. autofunction:: geoopt.manifolds.poincare.math.clip_tangent diff --git a/docs/index.rst b/docs/index.rst index 9f62ee12..9ee6098a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,6 +17,7 @@ API optimizers tensors samplers + extended devguide Indices and tables diff --git a/docs/manifolds.rst b/docs/manifolds.rst index 97b65a0e..6dd3f579 100644 --- a/docs/manifolds.rst +++ b/docs/manifolds.rst @@ -4,13 +4,11 @@ Manifolds .. currentmodule:: geoopt.manifolds -All manifolds share same API. In order not to duplicate the same information, the complete public API is provided only for :class:`geoopt.manifolds.Euclidean` in the end of this file. +All manifolds share same API. In order not to duplicate the same information, the complete public API is provided only for :class:`geoopt.manifolds.Manifold` in the end of this file. .. automodule:: geoopt.manifolds - :members: Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection + :members: Euclidean, Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection, PoincareBall - -.. autoclass:: geoopt.manifolds.Euclidean +.. autoclass:: geoopt.manifolds.base.Manifold :members: - :inherited-members: diff --git a/docs/plots/extended/poincare/distance.py b/docs/plots/extended/poincare/distance.py new file mode 100644 index 00000000..33408101 --- /dev/null +++ b/docs/plots/extended/poincare/distance.py @@ -0,0 +1,27 @@ +import geoopt.manifolds.poincare.math as pmath +import torch +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +sns.set_style("white") +radius = 1 +coords = np.linspace(-radius, radius, 100) +x = torch.tensor([-0.75, 0]) +xx, yy = np.meshgrid(coords, coords) +dist2 = xx ** 2 + yy ** 2 +mask = dist2 <= radius ** 2 +grid = np.stack([xx, yy], axis=-1) +dists = pmath.dist(torch.from_numpy(grid).float(), x) +dists[(~mask).nonzero()] = np.nan +circle = plt.Circle((0, 0), 1, fill=False, color="b") +plt.gca().add_artist(circle) +plt.xlim(-1.1, 1.1) +plt.ylim(-1.1, 1.1) +plt.gca().set_aspect("equal") +plt.contourf( + grid[..., 0], grid[..., 1], dists.log().numpy(), levels=100, cmap="inferno" +) +plt.colorbar() +plt.title("log distance to ($-$0.75, 0)") +plt.show() diff --git a/docs/plots/extended/poincare/distance2plane.py b/docs/plots/extended/poincare/distance2plane.py new file mode 100644 index 00000000..66cbc8fd --- /dev/null +++ b/docs/plots/extended/poincare/distance2plane.py @@ -0,0 +1,35 @@ +import geoopt.manifolds.poincare.math as pmath +import torch +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib import rcParams + +rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" +rcParams["text.usetex"] = True + +sns.set_style("white") +radius = 1 +coords = np.linspace(-radius, radius, 100) +x = torch.tensor([-0.75, 0]) +v = torch.tensor([0.1 / 3, -1 / 3]) +xx, yy = np.meshgrid(coords, coords) +dist2 = xx ** 2 + yy ** 2 +mask = dist2 <= radius ** 2 +grid = np.stack([xx, yy], axis=-1) +dists = pmath.dist2plane(torch.from_numpy(grid).float(), x, v) +dists[(~mask).nonzero()] = np.nan +circle = plt.Circle((0, 0), 1, fill=False, color="b") +plt.gca().add_artist(circle) +plt.xlim(-1.1, 1.1) +plt.ylim(-1.1, 1.1) + +plt.gca().set_aspect("equal") +plt.contourf( + grid[..., 0], grid[..., 1], dists.log().numpy(), levels=100, cmap="inferno" +) +plt.colorbar() +plt.scatter(*x, color="g") +plt.arrow(*x, *v, color="g", width=0.01) +plt.title(r"log distance to $\tilde{H}_{a, p}$") +plt.show() diff --git a/docs/plots/extended/poincare/gyrovector_parallel_transport.py b/docs/plots/extended/poincare/gyrovector_parallel_transport.py new file mode 100644 index 00000000..894f6c94 --- /dev/null +++ b/docs/plots/extended/poincare/gyrovector_parallel_transport.py @@ -0,0 +1,52 @@ +import geoopt.manifolds.poincare.math as pmath +import torch +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib import rcParams + +rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" +rcParams["text.usetex"] = True + +sns.set_style("white") + +x = torch.tensor((-0.25, -0.75)) +xv1 = torch.tensor((np.sin(np.pi / 3), np.cos(np.pi / 3))) / 5 +xv2 = torch.tensor((np.sin(-np.pi / 3), np.cos(np.pi / 3))) / 5 +t = torch.linspace(0, 1, 10)[:, None] + +y = torch.tensor((0.65, -0.55)) +xy = pmath.logmap(x, y) +path = pmath.geodesic(t, x, y) +yv1 = pmath.parallel_transport(x, y, xv1) +yv2 = pmath.parallel_transport(x, y, xv2) + +xgv1 = pmath.geodesic_unit(t, x, xv1) +xgv2 = pmath.geodesic_unit(t, x, xv2) + +ygv1 = pmath.geodesic_unit(t, y, yv1) +ygv2 = pmath.geodesic_unit(t, y, yv2) + + +def plot_gv(gv, **kwargs): + plt.plot(*gv.t().numpy(), **kwargs) + plt.arrow(*gv[-2], *(gv[-1] - gv[-2]), width=0.01, **kwargs) + + +circle = plt.Circle((0, 0), 1, fill=False, color="b") +plt.gca().add_artist(circle) +plt.xlim(-1.1, 1.1) +plt.ylim(-1.1, 1.1) +plt.gca().set_aspect("equal") +plt.annotate("x", x - 0.09, fontsize=15) +plt.annotate("y", y - 0.09, fontsize=15) +plt.annotate(r"$\vec{v}$", x + torch.tensor([0.3, 0.5]), fontsize=15) +plot_gv(xgv1, color="r") +plot_gv(xgv2, color="b") +plt.arrow(*x, *xy, width=0.01, color="g") +plot_gv(ygv1, color="r") +plot_gv(ygv2, color="b") + +plt.plot(*path.t().numpy(), color="g") +plt.title(r"gyrovector parallel transport $P_{x\to y}$") +plt.show() diff --git a/docs/plots/extended/poincare/hyperboloid_projection.png b/docs/plots/extended/poincare/hyperboloid_projection.png new file mode 100644 index 00000000..df4843c5 Binary files /dev/null and b/docs/plots/extended/poincare/hyperboloid_projection.png differ diff --git a/docs/plots/extended/poincare/klein_tiling.png b/docs/plots/extended/poincare/klein_tiling.png new file mode 100644 index 00000000..c2184592 Binary files /dev/null and b/docs/plots/extended/poincare/klein_tiling.png differ diff --git a/docs/plots/extended/poincare/mobius_add.py b/docs/plots/extended/poincare/mobius_add.py new file mode 100644 index 00000000..159656a8 --- /dev/null +++ b/docs/plots/extended/poincare/mobius_add.py @@ -0,0 +1,29 @@ +import geoopt.manifolds.poincare.math as pmath +import torch +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib import rcParams + +rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" +rcParams["text.usetex"] = True + +sns.set_style("white") + +x = torch.tensor((-0.25, -0.75)) / 2 +y = torch.tensor((0.65, -0.55)) / 2 +x_plus_y = pmath.mobius_add(x, y) + + +circle = plt.Circle((0, 0), 1, fill=False, color="b") +plt.gca().add_artist(circle) +plt.xlim(-1.1, 1.1) +plt.ylim(-1.1, 1.1) +plt.gca().set_aspect("equal") +plt.annotate("x", x - 0.09, fontsize=15) +plt.annotate("y", y - 0.09, fontsize=15) +plt.annotate(r"$x\oplus y$", x_plus_y - torch.tensor([0.1, 0.15]), fontsize=15) +plt.arrow(0, 0, *x, width=0.01, color="r") +plt.arrow(0, 0, *y, width=0.01, color="g") +plt.arrow(0, 0, *x_plus_y, width=0.01, color="b") +plt.title(r"Addition $x\oplus y$") +plt.show() diff --git a/docs/plots/extended/poincare/mobius_matvec.py b/docs/plots/extended/poincare/mobius_matvec.py new file mode 100644 index 00000000..df02968d --- /dev/null +++ b/docs/plots/extended/poincare/mobius_matvec.py @@ -0,0 +1,30 @@ +import geoopt.manifolds.poincare.math as pmath +import torch +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib import rcParams + +rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" +rcParams["text.usetex"] = True +sns.set_style("white") +x = torch.tensor((-0.25, -0.75)) / 3 +M = torch.tensor([[-1, -1.5], [0.2, 0.5]]) +M_x = pmath.mobius_matvec(M, x) + + +circle = plt.Circle((0, 0), 1, fill=False, color="b") +plt.gca().add_artist(circle) +plt.xlim(-1.1, 1.1) +plt.ylim(-1.1, 1.1) +plt.gca().set_aspect("equal") +plt.annotate("x", x - 0.09, fontsize=15) +plt.annotate( + r"$M=\begin{bmatrix}-1 &-1.5\\.2 &.5\end{bmatrix}$", + x + torch.tensor([-0.5, 0.5]), + fontsize=15, +) +plt.annotate(r"$M^\otimes x$", M_x - torch.tensor([0.1, 0.15]), fontsize=15) +plt.arrow(0, 0, *x, width=0.01, color="r") +plt.arrow(0, 0, *M_x, width=0.01, color="b") +plt.title(r"Matrix multiplication $M\otimes x$") +plt.show() diff --git a/docs/plots/extended/poincare/mobius_sigmoid_apply.py b/docs/plots/extended/poincare/mobius_sigmoid_apply.py new file mode 100644 index 00000000..deeb7f7e --- /dev/null +++ b/docs/plots/extended/poincare/mobius_sigmoid_apply.py @@ -0,0 +1,27 @@ +import geoopt.manifolds.poincare.math as pmath +from matplotlib import rcParams +import torch +import matplotlib.pyplot as plt +import seaborn as sns + +sns.set_style("white") +rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" +rcParams["text.usetex"] = True +x = torch.tensor((-0.25, -0.75)) / 3 +f_x = pmath.mobius_fn_apply(torch.sigmoid, x) + + +circle = plt.Circle((0, 0), 1, fill=False, color="b") +plt.gca().add_artist(circle) +plt.xlim(-1.1, 1.1) +plt.ylim(-1.1, 1.1) +plt.gca().set_aspect("equal") +plt.annotate("x", x - 0.09, fontsize=15) +plt.annotate( + r"$\sigma(x)=\frac{1}{1+e^{-x}}$", x + torch.tensor([-0.7, 0.5]), fontsize=15 +) +plt.annotate(r"$\sigma^\otimes(x)$", f_x - torch.tensor([0.1, 0.15]), fontsize=15) +plt.arrow(0, 0, *x, width=0.01, color="r") +plt.arrow(0, 0, *f_x, width=0.01, color="b") +plt.title(r"Mobius function (sigmoid) apply $\sigma^\otimes(x)$") +plt.show() diff --git a/docs/plots/extended/poincare/parallel_transport.py b/docs/plots/extended/poincare/parallel_transport.py new file mode 100644 index 00000000..be624c52 --- /dev/null +++ b/docs/plots/extended/poincare/parallel_transport.py @@ -0,0 +1,40 @@ +import geoopt.manifolds.poincare.math as pmath +import torch +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib import rcParams + +rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" +rcParams["text.usetex"] = True + + +sns.set_style("white") + +x = torch.tensor((-0.25, -0.75)) +v1 = torch.tensor((np.sin(np.pi / 3), np.cos(np.pi / 3))) / 5 +v2 = torch.tensor((np.sin(-np.pi / 3), np.cos(np.pi / 3))) / 5 +y = torch.tensor((0.65, -0.55)) +t = torch.linspace(0, 1) +xy = pmath.logmap(x, y) +path = pmath.geodesic(t[:, None], x, y) +yv1 = pmath.parallel_transport(x, y, v1) +yv2 = pmath.parallel_transport(x, y, v2) + + +circle = plt.Circle((0, 0), 1, fill=False, color="b") +plt.gca().add_artist(circle) +plt.xlim(-1.1, 1.1) +plt.ylim(-1.1, 1.1) +plt.gca().set_aspect("equal") +plt.annotate("x", x - 0.07, fontsize=15) +plt.annotate("y", y - 0.07, fontsize=15) +plt.annotate(r"$\vec{v}$", x + torch.tensor([0.3, 0.5]), fontsize=15) +plt.arrow(*x, *v1, width=0.01, color="r") +plt.arrow(*x, *xy, width=0.01, color="g") +plt.arrow(*x, *v2, width=0.01, color="b") +plt.arrow(*y, *yv1, width=0.01, color="r") +plt.arrow(*y, *yv2, width=0.01, color="b") +plt.plot(*path.t().numpy(), color="g") +plt.title(r"parallel transport $P^c_{x\to y}$") +plt.show() diff --git a/docs/plots/extended/poincare/poincare_lines.gif b/docs/plots/extended/poincare/poincare_lines.gif new file mode 100644 index 00000000..da427047 Binary files /dev/null and b/docs/plots/extended/poincare/poincare_lines.gif differ diff --git a/geoopt/__init__.py b/geoopt/__init__.py index 43f95f74..fe0609f3 100644 --- a/geoopt/__init__.py +++ b/geoopt/__init__.py @@ -11,6 +11,7 @@ Sphere, SphereSubspaceIntersection, SphereSubspaceComplementIntersection, + PoincareBall, ) __version__ = "0.0.1" diff --git a/geoopt/linalg/_expm.py b/geoopt/linalg/_expm.py index 1727a33d..1ffb88a7 100644 --- a/geoopt/linalg/_expm.py +++ b/geoopt/linalg/_expm.py @@ -3,7 +3,7 @@ maintainer: ferrine """ -import torch +import torch.jit @torch.jit.script @@ -73,6 +73,7 @@ def expm_one(A): # pragma: no cover U, V = torch_pade13(Ascaled) P = U + V Q = -U + V + # TODO: torch.gesv -> torch.solve after pytorch release R, _ = torch.gesv(P, Q) # solve P = Q*R expmA = matrix_2_power(R, n_squarings) return expmA diff --git a/geoopt/linalg/batch_linalg.py b/geoopt/linalg/batch_linalg.py index f4d8842e..2a7aabe7 100644 --- a/geoopt/linalg/batch_linalg.py +++ b/geoopt/linalg/batch_linalg.py @@ -1,4 +1,4 @@ -import torch +import torch.jit from . import _expm __all__ = ["svd", "qr", "sym", "extract_diag", "matrix_rank", "expm", "block_matrix"] diff --git a/geoopt/manifolds/__init__.py b/geoopt/manifolds/__init__.py index a54c88b6..99fe4840 100644 --- a/geoopt/manifolds/__init__.py +++ b/geoopt/manifolds/__init__.py @@ -6,3 +6,5 @@ SphereSubspaceComplementIntersection, SphereSubspaceIntersection, ) +from .poincare import PoincareBall +from . import poincare diff --git a/geoopt/manifolds/base.py b/geoopt/manifolds/base.py index 3440cb86..ee1949e7 100644 --- a/geoopt/manifolds/base.py +++ b/geoopt/manifolds/base.py @@ -210,7 +210,7 @@ def set_default_order(self, order): Returns ------- - self + Manifold returns same instance """ if order is None: @@ -245,7 +245,7 @@ def reset_default_order(self): Returns ------- - self + Manifold returns same instance """ return self.set_default_order(None) @@ -471,7 +471,7 @@ def assert_check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5): ) ) - def dist(self, x, y): + def dist(self, x, y, keepdim=False): """ Compute distance between 2 points on the manifold that is the shortest path along geodesics @@ -481,13 +481,15 @@ def dist(self, x, y): point on the manifold y : tensor point on the manifold + keepdim : bool + keep the last dim? Returns ------- scalar distance between two points """ - return self._dist(x, y) + return self._dist(x, y, keepdim=keepdim) def retr(self, x, u, t=1.0, order=None): """ @@ -638,7 +640,7 @@ def transp(self, x, v, *more, u=None, t=1.0, y=None, order=None): else: raise TypeError("transp() requires either y or u") - def inner(self, x, u, v=None): + def inner(self, x, u, v=None, keepdim=False): """ Inner product for tangent vectors at point :math:`x` @@ -650,6 +652,8 @@ def inner(self, x, u, v=None): tangent vector at point :math:`x` v : tensor (optional) tangent vector at point :math:`x` + keepdim : bool + keep the last dim? Returns ------- @@ -658,7 +662,7 @@ def inner(self, x, u, v=None): """ if v is None and self._inner_autofill: v = u - return self._inner(x, u, v) + return self._inner(x, u, v, keepdim=keepdim) # dev: autofill None parameter or propagate None? _inner_autofill = True @@ -907,7 +911,7 @@ def _retr(self, x, u, t): _dist = not_implemented @abc.abstractmethod - def _inner(self, x, u, v): + def _inner(self, x, u, v, keepdim): """ Developer Guide diff --git a/geoopt/manifolds/euclidean.py b/geoopt/manifolds/euclidean.py index e4f5576d..60bf4f3c 100644 --- a/geoopt/manifolds/euclidean.py +++ b/geoopt/manifolds/euclidean.py @@ -24,7 +24,7 @@ def _check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5): def _retr(self, x, u, t): return x + t * u - def _inner(self, x, u, v): + def _inner(self, x, u, v, keepdim): return u * v def _proju(self, x, u): @@ -50,5 +50,5 @@ def _transp2y(self, x, v, *more, y): def _logmap(self, x, y): return y - x - def _dist(self, x, y): + def _dist(self, x, y, keepdim): return (x - y).abs() diff --git a/geoopt/manifolds/poincare/__init__.py b/geoopt/manifolds/poincare/__init__.py new file mode 100644 index 00000000..1d8d723f --- /dev/null +++ b/geoopt/manifolds/poincare/__init__.py @@ -0,0 +1,102 @@ +import torch +from . import math +from ..base import Manifold + +__all__ = ["PoincareBall"] + + +class PoincareBall(Manifold): + """ + Poincare ball model, see more in :doc:`/extended/poincare` + + Parameters + ---------- + c : float|tensor + ball negative curvature + + Notes + ----- + It is extremely recommended to work with this manifold in double precision + """ + + ndim = 1 + reversible = False + _default_order = 1 + name = "Poincare ball" + + def __init__(self, c=1.0): + super().__init__() + self.register_buffer("c", torch.as_tensor(c)) + + def _check_shape(self, x, name): + ok = x.dim() > 0 + if not ok: + reason = "'{}' on poincare ball requires more that zero dim".format(name) + else: + reason = None + return ok, reason + + def _check_point_on_manifold(self, x, atol=1e-5, rtol=1e-5): + px = math.project(x, c=self.c) + ok = torch.allclose(x, px, atol=atol, rtol=rtol) + if not ok: + reason = "'x' norm lies out of the bounds [-1/sqrt(c)+eps, 1/sqrt(c)-eps]" + else: + reason = None + return ok, reason + + def _check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5): + return True, None + + def _dist(self, x, y, keepdim): + return math.dist(x, y, c=self.c, keepdim=keepdim) + + def _egrad2rgrad(self, x, u): + return math.egrad2rgrad(x, u, c=self.c) + + def _retr(self, x, u, t): + # always assume u is scaled properly + approx = x + u * t + return math.project(approx, c=self.c) + + _retr_transp_default_preference = "2y" + + def _projx(self, x): + return math.project(x, c=self.c) + + def _proju(self, x, u): + return math.clip_tangent(x, u, c=self.c) + + def _inner(self, x, u, v, keepdim): + return math.inner(x, u, v, c=self.c, keepdim=keepdim) + + def _expmap(self, x, u, t): + return math.project(math.expmap(x, u * t, c=self.c), c=self.c) + + def _logmap(self, x, y): + return math.logmap(x, y, c=self.c) + + def _transp2y(self, x, v, *more, y): + if not more: + return math.parallel_transport(x, y, v, c=self.c) + else: + n = len(more) + 1 + vecs = torch.stack((v,) + more, dim=0) + transp = math.parallel_transport(x, y, vecs, c=self.c) + return tuple(transp[i] for i in range(n)) + + def _transp_follow(self, x, v, *more, u, t): + y = self._retr(x, u, t) + return self._transp2y(x, v, *more, y=y) + + def _expmap_transp(self, x, v, *more, u, t): + y = self._expmap(x, u, t) + vs = self._transp2y(x, v, *more, y=y) + if more: + return (y,) + vs + else: + return y, vs + + def _transp_follow_expmap(self, x, v, *more, u, t): + y = self._expmap(x, u, t) + return self._transp2y(x, v, *more, y=y) diff --git a/geoopt/manifolds/poincare/math.py b/geoopt/manifolds/poincare/math.py new file mode 100644 index 00000000..bd8f9469 --- /dev/null +++ b/geoopt/manifolds/poincare/math.py @@ -0,0 +1,1347 @@ +""" +Functions for math on Poincare ball model. Most of this is taken from +a well written paper by Octavian-Eugen Ganea (2018) [1]_ + + +.. [1] Octavian-Eugen Ganea et al., Hyperbolic Neural Networks, NIPS 2018 +""" + +import functools +import torch.jit + + +def tanh(x): + return x.clamp(-15, 15).tanh() + + +class Artanh(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + x = x.clamp(-1 + 1e-15, 1 - 1e-15) + ctx.save_for_backward(x) + dtype = x.dtype + x = x.double() + res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5) + return res.to(dtype) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output / (1 - input ** 2) + + +class Arsinh(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + z = x.double() + return (z + torch.sqrt_(1 + z.pow(2))).clamp_min_(1e-15).log_().to(x.dtype) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output / (1 + input ** 2) ** 0.5 + + +def artanh(x): + return Artanh.apply(x) + + +def arsinh(x): + return Arsinh.apply(x) + + +def project(x, *, c=1.0, dim=-1): + r""" + Safe projection on the manifold for numerical stability. + + Parameters + ---------- + x : tensor + point on the Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension to compute norm + + Returns + ------- + tensor + projected vector on the manifold + """ + return _project(x, c, dim) + + +@torch.jit.script +def _max_norm(x): + if x.dtype == torch.float32: + maxnorm = torch.full((), 1 - 3e-3, dtype=x.dtype, device=x.device) + else: + maxnorm = torch.full((), 1 - 1e-5, dtype=x.dtype, device=x.device) + return maxnorm + + +def _project(x, c, dim: int = -1): + norm = x.norm(dim=dim, keepdim=True, p=2) + maxnorm = _max_norm(x) / (c ** 0.5) + cond = norm > maxnorm + projected = x / norm * maxnorm + return torch.where(cond, projected, x) + + +def lambda_x(x, *, c=1.0, keepdim=False, dim=-1): + r""" + Compute the conformal factor :math:`\lambda^c_x` for a point on the ball + + .. math:: + + \lambda^c_x = \frac{1}{1 - c \|x\|_2^2} + + Parameters + ---------- + x : tensor + point on the Poincare ball + c : float|tensor + ball negative curvature + keepdim : bool + retain the last dim? (default: false) + dim : int + reduction dimension + + Returns + ------- + tensor + conformal factor + """ + return _lambda_x(x, c, keepdim=keepdim, dim=dim) + + +def _lambda_x(x, c, keepdim: bool = False, dim: int = -1): + return 2 / (1 - c * x.pow(2).sum(dim=dim, keepdim=keepdim)) + + +def inner(x, u, v, *, c=1.0, keepdim=False, dim=-1): + r""" + Compute inner product for two vectors on the tangent space w.r.t Riemannian metric on the Poincare ball + + .. math:: + + \langle u, v\rangle_x = (\lambda^c_x)^2 \langle u, v \rangle + + Parameters + ---------- + x : tensor + point on the Poincare ball + u : tensor + tangent vector to :math:`x` on Poincare ball + v : tensor + tangent vector to :math:`x` on Poincare ball + c : float|tensor + ball negative curvature + keepdim : bool + retain the last dim? (default: false) + dim : int + reduction dimension + + Returns + ------- + tensor + inner product + """ + return _inner(x, u, v, c, keepdim=keepdim, dim=dim) + + +def _inner(x, u, v, c, keepdim: bool = False, dim: int = -1): + return _lambda_x(x, c, keepdim=True, dim=dim) ** 2 * (u * v).sum( + dim=dim, keepdim=keepdim + ) + + +def norm(x, u, *, c=1.0, keepdim=False, dim=-1): + r""" + Compute vector norm on the tangent space w.r.t Riemannian metric on the Poincare ball + + .. math:: + + \|u\|_x = \lambda^c_x \|u\|_2 + + Parameters + ---------- + x : tensor + point on the Poincare ball + u : tensor + tangent vector to :math:`x` on Poincare ball + c : float|tensor + ball negative curvature + keepdim : bool + retain the last dim? (default: false) + dim : int + reduction dimension + + Returns + ------- + tensor + norm of vector + """ + return _norm(x, u, c, keepdim=keepdim, dim=dim) + + +def _norm(x, u, c, keepdim: bool = False, dim: int = -1): + return _lambda_x(x, c, keepdim=keepdim, dim=dim) * u.norm( + dim=dim, keepdim=keepdim, p=2 + ) + + +def mobius_add(x, y, *, c=1.0, dim=-1): + r""" + Mobius addition is a special operation in a hyperbolic space. + + .. math:: + + x \oplus_c y = \frac{ + (1 + 2 c \langle x, y\rangle + c \|y\|^2_2) x + (1 - c \|x\|_2^2) y + }{ + 1 + 2 c \langle x, y\rangle + c^2 \|x\|^2_2 \|y\|^2_2 + } + + .. plot:: plots/extended/poincare/mobius_add.py + + In general this operation is not commutative: + + .. math:: + + x \oplus_c y \ne y \oplus_c x + + But in some cases this property holds: + + * zero vector case + + .. math:: + + \mathbf{0} \oplus_c x = x \oplus_c \mathbf{0} + + * zero negative curvature case that is same as Euclidean addition + + .. math:: + + x \oplus_0 y = y \oplus_0 x + + Another useful property is so called left-cancellation law: + + .. math:: + + (-x) \oplus_c (x \oplus_c y) = y + + Parameters + ---------- + x : tensor + point on the Poincare ball + y : tensor + point on the Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + the result of mobius addition + """ + return _mobius_add(x, y, c, dim=dim) + + +def _mobius_add(x, y, c, dim=-1): + y = y + 1e-15 + x2 = x.pow(2).sum(dim=dim, keepdim=True) + y2 = y.pow(2).sum(dim=dim, keepdim=True) + xy = (x * y).sum(dim=dim, keepdim=True) + num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y + denom = 1 + 2 * c * xy + c ** 2 * x2 * y2 + # avoid division by zero in this way + return num / (denom + 1e-15) + + +def mobius_sub(x, y, *, c=1.0, dim=-1): + r""" + Mobius substraction that can be represented via Mobius addition as follows: + + .. math:: + + x \ominus_c y = x \oplus_c (-y) + + Parameters + ---------- + x : tensor + point on Poincare ball + y : tensor + point on Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + the result of mobius substraction + """ + return _mobius_sub(x, y, c, dim=dim) + + +def _mobius_sub(x, y, c, dim: int = -1): + return _mobius_add(x, -y, c, dim=dim) + + +def mobius_coadd(x, y, *, c=1.0, dim=-1): + r""" + Mobius coaddition operation + + Addition operation :math:`\oplus_c` is neither associative, nor commutative. Coaddition, or cooperation in + Gyrogroup is an associative operation that is defined as follows. + + .. math:: + + a \boxplus_c b = b \boxplus_c a = a\operatorname{gyr}[a, -b]b\\ + = \frac{ + (1 + c \|y\|^2_2) x + (1 - c \|x\|_2^2) y + }{ + 1 + c^2 \|x\|^2_2 \|y\|^2_2 + }, + + where :math:`\operatorname{gyr}[a, b]c = \ominus_c (a \oplus b) \oplus_c (a \oplus_c (b \oplus_c c))` + + The following right cancellation property holds + + .. math:: + + (a \boxplus_c b) \ominus_c b = a\\ + (a \oplus_c b) \boxminus_c b = a + + Parameters + ---------- + x : tensor + point on Poincare ball + y : tensor + point on Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + the result of mobius coaddition + + """ + return _mobius_coadd(x, y, c, dim=dim) + + +def _mobius_coadd(x, y, c, dim: int = -1): + y = y + 1e-15 + x2 = x.pow(2).sum(dim=dim, keepdim=True) + y2 = y.pow(2).sum(dim=dim, keepdim=True) + num = (1 - c * y2) * x + (1 - c * x2) * y + denom = 1 - c ** 2 * x2 * y2 + # avoid division by zero in this way + return num / (denom + 1e-15) + + +def mobius_cosub(x, y, *, c=1.0, dim=-1): + """ + Mobius cosubstraction operation + + .. math:: + + a \boxminus_c b = a \boxplus_c -b + + Parameters + ---------- + x : tensor + point on Poincare ball + y : tensor + point on Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + the result of mobius coaddition + + """ + return _mobius_cosub(x, y, c, dim=dim) + + +def _mobius_cosub(x, y, c, dim: int = -1): + return _mobius_coadd(x, -y, c, dim=dim) + + +def mobius_scalar_mul(r, x, *, c=1.0, dim=-1): + r""" + Left scalar multiplication on the Poincare ball + + .. math:: + + r \otimes_c x = (1/\sqrt{c}) \tanh(r\tanh^{-1}(\sqrt{c}\|x\|_2))\frac{x}{\|x\|_2} + + This operation has properties similar to Euclidean + + * `n-addition` property + + .. math:: + + r \otimes_c x = x \oplus_c \dots \oplus_c x + + * Distributive property + + .. math:: + + (r_1 + r_2) \otimes_c x = r_1 \otimes_c x \oplus r_2 \otimes_c x + + * Scalar associativity + + .. math:: + + (r_1 r_2) \otimes_c x = r_1 \otimes_c (r_2 \otimes_c x) + + * Scaling property + + .. math:: + + |r| \otimes_c x / \|r \otimes_c x\|_2 = x/\|x\|_2 + + Parameters + ---------- + r : float|tensor + scalar for multiplication + x : tensor + point on Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + the result of mobius scalar multiplication + """ + return _mobius_scalar_mul(r, x, c, dim=dim) + + +def _mobius_scalar_mul(r, x, c, dim: int = -1): + x = x + 1e-15 + x_norm = x.norm(dim=dim, keepdim=True, p=2) + sqrt_c = c ** 0.5 + res_c = tanh(r * artanh(sqrt_c * x_norm)) * x / (x_norm * sqrt_c) + return res_c + + +def dist(x, y, *, c=1.0, keepdim=False, dim=-1): + r""" + Distance on the Poincare ball + + .. math:: + + d_c(x, y) = \frac{2}{\sqrt{c}}\tanh^{-1}(\sqrt{c}\|(-x)\oplus_c y\|_2) + + .. plot:: plots/extended/poincare/distance.py + + Parameters + ---------- + x : tensor + point on Poincare ball + y : tensor + point on Poincare ball + c : float|tensor + ball negative curvature + keepdim : bool + retain the last dim? (default: false) + dim : int + reduction dimension + + Returns + ------- + tensor + geodesic distance between :math:`x` and :math:`y` + """ + return _dist(x, y, c, keepdim=keepdim, dim=dim) + + +def _dist(x, y, c, keepdim: bool = False, dim: int = -1): + sqrt_c = c ** 0.5 + dist_c = artanh( + sqrt_c * _mobius_add(-x, y, c, dim=dim).norm(dim=dim, p=2, keepdim=keepdim) + ) + return dist_c * 2 / sqrt_c + + +def dist0(x, *, c=1.0, keepdim=False, dim=-1): + r""" + Distance on the Poincare ball to zero + + Parameters + ---------- + x : tensor + point on Poincare ball + c : float|tensor + ball negative curvature + keepdim : bool + retain the last dim? (default: false) + dim : int + reduction dimension for operations + + Returns + ------- + tensor + geodesic distance between :math:`x` and :math:`0` + """ + return _dist0(x, c, keepdim=keepdim, dim=dim) + + +def _dist0(x, c, keepdim: bool = False, dim: int = -1): + sqrt_c = c ** 0.5 + dist_c = artanh(sqrt_c * x.norm(dim=dim, p=2, keepdim=keepdim)) + return dist_c * 2 / sqrt_c + + +def clip_tangent(x, u, *, c=1.0, dim=-1): + r""" + Project tangent vector to reasonable values that do not exceed + maximum allowed (vector norm allowing to travel to the opposite pole) + + .. math:: + + \operatorname{maxnorm}_x = d_{c}(\operatorname{proj}(-\infty), \operatorname{proj}(\infty)) / \lambda_x^c + + Parameters + ---------- + x : tensor + point on Poincare ball + u : tensor + tangent vector + c : float|tensor + ball negative curvature + dim : int + reduction dimension to compute norm + + Returns + ------- + tensor + same tangent vector with reasonable values + """ + return _clip_tangent(x, u, c, dim=dim) + + +def _clip_tangent(x, u, c, dim: int = -1): + # get the almost infinite vecotor estimate + # this is the norm of travel vector to the opposite pole + s = x.size(dim) + p = torch.ones((s,), dtype=x.dtype, device=x.device) + p = p / s ** 0.5 / (c ** 0.5) + p = _project(p, c, dim=dim) + # normalize its length based on x + maxnorm = _dist(p, -p, c, keepdim=True, dim=dim) / _lambda_x( + x, c, keepdim=True, dim=dim + ) + norm = u.norm(dim=dim, keepdim=True, p=2) + cond = norm > maxnorm + projected = u / norm * maxnorm + return torch.where(cond, projected, u) + + +def geodesic(t, x, y, *, c=1.0, dim=-1): + r""" + Geodesic (the shortest) path connecting :math:`x` and :math:`y`. + The path can be treated as and extension of a line segment between + points but in a Riemannian manifold. In Poincare ball model, the path + is expressed using Mobius addition and scalar multiplication: + + .. math:: + + \gamma_{x\to y}(t) = x \oplus_c r \otimes_c ((-x) \oplus_c y) + + The required properties of this path are the following: + + .. math:: + + \gamma_{x\to y}(0) = x\\ + \gamma_{x\to y}(1) = y\\ + \dot\gamma_{x\to y}(t) = v + + Moreover, as geodesic path is not only the shortest path connecting points and Poincare ball. + This definition also requires local distance minimization and thus another property appears: + + .. math:: + + d_c(\gamma_{x\to y}(t_1), \gamma_{x\to y}(t_2)) = v|t_1-t_2| + + "Natural parametrization" of the curve ensures unit speed geodesics which yields the above formula with :math:`v=1`. + However, for Poincare ball we can always compute the constant speed :math:`v` from the points + that particular path connects: + + .. math:: + + v = d_c(\gamma_{x\to y}(0), \gamma_{x\to y}(1)) = d_c(x, y) + + + Parameters + ---------- + t : float|tensor + travelling time + x : tensor + starting point on Poincare ball + y : tensor + target point on Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + point on the Poincare ball + """ + return _geodesic(t, x, y, c, dim=dim) + + +def _geodesic(t, x, y, c, dim: int = -1): + # this is not very numerically unstable + v = _mobius_add(-x, y, c, dim=dim) + tv = _mobius_scalar_mul(t, v, c, dim=dim) + gamma_t = _mobius_add(x, tv, c, dim=dim) + return gamma_t + + +def expmap(x, u, *, c=1.0, dim=-1): + r""" + Exponential map for Poincare ball model. This is tightly related with :func:`geodesic`. + Intuitively Exponential map is a smooth constant travelling from starting point :math:`x` with speed :math:`u`. + + A bit more formally this is travelling along curve :math:`\gamma_{x, u}(t)` such that + + .. math:: + + \gamma_{x, u}(0) = x\\ + \dot\gamma_{x, u}(0) = u\\ + \|\dot\gamma_{x, u}(t)\|_{\gamma_{x, u}(t)} = \|u\|_x + + The existence of this curve relies on uniqueness of differential equation solution, that is local. + For the Poincare ball model the solution is well defined globally and we have. + + .. math:: + + \operatorname{Exp}^c_x(u) = \gamma_{x, u}(1) = \\ + x\oplus_c \tanh(\sqrt{c}/2 \|u\|_x) \frac{u}{\sqrt{c}\|u\|_2} + + Parameters + ---------- + x : tensor + starting point on Poincare ball + u : tensor + speed vector on Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + :math:`\gamma_{x, u}(1)` end point + """ + return _expmap(x, u, c, dim=dim) + + +def _expmap(x, u, c, dim: int = -1): + u += 1e-15 + sqrt_c = c ** 0.5 + u_norm = u.norm(dim=dim, p=2, keepdim=True) + second_term = ( + tanh(sqrt_c / 2 * _lambda_x(x, c, keepdim=True, dim=dim) * u_norm) + * u + / (sqrt_c * u_norm) + ) + gamma_1 = _mobius_add(x, second_term, c, dim=dim) + return gamma_1 + + +def expmap0(u, *, c=1.0, dim=-1): + r""" + Exponential map for Poincare ball model from :math:`0`. + + .. math:: + + \operatorname{Exp}^c_0(u) = \tanh(\sqrt{c}/2 \|u\|_2) \frac{u}{\sqrt{c}\|u\|_2} + + Parameters + ---------- + u : tensor + speed vector on Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + :math:`\gamma_{0, u}(1)` end point + """ + return _expmap0(u, c, dim=dim) + + +def _expmap0(u, c, dim: int = -1): + u = u + 1e-15 + sqrt_c = c ** 0.5 + u_norm = u.norm(dim=dim, p=2, keepdim=True) + gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm) + return gamma_1 + + +def geodesic_unit(t, x, u, *, c=1.0, dim=-1): + r""" + Unit speed geodesic starting from :math:`x` with direction :math:`u/\|u\|_x` + + .. math:: + + \gamma_{x,u}(t) = x\oplus_c \tanh(t\sqrt{c}/2) \frac{u}{\sqrt{c}\|u\|_2} + + Parameters + ---------- + t : tensor + travelling time + x : tensor + initial point + u : tensor + direction + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + the point on geodesic line + """ + return _geodesic_unit(t, x, u, c, dim=dim) + + +def _geodesic_unit(t, x, u, c, dim: int = -1): + sqrt_c = c ** 0.5 + u_norm = u.norm(dim=dim, p=2, keepdim=True) + second_term = tanh(sqrt_c / 2 * t) * u / (sqrt_c * u_norm) + gamma_1 = _mobius_add(x, second_term, c, dim=dim) + return gamma_1 + + +def logmap(x, y, *, c=1.0, dim=-1): + r""" + Logarithmic map for two points :math:`x` and :math:`y` on the manifold. + + .. math:: + + \operatorname{Log}^c_x(y) = \frac{2}{\sqrt{c}\lambda_x^c} \tanh^{-1}( + \sqrt{c} \|(-x)\oplus_c y\|_2 + ) * \frac{(-x)\oplus_c y}{\|(-x)\oplus_c y\|_2} + + The result of Logarithmic map is a vector such that + + .. math:: + + y = \operatorname{Exp}^c_x(\operatorname{Log}^c_x(y)) + + + Parameters + ---------- + x : tensor + starting point on Poincare ball + y : tensor + target point on Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + tangent vector that transports :math:`x` to :math:`y` + """ + return _logmap(x, y, c, dim=dim) + + +def _logmap(x, y, c, dim: int = -1): + sub = _mobius_add(-x, y, c, dim=dim) + sub_norm = sub.norm(dim=dim, p=2, keepdim=True) + lam = _lambda_x(x, c, keepdim=True, dim=dim) + sqrt_c = c ** 0.5 + return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm + + +def logmap0(y, *, c=1.0, dim=-1): + r""" + Logarithmic map for :math:`y` from :math:`0` on the manifold. + + + .. math:: + + \operatorname{Log}^c_0(y) = \tanh^{-1}(\sqrt{c}\|y\|_2) \frac{y}{\|y\|_2} + + The result is such that + + .. math:: + + y = \operatorname{Exp}^c_0(\operatorname{Log}^c_0(y)) + + Parameters + ---------- + y : tensor + target point on Poincare ball + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + tangent vector that transports :math:`0` to :math:`y` + """ + return _logmap0(y, c, dim=dim) + + +def _logmap0(y, c, dim: int = -1): + sqrt_c = c ** 0.5 + y = y + 1e-15 + y_norm = y.norm(dim=dim, p=2, keepdim=True) + return y / y_norm / sqrt_c * artanh(sqrt_c * y_norm) + + +def mobius_matvec(m, x, *, c=1.0, dim=-1): + r""" + Generalization for matrix-vector multiplication to hyperbolic space defined as + + .. math:: + + M \otimes_c x = (1/\sqrt{c}) \tanh\left( + \frac{\|Mx\|_2}{\|x\|_2}\tanh^{-1}(\sqrt{c}\|x\|_2) + \right)\frac{Mx}{\|Mx\|_2} + + .. plot:: plots/extended/poincare/mobius_matvec.py + + Parameters + ---------- + m : tensor + matrix for multiplication. + Batched matmul is performed if ``m.dim() > 2``, but only last dim reduction is supported + x : tensor + point on Poincare ball + c : float|tensor + negative ball curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + Mobius matvec result + """ + return _mobius_matvec(m, x, c, dim=dim) + + +def _mobius_matvec(m, x, c, dim: int = -1): + if m.dim() > 2 and dim != -1: + raise RuntimeError( + "broadcasted Mobius matvec is supported for the last dim only" + ) + x = x + 1e-15 + x_norm = x.norm(dim=dim, keepdim=True, p=2) + sqrt_c = c ** 0.5 + if dim != -1 or m.dim() == 2: + mx = torch.tensordot(x, m, dims=([dim], [1])) + else: + mx = torch.matmul(m, x.unsqueeze(-1)).squeeze(-1) + mx_norm = mx.norm(dim=dim, keepdim=True, p=2) + res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c) + cond = (mx == 0).prod(dim=dim, keepdim=True, dtype=torch.uint8) + res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device) + res = torch.where(cond, res_0, res_c) + return res + + +def mobius_pointwise_mul(w, x, *, c=1.0, dim=-1): + r""" + Generalization for point-wise multiplication to hyperbolic space defined as + + .. math:: + + \operatorname{diag}(w) \otimes_c x = (1/\sqrt{c}) \tanh\left( + \frac{\|\operatorname{diag}(w)x\|_2}{x}\tanh^{-1}(\sqrt{c}\|x\|_2) + \right)\frac{\|\operatorname{diag}(w)x\|_2}{\|x\|_2} + + + Parameters + ---------- + w : tensor + weights for multiplication (should be broadcastable to x) + x : tensor + point on Poincare ball + c : float|tensor + negative ball curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + Mobius point-wise mul result + """ + return _mobius_pointwise_mul(w, x, c, dim=dim) + + +def _mobius_pointwise_mul(w, x, c, dim: int = -1): + x = x + 1e-15 + x_norm = x.norm(dim=dim, keepdim=True, p=2) + sqrt_c = c ** 0.5 + wx = w * x + wx_norm = wx.norm(dim=dim, keepdim=True, p=2) + res_c = tanh(wx_norm / x_norm * artanh(sqrt_c * x_norm)) * wx / (wx_norm * sqrt_c) + cond = (wx == 0).prod(dim=dim, keepdim=True, dtype=torch.uint8) + res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device) + res = torch.where(cond, res_0, res_c) + return res + + +def mobius_fn_apply_chain(x, *fns, c=1.0, dim=-1): + r""" + Generalization for functions in hyperbolic space. + First, hyperbolic vector is mapped to a Euclidean space via + :math:`\operatorname{Log}^c_0` and nonlinear function is applied in this tangent space. + The resulting vector is then mapped back with :math:`\operatorname{Exp}^c_0` + + .. math:: + + f^{\otimes_c}(x) = \operatorname{Exp}^c_0(f(\operatorname{Log}^c_0(y))) + + The definition of mobius function application allows chaining as + + .. math:: + + y = \operatorname{Exp}^c_0(\operatorname{Log}^c_0(y)) + + Resulting in + + .. math:: + + (f \circ g)^{\otimes_c}(x) = \operatorname{Exp}^c_0((f \circ g) (\operatorname{Log}^c_0(y))) + + Parameters + ---------- + x : tensor + point on Poincare ball + fns : callable[] + functions to apply + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + Apply chain result + """ + if not fns: + return x + else: + ex = _logmap0(x, c, dim=dim) + for fn in fns: + ex = fn(ex) + y = _expmap0(ex, c, dim=dim) + return y + + +def mobius_fn_apply(fn, x, *args, c=1.0, dim=-1, **kwargs): + r""" + Generalization for functions in hyperbolic space. + First, hyperbolic vector is mapped to a Euclidean space via + :math:`\operatorname{Log}^c_0` and nonlinear function is applied in this tangent space. + The resulting vector is then mapped back with :math:`\operatorname{Exp}^c_0` + + .. math:: + + f^{\otimes_c}(x) = \operatorname{Exp}^c_0(f(\operatorname{Log}^c_0(y))) + + .. plot:: plots/extended/poincare/mobius_sigmoid_apply.py + + Parameters + ---------- + x : tensor + point on Poincare ball + fn : callable + function to apply + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + Result of function in hyperbolic space + """ + ex = _logmap0(x, c, dim=dim) + ex = fn(ex, *args, **kwargs) + y = _expmap0(ex, c, dim=dim) + return y + + +def mobiusify(fn): + r""" + Wraps a function so that is works in hyperbolic space. New function will accept additional argument ``c`` + + Parameters + ---------- + fn : callable + function in Euclidean space, only its first argument is treated as hyperbolic + + Returns + ------- + callable + function working in hyperbolic space + """ + + @functools.wraps(fn) + def mobius_fn(x, *args, c=1.0, dim=-1, **kwargs): + ex = _logmap0(x, c, dim=dim) + ex = fn(ex, *args, **kwargs) + y = _expmap0(ex, c, dim=dim) + return y + + return mobius_fn + + +def dist2plane(x, p, a, *, c=1.0, keepdim=False, signed=False, dim=-1): + r""" + Distance from :math:`x` to a hyperbolic hyperplane in Poincare ball + that is orthogonal to :math:`a` and contains :math:`p`. + + .. plot:: plots/extended/poincare/distance2plane.py + + To form an intuition what is a hyperbolic hyperplane, let's first consider Euclidean hyperplane + + .. math:: + + H_{a, b} = \left\{ + x \in \mathbb{R}^n\;:\;\langle x, a\rangle - b = 0 + \right\}, + + where :math:`a\in \mathbb{R}^n\backslash \{\mathbf{0}\}` and :math:`b\in \mathbb{R}^n`. + + This formulation of a hyperplane is hard to generalize, + therefore we can rewrite :math:`\langle x, a\rangle - b` + utilizing orthogonal completion. + Setting any :math:`p` s.t. :math:`b=\langle a, p\rangle` we have + + .. math:: + + H_{a, b} = \left\{ + x \in \mathbb{R}^n\;:\;\langle x, a\rangle - b = 0 + \right\}\\ + =H_{a, \langle a, p\rangle} = \tilde{H}_{a, p}\\ + = \left\{ + x \in \mathbb{R}^n\;:\;\langle x, a\rangle - \langle a, p\rangle = 0 + \right\}\\ + =\left\{ + x \in \mathbb{R}^n\;:\;\langle -p + x, a\rangle = 0 + \right\}\\ + = p + \{a\}^\perp + + Naturally we have a set :math:`\{a\}^\perp` with applied :math:`+` operator to each element. + Generalizing a notion of summation to the hyperbolic space we replace :math:`+` with :math:`\oplus_c`. + + Next, we should figure out what is :math:`\{a\}^\perp` in the Poincare ball. + + First thing that we should acknowledge is that notion of orthogonality is defined for vectors in tangent spaces. + Let's consider now :math:`p\in \mathbb{D}_c^n` and :math:`a\in T_p\mathbb{D}_c^n\backslash \{\mathbf{0}\}`. + + Slightly deviating from traditional notation let's write :math:`\{a\}_p^\perp` + highlighting the tight relationship of :math:`a\in T_p\mathbb{D}_c^n\backslash \{\mathbf{0}\}` + with :math:`p \in \mathbb{D}_c^n`. We then define + + .. math:: + + \{a\}_p^\perp := \left\{ + z\in T_p\mathbb{D}_c^n \;:\; \langle z, a\rangle_p = 0 + \right\} + + Recalling that a tangent vector :math:`z` for point :math:`p` yields :math:`x = \operatorname{Exp}^c_p(z)` + we rewrite the above equation as + + .. math:: + \{a\}_p^\perp := \left\{ + x\in \mathbb{D}_c^n \;:\; \langle \operatorname{Log}_p^c(x), a\rangle_p = 0 + \right\} + + This formulation is something more pleasant to work with. + Putting all together + + .. math:: + + \tilde{H}_{a, p}^c = p + \{a\}^\perp_p\\ + = \left\{ + x \in \mathbb{D}_c^n\;:\;\langle\operatorname{Log}^c_p(x), a\rangle_p = 0 + \right\} \\ + = \left\{ + x \in \mathbb{D}_c^n\;:\;\langle -p \oplus_c x, a\rangle = 0 + \right\} + + To compute the distance :math:`d_c(x, \tilde{H}_{a, p}^c)` we find + + .. math:: + + d_c(x, \tilde{H}_{a, p}^c) = \inf_{w\in \tilde{H}_{a, p}^c} d_c(x, w)\\ + = \frac{1}{\sqrt{c}} \sinh^{-1}\left\{ + \frac{ + 2\sqrt{c} |\langle(-p)\oplus_c x, a\rangle| + }{ + (1-c\|(-p)\oplus_c x\|^2_2)\|a\|_2 + } + \right\} + + Parameters + ---------- + x : tensor + point on Poincare ball + a : tensor + vector on tangent space of :math:`p` + p : tensor + point on Poincare ball lying on the hyperplane + c : float|tensor + ball negative curvature + keepdim : bool + retain the last dim? (default: false) + signed : bool + return signed distance + dim : int + reduction dimension for operations + + Returns + ------- + tensor + distance to the hyperplane + """ + return _dist2plane(x, a, p, c, keepdim=keepdim, signed=signed, dim=dim) + + +def _dist2plane(x, a, p, c, keepdim: bool = False, signed: bool = False, dim: int = -1): + sqrt_c = c ** 0.5 + diff = _mobius_add(-p, x, c, dim=dim) + diff_norm2 = diff.pow(2).sum(dim=dim, keepdim=keepdim) + sc_diff_a = (diff * a).sum(dim=dim, keepdim=keepdim) + if not signed: + sc_diff_a = sc_diff_a.abs() + a_norm = a.norm(dim=dim, keepdim=keepdim, p=2) + num = 2 * sqrt_c * sc_diff_a + denom = (1 - c * diff_norm2) * a_norm + return arsinh(num / (denom + 1e-15)) / sqrt_c + + +def gyration(a, b, u, *, c=1.0, dim=-1): + r""" + Gyration is a special operation in hyperbolic geometry. + Addition operation :math:`\oplus_c` is not associative (as mentioned in :func:`mobius_add`), + but gyroassociative which means + + .. math:: + + u \oplus_c (v \oplus_c w) = (u\oplus_c v) \oplus_c \operatorname{gyr}[u, v]w, + + where + + .. math:: + + \operatorname{gyr}[u, v]w = \ominus (u \oplus_c v) \oplus (u \oplus_c (v \oplus_c w)) + + We can simplify this equation using explicit formula for Mobius addition [1]. Recall + + .. math:: + + A = - c^2 \langle u, w\rangle \langle v, v\rangle + c \langle v, w\rangle + + 2 c^2 \langle u, v\rangle \langle v, w\rangle\\ + B = - c^2 \langle v, w\rangle \langle u, u\rangle - c \langle u, w\rangle\\ + D = 1 + 2 c \langle u, v\rangle + c^2 \langle u, u\rangle \langle v, v\rangle\\ + + \operatorname{gyr}[u, v]w = w + 2 \frac{A u + B v}{D} + + Parameters + ---------- + a : tensor + first point on Poincare ball + b : tensor + second point on Poincare ball + u : tensor + vector field for operation + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + the result of automorphism + + References + ---------- + [1] A. A. Ungar (2009), A Gyrovector Space Approach to Hyperbolic Geometry + """ + return _gyration(a, b, u, c, dim=dim) + + +def _gyration(u, v, w, c, dim: int = -1): + # non-simplified + # mupv = -_mobius_add(u, v, c) + # vpw = _mobius_add(u, w, c) + # upvpw = _mobius_add(u, vpw, c) + # return _mobius_add(mupv, upvpw, c) + # simplified + u2 = u.pow(2).sum(dim=dim, keepdim=True) + v2 = v.pow(2).sum(dim=dim, keepdim=True) + uv = (u * v).sum(dim=dim, keepdim=True) + uw = (u * w).sum(dim=dim, keepdim=True) + vw = (v * w).sum(dim=dim, keepdim=True) + c2 = c ** 2 + a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw + b = -c2 * vw * u2 - c * uw + d = 1 + 2 * c * uv + c2 * u2 * v2 + return w + 2 * (a * u + b * v) / (d + 1e-15) + + +def parallel_transport(x, y, v, *, c=1.0, dim=-1): + r""" + Parallel transport is essential for adaptive algorithms in Riemannian manifolds. + For Hyperbolic spaces parallel transport is expressed via gyration. + + .. plot:: plots/extended/poincare/gyrovector_parallel_transport.py + + To recover parallel transport we first need to study isomorphism between gyrovectors and vectors. + The reason is that originally, parallel transport is well defined for gyrovectors as + + .. math:: + + P_{x\to y}(z) = \operatorname{gyr}[y, -x]z, + + where :math:`x,\:y,\:z \in \mathbb{D}_c^n` and + :math:`\operatorname{gyr}[a, b]c = \ominus (a \oplus_c b) \oplus_c (a \oplus_c (b \oplus_c c))` + + But we want to obtain parallel transport for vectors, not for gyrovectors. + The blessing is isomorphism mentioned above. This mapping is given by + + .. math:: + + U^c_p \: : \: T_p\mathbb{D}_c^n \to \mathbb{G} = v \mapsto \lambda^c_p v + + + Finally, having points :math:`x,\:y \in \mathbb{D}_c^n` and a tangent vector :math:`u\in T_x\mathbb{D}_c^n` we obtain + + .. math:: + + P^c_{x\to y}(v) = (U^c_y)^{-1}\left(\operatorname{gyr}[y, -x] U^c_x(v)\right)\\ + = \operatorname{gyr}[y, -x] v \lambda^c_x / \lambda^c_y + + .. plot:: plots/extended/poincare/parallel_transport.py + + + Parameters + ---------- + x : tensor + starting point + y : tensor + end point + v : tensor + tangent vector to be transported + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + transported vector + """ + return _parallel_transport(x, y, v, c, dim=dim) + + +def _parallel_transport(x, y, u, c, dim: int = -1): + return ( + _gyration(y, -x, u, c, dim=dim) + * _lambda_x(x, c, keepdim=True, dim=dim) + / _lambda_x(y, c, keepdim=True, dim=dim) + ) + + +def parallel_transport0(y, v, *, c=1.0, dim=-1): + r""" + Special case parallel transport with starting point at zero that + can be computed more efficiently and numerically stable + + Parameters + ---------- + y : tensor + target point + v : tensor + vector to be transported + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + """ + return _parallel_transport0(y, v, c, dim=dim) + + +def _parallel_transport0(y, v, c, dim: int = -1): + return v * (1 - c * y.pow(2).sum(dim=dim, keepdim=True)) + + +def egrad2rgrad(x, grad, *, c=1.0, dim=-1): + r""" + Translate Euclidean gradient to Riemannian gradient on tangent space of :math:`x` + + .. math:: + + \nabla_x = \nabla^E_x / (\lambda_x^c)^2 + + Parameters + ---------- + x : tensor + point on the Poincare ball + grad : tensor + Euclidean gradient for :math:`x` + c : float|tensor + ball negative curvature + dim : int + reduction dimension for operations + + Returns + ------- + tensor + Riemannian gradient :math:`u\in T_x\mathbb{D}_c^n` + """ + return _egrad2rgrad(x, grad, c, dim=dim) + + +def _egrad2rgrad(x, grad, c, dim: int = -1): + return grad / _lambda_x(x, c, keepdim=True, dim=dim) ** 2 diff --git a/geoopt/manifolds/sphere.py b/geoopt/manifolds/sphere.py index 5c3741d1..2795d37e 100644 --- a/geoopt/manifolds/sphere.py +++ b/geoopt/manifolds/sphere.py @@ -38,14 +38,14 @@ def _check_point_on_manifold(self, x, atol=1e-5, rtol=1e-5): return True, None def _check_vector_on_tangent(self, x, u, atol=1e-5, rtol=1e-5): - inner = self._inner(None, x, u) + inner = self._inner(None, x, u, keepdim=True) ok = torch.allclose(inner, inner.new((1,)).fill_(0), atol=atol, rtol=rtol) if not ok: return False, "` != 0` with atol={}, rtol={}".format(atol, rtol) return True, None - def _inner(self, x, u, v): - return (u * v).sum(-1) + def _inner(self, x, u, v, keepdim): + return (u * v).sum(-1, keepdim=keepdim) def _projx(self, x): return x / x.norm(dim=-1, keepdim=True) @@ -88,13 +88,13 @@ def _expmap_transp(self, x, v, *more, u, t): def _logmap(self, x, y): u = self._proju(x, y - x) - dist = self._dist(x, y).unsqueeze(-1) + dist = self._dist(x, y, keepdim=True) # If the two points are "far apart", correct the norm. cond = dist.gt(1e-6) return torch.where(cond, u * dist / u.norm(dim=-1, keepdim=True), u) - def _dist(self, x, y): - inner = self._inner(None, x, y).clamp(-1, 1) + def _dist(self, x, y, keepdim): + inner = self._inner(None, x, y, keepdim=keepdim).clamp(-1, 1) return torch.acos(inner) diff --git a/geoopt/manifolds/stiefel.py b/geoopt/manifolds/stiefel.py index 6bfa5c64..a6100ba6 100644 --- a/geoopt/manifolds/stiefel.py +++ b/geoopt/manifolds/stiefel.py @@ -89,7 +89,7 @@ class CanonicalStiefel(Stiefel): name = "Stiefel(canonical)" reversible = True - def _inner(self, x, u, v): + def _inner(self, x, u, v, keepdim): # _x = tr(u^T(I-1/2xx^T)v) # = tr(u^T(v-1/2xx^Tv)) # = tr(u^Tv-1/2u^Txx^Tv) @@ -102,7 +102,9 @@ def _inner(self, x, u, v): v = u else: xtv = x.transpose(-1, -2) @ v - return (u * v).sum([-1, -2]) - 0.5 * (xtv * xtu).sum([-1, -2]) + return (u * v).sum([-1, -2], keepdim=keepdim) - 0.5 * (xtv * xtu).sum( + [-1, -2], keepdim=keepdim + ) # we do faster on inner without autofill _inner_autofill = False @@ -112,6 +114,7 @@ def _transp_follow_one(self, x, v, *, u, t): rhs = v + t / 2 * a @ v lhs = -t / 2 * a lhs[..., torch.arange(a.shape[-2]), torch.arange(x.shape[-2])] += 1 + # TODO: torch.gesv -> torch.solve after pytorch release qv, _ = torch.gesv(rhs, lhs) return qv @@ -182,8 +185,8 @@ def _retr_transp(self, x, v, *more, u, t): else: return y, vs - def _inner(self, x, u, v): - return (u * v).sum([-1, -2]) + def _inner(self, x, u, v, keepdim): + return (u * v).sum([-1, -2], keepdim=keepdim) def _retr(self, x, u, t): q, r = linalg.batch_linalg.qr(x + u * t) diff --git a/geoopt/optim/radam.py b/geoopt/optim/radam.py index b669d6ad..b6a6d98a 100644 --- a/geoopt/optim/radam.py +++ b/geoopt/optim/radam.py @@ -83,17 +83,10 @@ def step(self, closure=None): # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values - inner_prod_shape = p.shape - if manifold.ndim > 0: - inner_prod_shape = inner_prod_shape[: -manifold.ndim] - state["exp_avg_sq"] = torch.zeros( - inner_prod_shape, dtype=p.dtype, device=p.device - ) + state["exp_avg_sq"] = torch.zeros_like(p) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state["max_exp_avg_sq"] = torch.zeros( - inner_prod_shape, dtype=p.dtype, device=p.device - ) + state["max_exp_avg_sq"] = torch.zeros_like(p) # this is assumed to be already transported if "traced_step" not in state: @@ -174,7 +167,9 @@ def perform_step( grad.add_(weight_decay, point) grad = manifold.egrad2rgrad(point, grad) exp_avg.mul_(betas[0]).add_(1 - betas[0], grad) - exp_avg_sq.mul_(betas[1]).add_(1 - betas[1], manifold.inner(point, grad)) + exp_avg_sq.mul_(betas[1]).add_( + 1 - betas[1], manifold.inner(point, grad, keepdim=True) + ) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) @@ -182,7 +177,6 @@ def perform_step( denom = max_exp_avg_sq.sqrt().add_(eps) else: denom = exp_avg_sq.sqrt().add_(eps) - denom = manifold.broadcast_scalar(denom) step.add_(1) bias_correction1 = 1 - betas[0] ** step.type_as(betas) bias_correction2 = 1 - betas[1] ** step.type_as(betas) diff --git a/geoopt/optim/tracing.py b/geoopt/optim/tracing.py index 842520b1..7b014b23 100644 --- a/geoopt/optim/tracing.py +++ b/geoopt/optim/tracing.py @@ -33,5 +33,6 @@ def create_traced_update(step, manifold, point, *buffers, **kwargs): def partial(*args): step(manifold, *args, **kwargs) return args + return partial # return torch.jit.trace(partial, (point, grad, lr) + buffers) diff --git a/requirements-dev.txt b/requirements-dev.txt index 5d33904c..ee440468 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,3 +7,4 @@ pymanopt twine wheel sphinx +seaborn diff --git a/setup.py b/setup.py index fe8f2fc0..f1a55f78 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ DESCRIPTION = """Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more""" PROJECT_ROOT = os.path.dirname(os.path.realpath(__file__)) -with open(os.path.join(PROJECT_ROOT, 'README.rst'), encoding='utf-8') as buff: +with open(os.path.join(PROJECT_ROOT, "README.rst"), encoding="utf-8") as buff: LONG_DESCRIPTION = buff.read() @@ -44,5 +44,5 @@ def get_version(*path): url="https://github.com/geoopt/geoopt", python_requires=">=3.6.0", license="Apache License, Version 2.0", - classifiers=classifiers + classifiers=classifiers, ) diff --git a/tests/test_adam.py b/tests/test_adam.py index 318b2879..96a3159e 100644 --- a/tests/test_adam.py +++ b/tests/test_adam.py @@ -30,3 +30,24 @@ def closure(): np.testing.assert_allclose(X.data, Xstar, atol=1e-5, rtol=1e-5) optim.load_state_dict(optim.state_dict()) optim.step(closure) + + +def test_adam_poincare(): + torch.manual_seed(44) + ideal = torch.tensor([0.5, 0.5]) + start = torch.randn(2) / 2 + start = geoopt.manifolds.poincare.math.expmap0(start, c=1.0) + start = geoopt.ManifoldParameter(start, manifold=geoopt.PoincareBall()) + + def closure(): + optim.zero_grad() + loss = geoopt.manifolds.poincare.math.dist(start, ideal) ** 2 + loss.backward() + return loss.item() + + optim = geoopt.optim.RiemannianAdam([start], lr=1e-2) + + for _ in range(2000): + optim.step(closure) + + np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5) diff --git a/tests/test_manifold.py b/tests/test_manifold.py index a10c8d06..4d3b6ec8 100644 --- a/tests/test_manifold.py +++ b/tests/test_manifold.py @@ -26,6 +26,7 @@ def t(request): # match implementation of pymanopt for stiefel functools.partial(geoopt.manifolds.Stiefel, canonical=False), functools.partial(geoopt.manifolds.Stiefel, canonical=True), + geoopt.manifolds.PoincareBall, geoopt.manifolds.Euclidean, geoopt.manifolds.Sphere, functools.partial( @@ -41,7 +42,7 @@ def t(request): def manifold(request, retraction_order): man = request.param() try: - return man.set_default_order(retraction_order) + return man.set_default_order(retraction_order).double() except ValueError: pytest.skip("not supported retraction order for {}".format(man)) @@ -63,6 +64,7 @@ def manifold(request, retraction_order): # shapes to verify unary element implementation shapes = { + geoopt.manifolds.PoincareBall: (3,), geoopt.manifolds.EuclideanStiefel: (10, 5), geoopt.manifolds.CanonicalStiefel: (10, 5), geoopt.manifolds.Euclidean: (1,), @@ -79,11 +81,17 @@ def manifold(request, retraction_order): @pytest.fixture() def unary_case(manifold): shape = shapes[type(manifold)] - manopt_manifold = mannopt[type(manifold)](*shape) np.random.seed(42) - rand = manopt_manifold.rand().astype("float64") - x = geoopt.ManifoldTensor(torch.from_numpy(rand), manifold=manifold) torch.manual_seed(43) + if type(manifold) in mannopt: + manopt_manifold = mannopt[type(manifold)](*shape) + rand = manopt_manifold.rand().astype("float64") + x = geoopt.ManifoldTensor(torch.from_numpy(rand), manifold=manifold) + else: + manopt_manifold = None + x = geoopt.ManifoldTensor( + torch.randn(shape, dtype=torch.float64) * 0.1, manifold=manifold + ) ex = geoopt.ManifoldTensor(torch.randn_like(x), manifold=manifold) v = x.proju(torch.randn_like(x)) ev = torch.randn_like(x) @@ -105,6 +113,8 @@ def test_projection_via_assert(unary_case): def test_vector_projection(unary_case): if isinstance(unary_case.manifold, geoopt.manifolds.CanonicalStiefel): pytest.skip("pymanopt uses euclidean Stiefel") + elif unary_case.manopt_manifold is None: + pytest.skip("pymanopt does not have {}".format(unary_case.manifold)) x = unary_case.x ev = unary_case.ev @@ -124,6 +134,8 @@ def test_vector_projection_via_assert(unary_case): def test_retraction(unary_case, retraction_order, t): + if unary_case.manopt_manifold is None: + pytest.skip("pymanopt does not have {}".format(unary_case.manifold)) if isinstance(unary_case.manifold, geoopt.manifolds.CanonicalStiefel): pytest.skip("pymanopt uses euclidean Stiefel") x = unary_case.x @@ -139,6 +151,8 @@ def test_retraction(unary_case, retraction_order, t): def test_transport(unary_case, t): + if unary_case.manopt_manifold is None: + pytest.skip("pymanopt does not have {}".format(unary_case.manifold)) if isinstance(unary_case.manifold, geoopt.manifolds.CanonicalStiefel): pytest.skip("pymanopt uses euclidean Stiefel") x = unary_case.x @@ -261,24 +275,28 @@ def test_broadcast_retr_transp_many(unary_case, t): def test_reversibility(unary_case, t): - torch.manual_seed(43) - X = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype) - U = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype) - X = unary_case.manifold.projx(X) - U = unary_case.manifold.proju(X, U) - Z, Q = unary_case.manifold.retr_transp(X, U, u=U, t=t) - X1, U1 = unary_case.manifold.retr_transp(Z, Q, u=Q, t=-t) if unary_case.manifold.reversible: + torch.manual_seed(43) + X = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype) + U = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype) + X = unary_case.manifold.projx(X) + U = unary_case.manifold.proju(X, U) + Z, Q = unary_case.manifold.retr_transp(X, U, u=U, t=t) + X1, U1 = unary_case.manifold.retr_transp(Z, Q, u=Q, t=-t) + np.testing.assert_allclose(X1, X, atol=1e-5) np.testing.assert_allclose(U1, U, atol=1e-5) else: - assert not np.allclose(X1, X, atol=1e-5) - assert not np.allclose(U1, U, atol=1e-5) + pytest.skip("The manifold {} is not supposed to be checked") def test_dist(unary_case): if type(unary_case.manifold)._dist is geoopt.manifolds.base.not_implemented: - pytest.skip("logmap is not implemented for {}".format(unary_case.manifold)) + pytest.skip("dist is not implemented for {}".format(unary_case.manifold)) + if unary_case.manopt_manifold is None: + pytest.skip( + "dist is not implemented for pymanopt {}".format(unary_case.manifold) + ) torch.manual_seed(43) x = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype) y = torch.randn(*unary_case.shape, dtype=unary_case.x.dtype) @@ -295,12 +313,16 @@ def test_logmap(unary_case, t): x = unary_case.x v = unary_case.v - y = unary_case.manopt_manifold.exp(x.numpy(), v.numpy() * t) - vman = unary_case.manopt_manifold.log(x.numpy(), y) - vhat = unary_case.manifold.logmap(x, torch.as_tensor(y)) - np.testing.assert_allclose(vhat, vman) + if unary_case.manopt_manifold is not None: + y = unary_case.manopt_manifold.exp(x.numpy(), v.numpy() * t) + vman = unary_case.manopt_manifold.log(x.numpy(), y) + vhat = unary_case.manifold.logmap(x, torch.as_tensor(y)) + np.testing.assert_allclose(vhat, vman, atol=1e-7) + else: + y = unary_case.manifold.expmap(x, v) + vhat = unary_case.manifold.logmap(x, torch.as_tensor(y)) ey = unary_case.manifold.expmap(x, vhat) - np.testing.assert_allclose(y, ey) + np.testing.assert_allclose(y, ey, atol=1e-7) def test_logmap_many(unary_case, t): @@ -317,4 +339,4 @@ def test_logmap_many(unary_case, t): Uh = unary_case.manifold.logmap(X, Y) Yh = unary_case.manifold.expmap(X, Uh) - np.testing.assert_allclose(Yh, Y) + np.testing.assert_allclose(Yh, Y, atol=1e-7) diff --git a/tests/test_poincare_math.py b/tests/test_poincare_math.py new file mode 100644 index 00000000..31780120 --- /dev/null +++ b/tests/test_poincare_math.py @@ -0,0 +1,371 @@ +""" +Tests ideas are taken mostly from https://github.com/dalab/hyperbolic_nn/blob/master/util.py with some changes +""" +import torch +import random +import numpy as np +import pytest +from geoopt.manifolds import poincare + + +@pytest.fixture("function", autouse=True, params=range(30, 40)) +def seed(request): + seed = request.param + torch.manual_seed(seed) + random.seed(seed) + return seed + + +@pytest.fixture("function", params=[torch.float64, torch.float32]) +def dtype(request): + return request.param + + +@pytest.fixture +def c(seed, dtype): + # test broadcasted and non broadcasted versions + if seed == 30: + c = torch.tensor(0.0).to(dtype) + elif seed == 35: + c = torch.zeros(100, 1, dtype=dtype) + elif seed > 35: + c = torch.rand(100, 1, dtype=dtype) + else: + c = torch.tensor(random.random()).to(dtype) + return c + 1e-10 + + +@pytest.fixture +def a(seed, c): + if seed in {30, 35}: + a = torch.randn(100, 10, dtype=c.dtype) + elif seed > 35: + # do not check numerically unstable regions + # I've manually observed small differences there + a = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1) + a /= a.norm(dim=-1, keepdim=True) * 1.3 + a *= (torch.rand_like(c) * c) ** 0.5 + else: + a = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1) + a /= a.norm(dim=-1, keepdim=True) * 1.3 + a *= random.uniform(0, c) ** 0.5 + return poincare.math.project(a, c=c) + + +@pytest.fixture +def b(seed, c): + if seed in {30, 35}: + b = torch.randn(100, 10, dtype=c.dtype) + elif seed > 35: + b = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1) + b /= b.norm(dim=-1, keepdim=True) * 1.3 + b *= (torch.rand_like(c) * c) ** 0.5 + else: + b = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1) + b /= b.norm(dim=-1, keepdim=True) * 1.3 + b *= random.uniform(0, c) ** 0.5 + return poincare.math.project(b, c=c) + + +def test_mobius_addition_left_cancelation(a, b, c): + res = poincare.math.mobius_add(-a, poincare.math.mobius_add(a, b, c=c), c=c) + tolerance = {torch.float32: dict(atol=1e-6, rtol=1e-6), torch.float64: dict()} + np.testing.assert_allclose(res, b, **tolerance[c.dtype]) + + +def test_mobius_addition_zero_a(b, c): + a = torch.zeros(100, 10, dtype=c.dtype) + res = poincare.math.mobius_add(a, b, c=c) + np.testing.assert_allclose(res, b) + + +def test_mobius_addition_zero_b(a, c): + b = torch.zeros(100, 10, dtype=c.dtype) + res = poincare.math.mobius_add(a, b, c=c) + np.testing.assert_allclose(res, a) + + +def test_mobius_addition_negative_cancellation(a, c): + res = poincare.math.mobius_add(a, -a, c=c) + tolerance = { + torch.float32: dict(atol=1e-7, rtol=1e-6), + torch.float64: dict(atol=1e-10), + } + np.testing.assert_allclose(res, torch.zeros_like(res), **tolerance[c.dtype]) + + +def test_mobius_negative_addition(a, b, c): + res = poincare.math.mobius_add(-b, -a, c=c) + res1 = -poincare.math.mobius_add(b, a, c=c) + tolerance = { + torch.float32: dict(atol=1e-7, rtol=1e-6), + torch.float64: dict(atol=1e-10), + } + np.testing.assert_allclose(res, res1, **tolerance[c.dtype]) + + +@pytest.mark.parametrize("n", list(range(5))) +def test_n_additions_via_scalar_multiplication(n, a, c): + y = torch.zeros_like(a) + for _ in range(n): + y = poincare.math.mobius_add(a, y, c=c) + ny = poincare.math.mobius_scalar_mul(n, a, c=c) + tolerance = { + torch.float32: dict(atol=1e-7, rtol=1e-6), + torch.float64: dict(atol=1e-10), + } + np.testing.assert_allclose(y, ny, **tolerance[c.dtype]) + + +@pytest.fixture +def r1(seed, dtype): + if seed % 3 == 0: + return random.uniform(-1, 1) + else: + return torch.rand(100, 1, dtype=dtype) * 2 - 1 + + +@pytest.fixture +def r2(seed, dtype): + if seed % 3 == 1: + return random.uniform(-1, 1) + else: + return torch.rand(100, 1, dtype=dtype) * 2 - 1 + + +def test_scalar_multiplication_distributive(a, c, r1, r2): + res = poincare.math.mobius_scalar_mul(r1 + r2, a, c=c) + res1 = poincare.math.mobius_add( + poincare.math.mobius_scalar_mul(r1, a, c=c), + poincare.math.mobius_scalar_mul(r2, a, c=c), + c=c, + ) + res2 = poincare.math.mobius_add( + poincare.math.mobius_scalar_mul(r1, a, c=c), + poincare.math.mobius_scalar_mul(r2, a, c=c), + c=c, + ) + tolerance = { + torch.float32: dict(atol=1e-6, rtol=1e-7), + torch.float64: dict(atol=1e-7, rtol=1e-10), + } + np.testing.assert_allclose(res1, res, **tolerance[c.dtype]) + np.testing.assert_allclose(res2, res, **tolerance[c.dtype]) + + +def test_scalar_multiplication_associative(a, c, r1, r2): + res = poincare.math.mobius_scalar_mul(r1 * r2, a, c=c) + res1 = poincare.math.mobius_scalar_mul( + r1, poincare.math.mobius_scalar_mul(r2, a, c=c), c=c + ) + res2 = poincare.math.mobius_scalar_mul( + r2, poincare.math.mobius_scalar_mul(r1, a, c=c), c=c + ) + tolerance = { + torch.float32: dict(atol=1e-7, rtol=1e-6), # worked with rtol=1e-7 locally + torch.float64: dict(atol=1e-7, rtol=1e-10), + } + np.testing.assert_allclose(res1, res, **tolerance[c.dtype]) + np.testing.assert_allclose(res2, res, **tolerance[c.dtype]) + + +def test_scaling_property(a, c, r1): + x1 = a / a.norm(dim=-1, keepdim=True) + ra = poincare.math.mobius_scalar_mul(r1, a, c=c) + x2 = poincare.math.mobius_scalar_mul(abs(r1), a, c=c) / ra.norm( + dim=-1, keepdim=True + ) + tolerance = { + torch.float32: dict(rtol=1e-5, atol=1e-6), + torch.float64: dict(atol=1e-10), + } + np.testing.assert_allclose(x1, x2, **tolerance[c.dtype]) + + +def test_geodesic_borders(a, b, c): + geo0 = poincare.math.geodesic(0.0, a, b, c=c) + geo1 = poincare.math.geodesic(1.0, a, b, c=c) + tolerance = { + torch.float32: dict(rtol=1e-5, atol=1e-6), + torch.float64: dict(atol=1e-10), + } + np.testing.assert_allclose(geo0, a, **tolerance[c.dtype]) + np.testing.assert_allclose(geo1, b, **tolerance[c.dtype]) + + +def test_geodesic_segment_length_property(a, b, c): + extra_dims = len(a.shape) + segments = 12 + t = torch.linspace(0, 1, segments + 1, dtype=c.dtype).view( + (segments + 1,) + (1,) * extra_dims + ) + gamma_ab_t = poincare.math.geodesic(t, a, b, c=c) + gamma_ab_t0 = gamma_ab_t[:-1] + gamma_ab_t1 = gamma_ab_t[1:] + dist_ab_t0mt1 = poincare.math.dist(gamma_ab_t0, gamma_ab_t1, c=c, keepdim=True) + speed = ( + poincare.math.dist(a, b, c=c, keepdim=True) + .unsqueeze(0) + .expand_as(dist_ab_t0mt1) + ) + # we have exactly 12 line segments + tolerance = {torch.float32: dict(rtol=1e-5), torch.float64: dict(atol=1e-10)} + np.testing.assert_allclose(dist_ab_t0mt1, speed / segments, **tolerance[c.dtype]) + + +def test_geodesic_segement_unit_property(a, b, c): + extra_dims = len(a.shape) + segments = 12 + t = torch.linspace(0, 1, segments + 1, dtype=c.dtype).view( + (segments + 1,) + (1,) * extra_dims + ) + gamma_ab_t = poincare.math.geodesic_unit(t, a, b, c=c) + gamma_ab_t0 = gamma_ab_t[:1] + gamma_ab_t1 = gamma_ab_t + dist_ab_t0mt1 = poincare.math.dist(gamma_ab_t0, gamma_ab_t1, c=c, keepdim=True) + true_distance_travelled = t.expand_as(dist_ab_t0mt1) + # we have exactly 12 line segments + tolerance = { + torch.float32: dict(atol=1e-6, rtol=1e-5), + torch.float64: dict(atol=1e-10), + } + np.testing.assert_allclose( + dist_ab_t0mt1, true_distance_travelled, **tolerance[c.dtype] + ) + + +def test_expmap_logmap(a, b, c): + # this test appears to be numerical unstable once a and b may appear on the opposite sides + bh = poincare.math.expmap(x=a, u=poincare.math.logmap(a, b, c=c), c=c) + tolerance = {torch.float32: dict(rtol=1e-5, atol=1e-6), torch.float64: dict()} + np.testing.assert_allclose(bh, b, **tolerance[c.dtype]) + + +def test_expmap0_logmap0(a, c): + # this test appears to be numerical unstable once a and b may appear on the opposite sides + v = poincare.math.logmap0(a, c=c) + norm = poincare.math.norm(torch.zeros_like(v), v, c=c, keepdim=True) + dist = poincare.math.dist0(a, c=c, keepdim=True) + bh = poincare.math.expmap0(v, c=c) + tolerance = {torch.float32: dict(rtol=1e-6), torch.float64: dict()} + np.testing.assert_allclose(bh, a, **tolerance[c.dtype]) + np.testing.assert_allclose(norm, dist, **tolerance[c.dtype]) + + +def test_matvec_zeros(a, c): + mat = a.new_zeros(3, a.shape[-1]) + z = poincare.math.mobius_matvec(mat, a, c=c) + np.testing.assert_allclose(z, 0.0) + + +def test_matvec_via_equiv_fn_apply(a, c): + mat = a.new(3, a.shape[-1]).normal_() + y = poincare.math.mobius_fn_apply(lambda x: x @ mat.transpose(-1, -2), a, c=c) + y1 = poincare.math.mobius_matvec(mat, a, c=c) + tolerance = {torch.float32: dict(atol=1e-5), torch.float64: dict()} + np.testing.assert_allclose(y, y1, **tolerance[c.dtype]) + + +def test_mobiusify(a, c): + mat = a.new(3, a.shape[-1]).normal_() + + @poincare.math.mobiusify + def matvec(x): + return x @ mat.transpose(-1, -2) + + y = matvec(a, c=c) + y1 = poincare.math.mobius_matvec(mat, a, c=c) + tolerance = {torch.float32: dict(atol=1e-5), torch.float64: dict()} + np.testing.assert_allclose(y, y1, **tolerance[c.dtype]) + + +def test_matvec_chain_via_equiv_fn_apply(a, c): + mat1 = a.new(a.shape[-1], a.shape[-1]).normal_() + mat2 = a.new(a.shape[-1], a.shape[-1]).normal_() + y = poincare.math.mobius_fn_apply_chain( + a, + lambda x: x @ mat1.transpose(-1, -2), + lambda x: x @ mat2.transpose(-1, -2), + c=c, + ) + y1 = poincare.math.mobius_matvec(mat1, a, c=c) + y1 = poincare.math.mobius_matvec(mat2, y1, c=c) + np.testing.assert_allclose(y, y1, atol=1e-5) + + +def test_parallel_transport0_preserves_inner_products(a, c): + # pointing to the center + v_0 = torch.rand_like(a) + 1e-5 + u_0 = torch.rand_like(a) + 1e-5 + zero = torch.zeros_like(a) + v_a = poincare.math.parallel_transport0(a, v_0, c=c) + u_a = poincare.math.parallel_transport0(a, u_0, c=c) + # compute norms + vu_0 = poincare.math.inner(zero, v_0, u_0, c=c, keepdim=True) + vu_a = poincare.math.inner(a, v_a, u_a, c=c, keepdim=True) + np.testing.assert_allclose(vu_a, vu_0, atol=1e-6, rtol=1e-6) + + +def test_parallel_transport0_is_same_as_usual(a, c): + # pointing to the center + v_0 = torch.rand_like(a) + 1e-5 + zero = torch.zeros_like(a) + v_a = poincare.math.parallel_transport0(a, v_0, c=c) + v_a1 = poincare.math.parallel_transport(zero, a, v_0, c=c) + # compute norms + np.testing.assert_allclose(v_a, v_a1, atol=1e-6, rtol=1e-6) + + +def test_parallel_transport_a_b(a, b, c): + # pointing to the center + v_0 = torch.rand_like(a) + u_0 = torch.rand_like(a) + v_1 = poincare.math.parallel_transport(a, b, v_0, c=c) + u_1 = poincare.math.parallel_transport(a, b, u_0, c=c) + # compute norms + vu_1 = poincare.math.inner(b, v_1, u_1, c=c, keepdim=True) + vu_0 = poincare.math.inner(a, v_0, u_0, c=c, keepdim=True) + np.testing.assert_allclose(vu_0, vu_1, atol=1e-6, rtol=1e-6) + + +def test_add_infinity_and_beyond(a, b, c): + infty = b * 10000000 + infty = poincare.math.clip_tangent(a, infty, c=c) + for i in range(100): + z = poincare.math.expmap(a, infty, c=c) + z = poincare.math.project(z, c=c) + z = poincare.math.mobius_scalar_mul(1000.0, z, c=c) + z = poincare.math.project(z, c=c) + infty = poincare.math.parallel_transport(a, z, infty, c=c) + assert np.isfinite(z).all(), (i, z) + assert np.isfinite(infty).all(), (i, infty) + a = z + z = poincare.math.expmap(a, -infty, c=c) + # they just need to be very far, exact answer is not supposed + tolerance = { + torch.float32: dict(rtol=3e-1, atol=2e-1), + torch.float64: dict(rtol=1e-1, atol=1e-3), + } + np.testing.assert_allclose(z, -a, **tolerance[c.dtype]) + + +def test_mobius_coadd(a, b, c): + # (a \boxplus_c b) \ominus_c b = a + ah = poincare.math.mobius_sub(poincare.math.mobius_coadd(a, b, c=c), b, c=c) + np.testing.assert_allclose(ah, a, atol=1e-5) + + +def test_mobius_cosub(a, b, c): + # (a \oplus_c b) \boxminus b = a + ah = poincare.math.mobius_cosub(poincare.math.mobius_add(a, b, c=c), b, c=c) + np.testing.assert_allclose(ah, a, atol=1e-5) + + +def test_distance2plane(a, c): + v = torch.rand_like(a) + vr = v / poincare.math.norm(a, v, c=c, keepdim=True) + z = poincare.math.expmap(a, vr, c=c) + dist1 = poincare.math.dist(a, z, c=c) + dist = poincare.math.dist2plane(z, a, vr, c=c) + + np.testing.assert_allclose(dist, dist1, atol=1e-5, rtol=1e-5) diff --git a/tests/test_rhmc.py b/tests/test_rhmc.py index 3fb8bdd1..2ce01c8f 100644 --- a/tests/test_rhmc.py +++ b/tests/test_rhmc.py @@ -6,6 +6,15 @@ import geoopt.samplers.rhmc +@pytest.fixture(autouse=True) +def withdtype(): + torch.set_default_dtype(torch.float64) + try: + yield + finally: + torch.set_default_dtype(torch.float32) + + @pytest.mark.parametrize( "params", [ @@ -20,8 +29,6 @@ ], ) def test_leapfrog_reversibility(params): - torch.set_default_dtype(torch.float64) - class NormalDist(torch.nn.Module): def __init__(self, mu, sigma): super().__init__() diff --git a/tests/test_utils.py b/tests/test_utils.py index e335ad82..da936667 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -106,3 +106,14 @@ def test_manifold_is_submodule(): container = torch.nn.ModuleDict({"sphere": sub_sphere}) container.to(torch.float64) assert sub_sphere._projector.dtype == torch.float64 + + +def test_manifold_is_submodule_poincare(): + print(torch.get_default_dtype()) + c = torch.tensor(1.0) + ball = geoopt.manifolds.PoincareBall(c) + assert ball.c.dtype == torch.float32 + ball.to(torch.float64) + container = torch.nn.ModuleDict({"ball": ball}) + container.to(torch.float64) + assert ball.c.dtype == torch.float64