Skip to content

Commit

Permalink
[mlir][memref] Fix crash in SubViewReturnTypeCanonicalizer
Browse files Browse the repository at this point in the history
`SubViewReturnTypeCanonicalizer` is used by `OpWithOffsetSizesAndStridesConstantArgumentFolder`, which folds constant SSA value (dynamic) sizes into static sizes. The previous implementation crashed when a dynamic size was folded into a static `1` dimension, which was then mistaken as a rank reduction.

Differential Revision: https://reviews.llvm.org/D158721
  • Loading branch information
matthias-springer authored and tru committed Aug 31, 2023
1 parent 08d720d commit ad5ed49
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 32 deletions.
64 changes: 33 additions & 31 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,17 @@ namespace {
namespace saturated_arith {
struct Wrapper {
static Wrapper stride(int64_t v) {
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
: Wrapper{false, v};
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
static Wrapper offset(int64_t v) {
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
: Wrapper{false, v};
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
static Wrapper size(int64_t v) {
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
int64_t asOffset() {
return saturated ? ShapedType::kDynamic : v;
}
int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
int64_t asStride() {
return saturated ? ShapedType::kDynamic : v;
}
int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
bool operator==(Wrapper other) {
return (saturated && other.saturated) ||
(!saturated && !other.saturated && v == other.v);
Expand Down Expand Up @@ -732,8 +726,7 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
for (auto it : llvm::zip(sourceStrides, resultStrides)) {
auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st)
if (ShapedType::isDynamic(ss) &&
!ShapedType::isDynamic(st))
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
return false;
}

Expand Down Expand Up @@ -766,8 +759,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
// same. They are also compatible if either one is dynamic (see
// description of MemRefCastOp for details).
auto checkCompatible = [](int64_t a, int64_t b) {
return (ShapedType::isDynamic(a) ||
ShapedType::isDynamic(b) || a == b);
return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
};
if (!checkCompatible(aOffset, bOffset))
return false;
Expand Down Expand Up @@ -1890,8 +1882,7 @@ LogicalResult ReinterpretCastOp::verify() {
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front();
if (!ShapedType::isDynamic(resultOffset) &&
!ShapedType::isDynamic(expectedOffset) &&
resultOffset != expectedOffset)
!ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
return emitError("expected result type with offset = ")
<< expectedOffset << " instead of " << resultOffset;

Expand Down Expand Up @@ -2945,18 +2936,6 @@ static MemRefType getCanonicalSubViewResultType(
nonRankReducedType.getMemorySpace());
}

/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
/// to deduce the result type. Additionally, reduce the rank of the inferred
/// result type if `currentResultType` is lower rank than `sourceType`.
static MemRefType getCanonicalSubViewResultType(
MemRefType currentResultType, MemRefType sourceType,
ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return getCanonicalSubViewResultType(currentResultType, sourceType,
sourceType, mixedOffsets, mixedSizes,
mixedStrides);
}

Value mlir::memref::createCanonicalRankReducingSubViewOp(
OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
auto memrefType = llvm::cast<MemRefType>(memref.getType());
Expand Down Expand Up @@ -3109,9 +3088,32 @@ struct SubViewReturnTypeCanonicalizer {
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
mixedOffsets, mixedSizes,
mixedStrides);
// Infer a memref type without taking into account any rank reductions.
MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(
op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides));

// Directly return the non-rank reduced type if there are no dropped dims.
llvm::SmallBitVector droppedDims = op.getDroppedDims();
if (droppedDims.empty())
return nonReducedType;

// Take the strides and offset from the non-rank reduced type.
auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);

// Drop dims from shape and strides.
SmallVector<int64_t> targetShape;
SmallVector<int64_t> targetStrides;
for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
if (droppedDims.test(i))
continue;
targetStrides.push_back(nonReducedStrides[i]);
targetShape.push_back(nonReducedType.getDimSize(i));
}

return MemRefType::get(targetShape, nonReducedType.getElementType(),
StridedLayoutAttr::get(nonReducedType.getContext(),
offset, targetStrides),
nonReducedType.getMemorySpace());
}
};

Expand Down
17 changes: 16 additions & 1 deletion mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ func.func @fold_multiple_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32

// -----

// CHECK-lABEL: func @ub_negative_alloc_size
// CHECK-LABEL: func private @ub_negative_alloc_size
func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
%idx1 = index.constant 1
%c-2 = arith.constant -2 : index
Expand All @@ -940,3 +940,18 @@ func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
%alloc = memref.alloc(%c15, %c-2, %idx1) : memref<?x?x?xi1>
return %alloc : memref<?x?x?xi1>
}

// -----

// CHECK-LABEL: func @subview_rank_reduction(
// CHECK-SAME: %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index
func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
-> memref<?x?xf32, strided<[384, 1], offset: ?>> {
%c1 = arith.constant 1 : index
// CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
// CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?x?xf32, strided<[384, 1], offset: ?>>
%0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1]
: memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
// CHECK: return %[[cast]]
return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
}

0 comments on commit ad5ed49

Please sign in to comment.