Skip to content

Commit a374017

Browse files
authored
[mlir][memref] Introduce memref.distinct_objects op (#156913)
The `distinct_objects` operation takes a list of memrefs and returns a list of memrefs of the same types, with the additional assumption that accesses to these memrefs will never alias with each other. This means that loads and stores to different memrefs in the list can be safely reordered. The discussion https://discourse.llvm.org/t/rfc-introducing-memref-aliasing-attributes/88049
1 parent 93c8305 commit a374017

File tree

6 files changed

+164
-5
lines changed

6 files changed

+164
-5
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
155155
The `assume_alignment` operation takes a memref and an integer alignment
156156
value. It returns a new SSA value of the same memref type, but associated
157157
with the assumption that the underlying buffer is aligned to the given
158-
alignment.
158+
alignment.
159159

160160
If the buffer isn't aligned to the given alignment, its result is poison.
161161
This operation doesn't affect the semantics of a program where the
@@ -170,14 +170,49 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
170170
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
171171
let extraClassDeclaration = [{
172172
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
173-
173+
174174
Value getViewSource() { return getMemref(); }
175175
}];
176176

177177
let hasVerifier = 1;
178178
let hasFolder = 1;
179179
}
180180

181+
//===----------------------------------------------------------------------===//
182+
// DistinctObjectsOp
183+
//===----------------------------------------------------------------------===//
184+
185+
def DistinctObjectsOp : MemRef_Op<"distinct_objects", [
186+
Pure,
187+
DeclareOpInterfaceMethods<InferTypeOpInterface>
188+
// ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument
189+
]> {
190+
let summary = "assumption that acesses to specific memrefs will never alias";
191+
let description = [{
192+
The `distinct_objects` operation takes a list of memrefs and returns the same
193+
memrefs, with the additional assumption that accesses to them will never
194+
alias with each other. This means that loads and stores to different
195+
memrefs in the list can be safely reordered.
196+
197+
If the memrefs do alias, the load/store behavior is undefined. This
198+
operation doesn't affect the semantics of a valid program. It is
199+
intended for optimization purposes, allowing the compiler to generate more
200+
efficient code based on the non-aliasing assumption. The optimization is
201+
best-effort.
202+
203+
Example:
204+
205+
```mlir
206+
%1, %2 = memref.distinct_objects %a, %b : memref<?xf32>, memref<?xf32>
207+
```
208+
}];
209+
let arguments = (ins Variadic<AnyMemRef>:$operands);
210+
let results = (outs Variadic<AnyMemRef>:$results);
211+
212+
let assemblyFormat = "$operands attr-dict `:` type($operands)";
213+
let hasVerifier = 1;
214+
}
215+
181216
//===----------------------------------------------------------------------===//
182217
// AllocOp
183218
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,51 @@ struct AssumeAlignmentOpLowering
465465
}
466466
};
467467

468+
struct DistinctObjectsOpLowering
469+
: public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
470+
using ConvertOpToLLVMPattern<
471+
memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
472+
explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
473+
: ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {}
474+
475+
LogicalResult
476+
matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
477+
ConversionPatternRewriter &rewriter) const override {
478+
ValueRange operands = adaptor.getOperands();
479+
if (operands.size() <= 1) {
480+
// Fast path.
481+
rewriter.replaceOp(op, operands);
482+
return success();
483+
}
484+
485+
Location loc = op.getLoc();
486+
SmallVector<Value> ptrs;
487+
for (auto [origOperand, newOperand] :
488+
llvm::zip_equal(op.getOperands(), operands)) {
489+
auto memrefType = cast<MemRefType>(origOperand.getType());
490+
MemRefDescriptor memRefDescriptor(newOperand);
491+
Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
492+
memrefType);
493+
ptrs.push_back(ptr);
494+
}
495+
496+
auto cond =
497+
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
498+
// Generate separate_storage assumptions for each pair of pointers.
499+
for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
500+
for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
501+
Value ptr1 = ptrs[i];
502+
Value ptr2 = ptrs[j];
503+
LLVM::AssumeOp::create(rewriter, loc, cond,
504+
LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
505+
}
506+
}
507+
508+
rewriter.replaceOp(op, operands);
509+
return success();
510+
}
511+
};
512+
468513
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
469514
// The memref descriptor being an SSA value, there is no need to clean it up
470515
// in any way.
@@ -1997,22 +2042,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
19972042
patterns.add<
19982043
AllocaOpLowering,
19992044
AllocaScopeOpLowering,
2000-
AtomicRMWOpLowering,
20012045
AssumeAlignmentOpLowering,
2046+
AtomicRMWOpLowering,
20022047
ConvertExtractAlignedPointerAsIndex,
20032048
DimOpLowering,
2049+
DistinctObjectsOpLowering,
20042050
ExtractStridedMetadataOpLowering,
20052051
GenericAtomicRMWOpLowering,
20062052
GetGlobalMemrefOpLowering,
20072053
LoadOpLowering,
20082054
MemRefCastOpLowering,
2009-
MemorySpaceCastOpLowering,
20102055
MemRefReinterpretCastOpLowering,
20112056
MemRefReshapeOpLowering,
2057+
MemorySpaceCastOpLowering,
20122058
PrefetchOpLowering,
20132059
RankOpLowering,
2014-
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
20152060
ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2061+
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
20162062
StoreOpLowering,
20172063
SubViewOpLowering,
20182064
TransposeOpLowering,

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,29 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
606606
return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
607607
}
608608

609+
//===----------------------------------------------------------------------===//
610+
// DistinctObjectsOp
611+
//===----------------------------------------------------------------------===//
612+
613+
LogicalResult DistinctObjectsOp::verify() {
614+
if (getOperandTypes() != getResultTypes())
615+
return emitOpError("operand types and result types must match");
616+
617+
if (getOperandTypes().empty())
618+
return emitOpError("expected at least one operand");
619+
620+
return success();
621+
}
622+
623+
LogicalResult DistinctObjectsOp::inferReturnTypes(
624+
MLIRContext * /*context*/, std::optional<Location> /*location*/,
625+
ValueRange operands, DictionaryAttr /*attributes*/,
626+
OpaqueProperties /*properties*/, RegionRange /*regions*/,
627+
SmallVectorImpl<Type> &inferredReturnTypes) {
628+
llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
629+
return success();
630+
}
631+
609632
//===----------------------------------------------------------------------===//
610633
// CastOp
611634
//===----------------------------------------------------------------------===//

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,36 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {
195195

196196
// -----
197197

198+
// ALL-LABEL: func @distinct_objects
199+
// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
200+
func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
201+
// ALL-DAG: %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xf16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
202+
// ALL-DAG: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<?xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
203+
// ALL-DAG: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : memref<?xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
204+
// ALL: %[[PTR_0:.*]] = llvm.extractvalue %[[CAST_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
205+
// ALL: %[[PTR_1:.*]] = llvm.extractvalue %[[CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
206+
// ALL: %[[PTR_2:.*]] = llvm.extractvalue %[[CAST_2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
207+
// ALL: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
208+
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_1]] : !llvm.ptr, !llvm.ptr)] : i1
209+
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
210+
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_1]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
211+
%1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
212+
return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
213+
}
214+
215+
// -----
216+
217+
// ALL-LABEL: func @distinct_objects_noop
218+
// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>)
219+
func.func @distinct_objects_noop(%arg0: memref<?xf16>) -> memref<?xf16> {
220+
// 1-operand version is noop
221+
// ALL-NEXT: return %[[ARG0]]
222+
%1 = memref.distinct_objects %arg0 : memref<?xf16>
223+
return %1 : memref<?xf16>
224+
}
225+
226+
// -----
227+
198228
// CHECK-LABEL: func @assume_alignment_w_offset
199229
// CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset
200230
func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) {

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,3 +1169,19 @@ func.func @expand_shape_invalid_output_shape(
11691169
into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
11701170
return
11711171
}
1172+
1173+
// -----
1174+
1175+
func.func @distinct_objects_types_mismatch(%arg0: memref<?xf32>, %arg1: memref<?xi32>) -> (memref<?xi32>, memref<?xf32>) {
1176+
// expected-error @+1 {{operand types and result types must match}}
1177+
%0, %1 = "memref.distinct_objects"(%arg0, %arg1) : (memref<?xf32>, memref<?xi32>) -> (memref<?xi32>, memref<?xf32>)
1178+
return %0, %1 : memref<?xi32>, memref<?xf32>
1179+
}
1180+
1181+
// -----
1182+
1183+
func.func @distinct_objects_0_operands() {
1184+
// expected-error @+1 {{expected at least one operand}}
1185+
"memref.distinct_objects"() : () -> ()
1186+
return
1187+
}

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,15 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
302302
return
303303
}
304304

305+
// CHECK-LABEL: func @distinct_objects
306+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
307+
func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
308+
// CHECK: %[[RES:.*]]:3 = memref.distinct_objects %[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<?xf16>, memref<?xf32>, memref<?xf64>
309+
%1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
310+
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
311+
return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
312+
}
313+
305314
// CHECK-LABEL: func @expand_collapse_shape_static
306315
func.func @expand_collapse_shape_static(
307316
%arg0: memref<3x4x5xf32>,

0 commit comments

Comments
 (0)