Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor performance of MPI in v3.0 #464

Closed
kastoryano opened this issue Aug 17, 2020 · 12 comments
Closed

Poor performance of MPI in v3.0 #464

kastoryano opened this issue Aug 17, 2020 · 12 comments
Labels
docs Documentation-related issues

Comments

@kastoryano
Copy link

I have run a number of experiments comparing the performance of NK2.1 and NK3.0 on multi cores using mpi. I see very inconsistent performance for NK3.0, which leads me to question whether mpi is implemented effectively (optimally) in v3.0. Below is an example experiment:

NetKet 2.1, heisenberg1d.py, L=60, alpha = 3, symmetries = True, samples = 2000

1 core (no mpi) : 24 s/it
9 cores : 3 s/it
18 cores: 1.5 s/it
36 cores: 2.3 s/it

It seems 18 cores reaches a minimum. This also is the case in all the other experiments i ran, varying #samples, and system size (#variables). Similar performance is seen for conv networks.

For NetKet3.0 (without jax) on the exact same system, I get:

1 core (no mpi): 6 s/it
8 cores : 10 s/it
18 cores : 20 s/it

I have not found any scenario in v3.0 where multicore mode (mpi) helps (Note: the convolutional NN examples do not run in 3.0. The layer functon returns an error).

In all experiments, I am using

mpirun -n x_cores python3 filename.py

without modyfying anything in the library except the parameters for the experiments. Using mpiexec does not change anything. Any ideas what is happening?

@PhilipVinc
Copy link
Member

PhilipVinc commented Aug 17, 2020

Hmm interesting.
What is your test file? what Network are you using? are you using SR?

EDIT: heisemberg1d.py

Can you give me the following information:

  • what is the cpu?
  • can you test netket v3.0 with 2,3,4 cores too?
  • can you check the cpu usage of python processes when doing so? (using top, for example)

my guess is that numpy multithreads BLAS differently than v2.1's EIGEN, which could lead to weird interplay with MPI on the same system.

If you could also test 1,2,4 cores without SR it would be helpful to understand where something is going wrong.

@kastoryano
Copy link
Author

I think I now understand the problem. There is multithreading happening automatically in v3.0 (OpenMP i guess?), which is not happening in v2.1. When i deactivate multithreading, the performance is about a factor of six faster better than in v2.1 on the example above. So there is actually a significant performance boost in v3.0. Nice!

It would be really helpful to have some documentation about the paralellization for us plebeians.

@PhilipVinc
Copy link
Member

Yes and no: we don't have and never have had explicit multithreading (except for a few RBM kernels in v2.1), but BLAS does multithread linear algebra operations on sufficiently big arrays.
Numpy (v3.0) is very aggressive in this multithreading (I guess due to the GIL forcing single-threaded execution), and it's threshold is lower than C++'s Eigen (v2.1). Also, Eigen's multithreading on some systems is automatically disabled, and requires an ENV variable to activate.

if you want to disable this I'm sure you can do it from numpy check this comment or the threadpoolctl package.

About documenting this... yes, you are right, but documenting takes time. (And I'd like to argue that scientific-computing users should be aware of the implicit threading done by BLAS libraries in general. This is not a netket-specific issue, but, yeah... most people don't know, and learn when you hit an issue like that.)

@gcarleo
Copy link
Member

gcarleo commented Aug 19, 2020 via email

@PhilipVinc
Copy link
Member

PhilipVinc commented Aug 19, 2020 via email

@femtobit
Copy link
Collaborator

femtobit commented Sep 7, 2020

If running across N MPI processes with M total samples, every process will be solving a system where one dimension is M/N, so it will be a smaller one.

Is this the case now in the v3.0 branch? I do recall that in NetKet 2, the parameters are first synchronized over all MPI workers and then the SR equation is solved using all samples (here for the full and here for the iterative solver), in which case it'd make sense to only do this only on the root node; I think this is what Giuseppe was referring to. Of course, this is not necessarily an efficient way to do this, as you point out, since in any case this means the performance of the root node on the full linear system becomes the bottleneck (which I think becomes noticeable in Netket v2 for a larger number of parameters).

The question is whether solving the N systems with fewer samples separately and averaging the solution is equivalent (in the presence of numerical error) to solving the full system after computing the joint S matrix and gradient for all samples or whether this has a risk of introducing another source of error. Intuitively, I think this should not be the case for sufficiently many samples in which case your suggestion seems to be clearly preferable to only working on the root node. Maybe we should test this for v3.

Since we're discussing parallel performance: In the version 3 Python code, is the MCMC part (i.e., the computation of samples using AbstractSampler) currently only parallelized via MPI or does it also benefit from adding more CPUs (and potentially GPUs using Jax or PyTorch) to a single MPI worker?

@VolodyaCO
Copy link
Collaborator

I'm having also problems with parallelisation. I'm not explicitly doing anything with MPI, but currently I am using optuna to find hyperparameters of the wavefunction. Optuna allows to generate hyperparameter proposals and run them in parallel. But it seems that every generated job also performs parallel computations, thus requiring more resources than the computer actually has.

It would be good to have something that limits the number of cores. I tried with threadpoolctl, but since it acts on a job belonging to a pool of jobs, it gets confused (as stated in their issues).

@PhilipVinc
Copy link
Member

What are you using, numpy, jax, or torch?

@VolodyaCO
Copy link
Collaborator

I'm using a jax machine and a jax sampler. If I don't get it to work I'll post details in a couple of days

@PhilipVinc
Copy link
Member

As I said above, Jax does some parallelization on his own. Check it’s docs and what ENV variables to set.

@VolodyaCO
Copy link
Collaborator

Ok I had to use a somewhat ugly workaround, but it works (unfortunately setting some env variables didn't completely work XLA_FLAGS="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" python my_file.py). This was the solution: taskset -c 0 python myscript.py Then I used the subprocess module so that optuna does distributed optimisation of hyperparameters.

@PhilipVinc
Copy link
Member

RECAP: this is a non-issue. Maybe we should document the fact that jax and numpy do parallelisation on their own and how to disable it but that is also somewhat known stuff.

@PhilipVinc PhilipVinc added the docs Documentation-related issues label Dec 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs Documentation-related issues
Projects
None yet
Development

No branches or pull requests

5 participants