Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add base * add mobius add|sub * fix * missing formulas * remove unused import * add scalar mul, test props * unnessesary cons in project * no cover script functions * add distance * fix typo in comment * add geodesics * add expmap * add functions * add singlt apply * black * fix typos in docs * fix typos in docs * add parallel transport * add dist to a plane and parallel transport. Parallel transport is numerically unstable * fix math bugs * add cool plots * fix small things * add egrad2rgrad * add reference * docs * fix typos * finish Poincare ball implementation * fix small typo * add to inifinite and beyond test * add signed distance * infinity and beyond test * black * docfix * fix docs * fix doc * fix docs typos * add import * add dist0 * optim fails * fix numerics, do not repare broken test * black * some refactoring * fix typo * p.data -> p in optim * update docs a bit * split pr * remove torch script (it gave minor improvemets), delay to pytorch/pytorch#14455 resolution * fix coadd impl * coma typo in docs * nan police float32 * nan police! arcsinh * typo * nan police scripted!\nwratpping artanh in a script function results in umstable behavior * tests * fix typo * another test for parallel transport 0 * random doc fix to make typechecker happy * manifold->module migration fix * black * fix test for poincare (autocast double) * add float32 tests * fix typo * rename project->clip tangent * docs * fix side effect in tests * infinity anb beyond test was failing in torch==1.0.1 but not in torch_nightly, acceptable tolerance differs * add dim argument for poincare math * batched matvec * typo in dist formula * fix tracing issues and grad numerics for Arsinh,Artanh * _max_norm, specify device + dtype * clamp before save to backward in artanh * inplace ops in function impl * black * fix typo * fix spelling * some fixes to docs * euclidean -> Euclidean * black * math font for number * random travis fail? * pytorch future reminder
- Loading branch information
Showing
34 changed files
with
2,358 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Extended Guide | ||
============== | ||
|
||
.. toctree:: | ||
:maxdepth: 1 | ||
|
||
extended/poincare |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
Poincare Ball model | ||
=================== | ||
|
||
Poincare ball model is a compact representation of hyperbolic space. | ||
To have a nice introduction into this model we should start from | ||
simple concepts, putting them all together to build a more complete picture. | ||
|
||
Hyperbolic spaces | ||
----------------- | ||
|
||
Hyperbolic space is a constant negative curvature Riemannian manifold. | ||
A very simple example of Riemannian manifold with constant, but positive curvature is sphere. | ||
|
||
An (N+1)-dimensional hyperboloid spans the manifold that can be embedded into N-dimensional space via projections. | ||
|
||
.. figure:: ../plots/extended/poincare/hyperboloid_projection.png | ||
:width: 300 | ||
|
||
img source `Wikipedia, Hyperboloid Model <https://en.wikipedia.org/wiki/Hyperboloid_model/>`_ | ||
|
||
Originally, the distance between points on the hyperboloid is defined as | ||
|
||
.. math:: | ||
d(x, y) = \operatorname{arccosh}(x, y) | ||
It is difficult to work in (N+1)-dimensional space and there is a range of useful embeddings | ||
exist in literature | ||
|
||
Klein Model | ||
~~~~~~~~~~~ | ||
|
||
.. figure:: ../plots/extended/poincare/klein_tiling.png | ||
:width: 300 | ||
|
||
img source `Wikipedia, Klein Model <https://en.wikipedia.org/wiki/Beltrami-Klein_model/>`_ | ||
|
||
|
||
Poincare Model | ||
~~~~~~~~~~~~~~ | ||
|
||
.. figure:: ../plots/extended/poincare/poincare_lines.gif | ||
:width: 300 | ||
|
||
img source `Bulatov, Poincare Model <http://bulatov.org/math/1001/>`_ | ||
|
||
Here we go. | ||
|
||
First of all we note, that Poincare ball is embedded in a Sphere of radius :math:`r=1/\sqrt{c}`, | ||
where c is negative curvature. We also note, as :math:`c` goes to :math:`0`, we recover infinite radius ball. | ||
We should expect this limiting behaviour recovers Euclidean geometry. | ||
|
||
To connect Euclidean space with its embedded manifold we need to get :math:`g_x`. | ||
It is done via `conformal factor` :math:`\lambda^c_x`. | ||
|
||
|
||
.. autofunction:: geoopt.manifolds.poincare.math.lambda_x | ||
|
||
|
||
:math:`\lambda^c_x` connects Euclidean inner product with Riemannian one | ||
|
||
.. autofunction:: geoopt.manifolds.poincare.math.inner | ||
.. autofunction:: geoopt.manifolds.poincare.math.norm | ||
.. autofunction:: geoopt.manifolds.poincare.math.egrad2rgrad | ||
|
||
Math | ||
---- | ||
The good thing about Poincare ball is that it forms a Gyrogroup. Minimal definition of a Gyrogroup | ||
assumes a binary operation :math:`*` defined that satisfies a set of properties. | ||
|
||
Left identity | ||
For every element :math:`a\in G` there exist :math:`e\in G` such that :math:`e * a = a`. | ||
Left Inverse | ||
For every element :math:`a\in G` there exist :math:`b\in G` such that :math:`b * a = e` | ||
Gyroassociativity | ||
For any :math:`a,b,c\in G` there exist :math:`gyr[a, b]c\in G` such that :math:`a * (b * c)=(a * b) * gyr[a, b]c` | ||
Gyroautomorphism | ||
:math:`gyr[a, b]` is a magma automorphism in G | ||
Left loop | ||
:math:`gyr[a, b] = gyr[a * b, b]` | ||
|
||
As mentioned above, hyperbolic space forms a Gyrogroup equipped with | ||
|
||
.. autofunction:: geoopt.manifolds.poincare.math.mobius_add | ||
.. autofunction:: geoopt.manifolds.poincare.math.gyration | ||
|
||
Using this math, it is possible to define another useful operations | ||
|
||
.. autofunction:: geoopt.manifolds.poincare.math.mobius_sub | ||
.. autofunction:: geoopt.manifolds.poincare.math.mobius_scalar_mul | ||
.. autofunction:: geoopt.manifolds.poincare.math.mobius_pointwise_mul | ||
.. autofunction:: geoopt.manifolds.poincare.math.mobius_matvec | ||
.. autofunction:: geoopt.manifolds.poincare.math.mobius_fn_apply | ||
.. autofunction:: geoopt.manifolds.poincare.math.mobius_fn_apply_chain | ||
|
||
Manifold | ||
-------- | ||
Now we are ready to proceed with studying distances, geodesics, exponential maps and more | ||
|
||
.. autofunction:: geoopt.manifolds.poincare.math.dist | ||
.. autofunction:: geoopt.manifolds.poincare.math.dist2plane | ||
.. autofunction:: geoopt.manifolds.poincare.math.parallel_transport | ||
.. autofunction:: geoopt.manifolds.poincare.math.geodesic | ||
.. autofunction:: geoopt.manifolds.poincare.math.geodesic_unit | ||
.. autofunction:: geoopt.manifolds.poincare.math.expmap | ||
.. autofunction:: geoopt.manifolds.poincare.math.expmap0 | ||
.. autofunction:: geoopt.manifolds.poincare.math.logmap | ||
.. autofunction:: geoopt.manifolds.poincare.math.logmap0 | ||
|
||
|
||
Stability | ||
--------- | ||
Numerical stability is a pain in this model. It is strongly recommended to work in ``float64``, | ||
so expect adventures in ``float32`` (but this is not certain). | ||
|
||
.. autofunction:: geoopt.manifolds.poincare.math.project | ||
.. autofunction:: geoopt.manifolds.poincare.math.clip_tangent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ API | |
optimizers | ||
tensors | ||
samplers | ||
extended | ||
devguide | ||
|
||
Indices and tables | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import geoopt.manifolds.poincare.math as pmath | ||
import torch | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
|
||
sns.set_style("white") | ||
radius = 1 | ||
coords = np.linspace(-radius, radius, 100) | ||
x = torch.tensor([-0.75, 0]) | ||
xx, yy = np.meshgrid(coords, coords) | ||
dist2 = xx ** 2 + yy ** 2 | ||
mask = dist2 <= radius ** 2 | ||
grid = np.stack([xx, yy], axis=-1) | ||
dists = pmath.dist(torch.from_numpy(grid).float(), x) | ||
dists[(~mask).nonzero()] = np.nan | ||
circle = plt.Circle((0, 0), 1, fill=False, color="b") | ||
plt.gca().add_artist(circle) | ||
plt.xlim(-1.1, 1.1) | ||
plt.ylim(-1.1, 1.1) | ||
plt.gca().set_aspect("equal") | ||
plt.contourf( | ||
grid[..., 0], grid[..., 1], dists.log().numpy(), levels=100, cmap="inferno" | ||
) | ||
plt.colorbar() | ||
plt.title("log distance to ($-$0.75, 0)") | ||
plt.show() |
Oops, something went wrong.