diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index dec5936fa7e83..e5801c3733ed5 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -112,18 +112,22 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, namespace { //===----------------------------------------------------------------------===// -// ConvertMemRefAlloc +// ConvertMemRefAllocation //===----------------------------------------------------------------------===// -struct ConvertMemRefAlloc final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +template +struct ConvertMemRefAllocation final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto currentType = op.getMemref().getType().cast(); - auto newResultType = - getTypeConverter()->convertType(op.getType()).dyn_cast(); + static_assert(std::is_same() || + std::is_same(), + "expected only memref::AllocOp or memref::AllocaOp"); + auto currentType = cast(op.getMemref().getType()); + auto newResultType = dyn_cast( + this->getTypeConverter()->convertType(op.getType())); if (!newResultType) { return rewriter.notifyMatchFailure( op->getLoc(), @@ -132,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern { // Special case zero-rank memrefs. if (currentType.getRank() == 0) { - rewriter.replaceOpWithNewOp( - op, newResultType, ValueRange{}, adaptor.getSymbolOperands(), - adaptor.getAlignmentAttr()); + rewriter.replaceOpWithNewOp(op, newResultType, ValueRange{}, + adaptor.getSymbolOperands(), + adaptor.getAlignmentAttr()); return success(); } @@ -156,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern { rewriter, loc, linearizedMemRefInfo.linearizedSize)); } - rewriter.replaceOpWithNewOp( - op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(), - adaptor.getAlignmentAttr()); + rewriter.replaceOpWithNewOp(op, newResultType, dynamicLinearizedSize, + adaptor.getSymbolOperands(), + adaptor.getAlignmentAttr()); return success(); } }; @@ -344,10 +348,11 @@ void memref::populateMemRefNarrowTypeEmulationPatterns( RewritePatternSet &patterns) { // Populate `memref.*` conversion patterns. - patterns - .add( - typeConverter, patterns.getContext()); + patterns.add, + ConvertMemRefAllocation, ConvertMemRefLoad, + ConvertMemRefAssumeAlignment, ConvertMemRefSubview, + ConvertMemRefReinterpretCast>(typeConverter, + patterns.getContext()); memref::populateResolveExtractStridedMetadataPatterns(patterns); } diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index 2c411defb47e3..dc32a59a1a149 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -232,3 +232,36 @@ func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 { // CHECK32: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32 // CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i4 // CHECK32: return %[[TRUNC]] + +// ----- + +func.func @memref_alloca_load_i4(%arg0: index) -> i4 { + %0 = memref.alloca() : memref<5xi4> + %1 = memref.load %0[%arg0] : memref<5xi4> + return %1 : i4 +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8) +// CHECK: func @memref_alloca_load_i4( +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8> +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]] +// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] +// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8 +// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]] +// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4 +// CHECK: return %[[TRUNC]] + +// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)> +// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32) +// CHECK32: func @memref_alloca_load_i4( +// CHECK32-SAME: %[[ARG0:.+]]: index +// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32> +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]] +// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] +// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32 +// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]] +// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4 +// CHECK32: return %[[TRUNC]]