Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poincare Manifold methods #78

Merged
merged 9 commits into from Jun 28, 2019
Merged

Poincare Manifold methods #78

merged 9 commits into from Jun 28, 2019

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Jun 18, 2019

This adds shortcuts for poincare.math into PoincareBall


def transp(self, x, y, v, *more):
def transp(self, x, y, v, *more, dim=-1):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have concerns whether this method should work with non-matching v, *more. The easy way is to do this in a loop and do not care.

Copy link
Member

@rrkarim rrkarim Jun 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe check and raise the warning? (still need to loop tho)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point in *more anyway and where we will use it? I have completely missed that part.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A flag in parameters may save nerves and add more clarity in what's going on

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like stack=True

Copy link
Member Author

@ferrine ferrine Jun 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is to be able to transport multiple vectors with one pass, as this might be much more computationally cheaper. E.g. in Stiefel manifolds this allows performing this operation with one LU decomposition

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how flag is supposed to be implemented

    def transp(self, x, y, v, *more, dim=-1, stack=True):
        if not more:
            return math.parallel_transport(x, y, v, c=self.c, dim=dim)
        else:
            if stack:
                vecs = torch.stack((v,) + more, dim=0)
                transp = math.parallel_transport(
                    x, y, vecs, c=self.c, dim=idx2sign(dim, x.dim())
                )
                return transp.unbind(0)
            else:
                return tuple(
                    math.parallel_transport(x, y, vec, c=self.c, dim=dim)
                    for vec in (v, *more)
                )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, tuple approach is not that bad

    def transp(self, x, y, v, *more, dim=-1):
        if not more:
            return math.parallel_transport(x, y, v, c=self.c, dim=dim)
        else:
            return tuple(
                math.parallel_transport(x, y, vec, c=self.c, dim=dim)
                for vec in (v, *more)
            )

@ferrine
Copy link
Member Author

ferrine commented Jun 19, 2019

I think the PR is ready for review

try:
vecs = torch.stack((v,) + more, dim=0)
except RuntimeError:
return tuple(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When are we going to need to pass *more?

  • We mightn't want *more in transp() to behave differently from * in e.g. expmap()
  • There might be occasions when we'd want to handle *more by reshaping and stacking all tensors into bs*man_dims tensor and then splitting back after transporting.
  • The latter would be a rather weird default
  • Unless we return the bs*man_dims array plus indices/access object that knows how to extract arrays of original shapes

Also not sure how I feel about changing return type from array to tupple of arrays -- although this confusion is very much pythonic, it is a confusion

@@ -1273,6 +1273,33 @@ def _parallel_transport0(y, v, c, dim: int = -1):
return v * (1 - c * y.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM)


def parallel_transport0back(x, v, *, c=1.0, dim: int = -1):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps s/parallel_// for brevity?
And (later) s/\btransp\b/transport/ in PoincareBall class too so that we have just "transport" everywhere

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parallel transport is a special case of vector transport that is exact one. That's the purpose of having parallel in math. In the manifold we do not provide any different, however but could for sphere

geoopt/utils.py Outdated
if idx < 0:
return idx
else:
return (idx + 1) % -(dim + 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tldr and less nesting levels

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have positive indexing of dims, this will make the proper negative index that is used in broadcasting. I will delete that

@ferrine
Copy link
Member Author

ferrine commented Jun 20, 2019

Do we plan any more changes here?

@ferrine ferrine merged commit 93ca52d into master Jun 28, 2019
@ferrine ferrine deleted the manifold-methods branch June 28, 2019 17:11
andbloch pushed a commit to andbloch/geoopt that referenced this pull request Dec 29, 2019
* add methods

* make parallel transport more restrictive

* just use tuple

* remove unused fn

* update changelog

* remove unused

* add apply chain

* project kwarg

* check v is none
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants