diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 78431b9f66f90..839861c2369ca 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -13,6 +13,7 @@ #ifndef LINALG_IR_LINALGINTERFACES #define LINALG_IR_LINALGINTERFACES +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/IR/OpBase.td" // The 'LinalgContractionOpInterface' provides access to the @@ -178,7 +179,8 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> { } // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. -def LinalgStructuredInterface : OpInterface<"LinalgOp"> { +def LinalgStructuredInterface + : OpInterface<"LinalgOp", [DestinationStyleOpInterface]> { let cppNamespace = "::mlir::linalg"; let methods = [ //===------------------------------------------------------------------===// @@ -321,13 +323,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - // MLIR currently does not support dependent interfaces or interface - // inheritance. By construction all ops with StructuredOpInterface must - // implement DestinationStyleOpInterface. - // TODO: reevaluate the need for a cast when a better mechanism exists. - return getBlock()->getArguments().take_front( - cast(*this->getOperation()) - .getNumDpsInputs()); + return getBlock()->getArguments().take_front($_op.getNumDpsInputs()); }] >, InterfaceMethod< @@ -339,13 +335,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - // MLIR currently does not support dependent interfaces or interface - // inheritance. By construction all ops with StructuredOpInterface must - // implement DestinationStyleOpInterface. - // TODO: reevaluate the need for a cast when a better mechanism exists. - return getBlock()->getArguments().take_back( - cast(*this->getOperation()) - .getNumDpsInits()); + return getBlock()->getArguments().take_back($_op.getNumDpsInits()); }] >, InterfaceMethod< @@ -418,13 +408,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { assert(result.getOwner() == this->getOperation()); auto indexingMaps = $_op.getIndexingMaps().template getAsValueRange(); - // MLIR currently does not support dependent interfaces or interface - // inheritance. By construction all ops with StructuredOpInterface must - // implement DestinationStyleOpInterface. - // TODO: reevaluate the need for a cast when a better mechanism exists. - return *(indexingMaps.begin() + - cast(*this->getOperation()) - .getNumDpsInputs() + + return *(indexingMaps.begin() + $_op.getNumDpsInputs() + result.getResultNumber()); }] >, @@ -439,14 +423,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); - // MLIR currently does not support dependent interfaces or interface - // inheritance. By construction all ops with StructuredOpInterface must - // implement DestinationStyleOpInterface. - // TODO: reevaluate the need for a cast when a better mechanism exists. int64_t resultIndex = - opOperand->getOperandNumber() - - cast(*this->getOperation()) - .getNumDpsInputs(); + opOperand->getOperandNumber() - $_op.getNumDpsInputs(); assert(resultIndex >= 0 && resultIndex < this->getOperation()->getNumResults()); Operation *yieldOp = getBlock()->getTerminator(); @@ -800,80 +778,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /// Return the index in the indexingMaps vector that corresponds to this `opOperand` int64_t getIndexingMapIndex(OpOperand *opOperand); - - //========================================================================// - // Forwarding functions to access interface methods from the - // DestinationStyleOpInterface. - // MLIR currently does not support dependent interfaces or interface - // inheritance. By construction all ops with StructuredOpInterface must - // implement DestinationStyleOpInterface. - // TODO: reevaluate the need for a cast when a better mechanism exists. - //========================================================================// - - int64_t getNumDpsInputs() { - return cast(*this->getOperation()) - .getNumDpsInputs(); - } - - int64_t getNumDpsInits() { - return cast(*this->getOperation()) - .getNumDpsInits(); - } - - OpOperandVector getDpsInputOperands() { - return cast(*this->getOperation()) - .getDpsInputOperands(); - } - - OpOperand *getDpsInputOperand(int64_t i) { - return cast(*this->getOperation()) - .getDpsInputOperand(i); - } - - void setDpsInitOperand(int64_t i, Value value) { - return cast(*this->getOperation()) - .setDpsInitOperand(i, value); - } - - OpOperandVector getDpsInitOperands() { - return cast(*this->getOperation()) - .getDpsInitOperands(); - } - - OpOperand *getDpsInitOperand(int64_t i) { - return cast(*this->getOperation()) - .getDpsInitOperand(i); - } - - bool isDpsInput(OpOperand *opOperand) { - return cast(*this->getOperation()) - .isDpsInput(opOperand); - } - - bool isDpsInit(OpOperand *opOperand) { - return cast(*this->getOperation()) - .isDpsInit(opOperand); - } - - bool isScalar(OpOperand *opOperand) { - return cast(*this->getOperation()) - .isScalar(opOperand); - } - - OpResult getTiedOpResult(OpOperand *opOperand) { - return cast(*this->getOperation()) - .getTiedOpResult(opOperand); - } - - bool hasBufferSemantics() { - return cast(*this->getOperation()) - .hasBufferSemantics(); - } - - bool hasTensorSemantics() { - return cast(*this->getOperation()) - .hasTensorSemantics(); - } }]; let verify = [{ return detail::verifyStructuredOpInterface($_op); }];