Skip to content

Commit

Permalink
[mlir][Bufferize] Fix incorrect bufferization of rank-reducing tensor…
Browse files Browse the repository at this point in the history
… ops.

This revision fixes SubviewOp, InsertSliceOp, ExtractSliceOp construction during bufferization
where not all offset/size/stride operands were properly specified.

A test that exhibited problematic behaviors related to incorrect memref casts is introduced.
Init tensor optimization is disabled in teh testing func bufferize pass.

Differential Revision: https://reviews.llvm.org/D116899
  • Loading branch information
Nicolas Vasilache committed Jan 10, 2022
1 parent 7543365 commit d0ee094
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 42 deletions.
13 changes: 13 additions & 0 deletions mlir/include/mlir/Interfaces/ViewLikeInterface.td
Expand Up @@ -483,6 +483,19 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
::mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()};
return names;
}
/// Assume target is a shaped type and offsets/sizes/strides are vectors of
/// the same length and lower than target's rank.
/// Complete missing dims `i` with offset=0, size=dim(target, i), stride=1
/// until all vectors have size rank. The commpletion occurs for the most
/// minor dimensions (i.e. fastest varying).
/// Take a `createDim` lambda that knows how to build the size of a
/// particular dimension of `target` (to avoid dialect dependencies).
static void expandToRank(
Value target,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides,
llvm::function_ref<OpFoldResult(Value, int64_t)> createDim);
}];

let verify = [{
Expand Down
Expand Up @@ -347,6 +347,14 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
});
}

// bufferization.to_memref is not allowed to change the rank.
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
rankedTensorType.getRank()) &&
"to_memref would be invalid: mismatching ranks");
}

static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");

Expand All @@ -364,6 +372,7 @@ static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
memrefType = getUnrankedMemRefType(
tensor.getType().cast<TensorType>().getElementType());
}
ensureToMemrefOpIsValid(tensor, memrefType);
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
tensor);
}
Expand Down
Expand Up @@ -563,10 +563,26 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
},
/*rewriteFunc=*/
[](OpBuilder &b, Location loc, OpOperand &operand) {
auto insertSliceOp = cast<tensor::InsertSliceOp>(operand.getOwner());
auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
// Expand offsets, sizes and strides to the full rank to handle the
// rank-reducing case.
SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides();
OffsetSizeAndStrideOpInterface::expandToRank(
insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides,
[&](Value target, int64_t dim) -> OpFoldResult {
auto shapedType = target.getType().cast<ShapedType>();
if (shapedType.isDynamicDim(dim))
return b.create<tensor::DimOp>(loc, target, dim).result();
return b.getIndexAttr(shapedType.getDimSize(dim));
});
auto t = tensor::ExtractSliceOp::inferRankReducedResultType(
insertOp.getSourceType().getRank(),
insertOp.dest().getType().cast<RankedTensorType>(), mixedOffsets,
mixedSizes, mixedStrides);
auto extractOp = b.create<tensor::ExtractSliceOp>(
loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides);
return extractOp.result();
},
newOps);
Expand Down
Expand Up @@ -19,6 +19,14 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace scf_ext {

// bufferization.to_memref is not allowed to change the rank.
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
assert((!rankedTensorType || (memrefType.cast<MemRefType>().getRank() ==
rankedTensorType.getRank())) &&
"to_memref would be invalid: mismatching ranks");
}

/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
/// fully implemented at the moment.
struct ExecuteRegionOpInterface
Expand Down Expand Up @@ -159,6 +167,8 @@ struct IfOpInterface
SmallVector<Value> thenYieldValues;
for (OpOperand &operand : thenYieldOp->getOpOperands()) {
if (operand.get().getType().isa<TensorType>()) {
ensureToMemrefOpIsValid(operand.get(),
newTypes[operand.getOperandNumber()]);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
operand.get().getLoc(), newTypes[operand.getOperandNumber()],
operand.get());
Expand All @@ -172,6 +182,8 @@ struct IfOpInterface
SmallVector<Value> elseYieldValues;
for (OpOperand &operand : elseYieldOp->getOpOperands()) {
if (operand.get().getType().isa<TensorType>()) {
ensureToMemrefOpIsValid(operand.get(),
newTypes[operand.getOperandNumber()]);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
operand.get().getLoc(), newTypes[operand.getOperandNumber()],
operand.get());
Expand Down Expand Up @@ -317,6 +329,7 @@ struct ForOpInterface
rewriter.setInsertionPoint(yieldOp);
SmallVector<Value> yieldValues =
convert(yieldOp.getResults(), [&](Value val, int64_t index) {
ensureToMemrefOpIsValid(val, initArgs[index].getType());
return rewriter.create<bufferization::ToMemrefOp>(
val.getLoc(), initArgs[index].getType(), val);
});
Expand Down
Expand Up @@ -68,7 +68,7 @@ struct CastOpInterface

// Compute the new memref type.
Type resultMemRefType;
if (auto rankedTensorType = resultTensorType.isa<RankedTensorType>()) {
if (resultTensorType.isa<RankedTensorType>()) {
resultMemRefType =
getContiguousMemRefType(resultTensorType, layout, memorySpace);
} else {
Expand Down Expand Up @@ -165,16 +165,27 @@ struct ExtractSliceOpInterface
alloc = *allocOrFailure;
}

// Expand offsets, sizes and strides to the full rank to handle the
// rank-reducing case.
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
OffsetSizeAndStrideOpInterface::expandToRank(
srcMemref, mixedOffsets, mixedSizes, mixedStrides,
[&](Value target, int64_t dim) -> OpFoldResult {
auto shapedType = target.getType().cast<ShapedType>();
if (shapedType.isDynamicDim(dim))
return rewriter.create<memref::DimOp>(loc, target, dim).result();
return rewriter.getIndexAttr(shapedType.getDimSize(dim));
});
// Bufferize to subview.
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
dstTensorType.getRank(), srcMemrefType,
extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
extractSliceOp.getMixedStrides())
.cast<MemRefType>();
auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
dstTensorType.getRank(), srcMemrefType,
mixedOffsets, mixedSizes, mixedStrides)
.cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes,
mixedStrides);

// If not inplaceable, copy.
if (!inplace) {
Expand Down Expand Up @@ -422,17 +433,29 @@ struct InsertSliceOpInterface
if (failed(dstMemref))
return failure();

// Expand offsets, sizes and strides to the full rank to handle the
// rank-reducing case.
SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
OffsetSizeAndStrideOpInterface::expandToRank(
*dstMemref, mixedOffsets, mixedSizes, mixedStrides,
[&](Value target, int64_t dim) -> OpFoldResult {
auto shapedType = target.getType().cast<ShapedType>();
if (shapedType.isDynamicDim(dim))
return rewriter.create<memref::DimOp>(loc, target, dim).result();
return rewriter.getIndexAttr(shapedType.getDimSize(dim));
});
// Take a subview of the dst.
auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getRank(), dstMemrefType,
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides())
mixedOffsets, mixedSizes, mixedStrides)
.cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, *dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
mixedStrides);

// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
Expand Down
Expand Up @@ -96,6 +96,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
options->addPostAnalysisStep<
linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
}

if (!allowReturnMemref)
options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();

Expand Down
19 changes: 19 additions & 0 deletions mlir/lib/Interfaces/ViewLikeInterface.cpp
Expand Up @@ -176,3 +176,22 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
return false;
return true;
}

void OffsetSizeAndStrideOpInterface::expandToRank(
Value target, SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes, SmallVector<OpFoldResult> &strides,
llvm::function_ref<OpFoldResult(Value, int64_t)> createOrFoldDim) {
auto shapedType = target.getType().cast<ShapedType>();
unsigned rank = shapedType.getRank();
assert(offsets.size() == sizes.size() && "mismatched lengths");
assert(offsets.size() == strides.size() && "mismatched lengths");
assert(offsets.size() <= rank && "rank overflow");
MLIRContext *ctx = target.getContext();
Attribute zero = IntegerAttr::get(IndexType::get(ctx), APInt(64, 0));
Attribute one = IntegerAttr::get(IndexType::get(ctx), APInt(64, 1));
for (unsigned i = offsets.size(); i < rank; ++i) {
offsets.push_back(zero);
sizes.push_back(createOrFoldDim(target, i));
strides.push_back(one);
}
}
29 changes: 29 additions & 0 deletions mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
Expand Up @@ -67,3 +67,32 @@ func private @private_func(tensor<?xf32>) -> ()
func @empty_func() -> () {
return
}

// -----

// CHECK-LABEL: func @rank_reducing
func @rank_reducing(
%i: index, %j: index,
%arg0: tensor<8x18x32xf32>)
-> tensor<?x1x6x8xf32> {
%c1 = arith.constant 1 : index
%c6 = arith.constant 6 : index
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32>
%1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor<?x1x6x8xf32>
%2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32>
%5 = scf.for %arg7 = %c0 to %c32 step %c8 iter_args(%arg8 = %1) -> (tensor<?x1x6x8xf32>) {
%7 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg7)
%8 = tensor.extract_slice %arg0[%i, %j, %arg7] [1, 6, 8] [1, 1, 1] : tensor<8x18x32xf32> to tensor<1x6x8xf32>
%9 = scf.for %arg9 = %c0 to %c6 step %c1 iter_args(%arg10 = %2) -> (tensor<1x6x8xf32>) {
%11 = tensor.extract_slice %8[0, %arg9, 0] [1, 1, 8] [1, 1, 1] : tensor<1x6x8xf32> to tensor<1x1x8xf32>
%12 = tensor.insert_slice %11 into %arg10[0, %arg9, 0] [1, 1, 8] [1, 1, 1] : tensor<1x1x8xf32> into tensor<1x6x8xf32>
scf.yield %12 : tensor<1x6x8xf32>
}
%10 = tensor.insert_slice %9 into %arg8[%7, 0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] : tensor<1x6x8xf32> into tensor<?x1x6x8xf32>
scf.yield %10 : tensor<?x1x6x8xf32>
}
return %5: tensor<?x1x6x8xf32>
}
Expand Up @@ -1710,26 +1710,3 @@ func @equivalent_func_arg_2(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf
}
return %1: tensor<?xf32>
}

// -----

//===----------------------------------------------------------------------===//
// InitTensorOp elimination would produce SSA violations for the example below.
//===----------------------------------------------------------------------===//

func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>)
-> tensor<?x1x6x8xf32> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32>
%1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor<?x1x6x8xf32>
%2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32>
%3 = scf.for %arg3 = %c0 to %c32 step %c8 iter_args(%arg4 = %1) -> (tensor<?x1x6x8xf32>) {
%4 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg3)
%5 = tensor.insert_slice %2 into %arg4[%4,0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] :
tensor<1x6x8xf32> into tensor<?x1x6x8xf32>
scf.yield %5 : tensor<?x1x6x8xf32>
}
return %3 : tensor<?x1x6x8xf32>
}
23 changes: 23 additions & 0 deletions mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Expand Up @@ -1199,3 +1199,26 @@ func @op_is_reading_but_following_ops_are_not(
// CHECK: return %[[ALLOC]]
return %r1 : tensor<?xf32>
}

// -----

//===----------------------------------------------------------------------===//
// InitTensorOp elimination would produce SSA violations for the example below.
//===----------------------------------------------------------------------===//

func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32>)
-> tensor<?x1x6x8xf32> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32>
%1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor<?x1x6x8xf32>
%2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32>
%3 = scf.for %arg3 = %c0 to %c32 step %c8 iter_args(%arg4 = %1) -> (tensor<?x1x6x8xf32>) {
%4 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg3)
%5 = tensor.insert_slice %2 into %arg4[%4,0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] :
tensor<1x6x8xf32> into tensor<?x1x6x8xf32>
scf.yield %5 : tensor<?x1x6x8xf32>
}
return %3 : tensor<?x1x6x8xf32>
}
3 changes: 0 additions & 3 deletions mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
Expand Up @@ -96,9 +96,6 @@ struct TestComprehensiveFunctionBufferize
void TestComprehensiveFunctionBufferize::runOnFunction() {
auto options = std::make_unique<BufferizationOptions>();

// Enable InitTensorOp elimination.
options->addPostAnalysisStep<
linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
if (!allowReturnMemref)
options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();

Expand Down

0 comments on commit d0ee094

Please sign in to comment.