diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 48978eb7663d5..771d753a8bddb 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to); std::optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); +//===----------------------------------------------------------------------===// +// Convolution matcher utility +//===----------------------------------------------------------------------===// + +template +bool isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 40fc0d68e358f..0b3662c888010 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,80 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } +/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy` +/// with `dilations` and `strides`. +template +static FailureOr +specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, + ArrayRef dilations, ArrayRef strides) { + SmallVector inputs = genericOp.getDpsInputs(); + ValueRange outputs = genericOp.getDpsInits(); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector resultTypes = genericOp.hasPureTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; + LinalgOp namedOp; + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, + inputs, outputs); + } else { + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + namedOp = rewriter.replaceOpWithNewOp( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + } + return namedOp; +} + +// Converts linalg.generic to named linalg.*conv/pooling* where possible. To +// improve the search speed, the convolution ops have been segregated based on +// the rank of iterator types array. +static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + // ----------------------------- + // Depthwise Convolution ops. + // ----------------------------- + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + // ----------------------------- + // Pooling ops. + // ----------------------------- + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + return failure(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -316,6 +390,11 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } + + // Convolution - e.g. *conv/pooling* + if (isaConvolutionOpInterface(genericOp)) { + return specializeLinalgConvolutions(rewriter, genericOp); + } return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 24d3722cf5426..0be2668a9b346 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -240,6 +240,548 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } +//===----------------------------------------------------------------------===// +// Convolution matcher utilities +//===----------------------------------------------------------------------===// + +/// Returns the BlockArgument that leads to `val`, if any. Traverses optional +/// ext* ops. +static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { + BlockArgument blockArg; + if (!(blockArg = dyn_cast(val))) { + Operation *defOp = val.getDefiningOp(); + if (!dyn_cast_if_present(defOp) && + !dyn_cast_if_present(defOp) && + !dyn_cast_if_present(defOp)) { + return nullptr; + } + blockArg = dyn_cast(defOp->getOperand(0)); + } + return blockArg; +} + +/// Utility to match block body for matmul-like ops. +static bool bodyMatcherForMatmulLikeOps(Value yieldVal, Block *body) { + Operation *addOp = yieldVal.getDefiningOp(); + if (!isa_and_present(addOp)) + return false; + + Operation *mulOp = addOp->getOperand(1).getDefiningOp(); + if (!isa_and_present(mulOp)) + return false; + + BlockArgument lhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0)); + BlockArgument rhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1)); + BlockArgument outBlockArg = + getBlockArgumentWithOptionalExtOps(addOp->getOperand(0)); + if (!lhsBlockArg || !rhsBlockArg || !outBlockArg || + lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body || + outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 || + rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2) + return false; + return true; +} + +/// Utility to match block body for linalg.pool* ops. +template +static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { + Operation *defOp = yieldVal.getDefiningOp(); + if (!(isa_and_present(defOp) || ...)) + return false; + + BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); + BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); + if (!lhsArg || !rhsArg || lhsArg.getOwner() != body || + rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 || + rhsArg.getArgNumber() != 0) + return false; + return true; +} + +static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +// max_unsigned ops should not allow float data type. +// TODO: Retire OPDSL logic. Refer to : +// https://github.com/llvm/llvm-project/issues/164800 +static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +// min_unsigned ops should not allow float data type. +// TODO: Retire OPDSL logic. Refer to : +// https://github.com/llvm/llvm-project/issues/164800 +static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, + uint32_t mapIndex, uint32_t dimIndex) { + auto affineMap = cast(indexingMaps[mapIndex]).getValue(); + if (dimIndex < affineMap.getNumResults()) + return affineMap.getResult(dimIndex); + return nullptr; +} + +// Check if `expr` is either: +// - a dimension expr alone (implying *1), or +// - a multiplication of dimension expr by constant. +static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, + int64_t &constantValue) { + if (auto dExpr = dyn_cast(expr)) { + dim = dExpr; + constantValue = 1; + return true; + } + + auto mulExpr = dyn_cast(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return false; + + AffineExpr lhs = mulExpr.getLHS(); + AffineExpr rhs = mulExpr.getRHS(); + + if (auto dExpr = dyn_cast(lhs)) { + if (auto cst = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + if (auto cst = dyn_cast(lhs)) { + if (auto dExpr = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + return false; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// indexingMaps[0].getResult(iDim) == +/// indexingMaps[1].getResult(fDim) * + +/// indexingMaps[n-1].getResult(oDim) * +/// where, CST_1 and CST_2 can be any constant. +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, + unsigned fDim, unsigned oDim, + int64_t &dilation, int64_t &stride) { + unsigned inputMapIdx = 0, filterMapIdx = 1, + outputMapIdx = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim); + auto addExpr = dyn_cast(inpExpr); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) + return false; + + AffineExpr dim0, dim1; + int64_t c0, c1; + + if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && + isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { + // Pattern matched with dims and constants extracted. + AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim); + if (dim0 == fExpr && dim1 == oExpr) { + dilation = c0; + stride = c1; + return true; + } else if (dim1 == fExpr && dim0 == oExpr) { + dilation = c1; + stride = c0; + return true; + } + } + return false; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// indexingMaps[aIndex].getResult(aDim) == +/// indexingMaps[bIndex].getResult(bDim) +static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, + unsigned aDim, unsigned bIndex, + unsigned bDim) { + return getAffineMapDim(indexingMaps, aIndex, aDim) == + getAffineMapDim(indexingMaps, bIndex, bDim); +} + +/// Give an array of AffineMaps, verify each map to be of the corresponding +/// `expectedSize`. +static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, + ArrayRef expectedSizes) { + if (indexingMaps.size() != expectedSizes.size()) + return false; + + for (auto [indexingMap, expectedSize] : + llvm::zip_equal(indexingMaps, expectedSizes)) { + auto affineMap = cast(indexingMap).getValue(); + if (affineMap.getNumResults() != expectedSize) + return false; + } + return true; +} + +/// Utility to update `dilations` and `strides` by copy the corresponding data +/// from `tempDilations` and `tempStrides`. +static void updateConvDilationsAndStrides(SmallVector *dilations, + SmallVector *strides, + ArrayRef tempDilations, + ArrayRef tempStrides) { + if (!(dilations && strides)) + return; + for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) { + dilations->push_back(dilation); + strides->push_back(stride); + } + return; +} + +// --------------------------------------------- +// Matchers for specific convolution operation. +// --------------------------------------------- + +// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, w) -> (w, C)> +// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 2, filterMapIdx, 1) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 2, outputMapIdx, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + bodyMatcherForMatmulLikeOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 1, filterMapIdx, 0) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 1, outputMapIdx, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[1], + tempStrides[1]) && + bodyMatcherForMatmulLikeOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; +} + +// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D + d, H + h, W + w, C)> +// #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (d, h, w, C, CM)> +// #outputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D, H, W, C, CM)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 4, filterMapIdx, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 4, outputMapIdx, 4) && + matchConvDimExprPattern(indexingMaps, filterMapIdx, 4, outputMapIdx, + 5) && + bodyMatcherForMatmulLikeOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, outputMapIdx = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, outputMapIdx = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, outputMapIdx = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && + bodyMatcherForSumPoolOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, outputMapIdx = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && + bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, outputMapIdx = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + bool returnVal = + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && + bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; +} + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir new file mode 100644 index 0000000000000..06c9a84049d81 --- /dev/null +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -0,0 +1,124 @@ +// The following test examples of linalg convolution named ops lowered to linalg.generic and then +// lifted back up to named op. +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s + +func.func @depthwise_conv_1d_nwc_wc(%input: memref, %filter: memref, %output: memref) { + linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @depthwise_conv_1d_nwc_wc +// CHECK: linalg.depthwise_conv_1d_nwc_wc +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nchw_chw +// CHECK: linalg.depthwise_conv_2d_nchw_chw +// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwcm +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max +// CHECK: linalg.pooling_nhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_min(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_sum +// CHECK: linalg.pooling_nhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max_unsigned +// CHECK: linalg.pooling_nhwc_max_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_min_unsigned_integer(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min_unsigned_integer +// CHECK: linalg.pooling_nhwc_min_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +func.func @pooling_nhwc_min_unsigned_float(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min_unsigned_float +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic