Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sasaki metric #1568

Merged
merged 41 commits into from Jul 12, 2022
Merged

Sasaki metric #1568

merged 41 commits into from Jul 12, 2022

Conversation

ambellan
Copy link
Contributor

@ambellan ambellan commented Jun 7, 2022

Checklist

  • My pull request has a clear and explanatory title.
  • If neccessary, my code is vectorized.
  • I have added apropriate unit tests.
  • I have made sure the code passes all unit tests. (refer to comment below)
  • My PR follows PEP8 guidelines. (refer to comment below)
  • My PR follows geomstats coding style and API.
  • My code is properly documented and I made sure the documentation renders properly. (Link) (Code is documented, but I cannot get sphinx-build docs/ docs/html running, as it fails immediately.)

Description

A natural choice of metric on the tangent bundle of a Riemannian manifold is the Sasaki metric and this PR implements it.

Additional context

This implementation was originally part of the submission 'Sasaki Metric and Applications in Geodesic Analysis' to the ICLR Computational Geometry & Topology Challenge 2022.

@codecov
Copy link

codecov bot commented Jun 7, 2022

Codecov Report

Merging #1568 (387133f) into master (56cce9e) will increase coverage by 2.37%.
The diff coverage is 98.94%.

@@            Coverage Diff             @@
##           master    #1568      +/-   ##
==========================================
+ Coverage   88.52%   90.89%   +2.37%     
==========================================
  Files         104      110       +6     
  Lines       10263    10594     +331     
==========================================
+ Hits         9084     9628     +544     
+ Misses       1179      966     -213     
Flag Coverage Δ
autograd 88.61% <98.94%> (+0.10%) ⬆️
numpy 87.10% <97.88%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
geomstats/geometry/sasaki_metric.py 98.94% <98.94%> (ø)
geomstats/_backend/numpy/autodiff.py 90.00% <0.00%> (ø)
geomstats/_backend/numpy/_common.py 88.89% <0.00%> (ø)
geomstats/_backend/numpy/random.py 100.00% <0.00%> (ø)
geomstats/_backend/numpy/linalg.py 94.12% <0.00%> (ø)
geomstats/_backend/numpy/__init__.py 98.78% <0.00%> (ø)
geomstats/geometry/symmetric_matrices.py 92.14% <0.00%> (+1.13%) ⬆️
geomstats/geometry/stratified/graph_space.py 53.62% <0.00%> (+1.21%) ⬆️
geomstats/geometry/stratified/point_set.py 96.08% <0.00%> (+25.50%) ⬆️
geomstats/geometry/stratified/spider.py 94.32% <0.00%> (+62.50%) ⬆️
... and 1 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 56cce9e...387133f. Read the comment docs.

@ninamiolane
Copy link
Collaborator

Very nice, thank you!

May I let you address the lint errors in the Github workflows:

  • Linting
  • Deep Source
    or let us know if you cann't/shouldn't address these.

I will let @LPereira95 comment on the tests, and use of pytorch/tensorflow backends.

@luisfpereira luisfpereira self-requested a review June 8, 2022 08:00
@ambellan
Copy link
Contributor Author

ambellan commented Jun 8, 2022

I fixed the linting/deepsource issues (Originally I kept the --ignoreD flag for flake8; that introduced errors I did not notice in advance). However, it is a bit tricky that DeepSource somehow spots similar but different problems as flake8 (i.e. ). Can DeepSource somehow be run before initiating a PR? The tests for the tf and torch backends are failing because right now the implementation only supports np and autograd. However, I used flags to enable testing for np/autograd only and I have no idea why the tf and torch test are still carried out.
Cheerio

Copy link
Collaborator

@luisfpereira luisfpereira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the very interesting contribution @ambellan!

As you can see, I have several remarks to make. Most of them are related with vectorization. Addressing my suggestions will increase the code speed and probably make the code work with all the backends.

I'm looking forward to your feedback on this.

geomstats/geometry/sasaki_metric.py Outdated Show resolved Hide resolved
geomstats/geometry/sasaki_metric.py Outdated Show resolved Hide resolved
geomstats/geometry/sasaki_metric.py Outdated Show resolved Hide resolved
def __init__(self, metric: RiemannianMetric, n_s=3):
self.metric = metric # Riemannian metric of underlying space
self.n_s = n_s # Number of discretization steps
shape = (2, gs.prod(metric.shape))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should shape be (2, gs.prod(metric.shape) or (2, *metric.shape) or (2*metric.shape[0], *metric.shape[1:])?

I think (2, *metric.shape) is the right one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this would introduce a difference if metric.shape consists of more than one element. I am not shure whether this should be changed as it my introduce problems. Ok, I just checked back with @vontycowicz and he confirmed that (2, *metric.shape) did not work together with the Frechet-mean in Kendall shape space. For the moment we would like to keep it as is, unrolling metric.shape. This may be changed later on.

Replacing (2, gs.prod(metric.shape) by (2, gs.prod(gs.array(metric.shape)) already allows to successfully pass the unit tests employing the pytorch backend.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why the unrolling hasn't worked with Frechet mean in Kendal shape space? Can you share a small reproducible example of the error?

The goal of unrolling would be that instead of reshaping to (-1, 2) + self.metric.shape later in the code, we could do (-1,) + self.shape. It would be more robust as we could control reshaping shape from one place only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LPereira95, When taking (2, *metric.shape) instead of (2, gs.prod(metric.shape) problems in the FrechetMean class at different locations, here with two examples:

  1. In the constructor:
    if point_type is None: self.point_type = metric.default_point_type error.check_parameter_accepted_values( self.point_type, "point_type", ["vector", "matrix"] )
    We would have something like 3-tensor as point_type and that is not handled.

  2. In the _default_gradient_descent
    if point_type == "vector": points = gs.to_ndarray(points, to_ndim=2) einsum_str = "n,nj->j" else: points = gs.to_ndarray(points, to_ndim=3) einsum_str = "n,nij->ij"
    There is no handling for point_types different from vector and matrix.

Apart form that, vectorizing the shape shape of the base manifold has the advantage that the iteration of taking the sasaki_metrix (i.e. looking at the tangent bundle of the tangent bundle) does not introduce new dimensions to the shape.

geomstats/geometry/sasaki_metric.py Outdated Show resolved Hide resolved
geomstats/geometry/sasaki_metric.py Outdated Show resolved Hide resolved
geomstats/geometry/sasaki_metric.py Outdated Show resolved Hide resolved
geomstats/geometry/sasaki_metric.py Outdated Show resolved Hide resolved
geomstats/geometry/sasaki_metric.py Outdated Show resolved Hide resolved
geomstats/geometry/sasaki_metric.py Show resolved Hide resolved
@luisfpereira
Copy link
Collaborator

@ambellan, for deep source I don't know if it is possible to run locally (I couldn't find a way by going through their website). Do you know if it is possible @ninamiolane?

About the docs, can you provide the error sphinx-build raises? It should work...

For the tests, the best is to use the flag skip_all as it is done here (sorry for the lack of documentation). Nevertheless, I would prefer to have it working in all the backends.

@ninamiolane
Copy link
Collaborator

I don't know if DeepSource can be run locally, they do have a CLI https://deepsource.io/docs/cli/usage/#commands but it seems to do something different. :/

Yet, note that DeepSource's GitHub workflow usually completes before all of the others, taking ~2min to complete, thus it is not too too long.

@ambellan
Copy link
Contributor Author

@LPereira95 regarding sphinx: I got the following errors:
Extension error: Could not import extension nbsphinx (exception: No module named 'nbsphinx')
Extension error: Could not import extension nbsphinx_link (exception: No module named 'nbsphinx_link')
Extension error: Could not import extension sphinx_gallery.load_style (exception: No module named 'sphinx_gallery')
Theme error: no theme named 'pydata_sphinx_theme' found (missing theme.conf?)
After installing the missing packages I was able to proceed, however, running sphinx-build docs/ docs/html carried out in the geomstats project folder did build a documentation for the geomstats package intalled in the respective python environment. I additonally had to make shure that the right geomstats structure is seen first. I did this by export PYTHONPATH="${PYTHONPATH}:<...>/projects/07geomstats/geomstats". After doing so sphinx did build the correct docs, however, I could not find anything relating to the added Sasaki metric. Is there anything that needs to added to the sphinx configuration?

@ambellan
Copy link
Contributor Author

Dear @LPereira95, for now I incorporated as many of your suggestions as to me seem to be reasonable/feasible. Please take a look at my changes and comments. Thank you for all your valuable remarks, I think they are really helpful to further develop the code. However, there is one thing I don't understand: Why don't you accept any additional comments? I think they usually simply clarify what happens in the code.

@ambellan
Copy link
Contributor Author

ambellan commented Jun 28, 2022

General note: Apart form black I also had to install and run isort in order to pass the linting test. Maybe this could be added to the guidelines.

@ambellan
Copy link
Contributor Author

ambellan commented Jun 30, 2022

@LPereira95 and @ninamiolane, I think I addressed all requested changes. Please have a look and iterate/accept the PR.

@ninamiolane
Copy link
Collaborator

@ambellan Many thanks!! There is an issue with tensorflow, see here: https://github.com/geomstats/geomstats/runs/7131533833?check_suite_focus=true

Would you know what is happening? Good to merge on my side otherwise!

@luisfpereira
Copy link
Collaborator

General note: Apart form black I also had to install and run isort in order to pass the linting test. Maybe this could be added to the guidelines.

May be of your interest @nanjekyejoannah.

Copy link
Collaborator

@luisfpereira luisfpereira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for addressing the requested changes @ambellan! The code is in great shape now. I'm also very pleased by your effort in creating the gradient document.

Can you please address these last comments?

(If you feel we are iterating too much, let me know and I'll address them myself)

geomstats/geometry/sasaki_metric.py Outdated Show resolved Hide resolved
geomstats/geometry/sasaki_metric.py Show resolved Hide resolved
@ambellan
Copy link
Contributor Author

@ninamiolane and @LPereira95, thank your for all your valuable feedback. I think I handled all comments raised so far. Please let me know if there is something more left open, or else please accept the PR.

self.metric = metric
shape = (2, gs.prod(gs.array(metric.shape)))

self.n_jobs = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass n_jobs as input with default value 1 (for consistency with other uses of joblib in the library).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@luisfpereira
Copy link
Collaborator

Thanks @ambellan. Please address the last comment from my side.

Also, can you please try out FrechetMean again, after merging the changes from master? I've made some changes to avoid the use of default_point_type there. I just want to ensure shape works fine for SasakiMetric (do you have a minimal test to add to Frechet mean tests?).

@ambellan
Copy link
Contributor Author

@LPereira95, I tested FrechetMean for SasakiMetric on Kendall shape space and it worked as expected. However, I guess the test case (the one from the ICLR challenge) is, at least in its current form, not suited as minimal working example. As we also planned to contribute our challenge notebook as example on how to use SasakiMetric to Geomstats the respective PR might be the right place to also set up a slimmed version as additional test for FrechetMean. What do you think?

@luisfpereira
Copy link
Collaborator

Thanks a lot @ambellan!

I agree with the idea of addressing the additional test in another PR. Looking forward to it!

@luisfpereira luisfpereira merged commit b975b5f into geomstats:master Jul 12, 2022
@ninamiolane
Copy link
Collaborator

Thank you both!! @LPereira95 @ambellan

@ambellan
Copy link
Contributor Author

@ninamiolane and @LPereira95, thank you both again for handling the PR! I learned a lot about Life, the Universe, and Everything (well, and of course about Geomstats and good code):grinning:.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants