Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A bit better public api #40

Merged
merged 18 commits into from
Feb 12, 2019
Merged

A bit better public api #40

merged 18 commits into from
Feb 12, 2019

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Feb 8, 2019

API change

The purpose of this PR is to allow users to use the library in general setting. The current API does not provide enough flexibility in terms of such kind of operations:

  • vector transport from x to y
  • retraction of order k in Manifold.retr
  • retraction+vector transport approximating retraction at given order

These simple features require workarounds in development if we choose the way to maintain the very early API. Therefore I propose to change methods signatures a bit (the change in possible old code is mechanical) and introduce new optional arguments.

Breaking changes

Retraction

-    def retr(self, x, u, t):
+    def retr(self, x, u, t=1., order=None):

Vector transport

-    def transp(self, x, u, t, v, *more):
+    def transp(self, x, v, *more, u=None, t=1., y=None, order=None):

Retraction + vector transport

-    def retr_transp(self, x, u, t, v, *more):
+    def retr_transp(self, x, v, *more, u, t=1.0, order=None):

order is an order of retraction approximation, by default uses the simplest that is usually a first order approximation. Possible choices depend on a concrete manifold and -1 stays for exponential map

exponential map is assumed to be a separate method

Migration

Vector transport

-u = manifold.transp(u, 1.0, v)
+u = manifold.transp(v, u=u, t=1.0)

Retraction + vector transport

-p, v = manifold.retr_transp(p, u, 1.0, v)
+p, v = manifold.retr_transp(p, v, u=u, t=1.0)

New methods

Some new methods are introduced in the base Manifold class:

def expmap(self, x, u, t=1.0):

This method returns best possible approximation (of maximum order) for retraction map. Exponential map is assumed to be ideal and has order -1.

Same holds for exponential map + vector transport

def expmap_transp(self, x, v, *more, u, t=1.0):

It will be possible to specify the default order on per manifold basis with the following method

def set_default_order(self, order):

Developer API changes

Base class

From now I propose to use a metaclass that tracks and registers retractions. Right after a class creation, it filters dir(cls) and looks for special declared methods. If a method is not implemented, then it should be geoopt.base.not_implemented. geoopt.base.not_implemented is just a placeholder for a function that raises not implemented error.

Special private functions contain the following:

  • r"^_retr(\d+)?$" for retraction of a given order (if int postfix provided)
  • r"^_retr(\d+)?_transp$" for retraction and transport
  • r"^_expmap$" for exponential map (retraction with order -1)
  • r"^_expmap_transp$" for exponential map + vector transport (retraction with order -1)
  • r"^_transp_follow(\d+)?$" vector transport that uses direction + retraction rather that the final point
  • r"^_transp_follow_expmap$" vector transport that uses direction + exponential map rather that the final point (retraction+transport with order -1)

After this all is registered in MethodDict (default dict with not_implemented as missing value)

retractoins = MethodDict()
retractoins_transport = MethodDict()
transports_follow = MethodDict()

With this dict it comes possible to define generic dispatch methods for different orders of approximations like this:

    def retr(self, x, u, t=1.0, order=None):
        t = self.broadcast_scalar(t)
        return self._retr_funcs[order](self, x, u, t)

As you see, we avoid weird code that makes use of if or getattr with handling exceptions.

Exponential map is dispatched in the same way

    def expmap(self, x, u, t=1.0):
        t = self.broadcast_scalar(t)
        return self._retr_funcs[-1](self, x, u, t)

To make retr_transp work in case of _transp2y is much more efficient than _transp_follow there is a class attribute _retr_transp_default_preference to indicate this. The attribute should be present in the class definition if differs from default provided in Manifold

Deprecations

methods like

def _transp_one(self, x, u, t, v):
def _transp_many(self, x, u, t, *vs):

are deprecated in favour of their unified alternatives

def _transp_follow(self, x, v, *more, u, t):
def _transp2y(self, x, v, *more, y):
# and others

transp2y method appears to be very useful sometimes and allows public interface to have optional y in transp

@ferrine ferrine changed the title propose a bit better public api A bit better public api Feb 8, 2019
@ferrine ferrine merged commit 7d9bbcf into master Feb 12, 2019
@ferrine ferrine deleted the api branch February 12, 2019 22:11
@ferrine
Copy link
Member Author

ferrine commented Feb 12, 2019

This PR helps with #33

andbloch pushed a commit to andbloch/geoopt that referenced this pull request Dec 29, 2019
* propose public api

* add description to docs, ferify tests

* docs+black\nmove notes section below return section

* set adequate defaults for t

* black

* order note

* refactor developer api

* black

* add more docs

* remove unnessesary code from Euclidean

* set order refactor

* fix error message

* create developer section in docs

* create developer section in docs

* improve docs

* a bit more consistent pytorch usage

* black

* blank line in the end of the file
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant