Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenMP] Move KernelInfoState and AAKernelInfo to OpenMPOpt.h #71878

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

nmustakin
Copy link

Current barrier modeling is assumed to access all/any memory and that is propagated to the kernel. To update that behavior, we need to accumulate the memory effects of all reaching kernels into the memory effect of the barrier. To make the ReachingKernelEntries of AAKernelInfo available, KernelInfoState and AAKernelInfo definitions have been moved to llvm/include/llvm/Transforms/IPO/OpenMPOpt.h along with redefining AAKernelinfo in llvm/lib/Transforms/IPO/OpenMPOpt.cpp to struct AAKernelInfoImpl : AAKernelInfo.

@llvmbot llvmbot added llvm:transforms clang:openmp OpenMP related changes to Clang labels Nov 9, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 9, 2023

@llvm/pr-subscribers-llvm-transforms

Author: None (nmustakin)

Changes

Current barrier modeling is assumed to access all/any memory and that is propagated to the kernel. To update that behavior, we need to accumulate the memory effects of all reaching kernels into the memory effect of the barrier. To make the ReachingKernelEntries of AAKernelInfo available, KernelInfoState and AAKernelInfo definitions have been moved to llvm/include/llvm/Transforms/IPO/OpenMPOpt.h along with redefining AAKernelinfo in llvm/lib/Transforms/IPO/OpenMPOpt.cpp to struct AAKernelInfoImpl : AAKernelInfo.


Patch is 27.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71878.diff

2 Files Affected:

  • (modified) llvm/include/llvm/Transforms/IPO/OpenMPOpt.h (+202)
  • (modified) llvm/lib/Transforms/IPO/OpenMPOpt.cpp (+44-230)
diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index 2499c2bbccf4554..049c1458bb9d630 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -12,6 +12,7 @@
 #include "llvm/Analysis/CGSCCPassManager.h"
 #include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/Transforms/IPO/Attributor.h"
 
 namespace llvm {
 
@@ -62,6 +63,207 @@ class OpenMPOptCGSCCPass : public PassInfoMixin<OpenMPOptCGSCCPass> {
   const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
 };
 
+template <typename Ty, bool InsertInvalidates = true>
+struct BooleanStateWithSetVector : public BooleanState {
+  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
+  bool insert(const Ty &Elem) {
+    if (InsertInvalidates)
+      BooleanState::indicatePessimisticFixpoint();
+    return Set.insert(Elem);
+  }
+
+  const Ty &operator[](int Idx) const { return Set[Idx]; }
+  bool operator==(const BooleanStateWithSetVector &RHS) const {
+    return BooleanState::operator==(RHS) && Set == RHS.Set;
+  }
+  bool operator!=(const BooleanStateWithSetVector &RHS) const {
+    return !(*this == RHS);
+  }
+
+  bool empty() const { return Set.empty(); }
+  size_t size() const { return Set.size(); }
+
+  /// "Clamp" this state with \p RHS.
+  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
+    BooleanState::operator^=(RHS);
+    Set.insert(RHS.Set.begin(), RHS.Set.end());
+    return *this;
+  }
+
+private:
+  /// A set to keep track of elements.
+  SetVector<Ty> Set;
+
+public:
+  typename decltype(Set)::iterator begin() { return Set.begin(); }
+  typename decltype(Set)::iterator end() { return Set.end(); }
+  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
+  typename decltype(Set)::const_iterator end() const { return Set.end(); }
+};
+
+template <typename Ty, bool InsertInvalidates = true>
+using BooleanStateWithPtrSetVector =
+    BooleanStateWithSetVector<Ty *, InsertInvalidates>;
+
+struct KernelInfoState : AbstractState {
+  /// Flag to track if we reached a fixpoint.
+  bool IsAtFixpoint = false;
+
+  /// The parallel regions (identified by the outlined parallel functions) that
+  /// can be reached from the associated function.
+  BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
+      ReachedKnownParallelRegions;
+
+  /// State to track what parallel region we might reach.
+  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
+
+  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
+  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
+  /// false.
+  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
+
+  /// The __kmpc_target_init call in this kernel, if any. If we find more than
+  /// one we abort as the kernel is malformed.
+  CallBase *KernelInitCB = nullptr;
+
+  /// The constant kernel environement as taken from and passed to
+  /// __kmpc_target_init.
+  ConstantStruct *KernelEnvC = nullptr;
+
+  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
+  /// one we abort as the kernel is malformed.
+  CallBase *KernelDeinitCB = nullptr;
+
+  /// Flag to indicate if the associated function is a kernel entry.
+  bool IsKernelEntry = false;
+
+  /// State to track what kernel entries can reach the associated function.
+  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
+
+  /// State to indicate if we can track parallel level of the associated
+  /// function. We will give up tracking if we encounter unknown caller or the
+  /// caller is __kmpc_parallel_51.
+  BooleanStateWithSetVector<uint8_t> ParallelLevels;
+
+  /// Flag that indicates if the kernel has nested Parallelism
+  bool NestedParallelism = false;
+
+  /// Abstract State interface
+  ///{
+
+  KernelInfoState() = default;
+  KernelInfoState(bool BestState) {
+    if (!BestState)
+      indicatePessimisticFixpoint();
+  }
+
+  /// See AbstractState::isValidState(...)
+  bool isValidState() const override { return true; }
+
+  /// See AbstractState::isAtFixpoint(...)
+  bool isAtFixpoint() const override { return IsAtFixpoint; }
+
+  /// See AbstractState::indicatePessimisticFixpoint(...)
+  ChangeStatus indicatePessimisticFixpoint() override {
+    IsAtFixpoint = true;
+    ParallelLevels.indicatePessimisticFixpoint();
+    ReachingKernelEntries.indicatePessimisticFixpoint();
+    SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+    ReachedKnownParallelRegions.indicatePessimisticFixpoint();
+    ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
+    NestedParallelism = true;
+    return ChangeStatus::CHANGED;
+  }
+
+  /// See AbstractState::indicateOptimisticFixpoint(...)
+  ChangeStatus indicateOptimisticFixpoint() override {
+    IsAtFixpoint = true;
+    ParallelLevels.indicateOptimisticFixpoint();
+    ReachingKernelEntries.indicateOptimisticFixpoint();
+    SPMDCompatibilityTracker.indicateOptimisticFixpoint();
+    ReachedKnownParallelRegions.indicateOptimisticFixpoint();
+    ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
+    return ChangeStatus::UNCHANGED;
+  }
+
+  /// Return the assumed state
+  KernelInfoState &getAssumed() { return *this; }
+  const KernelInfoState &getAssumed() const { return *this; }
+
+  bool operator==(const KernelInfoState &RHS) const {
+    if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
+      return false;
+    if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
+      return false;
+    if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
+      return false;
+    if (ReachingKernelEntries != RHS.ReachingKernelEntries)
+      return false;
+    if (ParallelLevels != RHS.ParallelLevels)
+      return false;
+    if (NestedParallelism != RHS.NestedParallelism)
+      return false;
+    return true;
+  }
+
+  /// Returns true if this kernel contains any OpenMP parallel regions.
+  bool mayContainParallelRegion() {
+    return !ReachedKnownParallelRegions.empty() ||
+           !ReachedUnknownParallelRegions.empty();
+  }
+
+  /// Return empty set as the best state of potential values.
+  static KernelInfoState getBestState() { return KernelInfoState(true); }
+
+  static KernelInfoState getBestState(KernelInfoState &KIS) {
+    return getBestState();
+  }
+
+  /// Return full set as the worst state of potential values.
+  static KernelInfoState getWorstState() { return KernelInfoState(false); }
+
+  /// "Clamp" this state with \p KIS.
+  KernelInfoState operator^=(const KernelInfoState &KIS) {
+    // Do not merge two different _init and _deinit call sites.
+    if (KIS.KernelInitCB) {
+      if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelInitCB = KIS.KernelInitCB;
+    }
+    if (KIS.KernelDeinitCB) {
+      if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelDeinitCB = KIS.KernelDeinitCB;
+    }
+    if (KIS.KernelEnvC) {
+      if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelEnvC = KIS.KernelEnvC;
+    }
+    SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
+    ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
+    ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
+    NestedParallelism |= KIS.NestedParallelism;
+    return *this;
+  }
+
+  KernelInfoState operator&=(const KernelInfoState &KIS) {
+    return (*this ^= KIS);
+  }
+};
+
+struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
+  using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
+  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+  /// Public getter for ReachingKernelEntries
+  virtual BooleanStateWithPtrSetVector<Function, false>
+  getReachingKernels() = 0;
+};
+
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_IPO_OPENMPOPT_H
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 5b42f215fb40ca0..4dd32f236c03f3d 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -638,200 +638,6 @@ struct OMPInformationCache : public InformationCache {
   bool OpenMPPostLink = false;
 };
 
-template <typename Ty, bool InsertInvalidates = true>
-struct BooleanStateWithSetVector : public BooleanState {
-  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
-  bool insert(const Ty &Elem) {
-    if (InsertInvalidates)
-      BooleanState::indicatePessimisticFixpoint();
-    return Set.insert(Elem);
-  }
-
-  const Ty &operator[](int Idx) const { return Set[Idx]; }
-  bool operator==(const BooleanStateWithSetVector &RHS) const {
-    return BooleanState::operator==(RHS) && Set == RHS.Set;
-  }
-  bool operator!=(const BooleanStateWithSetVector &RHS) const {
-    return !(*this == RHS);
-  }
-
-  bool empty() const { return Set.empty(); }
-  size_t size() const { return Set.size(); }
-
-  /// "Clamp" this state with \p RHS.
-  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
-    BooleanState::operator^=(RHS);
-    Set.insert(RHS.Set.begin(), RHS.Set.end());
-    return *this;
-  }
-
-private:
-  /// A set to keep track of elements.
-  SetVector<Ty> Set;
-
-public:
-  typename decltype(Set)::iterator begin() { return Set.begin(); }
-  typename decltype(Set)::iterator end() { return Set.end(); }
-  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
-  typename decltype(Set)::const_iterator end() const { return Set.end(); }
-};
-
-template <typename Ty, bool InsertInvalidates = true>
-using BooleanStateWithPtrSetVector =
-    BooleanStateWithSetVector<Ty *, InsertInvalidates>;
-
-struct KernelInfoState : AbstractState {
-  /// Flag to track if we reached a fixpoint.
-  bool IsAtFixpoint = false;
-
-  /// The parallel regions (identified by the outlined parallel functions) that
-  /// can be reached from the associated function.
-  BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
-      ReachedKnownParallelRegions;
-
-  /// State to track what parallel region we might reach.
-  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
-
-  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
-  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
-  /// false.
-  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
-
-  /// The __kmpc_target_init call in this kernel, if any. If we find more than
-  /// one we abort as the kernel is malformed.
-  CallBase *KernelInitCB = nullptr;
-
-  /// The constant kernel environement as taken from and passed to
-  /// __kmpc_target_init.
-  ConstantStruct *KernelEnvC = nullptr;
-
-  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
-  /// one we abort as the kernel is malformed.
-  CallBase *KernelDeinitCB = nullptr;
-
-  /// Flag to indicate if the associated function is a kernel entry.
-  bool IsKernelEntry = false;
-
-  /// State to track what kernel entries can reach the associated function.
-  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
-
-  /// State to indicate if we can track parallel level of the associated
-  /// function. We will give up tracking if we encounter unknown caller or the
-  /// caller is __kmpc_parallel_51.
-  BooleanStateWithSetVector<uint8_t> ParallelLevels;
-
-  /// Flag that indicates if the kernel has nested Parallelism
-  bool NestedParallelism = false;
-
-  /// Abstract State interface
-  ///{
-
-  KernelInfoState() = default;
-  KernelInfoState(bool BestState) {
-    if (!BestState)
-      indicatePessimisticFixpoint();
-  }
-
-  /// See AbstractState::isValidState(...)
-  bool isValidState() const override { return true; }
-
-  /// See AbstractState::isAtFixpoint(...)
-  bool isAtFixpoint() const override { return IsAtFixpoint; }
-
-  /// See AbstractState::indicatePessimisticFixpoint(...)
-  ChangeStatus indicatePessimisticFixpoint() override {
-    IsAtFixpoint = true;
-    ParallelLevels.indicatePessimisticFixpoint();
-    ReachingKernelEntries.indicatePessimisticFixpoint();
-    SPMDCompatibilityTracker.indicatePessimisticFixpoint();
-    ReachedKnownParallelRegions.indicatePessimisticFixpoint();
-    ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
-    NestedParallelism = true;
-    return ChangeStatus::CHANGED;
-  }
-
-  /// See AbstractState::indicateOptimisticFixpoint(...)
-  ChangeStatus indicateOptimisticFixpoint() override {
-    IsAtFixpoint = true;
-    ParallelLevels.indicateOptimisticFixpoint();
-    ReachingKernelEntries.indicateOptimisticFixpoint();
-    SPMDCompatibilityTracker.indicateOptimisticFixpoint();
-    ReachedKnownParallelRegions.indicateOptimisticFixpoint();
-    ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
-    return ChangeStatus::UNCHANGED;
-  }
-
-  /// Return the assumed state
-  KernelInfoState &getAssumed() { return *this; }
-  const KernelInfoState &getAssumed() const { return *this; }
-
-  bool operator==(const KernelInfoState &RHS) const {
-    if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
-      return false;
-    if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
-      return false;
-    if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
-      return false;
-    if (ReachingKernelEntries != RHS.ReachingKernelEntries)
-      return false;
-    if (ParallelLevels != RHS.ParallelLevels)
-      return false;
-    if (NestedParallelism != RHS.NestedParallelism)
-      return false;
-    return true;
-  }
-
-  /// Returns true if this kernel contains any OpenMP parallel regions.
-  bool mayContainParallelRegion() {
-    return !ReachedKnownParallelRegions.empty() ||
-           !ReachedUnknownParallelRegions.empty();
-  }
-
-  /// Return empty set as the best state of potential values.
-  static KernelInfoState getBestState() { return KernelInfoState(true); }
-
-  static KernelInfoState getBestState(KernelInfoState &KIS) {
-    return getBestState();
-  }
-
-  /// Return full set as the worst state of potential values.
-  static KernelInfoState getWorstState() { return KernelInfoState(false); }
-
-  /// "Clamp" this state with \p KIS.
-  KernelInfoState operator^=(const KernelInfoState &KIS) {
-    // Do not merge two different _init and _deinit call sites.
-    if (KIS.KernelInitCB) {
-      if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelInitCB = KIS.KernelInitCB;
-    }
-    if (KIS.KernelDeinitCB) {
-      if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelDeinitCB = KIS.KernelDeinitCB;
-    }
-    if (KIS.KernelEnvC) {
-      if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelEnvC = KIS.KernelEnvC;
-    }
-    SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
-    ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
-    ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
-    NestedParallelism |= KIS.NestedParallelism;
-    return *this;
-  }
-
-  KernelInfoState operator&=(const KernelInfoState &KIS) {
-    return (*this ^= KIS);
-  }
-
-  ///}
-};
-
 /// Used to map the values physically (in the IR) stored in an offload
 /// array, to a vector in memory.
 struct OffloadArray {
@@ -3596,9 +3402,9 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
   unsigned SharedMemoryUsed = 0;
 };
 
-struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
-  using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
-  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+struct AAKernelInfoImpl : AAKernelInfo {
+  AAKernelInfoImpl(const IRPosition &IRP, Attributor &A)
+      : AAKernelInfo(IRP, A) {}
 
   /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
   /// unknown callees.
@@ -3635,27 +3441,34 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
   }
 
   /// Create an abstract attribute biew for the position \p IRP.
-  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
+  static AAKernelInfoImpl &createForPosition(const IRPosition &IRP,
+                                             Attributor &A);
 
   /// See AbstractAttribute::getName()
-  const std::string getName() const override { return "AAKernelInfo"; }
+  const std::string getName() const override { return "AAKernelInfoImpl"; }
 
   /// See AbstractAttribute::getIdAddr()
   const char *getIdAddr() const override { return &ID; }
 
-  /// This function should return true if the type of the \p AA is AAKernelInfo
+  /// This function should return true if the type of the \p AA is
+  /// AAKernelInfoImpl
   static bool classof(const AbstractAttribute *AA) {
     return (AA->getIdAddr() == &ID);
   }
 
   static const char ID;
+
+  /// Return the ReachingKernelEntries
+  BooleanStateWithPtrSetVector<Function, false> getReachingKernels() override {
+    return ReachingKernelEntries;
+  }
 };
 
 /// The function kernel info abstract attribute, basically, what can we say
 /// about a function with regards to the KernelInfoState.
-struct AAKernelInfoFunction : AAKernelInfo {
+struct AAKernelInfoFunction : AAKernelInfoImpl {
   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
-      : AAKernelInfo(IRP, A) {}
+      : AAKernelInfoImpl(IRP, A) {}
 
   SmallPtrSet<Instruction *, 4> GuardedInstructions;
 
@@ -3815,7 +3628,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     };
 
     // Add a dependence to ensure updates if the state changes.
-    auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
+    auto AddDependence = [](Attributor &A, const AAKernelInfoImpl *KI,
                             const AbstractAttribute *QueryingAA) {
       if (QueryingAA) {
         A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
@@ -4119,10 +3932,10 @@ struct AAKernelInfoFunction : AAKernelInfo {
 
     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
       BasicBlock *BB = GuardedI->getParent();
-      auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
+      auto *CalleeAA = A.lookupAAFor<AAKernelInfoImpl>(
           IRPosition::function(*GuardedI->getFunction()), nullptr,
           DepClassTy::NONE);
-      assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
+      assert(CalleeAA != nullptr && "Expected Callee AAKernelInfoImpl");
       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
       // Continue if instruction is already guarded.
       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
@@ -4724,7 +4537,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
           // we cannot fix the internal spmd-zation state either.
           int SPMD = 0, Generic = 0;
           for (auto *Kernel : ReachingKernelEntries) {
-            auto *CBAA = A.getAAFor<AAKernelInfo>(
+            auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
                 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
             if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
                 CBAA->SPMDCompatibilityTracker.isAssumed())
@@ -4745,7 +4558,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     bool AllSPMDStatesWereFixed = true;
     auto CheckCallInst = [&](Instruction &I) {
       auto &CB = cast<CallBase>(I);
-      auto *CBAA = A.getAAFor<AAKernelInfo>(
+      auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
       if...
[truncated]

@@ -5366,7 +5180,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
ChangeStatus foldParallelLevel(Attributor &A) {
std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;

auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change it here.

*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);

if (!CallerKernelInfoAA ||
!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
return indicatePessimisticFixpoint();

for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
DepClassTy::REQUIRED);
auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change it here.

// callee.
if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
const IRPosition &FnPos = IRPosition::function(*F);
auto *FnAA =
A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change it here.

@@ -5054,12 +4868,12 @@ struct AAKernelInfoCallSite : AAKernelInfo {
auto CheckCallee = [&](Function *F, int NumCallees) {
const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);

// If F is not a runtime function, propagate the AAKernelInfo of the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change it here.

AAKernelInfo *AA = nullptr;
AAKernelInfoImpl &AAKernelInfoImpl::createForPosition(const IRPosition &IRP,
Attributor &A) {
AAKernelInfoImpl *AA = nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change it here.


/// See AbstractAttribute::getIdAddr()
const char *getIdAddr() const override { return &ID; }

/// This function should return true if the type of the \p AA is AAKernelInfo
/// This function should return true if the type of the \p AA is
/// AAKernelInfoImpl
static bool classof(const AbstractAttribute *AA) {
return (AA->getIdAddr() == &ID);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to go to AAKernelInfo.


/// See AbstractAttribute::getName()
const std::string getName() const override { return "AAKernelInfo"; }
const std::string getName() const override { return "AAKernelInfoImpl"; }

/// See AbstractAttribute::getIdAddr()
const char *getIdAddr() const override { return &ID; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to go to AAKernelInfo.


/// See AbstractAttribute::getName()
const std::string getName() const override { return "AAKernelInfo"; }
const std::string getName() const override { return "AAKernelInfoImpl"; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to go to AAKernelInfo.

@@ -3635,27 +3441,34 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
}

/// Create an abstract attribute biew for the position \p IRP.
static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
static AAKernelInfoImpl &createForPosition(const IRPosition &IRP,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to go to AAKernelInfo.


/// See AbstractAttribute::getIdAddr()
const char *getIdAddr() const override { return &ID; }

/// This function should return true if the type of the \p AA is AAKernelInfo
/// This function should return true if the type of the \p AA is
/// AAKernelInfoImpl
static bool classof(const AbstractAttribute *AA) {
return (AA->getIdAddr() == &ID);
}

static const char ID;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to go to AAKernelInfo.

Copy link

github-actions bot commented Nov 13, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@doru1004
Copy link
Contributor

To update that behavior, we need to accumulate the memory effects of all reaching kernels into the memory effect of the barrier.

Can you explain what you mean by updating the behavior? Since you're not marking the patch as NFC I am guessing you are also performing a functional change to the code, you're not just relocating the classes.

@jdoerfert
Copy link
Member

To update that behavior, we need to accumulate the memory effects of all reaching kernels into the memory effect of the barrier.

Can you explain what you mean by updating the behavior? Since you're not marking the patch as NFC I am guessing you are also performing a functional change to the code, you're not just relocating the classes.

This commit is NFC, as far as I can tell. The commit message is a justification for this commit, so why do we need to move this around (make it public).

…/IPO/OpenMPOpt.h

Move functions to AAKernelInfo

Fix formatting issues

Update OpenMPOpt.cpp

Update OpenMPOpt.cpp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:openmp OpenMP related changes to Clang llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants