diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 9f58e9055acad..553af2adb60f3 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -17,6 +17,9 @@ #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" @@ -29,6 +32,26 @@ using namespace mlir; // Utility functions //===----------------------------------------------------------------------===// +/// Replaces the memref::StoreOp with two new memref::AtomicRMWOps. The first +/// memref::AtomicRMWOp sets the destination bits to all zero to prepare the +/// destination byte to be written to. The second memref::AtomicRMWOp does the +/// writing of the value to store, using an `ori` type operation. The value +/// to store and the write mask should both have the destination type bitwidth, +/// and the bits of the value to store should be all zero except for the bits +/// aligned with the store destination. +static void replaceStoreWithAtomics(ConversionPatternRewriter &rewriter, + memref::StoreOp op, Value writeMask, + Value storeVal, Value memref, + ValueRange storeIndices) { + // Clear destination bits + rewriter.create(op.getLoc(), arith::AtomicRMWKind::andi, + writeMask, memref, storeIndices); + // Write srcs bits to destination + rewriter.create(op->getLoc(), arith::AtomicRMWKind::ori, + storeVal, memref, storeIndices); + rewriter.eraseOp(op); +} + /// When data is loaded/stored in `targetBits` granularity, but is used in /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is /// treated as an array of elements of width `sourceBits`. @@ -43,13 +66,67 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, AffineExpr s0; bindSymbols(builder.getContext(), s0); int scaleFactor = targetBits / sourceBits; - OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply( - builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx}); + AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits; + OpFoldResult offsetVal = + affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx}); Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal); IntegerType dstType = builder.getIntegerType(targetBits); return builder.create(loc, dstType, bitOffset); } +/// When writing a subbyte size, writing needs to happen atomically in case of +/// another write happening on the same byte at the same time. To do the write, +/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte +/// store. This function returns the appropriate mask for clearing these bits. +static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices, + int64_t srcBits, int64_t dstBits, + Value bitwidthOffset, OpBuilder &builder) { + auto dstIntegerType = builder.getIntegerType(dstBits); + auto maskRightAlignedAttr = + builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1); + Value maskRightAligned = + builder + .create(loc, dstIntegerType, maskRightAlignedAttr) + .getResult(); + Value writeMaskInverse = + builder.create(loc, maskRightAligned, bitwidthOffset); + auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1); + Value flipVal = + builder.create(loc, dstIntegerType, flipValAttr) + .getResult(); + return builder.create(loc, writeMaskInverse, flipVal); +} + +/// Returns the scaled linearized index based on the `srcBits` and `dstBits` +/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and +/// the returned index has the granularity of `dstBits` +static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, + OpFoldResult linearizedIndex, + int64_t srcBits, int64_t dstBits) { + AffineExpr s0; + bindSymbols(builder.getContext(), s0); + int64_t scaler = dstBits / srcBits; + OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply( + builder, loc, s0.floorDiv(scaler), {linearizedIndex}); + return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices); +} + +static OpFoldResult +getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, + const SmallVector &indices, + Value memref) { + auto stridedMetadata = + builder.create(loc, memref); + OpFoldResult linearizedIndices; + std::tie(std::ignore, linearizedIndices) = + memref::getLinearizedMemRefOffsetAndSize( + builder, loc, srcBits, srcBits, + stridedMetadata.getConstifiedMixedOffset(), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides(), indices); + return linearizedIndices; +} + namespace { //===----------------------------------------------------------------------===// @@ -155,32 +232,15 @@ struct ConvertMemRefLoad final : OpConversionPattern { bitsLoad = rewriter.create(loc, adaptor.getMemref(), ValueRange{}); } else { - SmallVector indices = - getAsOpFoldResult(adaptor.getIndices()); - - auto stridedMetadata = rewriter.create( - loc, op.getMemRef()); - // Linearize the indices of the original load instruction. Do not account // for the scaling yet. This will be accounted for later. - OpFoldResult linearizedIndices; - std::tie(std::ignore, linearizedIndices) = - memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, srcBits, - stridedMetadata.getConstifiedMixedOffset(), - stridedMetadata.getConstifiedMixedSizes(), - stridedMetadata.getConstifiedMixedStrides(), indices); - - AffineExpr s0; - bindSymbols(rewriter.getContext(), s0); - int64_t scaler = dstBits / srcBits; - OpFoldResult scaledLinearizedIndices = - affine::makeComposedFoldedAffineApply( - rewriter, loc, s0.floorDiv(scaler), {linearizedIndices}); + OpFoldResult linearizedIndices = getLinearizedSrcIndices( + rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); + Value newLoad = rewriter.create( loc, adaptor.getMemref(), - getValueOrCreateConstantIndexOp(rewriter, loc, - scaledLinearizedIndices)); + getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits, + dstBits)); // Get the offset and shift the bits to the rightmost. // Note, currently only the big-endian is supported. @@ -211,6 +271,60 @@ struct ConvertMemRefLoad final : OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// ConvertMemrefStore +//===----------------------------------------------------------------------===// + +struct ConvertMemrefStore final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto convertedType = adaptor.getMemref().getType().cast(); + int srcBits = op.getMemRefType().getElementTypeBitWidth(); + int dstBits = convertedType.getElementTypeBitWidth(); + auto dstIntegerType = rewriter.getIntegerType(dstBits); + if (dstBits % srcBits != 0) { + return rewriter.notifyMatchFailure( + op, "only dstBits % srcBits == 0 supported"); + } + + Location loc = op.getLoc(); + Value extendedInput = rewriter.create(loc, dstIntegerType, + adaptor.getValue()); + + // Special case 0-rank memref stores. We can compute the mask at compile + // time. + if (convertedType.getRank() == 0) { + // Create mask to clear destination bits + auto writeMaskValAttr = + rewriter.getIntegerAttr(dstIntegerType, ~(1 << (srcBits)) - 1); + Value writeMask = rewriter.create(loc, dstIntegerType, + writeMaskValAttr); + + replaceStoreWithAtomics(rewriter, op, writeMask, extendedInput, + adaptor.getMemref(), ValueRange{}); + return success(); + } + + OpFoldResult linearizedIndices = getLinearizedSrcIndices( + rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); + Value storeIndices = getIndicesForLoadOrStore( + rewriter, loc, linearizedIndices, srcBits, dstBits); + Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits, + dstBits, rewriter); + Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits, + dstBits, bitwidthOffset, rewriter); + // Align the value to write with the destination bits + Value alignedVal = + rewriter.create(loc, extendedInput, bitwidthOffset); + replaceStoreWithAtomics(rewriter, op, writeMask, alignedVal, + adaptor.getMemref(), storeIndices); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertMemRefSubview //===----------------------------------------------------------------------===// @@ -292,7 +406,7 @@ void memref::populateMemRefNarrowTypeEmulationPatterns( // Populate `memref.*` conversion patterns. patterns.add( + ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemrefStore>( typeConverter, patterns.getContext()); memref::populateResolveExtractStridedMetadataPatterns(patterns); } diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index 6ed97f05aa7cf..22c5947fd2ac9 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -174,3 +174,172 @@ func.func @memref_strided_i4(%idx : index) -> i4 { // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32> // CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>> // CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]] + +// ----- + +func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () { + %0 = memref.alloc() : memref<5xi4> + memref.store %arg1, %0[%arg0] : memref<5xi4> + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)> +// CHECK: func @memref_store_i4( +// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4 +// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> +// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8 +// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] +// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8 +// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8 +// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8 +// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8 +// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8 +// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8 +// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8 +// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8 +// CHECK: return + +// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)> +// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)> +// CHECK32: func @memref_store_i4( +// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4 +// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32> +// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32 +// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] +// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32 +// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32 +// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32 +// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32 +// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32 +// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32 +// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32 +// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32 +// CHECK32: return + +// ----- + +func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () { + %0 = memref.alloc() : memref<3x125xi4> + memref.assume_alignment %0, 64 : memref<3x125xi4> + memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4> + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)> +// CHECK: func @memref_store_i4_rank2( +// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4 +// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8> +// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8> +// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8 +// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] +// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]] +// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8 +// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8 +// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8 +// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8 +// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8 +// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8 +// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8 +// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8 +// CHECK: return + +// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)> +// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)> +// CHECK32: func @memref_store_i4_rank2( +// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4 +// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32> +// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32> +// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32 +// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32 +// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32 +// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32 +// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32 +// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32 +// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32 +// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32 +// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32 +// CHECK32: return + +// ----- + +func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () { + %0 = memref.alloc(%arg0, %arg1) : memref + memref.store %arg4, %0[%arg2, %arg3] : memref + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)> +// CHECK: func @memref_store_i4_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4 +// CHECK-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] +// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref +// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8 +// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] +// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] +// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8 +// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8 +// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8 +// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8 +// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8 +// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8 +// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref) -> i8 +// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref) -> i8 +// CHECK: return + +// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)> +// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)> +// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)> +// CHECK32: func @memref_store_i4_dynamic( +// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index +// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4 +// CHECK32-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref +// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32 +// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] +// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] +// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32 +// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32 +// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32 +// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32 +// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32 +// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32 +// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref) -> i32 +// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref) -> i32 +// CHECK32: return + +// ----- + +func.func @rank_zero_memref_store(%arg0: i4) -> () { + %0 = memref.alloc() : memref + memref.store %arg0, %0[] : memref + return +} +// CHECK-LABEL: func @rank_zero_memref +// CHECK-SAME: %[[ARG0:.+]]: i4 +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref +// CHECK: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8 +// CHECK: %[[MASK:.+]] = arith.constant -18 : i8 +// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i8, memref) -> i8 +// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i8, memref) -> i8 +// CHECK: return + +// CHECK32-LABEL: func @rank_zero_memref +// CHECK32-SAME: %[[ARG0:.+]]: i4 +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref +// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32 +// CHECK32: %[[MASK:.+]] = arith.constant -18 : i32 +// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i32, memref) -> i32 +// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i32, memref) -> i32 +// CHECK32: return