diff --git a/mlir/include/mlir/Analysis/AliasAnalysis.h b/mlir/include/mlir/Analysis/AliasAnalysis.h index f3fce42097a97c..925af24ddd69c4 100644 --- a/mlir/include/mlir/Analysis/AliasAnalysis.h +++ b/mlir/include/mlir/Analysis/AliasAnalysis.h @@ -67,14 +67,106 @@ class AliasResult { /// Returns if this result is a partial alias. bool isPartial() const { return kind == PartialAlias; } - /// Return the internal kind of this alias result. - Kind getKind() const { return kind; } + /// Print this alias result to the provided output stream. + void print(raw_ostream &os) const; private: /// The internal kind of the result. Kind kind; }; +inline raw_ostream &operator<<(raw_ostream &os, const AliasResult &result) { + result.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +// ModRefResult +//===----------------------------------------------------------------------===// + +/// The possible results of whether a memory access modifies or references +/// a memory location. The possible results are: no access at all, a +/// modification, a reference, or both a modification and a reference. +class LLVM_NODISCARD ModRefResult { + /// Note: This is a simplified version of the ModRefResult in + /// `llvm/Analysis/AliasAnalysis.h`, and namely removes the `Must` concept. If + /// this becomes useful/necessary we should add it here. + enum class Kind { + /// The access neither references nor modifies the value stored in memory. + NoModRef = 0, + /// The access may reference the value stored in memory. + Ref = 1, + /// The access may modify the value stored in memory. + Mod = 2, + /// The access may reference and may modify the value stored in memory. + ModRef = Ref | Mod, + }; + +public: + bool operator==(const ModRefResult &rhs) const { return kind == rhs.kind; } + bool operator!=(const ModRefResult &rhs) const { return !(*this == rhs); } + + /// Return a new result that indicates that the memory access neither + /// references nor modifies the value stored in memory. + static ModRefResult getNoModRef() { return Kind::NoModRef; } + + /// Return a new result that indicates that the memory access may reference + /// the value stored in memory. + static ModRefResult getRef() { return Kind::Ref; } + + /// Return a new result that indicates that the memory access may modify the + /// value stored in memory. + static ModRefResult getMod() { return Kind::Mod; } + + /// Return a new result that indicates that the memory access may reference + /// and may modify the value stored in memory. + static ModRefResult getModAndRef() { return Kind::ModRef; } + + /// Returns if this result does not modify or reference memory. + LLVM_NODISCARD bool isNoModRef() const { return kind == Kind::NoModRef; } + + /// Returns if this result modifies memory. + LLVM_NODISCARD bool isMod() const { + return static_cast(kind) & static_cast(Kind::Mod); + } + + /// Returns if this result references memory. + LLVM_NODISCARD bool isRef() const { + return static_cast(kind) & static_cast(Kind::Ref); + } + + /// Returns if this result modifies *or* references memory. + LLVM_NODISCARD bool isModOrRef() const { return kind != Kind::NoModRef; } + + /// Returns if this result modifies *and* references memory. + LLVM_NODISCARD bool isModAndRef() const { return kind == Kind::ModRef; } + + /// Merge this ModRef result with `other` and return the result. + ModRefResult merge(const ModRefResult &other) { + return ModRefResult(static_cast(static_cast(kind) | + static_cast(other.kind))); + } + /// Intersect this ModRef result with `other` and return the result. + ModRefResult intersect(const ModRefResult &other) { + return ModRefResult(static_cast(static_cast(kind) & + static_cast(other.kind))); + } + + /// Print this ModRef result to the provided output stream. + void print(raw_ostream &os) const; + +private: + ModRefResult(Kind kind) : kind(kind) {} + + /// The internal kind of the result. + Kind kind; +}; + +inline raw_ostream &operator<<(raw_ostream &os, const ModRefResult &result) { + result.print(os); + return os; +} + //===----------------------------------------------------------------------===// // AliasAnalysisTraits //===----------------------------------------------------------------------===// @@ -92,6 +184,9 @@ struct AliasAnalysisTraits { /// Given two values, return their aliasing behavior. virtual AliasResult alias(Value lhs, Value rhs) = 0; + + /// Return the modify-reference behavior of `op` on `location`. + virtual ModRefResult getModRef(Operation *op, Value location) = 0; }; /// This class represents the `Model` of an alias analysis implementation @@ -108,6 +203,11 @@ struct AliasAnalysisTraits { return impl.alias(lhs, rhs); } + /// Return the modify-reference behavior of `op` on `location`. + ModRefResult getModRef(Operation *op, Value location) final { + return impl.getModRef(op, location); + } + private: ImplT impl; }; @@ -147,7 +247,12 @@ class AliasAnalysis { /// * AnalysisT(AnalysisT &&) /// * AliasResult alias(Value lhs, Value rhs) /// - This method returns an `AliasResult` that corresponds to the - /// aliasing behavior between `lhs` and `rhs`. + /// aliasing behavior between `lhs` and `rhs`. The conservative "I don't + /// know" result of this method should be MayAlias. + /// * ModRefResult getModRef(Operation *op, Value location) + /// - This method returns a `ModRefResult` that corresponds to the + /// modify-reference behavior of `op` on the given `location`. The + /// conservative "I don't know" result of this method should be ModRef. template void addAnalysisImplementation(AnalysisT &&analysis) { aliasImpls.push_back( @@ -161,6 +266,13 @@ class AliasAnalysis { /// Given two values, return their aliasing behavior. AliasResult alias(Value lhs, Value rhs); + //===--------------------------------------------------------------------===// + // ModRef Queries + //===--------------------------------------------------------------------===// + + /// Return the modify-reference behavior of `op` on `location`. + ModRefResult getModRef(Operation *op, Value location); + private: /// A set of internal alias analysis implementations. SmallVector, 4> aliasImpls; diff --git a/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h b/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h index 45edd2088cdd66..afed185e6c29c2 100644 --- a/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h +++ b/mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h @@ -25,6 +25,9 @@ class LocalAliasAnalysis { public: /// Given two values, return their aliasing behavior. AliasResult alias(Value lhs, Value rhs); + + /// Return the modify-reference behavior of `op` on `location`. + ModRefResult getModRef(Operation *op, Value location); }; } // end namespace mlir diff --git a/mlir/lib/Analysis/AliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis.cpp index 946825e156afc3..2f2b7824952866 100644 --- a/mlir/lib/Analysis/AliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis.cpp @@ -27,6 +27,44 @@ AliasResult AliasResult::merge(AliasResult other) const { return MayAlias; } +void AliasResult::print(raw_ostream &os) const { + switch (kind) { + case Kind::NoAlias: + os << "NoAlias"; + break; + case Kind::MayAlias: + os << "MayAlias"; + break; + case Kind::PartialAlias: + os << "PartialAlias"; + break; + case Kind::MustAlias: + os << "MustAlias"; + break; + } +} + +//===----------------------------------------------------------------------===// +// ModRefResult +//===----------------------------------------------------------------------===// + +void ModRefResult::print(raw_ostream &os) const { + switch (kind) { + case Kind::NoModRef: + os << "NoModRef"; + break; + case Kind::Ref: + os << "Ref"; + break; + case Kind::Mod: + os << "Mod"; + break; + case Kind::ModRef: + os << "ModRef"; + break; + } +} + //===----------------------------------------------------------------------===// // AliasAnalysis //===----------------------------------------------------------------------===// @@ -35,7 +73,6 @@ AliasAnalysis::AliasAnalysis(Operation *op) { addAnalysisImplementation(LocalAliasAnalysis()); } -/// Given the two values, return their aliasing behavior. AliasResult AliasAnalysis::alias(Value lhs, Value rhs) { // Check each of the alias analysis implemenations for an alias result. for (const std::unique_ptr &aliasImpl : aliasImpls) { @@ -45,3 +82,16 @@ AliasResult AliasAnalysis::alias(Value lhs, Value rhs) { } return AliasResult::MayAlias; } + +ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) { + // Compute the mod-ref behavior by refining a top `ModRef` result with each of + // the alias analysis implementations. We early exit at the point where we + // refine down to a `NoModRef`. + ModRefResult result = ModRefResult::getModAndRef(); + for (const std::unique_ptr &aliasImpl : aliasImpls) { + result = result.intersect(aliasImpl->getModRef(op, location)); + if (result.isNoModRef()) + return result; + } + return result; +} diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index 17a9ded8691732..062443a39619b8 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -195,7 +195,7 @@ static void collectUnderlyingAddressValues(Value value, } //===----------------------------------------------------------------------===// -// LocalAliasAnalysis +// LocalAliasAnalysis: alias //===----------------------------------------------------------------------===// /// Given a value, try to get an allocation effect attached to it. If @@ -336,3 +336,56 @@ AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) { // We should always have a valid result here. return *result; } + +//===----------------------------------------------------------------------===// +// LocalAliasAnalysis: getModRef +//===----------------------------------------------------------------------===// + +ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) { + // Check to see if this operation relies on nested side effects. + if (op->hasTrait()) { + // TODO: To check recursive operations we need to check all of the nested + // operations, which can result in a quadratic number of queries. We should + // introduce some caching of some kind to help alleviate this, especially as + // this caching could be used in other areas of the codebase (e.g. when + // checking `wouldOpBeTriviallyDead`). + return ModRefResult::getModAndRef(); + } + + // Otherwise, check to see if this operation has a memory effect interface. + MemoryEffectOpInterface interface = dyn_cast(op); + if (!interface) + return ModRefResult::getModAndRef(); + + // Build a ModRefResult by merging the behavior of the effects of this + // operation. + SmallVector effects; + interface.getEffects(effects); + + ModRefResult result = ModRefResult::getNoModRef(); + for (const MemoryEffects::EffectInstance &effect : effects) { + if (isa(effect.getEffect())) + continue; + + // Check for an alias between the effect and our memory location. + // TODO: Add support for checking an alias with a symbol reference. + AliasResult aliasResult = AliasResult::MayAlias; + if (Value effectValue = effect.getValue()) + aliasResult = alias(effectValue, location); + + // If we don't alias, ignore this effect. + if (aliasResult.isNo()) + continue; + + // Merge in the corresponding mod or ref for this effect. + if (isa(effect.getEffect())) { + result = result.merge(ModRefResult::getRef()); + } else { + assert(isa(effect.getEffect())); + result = result.merge(ModRefResult::getMod()); + } + if (result.isModAndRef()) + break; + } + return result; +} diff --git a/mlir/test/Analysis/test-alias-analysis-modref.mlir b/mlir/test/Analysis/test-alias-analysis-modref.mlir new file mode 100644 index 00000000000000..46ac7fbdf5d9f1 --- /dev/null +++ b/mlir/test/Analysis/test-alias-analysis-modref.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt %s -pass-pipeline='func(test-alias-analysis-modref)' -split-input-file -allow-unregistered-dialect 2>&1 | FileCheck %s + +// CHECK-LABEL: Testing : "no_side_effects" +// CHECK: alloc -> func.region0#0: NoModRef +// CHECK: dealloc -> func.region0#0: NoModRef +// CHECK: return -> func.region0#0: NoModRef +func @no_side_effects(%arg: memref<2xf32>) attributes {test.ptr = "func"} { + %1 = memref.alloc() {test.ptr = "alloc"} : memref<8x64xf32> + memref.dealloc %1 {test.ptr = "dealloc"} : memref<8x64xf32> + return {test.ptr = "return"} +} + +// ----- + +// CHECK-LABEL: Testing : "simple" +// CHECK-DAG: store -> alloc#0: Mod +// CHECK-DAG: load -> alloc#0: Ref + +// CHECK-DAG: store -> func.region0#0: NoModRef +// CHECK-DAG: load -> func.region0#0: NoModRef +func @simple(%arg: memref, %value: i32) attributes {test.ptr = "func"} { + %1 = memref.alloca() {test.ptr = "alloc"} : memref + memref.store %value, %1[] {test.ptr = "store"} : memref + %2 = memref.load %1[] {test.ptr = "load"} : memref + return {test.ptr = "return"} +} + +// ----- + +// CHECK-LABEL: Testing : "mayalias" +// CHECK-DAG: store -> func.region0#0: Mod +// CHECK-DAG: load -> func.region0#0: Ref + +// CHECK-DAG: store -> func.region0#1: Mod +// CHECK-DAG: load -> func.region0#1: Ref +func @mayalias(%arg0: memref, %arg1: memref, %value: i32) attributes {test.ptr = "func"} { + memref.store %value, %arg1[] {test.ptr = "store"} : memref + %1 = memref.load %arg1[] {test.ptr = "load"} : memref + return {test.ptr = "return"} +} + +// ----- + +// CHECK-LABEL: Testing : "recursive" +// CHECK-DAG: if -> func.region0#0: ModRef +// CHECK-DAG: if -> func.region0#1: ModRef + +// TODO: This is provably NoModRef, but requires handling recursive side +// effects. +// CHECK-DAG: if -> alloc#0: ModRef +func @recursive(%arg0: memref, %arg1: memref, %cond: i1, %value: i32) attributes {test.ptr = "func"} { + %0 = memref.alloca() {test.ptr = "alloc"} : memref + scf.if %cond { + memref.store %value, %arg0[] : memref + %1 = memref.load %arg0[] : memref + } {test.ptr = "if"} + return {test.ptr = "return"} +} + +// ----- + +// CHECK-LABEL: Testing : "unknown" +// CHECK-DAG: unknown -> func.region0#0: ModRef +func @unknown(%arg0: memref) attributes {test.ptr = "func"} { + "foo.op"() {test.ptr = "unknown"} : () -> () + return +} diff --git a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp index d17a1c1b360a91..c54e5d8ba58230 100644 --- a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp +++ b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp @@ -16,15 +16,38 @@ using namespace mlir; +/// Print a value that is used as an operand of an alias query. +static void printAliasOperand(Operation *op) { + llvm::errs() << op->getAttrOfType("test.ptr").getValue(); +} +static void printAliasOperand(Value value) { + if (BlockArgument arg = value.dyn_cast()) { + Region *region = arg.getParentRegion(); + unsigned parentBlockNumber = + std::distance(region->begin(), arg.getOwner()->getIterator()); + llvm::errs() << region->getParentOp() + ->getAttrOfType("test.ptr") + .getValue() + << ".region" << region->getRegionNumber(); + if (parentBlockNumber != 0) + llvm::errs() << ".block" << parentBlockNumber; + llvm::errs() << "#" << arg.getArgNumber(); + return; + } + OpResult result = value.cast(); + printAliasOperand(result.getOwner()); + llvm::errs() << "#" << result.getResultNumber(); +} + +//===----------------------------------------------------------------------===// +// Testing AliasResult +//===----------------------------------------------------------------------===// + namespace { struct TestAliasAnalysisPass : public PassWrapper> { void runOnOperation() override { - llvm::errs() << "Testing : "; - if (Attribute testName = getOperation()->getAttr("test.name")) - llvm::errs() << testName << "\n"; - else - llvm::errs() << getOperation()->getAttr("sym_name") << "\n"; + llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n"; // Collect all of the values to check for aliasing behavior. AliasAnalysis &aliasAnalysis = getAnalysis(); @@ -49,52 +72,64 @@ struct TestAliasAnalysisPass printAliasOperand(lhs); llvm::errs() << " <-> "; printAliasOperand(rhs); - llvm::errs() << ": "; + llvm::errs() << ": " << result << "\n"; + } +}; +} // end anonymous namespace - switch (result.getKind()) { - case AliasResult::NoAlias: - llvm::errs() << "NoAlias"; - break; - case AliasResult::MayAlias: - llvm::errs() << "MayAlias"; - break; - case AliasResult::PartialAlias: - llvm::errs() << "PartialAlias"; - break; - case AliasResult::MustAlias: - llvm::errs() << "MustAlias"; - break; +//===----------------------------------------------------------------------===// +// Testing ModRefResult +//===----------------------------------------------------------------------===// + +namespace { +struct TestAliasAnalysisModRefPass + : public PassWrapper> { + void runOnOperation() override { + llvm::errs() << "Testing : " << getOperation()->getAttr("sym_name") << "\n"; + + // Collect all of the values to check for aliasing behavior. + AliasAnalysis &aliasAnalysis = getAnalysis(); + SmallVector valsToCheck; + getOperation()->walk([&](Operation *op) { + if (!op->getAttr("test.ptr")) + return; + valsToCheck.append(op->result_begin(), op->result_end()); + for (Region ®ion : op->getRegions()) + for (Block &block : region) + valsToCheck.append(block.args_begin(), block.args_end()); + }); + + // Check for aliasing behavior between each of the values. + for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it) { + getOperation()->walk([&](Operation *op) { + if (!op->getAttr("test.ptr")) + return; + printModRefResult(aliasAnalysis.getModRef(op, *it), op, *it); + }); } - llvm::errs() << "\n"; } - /// Print a value that is used as an operand of an alias query. - void printAliasOperand(Value value) { - if (BlockArgument arg = value.dyn_cast()) { - Region *region = arg.getParentRegion(); - unsigned parentBlockNumber = - std::distance(region->begin(), arg.getOwner()->getIterator()); - llvm::errs() << region->getParentOp() - ->getAttrOfType("test.ptr") - .getValue() - << ".region" << region->getRegionNumber(); - if (parentBlockNumber != 0) - llvm::errs() << ".block" << parentBlockNumber; - llvm::errs() << "#" << arg.getArgNumber(); - return; - } - OpResult result = value.cast(); - llvm::errs() - << result.getOwner()->getAttrOfType("test.ptr").getValue() - << "#" << result.getResultNumber(); + + /// Print the result of an alias query. + void printModRefResult(ModRefResult result, Operation *op, Value location) { + printAliasOperand(op); + llvm::errs() << " -> "; + printAliasOperand(location); + llvm::errs() << ": " << result << "\n"; } }; } // end anonymous namespace +//===----------------------------------------------------------------------===// +// Pass Registration +//===----------------------------------------------------------------------===// + namespace mlir { namespace test { void registerTestAliasAnalysisPass() { - PassRegistration pass("test-alias-analysis", - "Test alias analysis results."); + PassRegistration aliasPass( + "test-alias-analysis", "Test alias analysis results."); + PassRegistration modRefPass( + "test-alias-analysis-modref", "Test alias analysis ModRef results."); } } // namespace test } // namespace mlir