Skip to content

Commit

Permalink
[OpenMP] Make OpenMPOpt aware of the OpenMP runtime's status
Browse files Browse the repository at this point in the history
The `OpenMPOpt` pass contains optimizations that generate new calls into
the OpenMP runtime. This causes problems if we are in a state where the
runtime has already been linked statically. Generating these new calls
will result in them never being resolved. We should indicate if we are
in a "post-link" LTO phase and prevent OpenMPOpt from generating new
runtime calls.

Generally, it's not desireable for passes to maintain state about the
context in which they're called. But this is the only reasonable
solution to static linking when we have a pass that generates new
runtime calls.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D142646

(cherry picked from commit 0bdde9d)
  • Loading branch information
jhuber6 authored and tstellar committed Jan 28, 2023
1 parent 9833a55 commit c0e53ac
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 8 deletions.
12 changes: 12 additions & 0 deletions llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
Expand Up @@ -37,13 +37,25 @@ KernelSet getDeviceKernels(Module &M);
/// OpenMP optimizations pass.
class OpenMPOptPass : public PassInfoMixin<OpenMPOptPass> {
public:
OpenMPOptPass() : LTOPhase(ThinOrFullLTOPhase::None) {}
OpenMPOptPass(ThinOrFullLTOPhase LTOPhase) : LTOPhase(LTOPhase) {}

PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);

private:
const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
};

class OpenMPOptCGSCCPass : public PassInfoMixin<OpenMPOptCGSCCPass> {
public:
OpenMPOptCGSCCPass() : LTOPhase(ThinOrFullLTOPhase::None) {}
OpenMPOptCGSCCPass(ThinOrFullLTOPhase LTOPhase) : LTOPhase(LTOPhase) {}

PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
LazyCallGraph &CG, CGSCCUpdateResult &UR);

private:
const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
};

} // end namespace llvm
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Passes/PassBuilderPipelines.cpp
Expand Up @@ -1604,7 +1604,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
}

// Try to run OpenMP optimizations, quick no-op if no OpenMP metadata present.
MPM.addPass(OpenMPOptPass());
MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink));

// Remove unused virtual tables to improve the quality of code generated by
// whole-program devirtualization and bitset lowering.
Expand Down Expand Up @@ -1811,7 +1811,8 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
addVectorPasses(Level, MainFPM, /* IsFullLTO */ true);

// Run the OpenMPOpt CGSCC pass again late.
MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(OpenMPOptCGSCCPass()));
MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(
OpenMPOptCGSCCPass(ThinOrFullLTOPhase::FullLTOPostLink)));

invokePeepholeEPCallbacks(MainFPM, Level);
MainFPM.addPass(JumpThreadingPass());
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassRegistry.def
Expand Up @@ -44,6 +44,7 @@ MODULE_PASS("always-inline", AlwaysInlinerPass())
MODULE_PASS("attributor", AttributorPass())
MODULE_PASS("annotation2metadata", Annotation2MetadataPass())
MODULE_PASS("openmp-opt", OpenMPOptPass())
MODULE_PASS("openmp-opt-postlink", OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink))
MODULE_PASS("called-value-propagation", CalledValuePropagationPass())
MODULE_PASS("canonicalize-aliases", CanonicalizeAliasesPass())
MODULE_PASS("cg-profile", CGProfilePass())
Expand Down
53 changes: 47 additions & 6 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Expand Up @@ -188,9 +188,9 @@ struct AAICVTracker;
struct OMPInformationCache : public InformationCache {
OMPInformationCache(Module &M, AnalysisGetter &AG,
BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
KernelSet &Kernels)
KernelSet &Kernels, bool OpenMPPostLink)
: InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
Kernels(Kernels) {
Kernels(Kernels), OpenMPPostLink(OpenMPPostLink) {

OMPBuilder.initialize();
initializeRuntimeFunctions(M);
Expand Down Expand Up @@ -448,6 +448,24 @@ struct OMPInformationCache : public InformationCache {
CI->setCallingConv(Fn->getCallingConv());
}

// Helper function to determine if it's legal to create a call to the runtime
// functions.
bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
// We can always emit calls if we haven't yet linked in the runtime.
if (!OpenMPPostLink)
return true;

// Once the runtime has been already been linked in we cannot emit calls to
// any undefined functions.
for (RuntimeFunction Fn : Fns) {
RuntimeFunctionInfo &RFI = RFIs[Fn];

if (RFI.Declaration && RFI.Declaration->isDeclaration())
return false;
}
return true;
}

/// Helper to initialize all runtime function information for those defined
/// in OpenMPKinds.def.
void initializeRuntimeFunctions(Module &M) {
Expand Down Expand Up @@ -523,6 +541,9 @@ struct OMPInformationCache : public InformationCache {

/// Collection of known OpenMP runtime functions..
DenseSet<const Function *> RTLFunctions;

/// Indicates if we have already linked in the OpenMP device library.
bool OpenMPPostLink = false;
};

template <typename Ty, bool InsertInvalidates = true>
Expand Down Expand Up @@ -1412,7 +1433,10 @@ struct OpenMPOpt {
Changed |= WasSplit;
return WasSplit;
};
RFI.foreachUse(SCC, SplitMemTransfers);
if (OMPInfoCache.runtimeFnsAvailable(
{OMPRTL___tgt_target_data_begin_mapper_issue,
OMPRTL___tgt_target_data_begin_mapper_wait}))
RFI.foreachUse(SCC, SplitMemTransfers);

return Changed;
}
Expand Down Expand Up @@ -3914,6 +3938,12 @@ struct AAKernelInfoFunction : AAKernelInfo {
bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());

// We cannot change to SPMD mode if the runtime functions aren't availible.
if (!OMPInfoCache.runtimeFnsAvailable(
{OMPRTL___kmpc_get_hardware_thread_id_in_block,
OMPRTL___kmpc_barrier_simple_spmd}))
return false;

if (!SPMDCompatibilityTracker.isAssumed()) {
for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
if (!NonCompatibleI)
Expand Down Expand Up @@ -4021,6 +4051,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
if (!ReachedKnownParallelRegions.isValidState())
return ChangeStatus::UNCHANGED;

auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
if (!OMPInfoCache.runtimeFnsAvailable(
{OMPRTL___kmpc_get_hardware_num_threads_in_block,
OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
return ChangeStatus::UNCHANGED;

const int InitModeArgNo = 1;
const int InitUseStateMachineArgNo = 2;

Expand Down Expand Up @@ -4167,7 +4204,6 @@ struct AAKernelInfoFunction : AAKernelInfo {
BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);

Module &M = *Kernel->getParent();
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
FunctionCallee BlockHwSizeFn =
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
Expand Down Expand Up @@ -5343,7 +5379,10 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
BumpPtrAllocator Allocator;
CallGraphUpdater CGUpdater;

OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels);
bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels,
PostLink);

unsigned MaxFixpointIterations =
(isOpenMPDevice(M)) ? SetFixpointIterations : 32;
Expand Down Expand Up @@ -5417,9 +5456,11 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
CallGraphUpdater CGUpdater;
CGUpdater.initialize(CG, C, AM, UR);

bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
SetVector<Function *> Functions(SCC.begin(), SCC.end());
OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
/*CGSCC*/ &Functions, Kernels);
/*CGSCC*/ &Functions, Kernels, PostLink);

unsigned MaxFixpointIterations =
(isOpenMPDevice(M)) ? SetFixpointIterations : 32;
Expand Down
2 changes: 2 additions & 0 deletions llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll
Expand Up @@ -2,7 +2,9 @@
; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=AMDGPU
; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=NVPTX
; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -openmp-opt-disable-state-machine-rewrite -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=AMDGPU
; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=AMDGPU
; RUN: opt --mtriple=nvptx64-- -openmp-opt-disable-state-machine-rewrite -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=NVPTX
; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=NVPTX

;; void p0(void);
;; void p1(void);
Expand Down
2 changes: 2 additions & 0 deletions llvm/test/Transforms/OpenMP/spmdization.ll
Expand Up @@ -2,7 +2,9 @@
; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt < %s | FileCheck %s --check-prefixes=AMDGPU
; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt < %s | FileCheck %s --check-prefixes=NVPTX
; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt -openmp-opt-disable-spmdization < %s | FileCheck %s --check-prefix=AMDGPU-DISABLED
; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=AMDGPU-DISABLED
; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt -openmp-opt-disable-spmdization < %s | FileCheck %s --check-prefix=NVPTX-DISABLED
; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=NVPTX-DISABLED

;; void unknown(void);
;; void spmd_amenable(void) __attribute__((assume("ompx_spmd_amenable")));
Expand Down

0 comments on commit c0e53ac

Please sign in to comment.