Skip to content

Commit

Permalink
[OpenMP][NFCI] Avoid storing non-constant values in ICV
Browse files Browse the repository at this point in the history
If we store a constant in an ICV it is easier for the optimizer to
propagate it. Since we often use the full block for the thread limit and
the parallel team size, we can instead replace that dynamic value with a
constant that otherwise cannot occur, here 0.
  • Loading branch information
jdoerfert committed Jul 18, 2023
1 parent 88a68de commit f914208
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
4 changes: 4 additions & 0 deletions openmp/libomptarget/DeviceRTL/include/State.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ struct ICVStateTy {
uint32_t NThreadsVar;
uint32_t LevelVar;
uint32_t ActiveLevelVar;
uint32_t Padding0Val;
uint32_t MaxActiveLevelsVar;
uint32_t RunSchedVar;
uint32_t RunSchedChunkVar;
Expand Down Expand Up @@ -339,6 +340,9 @@ void runAndCheckState(void(Func(void)));

void assumeInitialState(bool IsSPMD);

/// Return the value of the ParallelTeamSize ICV.
int getEffectivePTeamSize();

} // namespace state

namespace icv {
Expand Down
1 change: 0 additions & 1 deletion openmp/libomptarget/DeviceRTL/src/Mapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ uint32_t mapping::getThreadIdInWarp() {

uint32_t mapping::getThreadIdInBlock() {
uint32_t ThreadIdInBlock = impl::getThreadIdInBlock();
ASSERT(ThreadIdInBlock < impl::getNumHardwareThreadsInBlock(), nullptr);
return ThreadIdInBlock;
}

Expand Down
12 changes: 7 additions & 5 deletions openmp/libomptarget/DeviceRTL/src/Parallelism.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ void __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
ASSERT(state::HasThreadState == false, nullptr);

uint32_t NumThreads = determineNumberOfThreads(num_threads);
uint32_t BlockSize = mapping::getBlockSize();
uint32_t PTeamSize = NumThreads == BlockSize ? 0 : NumThreads;
if (mapping::isSPMDMode()) {
// Avoid the race between the read of the `icv::Level` above and the write
// below by synchronizing all threads here.
Expand All @@ -118,7 +120,7 @@ void __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
// Note that the order here is important. `icv::Level` has to be updated
// last or the other updates will cause a thread specific state to be
// created.
state::ValueRAII ParallelTeamSizeRAII(state::ParallelTeamSize, NumThreads,
state::ValueRAII ParallelTeamSizeRAII(state::ParallelTeamSize, PTeamSize,
1u, TId == 0, ident,
/* ForceTeamState */ true);
state::ValueRAII ActiveLevelRAII(icv::ActiveLevel, 1u, 0u, TId == 0,
Expand All @@ -130,7 +132,7 @@ void __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
// team state properly.
synchronize::threadsAligned(atomic::acq_rel);

state::ParallelTeamSize.assert_eq(NumThreads, ident,
state::ParallelTeamSize.assert_eq(PTeamSize, ident,
/* ForceTeamState */ true);
icv::ActiveLevel.assert_eq(1u, ident, /* ForceTeamState */ true);
icv::Level.assert_eq(1u, ident, /* ForceTeamState */ true);
Expand All @@ -139,7 +141,7 @@ void __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
// assumptions above.
synchronize::threadsAligned(atomic::relaxed);

if (TId < NumThreads)
if (!PTeamSize || TId < PTeamSize)
invokeMicrotask(TId, 0, fn, args, nargs);

// Synchronize all threads at the end of a parallel region.
Expand Down Expand Up @@ -239,7 +241,7 @@ void __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
// Note that the order here is important. `icv::Level` has to be updated
// last or the other updates will cause a thread specific state to be
// created.
state::ValueRAII ParallelTeamSizeRAII(state::ParallelTeamSize, NumThreads,
state::ValueRAII ParallelTeamSizeRAII(state::ParallelTeamSize, PTeamSize,
1u, true, ident,
/* ForceTeamState */ true);
state::ValueRAII ParallelRegionFnRAII(state::ParallelRegionFn, wrapper_fn,
Expand Down Expand Up @@ -272,7 +274,7 @@ __kmpc_kernel_parallel(ParallelRegionFnTy *WorkFn) {

// Set to true for workers participating in the parallel region.
uint32_t TId = mapping::getThreadIdInBlock();
bool ThreadIsActive = TId < state::ParallelTeamSize;
bool ThreadIsActive = TId < state::getEffectivePTeamSize();
return ThreadIsActive;
}

Expand Down
17 changes: 13 additions & 4 deletions openmp/libomptarget/DeviceRTL/src/State.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ void state::ICVStateTy::assertEqual(const ICVStateTy &Other) const {
}

void state::TeamStateTy::init(bool IsSPMD) {
ICVState.NThreadsVar = mapping::getBlockSize(IsSPMD);
ICVState.NThreadsVar = 0;
ICVState.LevelVar = 0;
ICVState.ActiveLevelVar = 0;
ICVState.Padding0Val = 0;
ICVState.MaxActiveLevelsVar = 1;
ICVState.RunSchedVar = omp_sched_static;
ICVState.RunSchedChunkVar = 1;
Expand Down Expand Up @@ -312,14 +313,22 @@ void state::assumeInitialState(bool IsSPMD) {
ASSERT(mapping::isSPMDMode() == IsSPMD, nullptr);
}

int state::getEffectivePTeamSize() {
int PTeamSize = state::ParallelTeamSize;
return PTeamSize ? PTeamSize : mapping::getBlockSize();
}

extern "C" {
void omp_set_dynamic(int V) {}

int omp_get_dynamic(void) { return 0; }

void omp_set_num_threads(int V) { icv::NThreads = V; }

int omp_get_max_threads(void) { return icv::NThreads; }
int omp_get_max_threads(void) {
int NT = icv::NThreads;
return NT > 0 ? NT : mapping::getBlockSize();
}

int omp_get_level(void) {
int LevelVar = icv::Level;
Expand Down Expand Up @@ -350,11 +359,11 @@ int omp_get_thread_num(void) {
}

int omp_get_team_size(int Level) {
return returnValIfLevelIsActive(Level, state::ParallelTeamSize, 1);
return returnValIfLevelIsActive(Level, state::getEffectivePTeamSize(), 1);
}

int omp_get_num_threads(void) {
return omp_get_level() > 1 ? 1 : state::ParallelTeamSize;
return omp_get_level() != 1 ? 1 : state::getEffectivePTeamSize();
}

int omp_get_thread_limit(void) { return mapping::getBlockSize(); }
Expand Down

0 comments on commit f914208

Please sign in to comment.