Skip to content

Commit d47dd11

Browse files
committed
[mlir] Add support for querying the ModRef behavior from the AliasAnalysis class
This allows for checking if a given operation may modify/reference/or both a given value. Right now this API is limited to Value based memory locations, but we should expand this to include attribute based values at some point. This is left for future work because the rest of the AliasAnalysis API also has this restriction. Differential Revision: https://reviews.llvm.org/D101673
1 parent b3ceffd commit d47dd11

File tree

6 files changed

+366
-46
lines changed

6 files changed

+366
-46
lines changed

mlir/include/mlir/Analysis/AliasAnalysis.h

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,106 @@ class AliasResult {
6767
/// Returns if this result is a partial alias.
6868
bool isPartial() const { return kind == PartialAlias; }
6969

70-
/// Return the internal kind of this alias result.
71-
Kind getKind() const { return kind; }
70+
/// Print this alias result to the provided output stream.
71+
void print(raw_ostream &os) const;
7272

7373
private:
7474
/// The internal kind of the result.
7575
Kind kind;
7676
};
7777

78+
inline raw_ostream &operator<<(raw_ostream &os, const AliasResult &result) {
79+
result.print(os);
80+
return os;
81+
}
82+
83+
//===----------------------------------------------------------------------===//
84+
// ModRefResult
85+
//===----------------------------------------------------------------------===//
86+
87+
/// The possible results of whether a memory access modifies or references
88+
/// a memory location. The possible results are: no access at all, a
89+
/// modification, a reference, or both a modification and a reference.
90+
class LLVM_NODISCARD ModRefResult {
91+
/// Note: This is a simplified version of the ModRefResult in
92+
/// `llvm/Analysis/AliasAnalysis.h`, and namely removes the `Must` concept. If
93+
/// this becomes useful/necessary we should add it here.
94+
enum class Kind {
95+
/// The access neither references nor modifies the value stored in memory.
96+
NoModRef = 0,
97+
/// The access may reference the value stored in memory.
98+
Ref = 1,
99+
/// The access may modify the value stored in memory.
100+
Mod = 2,
101+
/// The access may reference and may modify the value stored in memory.
102+
ModRef = Ref | Mod,
103+
};
104+
105+
public:
106+
bool operator==(const ModRefResult &rhs) const { return kind == rhs.kind; }
107+
bool operator!=(const ModRefResult &rhs) const { return !(*this == rhs); }
108+
109+
/// Return a new result that indicates that the memory access neither
110+
/// references nor modifies the value stored in memory.
111+
static ModRefResult getNoModRef() { return Kind::NoModRef; }
112+
113+
/// Return a new result that indicates that the memory access may reference
114+
/// the value stored in memory.
115+
static ModRefResult getRef() { return Kind::Ref; }
116+
117+
/// Return a new result that indicates that the memory access may modify the
118+
/// value stored in memory.
119+
static ModRefResult getMod() { return Kind::Mod; }
120+
121+
/// Return a new result that indicates that the memory access may reference
122+
/// and may modify the value stored in memory.
123+
static ModRefResult getModAndRef() { return Kind::ModRef; }
124+
125+
/// Returns if this result does not modify or reference memory.
126+
LLVM_NODISCARD bool isNoModRef() const { return kind == Kind::NoModRef; }
127+
128+
/// Returns if this result modifies memory.
129+
LLVM_NODISCARD bool isMod() const {
130+
return static_cast<int>(kind) & static_cast<int>(Kind::Mod);
131+
}
132+
133+
/// Returns if this result references memory.
134+
LLVM_NODISCARD bool isRef() const {
135+
return static_cast<int>(kind) & static_cast<int>(Kind::Ref);
136+
}
137+
138+
/// Returns if this result modifies *or* references memory.
139+
LLVM_NODISCARD bool isModOrRef() const { return kind != Kind::NoModRef; }
140+
141+
/// Returns if this result modifies *and* references memory.
142+
LLVM_NODISCARD bool isModAndRef() const { return kind == Kind::ModRef; }
143+
144+
/// Merge this ModRef result with `other` and return the result.
145+
ModRefResult merge(const ModRefResult &other) {
146+
return ModRefResult(static_cast<Kind>(static_cast<int>(kind) |
147+
static_cast<int>(other.kind)));
148+
}
149+
/// Intersect this ModRef result with `other` and return the result.
150+
ModRefResult intersect(const ModRefResult &other) {
151+
return ModRefResult(static_cast<Kind>(static_cast<int>(kind) &
152+
static_cast<int>(other.kind)));
153+
}
154+
155+
/// Print this ModRef result to the provided output stream.
156+
void print(raw_ostream &os) const;
157+
158+
private:
159+
ModRefResult(Kind kind) : kind(kind) {}
160+
161+
/// The internal kind of the result.
162+
Kind kind;
163+
};
164+
165+
inline raw_ostream &operator<<(raw_ostream &os, const ModRefResult &result) {
166+
result.print(os);
167+
return os;
168+
}
169+
78170
//===----------------------------------------------------------------------===//
79171
// AliasAnalysisTraits
80172
//===----------------------------------------------------------------------===//
@@ -92,6 +184,9 @@ struct AliasAnalysisTraits {
92184

93185
/// Given two values, return their aliasing behavior.
94186
virtual AliasResult alias(Value lhs, Value rhs) = 0;
187+
188+
/// Return the modify-reference behavior of `op` on `location`.
189+
virtual ModRefResult getModRef(Operation *op, Value location) = 0;
95190
};
96191

97192
/// This class represents the `Model` of an alias analysis implementation
@@ -108,6 +203,11 @@ struct AliasAnalysisTraits {
108203
return impl.alias(lhs, rhs);
109204
}
110205

206+
/// Return the modify-reference behavior of `op` on `location`.
207+
ModRefResult getModRef(Operation *op, Value location) final {
208+
return impl.getModRef(op, location);
209+
}
210+
111211
private:
112212
ImplT impl;
113213
};
@@ -147,7 +247,12 @@ class AliasAnalysis {
147247
/// * AnalysisT(AnalysisT &&)
148248
/// * AliasResult alias(Value lhs, Value rhs)
149249
/// - This method returns an `AliasResult` that corresponds to the
150-
/// aliasing behavior between `lhs` and `rhs`.
250+
/// aliasing behavior between `lhs` and `rhs`. The conservative "I don't
251+
/// know" result of this method should be MayAlias.
252+
/// * ModRefResult getModRef(Operation *op, Value location)
253+
/// - This method returns a `ModRefResult` that corresponds to the
254+
/// modify-reference behavior of `op` on the given `location`. The
255+
/// conservative "I don't know" result of this method should be ModRef.
151256
template <typename AnalysisT>
152257
void addAnalysisImplementation(AnalysisT &&analysis) {
153258
aliasImpls.push_back(
@@ -161,6 +266,13 @@ class AliasAnalysis {
161266
/// Given two values, return their aliasing behavior.
162267
AliasResult alias(Value lhs, Value rhs);
163268

269+
//===--------------------------------------------------------------------===//
270+
// ModRef Queries
271+
//===--------------------------------------------------------------------===//
272+
273+
/// Return the modify-reference behavior of `op` on `location`.
274+
ModRefResult getModRef(Operation *op, Value location);
275+
164276
private:
165277
/// A set of internal alias analysis implementations.
166278
SmallVector<std::unique_ptr<Concept>, 4> aliasImpls;

mlir/include/mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ class LocalAliasAnalysis {
2525
public:
2626
/// Given two values, return their aliasing behavior.
2727
AliasResult alias(Value lhs, Value rhs);
28+
29+
/// Return the modify-reference behavior of `op` on `location`.
30+
ModRefResult getModRef(Operation *op, Value location);
2831
};
2932
} // end namespace mlir
3033

mlir/lib/Analysis/AliasAnalysis.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,44 @@ AliasResult AliasResult::merge(AliasResult other) const {
2727
return MayAlias;
2828
}
2929

30+
void AliasResult::print(raw_ostream &os) const {
31+
switch (kind) {
32+
case Kind::NoAlias:
33+
os << "NoAlias";
34+
break;
35+
case Kind::MayAlias:
36+
os << "MayAlias";
37+
break;
38+
case Kind::PartialAlias:
39+
os << "PartialAlias";
40+
break;
41+
case Kind::MustAlias:
42+
os << "MustAlias";
43+
break;
44+
}
45+
}
46+
47+
//===----------------------------------------------------------------------===//
48+
// ModRefResult
49+
//===----------------------------------------------------------------------===//
50+
51+
void ModRefResult::print(raw_ostream &os) const {
52+
switch (kind) {
53+
case Kind::NoModRef:
54+
os << "NoModRef";
55+
break;
56+
case Kind::Ref:
57+
os << "Ref";
58+
break;
59+
case Kind::Mod:
60+
os << "Mod";
61+
break;
62+
case Kind::ModRef:
63+
os << "ModRef";
64+
break;
65+
}
66+
}
67+
3068
//===----------------------------------------------------------------------===//
3169
// AliasAnalysis
3270
//===----------------------------------------------------------------------===//
@@ -35,7 +73,6 @@ AliasAnalysis::AliasAnalysis(Operation *op) {
3573
addAnalysisImplementation(LocalAliasAnalysis());
3674
}
3775

38-
/// Given the two values, return their aliasing behavior.
3976
AliasResult AliasAnalysis::alias(Value lhs, Value rhs) {
4077
// Check each of the alias analysis implemenations for an alias result.
4178
for (const std::unique_ptr<Concept> &aliasImpl : aliasImpls) {
@@ -45,3 +82,16 @@ AliasResult AliasAnalysis::alias(Value lhs, Value rhs) {
4582
}
4683
return AliasResult::MayAlias;
4784
}
85+
86+
ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) {
87+
// Compute the mod-ref behavior by refining a top `ModRef` result with each of
88+
// the alias analysis implementations. We early exit at the point where we
89+
// refine down to a `NoModRef`.
90+
ModRefResult result = ModRefResult::getModAndRef();
91+
for (const std::unique_ptr<Concept> &aliasImpl : aliasImpls) {
92+
result = result.intersect(aliasImpl->getModRef(op, location));
93+
if (result.isNoModRef())
94+
return result;
95+
}
96+
return result;
97+
}

mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ static void collectUnderlyingAddressValues(Value value,
195195
}
196196

197197
//===----------------------------------------------------------------------===//
198-
// LocalAliasAnalysis
198+
// LocalAliasAnalysis: alias
199199
//===----------------------------------------------------------------------===//
200200

201201
/// Given a value, try to get an allocation effect attached to it. If
@@ -336,3 +336,56 @@ AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) {
336336
// We should always have a valid result here.
337337
return *result;
338338
}
339+
340+
//===----------------------------------------------------------------------===//
341+
// LocalAliasAnalysis: getModRef
342+
//===----------------------------------------------------------------------===//
343+
344+
ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) {
345+
// Check to see if this operation relies on nested side effects.
346+
if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
347+
// TODO: To check recursive operations we need to check all of the nested
348+
// operations, which can result in a quadratic number of queries. We should
349+
// introduce some caching of some kind to help alleviate this, especially as
350+
// this caching could be used in other areas of the codebase (e.g. when
351+
// checking `wouldOpBeTriviallyDead`).
352+
return ModRefResult::getModAndRef();
353+
}
354+
355+
// Otherwise, check to see if this operation has a memory effect interface.
356+
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
357+
if (!interface)
358+
return ModRefResult::getModAndRef();
359+
360+
// Build a ModRefResult by merging the behavior of the effects of this
361+
// operation.
362+
SmallVector<MemoryEffects::EffectInstance> effects;
363+
interface.getEffects(effects);
364+
365+
ModRefResult result = ModRefResult::getNoModRef();
366+
for (const MemoryEffects::EffectInstance &effect : effects) {
367+
if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect()))
368+
continue;
369+
370+
// Check for an alias between the effect and our memory location.
371+
// TODO: Add support for checking an alias with a symbol reference.
372+
AliasResult aliasResult = AliasResult::MayAlias;
373+
if (Value effectValue = effect.getValue())
374+
aliasResult = alias(effectValue, location);
375+
376+
// If we don't alias, ignore this effect.
377+
if (aliasResult.isNo())
378+
continue;
379+
380+
// Merge in the corresponding mod or ref for this effect.
381+
if (isa<MemoryEffects::Read>(effect.getEffect())) {
382+
result = result.merge(ModRefResult::getRef());
383+
} else {
384+
assert(isa<MemoryEffects::Write>(effect.getEffect()));
385+
result = result.merge(ModRefResult::getMod());
386+
}
387+
if (result.isModAndRef())
388+
break;
389+
}
390+
return result;
391+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: mlir-opt %s -pass-pipeline='func(test-alias-analysis-modref)' -split-input-file -allow-unregistered-dialect 2>&1 | FileCheck %s
2+
3+
// CHECK-LABEL: Testing : "no_side_effects"
4+
// CHECK: alloc -> func.region0#0: NoModRef
5+
// CHECK: dealloc -> func.region0#0: NoModRef
6+
// CHECK: return -> func.region0#0: NoModRef
7+
func @no_side_effects(%arg: memref<2xf32>) attributes {test.ptr = "func"} {
8+
%1 = memref.alloc() {test.ptr = "alloc"} : memref<8x64xf32>
9+
memref.dealloc %1 {test.ptr = "dealloc"} : memref<8x64xf32>
10+
return {test.ptr = "return"}
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: Testing : "simple"
16+
// CHECK-DAG: store -> alloc#0: Mod
17+
// CHECK-DAG: load -> alloc#0: Ref
18+
19+
// CHECK-DAG: store -> func.region0#0: NoModRef
20+
// CHECK-DAG: load -> func.region0#0: NoModRef
21+
func @simple(%arg: memref<i32>, %value: i32) attributes {test.ptr = "func"} {
22+
%1 = memref.alloca() {test.ptr = "alloc"} : memref<i32>
23+
memref.store %value, %1[] {test.ptr = "store"} : memref<i32>
24+
%2 = memref.load %1[] {test.ptr = "load"} : memref<i32>
25+
return {test.ptr = "return"}
26+
}
27+
28+
// -----
29+
30+
// CHECK-LABEL: Testing : "mayalias"
31+
// CHECK-DAG: store -> func.region0#0: Mod
32+
// CHECK-DAG: load -> func.region0#0: Ref
33+
34+
// CHECK-DAG: store -> func.region0#1: Mod
35+
// CHECK-DAG: load -> func.region0#1: Ref
36+
func @mayalias(%arg0: memref<i32>, %arg1: memref<i32>, %value: i32) attributes {test.ptr = "func"} {
37+
memref.store %value, %arg1[] {test.ptr = "store"} : memref<i32>
38+
%1 = memref.load %arg1[] {test.ptr = "load"} : memref<i32>
39+
return {test.ptr = "return"}
40+
}
41+
42+
// -----
43+
44+
// CHECK-LABEL: Testing : "recursive"
45+
// CHECK-DAG: if -> func.region0#0: ModRef
46+
// CHECK-DAG: if -> func.region0#1: ModRef
47+
48+
// TODO: This is provably NoModRef, but requires handling recursive side
49+
// effects.
50+
// CHECK-DAG: if -> alloc#0: ModRef
51+
func @recursive(%arg0: memref<i32>, %arg1: memref<i32>, %cond: i1, %value: i32) attributes {test.ptr = "func"} {
52+
%0 = memref.alloca() {test.ptr = "alloc"} : memref<i32>
53+
scf.if %cond {
54+
memref.store %value, %arg0[] : memref<i32>
55+
%1 = memref.load %arg0[] : memref<i32>
56+
} {test.ptr = "if"}
57+
return {test.ptr = "return"}
58+
}
59+
60+
// -----
61+
62+
// CHECK-LABEL: Testing : "unknown"
63+
// CHECK-DAG: unknown -> func.region0#0: ModRef
64+
func @unknown(%arg0: memref<i32>) attributes {test.ptr = "func"} {
65+
"foo.op"() {test.ptr = "unknown"} : () -> ()
66+
return
67+
}

0 commit comments

Comments
 (0)