Skip to content

Commit

Permalink
[mlir] Add support for memref.alloca sub-byte emulation (#73138)
Browse files Browse the repository at this point in the history
Adds a similar case to `memref.alloc` for `memref.alloca` in
EmulateNarrowTypes.

Fixes iree-org/iree#15515
  • Loading branch information
Max191 committed Nov 28, 2023
1 parent e1f911e commit b823f84
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 17 deletions.
39 changes: 22 additions & 17 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,22 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
namespace {

//===----------------------------------------------------------------------===//
// ConvertMemRefAlloc
// ConvertMemRefAllocation
//===----------------------------------------------------------------------===//

struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
template <typename OpTy>
struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto currentType = op.getMemref().getType().cast<MemRefType>();
auto newResultType =
getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
static_assert(std::is_same<OpTy, memref::AllocOp>() ||
std::is_same<OpTy, memref::AllocaOp>(),
"expected only memref::AllocOp or memref::AllocaOp");
auto currentType = cast<MemRefType>(op.getMemref().getType());
auto newResultType = dyn_cast<MemRefType>(
this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
Expand All @@ -132,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {

// Special case zero-rank memrefs.
if (currentType.getRank() == 0) {
rewriter.replaceOpWithNewOp<memref::AllocOp>(
op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
adaptor.getAlignmentAttr());
rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
adaptor.getSymbolOperands(),
adaptor.getAlignmentAttr());
return success();
}

Expand All @@ -156,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
rewriter, loc, linearizedMemRefInfo.linearizedSize));
}

rewriter.replaceOpWithNewOp<memref::AllocOp>(
op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
adaptor.getAlignmentAttr());
rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
adaptor.getSymbolOperands(),
adaptor.getAlignmentAttr());
return success();
}
};
Expand Down Expand Up @@ -344,10 +348,11 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {

// Populate `memref.*` conversion patterns.
patterns
.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
ConvertMemRefReinterpretCast>(typeConverter,
patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}

Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,36 @@ func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 {
// CHECK32: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i4
// CHECK32: return %[[TRUNC]]

// -----

func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
%0 = memref.alloca() : memref<5xi4>
%1 = memref.load %0[%arg0] : memref<5xi4>
return %1 : i4
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
// CHECK: func @memref_alloca_load_i4(
// CHECK-SAME: %[[ARG0:.+]]: index
// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
// CHECK: return %[[TRUNC]]

// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
// CHECK32: func @memref_alloca_load_i4(
// CHECK32-SAME: %[[ARG0:.+]]: index
// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
// CHECK32: return %[[TRUNC]]

0 comments on commit b823f84

Please sign in to comment.