From 7a513df00552310c078f38d2d1a8aa20e9cf03e6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 4 Sep 2025 17:25:26 +0200 Subject: [PATCH 1/4] [mlir][memref] Introduce `memref.distinct_objects` op --- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 39 ++++++++++++++- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 49 +++++++++++++++++-- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 19 +++++++ .../MemRefToLLVM/memref-to-llvm.mlir | 19 +++++++ mlir/test/Dialect/MemRef/ops.mlir | 9 ++++ 5 files changed, 130 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 671cc05e963b4..933fb87de30ab 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -153,7 +153,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ The `assume_alignment` operation takes a memref and an integer alignment value. It returns a new SSA value of the same memref type, but associated with the assumption that the underlying buffer is aligned to the given - alignment. + alignment. If the buffer isn't aligned to the given alignment, its result is poison. This operation doesn't affect the semantics of a program where the @@ -168,7 +168,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; let extraClassDeclaration = [{ MemRefType getType() { return ::llvm::cast(getResult().getType()); } - + Value getViewSource() { return getMemref(); } }]; @@ -176,6 +176,41 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// DistinctObjectsOp +//===----------------------------------------------------------------------===// + +def DistinctObjectsOp : MemRef_Op<"distinct_objects", [ + Pure, + DeclareOpInterfaceMethods + // ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument + ]> { + let summary = "assumption that acesses to specific memrefs will never alias"; + let description = [{ + The `distinct_objects` operation takes a list of memrefs and returns a list of + memrefs of the same types, with the additional assumption that accesses to + these memrefs will never alias with each other. This means that loads and + stores to different memrefs in the list can be safely reordered. + + If the memrefs do alias, the behavior is undefined. This operation doesn't + affect the semantics of a program where the non-aliasing assumption holds + true. It is intended for optimization purposes, allowing the compiler to + generate more efficient code based on the non-aliasing assumption. The + optimization is best-effort. + + Example: + + ```mlir + %1, %2 = memref.distinct_objects %a, %b : memref, memref + ``` + }]; + let arguments = (ins Variadic:$operands); + let results = (outs Variadic:$results); + + let assemblyFormat = "$operands attr-dict `:` type($operands)"; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 262e0e7a30c63..571e5000b3f51 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -465,6 +465,48 @@ struct AssumeAlignmentOpLowering } }; +struct DistinctObjectsOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + memref::DistinctObjectsOp>::ConvertOpToLLVMPattern; + explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter) {} + + LogicalResult + matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange operands = adaptor.getOperands(); + if (operands.empty()) { + rewriter.eraseOp(op); + return success(); + } + Location loc = op.getLoc(); + SmallVector ptrs; + for (auto [origOperand, newOperand] : + llvm::zip_equal(op.getOperands(), operands)) { + auto memrefType = cast(origOperand.getType()); + Value ptr = getStridedElementPtr(rewriter, loc, memrefType, newOperand, + /*indices=*/{}); + ptrs.push_back(ptr); + } + + auto cond = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1); + // Generate separate_storage assumptions for each pair of pointers. + for (auto i : llvm::seq(ptrs.size() - 1)) { + for (auto j : llvm::seq(i + 1, ptrs.size())) { + Value ptr1 = ptrs[i]; + Value ptr2 = ptrs[j]; + LLVM::AssumeOp::create(rewriter, loc, cond, + LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2); + } + } + + rewriter.replaceOp(op, operands); + return success(); + } +}; + // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. @@ -1997,22 +2039,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns( patterns.add< AllocaOpLowering, AllocaScopeOpLowering, - AtomicRMWOpLowering, AssumeAlignmentOpLowering, + AtomicRMWOpLowering, ConvertExtractAlignedPointerAsIndex, DimOpLowering, + DistinctObjectsOpLowering, ExtractStridedMetadataOpLowering, GenericAtomicRMWOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, - MemorySpaceCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, + MemorySpaceCastOpLowering, PrefetchOpLowering, RankOpLowering, - ReassociatingReshapeOpConversion, ReassociatingReshapeOpConversion, + ReassociatingReshapeOpConversion, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 5d15d5f6e3de4..62fb57f24a870 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -542,6 +542,25 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) { return getMemref(); } +//===----------------------------------------------------------------------===// +// DistinctObjectsOp +//===----------------------------------------------------------------------===// + +LogicalResult DistinctObjectsOp::verify() { + if (getOperandTypes() != getResultTypes()) + return emitOpError("operand types and result types must match"); + return success(); +} + +LogicalResult DistinctObjectsOp::inferReturnTypes( + MLIRContext * /*context*/, std::optional /*location*/, + ValueRange operands, DictionaryAttr /*attributes*/, + OpaqueProperties /*properties*/, RegionRange /*regions*/, + SmallVectorImpl &inferredReturnTypes) { + llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes)); + return success(); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 45b1a1f1ca40c..3eb8df093af10 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -195,6 +195,25 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) { // ----- +// ALL-LABEL: func @distinct_objects +// ALL-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref) +func.func @distinct_objects(%arg0: memref, %arg1: memref, %arg2: memref) -> (memref, memref, memref) { +// ALL-DAG: %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL-DAG: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL-DAG: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[PTR_0:.*]] = llvm.extractvalue %[[CAST_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[PTR_1:.*]] = llvm.extractvalue %[[CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[PTR_2:.*]] = llvm.extractvalue %[[CAST_2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1 +// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_1]] : !llvm.ptr, !llvm.ptr)] : i1 +// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1 +// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_1]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1 + %1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref, memref, memref + return %1, %2, %3 : memref, memref, memref +} + +// ----- + // CHECK-LABEL: func @assume_alignment_w_offset // CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) { diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 6c2298a3f8acb..a90c9505a8405 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -302,6 +302,15 @@ func.func @assume_alignment(%0: memref<4x4xf16>) { return } +// CHECK-LABEL: func @distinct_objects +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref) +func.func @distinct_objects(%arg0: memref, %arg1: memref, %arg2: memref) -> (memref, memref, memref) { + // CHECK: %[[RES:.*]]:3 = memref.distinct_objects %[[ARG0]], %[[ARG1]], %[[ARG2]] : memref, memref, memref + %1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref, memref, memref + // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : memref, memref, memref + return %1, %2, %3 : memref, memref, memref +} + // CHECK-LABEL: func @expand_collapse_shape_static func.func @expand_collapse_shape_static( %arg0: memref<3x4x5xf32>, From 1e24829dbc19f9d3afbb8d11cd1d9a334c9c2c80 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 10 Sep 2025 17:25:38 +0200 Subject: [PATCH 2/4] verifier test --- mlir/test/Dialect/MemRef/invalid.mlir | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 3f96d907632b7..67951f8ef0765 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -1169,3 +1169,11 @@ func.func @expand_shape_invalid_output_shape( into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>> return } + +// ----- + +func.func @Invalid_distinct_objects(%arg0: memref, %arg1: memref) -> (memref, memref) { + // expected-error @+1 {{operand types and result types must match}} + %0, %1 = "memref.distinct_objects"(%arg0, %arg1) : (memref, memref) -> (memref, memref) + return %0, %1 : memref, memref +} From 4b943e82d1727e00512c35fb0034beeab7346e1e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 23 Sep 2025 23:13:13 +0200 Subject: [PATCH 3/4] comments, reject 0-poerand version --- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 20 +++++++++---------- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 6 ++++-- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 ++++ .../MemRefToLLVM/memref-to-llvm.mlir | 11 ++++++++++ mlir/test/Dialect/MemRef/invalid.mlir | 10 +++++++++- 5 files changed, 38 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 933fb87de30ab..f75e311645426 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -187,16 +187,16 @@ def DistinctObjectsOp : MemRef_Op<"distinct_objects", [ ]> { let summary = "assumption that acesses to specific memrefs will never alias"; let description = [{ - The `distinct_objects` operation takes a list of memrefs and returns a list of - memrefs of the same types, with the additional assumption that accesses to - these memrefs will never alias with each other. This means that loads and - stores to different memrefs in the list can be safely reordered. - - If the memrefs do alias, the behavior is undefined. This operation doesn't - affect the semantics of a program where the non-aliasing assumption holds - true. It is intended for optimization purposes, allowing the compiler to - generate more efficient code based on the non-aliasing assumption. The - optimization is best-effort. + The `distinct_objects` operation takes a list of memrefs and returns the same + memrefs, with the additional assumption that accesses to them will never + alias with each other. This means that loads and stores to different + memrefs in the list can be safely reordered. + + If the memrefs do alias, the load/store behavior is undefined. This + operation doesn't affect the semantics of a valid program. It is + intended for optimization purposes, allowing the compiler to generate more + efficient code based on the non-aliasing assumption. The optimization is + best-effort. Example: diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 571e5000b3f51..64270726f4a01 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -476,10 +476,12 @@ struct DistinctObjectsOpLowering matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ValueRange operands = adaptor.getOperands(); - if (operands.empty()) { - rewriter.eraseOp(op); + if (operands.size() <= 1) { + // Fast path. + rewriter.replaceOp(op, operands); return success(); } + Location loc = op.getLoc(); SmallVector ptrs; for (auto [origOperand, newOperand] : diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 62fb57f24a870..0bca922b0c804 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -549,6 +549,10 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) { LogicalResult DistinctObjectsOp::verify() { if (getOperandTypes() != getResultTypes()) return emitOpError("operand types and result types must match"); + + if (getOperandTypes().empty()) + return emitOpError("expected at least one operand"); + return success(); } diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 3eb8df093af10..0cbe064572911 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -214,6 +214,17 @@ func.func @distinct_objects(%arg0: memref, %arg1: memref, %arg2: m // ----- +// ALL-LABEL: func @distinct_objects_noop +// ALL-SAME: (%[[ARG0:.*]]: memref) +func.func @distinct_objects_noop(%arg0: memref) -> memref { +// 1-operand version is noop +// ALL-NEXT: return %[[ARG0]] + %1 = memref.distinct_objects %arg0 : memref + return %1 : memref +} + +// ----- + // CHECK-LABEL: func @assume_alignment_w_offset // CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) { diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 67951f8ef0765..5ff292058ccc1 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -1172,8 +1172,16 @@ func.func @expand_shape_invalid_output_shape( // ----- -func.func @Invalid_distinct_objects(%arg0: memref, %arg1: memref) -> (memref, memref) { +func.func @distinct_objects_types_mismatch(%arg0: memref, %arg1: memref) -> (memref, memref) { // expected-error @+1 {{operand types and result types must match}} %0, %1 = "memref.distinct_objects"(%arg0, %arg1) : (memref, memref) -> (memref, memref) return %0, %1 : memref, memref } + +// ----- + +func.func @distinct_objects_0_operands() { + // expected-error @+1 {{expected at least one operand}} + "memref.distinct_objects"() : () -> () + return +} From 123a6921333111782c98bf8357546898404350b2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 25 Sep 2025 23:13:01 +0200 Subject: [PATCH 4/4] use bufferPtr --- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 64270726f4a01..c62137721a2b9 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -487,8 +487,9 @@ struct DistinctObjectsOpLowering for (auto [origOperand, newOperand] : llvm::zip_equal(op.getOperands(), operands)) { auto memrefType = cast(origOperand.getType()); - Value ptr = getStridedElementPtr(rewriter, loc, memrefType, newOperand, - /*indices=*/{}); + MemRefDescriptor memRefDescriptor(newOperand); + Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), + memrefType); ptrs.push_back(ptr); }