Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
The `assume_alignment` operation takes a memref and an integer alignment
value. It returns a new SSA value of the same memref type, but associated
with the assumption that the underlying buffer is aligned to the given
alignment.
alignment.

If the buffer isn't aligned to the given alignment, its result is poison.
This operation doesn't affect the semantics of a program where the
Expand All @@ -168,14 +168,49 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
let extraClassDeclaration = [{
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }

Value getViewSource() { return getMemref(); }
}];

let hasVerifier = 1;
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// DistinctObjectsOp
//===----------------------------------------------------------------------===//

def DistinctObjectsOp : MemRef_Op<"distinct_objects", [
Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>
// ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument
]> {
let summary = "assumption that acesses to specific memrefs will never alias";
let description = [{
The `distinct_objects` operation takes a list of memrefs and returns the same
memrefs, with the additional assumption that accesses to them will never
alias with each other. This means that loads and stores to different
memrefs in the list can be safely reordered.

If the memrefs do alias, the load/store behavior is undefined. This
operation doesn't affect the semantics of a valid program. It is
intended for optimization purposes, allowing the compiler to generate more
efficient code based on the non-aliasing assumption. The optimization is
best-effort.

Example:

```mlir
%1, %2 = memref.distinct_objects %a, %b : memref<?xf32>, memref<?xf32>
```
}];
let arguments = (ins Variadic<AnyMemRef>:$operands);
let results = (outs Variadic<AnyMemRef>:$results);

let assemblyFormat = "$operands attr-dict `:` type($operands)";
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//
Expand Down
52 changes: 49 additions & 3 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,51 @@ struct AssumeAlignmentOpLowering
}
};

struct DistinctObjectsOpLowering
: public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
using ConvertOpToLLVMPattern<
memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {}

LogicalResult
matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ValueRange operands = adaptor.getOperands();
if (operands.size() <= 1) {
// Fast path.
rewriter.replaceOp(op, operands);
return success();
}

Location loc = op.getLoc();
SmallVector<Value> ptrs;
for (auto [origOperand, newOperand] :
llvm::zip_equal(op.getOperands(), operands)) {
auto memrefType = cast<MemRefType>(origOperand.getType());
MemRefDescriptor memRefDescriptor(newOperand);
Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
memrefType);
ptrs.push_back(ptr);
}

auto cond =
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
// Generate separate_storage assumptions for each pair of pointers.
for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
Value ptr1 = ptrs[i];
Value ptr2 = ptrs[j];
LLVM::AssumeOp::create(rewriter, loc, cond,
LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
}
}

rewriter.replaceOp(op, operands);
return success();
}
};

// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
Expand Down Expand Up @@ -1997,22 +2042,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
patterns.add<
AllocaOpLowering,
AllocaScopeOpLowering,
AtomicRMWOpLowering,
AssumeAlignmentOpLowering,
AtomicRMWOpLowering,
ConvertExtractAlignedPointerAsIndex,
DimOpLowering,
DistinctObjectsOpLowering,
ExtractStridedMetadataOpLowering,
GenericAtomicRMWOpLowering,
GetGlobalMemrefOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
MemorySpaceCastOpLowering,
MemRefReinterpretCastOpLowering,
MemRefReshapeOpLowering,
MemorySpaceCastOpLowering,
PrefetchOpLowering,
RankOpLowering,
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
StoreOpLowering,
SubViewOpLowering,
TransposeOpLowering,
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,29 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
return getMemref();
}

//===----------------------------------------------------------------------===//
// DistinctObjectsOp
//===----------------------------------------------------------------------===//

LogicalResult DistinctObjectsOp::verify() {
if (getOperandTypes() != getResultTypes())
return emitOpError("operand types and result types must match");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Should be implementable with the TypesMatchWith construct in ODS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a variadic number of inputs/results and we need their types to match pairwise. I didn't quite figured how to express it using existing constraints.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I missed that they are variadics! Makes sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test though? It a custom verifier, so we should test it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a test


if (getOperandTypes().empty())
return emitOpError("expected at least one operand");

return success();
}

LogicalResult DistinctObjectsOp::inferReturnTypes(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be done with AllTypesMatch?

Copy link
Contributor Author

@Hardcode84 Hardcode84 Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how to express it using existing constraints, see my prev comment on verify.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably add a constraint - something like EachTypeMatches<"src", "res"> - and possibly use it to do return type inference

But that's a separate PR

MLIRContext * /*context*/, std::optional<Location> /*location*/,
ValueRange operands, DictionaryAttr /*attributes*/,
OpaqueProperties /*properties*/, RegionRange /*regions*/,
SmallVectorImpl<Type> &inferredReturnTypes) {
llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
return success();
}

//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,36 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {

// -----

// ALL-LABEL: func @distinct_objects
// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
// ALL-DAG: %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xf16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL-DAG: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<?xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL-DAG: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : memref<?xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL: %[[PTR_0:.*]] = llvm.extractvalue %[[CAST_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL: %[[PTR_1:.*]] = llvm.extractvalue %[[CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL: %[[PTR_2:.*]] = llvm.extractvalue %[[CAST_2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// ALL: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_1]] : !llvm.ptr, !llvm.ptr)] : i1
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_1]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
%1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
}

// -----

// ALL-LABEL: func @distinct_objects_noop
// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>)
func.func @distinct_objects_noop(%arg0: memref<?xf16>) -> memref<?xf16> {
// 1-operand version is noop
// ALL-NEXT: return %[[ARG0]]
%1 = memref.distinct_objects %arg0 : memref<?xf16>
return %1 : memref<?xf16>
}

// -----

// CHECK-LABEL: func @assume_alignment_w_offset
// CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset
func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) {
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1169,3 +1169,19 @@ func.func @expand_shape_invalid_output_shape(
into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
return
}

// -----

func.func @distinct_objects_types_mismatch(%arg0: memref<?xf32>, %arg1: memref<?xi32>) -> (memref<?xi32>, memref<?xf32>) {
// expected-error @+1 {{operand types and result types must match}}
%0, %1 = "memref.distinct_objects"(%arg0, %arg1) : (memref<?xf32>, memref<?xi32>) -> (memref<?xi32>, memref<?xf32>)
return %0, %1 : memref<?xi32>, memref<?xf32>
}

// -----

func.func @distinct_objects_0_operands() {
// expected-error @+1 {{expected at least one operand}}
"memref.distinct_objects"() : () -> ()
return
}
9 changes: 9 additions & 0 deletions mlir/test/Dialect/MemRef/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,15 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
return
}

// CHECK-LABEL: func @distinct_objects
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
// CHECK: %[[RES:.*]]:3 = memref.distinct_objects %[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<?xf16>, memref<?xf32>, memref<?xf64>
%1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
}

// CHECK-LABEL: func @expand_collapse_shape_static
func.func @expand_collapse_shape_static(
%arg0: memref<3x4x5xf32>,
Expand Down