165 changes: 72 additions & 93 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,21 +504,25 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
return numOccurences;
}

/// Given the type of the un-rank reduced subview result type and the
/// rank-reduced result type, computes the dropped dimensions. This accounts for
/// cases where there are multiple unit-dims, but only a subset of those are
/// dropped. For MemRefTypes these can be disambiguated using the strides. If a
/// dimension is dropped the stride must be dropped too.
/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
/// to be a subset of `originalType` with some `1` entries erased, return the
/// set of indices that specifies which of the entries of `originalShape` are
/// dropped to obtain `reducedShape`.
/// This accounts for cases where there are multiple unit-dims, but only a
/// subset of those are dropped. For MemRefTypes these can be disambiguated
/// using the strides. If a dimension is dropped the stride must be dropped too.
static llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
ArrayAttr staticSizes) {
ArrayRef<OpFoldResult> sizes) {
llvm::SmallDenseSet<unsigned> unusedDims;
if (originalType.getRank() == reducedType.getRank())
return unusedDims;

for (auto dim : llvm::enumerate(staticSizes))
if (dim.value().cast<IntegerAttr>().getInt() == 1)
unusedDims.insert(dim.index());
for (auto dim : llvm::enumerate(sizes))
if (auto attr = dim.value().dyn_cast<Attribute>())
if (attr.cast<IntegerAttr>().getInt() == 1)
unusedDims.insert(dim.index());

SmallVector<int64_t> originalStrides, candidateStrides;
int64_t originalOffset, candidateOffset;
if (failed(
Expand Down Expand Up @@ -574,7 +578,7 @@ llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
MemRefType sourceType = getSourceType();
MemRefType resultType = getType();
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
assert(unusedDims && "unable to find unused dims of subview");
return *unusedDims;
}
Expand Down Expand Up @@ -1546,8 +1550,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
staticStrides, ShapedType::kDynamicStrideOrOffset);
return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
staticSizes, staticStrides)
.cast<MemRefType>();
staticSizes, staticStrides);
}

Type SubViewOp::inferRankReducedResultType(
Expand Down Expand Up @@ -1704,88 +1707,58 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
/// For ViewLikeOpInterface.
Value SubViewOp::getViewSource() { return source(); }

enum SubViewVerificationResult {
Success,
RankTooLarge,
SizeMismatch,
ElemTypeMismatch,
MemSpaceMismatch,
AffineMapMismatch
};

/// Checks if `original` Type type can be rank reduced to `reduced` type.
/// This function is slight variant of `is subsequence` algorithm where
/// not matching dimension must be 1.
static SubViewVerificationResult
isRankReducedType(Type originalType, Type candidateReducedType,
ArrayAttr staticSizes, std::string *errMsg = nullptr) {
if (originalType == candidateReducedType)
return SubViewVerificationResult::Success;
if (!originalType.isa<MemRefType>())
return SubViewVerificationResult::Success;
if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
return SubViewVerificationResult::Success;

ShapedType originalShapedType = originalType.cast<ShapedType>();
ShapedType candidateReducedShapedType =
candidateReducedType.cast<ShapedType>();

// Rank and size logic is valid for all ShapedTypes.
ArrayRef<int64_t> originalShape = originalShapedType.getShape();
ArrayRef<int64_t> candidateReducedShape =
candidateReducedShapedType.getShape();
unsigned originalRank = originalShape.size(),
candidateReducedRank = candidateReducedShape.size();
if (candidateReducedRank > originalRank)
return SubViewVerificationResult::RankTooLarge;
static SliceVerificationResult
isRankReducedMemRefType(MemRefType originalType,
MemRefType candidatecandidateReducedType,
ArrayRef<OpFoldResult> sizes) {
auto partialRes =
isRankReducedType(originalType, candidatecandidateReducedType);
if (partialRes != SliceVerificationResult::Success)
return partialRes;

MemRefType original = originalType.cast<MemRefType>();
MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
MemRefType candidateReduced =
candidatecandidateReducedType.cast<MemRefType>();

auto optionalUnusedDimsMask =
computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
computeMemRefRankReductionMask(original, candidateReduced, sizes);

// Sizes cannot be matched in case empty vector is returned.
if (!optionalUnusedDimsMask.hasValue())
return SubViewVerificationResult::SizeMismatch;
return SliceVerificationResult::LayoutMismatch;

if (originalShapedType.getElementType() !=
candidateReducedShapedType.getElementType())
return SubViewVerificationResult::ElemTypeMismatch;

// Strided layout logic is relevant for MemRefType only.
if (original.getMemorySpace() != candidateReduced.getMemorySpace())
return SubViewVerificationResult::MemSpaceMismatch;
return SubViewVerificationResult::Success;
return SliceVerificationResult::MemSpaceMismatch;

return SliceVerificationResult::Success;
}

template <typename OpTy>
static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
OpTy op, Type expectedType,
StringRef errMsg = "") {
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
OpTy op, Type expectedType) {
auto memrefType = expectedType.cast<ShapedType>();
switch (result) {
case SubViewVerificationResult::Success:
case SliceVerificationResult::Success:
return success();
case SubViewVerificationResult::RankTooLarge:
case SliceVerificationResult::RankTooLarge:
return op.emitError("expected result rank to be smaller or equal to ")
<< "the source rank. " << errMsg;
case SubViewVerificationResult::SizeMismatch:
<< "the source rank. ";
case SliceVerificationResult::SizeMismatch:
return op.emitError("expected result type to be ")
<< expectedType
<< " or a rank-reduced version. (mismatch of result sizes) "
<< errMsg;
case SubViewVerificationResult::ElemTypeMismatch:
<< " or a rank-reduced version. (mismatch of result sizes) ";
case SliceVerificationResult::ElemTypeMismatch:
return op.emitError("expected result element type to be ")
<< memrefType.getElementType() << errMsg;
case SubViewVerificationResult::MemSpaceMismatch:
return op.emitError("expected result and source memory spaces to match.")
<< errMsg;
case SubViewVerificationResult::AffineMapMismatch:
<< memrefType.getElementType();
case SliceVerificationResult::MemSpaceMismatch:
return op.emitError("expected result and source memory spaces to match.");
case SliceVerificationResult::LayoutMismatch:
return op.emitError("expected result type to be ")
<< expectedType
<< " or a rank-reduced version. (mismatch of result affine map) "
<< errMsg;
<< " or a rank-reduced version. (mismatch of result layout) ";
}
llvm_unreachable("unexpected subview verification result");
}
Expand All @@ -1811,10 +1784,9 @@ static LogicalResult verify(SubViewOp op) {
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));

std::string errMsg;
auto result =
isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
return produceSubViewErrorMsg(result, op, expectedType, errMsg);
auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
subViewType, op.getMixedSizes());
return produceSubViewErrorMsg(result, op, expectedType);
}

raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
Expand Down Expand Up @@ -1854,21 +1826,29 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
/// Infer the canonical type of the result of a subview operation. Returns a
/// type with rank `resultRank` that is either the rank of the rank-reduced
/// type, or the non-rank-reduced type.
static MemRefType
getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
auto resultType =
SubViewOp::inferRankReducedResultType(
resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
.cast<MemRefType>();
if (resultType.getRank() != resultRank) {
resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
mixedSizes, mixedStrides)
.cast<MemRefType>();
static MemRefType getCanonicalSubViewResultType(
MemRefType currentResultType, MemRefType sourceType,
ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
mixedSizes, mixedStrides)
.cast<MemRefType>();
llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes);
// Return nullptr as failure mode.
if (!unusedDims)
return nullptr;
SmallVector<int64_t> shape;
for (auto sizes : llvm::enumerate(nonRankReducedType.getShape())) {
if (unusedDims->count(sizes.index()))
continue;
shape.push_back(sizes.value());
}
return resultType;
AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
if (!layoutMap.isIdentity())
layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
nonRankReducedType.getMemorySpace());
}

namespace {
Expand Down Expand Up @@ -1911,8 +1891,7 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
auto resultType = getCanonicalSubViewResultType(
subViewOp.getType().getRank(),
castOp.source().getType().cast<MemRefType>(),
subViewOp.getType(), castOp.source().getType().cast<MemRefType>(),
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
subViewOp.getMixedStrides());
Value newSubView = rewriter.create<SubViewOp>(
Expand All @@ -1931,9 +1910,9 @@ struct SubViewReturnTypeCanonicalizer {
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return getCanonicalSubViewResultType(op.getType().getRank(),
op.getSourceType(), mixedOffsets,
mixedSizes, mixedStrides);
return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
mixedOffsets, mixedSizes,
mixedStrides);
}
};

Expand Down
134 changes: 53 additions & 81 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down Expand Up @@ -655,10 +656,11 @@ static LogicalResult verify(ReshapeOp op) {
/// An extract_slice op result type can be fully inferred from the source type
/// and the static representation of offsets, sizes and strides. Special
/// sentinels encode the dynamic case.
Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
ArrayRef<int64_t> leadingStaticOffsets,
ArrayRef<int64_t> leadingStaticSizes,
ArrayRef<int64_t> leadingStaticStrides) {
RankedTensorType
ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
ArrayRef<int64_t> leadingStaticOffsets,
ArrayRef<int64_t> leadingStaticSizes,
ArrayRef<int64_t> leadingStaticStrides) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type and
// strides=1.
Expand All @@ -673,11 +675,11 @@ Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
sourceRankedTensorType.getElementType());
}

Type ExtractSliceOp::inferResultType(
RankedTensorType sourceRankedTensorType,
ArrayRef<OpFoldResult> leadingStaticOffsets,
ArrayRef<OpFoldResult> leadingStaticSizes,
ArrayRef<OpFoldResult> leadingStaticStrides) {
RankedTensorType
ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
ArrayRef<OpFoldResult> leadingStaticOffsets,
ArrayRef<OpFoldResult> leadingStaticSizes,
ArrayRef<OpFoldResult> leadingStaticStrides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
Expand All @@ -693,7 +695,7 @@ Type ExtractSliceOp::inferResultType(
/// An extract_slice op result type can be fully inferred from the source type
/// and the static representation of offsets, sizes and strides. Special
/// sentinels encode the dynamic case.
Type ExtractSliceOp::inferRankReducedResultType(
RankedTensorType ExtractSliceOp::inferRankReducedResultType(
unsigned resultRank, RankedTensorType sourceRankedTensorType,
ArrayRef<int64_t> leadingStaticOffsets,
ArrayRef<int64_t> leadingStaticSizes,
Expand All @@ -717,7 +719,7 @@ Type ExtractSliceOp::inferRankReducedResultType(
return inferredType;
}

Type ExtractSliceOp::inferRankReducedResultType(
RankedTensorType ExtractSliceOp::inferRankReducedResultType(
unsigned resultRank, RankedTensorType sourceRankedTensorType,
ArrayRef<OpFoldResult> leadingStaticOffsets,
ArrayRef<OpFoldResult> leadingStaticSizes,
Expand Down Expand Up @@ -746,10 +748,12 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,

ShapedType::kDynamicStrideOrOffset);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
ShapedType::kDynamicSize);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,

ShapedType::kDynamicStrideOrOffset);
auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
// Structuring implementation this way avoids duplication between builders.
Expand Down Expand Up @@ -797,89 +801,35 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
}

enum SliceVerificationResult {
Success,
RankTooLarge,
SizeMismatch,
ElemTypeMismatch,
};

/// Checks if `original` Type type can be rank reduced to `reduced` type.
/// This function is slight variant of `is subsequence` algorithm where
/// not matching dimension must be 1.
static SliceVerificationResult
isRankReducedType(Type originalType, Type candidateReducedType,
std::string *errMsg = nullptr) {
if (originalType == candidateReducedType)
return SliceVerificationResult::Success;
if (!originalType.isa<RankedTensorType>())
return SliceVerificationResult::Success;
if (originalType.isa<RankedTensorType>() &&
!candidateReducedType.isa<RankedTensorType>())
return SliceVerificationResult::Success;

ShapedType originalShapedType = originalType.cast<ShapedType>();
ShapedType candidateReducedShapedType =
candidateReducedType.cast<ShapedType>();

// Rank and size logic is valid for all ShapedTypes.
ArrayRef<int64_t> originalShape = originalShapedType.getShape();
ArrayRef<int64_t> candidateReducedShape =
candidateReducedShapedType.getShape();
unsigned originalRank = originalShape.size(),
candidateReducedRank = candidateReducedShape.size();
if (candidateReducedRank > originalRank)
return SliceVerificationResult::RankTooLarge;

auto optionalUnusedDimsMask =
computeRankReductionMask(originalShape, candidateReducedShape);

// Sizes cannot be matched in case empty vector is returned.
if (!optionalUnusedDimsMask.hasValue())
return SliceVerificationResult::SizeMismatch;

if (originalShapedType.getElementType() !=
candidateReducedShapedType.getElementType())
return SliceVerificationResult::ElemTypeMismatch;

// We are done for the tensor case.
if (originalType.isa<RankedTensorType>())
return SliceVerificationResult::Success;

return SliceVerificationResult::Success;
}

template <typename OpTy>
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
OpTy op, Type expectedType,
StringRef errMsg = "") {
OpTy op, Type expectedType) {
auto memrefType = expectedType.cast<ShapedType>();
switch (result) {
case SliceVerificationResult::Success:
return success();
case SliceVerificationResult::RankTooLarge:
return op.emitError("expected result rank to be smaller or equal to ")
<< "the source rank. " << errMsg;
return op.emitError("expected rank to be smaller or equal to ")
<< "the other rank. ";
case SliceVerificationResult::SizeMismatch:
return op.emitError("expected result type to be ")
<< expectedType
<< " or a rank-reduced version. (mismatch of result sizes) "
<< errMsg;
return op.emitError("expected type to be ")
<< expectedType << " or a rank-reduced version. (size mismatch) ";
case SliceVerificationResult::ElemTypeMismatch:
return op.emitError("expected result element type to be ")
<< memrefType.getElementType() << errMsg;
return op.emitError("expected element type to be ")
<< memrefType.getElementType();
default:
llvm_unreachable("unexpected extract_slice op verification result");
}
llvm_unreachable("unexpected extract_slice op verification result");
}

/// Verifier for ExtractSliceOp.
static LogicalResult verify(ExtractSliceOp op) {
// Verify result type against inferred type.
auto expectedType = ExtractSliceOp::inferResultType(
op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));
auto result = isRankReducedType(expectedType, op.getType());
auto expectedType =
ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(),
op.getMixedSizes(), op.getMixedStrides());
auto result =
isRankReducedType(expectedType.cast<ShapedType>(), op.getType());
return produceSliceErrorMsg(result, op, expectedType);
}

Expand Down Expand Up @@ -1104,10 +1054,12 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,

ShapedType::kDynamicStrideOrOffset);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
ShapedType::kDynamicSize);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,

ShapedType::kDynamicStrideOrOffset);
build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getI64ArrayAttr(staticOffsets),
Expand All @@ -1128,6 +1080,19 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}

/// Verifier for InsertSliceOp.
static LogicalResult verify(InsertSliceOp op) {
// insert_slice is the inverse of extract_slice, use the same type inference.
auto expectedType = ExtractSliceOp::inferRankReducedResultType(
op.getSourceType().getRank(), op.getType(),
extractFromI64ArrayAttr(op.static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));
auto result =
isRankReducedType(expectedType.cast<ShapedType>(), op.getSourceType());
return produceSliceErrorMsg(result, op, expectedType);
}

/// If we have two consecutive InsertSliceOp writing to the same slice, we
/// can mutate the second InsertSliceOp's destination to the first one's.
///
Expand Down Expand Up @@ -1202,9 +1167,16 @@ class InsertSliceOpConstantArgumentFolder final
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);

// Create the new op in canonical form.
rewriter.replaceOpWithNewOp<InsertSliceOp>(
insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
auto sourceType = ExtractSliceOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
mixedOffsets, mixedSizes, mixedStrides);
Value toInsert = insertSliceOp.source();
if (sourceType != insertSliceOp.getSourceType())
toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
sourceType, toInsert);
rewriter.replaceOpWithNewOp<InsertSliceOp>(
insertSliceOp, toInsert, insertSliceOp.dest(), mixedOffsets, mixedSizes,
mixedStrides);
return success();
}
};
Expand Down
18 changes: 10 additions & 8 deletions mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,24 @@

namespace mlir {

/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
/// it is a Value or into `staticVec` if it is an IntegerAttr.
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
/// Helper function to dispatch an OpFoldResult into `staticVec` if:
/// a) it is an IntegerAttr
/// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
/// In such dynamic cases, a copy of the `sentinel` value is also pushed to
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
/// come from an AttrSizedOperandSegments trait.
void dispatchIndexOpFoldResult(OpFoldResult ofr,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec,
int64_t sentinel) {
if (auto v = ofr.dyn_cast<Value>()) {
dynamicVec.push_back(v);
staticVec.push_back(sentinel);
auto v = ofr.dyn_cast<Value>();
if (!v) {
APInt apInt = ofr.get<Attribute>().cast<IntegerAttr>().getValue();
staticVec.push_back(apInt.getSExtValue());
return;
}
APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
staticVec.push_back(apInt.getSExtValue());
dynamicVec.push_back(v);
staticVec.push_back(sentinel);
}

void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
Expand Down
35 changes: 34 additions & 1 deletion mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
llvm::SmallDenseSet<unsigned> unusedDims;
unsigned reducedIdx = 0;
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
// Greedily insert `originalIdx` if no match.
// Greedily insert `originalIdx` if match.
if (reducedIdx < reducedRank &&
originalShape[originalIdx] == reducedShape[reducedIdx]) {
reducedIdx++;
Expand All @@ -590,6 +590,39 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
return unusedDims;
}

SliceVerificationResult
mlir::isRankReducedType(ShapedType originalType,
ShapedType candidateReducedType) {
if (originalType == candidateReducedType)
return SliceVerificationResult::Success;

ShapedType originalShapedType = originalType.cast<ShapedType>();
ShapedType candidateReducedShapedType =
candidateReducedType.cast<ShapedType>();

// Rank and size logic is valid for all ShapedTypes.
ArrayRef<int64_t> originalShape = originalShapedType.getShape();
ArrayRef<int64_t> candidateReducedShape =
candidateReducedShapedType.getShape();
unsigned originalRank = originalShape.size(),
candidateReducedRank = candidateReducedShape.size();
if (candidateReducedRank > originalRank)
return SliceVerificationResult::RankTooLarge;

auto optionalUnusedDimsMask =
computeRankReductionMask(originalShape, candidateReducedShape);

// Sizes cannot be matched in case empty vector is returned.
if (!optionalUnusedDimsMask.hasValue())
return SliceVerificationResult::SizeMismatch;

if (originalShapedType.getElementType() !=
candidateReducedShapedType.getElementType())
return SliceVerificationResult::ElemTypeMismatch;

return SliceVerificationResult::Success;
}

bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
// Empty attribute is allowed as default memory space.
if (!memorySpace)
Expand Down
22 changes: 4 additions & 18 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -820,38 +820,24 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
// CHECK: [[STRIDE:%.+]] = arith.constant 1
// CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
// CHECK: [[IDX0:%.+]] = arith.constant 0 : index
// CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[IDX0]]
// CHECK: [[IDX1:%.+]] = arith.constant 1 : index
// CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[IDX1]]
// CHECK: [[ARG1_AXIS:%.+]] = tensor.dim %arg1, [[AXIS]]
// CHECK: [[RESULT_AXIS:%.+]] = arith.addi [[ARG0_DIM0]], [[ARG1_AXIS]]
// CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1]
// CHECK: [[CST:%.+]] = arith.constant 0.0
// CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]])
// CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[AXIS]]
// CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
// CHECK: [[NEW_OFFSET:%.+]] = arith.addi [[OFFSET]], [[ARG0_DIM0]]
// CHECK: [[ARG1_DIM0:%.+]] = tensor.dim %arg1, [[AXIS]]
// CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]]{{\[}}[[NEW_OFFSET]], [[OFFSET]]] {{\[}}[[ARG1_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
// CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1]
// CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg1 into [[INSERT0]][5, 0] [6, 1] [1, 1]
%0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>)

// CHECK: [[AXIS:%.+]] = arith.constant 1
// CHECK: [[STRIDE:%.+]] = arith.constant 1
// CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
// CHECK: [[IDX0:%.+]] = arith.constant 0 : index
// CHECK: [[ARG0_DIM0:%.+]] = tensor.dim %arg0, [[IDX0]]
// CHECK: [[IDX1:%.+]] = arith.constant 1 : index
// CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[IDX1]]
// CHECK: [[ARG1_AXIS:%.+]] = tensor.dim %arg0, [[AXIS]]
// CHECK: [[RESULT_AXIS:%.+]] = arith.addi [[ARG0_DIM1]], [[ARG1_AXIS]]
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2]
// CHECK: [[CST:%.+]] = arith.constant 0.0
// CHECK: [[FILL:%.+]] = linalg.fill([[CST]], [[INIT]])
// CHECK: [[ARG0_DIM1:%.+]] = tensor.dim %arg0, [[AXIS]]
// CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
// CHECK: [[NEW_OFFSET:%.+]] = arith.addi [[OFFSET]], [[ARG0_DIM1]]
// CHECK: [[ARG1_DIM1:%.+]] = tensor.dim %arg0, [[AXIS]]
// CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]]{{\[}}[[OFFSET]], [[NEW_OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG1_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
// CHECK: [[INSERT0:%.+]] = tensor.insert_slice %arg0 into [[FILL]][0, 0] [5, 1] [1, 1]
// CHECK: [[INSERT1:%.+]] = tensor.insert_slice %arg0 into [[INSERT0]][0, 1] [5, 1] [1, 1]
%1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>)
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,9 @@ func @nested_extract_slice_and_insert(
%A : tensor<?x?xf32>,
%B : tensor<?x?xf32> {linalg.inplaceable = true},
%C : tensor<?x?xf32> {linalg.inplaceable = true},
%idx : index)
%idx : index,
%sz1 : index,
%sz2 : index)
-> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
{
%f0 = arith.constant 0.0 : f32
Expand Down Expand Up @@ -497,9 +499,9 @@ func @nested_extract_slice_and_insert(
// CHECK-NEXT: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
%sC = tensor.extract_slice %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%ssC = tensor.extract_slice %sC[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
%FC = linalg.fill(%f0, %ssC) : f32, tensor<4x4xf32> -> tensor<4x4xf32>
%rsC = tensor.insert_slice %FC into %sC[0, 0][12345, 67890][1, 1] : tensor<4x4xf32> into tensor<?x?xf32>
%ssC = tensor.extract_slice %sC[0, 0][%sz1, 4][1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
%FC = linalg.fill(%f0, %ssC) : f32, tensor<?x4xf32> -> tensor<?x4xf32>
%rsC = tensor.insert_slice %FC into %sC[0, 0][%sz2, 4][1, 1] : tensor<?x4xf32> into tensor<?x?xf32>
%rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>

return %rA, %rB, %rC: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
linalg.yield %1 : f32
} -> tensor<4xf32>

%sum_sub = tensor.insert_slice %acc into %o_[%j][%c4][1]
%sum_sub = tensor.insert_slice %acc into %o_[%j][4][1]
: tensor<4xf32> into tensor<24xf32>
linalg.yield %sum_sub : tensor<24xf32>
}
Expand Down
24 changes: 23 additions & 1 deletion mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,

// -----

#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> memref<?x?xf32, #map0>
{
Expand Down Expand Up @@ -395,3 +395,25 @@ func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf3
%collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
return %collapsed : memref<?x?xf32>
}

// -----

func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
-> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> {
%c0 = arith.constant 0 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%0 = memref.subview %arg0[%arg1, %arg1, %arg1, 0] [%c1, %c4, %c1, 1] [1, 1, 1, 1]
: memref<2x5x7x1xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
%1 = memref.cast %0
: memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> to
memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
return %1 : memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
}

// CHECK-LABEL: func @reduced_memref
// CHECK: %[[RESULT:.+]] = memref.subview
// CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}>
// CHECK: return %[[RESULT]]
4 changes: 2 additions & 2 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index)
// -----

func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
// expected-error@+1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result sizes)}}
// expected-error@+1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result layout)}}
%0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>
return
}
Expand All @@ -653,7 +653,7 @@ func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg

func @static_stride_to_dynamic_stride(%arg0 : memref<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> memref<?x?xf32, offset:?, strides: [?, ?]> {
// expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}}
// expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result layout)}}
%0 = memref.subview %arg0[0, 0, 0] [1, %arg1, %arg2] [1, 1, 1] : memref<?x?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
return %0 : memref<?x?xf32, offset: ?, strides: [?, ?]>
}
Expand Down
11 changes: 6 additions & 5 deletions mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,17 @@ func @tensor_dim_of_iter_arg(%t : tensor<?x?xf32>) -> index {
// CHECK: scf.for
// CHECK: tensor.dim %[[t]]
func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
%t2 : tensor<?x?xf32>) -> index {
%t2 : tensor<10x10xf32>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
-> (tensor<?x?xf32>, index) {
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%2 = tensor.insert_slice %t2 into %arg0[0, 0] [10, 10] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
: tensor<10x10xf32> into tensor<?x?xf32>
%3 = tensor.insert_slice %t2 into %2[1, 1] [10, 10] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
: tensor<10x10xf32> into tensor<?x?xf32>
scf.yield %3, %dim : tensor<?x?xf32>, index
}
return %1 : index
Expand All @@ -274,7 +274,7 @@ func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
// CHECK: scf.for
// CHECK: tensor.dim %[[t]]
func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
%t2 : tensor<?x?xf32>) -> index {
%t2 : tensor<10x10xf32>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
Expand All @@ -284,14 +284,15 @@ func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
-> (tensor<?x?xf32>, index) {
%dim = tensor.dim %arg2, %c0 : tensor<?x?xf32>
%4 = tensor.insert_slice %t2 into %arg2[0, 0] [10, 10] [1, 1]
: tensor<?x?xf32> into tensor<?x?xf32>
: tensor<10x10xf32> into tensor<?x?xf32>
scf.yield %4, %dim : tensor<?x?xf32>, index
}
scf.yield %2, %3 : tensor<?x?xf32>, index
}
return %1 : index
}


// -----

// A test case that should not canonicalize because the loop is not shape
Expand Down
31 changes: 11 additions & 20 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,10 @@ func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32x
// CHECK-NOT: tensor.cast
// CHECK: return %[[S]] : tensor<4x6x16x32xi8>
func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
%c0 = arith.constant 0: index
%cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
%res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
%sz = tensor.dim %cast, %c0: tensor<?x32xi8>
%res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, %sz] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
return %res : tensor<4x6x16x32xi8>
}

Expand Down Expand Up @@ -408,9 +410,10 @@ func @rank_reducing_insert_slice_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : i
}
// CHECK-LABEL: func @rank_reducing_insert_slice_canonicalize
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]]
// CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32>
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
// CHECK-SAME: : tensor<?x?xf32> into tensor<?x?x?xf32>
// CHECK-SAME: : tensor<4x?xf32> into tensor<?x?x?xf32>
// CHEKC: return %[[RESULT]]

// -----
Expand Down Expand Up @@ -450,7 +453,7 @@ func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i
^bb0(%arg4: index, %arg5: index):
tensor.yield %1 : i32
} : tensor<?x?xi32>
%3 = tensor.insert_slice %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
%3 = tensor.insert_slice %arg0 into %2[0, %arg3] [2, %0] [1, 1] : tensor<2x?xi32> into tensor<?x?xi32>
return %3 : tensor<?x?xi32>
}
// CHECK-LABEL: func @insert_slice_propagate_dest_cast
Expand All @@ -462,17 +465,14 @@ func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i
// -----

func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c9 = arith.constant 9 : index
%c3 = arith.constant 3 : index
%2 = tensor.extract %arg1[] : tensor<i32>
%4 = tensor.generate %c3, %c9 {
^bb0(%arg2: index, %arg3: index):
tensor.yield %2 : i32
} : tensor<?x?xi32>
%5 = tensor.insert_slice %arg0 into %4[%c0, %c1] [%c2, %c3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
%5 = tensor.insert_slice %arg0 into %4[0, 1] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
%6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32>
return %6 : tensor<3x9xi32>
}
Expand Down Expand Up @@ -527,8 +527,9 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
// CHECK: return %[[r]]
func @insert_tensor_cast_on_insert_slice_src(
%arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1]
%arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
%c64 = arith.constant 64: index
%r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [%c64, 5, %c64] [1, 1, 1]
: tensor<?x5x?xf32> into tensor<?x?x?xf32>
return %r : tensor<?x?x?xf32>
}
Expand Down Expand Up @@ -559,13 +560,3 @@ func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8x
// CHECK: return %[[INSERT]]
return %1 : tensor<?x?x?xf32>
}

// -----

// CHECK-LABEL: func @folding_incorrect_ir_triggers_infinite_loop
func @folding_incorrect_ir_triggers_infinite_loop(
%A : tensor<4x4xf32>, %C : tensor<?x?xf32>) -> tensor<?x?xf32> {
%rC = tensor.insert_slice %A into %C[0, 0][12345, 67890][1, 1] :
tensor<4x4xf32> into tensor<?x?xf32>
return %rC: tensor<?x?xf32>
}
68 changes: 62 additions & 6 deletions mlir/test/Dialect/Tensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,36 @@ func @tensor.reshape_num_elements_mismatch(

// -----

func @slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
// expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (mismatch of result sizes)}}
func @extract_slice_wrong_result_rank(%t: tensor<?xf32>, %idx : index) {
// expected-error @+1 {{expected rank to be smaller or equal to the other rank.}}
%0 = tensor.extract_slice %t[0][4][1] : tensor<?xf32> to tensor<?x?xf32>

return
}

// -----

func @extract_slice_wrong_result_rank(%t: tensor<?xf32>, %idx : index) {
// expected-error @+1 {{expected element type to be 'f32'}}
%0 = tensor.extract_slice %t[0][4][1] : tensor<?xf32> to tensor<4xi8>

return
}

// -----

func @extract_slice_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
// expected-error @+1 {{expected type to be 'tensor<?x4x4xf32>' or a rank-reduced version. (size mismatch)}}
%0 = tensor.extract_slice %t[0, 0, 0][%idx, 4, 4][1, 1, 1]
: tensor<8x16x4xf32> to tensor<4x4x4xf32>

return
}

// -----

func @extract_slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {
// expected-error @+1 {{expected type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (size mismatch)}}
%0 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1]
: tensor<8x16x4xf32> to tensor<?x4x4xf32>

Expand All @@ -159,10 +187,38 @@ func @slice_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) {

// -----

func @slice_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) {
// expected-error @+1 {{expected result type to be 'tensor<?x3x?xf32>' or a rank-reduced version. (mismatch of result sizes)}}
%0 = tensor.extract_slice %t[0, 0, 0][%idx, 3, %idx][1, 1, 1]
: tensor<8x16x4xf32> to tensor<4x4x4xf32>
func @insert_slice_wrong_result_rank(%t1: tensor<?xf32>, %t2: tensor<?x?xf32>, %idx : index) {
// expected-error @+1 {{expected rank to be smaller or equal to the other rank.}}
%0 = tensor.insert_slice %t2 into %t1[0][4][1] : tensor<?x?xf32> into tensor<?xf32>

return
}

// -----

func @insert_slice_wrong_result_rank(%t1: tensor<4xi8>, %t2: tensor<?xf32>, %idx : index) {
// expected-error @+1 {{expected element type to be 'f32'}}
%0 = tensor.insert_slice %t1 into %t2[0][4][1] : tensor<4xi8> into tensor<?xf32>

return
}

// -----

func @insert_slice_wrong_static_type(%t1: tensor<4x4x4xf32>, %t2: tensor<8x16x4xf32>, %idx : index) {
// expected-error @+1 {{expected type to be 'tensor<?x4x4xf32>' or a rank-reduced version. (size mismatch)}}
%0 = tensor.insert_slice %t1 into %t2[0, 0, 0][%idx, 4, 4][1, 1, 1]
: tensor<4x4x4xf32> into tensor<8x16x4xf32>

return
}

// -----

func @insert_slice_wrong_dynamic_type(%t1: tensor<?x4x4xf32>, %t2: tensor<8x16x4xf32>, %idx : index) {
// expected-error @+1 {{expected type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (size mismatch)}}
%0 = tensor.insert_slice %t1 into %t2[0, 2, 0][4, 4, 4][1, 1, 1]
: tensor<?x4x4xf32> into tensor<8x16x4xf32>

return
}
57 changes: 57 additions & 0 deletions mlir/test/Dialect/Tensor/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,60 @@ func @tensor_reshape(%unranked: tensor<*xf32>, %shape1: tensor<1xi32>,
: (tensor<?x?xf32>, tensor<?xi32>) -> tensor<*xf32>
return %new_unranked : tensor<*xf32>
}

// CHECK-LABEL: func @slice({{.*}}) {
func @slice(%t: tensor<8x16x4xf32>, %idx : index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

// CHECK: tensor.extract_slice
// CHECK-SAME: tensor<8x16x4xf32> to tensor<?x?x?xf32>
%1 = tensor.extract_slice %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
: tensor<8x16x4xf32> to tensor<?x?x?xf32>

// CHECK: tensor.extract_slice
// CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32>
%2 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1]
: tensor<8x16x4xf32> to tensor<4x4x4xf32>

// CHECK: tensor.extract_slice
// CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32>
%3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1]
: tensor<8x16x4xf32> to tensor<4x4xf32>

return
}

// CHECK-LABEL: func @insert_slice({{.*}}) {
func @insert_slice(
%t: tensor<8x16x4xf32>,
%td: tensor<8x?x4xf32>,
%t2: tensor<16x32x8xf32>,
%t3: tensor<4x4xf32>,
%idx : index,
%sz : index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

// CHECK: tensor.insert_slice
// CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
%1 = tensor.insert_slice %t into %t2[%c0, %c0, %c0][8, 16, 4][%c1, %c1, %c1]
: tensor<8x16x4xf32> into tensor<16x32x8xf32>

// CHECK: tensor.insert_slice
// CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
%2 = tensor.insert_slice %t into %t2[%c0, %idx, %c0][8, 16, 4][%c1, 1, %c1]
: tensor<8x16x4xf32> into tensor<16x32x8xf32>

// CHECK: tensor.insert_slice
// CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32>
%3 = tensor.insert_slice %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1]
: tensor<4x4xf32> into tensor<8x16x4xf32>

// CHECK: tensor.insert_slice
// CHECK-SAME: tensor<8x?x4xf32> into tensor<8x16x4xf32>
%4 = tensor.insert_slice %td into %t[0, %idx, 0][8, %sz, 4][1, 1, 1]
: tensor<8x?x4xf32> into tensor<8x16x4xf32>

return
}
50 changes: 0 additions & 50 deletions mlir/test/IR/core-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -486,53 +486,3 @@ func @assume_alignment(%0: memref<4x4xf16>) {
memref.assume_alignment %0, 16 : memref<4x4xf16>
return
}

// CHECK-LABEL: func @slice({{.*}}) {
func @slice(%t: tensor<8x16x4xf32>, %idx : index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

// CHECK: tensor.extract_slice
// CHECK-SAME: tensor<8x16x4xf32> to tensor<?x?x?xf32>
%1 = tensor.extract_slice %t[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
: tensor<8x16x4xf32> to tensor<?x?x?xf32>

// CHECK: tensor.extract_slice
// CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4x4xf32>
%2 = tensor.extract_slice %t[0, 2, 0][4, 4, 4][1, 1, 1]
: tensor<8x16x4xf32> to tensor<4x4x4xf32>

// CHECK: tensor.extract_slice
// CHECK-SAME: tensor<8x16x4xf32> to tensor<4x4xf32>
%3 = tensor.extract_slice %t[0, 2, 0][4, 1, 4][1, 1, 1]
: tensor<8x16x4xf32> to tensor<4x4xf32>

return
}

// CHECK-LABEL: func @insert_slice({{.*}}) {
func @insert_slice(
%t: tensor<8x16x4xf32>,
%t2: tensor<16x32x8xf32>,
%t3: tensor<4x4xf32>,
%idx : index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index

// CHECK: tensor.insert_slice
// CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
%1 = tensor.insert_slice %t into %t2[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
: tensor<8x16x4xf32> into tensor<16x32x8xf32>

// CHECK: tensor.insert_slice
// CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
%2 = tensor.insert_slice %t into %t2[%c0, %idx, %c0][%idx, 4, %idx][%c1, 1, %c1]
: tensor<8x16x4xf32> into tensor<16x32x8xf32>

// CHECK: tensor.insert_slice
// CHECK-SAME: tensor<4x4xf32> into tensor<8x16x4xf32>
%3 = tensor.insert_slice %t3 into %t[0, 2, 0][4, 1, 4][1, 1, 1]
: tensor<4x4xf32> into tensor<8x16x4xf32>

return
}