diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index f79e955437ddbd..813dd7db5e9e38 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2504,7 +2504,7 @@ def SubIOp : IntArithmeticOp<"subi"> { //===----------------------------------------------------------------------===// def SubViewOp : Std_Op<"subview", [ - AttrSizedOperandSegments, + AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect, ]> { @@ -2516,17 +2516,14 @@ def SubViewOp : Std_Op<"subview", [ The SubView operation supports the following arguments: *) Memref: the "base" memref on which to create a "view" memref. - *) Offsets: zero or memref-rank number of dynamic offsets into the "base" - memref at which to create the "view" memref. - *) Sizes: zero or memref-rank dynamic size operands which specify the - dynamic sizes of the result "view" memref type. - *) Strides: zero or memref-rank number of dynamic strides which are applied - multiplicatively to the base memref strides in each dimension. - - Note on the number of operands for offsets, sizes and strides: For - each of these, the number of operands must either be same as the - memref-rank number or empty. For the latter, those values will be - treated as constants. + *) Offsets: memref-rank number of dynamic offsets or static integer + attributes into the "base" memref at which to create the "view" + memref. + *) Sizes: memref-rank number of dynamic sizes or static integer attributes + which specify the sizes of the result "view" memref type. + *) Strides: memref-rank number of dynamic strides or static integer + attributes multiplicatively to the base memref strides in each + dimension. Example 1: @@ -2537,7 +2534,7 @@ def SubViewOp : Std_Op<"subview", [ // dynamic sizes for each dimension, and stride arguments '%c1'. %1 = subview %0[%c0, %c0][%size0, %size1][%c1, %c1] : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1) > to - memref (d0 * s1 + d1 + s0)> + memref (d0 * s1 + d1 * s2 + s0)> ``` Example 2: @@ -2564,9 +2561,9 @@ def SubViewOp : Std_Op<"subview", [ %0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)> // Subview with constant offsets, sizes and strides. - %1 = subview %0[][][] + %1 = subview %0[0, 2, 0][4, 4, 4][1, 1, 1] : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to - memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)> + memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)> ``` Example 4: @@ -2608,7 +2605,7 @@ def SubViewOp : Std_Op<"subview", [ // #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0) // // where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1. - %1 = subview %0[%i, %j][][%x, %y] : + %1 = subview %0[%i, %j][4, 4][%x, %y] : : memref (d0 * s1 + d1 * s2 + s0)> to memref<4x4xf32, (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)> @@ -2624,24 +2621,25 @@ def SubViewOp : Std_Op<"subview", [ AnyMemRef:$source, Variadic:$offsets, Variadic:$sizes, - Variadic:$strides + Variadic:$strides, + I64ArrayAttr:$static_offsets, + I64ArrayAttr:$static_sizes, + I64ArrayAttr:$static_strides ); let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $source `[` $offsets `]` `[` $sizes `]` `[` $strides `]` attr-dict `:` - type($source) `to` type($result) - }]; - let builders = [ + // Build a SubViewOp with mized static and dynamic entries. OpBuilder< "OpBuilder &b, OperationState &result, Value source, " - "ValueRange offsets, ValueRange sizes, " - "ValueRange strides, Type resultType = Type(), " - "ArrayRef attrs = {}">, + "ArrayRef staticOffsets, ArrayRef staticSizes," + "ArrayRef staticStrides, ValueRange offsets, ValueRange sizes, " + "ValueRange strides, ArrayRef attrs = {}">, + // Build a SubViewOp with all dynamic entries. OpBuilder< - "OpBuilder &builder, OperationState &result, " - "Type resultType, Value source"> + "OpBuilder &b, OperationState &result, Value source, " + "ValueRange offsets, ValueRange sizes, ValueRange strides, " + "ArrayRef attrs = {}"> ]; let extraClassDeclaration = [{ @@ -2670,13 +2668,34 @@ def SubViewOp : Std_Op<"subview", [ /// operands could not be retrieved. LogicalResult getStaticStrides(SmallVectorImpl &staticStrides); - // Auxiliary range data structure and helper function that unpacks the - // offset, size and stride operands of the SubViewOp into a list of triples. - // Such a list of triple is sometimes more convenient to manipulate. + /// Auxiliary range data structure and helper function that unpacks the + /// offset, size and stride operands of the SubViewOp into a list of triples. + /// Such a list of triple is sometimes more convenient to manipulate. struct Range { Value offset, size, stride; }; SmallVector getRanges(); + + /// Return the rank of the result MemRefType. + unsigned getRank() { return getType().getRank(); } + + static StringRef getStaticOffsetsAttrName() { + return "static_offsets"; + } + static StringRef getStaticSizesAttrName() { + return "static_sizes"; + } + static StringRef getStaticStridesAttrName() { + return "static_strides"; + } + static ArrayRef getSpecialAttrNames() { + static SmallVector names{ + getStaticOffsetsAttrName(), + getStaticSizesAttrName(), + getStaticStridesAttrName(), + getOperandSegmentSizeAttr()}; + return names; + } }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 553be944ab30d4..39dc5d203a6180 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1293,7 +1293,7 @@ static LogicalResult verify(DimOp op) { auto indexAttr = op.getAttrOfType("index"); if (!indexAttr) return op.emitOpError("requires an integer attribute named 'index'"); - int64_t index = indexAttr.getValue().getSExtValue(); + int64_t index = indexAttr.getInt(); auto type = op.getOperand().getType(); if (auto tensorType = type.dyn_cast()) { @@ -2183,59 +2183,272 @@ OpFoldResult SubIOp::fold(ArrayRef operands) { // SubViewOp //===----------------------------------------------------------------------===// -// Returns a MemRefType with dynamic sizes and offset and the same stride as the -// `memRefType` passed as argument. -// TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep -// sizes and offset static. -static Type inferSubViewResultType(MemRefType memRefType) { - auto rank = memRefType.getRank(); - int64_t offset; - SmallVector strides; - auto res = getStridesAndOffset(memRefType, strides, offset); +/// Print a list with either (1) the static integer value in `arrayAttr` if +/// `isDynamic` evaluates to false or (2) the next value otherwise. +/// This allows idiomatic printing of mixed value and integer attributes in a +/// list. E.g. `[%arg0, 7, 42, %arg42]`. +static void printSubViewListOfOperandsOrIntegers( + OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, + llvm::function_ref isDynamic) { + p << "["; + unsigned idx = 0; + llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { + int64_t val = a.cast().getInt(); + if (isDynamic(val)) + p << values[idx++]; + else + p << val; + }); + p << "] "; +} + +/// Parse a mixed list with either (1) static integer values or (2) SSA values. +/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` +/// encode the position of SSA values. Add the parsed SSA values to `ssa` +/// in-order. +// +/// E.g. after parsing "[%arg0, 7, 42, %arg42]": +/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" +/// 2. `ssa` is filled with "[%arg0, %arg1]". +static ParseResult +parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, + StringRef attrName, int64_t dynVal, + SmallVectorImpl &ssa) { + if (failed(parser.parseLSquare())) + return failure(); + // 0-D. + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + SmallVector attrVals; + while (true) { + OpAsmParser::OperandType operand; + auto res = parser.parseOptionalOperand(operand); + if (res.hasValue() && succeeded(res.getValue())) { + ssa.push_back(operand); + attrVals.push_back(dynVal); + } else { + Attribute attr; + NamedAttrList placeholder; + if (failed(parser.parseAttribute(attr, "_", placeholder)) || + !attr.isa()) + return parser.emitError(parser.getNameLoc()) + << "expected SSA value or integer"; + attrVals.push_back(attr.cast().getInt()); + } + + if (succeeded(parser.parseOptionalComma())) + continue; + if (failed(parser.parseRSquare())) + return failure(); + else + break; + } + + auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); + result.addAttribute(attrName, arrayAttr); + return success(); +} + +namespace { +/// Helpers to write more idiomatic operations. +namespace saturated_arith { +struct Wrapper { + explicit Wrapper(int64_t v) : v(v) {} + operator int64_t() { return v; } + int64_t v; +}; +Wrapper operator+(Wrapper a, int64_t b) { + if (ShapedType::isDynamicStrideOrOffset(a) || + ShapedType::isDynamicStrideOrOffset(b)) + return Wrapper(ShapedType::kDynamicStrideOrOffset); + return Wrapper(a.v + b); +} +Wrapper operator*(Wrapper a, int64_t b) { + if (ShapedType::isDynamicStrideOrOffset(a) || + ShapedType::isDynamicStrideOrOffset(b)) + return Wrapper(ShapedType::kDynamicStrideOrOffset); + return Wrapper(a.v * b); +} +} // end namespace saturated_arith +} // end namespace + +/// A subview result type can be fully inferred from the source type and the +/// static representation of offsets, sizes and strides. Special sentinels +/// encode the dynamic case. +static Type inferSubViewResultType(MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides) { + unsigned rank = sourceMemRefType.getRank(); + (void)rank; + assert(staticOffsets.size() == rank && + "unexpected staticOffsets size mismatch"); + assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch"); + assert(staticStrides.size() == rank && + "unexpected staticStrides size mismatch"); + + // Extract source offset and strides. + int64_t sourceOffset; + SmallVector sourceStrides; + auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); assert(succeeded(res) && "SubViewOp expected strided memref type"); (void)res; - // Assume sizes and offset are fully dynamic for now until canonicalization - // occurs on the ranges. Typed strides don't change though. - offset = MemRefType::getDynamicStrideOrOffset(); - // Overwrite strides because verifier will not pass. - // TODO(b/144419106): don't force degrade the strides to fully dynamic. - for (auto &stride : strides) - stride = MemRefType::getDynamicStrideOrOffset(); - auto stridedLayout = - makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); - SmallVector sizes(rank, ShapedType::kDynamicSize); - return MemRefType::Builder(memRefType) - .setShape(sizes) - .setAffineMaps(stridedLayout); + // Compute target offset whose value is: + // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. + int64_t targetOffset = sourceOffset; + for (auto it : llvm::zip(staticOffsets, sourceStrides)) { + auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); + using namespace saturated_arith; + targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; + } + + // Compute target stride whose value is: + // `sourceStrides_i * staticStrides_i`. + SmallVector targetStrides; + targetStrides.reserve(staticOffsets.size()); + for (auto it : llvm::zip(sourceStrides, staticStrides)) { + auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); + using namespace saturated_arith; + targetStrides.push_back(Wrapper(sourceStride) * staticStride); + } + + // The type is now known. + return MemRefType::get( + staticSizes, sourceMemRefType.getElementType(), + makeStridedLinearLayoutMap(targetStrides, targetOffset, + sourceMemRefType.getContext()), + sourceMemRefType.getMemorySpace()); +} + +/// Print SubViewOp in the form: +/// ``` +/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `:` strided-memref-type `to` strided-memref-type +/// ``` +static void print(OpAsmPrinter &p, SubViewOp op) { + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; + p << op.getOperand(0); + printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset); + printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), + ShapedType::isDynamic); + printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset); + p.printOptionalAttrDict(op.getAttrs(), + /*elided=*/{SubViewOp::getSpecialAttrNames()}); + p << " : " << op.getOperand(0).getType() << " to " << op.getType(); +} + +/// Parse SubViewOp of the form: +/// ``` +/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `:` strided-memref-type `to` strided-memref-type +/// ``` +static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcInfo; + SmallVector offsetsInfo, sizesInfo, stridesInfo; + auto indexType = parser.getBuilder().getIndexType(); + Type srcType, dstType; + if (parser.parseOperand(srcInfo)) + return failure(); + if (parseListOfOperandsOrIntegers( + parser, result, SubViewOp::getStaticOffsetsAttrName(), + ShapedType::kDynamicStrideOrOffset, offsetsInfo) || + parseListOfOperandsOrIntegers(parser, result, + SubViewOp::getStaticSizesAttrName(), + ShapedType::kDynamicSize, sizesInfo) || + parseListOfOperandsOrIntegers( + parser, result, SubViewOp::getStaticStridesAttrName(), + ShapedType::kDynamicStrideOrOffset, stridesInfo)) + return failure(); + + auto b = parser.getBuilder(); + SmallVector segmentSizes{1, static_cast(offsetsInfo.size()), + static_cast(sizesInfo.size()), + static_cast(stridesInfo.size())}; + result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(), + b.getI32VectorAttr(segmentSizes)); + + return failure( + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(srcInfo, srcType, result.operands) || + parser.resolveOperands(offsetsInfo, indexType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.resolveOperands(stridesInfo, indexType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)); } void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source, - ValueRange offsets, ValueRange sizes, - ValueRange strides, Type resultType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides, ValueRange offsets, + ValueRange sizes, ValueRange strides, ArrayRef attrs) { - if (!resultType) - resultType = inferSubViewResultType(source.getType().cast()); - build(b, result, resultType, source, offsets, sizes, strides); + auto sourceMemRefType = source.getType().cast(); + auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets, + staticSizes, staticStrides); + build(b, result, resultType, source, offsets, sizes, strides, + b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), + b.getI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } -void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, - Type resultType, Value source) { - build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, - resultType); +/// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes` +/// and `staticStrides` are automatically filled with source-memref-rank +/// sentinel values that encode dynamic entries. +void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source, + ValueRange offsets, ValueRange sizes, + ValueRange strides, + ArrayRef attrs) { + auto sourceMemRefType = source.getType().cast(); + unsigned rank = sourceMemRefType.getRank(); + SmallVector staticOffsetsVector; + staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + SmallVector staticSizesVector; + staticSizesVector.assign(rank, ShapedType::kDynamicSize); + SmallVector staticStridesVector; + staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + build(b, result, source, staticOffsetsVector, staticSizesVector, + staticStridesVector, offsets, sizes, strides, attrs); +} + +/// Verify that a particular offset/size/stride static attribute is well-formed. +static LogicalResult +verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName, + ArrayAttr attr, llvm::function_ref isDynamic, + ValueRange values) { + /// Check static and dynamic offsets/sizes/strides breakdown. + if (attr.size() != op.getRank()) + return op.emitError("expected ") + << op.getRank() << " " << name << " values"; + unsigned expectedNumDynamicEntries = + llvm::count_if(attr.getValue(), [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + if (values.size() != expectedNumDynamicEntries) + return op.emitError("expected ") + << expectedNumDynamicEntries << " dynamic " << name << " values"; + return success(); +} + +/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. +static SmallVector extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); } +/// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { auto baseType = op.getBaseMemRefType().cast(); auto subViewType = op.getType(); - // The rank of the base and result subview must match. - if (baseType.getRank() != subViewType.getRank()) { - return op.emitError( - "expected rank of result type to match rank of base type "); - } - // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != subViewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " @@ -2243,96 +2456,32 @@ static LogicalResult verify(SubViewOp op) { << baseType << " and subview memref type " << subViewType; // Verify that the base memref type has a strided layout map. - int64_t baseOffset; - SmallVector baseStrides; - if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset))) - return op.emitError("base type ") << subViewType << " is not strided"; - - // Verify that the result memref type has a strided layout map. - int64_t subViewOffset; - SmallVector subViewStrides; - if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset))) - return op.emitError("result type ") << subViewType << " is not strided"; - - // Num offsets should either be zero or rank of memref. - if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) { - return op.emitError("expected number of dynamic offsets specified to match " - "the rank of the result type ") - << subViewType; - } - - // Num sizes should either be zero or rank of memref. - if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) { - return op.emitError("expected number of dynamic sizes specified to match " - "the rank of the result type ") - << subViewType; - } - - // Num strides should either be zero or rank of memref. - if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) { - return op.emitError("expected number of dynamic strides specified to match " - "the rank of the result type ") - << subViewType; - } - - // Verify that if the shape of the subview type is static, then sizes are not - // dynamic values, and vice versa. - if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) || - (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) { - return op.emitError("invalid to specify dynamic sizes when subview result " - "type is statically shaped and viceversa"); - } + if (!isStrided(baseType)) + return op.emitError("base type ") << baseType << " is not strided"; - // Verify that if dynamic sizes are specified, then the result memref type - // have full dynamic dimensions. - if (op.getNumSizes() > 0) { - if (llvm::any_of(subViewType.getShape(), [](int64_t dim) { - return dim != ShapedType::kDynamicSize; - })) { - // TODO: This is based on the assumption that number of size arguments are - // either 0, or the rank of the result type. It is possible to have more - // fine-grained verification where only particular dimensions are - // dynamic. That probably needs further changes to the shape op - // specification. - return op.emitError("expected shape of result type to be fully dynamic " - "when sizes are specified"); - } - } + // Verify static attributes offsets/sizes/strides. + if (failed(verifySubViewOpPart( + op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset, op.offsets()))) + return failure(); - // Verify that if dynamic offsets are specified or base memref has dynamic - // offset or base memref has dynamic strides, then the subview offset is - // dynamic. - if ((op.getNumOffsets() > 0 || - baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::is_contained(baseStrides, - MemRefType::getDynamicStrideOrOffset())) && - subViewOffset != MemRefType::getDynamicStrideOrOffset()) { - return op.emitError( - "expected result memref layout map to have dynamic offset"); - } + if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(), + op.static_sizes(), ShapedType::isDynamic, + op.sizes()))) + return failure(); + if (failed(verifySubViewOpPart( + op, "stride", op.getStaticStridesAttrName(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset, op.strides()))) + return failure(); - // For now, verify that if dynamic strides are specified, then all the result - // memref type have dynamic strides. - if (op.getNumStrides() > 0) { - if (llvm::any_of(subViewStrides, [](int64_t stride) { - return stride != MemRefType::getDynamicStrideOrOffset(); - })) { - return op.emitError("expected result type to have dynamic strides"); - } - } + // Verify result type against inferred type. + auto expectedType = inferSubViewResultType( + op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()), + extractFromI64ArrayAttr(op.static_sizes()), + extractFromI64ArrayAttr(op.static_strides())); + if (op.getType() != expectedType) + return op.emitError("expected result type to be ") << expectedType; - // If any of the base memref has dynamic stride, then the corresponding - // stride of the subview must also have dynamic stride. - assert(baseStrides.size() == subViewStrides.size()); - for (auto stride : enumerate(baseStrides)) { - if (stride.value() == MemRefType::getDynamicStrideOrOffset() && - subViewStrides[stride.index()] != - MemRefType::getDynamicStrideOrOffset()) { - return op.emitError( - "expected result type to have dynamic stride along a dimension if " - "the base memref type has dynamic stride along that dimension"); - } - } return success(); } @@ -2353,37 +2502,9 @@ SmallVector SubViewOp::getRanges() { LogicalResult SubViewOp::getStaticStrides(SmallVectorImpl &staticStrides) { - // If the strides are dynamic return failure. - if (getNumStrides()) - return failure(); - - // When static, the stride operands can be retrieved by taking the strides of - // the result of the subview op, and dividing the strides of the base memref. - int64_t resultOffset, baseOffset; - SmallVector resultStrides, baseStrides; - if (failed( - getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) || - llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || - failed(getStridesAndOffset(getType(), resultStrides, resultOffset))) + if (!strides().empty()) return failure(); - - assert(static_cast(resultStrides.size()) == getType().getRank() && - baseStrides.size() == resultStrides.size() && - "base and result memrefs must have the same rank"); - assert(!llvm::is_contained(resultStrides, - MemRefType::getDynamicStrideOrOffset()) && - "strides of subview op must be static, when there are no dynamic " - "strides specified"); - staticStrides.resize(getType().getRank()); - for (auto resultStride : enumerate(resultStrides)) { - auto baseStride = baseStrides[resultStride.index()]; - // The result stride is expected to be a multiple of the base stride. Abort - // if that is not the case. - if (resultStride.value() < baseStride || - resultStride.value() % baseStride != 0) - return failure(); - staticStrides[resultStride.index()] = resultStride.value() / baseStride; - } + staticStrides = extractFromI64ArrayAttr(static_strides()); return success(); } @@ -2391,136 +2512,80 @@ Value SubViewOp::getViewSource() { return source(); } namespace { -/// Pattern to rewrite a subview op with constant size arguments. -class SubViewOpShapeFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SubViewOp subViewOp, - PatternRewriter &rewriter) const override { - MemRefType subViewType = subViewOp.getType(); - // Follow all or nothing approach for shapes for now. If all the operands - // for sizes are constants then fold it into the type of the result memref. - if (subViewType.hasStaticShape() || - llvm::any_of(subViewOp.sizes(), [](Value operand) { - return !matchPattern(operand, m_ConstantIndex()); - })) { - return failure(); - } - SmallVector staticShape(subViewOp.getNumSizes()); - for (auto size : llvm::enumerate(subViewOp.sizes())) { - auto defOp = size.value().getDefiningOp(); - assert(defOp); - staticShape[size.index()] = cast(defOp).getValue(); +/// Take a list of `values` with potential new constant to extract and a list +/// of `constantValues` with`values.size()` sentinel that evaluate to true by +/// applying `isDynamic`. +/// Detects the `values` produced by a ConstantIndexOp and places the new +/// constant in place of the corresponding sentinel value. +void canonicalizeSubViewPart(SmallVectorImpl &values, + SmallVectorImpl &constantValues, + llvm::function_ref isDynamic) { + bool hasNewStaticValue = llvm::any_of( + values, [](Value val) { return matchPattern(val, m_ConstantIndex()); }); + if (hasNewStaticValue) { + for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size(); + cstIdx != e; ++cstIdx) { + // Was already static, skip. + if (!isDynamic(constantValues[cstIdx])) + continue; + // Newly static, move from Value to constant. + if (matchPattern(values[valIdx], m_ConstantIndex())) { + constantValues[cstIdx] = + cast(values[valIdx].getDefiningOp()).getValue(); + // Erase for impl. simplicity. Reverse iterator if we really must. + values.erase(std::next(values.begin(), valIdx)); + continue; + } + // Remains dynamic move to next value. + ++valIdx; } - MemRefType newMemRefType = - MemRefType::Builder(subViewType).setShape(staticShape); - auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), - ArrayRef(), subViewOp.strides(), newMemRefType); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); - return success(); } -}; +} -// Pattern to rewrite a subview op with constant stride arguments. -class SubViewOpStrideFolder final : public OpRewritePattern { +/// Pattern to rewrite a subview op with constant arguments. +class SubViewOpFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { - if (subViewOp.getNumStrides() == 0) { - return failure(); - } - // Follow all or nothing approach for strides for now. If all the operands - // for strides are constants then fold it into the strides of the result - // memref. - int64_t baseOffset, resultOffset; - SmallVector baseStrides, resultStrides; - MemRefType subViewType = subViewOp.getType(); - if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, - baseOffset)) || - failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || - llvm::is_contained(baseStrides, - MemRefType::getDynamicStrideOrOffset()) || - llvm::any_of(subViewOp.strides(), [](Value stride) { - return !matchPattern(stride, m_ConstantIndex()); - })) { + // No constant operand, just return; + if (llvm::none_of(subViewOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) return failure(); - } - SmallVector staticStrides(subViewOp.getNumStrides()); - for (auto stride : llvm::enumerate(subViewOp.strides())) { - auto defOp = stride.value().getDefiningOp(); - assert(defOp); - assert(baseStrides[stride.index()] > 0); - staticStrides[stride.index()] = - cast(defOp).getValue() * baseStrides[stride.index()]; - } - AffineMap layoutMap = makeStridedLinearLayoutMap( - staticStrides, resultOffset, rewriter.getContext()); - MemRefType newMemRefType = - MemRefType::Builder(subViewType).setAffineMaps(layoutMap); + // At least one of offsets/sizes/strides is a new constant. + // Form the new list of operands and constant attributes from the existing. + SmallVector newOffsets(subViewOp.offsets()); + SmallVector newStaticOffsets = + extractFromI64ArrayAttr(subViewOp.static_offsets()); + assert(newStaticOffsets.size() == subViewOp.getRank()); + canonicalizeSubViewPart(newOffsets, newStaticOffsets, + ShapedType::isDynamicStrideOrOffset); + + SmallVector newSizes(subViewOp.sizes()); + SmallVector newStaticSizes = + extractFromI64ArrayAttr(subViewOp.static_sizes()); + assert(newStaticOffsets.size() == subViewOp.getRank()); + canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic); + + SmallVector newStrides(subViewOp.strides()); + SmallVector newStaticStrides = + extractFromI64ArrayAttr(subViewOp.static_strides()); + assert(newStaticOffsets.size() == subViewOp.getRank()); + canonicalizeSubViewPart(newStrides, newStaticStrides, + ShapedType::isDynamicStrideOrOffset); + + // Create the new op in canonical form. auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), - subViewOp.sizes(), ArrayRef(), newMemRefType); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); - return success(); - } -}; + subViewOp.getLoc(), subViewOp.source(), newStaticOffsets, + newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides); -// Pattern to rewrite a subview op with constant offset arguments. -class SubViewOpOffsetFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SubViewOp subViewOp, - PatternRewriter &rewriter) const override { - if (subViewOp.getNumOffsets() == 0) { - return failure(); - } - // Follow all or nothing approach for offsets for now. If all the operands - // for offsets are constants then fold it into the offset of the result - // memref. - int64_t baseOffset, resultOffset; - SmallVector baseStrides, resultStrides; - MemRefType subViewType = subViewOp.getType(); - if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, - baseOffset)) || - failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || - llvm::is_contained(baseStrides, - MemRefType::getDynamicStrideOrOffset()) || - baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::any_of(subViewOp.offsets(), [](Value stride) { - return !matchPattern(stride, m_ConstantIndex()); - })) { - return failure(); - } - - auto staticOffset = baseOffset; - for (auto offset : llvm::enumerate(subViewOp.offsets())) { - auto defOp = offset.value().getDefiningOp(); - assert(defOp); - assert(baseStrides[offset.index()] > 0); - staticOffset += - cast(defOp).getValue() * baseStrides[offset.index()]; - } - - AffineMap layoutMap = makeStridedLinearLayoutMap( - resultStrides, staticOffset, rewriter.getContext()); - MemRefType newMemRefType = - MemRefType::Builder(subViewType).setAffineMaps(layoutMap); - auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), ArrayRef(), - subViewOp.sizes(), subViewOp.strides(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, subViewOp.getType()); + return success(); } }; @@ -2633,8 +2698,7 @@ OpFoldResult SubViewOp::fold(ArrayRef) { void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 4cc8e11294d788..41ed5315ab1c96 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -839,7 +839,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK32: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i32, // CHECK32: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i32, // CHECK32: %[[ARG2:.*]]: !llvm.i32) -func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { +func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] @@ -883,7 +883,8 @@ func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 %1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : - memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref (d0 * s1 + d1 * s2 + s0)>> + memref<64x4xf32, offset: 0, strides: [4, 1]> + to memref return } @@ -899,7 +900,7 @@ func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg // CHECK32: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i32, // CHECK32: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i32, // CHECK32: %[[ARG2:.*]]: !llvm.i32) -func @subview_non_zero_addrspace(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>, 3>, %arg0 : index, %arg1 : index, %arg2 : index) { +func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1], 3>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] @@ -943,13 +944,14 @@ func @subview_non_zero_addrspace(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 %1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : - memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>, 3> to memref (d0 * s1 + d1 * s2 + s0)>, 3> + memref<64x4xf32, offset: 0, strides: [4, 1], 3> + to memref return } // CHECK-LABEL: func @subview_const_size( // CHECK32-LABEL: func @subview_const_size( -func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { +func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] @@ -996,14 +998,15 @@ func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 // CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> - %1 = subview %0[%arg0, %arg1][][%arg0, %arg1] : - memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<4x2xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>> + %1 = subview %0[%arg0, %arg1][4, 2][%arg0, %arg1] : + memref<64x4xf32, offset: 0, strides: [4, 1]> + to memref<4x2xf32, offset: ?, strides: [?, ?]> return } // CHECK-LABEL: func @subview_const_stride( // CHECK32-LABEL: func @subview_const_stride( -func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { +func @subview_const_stride(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] @@ -1046,14 +1049,15 @@ func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> - %1 = subview %0[%arg0, %arg1][%arg0, %arg1][] : - memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref (d0 * 4 + d1 * 2 + s0)>> + %1 = subview %0[%arg0, %arg1][%arg0, %arg1][1, 2] : + memref<64x4xf32, offset: 0, strides: [4, 1]> + to memref return } // CHECK-LABEL: func @subview_const_stride_and_offset( // CHECK32-LABEL: func @subview_const_stride_and_offset( -func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>) { +func @subview_const_stride_and_offset(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] @@ -1092,8 +1096,9 @@ func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1) // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> - %1 = subview %0[][][] : - memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<62x3xf32, affine_map<(d0, d1) -> (d0 * 4 + d1 + 8)>> + %1 = subview %0[0, 8][62, 3][1, 1] : + memref<64x4xf32, offset: 0, strides: [4, 1]> + to memref<62x3xf32, offset: 8, strides: [4, 1]> return } diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir index bb9c2728dcb8a1..1be148707458af 100644 --- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir +++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir @@ -7,7 +7,7 @@ func @invalid_memref_cast(%arg0: memref) { %c0 = constant 0 : index // expected-error@+1 {{'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values, but got '!llvm<"{ double*, double*, i64, [2 x i64], [2 x i64] }">'}} %5 = memref_cast %arg0 : memref to memref - %25 = std.subview %5[%c0, %c0][%c1, %c1][] : memref to memref + %25 = std.subview %5[%c0, %c0][%c1, %c1][1, 1] : memref to memref return } diff --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir index 3540a101c55bf1..d3b339e82a88fa 100644 --- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir @@ -11,7 +11,7 @@ func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : in // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index // CHECK: load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} - %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> return %1 : f32 } @@ -25,7 +25,8 @@ func @fold_dynamic_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : i // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index // CHECK: load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} - %0 = subview %arg0[%arg1, %arg2][][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> + %0 = subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : + memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]> return %1 : f32 } @@ -41,7 +42,8 @@ func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : i // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index // CHECK: store [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} - %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : + memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> return } @@ -55,7 +57,8 @@ func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index // CHECK: store [[ARG7]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} - %0 = subview %arg0[%arg1, %arg2][][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> + %0 = subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : + memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]> return } diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir index c0d904adb51c98..2e28079b4f15ec 100644 --- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir @@ -28,7 +28,7 @@ func @fold_static_stride_subview // CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[C3]] // CHECK: %[[T9:.*]] = addi %[[ARG2]], %[[T8]] // CHECK store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]] - %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> %2 = sqrt %1 : f32 store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir index 52eddfcd69f85b..5ca6de5023ebfc 100644 --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -103,8 +103,8 @@ func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) { affine.for %arg4 = 0 to %13 step 264 { %18 = dim %0, 0 : memref %20 = std.subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref - to memref (d0 * s1 + d1 * s2 + s0)>> - %24 = dim %20, 0 : memref (d0 * s1 + d1 * s2 + s0)>> + to memref + %24 = dim %20, 0 : memref affine.for %arg5 = 0 to %24 step 768 { "foo"() : () -> () } diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir index d5148b3424c4cc..bd6a3e7d7033ba 100644 --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -23,9 +23,9 @@ func @matmul_f32(%A: memref, %M: index, %N: index, %K: index) { loop.for %arg4 = %c0 to %6 step %c2 { loop.for %arg5 = %c0 to %8 step %c3 { loop.for %arg6 = %c0 to %7 step %c4 { - %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref to memref - %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref to memref - %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref to memref + %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref + %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref + %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref linalg.matmul(%11, %14, %17) : memref, memref, memref } } @@ -88,9 +88,9 @@ func @matmul_f64(%A: memref, %M: index, %N: index, %K: index) { loop.for %arg4 = %c0 to %6 step %c2 { loop.for %arg5 = %c0 to %8 step %c3 { loop.for %arg6 = %c0 to %7 step %c4 { - %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref to memref - %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref to memref - %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref to memref + %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref + %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref + %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref linalg.matmul(%11, %14, %17) : memref, memref, memref } } @@ -153,9 +153,9 @@ func @matmul_i32(%A: memref, %M: index, %N: index, %K: index) { loop.for %arg4 = %c0 to %6 step %c2 { loop.for %arg5 = %c0 to %8 step %c3 { loop.for %arg6 = %c0 to %7 step %c4 { - %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref to memref - %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref to memref - %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref to memref + %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref + %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref + %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref linalg.matmul(%11, %14, %17) : memref, memref, memref } } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index da098a82b1b02b..41172aa22527bf 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -10,15 +10,14 @@ // CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> // CHECK-DAG: #[[BASE_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> -// CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)> // CHECK-DAG: #[[BASE_MAP1:map[0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-DAG: #[[SUBVIEW_MAP1:map[0-9]+]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // CHECK-DAG: #[[BASE_MAP2:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 22 + d1)> -// CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> -// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)> -// CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)> +// CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1 * 2)> // CHECK-LABEL: func @func_with_ops(%arg0: f32) { @@ -708,41 +707,56 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %c1 = constant 1 : index %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> - // CHECK: subview %0[%c0, %c0, %c0] [%arg0, %arg1, %arg2] [%c1, %c1, %c1] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref + // CHECK: subview %0[%c0, %c0, %c0] [%arg0, %arg1, %arg2] [%c1, %c1, %c1] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> + // CHECK-SAME: to memref %1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + : memref<8x16x4xf32, offset:0, strides: [64, 4, 1]> to + memref %2 = alloc()[%arg2] : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> - // CHECK: subview %2[%c1] [%arg0] [%c1] : memref<64xf32, #[[BASE_MAP1]]> to memref + // CHECK: subview %2[%c1] [%arg0] [%c1] : + // CHECK-SAME: memref<64xf32, #[[BASE_MAP1]]> + // CHECK-SAME: to memref %3 = subview %2[%c1][%arg0][%c1] : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref (d0 * s1 + s0)>> %4 = alloc() : memref<64x22xf32, affine_map<(d0, d1) -> (d0 * 22 + d1)>> - // CHECK: subview %4[%c0, %c1] [%arg0, %arg1] [%c1, %c0] : memref<64x22xf32, #[[BASE_MAP2]]> to memref + // CHECK: subview %4[%c0, %c1] [%arg0, %arg1] [%c1, %c0] : + // CHECK-SAME: memref<64x22xf32, #[[BASE_MAP2]]> + // CHECK-SAME: to memref %5 = subview %4[%c0, %c1][%arg0, %arg1][%c1, %c0] - : memref<64x22xf32, affine_map<(d0, d1) -> (d0 * 22 + d1)>> to - memref (d0 * s1 + d1 * s2 + s0)>> + : memref<64x22xf32, offset:0, strides: [22, 1]> to + memref - // CHECK: subview %0[] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<4x4x4xf32, #[[SUBVIEW_MAP3]]> - %6 = subview %0[][][] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref<4x4x4xf32, affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)>> + // CHECK: subview %0[0, 2, 0] [4, 4, 4] [1, 1, 1] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> + // CHECK-SAME: to memref<4x4x4xf32, #[[SUBVIEW_MAP3]]> + %6 = subview %0[0, 2, 0][4, 4, 4][1, 1, 1] + : memref<8x16x4xf32, offset:0, strides: [64, 4, 1]> to + memref<4x4x4xf32, offset:8, strides: [64, 4, 1]> %7 = alloc(%arg1, %arg2) : memref - // CHECK: subview {{%.*}}[] [] [] : memref to memref<4x4xf32, #[[SUBVIEW_MAP4]]> - %8 = subview %7[][][] - : memref to memref<4x4xf32, offset: ?, strides:[?, ?]> + // CHECK: subview {{%.*}}[0, 0] [4, 4] [1, 1] : + // CHECK-SAME: memref + // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP4]]> + %8 = subview %7[0, 0][4, 4][1, 1] + : memref to memref<4x4xf32, offset: ?, strides:[?, 1]> %9 = alloc() : memref<16x4xf32> - // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [] [{{%.*}}, {{%.*}}] : memref<16x4xf32> to memref<4x4xf32, #[[SUBVIEW_MAP4]] - %10 = subview %9[%arg1, %arg1][][%arg2, %arg2] + // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [4, 4] [{{%.*}}, {{%.*}}] : + // CHECK-SAME: memref<16x4xf32> + // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP2]] + %10 = subview %9[%arg1, %arg1][4, 4][%arg2, %arg2] : memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[?, ?]> - // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [] [] : memref<16x4xf32> to memref<4x4xf32, #[[SUBVIEW_MAP5]] - %11 = subview %9[%arg1, %arg2][][] + + // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [4, 4] [2, 2] : + // CHECK-SAME: memref<16x4xf32> + // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP5]] + %11 = subview %9[%arg1, %arg2][4, 4][2, 2] : memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[8, 2]> + return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 0f9fb3ccada568..b0535047874fd7 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -976,33 +976,22 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { // ----- func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>, 2> + %0 = alloc() : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> // expected-error@+1 {{different memory spaces}} - %1 = subview %0[][%arg2][] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>, 2> to + %1 = subview %0[0, 0, 0][%arg2][1, 1, 1] + : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> to memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>> return } // ----- -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> - // expected-error@+1 {{is not strided}} - %1 = subview %0[][%arg2][] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0, d1, d2)>> - return -} - -// ----- - func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> // expected-error@+1 {{is not strided}} - %1 = subview %0[][%arg2][] + %1 = subview %0[0, 0, 0][%arg2][1, 1, 1] : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> to - memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>> + memref<8x?x4xf32, offset: 0, strides: [?, 4, 1]> return } @@ -1010,8 +999,8 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected number of dynamic offsets specified to match the rank of the result type}} - %1 = subview %0[%arg0, %arg1][%arg2][] + // expected-error@+1 {{expected 3 offset values}} + %1 = subview %0[%arg0, %arg1][%arg2][1, 1, 1] : memref<8x16x4xf32> to memref<8x?x4xf32, offset: 0, strides:[?, ?, 4]> return @@ -1021,7 +1010,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected result type to have dynamic strides}} + // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>'}} %1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2] : memref<8x16x4xf32> to memref @@ -1030,106 +1019,6 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { // ----- -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = alloc() : memref<8x16x4xf32> - %c0 = constant 0 : index - %c1 = constant 1 : index - // expected-error@+1 {{expected result memref layout map to have dynamic offset}} - %1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] - : memref<8x16x4xf32> to - memref - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{expected rank of result type to match rank of base type}} - %0 = subview %arg1[%arg0, %arg0][][%arg0, %arg0] : memref to memref -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{expected number of dynamic offsets specified to match the rank of the result type}} - %0 = subview %arg1[%arg0][][] : memref to memref<4x4xf32, offset: ?, strides: [4, 1]> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{expected number of dynamic sizes specified to match the rank of the result type}} - %0 = subview %arg1[][%arg0][] : memref to memref -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{expected number of dynamic strides specified to match the rank of the result type}} - %0 = subview %arg1[][][%arg0] : memref to memref -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{invalid to specify dynamic sizes when subview result type is statically shaped and viceversa}} - %0 = subview %arg1[][%arg0, %arg0][] : memref to memref<4x8xf32, offset: ?, strides: [?, ?]> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{invalid to specify dynamic sizes when subview result type is statically shaped and viceversa}} - %0 = subview %arg1[][][] : memref to memref -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32>) { - // expected-error@+1 {{expected result memref layout map to have dynamic offset}} - %0 = subview %arg1[%arg0, %arg0][][] : memref<16x4xf32> to memref<4x2xf32> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: ?, strides: [4, 1]>) { - // expected-error@+1 {{expected result memref layout map to have dynamic offset}} - %0 = subview %arg1[][][] : memref<16x4xf32, offset: ?, strides: [4, 1]> to memref<4x2xf32> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: 8, strides:[?, 1]>) { - // expected-error@+1 {{expected result memref layout map to have dynamic offset}} - %0 = subview %arg1[][][] : memref<16x4xf32, offset: 8, strides:[?, 1]> to memref<4x2xf32> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32>) { - // expected-error@+1 {{expected result type to have dynamic strides}} - %0 = subview %arg1[][][%arg0, %arg0] : memref<16x4xf32> to memref<4x2xf32> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: 0, strides:[?, ?]>) { - // expected-error@+1 {{expected result type to have dynamic stride along a dimension if the base memref type has dynamic stride along that dimension}} - %0 = subview %arg1[][][] : memref<16x4xf32, offset: 0, strides:[?, ?]> to memref<4x2xf32, offset:?, strides:[2, 1]> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - %c0 = constant 0 : index - %c1 = constant 1 : index - // expected-error@+1 {{expected shape of result type to be fully dynamic when sizes are specified}} - %0 = subview %arg1[%c0, %c0, %c0][%c1, %arg0, %c1][%c1, %c1, %c1] : memref to memref - return -} - -// ----- - func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]> diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index e4090ccd6073b7..dfcf086c73de10 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -427,7 +427,7 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref, memref, memref } -#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> +#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> #map2 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s2 + d1 * s1 + d2 + s0)> // CHECK-LABEL: func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index, @@ -684,106 +684,138 @@ func @view(%arg0 : index) -> (f32, f32, f32, f32) { // CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> // CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 128 + s0 + d1 * 28 + d2 * 11)> // CHECK-DAG: #[[SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2 + 79)> -// CHECK-DAG: #[[SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)> -// CHECK-DAG: #[[SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 12)> +// CHECK-DAG: #[[SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2 * 2)> +// CHECK-DAG: #[[SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)> +// CHECK-DAG: #[[SUBVIEW_MAP8:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 12)> + // CHECK-LABEL: func @subview // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index func @subview(%arg0 : index, %arg1 : index) -> (index, index) { // CHECK: %[[C0:.*]] = constant 0 : index %c0 = constant 0 : index - // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK-NOT: constant 1 : index %c1 = constant 1 : index - // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK-NOT: constant 2 : index %c2 = constant 2 : index + // Folded but reappears after subview folding into dim. // CHECK: %[[C7:.*]] = constant 7 : index %c7 = constant 7 : index + // Folded but reappears after subview folding into dim. // CHECK: %[[C11:.*]] = constant 11 : index %c11 = constant 11 : index + // CHECK-NOT: constant 15 : index %c15 = constant 15 : index // CHECK: %[[ALLOC0:.*]] = alloc() - %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> + %0 = alloc() : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> // Test: subview with constant base memref and constant operands is folded. // Note that the subview uses the base memrefs layout map because it used // zero offset and unit stride arguments. - // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[BASE_MAP0]]> + // CHECK: subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [1, 1, 1] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> + // CHECK-SAME: to memref<7x11x2xf32, #[[BASE_MAP0]]> %1 = subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - %v0 = load %1[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - - // Test: subview with one dynamic operand should not be folded. - // CHECK: subview %[[ALLOC0]][%[[C0]], %[[ARG0]], %[[C0]]] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]> + : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to + memref + %v0 = load %1[%c0, %c0, %c0] : memref + + // Test: subview with one dynamic operand can also be folded. + // CHECK: subview %[[ALLOC0]][0, %[[ARG0]], 0] [7, 11, 15] [1, 1, 1] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> + // CHECK-SAME: to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]> %2 = subview %0[%c0, %arg0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - store %v0, %2[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to + memref + store %v0, %2[%c0, %c0, %c0] : memref // CHECK: %[[ALLOC1:.*]] = alloc(%[[ARG0]]) - %3 = alloc(%arg0) : memref (d0 * 64 + d1 * 4 + d2)>> + %3 = alloc(%arg0) : memref // Test: subview with constant operands but dynamic base memref is folded as long as the strides and offset of the base memref are static. - // CHECK: subview %[[ALLOC1]][] [] [] : memref to memref<7x11x15xf32, #[[BASE_MAP0]]> + // CHECK: subview %[[ALLOC1]][0, 0, 0] [7, 11, 15] [1, 1, 1] : + // CHECK-SAME: memref + // CHECK-SAME: to memref<7x11x15xf32, #[[BASE_MAP0]]> %4 = subview %3[%c0, %c0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] - : memref (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - store %v0, %4[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + : memref to + memref + store %v0, %4[%c0, %c0, %c0] : memref // Test: subview offset operands are folded correctly w.r.t. base strides. - // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP1]]> + // CHECK: subview %[[ALLOC0]][1, 2, 7] [7, 11, 2] [1, 1, 1] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to + // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP1]]> %5 = subview %0[%c1, %c2, %c7] [%c7, %c11, %c2] [%c1, %c1, %c1] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - store %v0, %5[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to + memref + store %v0, %5[%c0, %c0, %c0] : memref // Test: subview stride operands are folded correctly w.r.t. base strides. - // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP2]]> + // CHECK: subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [2, 7, 11] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> + // CHECK-SAME: to memref<7x11x2xf32, #[[SUBVIEW_MAP2]]> %6 = subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c2, %c7, %c11] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - store %v0, %6[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to + memref + store %v0, %6[%c0, %c0, %c0] : memref // Test: subview shape are folded, but offsets and strides are not even if base memref is static - // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP3]]> - %10 = subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to memref - store %v0, %10[%arg1, %arg1, %arg1] : memref + // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to + // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP3]]> + %10 = subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : + memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to + memref + store %v0, %10[%arg1, %arg1, %arg1] : + memref // Test: subview strides are folded, but offsets and shape are not even if base memref is static - // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref to memref - store %v0, %11[%arg0, %arg0, %arg0] : memref + // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 7, 11] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to + // CHECK-SAME: memref to + memref + store %v0, %11[%arg0, %arg0, %arg0] : + memref // Test: subview offsets are folded, but strides and shape are not even if base memref is static - // CHECK: subview %[[ALLOC0]][] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref to memref - store %v0, %13[%arg1, %arg1, %arg1] : memref + // CHECK: subview %[[ALLOC0]][1, 2, 7] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to + // CHECK-SAME: memref to + memref + store %v0, %13[%arg1, %arg1, %arg1] : + memref // CHECK: %[[ALLOC2:.*]] = alloc(%[[ARG0]], %[[ARG0]], %[[ARG1]]) %14 = alloc(%arg0, %arg0, %arg1) : memref // Test: subview shape are folded, even if base memref is not static - // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref to memref<7x11x2xf32, #[[SUBVIEW_MAP3]]> - %15 = subview %14[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : memref to memref + // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : + // CHECK-SAME: memref to + // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP3]]> + %15 = subview %14[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : + memref to + memref store %v0, %15[%arg1, %arg1, %arg1] : memref - // TEST: subview strides are not folded when the base memref is not static - // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[C2]], %[[C2]], %[[C2]]] : memref to memref to memref + // TEST: subview strides are folded, in the type only the most minor stride is folded. + // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 2, 2] : + // CHECK-SAME: memref to + // CHECK-SAME: memref to + memref store %v0, %16[%arg0, %arg0, %arg0] : memref - // TEST: subview offsets are not folded when the base memref is not static - // CHECK: subview %[[ALLOC2]][%[[C1]], %[[C1]], %[[C1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref to memref to memref + // TEST: subview offsets are folded but the type offset remains dynamic, when the base memref is not static + // CHECK: subview %[[ALLOC2]][1, 1, 1] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : + // CHECK-SAME: memref to + // CHECK-SAME: memref to + memref store %v0, %17[%arg0, %arg0, %arg0] : memref // CHECK: %[[ALLOC3:.*]] = alloc() : memref<12x4xf32> @@ -791,20 +823,26 @@ func @subview(%arg0 : index, %arg1 : index) -> (index, index) { %c4 = constant 4 : index // TEST: subview strides are maintained when sizes are folded - // CHECK: subview %[[ALLOC3]][%arg1, %arg1] [] [] : memref<12x4xf32> to memref<2x4xf32, #[[SUBVIEW_MAP6]]> - %19 = subview %18[%arg1, %arg1] [%c2, %c4] [] : memref<12x4xf32> to memref + // CHECK: subview %[[ALLOC3]][%arg1, %arg1] [2, 4] [1, 1] : + // CHECK-SAME: memref<12x4xf32> to + // CHECK-SAME: memref<2x4xf32, #[[SUBVIEW_MAP7]]> + %19 = subview %18[%arg1, %arg1] [%c2, %c4] [1, 1] : + memref<12x4xf32> to + memref store %v0, %19[%arg1, %arg1] : memref // TEST: subview strides and sizes are maintained when offsets are folded - // CHECK: subview %[[ALLOC3]][] [] [] : memref<12x4xf32> to memref<12x4xf32, #[[SUBVIEW_MAP7]]> - %20 = subview %18[%c2, %c4] [] [] : memref<12x4xf32> to memref<12x4xf32, offset: ?, strides:[4, 1]> + // CHECK: subview %[[ALLOC3]][2, 4] [12, 4] [1, 1] : + // CHECK-SAME: memref<12x4xf32> to + // CHECK-SAME: memref<12x4xf32, #[[SUBVIEW_MAP8]]> + %20 = subview %18[%c2, %c4] [12, 4] [1, 1] : + memref<12x4xf32> to + memref<12x4xf32, offset: ?, strides:[4, 1]> store %v0, %20[%arg1, %arg1] : memref<12x4xf32, offset: ?, strides:[4, 1]> // Test: dim on subview is rewritten to size operand. - %7 = dim %4, 0 : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - %8 = dim %4, 1 : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + %7 = dim %4, 0 : memref + %8 = dim %4, 1 : memref // CHECK: return %[[C7]], %[[C11]] return %7, %8 : index, index @@ -891,15 +929,3 @@ func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { // CHECK: return %[[ARG]] return %res : tensor<4x5xi32> } - -// ----- - -// CHECK-LABEL: func @memref_cast_folding_subview -func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref) { - %0 = memref_cast %arg0 : memref<4x5xf32> to memref - // CHECK-NEXT: subview %{{.*}}: memref<4x5xf32> - %1 = subview %0[][%i,%i][]: memref to memref - // CHECK-NEXT: return %{{.*}} - return %1: memref -} -