diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index 62745f8fa1dfa..4d6c54d74d2a9 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -89,9 +89,15 @@ void populateMemRefWideIntEmulationConversions( /// over wider types. /// When `disableAtomicRMW` is true, the store patterns generate non-atomic /// read-modify-write sequences instead of atomic operations. +/// When `assumeAligned` is true, `memref.subview` and +/// `memref.reinterpret_cast` patterns accept dynamic offsets under the +/// alignment contract that the caller guarantees those offsets are a multiple +/// of `dstBits / srcBits`. When false (the default), dynamic offsets are +/// rejected to preserve soundness for callers that cannot prove divisibility. void populateMemRefNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns, bool disableAtomicRMW = false); + RewritePatternSet &patterns, bool disableAtomicRMW = false, + bool assumeAligned = false); /// Appends type conversions for emulating memref operations over narrow types /// with ops over wider types. diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h index 9af0f301d763c..f58b776138def 100644 --- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h +++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h @@ -31,6 +31,18 @@ namespace memref { /// contiguous chunk of memory. bool isStaticShapeAndContiguousRowMajor(MemRefType type); +/// Controls how the per-dimension contribution to `linearizedSize` is divided +/// by `dstBits / srcBits` when scaling down to the emulated type. The offset +/// and intra-data offset are unaffected; they always use floor division and +/// remainder respectively. +/// - `Floor`: round each `stride * size / scaler` down. Suitable for indexing +/// computations where a partial trailing byte is not included. +/// - `Ceil`: round up, matching the result-shape size used by narrow-type +/// memref type conversion (see `getLinearizedShape`). Use this when the +/// caller needs the linearized size to cover all source elements, e.g. when +/// building the size attribute of a converted `memref.reinterpret_cast`. +enum class LinearizedDivKind { Floor, Ceil }; + /// For a `memref` with `offset`, `sizes` and `strides`, returns the /// offset, size, and potentially the size padded at the front to use for the /// linearized `memref`. @@ -47,6 +59,8 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type); /// load/store, the memory region emulated is larger than the actual memory /// region needed. `intraDataOffset` returns the element offset of the data /// relevant at the beginning. +/// - `sizeDivKind` selects floor vs ceil rounding for the `linearizedSize` +/// contribution from each dimension (see `LinearizedDivKind`). struct LinearizedMemRefInfo { OpFoldResult linearizedOffset; OpFoldResult linearizedSize; @@ -55,7 +69,8 @@ struct LinearizedMemRefInfo { std::pair getLinearizedMemRefOffsetAndSize( OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef sizes, - ArrayRef strides, ArrayRef indices = {}); + ArrayRef strides, ArrayRef indices = {}, + LinearizedDivKind sizeDivKind = LinearizedDivKind::Floor); /// For a `memref` with `offset` and `sizes`, returns the /// offset and size to use for the linearized `memref`, assuming that @@ -64,10 +79,12 @@ std::pair getLinearizedMemRefOffsetAndSize( /// element type with bitwidth `srcBits` using element type with /// bitwidth `dstBits`, the linearized offset and size are /// scaled down by `dstBits`/`srcBits`. -LinearizedMemRefInfo -getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, - int dstBits, OpFoldResult offset, - ArrayRef sizes); +/// - `sizeDivKind` selects floor vs ceil rounding for the `linearizedSize` +/// contribution from each dimension (see `LinearizedDivKind`). +LinearizedMemRefInfo getLinearizedMemRefOffsetAndSize( + OpBuilder &builder, Location loc, int srcBits, int dstBits, + OpFoldResult offset, ArrayRef sizes, + LinearizedDivKind sizeDivKind = LinearizedDivKind::Floor); /// Track temporary allocations that are never read from. If this is the case /// it means both the allocations and associated stores can be removed. diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 8686f22c9e3c2..a11e14faa5475 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -32,15 +32,24 @@ using namespace mlir; //===----------------------------------------------------------------------===// /// Converts a memref::ReinterpretCastOp to the converted type. The result -/// MemRefType of the old op must have a rank and stride of 1, with static -/// offset and size. The number of bits in the offset must evenly divide the -/// bitwidth of the new converted type. +/// memref is linearized to a rank-1 byte view (or rank-0 if the source is +/// rank-0). When `assumeAligned` is true, dynamic offsets are accepted under +/// the alignment contract that the caller guarantees the offset is a multiple +/// of `dstBits / srcBits`; statically-provable misalignment is rejected. +/// When `assumeAligned` is false, dynamic offsets are rejected outright since +/// divisibility cannot be proven from the IR alone. static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter, memref::ReinterpretCastOp::Adaptor adaptor, - memref::ReinterpretCastOp op, MemRefType newTy) { - auto convertedElementType = newTy.getElementType(); - auto oldElementType = op.getType().getElementType(); + memref::ReinterpretCastOp op, MemRefType newTy, + bool assumeAligned) { + if (newTy == op.getType()) { + return rewriter.notifyMatchFailure( + op, "result type was not converted by narrow-type emulation"); + } + + Type convertedElementType = newTy.getElementType(); + Type oldElementType = op.getType().getElementType(); int srcBits = oldElementType.getIntOrFloatBitWidth(); int dstBits = convertedElementType.getIntOrFloatBitWidth(); if (dstBits % srcBits != 0) { @@ -48,35 +57,70 @@ convertCastingOp(ConversionPatternRewriter &rewriter, "only dstBits % srcBits == 0 supported"); } - // Only support stride of 1. - if (llvm::any_of(op.getStaticStrides(), - [](int64_t stride) { return stride != 1; })) { - return rewriter.notifyMatchFailure(op->getLoc(), - "stride != 1 is not supported"); + ArrayRef staticStrides = op.getStaticStrides(); + if (!staticStrides.empty() && staticStrides.back() != 1) { + return rewriter.notifyMatchFailure( + op->getLoc(), "innermost stride != 1 is not supported"); + } + + // TODO: support dynamic sizes. Requires a divisibility analysis or a + // stronger alignment contract; tracked as follow-up work. + if (llvm::is_contained(op.getStaticSizes(), ShapedType::kDynamic)) { + return rewriter.notifyMatchFailure(op, "dynamic sizes are not supported"); } - auto sizes = op.getStaticSizes(); - int64_t offset = op.getStaticOffset(0); - // Only support static sizes and offsets. - if (llvm::is_contained(sizes, ShapedType::kDynamic) || - offset == ShapedType::kDynamic) { + if (!memref::isStaticShapeAndContiguousRowMajor(op.getType())) { return rewriter.notifyMatchFailure( - op, "dynamic size or offset is not supported"); + op, "result memref is not row-major contiguous"); } - int elementsPerByte = dstBits / srcBits; - if (offset % elementsPerByte != 0) { + // Reject dynamic offsets unless the caller has opted into the alignment + // contract via `assumeAligned`. Without it we cannot prove the offset is a + // multiple of `dstBits / srcBits`. + if (!assumeAligned && + llvm::is_contained(op.getStaticOffsets(), ShapedType::kDynamic)) { return rewriter.notifyMatchFailure( - op, "offset not multiple of elementsPerByte is not supported"); + op, "dynamic offsets require assumeAligned=true to ensure the offset " + "is a multiple of dstBits / srcBits"); } - SmallVector size; - if (!sizes.empty()) - size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte)); - offset = offset / elementsPerByte; + Location loc = op.getLoc(); + SmallVector mixedSizes = op.getMixedSizes(); + OpFoldResult origOffset = op.getMixedOffsets()[0]; + + SmallVector newSizes; + SmallVector newStrides; + OpFoldResult newOffset; + OpFoldResult intraOffset; + if (mixedSizes.empty()) { + int64_t elementsPerByte = dstBits / srcBits; + AffineExpr s0; + bindSymbols(rewriter.getContext(), s0); + newOffset = affine::makeComposedFoldedAffineApply( + rewriter, loc, s0.floorDiv(elementsPerByte), {origOffset}); + intraOffset = affine::makeComposedFoldedAffineApply( + rewriter, loc, s0 % elementsPerByte, {origOffset}); + } else { + // Use ceil division so the produced linearized size matches the converted + // result memref shape (see `getLinearizedShape` in the type converter), + // which also rounds up to fit all source elements. + memref::LinearizedMemRefInfo info = + memref::getLinearizedMemRefOffsetAndSize( + rewriter, loc, srcBits, dstBits, origOffset, mixedSizes, + memref::LinearizedDivKind::Ceil); + newOffset = info.linearizedOffset; + intraOffset = info.intraDataOffset; + newSizes.push_back(info.linearizedSize); + newStrides.push_back(rewriter.getIndexAttr(1)); + } + + if (auto cst = getConstantIntValue(intraOffset); cst && *cst != 0) { + return rewriter.notifyMatchFailure( + op, "offset is provably not a multiple of dstBits / srcBits"); + } rewriter.replaceOpWithNewOp( - op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides()); + op, newTy, adaptor.getSource(), newOffset, newSizes, newStrides); return success(); } @@ -349,6 +393,32 @@ struct ConvertMemRefLoad final : OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// ConvertMemRefCast +//===----------------------------------------------------------------------===// + +/// `memref.cast` between two narrow-typed memrefs forwards through the type +/// converter to a cast between the converted byte-typed memrefs. +struct ConvertMemRefCast final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert memref type: {0}", op.getType())); + } + if (newTy == op.getType()) + return failure(); + + rewriter.replaceOpWithNewOp(op, newTy, adaptor.getSource()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertMemRefMemorySpaceCast //===----------------------------------------------------------------------===// @@ -377,11 +447,15 @@ struct ConvertMemRefMemorySpaceCast final // ConvertMemRefReinterpretCast //===----------------------------------------------------------------------===// -/// Output types should be at most one dimensional, so only the 0 or 1 -/// dimensional cases are supported. +/// Forwards to `convertCastingOp`, which enforces all preconditions. +/// `assumeAligned` is propagated from the populate entry point and controls +/// acceptance of dynamic offsets. struct ConvertMemRefReinterpretCast final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + ConvertMemRefReinterpretCast(const TypeConverter &typeConverter, + MLIRContext *context, bool assumeAligned) + : OpConversionPattern(typeConverter, context), + assumeAligned(assumeAligned) {} LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, @@ -394,14 +468,11 @@ struct ConvertMemRefReinterpretCast final llvm::formatv("failed to convert memref type: {0}", op.getType())); } - // Only support for 0 or 1 dimensional cases. - if (op.getType().getRank() > 1) { - return rewriter.notifyMatchFailure( - op->getLoc(), "subview with rank > 1 is not supported"); - } - - return convertCastingOp(rewriter, adaptor, op, newTy); + return convertCastingOp(rewriter, adaptor, op, newTy, assumeAligned); } + +private: + bool assumeAligned; }; //===----------------------------------------------------------------------===// @@ -503,11 +574,17 @@ struct ConvertMemrefStore final : OpConversionPattern { //===----------------------------------------------------------------------===// /// Emulating narrow ints on subview have limited support, supporting only -/// static offset and size and stride of 1. Ideally, the subview should be +/// static sizes and stride of 1. When `assumeAligned` is true, dynamic +/// offsets are accepted under the alignment contract that the caller +/// guarantees the offset is a multiple of `dstBits / srcBits`. Without that +/// opt-in, dynamic offsets are rejected. Ideally, the subview should be /// folded away before running narrow type emulation, and this pattern should /// only run for cases that can't be folded. struct ConvertMemRefSubview final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + ConvertMemRefSubview(const TypeConverter &typeConverter, MLIRContext *context, + bool assumeAligned) + : OpConversionPattern(typeConverter, context), + assumeAligned(assumeAligned) {} LogicalResult matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, @@ -543,12 +620,21 @@ struct ConvertMemRefSubview final : OpConversionPattern { } auto sizes = subViewOp.getStaticSizes(); - int64_t lastOffset = subViewOp.getStaticOffsets().back(); - // Only support static sizes and offsets. - if (llvm::is_contained(sizes, ShapedType::kDynamic) || - lastOffset == ShapedType::kDynamic) { + // TODO: support dynamic sizes. Requires a divisibility analysis or a + // stronger alignment contract; tracked as follow-up work. + if (llvm::is_contained(sizes, ShapedType::kDynamic)) { + return rewriter.notifyMatchFailure(subViewOp->getLoc(), + "dynamic size is not supported"); + } + + // Reject dynamic offsets unless the caller has opted into the alignment + // contract via `assumeAligned`. + if (!assumeAligned && llvm::is_contained(subViewOp.getStaticOffsets(), + ShapedType::kDynamic)) { return rewriter.notifyMatchFailure( - subViewOp->getLoc(), "dynamic size or offset is not supported"); + subViewOp, + "dynamic offsets require assumeAligned=true to ensure the offset " + "is a multiple of dstBits / srcBits"); } // Transform the offsets, sizes and strides according to the emulation. @@ -566,11 +652,21 @@ struct ConvertMemRefSubview final : OpConversionPattern { getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)); + if (auto cst = getConstantIntValue(linearizedInfo.intraDataOffset); + cst && *cst != 0) { + return rewriter.notifyMatchFailure( + subViewOp, + "subview offset is provably not a multiple of dstBits / srcBits"); + } + rewriter.replaceOpWithNewOp( subViewOp, newTy, adaptor.getSource(), linearizedIndices, linearizedInfo.linearizedSize, strides.back()); return success(); } + +private: + bool assumeAligned; }; //===----------------------------------------------------------------------===// @@ -630,16 +726,18 @@ struct ConvertMemRefExpandShape final void memref::populateMemRefNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns, bool disableAtomicRMW) { + RewritePatternSet &patterns, bool disableAtomicRMW, bool assumeAligned) { // Populate `memref.*` conversion patterns. - patterns.add, - ConvertMemRefAllocation, ConvertMemRefCopy, - ConvertMemRefDealloc, ConvertMemRefCollapseShape, - ConvertMemRefExpandShape, ConvertMemRefLoad, - ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast, - ConvertMemRefSubview, ConvertMemRefReinterpretCast>( - typeConverter, patterns.getContext()); + patterns + .add, + ConvertMemRefAllocation, ConvertMemRefCast, + ConvertMemRefCopy, ConvertMemRefDealloc, ConvertMemRefCollapseShape, + ConvertMemRefExpandShape, ConvertMemRefLoad, + ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast>( + typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext(), assumeAligned); patterns.insert(typeConverter, patterns.getContext(), disableAtomicRMW); memref::populateResolveExtractStridedMetadataPatterns(patterns); diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index cf126cd85ddce..0899b1a9faeb4 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -51,7 +51,8 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type) { std::pair getLinearizedMemRefOffsetAndSize( OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef sizes, - ArrayRef strides, ArrayRef indices) { + ArrayRef strides, ArrayRef indices, + LinearizedDivKind sizeDivKind) { unsigned sourceRank = sizes.size(); assert(sizes.size() == strides.size() && "expected as many sizes as strides for a memref"); @@ -88,7 +89,10 @@ std::pair getLinearizedMemRefOffsetAndSize( AffineExpr sizeExpr = symbols[symbolIndex++]; values.push_back(sizes[i]); - productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler)); + AffineExpr product = strideExpr * sizeExpr; + productExpressions.push_back(sizeDivKind == LinearizedDivKind::Ceil + ? product.ceilDiv(scaler) + : product.floorDiv(scaler)); } AffineMap maxMap = AffineMap::get( /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions, @@ -112,7 +116,8 @@ std::pair getLinearizedMemRefOffsetAndSize( LinearizedMemRefInfo getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, - ArrayRef sizes) { + ArrayRef sizes, + LinearizedDivKind sizeDivKind) { SmallVector strides(sizes.size()); if (!sizes.empty()) { strides.back() = builder.getIndexAttr(1); @@ -128,7 +133,8 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, LinearizedMemRefInfo linearizedMemRefInfo; std::tie(linearizedMemRefInfo, std::ignore) = getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset, - sizes, strides); + sizes, strides, /*indices=*/{}, + sizeDivKind); return linearizedMemRefInfo; } diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-no-assume-aligned.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-no-assume-aligned.mlir new file mode 100644 index 0000000000000..3625f91cedbd6 --- /dev/null +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-no-assume-aligned.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s + +// Without `assume-aligned=true`, dynamic offsets in `memref.subview` and +// `memref.reinterpret_cast` cannot be proven to be multiples of +// `dstBits / srcBits`. The patterns must reject them so partial conversion +// fails to legalize the op. + +func.func @negative_subview_dynamic_inner_offset_i4(%off: index) -> i4 { + %c0 = arith.constant 0 : index + %arr = memref.alloc() : memref<128xi4> + // expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}} + %subview = memref.subview %arr[%off] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset: ?>> + %ld = memref.load %subview[%c0] : memref<32xi4, strided<[1], offset: ?>> + return %ld : i4 +} + +// ----- + +func.func @negative_reinterpret_cast_memref_rank3_dynamic_offset_i4(%arg0: memref<2x4x8xi4>, %off: index) -> memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>> { + // expected-error @+1 {{failed to legalize operation 'memref.reinterpret_cast' that was explicitly marked illegal}} + %r = memref.reinterpret_cast %arg0 to offset: [%off], sizes: [4, 4, 8], strides: [32, 8, 1] : memref<2x4x8xi4> to memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>> + return %r : memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>> +} diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index dd64ecc98721a..adc8fe3b36096 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s | FileCheck %s -// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32 +// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8 assume-aligned=true" --cse --verify-diagnostics --split-input-file %s | FileCheck %s +// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32 assume-aligned=true" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32 // Expect no conversions. func.func @memref_i8() -> i8 { @@ -238,6 +238,51 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 { // ----- +func.func @memref_subview_dynamic_inner_offset_i4(%off: index) -> i4 { + %c0 = arith.constant 0 : index + %arr = memref.alloc() : memref<128xi4> + %subview = memref.subview %arr[%off] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset: ?>> + %ld = memref.load %subview[%c0] : memref<32xi4, strided<[1], offset: ?>> + return %ld : i4 +} + +// CHECK-LABEL: func.func @memref_subview_dynamic_inner_offset_i4( +// CHECK-SAME: %[[OFF:[a-zA-Z0-9_]+]]: index +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8> +// CHECK: %[[IDX:.+]] = affine.apply {{.*}}%[[OFF]] +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][%[[IDX]]] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: ?>> +// CHECK: memref.load %[[SUBVIEW]] + +// CHECK32-LABEL: func.func @memref_subview_dynamic_inner_offset_i4( +// CHECK32-SAME: %[[OFF:[a-zA-Z0-9_]+]]: index +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32> +// CHECK32: %[[IDX:.+]] = affine.apply {{.*}}%[[OFF]] +// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][%[[IDX]]] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: ?>> +// CHECK32: memref.load %[[SUBVIEW]] + +// ----- + +// Dynamic innermost offset that is provably aligned (multiple of +// `dstBits / srcBits`). The affine simplifier folds the `floordiv` away. + +func.func @memref_subview_aligned_dynamic_inner_offset_i4(%x: index) -> i4 { + %c0 = arith.constant 0 : index + %off = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%x] + %arr = memref.alloc() : memref<128xi4> + %subview = memref.subview %arr[%off] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset: ?>> + %ld = memref.load %subview[%c0] : memref<32xi4, strided<[1], offset: ?>> + return %ld : i4 +} + +// CHECK-LABEL: func.func @memref_subview_aligned_dynamic_inner_offset_i4( +// CHECK-SAME: %[[X:[a-zA-Z0-9_]+]]: index +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8> +// CHECK-NOT: affine.apply +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][%[[X]]] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: ?>> +// CHECK: memref.load %[[SUBVIEW]] + +// ----- + func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 { %c0 = arith.constant 0 : index %arr = memref.alloc() : memref<40x40xi4> @@ -249,6 +294,61 @@ func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 { // ----- +// Rank-3 reinterpret_cast on a sub-byte (i4) memref with a static, aligned +// offset. + +func.func @reinterpret_cast_memref_rank3_static_offset_i4(%arg0: memref<2x4x8xi4>) -> memref<4x4x8xi4, strided<[32, 8, 1]>> { + %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [4, 4, 8], strides: [32, 8, 1] : memref<2x4x8xi4> to memref<4x4x8xi4, strided<[32, 8, 1]>> + return %r : memref<4x4x8xi4, strided<[32, 8, 1]>> +} + +// CHECK-LABEL: func @reinterpret_cast_memref_rank3_static_offset_i4( +// CHECK-SAME: %[[ARG0:.+]]: memref<32xi8> +// CHECK: %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [64], strides: [1] : memref<32xi8> to memref<64xi8> +// CHECK: return %[[R]] + +// CHECK32-LABEL: func @reinterpret_cast_memref_rank3_static_offset_i4( +// CHECK32-SAME: %[[ARG0:.+]]: memref<8xi32> +// CHECK32: %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [16], strides: [1] : memref<8xi32> to memref<16xi32> +// CHECK32: return %[[R]] + +// ----- + +// Rank-3 reinterpret_cast with a dynamic offset accepted under the alignment +// contract. + +func.func @reinterpret_cast_memref_rank3_dynamic_offset_i4(%arg0: memref<2x4x8xi4>, %off: index) -> memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>> { + %r = memref.reinterpret_cast %arg0 to offset: [%off], sizes: [4, 4, 8], strides: [32, 8, 1] : memref<2x4x8xi4> to memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>> + return %r : memref<4x4x8xi4, strided<[32, 8, 1], offset: ?>> +} + +// CHECK-LABEL: func @reinterpret_cast_memref_rank3_dynamic_offset_i4( +// CHECK-SAME: %[[ARG0:.+]]: memref<32xi8>, +// CHECK-SAME: %[[OFF:.+]]: index +// CHECK: %[[NEWOFF:.+]] = affine.apply {{.*}}%[[OFF]] +// CHECK: %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[NEWOFF]]{{\]}}, sizes: [64], strides: [1] : memref<32xi8> to memref<64xi8, strided<[1], offset: ?>> +// CHECK: return %[[R]] + +// CHECK32-LABEL: func @reinterpret_cast_memref_rank3_dynamic_offset_i4( +// CHECK32-SAME: %[[ARG0:.+]]: memref<8xi32>, +// CHECK32-SAME: %[[OFF:.+]]: index +// CHECK32: %[[NEWOFF:.+]] = affine.apply {{.*}}%[[OFF]] +// CHECK32: %[[R:.+]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[NEWOFF]]{{\]}}, sizes: [16], strides: [1] : memref<8xi32> to memref<16xi32, strided<[1], offset: ?>> +// CHECK32: return %[[R]] + +// ----- + +// Provably-misaligned static offset (1 is not a multiple of i4 -> i8 ratio +// of 2). Lowering must fail. + +func.func @negative_reinterpret_cast_memref_misaligned_static_offset_i4(%arg0: memref<2x4x8xi4>) -> memref<4x4x8xi4, strided<[32, 8, 1], offset: 1>> { + // expected-error @+1 {{failed to legalize operation 'memref.reinterpret_cast' that was explicitly marked illegal}} + %r = memref.reinterpret_cast %arg0 to offset: [1], sizes: [4, 4, 8], strides: [32, 8, 1] : memref<2x4x8xi4> to memref<4x4x8xi4, strided<[32, 8, 1], offset: 1>> + return %r : memref<4x4x8xi4, strided<[32, 8, 1], offset: 1>> +} + +// ----- + func.func @reinterpret_cast_memref_load_0D() -> i4 { %0 = memref.alloc() : memref<5xi4> %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp index 9313a0945d86b..bec83a8dcbef9 100644 --- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -99,8 +99,8 @@ struct TestEmulateNarrowTypePass RewritePatternSet patterns(ctx); arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); - memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns, - disableAtomicRMW); + memref::populateMemRefNarrowTypeEmulationPatterns( + typeConverter, patterns, disableAtomicRMW, assumeAligned); vector::populateVectorNarrowTypeEmulationPatterns( typeConverter, patterns, disableAtomicRMW, assumeAligned);