Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poincare ball model #45

Merged
merged 82 commits into from Mar 31, 2019
Merged
Show file tree
Hide file tree
Changes from 74 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
237272a
add base
ferrine Feb 13, 2019
0025386
add mobius add|sub
ferrine Feb 14, 2019
3316cca
fix
ferrine Feb 14, 2019
34b19ac
missing formulas
ferrine Feb 14, 2019
f0438c1
remove unused import
ferrine Feb 14, 2019
32a4f9e
add scalar mul, test props
ferrine Feb 14, 2019
b0a3ca0
unnessesary cons in project
ferrine Feb 14, 2019
d268499
no cover script functions
ferrine Feb 14, 2019
b873819
add distance
ferrine Feb 15, 2019
c5380c1
fix typo in comment
ferrine Feb 15, 2019
27816d9
add geodesics
ferrine Feb 15, 2019
969b9d8
add expmap
ferrine Feb 16, 2019
a36eaf0
add functions
ferrine Feb 16, 2019
ffd7823
add singlt apply
ferrine Feb 16, 2019
0f79e41
black
ferrine Feb 16, 2019
7df0290
fix typos in docs
ferrine Feb 16, 2019
87d4612
fix typos in docs
ferrine Feb 16, 2019
87287d3
add parallel transport
ferrine Feb 17, 2019
80d3e3a
add dist to a plane and parallel transport. Parallel transport is num…
ferrine Feb 17, 2019
f9ea268
fix math bugs
ferrine Feb 17, 2019
5a6a43a
add cool plots
ferrine Feb 17, 2019
e23c1c6
fix small things
ferrine Feb 17, 2019
96d7a4f
add egrad2rgrad
ferrine Feb 17, 2019
c082a7a
add reference
ferrine Feb 17, 2019
c341c7c
docs
ferrine Feb 18, 2019
9d3732b
fix typos
ferrine Feb 18, 2019
7acdfdd
finish Poincare ball implementation
ferrine Feb 18, 2019
d460c1d
fix small typo
ferrine Feb 18, 2019
3d80a8e
add to inifinite and beyond test
ferrine Feb 19, 2019
663217a
add signed distance
ferrine Feb 19, 2019
ff2be7f
infinity and beyond test
ferrine Feb 19, 2019
4e34369
black
ferrine Feb 19, 2019
918a2cf
docfix
ferrine Feb 19, 2019
e3d91e6
fix docs
ferrine Feb 19, 2019
4f31855
fix doc
ferrine Feb 19, 2019
dfca854
fix docs typos
ferrine Feb 19, 2019
264ca20
add import
ferrine Feb 19, 2019
e0bcc6c
add dist0
ferrine Feb 20, 2019
3c82f08
optim fails
ferrine Feb 20, 2019
b42a114
fix numerics, do not repare broken test
ferrine Feb 21, 2019
9256996
black
ferrine Feb 21, 2019
bc7efdb
some refactoring
ferrine Feb 21, 2019
4288db8
fix typo
ferrine Feb 21, 2019
7500ca0
p.data -> p in optim
ferrine Feb 26, 2019
22f8de9
update docs a bit
ferrine Mar 1, 2019
2967d30
split pr
ferrine Mar 2, 2019
f3303b8
remove torch script (it gave minor improvemets), delay to https://git…
ferrine Mar 2, 2019
3d50d23
fix coadd impl
ferrine Mar 2, 2019
0b191ef
coma typo in docs
ferrine Mar 2, 2019
6d8396c
nan police float32
ferrine Mar 2, 2019
8f0dae8
nan police! arcsinh
ferrine Mar 2, 2019
fd6ca2f
typo
ferrine Mar 2, 2019
69c4d21
nan police scripted!\nwratpping artanh in a script function results i…
ferrine Mar 2, 2019
5f61a7f
tests
ferrine Mar 2, 2019
676981a
fix typo
ferrine Mar 2, 2019
75a5ccb
another test for parallel transport 0
ferrine Mar 2, 2019
8dd210c
random doc fix to make typechecker happy
ferrine Mar 2, 2019
c0ed68c
manifold->module migration fix
ferrine Mar 2, 2019
698cf75
black
ferrine Mar 2, 2019
8f5606e
fix test for poincare (autocast double)
ferrine Mar 3, 2019
3f3855a
add float32 tests
ferrine Mar 3, 2019
74cc1da
fix typo
ferrine Mar 3, 2019
42c0dc8
rename project->clip tangent
ferrine Mar 3, 2019
5b5fde0
docs
ferrine Mar 3, 2019
cbe1443
fix side effect in tests
ferrine Mar 3, 2019
189bfee
infinity anb beyond test was failing in torch==1.0.1 but not in torch…
ferrine Mar 3, 2019
e70bdd5
add dim argument for poincare math
ferrine Mar 4, 2019
2f89406
batched matvec
ferrine Mar 5, 2019
1ff3fe8
typo in dist formula
ferrine Mar 7, 2019
56ba08c
fix tracing issues and grad numerics for Arsinh,Artanh
ferrine Mar 7, 2019
e0a5916
_max_norm, specify device + dtype
ferrine Mar 7, 2019
3e6fbfa
clamp before save to backward in artanh
ferrine Mar 7, 2019
9f0df90
inplace ops in function impl
ferrine Mar 7, 2019
09a16cc
black
ferrine Mar 7, 2019
6a45012
fix typo
ferrine Mar 24, 2019
de4c9b6
fix spelling
ferrine Mar 24, 2019
ffd1108
some fixes to docs
ferrine Mar 24, 2019
dc9aa0e
euclidean -> Euclidean
ferrine Mar 31, 2019
92a8da3
black
ferrine Mar 31, 2019
44238f1
math font for number
ferrine Mar 31, 2019
2d47aa1
random travis fail?
ferrine Mar 31, 2019
164a1d4
pytorch future reminder
ferrine Mar 31, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 44 additions & 44 deletions docs/conf.py
Expand Up @@ -36,34 +36,35 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
"matplotlib.sphinxext.plot_directive",
"sphinx.ext.autodoc",
"sphinx.ext.coverage",
"sphinx.ext.mathjax",
"sphinx.ext.viewcode",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]

# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
source_suffix = ".rst"

# The encoding of source files.
#
# source_encoding = 'utf-8-sig'

# The master toctree document.
master_doc = 'index'
master_doc = "index"

# General information about the project.
project = u'geoopt'
copyright = u'2018, Max Kochurov'
author = u'Max Kochurov'
project = u"geoopt"
copyright = u"2018, Max Kochurov"
author = u"Max Kochurov"

# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
Expand Down Expand Up @@ -93,7 +94,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

# The reST default role (used for this markup: `text`) to use for all
# documents.
Expand All @@ -115,7 +116,7 @@
# show_authors = False

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
pygments_style = "sphinx"

# A list of ignored prefixes for module index sorting.
# modindex_common_prefix = []
Expand All @@ -132,7 +133,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'alabaster'
html_theme = "alabaster"

# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
Expand Down Expand Up @@ -166,7 +167,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]

# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
Expand Down Expand Up @@ -246,34 +247,30 @@
# html_search_scorer = 'scorer.js'

# Output file base name for HTML help builder.
htmlhelp_basename = 'geooptdoc'
htmlhelp_basename = "geooptdoc"

# -- Options for LaTeX output ---------------------------------------------

latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',

# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',

# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',

# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}

# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'geoopt.tex', u'geoopt Documentation',
u'Max Kochurov', 'manual'),
(master_doc, "geoopt.tex", u"geoopt Documentation", u"Max Kochurov", "manual")
]

# The name of an image file (relative to this directory) to place at the top of
Expand Down Expand Up @@ -313,10 +310,7 @@

# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'geoopt', u'geoopt Documentation',
[author], 1)
]
man_pages = [(master_doc, "geoopt", u"geoopt Documentation", [author], 1)]

# If true, show URL addresses after external links.
#
Expand All @@ -329,9 +323,15 @@
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'geoopt', u'geoopt Documentation',
author, 'geoopt', 'One line description of project.',
'Miscellaneous'),
(
master_doc,
"geoopt",
u"geoopt Documentation",
author,
"geoopt",
"One line description of project.",
"Miscellaneous",
)
]

# Documents to append as an appendix to all manuals.
Expand All @@ -352,7 +352,7 @@

# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'python': ('https://docs.python.org/', None),
'torch': ('https://pytorch.org/docs/master/', None),
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
"python": ("https://docs.python.org/", None),
"torch": ("https://pytorch.org/docs/master/", None),
}
6 changes: 3 additions & 3 deletions docs/devguide.rst
@@ -1,5 +1,5 @@
Extending ``geoopt``
====================
Developer Guide
===============

Base Manifold
-------------
Expand All @@ -9,7 +9,7 @@ The common base class for all manifolds is :class:`geoopt.manifolds.base.Manifol
.. autoclass:: geoopt.manifolds.base.Manifold
:private-members:
:members:

:noindex:

Metaclass
---------
Expand Down
7 changes: 7 additions & 0 deletions docs/extended.rst
@@ -0,0 +1,7 @@
Extended Guide
==============

.. toctree::
:maxdepth: 1

extended/poincare
117 changes: 117 additions & 0 deletions docs/extended/poincare.rst
@@ -0,0 +1,117 @@
Poincare Ball model
===================

Poincare ball model is a compact representation of hyperbolic space.
To have a nice introduction into this model we should start from
simple concepts, putting them all together to build a more complete picture.

Hyperbolic spaces
-----------------

Hyperbolic space is a constant negative curvature Riemannian manifold.
A very simple example of Riemannian manifold with constant, but positive curvature is sphere.

An (N+1)-dimensional hyperboloid spans the manifold that can be embedded into N-dimensional space via projections.

.. figure:: ../plots/extended/poincare/hyperboloid_projection.png
:width: 300

img source `Wikipedia, Hyperboloid Model <https://en.wikipedia.org/wiki/Hyperboloid_model/>`_

Originally, the distance between points on the hyperboloid is defined as

.. math::

d(x, y) = \arccosh(x, y)

It is difficult to work in (N+1)-dimensional space and there is a range of useful embeddings
exist in literature

Klein Model
~~~~~~~~~~~

.. figure:: ../plots/extended/poincare/klein_tiling.png
:width: 300

img source `Wikipedia, Klein Model <https://en.wikipedia.org/wiki/Beltrami-Klein_model/>`_


Poincare Model
~~~~~~~~~~~~~~

.. figure:: ../plots/extended/poincare/poincare_lines.gif
:width: 300

img source `Bulatov, Poincare Model <http://bulatov.org/math/1001/>`_

Here we go.

First of all we note, that Poincare ball is embedded in a Sphere of radius :math:`r=1/\sqrt{c}`,
where c is negative curvature. We also note, as :math:`c` goes to `0`, we recover infinite radius ball.
We should expect this limiting behaviour recovers Euclidean geometry.

To connect euclidean space with its embedded manifold we need to get :math:`g_x`.
It is done via `conformal factor` :math:`\lambda^c_x`.


.. autofunction:: geoopt.manifolds.poincare.math.lambda_x


:math:`\lambda^c_x` connects euclidean inner product with Riemannian one

.. autofunction:: geoopt.manifolds.poincare.math.inner
.. autofunction:: geoopt.manifolds.poincare.math.norm
.. autofunction:: geoopt.manifolds.poincare.math.egrad2rgrad

Math
----
The good thing about Poincare ball is that it forms a Gyrogroup. Minimal definition of a Gyrogroup
assumes a binary operation :math:`*` defined that satisfies a set of properties.

Left identity
For every element :math:`a\in G` there exist :math:`e\in G` such that :math:`e * a = a`.
Left Inverse
For every element :math:`a\in G` there exist :math:`b\in G` such that :math:`b * a = e`
Gyroassociativity
For any :math:`a,b,c\in G` there exist :math:`gyr[a, b]c\in G` such that :math:`a * (b * c)=(a * b) * gyr[a, b]c`
Gyroautomorphism
:math:`gyr[a, b]` is a magma automorphism in G
Left loop
:math:`gyr[a, b] = gyr[a * b, b]`

As mentioned above, hyperbolic space forms a Gyrogroup equipped with

.. autofunction:: geoopt.manifolds.poincare.math.mobius_add
.. autofunction:: geoopt.manifolds.poincare.math.gyration

Using this math, it is possible to define another useful operations

.. autofunction:: geoopt.manifolds.poincare.math.mobius_sub
.. autofunction:: geoopt.manifolds.poincare.math.mobius_scalar_mul
.. autofunction:: geoopt.manifolds.poincare.math.mobius_pointwise_mul
.. autofunction:: geoopt.manifolds.poincare.math.mobius_matvec
.. autofunction:: geoopt.manifolds.poincare.math.mobius_fn_apply
.. autofunction:: geoopt.manifolds.poincare.math.mobius_fn_apply_chain

Manifold
--------
Now we are ready to proceed with studying distances, geodesics, exponential maps and more

.. autofunction:: geoopt.manifolds.poincare.math.dist
.. autofunction:: geoopt.manifolds.poincare.math.dist2plane
.. autofunction:: geoopt.manifolds.poincare.math.parallel_transport
.. autofunction:: geoopt.manifolds.poincare.math.geodesic
.. autofunction:: geoopt.manifolds.poincare.math.geodesic_unit
.. autofunction:: geoopt.manifolds.poincare.math.expmap
.. autofunction:: geoopt.manifolds.poincare.math.expmap0
.. autofunction:: geoopt.manifolds.poincare.math.logmap
.. autofunction:: geoopt.manifolds.poincare.math.logmap0


Stability
---------
Numerical stability is a pain in this model. It is strongly recommended to work in ``float64``,
so expect adventures in ``float32`` (but this is not certain).

.. autofunction:: geoopt.manifolds.poincare.math.project
.. autofunction:: geoopt.manifolds.poincare.math.clip_tangent
1 change: 1 addition & 0 deletions docs/index.rst
Expand Up @@ -17,6 +17,7 @@ API
optimizers
tensors
samplers
extended
devguide

Indices and tables
Expand Down
8 changes: 3 additions & 5 deletions docs/manifolds.rst
Expand Up @@ -4,13 +4,11 @@ Manifolds
.. currentmodule:: geoopt.manifolds


All manifolds share same API. In order not to duplicate the same information, the complete public API is provided only for :class:`geoopt.manifolds.Euclidean` in the end of this file.
All manifolds share same API. In order not to duplicate the same information, the complete public API is provided only for :class:`geoopt.manifolds.Manifold` in the end of this file.

.. automodule:: geoopt.manifolds
:members: Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection
:members: Euclidean, Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection, PoincareBall


.. autoclass:: geoopt.manifolds.Euclidean
.. autoclass:: geoopt.manifolds.base.Manifold
:members:
:inherited-members:

27 changes: 27 additions & 0 deletions docs/plots/extended/poincare/distance.py
@@ -0,0 +1,27 @@
import geoopt.manifolds.poincare.math as pmath
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("white")
radius = 1
coords = np.linspace(-radius, radius, 100)
x = torch.tensor([-0.75, 0])
xx, yy = np.meshgrid(coords, coords)
dist2 = xx ** 2 + yy ** 2
mask = dist2 <= radius ** 2
grid = np.stack([xx, yy], axis=-1)
dists = pmath.dist(torch.from_numpy(grid).float(), x)
dists[(~mask).nonzero()] = np.nan
circle = plt.Circle((0, 0), 1, fill=False, color="b")
plt.gca().add_artist(circle)
plt.xlim(-1.1, 1.1)
plt.ylim(-1.1, 1.1)
plt.gca().set_aspect("equal")
plt.contourf(
grid[..., 0], grid[..., 1], dists.log().numpy(), levels=100, cmap="inferno"
)
plt.colorbar()
plt.title("log distance to ($-$0.75, 0)")
plt.show()