New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Poincare Manifold methods #78
Conversation
d23625c
to
0095f36
Compare
|
||
def transp(self, x, y, v, *more): | ||
def transp(self, x, y, v, *more, dim=-1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have concerns whether this method should work with non-matching v, *more
. The easy way is to do this in a loop and do not care.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe check and raise the warning? (still need to loop tho)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the point in *more
anyway and where we will use it? I have completely missed that part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A flag in parameters may save nerves and add more clarity in what's going on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like stack=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point is to be able to transport multiple vectors with one pass, as this might be much more computationally cheaper. E.g. in Stiefel manifolds this allows performing this operation with one LU decomposition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is how flag is supposed to be implemented
def transp(self, x, y, v, *more, dim=-1, stack=True):
if not more:
return math.parallel_transport(x, y, v, c=self.c, dim=dim)
else:
if stack:
vecs = torch.stack((v,) + more, dim=0)
transp = math.parallel_transport(
x, y, vecs, c=self.c, dim=idx2sign(dim, x.dim())
)
return transp.unbind(0)
else:
return tuple(
math.parallel_transport(x, y, vec, c=self.c, dim=dim)
for vec in (v, *more)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, tuple approach is not that bad
def transp(self, x, y, v, *more, dim=-1):
if not more:
return math.parallel_transport(x, y, v, c=self.c, dim=dim)
else:
return tuple(
math.parallel_transport(x, y, vec, c=self.c, dim=dim)
for vec in (v, *more)
)
I think the PR is ready for review |
try: | ||
vecs = torch.stack((v,) + more, dim=0) | ||
except RuntimeError: | ||
return tuple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When are we going to need to pass *more
?
- We mightn't want
*more
intransp()
to behave differently from*
in e.g.expmap()
- There might be occasions when we'd want to handle
*more
by reshaping and stacking all tensors intobs*man_dims
tensor and then splitting back after transporting. - The latter would be a rather weird default
- Unless we return the
bs*man_dims
array plus indices/access object that knows how to extract arrays of original shapes
Also not sure how I feel about changing return type from array
to tupple
of array
s -- although this confusion is very much pythonic, it is a confusion
@@ -1273,6 +1273,33 @@ def _parallel_transport0(y, v, c, dim: int = -1): | |||
return v * (1 - c * y.pow(2).sum(dim=dim, keepdim=True)).clamp_min(MIN_NORM) | |||
|
|||
|
|||
def parallel_transport0back(x, v, *, c=1.0, dim: int = -1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps s/parallel_//
for brevity?
And (later) s/\btransp\b/transport/
in PoincareBall
class too so that we have just "transport" everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parallel transport is a special case of vector transport that is exact one. That's the purpose of having parallel
in math. In the manifold we do not provide any different, however but could for sphere
geoopt/utils.py
Outdated
if idx < 0: | ||
return idx | ||
else: | ||
return (idx + 1) % -(dim + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some tldr and less nesting levels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have positive indexing of dims, this will make the proper negative index that is used in broadcasting. I will delete that
Do we plan any more changes here? |
* add methods * make parallel transport more restrictive * just use tuple * remove unused fn * update changelog * remove unused * add apply chain * project kwarg * check v is none
This adds shortcuts for
poincare.math
intoPoincareBall