Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Make n_chains set the total number of chains across all MPI processes #706

Merged
merged 3 commits into from
May 27, 2021

Conversation

PhilipVinc
Copy link
Member

Recently we had several people confused by the fact that MPI does not particularly improve performance.
There are two issues:

  1. They don't read the documentation (and we don't have a page on MPI)
  2. Our n_chains is a rank-local property, so that if you increase number of MPI ranks you get more chains. However the number of samples is kept fixed.

Point 1) can be solved with better docs.

Point 2) is about inconsistency with the way we set n_samples. I propose to change the bahviour of n_chains so that it sets the number of chains globally according to the formula

n_chains_per_rank = n_chains_per_rank = max(
                    int(np.ceil(n_chains / mpi.n_nodes)), 1
                )

One can still specify n_chains_per_rank if he so desires.

This is just a skeleton implementaiton (though it should mostly work).
As fixing tests everywhere to use everywhere n_chains_per_rank instead of n_chains will take some time, i'll finish this PR only if we get consensus on this.

Note that it will be a fairly breaking change in the behaviour (though it won't technically break code)

@github-actions
Copy link

Hello and thanks for your Contribution!
I will be building previews of the updated documentation at the following link:
https://netket.github.io/netket/preview/pv/n_chains

Once the PR is closed or merged, the preview will be automatically deleted.

Copy link
Collaborator

@femtobit femtobit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very much in favor of this change. n_samples and n_chains should either both be rank-local or both global (and the former option is just confusing).

netket/sampler/base.py Outdated Show resolved Hide resolved
netket/sampler/base.py Outdated Show resolved Hide resolved
netket/sampler/base.py Outdated Show resolved Hide resolved
@PhilipVinc
Copy link
Member Author

Pff.
This plays very badly with Flax struct/dataclass.
I think we should roll our own dataclass.
It should be not much work.
maybe i'll do this at some point

@PhilipVinc
Copy link
Member Author

This is now rebased on top of #716.
It works.
So If we merge 716 we can have this.

@codecov-commenter
Copy link

codecov-commenter commented May 17, 2021

Codecov Report

Merging #706 (0e1a85f) into master (5340864) will decrease coverage by 0.33%.
The diff coverage is 54.38%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #706      +/-   ##
==========================================
- Coverage   69.47%   69.13%   -0.34%     
==========================================
  Files         216      216              
  Lines       12480    12509      +29     
  Branches     1809     1817       +8     
==========================================
- Hits         8670     8648      -22     
- Misses       3335     3378      +43     
- Partials      475      483       +8     
Impacted Files Coverage Δ
netket/sampler/exact.py 85.71% <ø> (ø)
netket/variational/mc_mixed_state.py 87.93% <ø> (ø)
netket/sampler/metropolis_pmap.py 48.75% <8.33%> (+0.60%) ⬆️
netket/variational/mc_state.py 81.01% <60.00%> (-1.38%) ⬇️
netket/sampler/base.py 77.45% <65.38%> (-3.51%) ⬇️
netket/sampler/metropolis.py 84.37% <100.00%> (ø)
netket/utils/__init__.py 100.00% <100.00%> (ø)
netket/utils/mpi/mpi.py 53.48% <0.00%> (-30.24%) ⬇️
netket/legacy/stats/_sum_inplace.py 52.94% <0.00%> (-21.57%) ⬇️
netket/utils/mpi/primitives.py 35.82% <0.00%> (-13.44%) ⬇️
... and 6 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5340864...0e1a85f. Read the comment docs.

netket/sampler/base.py Outdated Show resolved Hide resolved
@PhilipVinc PhilipVinc marked this pull request as ready for review May 20, 2021 09:44
@PhilipVinc
Copy link
Member Author

@gcarleo Do we want to do this?

@PhilipVinc
Copy link
Member Author

What this does, in the end, is that samplers can be build with

sa = nk.sampler.MetropolisLocal(n_chains=X)

and when run under MPI, every rank will use

n_chains_per_rank = n_chains_per_rank = max(
                    int(np.ceil(n_chains / mpi.n_nodes)), 1
                )

or you can create them with

sa = nk.sampler.MetropolisLocal(n_chains_per_rank=X)

that will match current behaviour.

Under nonmpi nothing changes

@gcarleo
Copy link
Member

gcarleo commented May 20, 2021

Yes, I am just worried that if one leaves n_chains=16 and runs on 1000 MPI ranks might be really surprised by the new behavior...

@PhilipVinc
Copy link
Member Author

What will happen in this case is 1 chain per rank + warning saying that 1*1000 != 16

The cleanest way to do this is usually to make n_chains an error and have n_chains_per_rank and n_chains_total and in a future release deprecate n_chains_total and go back to n_chains.

But that would break everything for people running stuff locally.

@gcarleo
Copy link
Member

gcarleo commented May 20, 2021

yeah I mean, I think this change is consistent with the fact that n_samplesfor us is really the total number of samples, not the n_samples_per_rank (btw, that might make sense too...) I am not against merging this actually

@PhilipVinc
Copy link
Member Author

PhilipVinc commented May 20, 2021

n_samples_per_rank I can add that in this PR.

I think a good alternative would be to print a warning always when running under MPI with n_chains for this release (warning can be disabled with a flag) saying that the behaviour changed.
Then we get rid of the warning in the next release (cc @femtobit)

Copy link
Collaborator

@femtobit femtobit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_samples_per_rank I can add that in this PR.

Yes, that would be nice to have for consistency. Either way, I'm happy with this PR.

I think a good alternative would be to print a warning always when running under MPI with n_chains for this release (warning can be disabled with a flag) saying that the behaviour changed.
Then we get rid of the warning in the next release (cc @femtobit)

Maybe... It'd be a warning that is displayed essentially every time NetKet is run, so it'd be pretty prominent (which is good to get people to notice, but can also be annoying - the flag helps but needs to be specified all the time). I'm undecided, feel free to do what you think is best.

@PhilipVinc
Copy link
Member Author

So if @gcarleo agrees I'll add

  • n_samples_per_rank to MCVariationalState.

And change the behaviour so that

  • n_chains becomes n_chains_per_rank
  • n_chains will now set the global number of chains.

If n_chains is not perfectly divisible by the number of ranks we print a warning, only on rank 0.
I know it's annoying but I think this is the correct thing to do.
Regardless, when you run stuff under MPI you already have some visual noise so I think this is not so bad.

@gcarleo
Copy link
Member

gcarleo commented May 27, 2021

Ok yes please add n_samples_per_rank and change n_chains accordingly, this looks like a good solution

fixup dtype
fixup!
impro
Update netket/sampler/base.py

Co-authored-by: Damian Hofmann <femtobit@users.noreply.github.com>
black
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants