diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index b7c02a205f5e6..dbb803bcb1e1e 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -839,8 +839,8 @@ def TileOp : Op:$dynamic_sizes, - DefaultValuedAttr:$static_sizes, - DefaultValuedAttr:$interchange); + DefaultValuedOptionalAttr:$static_sizes, + DefaultValuedOptionalAttr:$interchange); let results = (outs PDL_Operation:$tiled_linalg_op, Variadic:$loops); @@ -917,8 +917,8 @@ def TileToForeachThreadOp : let arguments = (ins PDL_Operation:$target, Variadic:$num_threads, Variadic:$tile_sizes, - DefaultValuedAttr:$static_num_threads, - DefaultValuedAttr:$static_tile_sizes, + DefaultValuedOptionalAttr:$static_num_threads, + DefaultValuedOptionalAttr:$static_tile_sizes, OptionalAttr:$mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$tiled_op); @@ -1009,8 +1009,8 @@ def TileToScfForOp : Op:$dynamic_sizes, - DefaultValuedAttr:$static_sizes, - DefaultValuedAttr:$interchange); + DefaultValuedOptionalAttr:$static_sizes, + DefaultValuedOptionalAttr:$interchange); let results = (outs PDL_Operation:$tiled_linalg_op, Variadic:$loops); diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index ccf01be858680..4a567b40e2e5e 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1260,9 +1260,9 @@ def MemRef_ReinterpretCastOp Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides); + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides); let results = (outs AnyMemRef:$result); let assemblyFormat = [{ @@ -1476,7 +1476,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ or copies. A reassociation is defined as a grouping of dimensions and is represented - with an array of I64ArrayAttr attributes. + with an array of DenseI64ArrayAttr attributes. Example: @@ -1563,7 +1563,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [ type. A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. + represented with an array of DenseI64ArrayAttr attribute. Note: Only the dimensions within a reassociation group must be contiguous. The remaining dimensions may be non-contiguous. @@ -1855,9 +1855,9 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides); + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides); let results = (outs AnyMemRef:$result); let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 14060075b2340..0af5811638a85 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -326,9 +326,9 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides ); let results = (outs AnyRankedTensor:$result); @@ -807,9 +807,9 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [ Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides ); let results = (outs AnyRankedTensor:$result); @@ -1013,7 +1013,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> { rank whose sizes are a reassociation of the original `src`. A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. + represented with an array of DenseI64ArrayAttr attribute. The verification rule is that the reassociation maps are applied to the result tensor with the higher rank to obtain the operand tensor with the @@ -1065,7 +1065,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { rank whose sizes are a reassociation of the original `src`. A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. + represented with an array of DenseI64ArrayAttr attribute. The verification rule is that the reassociation maps are applied to the operand tensor with the higher rank to obtain the result tensor with the @@ -1206,8 +1206,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [ AnyTensor:$source, Variadic:$low, Variadic:$high, - I64ArrayAttr:$static_low, - I64ArrayAttr:$static_high, + DenseI64ArrayAttr:$static_low, + DenseI64ArrayAttr:$static_high, UnitAttr:$nofold); let regions = (region SizedRegion<1>:$region); @@ -1254,16 +1254,17 @@ def Tensor_PadOp : Tensor_Op<"pad", [ // Return a vector of all the static or dynamic values (low/high padding) of // the op. - inline SmallVector getMixedPadImpl(ArrayAttr staticAttrs, + inline SmallVector getMixedPadImpl(ArrayRef staticAttrs, ValueRange values) { + Builder builder(*this); SmallVector res; unsigned numDynamic = 0; unsigned count = staticAttrs.size(); for (unsigned idx = 0; idx < count; ++idx) { - if (ShapedType::isDynamic(staticAttrs[idx].cast().getInt())) + if (ShapedType::isDynamic(staticAttrs[idx])) res.push_back(values[numDynamic++]); else - res.push_back(staticAttrs[idx]); + res.push_back(builder.getI64IntegerAttr(staticAttrs[idx])); } return res; } @@ -1400,9 +1401,9 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [ Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides ); let assemblyFormat = [{ $source `into` $dest `` @@ -1748,7 +1749,7 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [ DefaultValuedOptionalAttr:$outer_dims_perm, DenseI64ArrayAttr:$inner_dims_pos, Variadic:$inner_tiles, - I64ArrayAttr:$static_inner_tiles); + DenseI64ArrayAttr:$static_inner_tiles); let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ $source @@ -1803,7 +1804,7 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> { DefaultValuedOptionalAttr:$outer_dims_perm, DenseI64ArrayAttr:$inner_dims_pos, Variadic:$inner_tiles, - I64ArrayAttr:$static_inner_tiles); + DenseI64ArrayAttr:$static_inner_tiles); let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ $source diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index f09cf88afaab3..e72f7095b6da0 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -87,6 +87,18 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); SmallVector getAsValues(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec); +/// Return a vector of OpFoldResults with the same size a staticValues, but all +/// elements for which ShapedType::isDynamic is true, will be replaced by +/// dynamicValues. +SmallVector getMixedValues(ArrayRef staticValues, + ValueRange dynamicValues, Builder &b); + +/// Decompose a vector of mixed static or dynamic values into the corresponding +/// pair of arrays. This is the inverse function of `getMixedValues`. +std::pair> +decomposeMixedValues(Builder &b, + const SmallVectorImpl &mixedValues); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index 700546d082e6f..f950933b23c7a 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -21,18 +21,6 @@ namespace mlir { -/// Return a vector of OpFoldResults with the same size a staticValues, but all -/// elements for which ShapedType::isDynamic is true, will be replaced by -/// dynamicValues. -SmallVector getMixedValues(ArrayAttr staticValues, - ValueRange dynamicValues); - -/// Decompose a vector of mixed static or dynamic values into the corresponding -/// pair of arrays. This is the inverse function of `getMixedValues`. -std::pair> -decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues); - class OffsetSizeAndStrideOpInterface; namespace detail { @@ -61,7 +49,7 @@ namespace mlir { /// idiomatic printing of mixed value and integer attributes in a list. E.g. /// `[%arg0, 7, 42, %arg42]`. void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, - OperandRange values, ArrayAttr integers); + OperandRange values, ArrayRef integers); /// Pasrer hook for custom directive in assemblyFormat. /// @@ -79,13 +67,14 @@ void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl &values, - ArrayAttr &integers); + DenseI64ArrayAttr &integers); /// Verify that a the `values` has as many elements as the number of entries in /// `attr` for which `isDynamic` evaluates to true. LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, - ArrayAttr attr, ValueRange values); + ArrayRef attr, + ValueRange values); } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td index aca01262134c4..b5870af8c7936 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -124,7 +124,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*desc=*/[{ Return the static offset attributes. }], - /*retTy=*/"::mlir::ArrayAttr", + /*retTy=*/"::llvm::ArrayRef", /*methodName=*/"static_offsets", /*args=*/(ins), /*methodBody=*/"", @@ -136,7 +136,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*desc=*/[{ Return the static size attributes. }], - /*retTy=*/"::mlir::ArrayAttr", + /*retTy=*/"::llvm::ArrayRef", /*methodName=*/"static_sizes", /*args=*/(ins), /*methodBody=*/"", @@ -148,7 +148,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*desc=*/[{ Return the dynamic stride attributes. }], - /*retTy=*/"::mlir::ArrayAttr", + /*retTy=*/"::llvm::ArrayRef", /*methodName=*/"static_strides", /*args=*/(ins), /*methodBody=*/"", @@ -165,8 +165,9 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ + Builder b($_op->getContext()); return ::mlir::getMixedValues($_op.getStaticOffsets(), - $_op.getOffsets()); + $_op.getOffsets(), b); }] >, InterfaceMethod< @@ -178,7 +179,8 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes()); + Builder b($_op->getContext()); + return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes(), b); }] >, InterfaceMethod< @@ -190,8 +192,9 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ + Builder b($_op->getContext()); return ::mlir::getMixedValues($_op.getStaticStrides(), - $_op.getStrides()); + $_op.getStrides(), b); }] >, @@ -202,9 +205,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - ::llvm::APInt v = *(static_offsets() - .template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return ::mlir::ShapedType::isDynamic(v.getSExtValue()); + return ::mlir::ShapedType::isDynamic(static_offsets()[idx]); }] >, InterfaceMethod< @@ -214,9 +215,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - ::llvm::APInt v = *(static_sizes() - .template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return ::mlir::ShapedType::isDynamic(v.getSExtValue()); + return ::mlir::ShapedType::isDynamic(static_sizes()[idx]); }] >, InterfaceMethod< @@ -226,9 +225,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - ::llvm::APInt v = *(static_strides() - .template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return ::mlir::ShapedType::isDynamic(v.getSExtValue()); + return ::mlir::ShapedType::isDynamic(static_strides()[idx]); }] >, InterfaceMethod< @@ -241,9 +238,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*methodBody=*/"", /*defaultImplementation=*/[{ assert(!$_op.isDynamicOffset(idx) && "expected static offset"); - ::llvm::APInt v = *(static_offsets(). - template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return v.getSExtValue(); + return static_offsets()[idx]; }] >, InterfaceMethod< @@ -256,9 +251,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*methodBody=*/"", /*defaultImplementation=*/[{ assert(!$_op.isDynamicSize(idx) && "expected static size"); - ::llvm::APInt v = *(static_sizes(). - template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return v.getSExtValue(); + return static_sizes()[idx]; }] >, InterfaceMethod< @@ -271,9 +264,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*methodBody=*/"", /*defaultImplementation=*/[{ assert(!$_op.isDynamicStride(idx) && "expected static stride"); - ::llvm::APInt v = *(static_strides(). - template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return v.getSExtValue(); + return static_strides()[idx]; }] >, @@ -289,7 +280,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*defaultImplementation=*/[{ assert($_op.isDynamicOffset(idx) && "expected dynamic offset"); auto numDynamic = getNumDynamicEntriesUpToIdx( - static_offsets().template cast<::mlir::ArrayAttr>(), + static_offsets(), ::mlir::ShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + numDynamic; @@ -307,7 +298,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*defaultImplementation=*/[{ assert($_op.isDynamicSize(idx) && "expected dynamic size"); auto numDynamic = getNumDynamicEntriesUpToIdx( - static_sizes().template cast<::mlir::ArrayAttr>(), ::mlir::ShapedType::isDynamic, idx); + static_sizes(), ::mlir::ShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + offsets().size() + numDynamic; }] @@ -324,7 +315,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*defaultImplementation=*/[{ assert($_op.isDynamicStride(idx) && "expected dynamic stride"); auto numDynamic = getNumDynamicEntriesUpToIdx( - static_strides().template cast<::mlir::ArrayAttr>(), + static_strides(), ::mlir::ShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + @@ -333,20 +324,20 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface >, InterfaceMethod< /*desc=*/[{ - Helper method to compute the number of dynamic entries of `attr`, up to + Helper method to compute the number of dynamic entries of `staticVals`, up to `idx` using `isDynamic` to determine whether an entry is dynamic. }], /*retTy=*/"unsigned", /*methodName=*/"getNumDynamicEntriesUpToIdx", - /*args=*/(ins "::mlir::ArrayAttr":$attr, + /*args=*/(ins "::llvm::ArrayRef":$staticVals, "::llvm::function_ref":$isDynamic, "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ return std::count_if( - attr.getValue().begin(), attr.getValue().begin() + idx, - [&](::mlir::Attribute attr) { - return isDynamic(attr.cast<::mlir::IntegerAttr>().getInt()); + staticVals.begin(), staticVals.begin() + idx, + [&](int64_t val) { + return isDynamic(val); }); }] >, diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 42d2d9a1b3097..7da3c3693bb69 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1705,10 +1705,8 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern { auto viewMemRefType = subViewOp.getType(); auto inferredType = memref::SubViewOp::inferResultType( - subViewOp.getSourceType(), - extractFromI64ArrayAttr(subViewOp.getStaticOffsets()), - extractFromI64ArrayAttr(subViewOp.getStaticSizes()), - extractFromI64ArrayAttr(subViewOp.getStaticStrides())) + subViewOp.getSourceType(), subViewOp.getStaticOffsets(), + subViewOp.getStaticSizes(), subViewOp.getStaticStrides()) .cast(); auto targetElementTy = typeConverter->convertType(viewMemRefType.getElementType()); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 92bb30eefa5de..cb2eea2960e3d 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -30,8 +30,8 @@ class SliceOpConverter : public OpRewritePattern { PatternRewriter &rewriter) const final { Location loc = sliceOp.getLoc(); Value input = sliceOp.getInput(); - SmallVector strides, sizes; - auto starts = sliceOp.getStart(); + SmallVector strides, sizes, starts; + starts = extractFromI64ArrayAttr(sliceOp.getStart()); strides.resize(sliceOp.getType().template cast().getRank(), 1); SmallVector dynSizes; @@ -44,15 +44,15 @@ class SliceOpConverter : public OpRewritePattern { auto dim = rewriter.create(loc, input, index); auto offset = rewriter.create( - loc, - rewriter.getIndexAttr(starts[index].cast().getInt())); + loc, rewriter.getIndexAttr(starts[index])); dynSizes.push_back(rewriter.create(loc, dim, offset)); } auto newSliceOp = rewriter.create( sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, - ValueRange({}), starts, rewriter.getI64ArrayAttr(sizes), - rewriter.getI64ArrayAttr(strides)); + ValueRange({}), rewriter.getDenseI64ArrayAttr(starts), + rewriter.getDenseI64ArrayAttr(sizes), + rewriter.getDenseI64ArrayAttr(strides)); rewriter.replaceOp(sliceOp, newSliceOp.getResult()); return success(); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index f02ccfae68934..e6123a4f17749 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -40,16 +40,6 @@ static SmallVector extractUIntArray(ArrayAttr attr) { return result; } -/// Extracts a vector of int64_t from an array attribute. Asserts if the -/// attribute contains values other than integers. -static SmallVector extractI64Array(ArrayAttr attr) { - SmallVector result; - result.reserve(attr.size()); - for (APInt value : attr.getAsValueRange()) - result.push_back(value.getSExtValue()); - return result; -} - namespace { /// A simple pattern rewriter that implements no special logic. class SimpleRewriter : public PatternRewriter { @@ -1205,7 +1195,7 @@ transform::TileReductionUsingForeachThreadOp::applyToOne( DiagnosedSilenceableFailure transform::TileOp::apply(TransformResults &transformResults, TransformState &state) { - SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); + ArrayRef tileSizes = getStaticSizes(); ArrayRef targets = state.getPayloadOps(getTarget()); SmallVector> dynamicSizeProducers; @@ -1270,7 +1260,7 @@ transform::TileOp::apply(TransformResults &transformResults, }); } - tilingOptions.setInterchange(extractI64Array(getInterchange())); + tilingOptions.setInterchange(getInterchange()); SimpleRewriter rewriter(linalgOp.getContext()); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(linalgOp.getOperation()), @@ -1298,7 +1288,7 @@ transform::TileOp::apply(TransformResults &transformResults, SmallVector transform::TileOp::getMixedSizes() { ValueRange dynamic = getDynamicSizes(); - SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); + ArrayRef tileSizes = getStaticSizes(); SmallVector results; results.reserve(tileSizes.size()); unsigned dynamicPos = 0; @@ -1313,22 +1303,51 @@ SmallVector transform::TileOp::getMixedSizes() { return results; } +// We want to parse `DenseI64ArrayAttr` using the short form without the +// `array` prefix to be consistent in the IR with `parseDynamicIndexList`. +ParseResult parseOptionalInterchange(OpAsmParser &parser, + OperationState &result) { + if (succeeded(parser.parseOptionalLBrace())) { + if (failed(parser.parseKeyword("interchange"))) + return parser.emitError(parser.getNameLoc()) << "expect `interchange`"; + if (failed(parser.parseEqual())) + return parser.emitError(parser.getNameLoc()) << "expect `=`"; + result.addAttribute("interchange", + DenseI64ArrayAttr::parse(parser, Type{})); + if (failed(parser.parseRBrace())) + return parser.emitError(parser.getNameLoc()) << "expect `}`"; + } + return success(); +} + +void printOptionalInterchange(OpAsmPrinter &p, + ArrayRef interchangeVals) { + if (!interchangeVals.empty()) { + p << " {interchange = ["; + llvm::interleaveComma(interchangeVals, p, + [&](int64_t integer) { p << integer; }); + p << "]}"; + } +} + ParseResult transform::TileOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; - ArrayAttr staticSizes; + DenseI64ArrayAttr staticSizes; auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parseDynamicIndexList(parser, dynamicSizes, staticSizes) || - parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || - parser.parseOptionalAttrDict(result.attributes)) + parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands)) return ParseResult::failure(); + // Parse optional interchange. + if (failed(parseOptionalInterchange(parser, result))) + return ParseResult::failure(); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); size_t numExpectedLoops = - staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); + staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); return success(); } @@ -1336,7 +1355,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser, void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); - p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); + printOptionalInterchange(p, getInterchange()); } void transform::TileOp::getEffects( @@ -1379,13 +1398,13 @@ void transform::TileToForeachThreadOp::build( // bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); - auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); + auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); build(builder, result, /*resultTypes=*/TypeRange{operationType, operationType}, /*target=*/target, /*num_threads=*/ValueRange{}, /*tile_sizes=*/dynamicTileSizes, - /*static_num_threads=*/builder.getI64ArrayAttr({}), + /*static_num_threads=*/builder.getDenseI64ArrayAttr({}), /*static_tile_sizes=*/staticTileSizesAttr, /*mapping=*/mapping); } @@ -1414,14 +1433,14 @@ void transform::TileToForeachThreadOp::build( // bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); - auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); + auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); build(builder, result, /*resultTypes=*/TypeRange{operationType, operationType}, /*target=*/target, /*num_threads=*/dynamicNumThreads, /*tile_sizes=*/ValueRange{}, /*static_num_threads=*/staticNumThreadsAttr, - /*static_tile_sizes=*/builder.getI64ArrayAttr({}), + /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}), /*mapping=*/mapping); } @@ -1547,11 +1566,13 @@ void transform::TileToForeachThreadOp::getEffects( } SmallVector TileToForeachThreadOp::getMixedNumThreads() { - return getMixedValues(getStaticNumThreads(), getNumThreads()); + Builder b(getContext()); + return getMixedValues(getStaticNumThreads(), getNumThreads(), b); } SmallVector TileToForeachThreadOp::getMixedTileSizes() { - return getMixedValues(getStaticTileSizes(), getTileSizes()); + Builder b(getContext()); + return getMixedValues(getStaticTileSizes(), getTileSizes(), b); } LogicalResult TileToForeachThreadOp::verify() { @@ -1567,7 +1588,7 @@ LogicalResult TileToForeachThreadOp::verify() { DiagnosedSilenceableFailure transform::TileToScfForOp::apply(TransformResults &transformResults, TransformState &state) { - SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); + ArrayRef tileSizes = getStaticSizes(); ArrayRef targets = state.getPayloadOps(getTarget()); SmallVector> dynamicSizeProducers; @@ -1632,7 +1653,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults, }); } - tilingOptions.setInterchange(extractI64Array(getInterchange())); + tilingOptions.setInterchange(getInterchange()); SimpleRewriter rewriter(tilingInterfaceOp.getContext()); FailureOr tilingResult = tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); @@ -1655,7 +1676,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults, SmallVector transform::TileToScfForOp::getMixedSizes() { ValueRange dynamic = getDynamicSizes(); - SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); + ArrayRef tileSizes = getStaticSizes(); SmallVector results; results.reserve(tileSizes.size()); unsigned dynamicPos = 0; @@ -1674,18 +1695,20 @@ ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; - ArrayAttr staticSizes; + DenseI64ArrayAttr staticSizes; auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parseDynamicIndexList(parser, dynamicSizes, staticSizes) || - parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || - parser.parseOptionalAttrDict(result.attributes)) + parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands)) return ParseResult::failure(); + // Parse optional interchange. + if (failed(parseOptionalInterchange(parser, result))) + return ParseResult::failure(); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); size_t numExpectedLoops = - staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); + staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); return success(); } @@ -1693,7 +1716,7 @@ ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser, void TileToScfForOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); - p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); + printOptionalInterchange(p, getInterchange()); } void transform::TileToScfForOp::getEffects( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 2b84860c5e735..91e0ae41f3914 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -348,7 +348,7 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, SmallVector outputExprs; for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + - padOp.getStaticLow()[i].cast().getInt()); + padOp.getStaticLow()[i]); } SmallVector transferMaps = { diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 503c8aed2709d..4ba4050bb45e6 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1776,8 +1776,9 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamic); build(b, result, resultType, source, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } @@ -1823,8 +1824,8 @@ LogicalResult ReinterpretCastOp::verify() { << srcType << " and result memref type " << resultType; // Match sizes in result memref type and in static_sizes attribute. - for (auto &en : llvm::enumerate(llvm::zip( - resultType.getShape(), extractFromI64ArrayAttr(getStaticSizes())))) { + for (auto &en : + llvm::enumerate(llvm::zip(resultType.getShape(), getStaticSizes()))) { int64_t resultSize = std::get<0>(en.value()); int64_t expectedSize = std::get<1>(en.value()); if (!ShapedType::isDynamic(resultSize) && @@ -1844,7 +1845,7 @@ LogicalResult ReinterpretCastOp::verify() { << resultType; // Match offset in result memref type and in static_offsets attribute. - int64_t expectedOffset = extractFromI64ArrayAttr(getStaticOffsets()).front(); + int64_t expectedOffset = getStaticOffsets().front(); if (!ShapedType::isDynamic(resultOffset) && !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset) @@ -1852,8 +1853,8 @@ LogicalResult ReinterpretCastOp::verify() { << resultOffset << " instead of " << expectedOffset; // Match strides in result memref type and in static_strides attribute. - for (auto &en : llvm::enumerate(llvm::zip( - resultStrides, extractFromI64ArrayAttr(getStaticStrides())))) { + for (auto &en : + llvm::enumerate(llvm::zip(resultStrides, getStaticStrides()))) { int64_t resultStride = std::get<0>(en.value()); int64_t expectedStride = std::get<1>(en.value()); if (!ShapedType::isDynamic(resultStride) && @@ -2665,8 +2666,9 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, .cast(); } build(b, result, resultType, source, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } @@ -2831,9 +2833,7 @@ LogicalResult SubViewOp::verify() { // Verify result type against inferred type. auto expectedType = SubViewOp::inferResultType( - baseType, extractFromI64ArrayAttr(getStaticOffsets()), - extractFromI64ArrayAttr(getStaticSizes()), - extractFromI64ArrayAttr(getStaticStrides())); + baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()); auto result = isRankReducedMemRefType(expectedType.cast(), subViewType, getMixedSizes()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index faa00e2c97811..fae68a0a349e8 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -45,9 +45,8 @@ static void replaceUsesAndPropagateType(Operation *oldOp, Value val, builder.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( subviewUse.getType().getShape(), val.getType().cast(), - extractFromI64ArrayAttr(subviewUse.getStaticOffsets()), - extractFromI64ArrayAttr(subviewUse.getStaticSizes()), - extractFromI64ArrayAttr(subviewUse.getStaticStrides())); + subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), + subviewUse.getStaticStrides()); Value newSubview = builder.create( subviewUse->getLoc(), newType.cast(), val, subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 23af46c6d7912..f279876d19541 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -337,8 +337,7 @@ struct TensorCastExtractSlice : public OpRewritePattern { SmallVector sizes = extractOperand.getMixedSizes(); auto dimMask = computeRankReductionMask( - extractFromI64ArrayAttr(extractOperand.getStaticSizes()), - extractOperand.getType().getShape()); + extractOperand.getStaticSizes(), extractOperand.getType().getShape()); size_t dimIndex = 0; for (size_t i = 0, e = sizes.size(); i < e; i++) { if (dimMask && dimMask->count(i)) @@ -1713,8 +1712,9 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result, .cast(); } build(b, result, resultType, source, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } @@ -1949,13 +1949,13 @@ class ConstantOpExtractSliceFolder final return failure(); // Check if there are any dynamic parts, which are not supported. - auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets()); + auto offsets = op.getStaticOffsets(); if (llvm::is_contained(offsets, ShapedType::kDynamic)) return failure(); - auto sizes = extractFromI64ArrayAttr(op.getStaticSizes()); + auto sizes = op.getStaticSizes(); if (llvm::is_contained(sizes, ShapedType::kDynamic)) return failure(); - auto strides = extractFromI64ArrayAttr(op.getStaticStrides()); + auto strides = op.getStaticStrides(); if (llvm::is_contained(strides, ShapedType::kDynamic)) return failure(); @@ -2124,8 +2124,9 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamic); build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } @@ -2153,17 +2154,14 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, /// Rank-reducing type verification for both InsertSliceOp and /// ParallelInsertSliceOp. -static SliceVerificationResult -verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, - ArrayAttr staticOffsets, ArrayAttr staticSizes, - ArrayAttr staticStrides, - ShapedType *expectedType = nullptr) { +static SliceVerificationResult verifyInsertSliceOp( + ShapedType srcType, ShapedType dstType, ArrayRef staticOffsets, + ArrayRef staticSizes, ArrayRef staticStrides, + ShapedType *expectedType = nullptr) { // insert_slice is the inverse of extract_slice, use the same type // inference. RankedTensorType expected = ExtractSliceOp::inferResultType( - dstType, extractFromI64ArrayAttr(staticOffsets), - extractFromI64ArrayAttr(staticSizes), - extractFromI64ArrayAttr(staticStrides)); + dstType, staticOffsets, staticSizes, staticStrides); if (expectedType) *expectedType = expected; return isRankReducedType(expected, srcType); @@ -2482,9 +2480,8 @@ ParseResult parseInferType(OpAsmParser &parser, LogicalResult PadOp::verify() { auto sourceType = getSource().getType().cast(); auto resultType = getResult().getType().cast(); - auto expectedType = PadOp::inferResultType( - sourceType, extractFromI64ArrayAttr(getStaticLow()), - extractFromI64ArrayAttr(getStaticHigh())); + auto expectedType = + PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh()); for (int i = 0, e = sourceType.getRank(); i < e; ++i) { if (resultType.getDimSize(i) == expectedType.getDimSize(i)) continue; @@ -2556,8 +2553,9 @@ void PadOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef attrs) { auto sourceType = source.getType().cast(); auto resultType = inferResultType(sourceType, staticLow, staticHigh); - build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow), - b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); + build(b, result, resultType, source, low, high, + b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh), + nofold ? b.getUnitAttr() : UnitAttr()); result.addAttributes(attrs); } @@ -2591,7 +2589,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, } assert(resultType.isa()); build(b, result, resultType, source, dynamicLow, dynamicHigh, - b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh), + b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); result.addAttributes(attrs); } @@ -2658,8 +2656,7 @@ struct FoldSourceTensorCast : public OpRewritePattern { auto newResultType = PadOp::inferResultType( castOp.getSource().getType().cast(), - extractFromI64ArrayAttr(padTensorOp.getStaticLow()), - extractFromI64ArrayAttr(padTensorOp.getStaticHigh()), + padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), padTensorOp.getResultType().getShape()); if (newResultType == padTensorOp.getResultType()) { @@ -2940,8 +2937,9 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamic); build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } @@ -3086,12 +3084,12 @@ template static SmallVector getMixedTilesImpl(OpTy op) { static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); + Builder builder(op); SmallVector mixedInnerTiles; unsigned dynamicValIndex = 0; - for (Attribute attr : op.getStaticInnerTiles()) { - auto tileAttr = attr.cast(); - if (!ShapedType::isDynamic(tileAttr.getInt())) - mixedInnerTiles.push_back(tileAttr); + for (int64_t staticTile : op.getStaticInnerTiles()) { + if (!ShapedType::isDynamic(staticTile)) + mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile)); else mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]); } diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 5694cfeb5130f..432e75618917c 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -137,4 +137,41 @@ SmallVector getAsValues(OpBuilder &b, Location loc, return getValueOrCreateConstantIndexOp(b, loc, value); })); } + +/// Return a vector of OpFoldResults with the same size a staticValues, but all +/// elements for which ShapedType::isDynamic is true, will be replaced by +/// dynamicValues. +SmallVector getMixedValues(ArrayRef staticValues, + ValueRange dynamicValues, Builder &b) { + SmallVector res; + res.reserve(staticValues.size()); + unsigned numDynamic = 0; + unsigned count = static_cast(staticValues.size()); + for (unsigned idx = 0; idx < count; ++idx) { + int64_t value = staticValues[idx]; + res.push_back(ShapedType::isDynamic(value) + ? OpFoldResult{dynamicValues[numDynamic++]} + : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])}); + } + return res; +} + +/// Decompose a vector of mixed static or dynamic values into the corresponding +/// pair of arrays. This is the inverse function of `getMixedValues`. +std::pair> +decomposeMixedValues(Builder &b, + const SmallVectorImpl &mixedValues) { + SmallVector staticValues; + SmallVector dynamicValues; + for (const auto &it : mixedValues) { + if (it.is()) { + staticValues.push_back(it.get().cast().getInt()); + } else { + staticValues.push_back(ShapedType::kDynamic); + dynamicValues.push_back(it.get()); + } + } + return {b.getI64ArrayAttr(staticValues), dynamicValues}; +} + } // namespace mlir diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp index 775d26a6d1590..9a39a217bf442 100644 --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -20,15 +20,15 @@ using namespace mlir; LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned numElements, - ArrayAttr attr, + ArrayRef staticVals, ValueRange values) { - /// Check static and dynamic offsets/sizes/strides does not overflow type. - if (attr.size() != numElements) + // Check static and dynamic offsets/sizes/strides does not overflow type. + if (staticVals.size() != numElements) return op->emitError("expected ") << numElements << " " << name << " values"; unsigned expectedNumDynamicEntries = - llvm::count_if(attr.getValue(), [&](Attribute attr) { - return ShapedType::isDynamic(attr.cast().getInt()); + llvm::count_if(staticVals, [&](int64_t staticVal) { + return ShapedType::isDynamic(staticVal); }); if (values.size() != expectedNumDynamicEntries) return op->emitError("expected ") @@ -70,19 +70,19 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { } void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, - OperandRange values, ArrayAttr integers) { + OperandRange values, + ArrayRef integers) { printer << '['; if (integers.empty()) { printer << "]"; return; } unsigned idx = 0; - llvm::interleaveComma(integers, printer, [&](Attribute a) { - int64_t val = a.cast().getInt(); - if (ShapedType::isDynamic(val)) + llvm::interleaveComma(integers, printer, [&](int64_t integer) { + if (ShapedType::isDynamic(integer)) printer << values[idx++]; else - printer << val; + printer << integer; }); printer << ']'; } @@ -90,28 +90,28 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - ArrayAttr &integers) { + DenseI64ArrayAttr &integers) { if (failed(parser.parseLSquare())) return failure(); // 0-D. if (succeeded(parser.parseOptionalRSquare())) { - integers = parser.getBuilder().getArrayAttr({}); + integers = parser.getBuilder().getDenseI64ArrayAttr({}); return success(); } - SmallVector attrVals; + SmallVector integerVals; while (true) { OpAsmParser::UnresolvedOperand operand; auto res = parser.parseOptionalOperand(operand); if (res.has_value() && succeeded(res.value())) { values.push_back(operand); - attrVals.push_back(ShapedType::kDynamic); + integerVals.push_back(ShapedType::kDynamic); } else { - IntegerAttr attr; - if (failed(parser.parseAttribute(attr))) + int64_t integer; + if (failed(parser.parseInteger(integer))) return parser.emitError(parser.getNameLoc()) << "expected SSA value or integer"; - attrVals.push_back(attr.getInt()); + integerVals.push_back(integer); } if (succeeded(parser.parseOptionalComma())) @@ -120,7 +120,7 @@ ParseResult mlir::parseDynamicIndexList( return failure(); break; } - integers = parser.getBuilder().getI64ArrayAttr(attrVals); + integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); return success(); } @@ -144,34 +144,3 @@ bool mlir::detail::sameOffsetsSizesAndStrides( return false; return true; } - -SmallVector mlir::getMixedValues(ArrayAttr staticValues, - ValueRange dynamicValues) { - SmallVector res; - res.reserve(staticValues.size()); - unsigned numDynamic = 0; - unsigned count = static_cast(staticValues.size()); - for (unsigned idx = 0; idx < count; ++idx) { - APInt value = staticValues[idx].cast().getValue(); - res.push_back(ShapedType::isDynamic(value.getSExtValue()) - ? OpFoldResult{dynamicValues[numDynamic++]} - : OpFoldResult{staticValues[idx]}); - } - return res; -} - -std::pair> -mlir::decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues) { - SmallVector staticValues; - SmallVector dynamicValues; - for (const auto &it : mixedValues) { - if (it.is()) { - staticValues.push_back(it.get().cast().getInt()); - } else { - staticValues.push_back(ShapedType::kDynamic); - dynamicValues.push_back(it.get()); - } - } - return {b.getI64ArrayAttr(staticValues), dynamicValues}; -} diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 527a8656f7e33..5fd5cfe1073ad 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -49,6 +49,15 @@ def _get_int_array_attr( return ArrayAttr.get([_get_int64_attr(v) for v in values]) +def _get_dense_int64_array_attr( + values: Sequence[int]) -> DenseI64ArrayAttr: + """Creates a dense integer array from a sequence of integers. + Expects the thread-local MLIR context to have been set by the context + manager. + """ + if values is None: + return DenseI64ArrayAttr.get([]) + return DenseI64ArrayAttr.get(values) def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, @@ -250,14 +259,11 @@ def __init__(self, else: for size in sizes: if isinstance(size, int): - static_sizes.append(IntegerAttr.get(i64_type, size)) - elif isinstance(size, IntegerAttr): static_sizes.append(size) else: - static_sizes.append( - IntegerAttr.get(i64_type, ShapedType.get_dynamic_size())) + static_sizes.append(ShapedType.get_dynamic_size()) dynamic_sizes.append(_get_op_result_or_value(size)) - sizes_attr = ArrayAttr.get(static_sizes) + sizes_attr = DenseI64ArrayAttr.get(static_sizes) num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) @@ -266,14 +272,14 @@ def __init__(self, _get_op_result_or_value(target), dynamic_sizes=dynamic_sizes, static_sizes=sizes_attr, - interchange=_get_int_array_attr(interchange) if interchange else None, + interchange=_get_dense_int64_array_attr(interchange) if interchange else None, loc=loc, ip=ip) - def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]: + def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: if not attr: return [] - return [IntegerAttr(element).value for element in attr] + return [element for element in attr] class VectorizeOp: diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 06c52f50e0fa2..482cbc786d485 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -138,7 +138,7 @@ func.func @permute_generic(%A: memref>, transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - transform.structured.interchange %0 { iterator_interchange = [1, 2, 0]} + transform.structured.interchange %0 {iterator_interchange = [1, 2, 0]} } // CHECK-LABEL: func @permute_generic @@ -191,8 +191,8 @@ func.func @matmul_perm(%A: memref>, transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange=[1, 2, 0]} - %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange=[1, 0, 2]} + %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]} + %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]} %3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40] } diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index f52c4b6d63b33..34c86a317920b 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -108,7 +108,6 @@ def testSplit(): # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3 - @run def testTileCompact(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) @@ -120,14 +119,11 @@ def testTileCompact(): # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] # CHECK: interchange = [0, 1] - @run def testTileAttributes(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - attr = ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]]) - ichange = ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]]) + attr = DenseI64ArrayAttr.get([4, 8]) + ichange = DenseI64ArrayAttr.get([0, 1]) with InsertionPoint(sequence.body): structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) transform.YieldOp() @@ -136,7 +132,6 @@ def testTileAttributes(): # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] # CHECK: interchange = [0, 1] - @run def testTileZero(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) @@ -149,7 +144,6 @@ def testTileZero(): # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0] # CHECK: interchange = [0, 1, 2, 3] - @run def testTileDynamic(): with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get())