From 6b649570cbc44dd775d9657805cc60b2075d8011 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 29 Sep 2020 08:23:37 -0400 Subject: [PATCH] [mlir][Linalg] Refactor Linalg op initTensors support - NFC Manually-defined named ops do not currently support `init_tensors` or return values and may never support them. Add extra interface to the StructuredOpInterface so that we can still write op-agnostic transformations based on StructuredOpInterface. This is an NFC extension in preparation for tiling on tensors. Differential Revision: https://reviews.llvm.org/D88481 --- .../Dialect/Linalg/IR/LinalgStructuredOps.td | 33 +++++++-- .../Linalg/IR/LinalgStructuredOpsInterface.td | 69 ++++++++++++++++--- .../mlir/Dialect/Linalg/IR/LinalgTraits.h | 14 ++++ mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +- 5 files changed, 101 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index ed87689822e5fe..d123229337370e 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -22,14 +22,19 @@ include "mlir/Interfaces/CopyOpInterface.td" // The Linalg `NInputs` trait provides the API for ops that are known // to have a specified number of inputs, all passed as operands. // See Linalg/LinalgTraits.h for implementation details and usage. -class NInputs : - NativeOpTrait<"linalg::NInputs<" # !cast(args_in) # ">::Impl"> {} +class NInputs : + NativeOpTrait<"linalg::NInputs<" # !cast(n) # ">::Impl"> {} + +// The Linalg `ZeroInitTensors` trait provides the API for ops that are known +// to not have input tensor operands. +// See Linalg/LinalgTraits.h for implementation details and usage. +def ZeroInitTensors : NativeOpTrait<"linalg::ZeroInitTensors"> {} // The Linalg `NOutputs` trait provides the API for ops that are known // to have a specified number of outputs, all passed as operands. // See Linalg/LinalgTraits.h for implementation details and usage. -class NOutputs : - NativeOpTrait<"linalg::NOutputs<" # !cast(args_out) # ">::Impl"> {} +class NOutputs : + NativeOpTrait<"linalg::NOutputs<" # !cast(n) # ">::Impl"> {} def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">; def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">; @@ -62,6 +67,7 @@ class LinalgStructured_Op props> def CopyOp : LinalgStructured_Op<"copy", [ CopyOpInterface, NInputs<1>, + ZeroInitTensors, NOutputs<1> ]> { let description = [{ @@ -159,7 +165,10 @@ def CopyOp : LinalgStructured_Op<"copy", [ let hasCanonicalizer = 1; } -def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { +def FillOp : LinalgStructured_Op<"fill", [ + NInputs<0>, + ZeroInitTensors, + NOutputs<1>]> { let arguments = (ins AnyStridedMemRef:$output, AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); @@ -254,7 +263,12 @@ class PoolingBase_Op props> }]; } -def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { +def ConvOp : PoolingBase_Op<"conv", [ + NInputs<2>, + // Despite having reductions, this manually defined ConvOp may only take + // memref operands and can never have init tensors. + ZeroInitTensors, + NOutputs<1>]> { let description = [{ Generic n-D convolution as described in the TF documentation: @@ -371,7 +385,12 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> { } class SingleInputPoolingBase_Op - : PoolingBase_Op, NOutputs<1>]> { + : PoolingBase_Op, + // Despite having reductions, this manually defined ConvOp may only take + // memref operands and can never have init tensors. + ZeroInitTensors, + NOutputs<1>]> { let description = [{ A base class for single input pooling function. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td index 17e16a15d39a38..23d296c392ff90 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -125,13 +125,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { getNumIterators(getReductionIteratorTypeName(), iters) == 1; }]>, //===------------------------------------------------------------------===// - // Num input/output arguments handling. + // Num input/output/initTensors arguments handling. //===------------------------------------------------------------------===// // These special methods must be defined by each op that wants to implement // the LinalgStructuredInterface. For now, this is either: - // - inherited statically by using the NInputs or - // NOutputs traits. - // - derived from args_in/args_out attributes (for linalg.generic and + // - Explicitly specified in the op definition. + // - Derived from variadic attributes (for "named" ops, linalg.generic and // linalg.indexed_generic ops). InterfaceMethod< /*desc=*/[{ @@ -140,6 +139,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*retTy=*/"unsigned", /*methodName=*/"getNumInputs" >, + InterfaceMethod< + /*desc=*/[{ + Return the number of init tensors. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumInitTensors" + >, InterfaceMethod< /*desc=*/[{ Return the number of outputs. @@ -371,6 +377,46 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; }] >, + InterfaceMethod< + /*desc=*/[{ + Return the range over init tensors. + }], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getInitTensors", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = this->getOperation()->getOperands(); + return {range.begin() + getNumInputsAndOutputBuffers(), + range.begin() + getNumInputsAndOutputs()}; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return one single init tensor at position `$i`. + }], + /*retTy=*/"Value", + /*methodName=*/"getInitTensor", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i < $_op.getNumInitTensors() && "overflowing init tensor index"); + return getInitTensors()[i]; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the range over inputs, output buffers and init tensors. + }], + /*retTy=*/"Operation::operand_range", + /*methodName=*/"getShapedOperands", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = this->getOperation()->getOperands(); + return {range.begin(), range.begin() + getNumInputsAndOutputs()}; + }] + >, InterfaceMethod< /*desc=*/[{ Return the `i`-th shaped type, there are 3 cases: @@ -445,7 +491,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::to_vector<4>($_op.indexing_maps().template getAsValueRange()); + return llvm::to_vector<4>( + $_op.indexing_maps().template getAsValueRange()); }] >, InterfaceMethod< @@ -528,11 +575,11 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { }], /*retTy=*/"Operation *", /*methodName=*/"create", - (ins "OpBuilder &":$builder, "Location":$loc, + (ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes, "ValueRange":$operands, "ArrayRef":$attributes), [{ - return builder.create(loc, TypeRange{}, operands, - attributes); + return builder.create( + loc, resultTypes, operands, attributes); }] >, InterfaceMethod< @@ -542,10 +589,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { }], /*retTy=*/"Operation *", /*methodName=*/"clone", - (ins "OpBuilder &":$b, "Location":$loc, "ValueRange":$operands), [{ + (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, + "ValueRange":$operands), + [{ BlockAndValueMapping map; unsigned numRegions = $_op.getOperation()->getNumRegions(); - Operation *res = create(b, loc, operands, $_op.getAttrs()); + Operation *res = create(b, loc, resultTypes, operands, $_op.getAttrs()); assert(res->getNumRegions() == numRegions && "inconsistent # regions"); for (unsigned ridx = 0; ridx < numRegions; ++ridx) $_op.getOperation()->getRegion(ridx).cloneInto( diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h index 1df2b21bdade68..5f1c756ca446f7 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -35,6 +35,17 @@ template class NInputs { }; }; +/// This class provides the API for ops that are known to not have init tensor +/// operands. Use as a trait as follows: +/// +/// class CopyOp : public Op { +/// +template +class ZeroInitTensors : public TraitBase { +public: + static unsigned getNumInitTensors() { return 0; } +}; + /// This class provides the API for ops that are known to have a specified /// number of outputs, all passed as operands. Use as a trait as follows: /// @@ -87,6 +98,9 @@ class NamedStructuredOpTrait unsigned getNumInputs() { return cast(this->getOperation()).inputs().size(); } + unsigned getNumInitTensors() { + return cast(this->getOperation()).init_tensors().size(); + } unsigned getNumOutputs() { ConcreteType concreteOp = cast(this->getOperation()); return concreteOp.output_buffers().size() + diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 04d417480f3bf8..dfc977daa2071b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -99,7 +99,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, auto operands = getAssumedNonViewOperands(op); clonedViews.append(operands.begin(), operands.end()); - Operation *clonedOp = op.clone(b, loc, clonedViews); + Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews); // When the producer is an IndexedGenercOp, we have to transform its block // IV arguments according to the tiling of the consumer, i.e. offset them by // the values computed in `loopRanges`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 676caa145c3a25..3db801bc2d575a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -405,7 +405,7 @@ Optional static tileLinalgOpImpl( tileSizes, allViewSizes); auto operands = getAssumedNonViewOperands(op); views.append(operands.begin(), operands.end()); - res = op.clone(b, loc, views); + res = op.clone(b, loc, /*resultTypes*/ {}, views); return scf::ValueVector{}; }, options.distribution);