Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: set OpenMP thread num to a proper default value #3943

Merged
merged 13 commits into from
May 22, 2024
11 changes: 10 additions & 1 deletion docs/quick_start/easy_install.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,16 @@ Use 4 MPI processes to run, for example:
mpirun -n 4 abacus
```

> The total thread count(i.e. OpenMP per-process thread count * MPI process count) should not exceed the number of cores in your machine.
The total thread count (i.e. OpenMP per-process thread count * MPI process count) should not exceed the number of cores in your machine.
To use 4 threads and 4 MPI processes, set the environment variable `OMP_NUM_THREADS` before running `mpirun`:

```bash
OMP_NUM_THREADS=4 mpirun -n 4 abacus
```

In this case, the total thread count is 16.

ABACUS will try to determine the number of threads used by each process if `OMP_NUM_THREADS` is not set. However, it is **required** to set `OMP_NUM_THREADS` before running `mpirun` to avoid potential performance issues.

Please refer to [hands-on guide](./hands_on.md) for more instructions.

Expand Down
65 changes: 39 additions & 26 deletions source/module_base/parallel_global.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,43 +172,56 @@ void Parallel_Global::read_mpi_parameters(int argc,char **argv)

// GlobalV::KPAR = atoi(argv[1]); // mohan abandon 2010-06-09

// get the size --> GlobalV::NPROC
// get the rank --> GlobalV::MY_RANK
MPI_Comm_size(MPI_COMM_WORLD,&GlobalV::NPROC);
// get world size --> GlobalV::NPROC
mohanchen marked this conversation as resolved.
Show resolved Hide resolved
// get global rank --> GlobalV::MY_RANK
MPI_Comm_size(MPI_COMM_WORLD,&GlobalV::NPROC);
MPI_Comm_rank(MPI_COMM_WORLD, &GlobalV::MY_RANK);
int process_num = 0; // number of processes in the current node
int local_rank = 0; // rank of the process in the current node
MPI_Comm shmcomm;
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm);
caic99 marked this conversation as resolved.
Show resolved Hide resolved
MPI_Comm_size(shmcomm, &process_num);
MPI_Comm_rank(shmcomm, &local_rank);
MPI_Comm_free(&shmcomm);

// determining appropriate thread number for OpenMP
// Determining appropriate thread number for OpenMP:
// 1. If the number of threads is set by the user by `OMP_NUM_THREADS`, use it.
// 2. Otherwise, set to number of CPU cores / number of processes.
// 3. If the number of threads is larger than the hardware availability (should only happens if route 1 taken),
// output a warning message.
// 4. If the number of threads is smaller than the hardware availability, output an info message.
// CAVEAT: The user should set the number of threads properly to avoid oversubscribing.
// This mechanism only handles the worst case for the default setting (not setting number of threads at all, causing oversubscribing and extremely slow performance), not guaranteed to be optimal.
const int max_thread_num = std::thread::hardware_concurrency(); // Consider Hyperthreading disabled.
#ifdef _OPENMP
int current_thread_num = omp_get_max_threads();
int current_thread_num = omp_get_max_threads(); // Get the number of threads set by the user.
if (current_thread_num == max_thread_num && process_num >= 1) // Avoid oversubscribing on the number of threads not set.
{
current_thread_num = max_thread_num / process_num;
omp_set_num_threads(current_thread_num);
}
#else
int current_thread_num = 1;
#endif
MPI_Comm shmcomm;
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm);
int process_num = 0, local_rank = 0;
MPI_Comm_size(shmcomm, &process_num);
MPI_Comm_rank(shmcomm, &local_rank);
MPI_Comm_free(&shmcomm);
mpi_number = process_num;
omp_number = current_thread_num;
if (current_thread_num * process_num > max_thread_num && local_rank==0)
{
std::stringstream mess;
mess << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
mess << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
mess << "%% WARNING: Total thread number(" << current_thread_num * process_num << ") "
<< "is larger than hardware availability(" << max_thread_num << ")." << std::endl;
mess << "%% WARNING: The results may be INCORRECT. Please be sure what you are doing." << std::endl;
mess << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
mess << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl;
std::cerr << mess.str() << std::endl;
mess << "WARNING: Total thread number(" << current_thread_num * process_num << ") "
<< "is larger than hardware availability(" << max_thread_num << ")." << std::endl
<< "The results may be INCORRECT. Please set the environment variable OMP_NUM_THREADS to a proper value."
<< std::endl;
std::cerr << mess.str() << std::endl;
// the user may take their own risk by set the OMP_NUM_THREADS env var.
if (std::getenv("OMP_NUM_THREADS") == nullptr)
{
exit(1);
}
}
else if (current_thread_num * process_num < max_thread_num && local_rank==0)
{
// only output info in local rank 0
std::cerr << "WARNING: Total thread number on this node mismatches with hardware availability. "
"This may cause poor performance."<< std::endl;
std::cerr << "Info: Local MPI proc number: " << process_num << ","
<< "OpenMP thread number: " << current_thread_num << ","
<< "Total thread number: " << current_thread_num * process_num << ","
Expand Down Expand Up @@ -336,25 +349,25 @@ void Parallel_Global::divide_pools(void)
{
std::cout<<"\n NPROC=" << GlobalV::NPROC << " KPAR=" << GlobalV::KPAR;
std::cout<<"Error : Too many pools !"<<std::endl;
exit(0);
exit(1);
}

// (1) per process in each stogroup
if(GlobalV::NPROC%GlobalV::NSTOGROUP!=0)
{
std::cout<<"\n Error! NPROC="<<GlobalV::NPROC
<<" must be divided evenly by BNDPAR="<<GlobalV::NSTOGROUP<<std::endl;
exit(0);
exit(1);
}
GlobalV::NPROC_IN_STOGROUP = GlobalV::NPROC/GlobalV::NSTOGROUP;
GlobalV::MY_STOGROUP = int(GlobalV::MY_RANK / GlobalV::NPROC_IN_STOGROUP);
GlobalV::RANK_IN_STOGROUP = GlobalV::MY_RANK%GlobalV::NPROC_IN_STOGROUP;
if (GlobalV::NPROC_IN_STOGROUP < GlobalV::KPAR)
{
std::cout<<"\n Error! NPROC_IN_BNDGROUP=" << GlobalV::NPROC_IN_STOGROUP
std::cout<<"\n Error! NPROC_IN_BNDGROUP=" << GlobalV::NPROC_IN_STOGROUP
<<" is smaller than"<< " KPAR=" << GlobalV::KPAR<<std::endl;
std::cout<<" Please reduce KPAR or reduce BNDPAR"<<std::endl;
exit(0);
exit(1);
}

// (2) per process in each pool
Expand All @@ -370,7 +383,7 @@ void Parallel_Global::divide_pools(void)
GlobalV::MY_POOL = int( (GlobalV::RANK_IN_STOGROUP-GlobalV::NPROC_IN_STOGROUP%GlobalV::KPAR) / GlobalV::NPROC_IN_POOL);
GlobalV::RANK_IN_POOL = (GlobalV::RANK_IN_STOGROUP-GlobalV::NPROC_IN_STOGROUP%GlobalV::KPAR)%GlobalV::NPROC_IN_POOL;
}




Expand Down
6 changes: 3 additions & 3 deletions source/module_cell/test/parallel_kpoints_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class ParaKpoints : public ::testing::TestWithParam<ParaPrepare>
TEST(Parallel_KpointsTest, GatherkvecTest) {
// Initialize Parallel_Kpoints object
Parallel_Kpoints parallel_kpoints;

// Initialize local and global vectors
std::vector<ModuleBase::Vector3<double>> vec_local;
std::vector<ModuleBase::Vector3<double>> vec_global;
Expand Down Expand Up @@ -178,7 +178,7 @@ TEST(Parallel_KpointsTest, GatherkvecTest) {
parallel_kpoints.nkstot_np += 1;
parallel_kpoints.startk_pool[2] = 3;
}

// Call gatherkvec method
parallel_kpoints.gatherkvec(vec_local, vec_global);

Expand Down Expand Up @@ -214,7 +214,7 @@ TEST_P(ParaKpoints,DividePools)
{
std::string output;
testing::internal::CaptureStdout();
EXPECT_EXIT(Parallel_Global::init_pools(),testing::ExitedWithCode(0),"");
EXPECT_EXIT(Parallel_Global::init_pools(),testing::ExitedWithCode(1),"");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output,testing::HasSubstr("Too many pools"));
}
Expand Down