diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 69426cf1fded9..ab3fdc72b2266 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1424,7 +1424,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /// `linalgOp`. Operation::operand_range getAssumedNonShapedOperands() { Operation::operand_range res{ - getOperation()->getOperands().begin() + getNumShapedOperands(), + getOperation()->getOperands().begin() + getNumInputsAndOutputs(), getOperation()->getOperands().end()}; for (Type t : TypeRange{res}) { (void)t; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 0a1b4adc3aeee..41fcc2495e658 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -140,7 +140,7 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> { // Rank-polymorphic. // filling_value -> O(ivs) with parallel iterators. ArrayAttr iterator_types() { - unsigned nPar = getInputShapedType(0).getRank(); + int64_t nPar = getRank(getInputOperand(0)); return Builder(getContext()).getStrArrayAttr( SmallVector(nPar, getParallelIteratorTypeName())); } @@ -150,8 +150,8 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> { MLIRContext *context = getContext(); auto maybeInputMap = inputPermutation(); auto maybeOutputMap = outputPermutation(); - unsigned inputRank = getInputShapedType(0).getRank(); - unsigned outputRank = getOutputShapedType(0).getRank(); + int64_t inputRank = getRank(getInputOperand(0)); + int64_t outputRank = getRank(getOutputOperand(0)); return Builder(getContext()).getAffineMapArrayAttr({ extractOrIdentityMap(maybeInputMap, inputRank, context), extractOrIdentityMap(maybeOutputMap, outputRank, context)}); @@ -195,7 +195,7 @@ def FillOp : LinalgStructured_Op<"fill", []> { // Rank-polymorphic. // filling_value -> O(ivs) with parallel iterators. ArrayAttr iterator_types() { - unsigned nPar = getOutputShapedType(0).getRank(); + int64_t nPar = getRank(getOutputOperand(0)); return Builder(getContext()).getStrArrayAttr( SmallVector(nPar, getParallelIteratorTypeName())); } @@ -351,14 +351,14 @@ def ConvOp : PoolingBase_Op<"conv", []> { unsigned getNumOutputFeatureDimensions() { return 1; } unsigned getNumSpatialDimensions() { - return getOutputShapedType(0).getRank() - getNumBatchDimensions() - + return getRank(getOutputOperand(0)) - getNumBatchDimensions() - getNumOutputFeatureDimensions(); } ArrayAttr iterator_types() { // Outer parallel loops are always the number of output dimensions; i.e. // [b, xs, q] in the TF notation above. - unsigned nPar = getOutputShapedType(0).getRank(); + int64_t nPar = getRank(getOutputOperand(0)); unsigned nRed = getNumInputFeatureDimensions(); // Window loops are a special kind of reduction that is never tiled or // parallelized across; i.e. [zs] in the TF notation above whose number @@ -457,7 +457,7 @@ class SingleInputPoolingBase_Op ArrayAttr iterator_types() { // Outer parallel loops are always the number of output dimensions. - unsigned nPar = getOutputShapedType(0).getRank(); + int64_t nPar = getRank(getOutputOperand(0)); // The window loops has the same number loops with output dimensions. unsigned nWin = nPar; SmallVector iters(nPar, getParallelIteratorTypeName()); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index f565e0091c448..8a48e89cda530 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -194,21 +194,18 @@ SmallVector mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, Location loc) { SmallVector res; - for (Value v : getShapedOperands()) { - ShapedType t = v.getType().template cast(); - for (unsigned i = 0, e = t.getRank(); i < e; ++i) - res.push_back(b.createOrFold(loc, v, i)); + for (OpOperand *opOperand : getInputAndOutputOperands()) { + for (int64_t i = 0, e = getRank(opOperand); i < e; ++i) + res.push_back(b.createOrFold(loc, opOperand->get(), i)); } return res; } SmallVector LinalgOp::createFlatListOfOperandStaticDims() { SmallVector res; - for (Value v : getShapedOperands()) { - ShapedType t = v.getType().template cast(); - assert(t.hasStaticShape() && "expected operands to have static shapes"); - llvm::append_range(res, t.getShape()); - } + assert(!hasDynamicShape() && "expected operands to have static shapes"); + for (OpOperand *opOperand : getInputAndOutputOperands()) + llvm::append_range(res, getShape(opOperand)); return res; } @@ -302,15 +299,14 @@ LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim( auto allResultDimValues = applyMapToValues(b, loc, resultShapesFromInputShapesMap, createFlatListOfOperandDims(b, loc)); - unsigned pos = 0; + int64_t pos = 0; ArrayRef shapeExprs = resultShapesFromInputShapesMap.getResults(); - for (auto resultIdx : llvm::seq(0, getNumOutputs())) { - ShapedType resultType = getOutputShapedType(resultIdx); + for (OpOperand *opOperand : getOutputOperands()) { SmallVector shapes; - for (unsigned dim : llvm::seq(0, resultType.getRank())) { + for (int64_t dim : llvm::seq(0, getRank(opOperand))) { if (checkDimExpr.visit(shapeExprs[pos])) shapes.push_back( - b.createOrFold(loc, getOutput(resultIdx), dim)); + b.createOrFold(loc, opOperand->get(), dim)); else shapes.push_back(allResultDimValues[pos]); pos++; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index c2c1633376595..47c6bc70339f1 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -699,15 +699,19 @@ static void getGenericEffectsImpl( void GenericOp::getEffects( SmallVectorImpl> &effects) { - getGenericEffectsImpl(effects, getOperation()->getResults(), - getInputBuffers(), getOutputBuffers()); + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); + getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, + outputBuffers); } void IndexedGenericOp::getEffects( SmallVectorImpl> &effects) { - getGenericEffectsImpl(effects, getOperation()->getResults(), - getInputBuffers(), getOutputBuffers()); + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); + getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, + outputBuffers); } template diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp index 71fb4364ac6a3..9e86bf3227c1a 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -2107,8 +2107,10 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os, } void {0}::getEffects(SmallVectorImpl< SideEffects::EffectInstance >&effects) {{ + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); getGenericEffectsImpl(effects, - getOperation()->getResults(), getInputBuffers(), getOutputBuffers()); + getOperation()->getResults(), inputBuffers, outputBuffers); })FMT"; os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName); } diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 318613965ea30..312724c94b680 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -551,8 +551,10 @@ LogicalResult {0}::fold(ArrayRef, } void {0}::getEffects(SmallVectorImpl< SideEffects::EffectInstance >&effects) {{ - getGenericEffectsImpl(effects, - getOperation()->getResults(), getInputBuffers(), getOutputBuffers()); + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); + getGenericEffectsImpl(effects, + getOperation()->getResults(), inputBuffers, outputBuffers); } )FMT";