Skip to content

Conversation

@Abhishek-Varma
Copy link
Contributor

-- This commit includes the basic infra/utilities to add matchers for
linalg.conv/pool ops - such that given a linalg.generic op it
identifies which linalg.conv/pool op it is.
-- It adds a few representative linalg.conv/pool ops to demo the
matchers' capability and does so as part of linalg-specialize-generic-ops
pass.
-- The goal is directed towards addressing the aim of
[RFC] Op explosion in Linalg
iteratively for *conv*/*pooling* ops.
-- This is part-1 of a series of PRs aimed to add matchers for Convolution ops.
-- For further details, refer to #163374 (review)

Signed-off-by: Abhishek Varma abhvarma@amd.com

-- This commit includes the basic infra/utilities to add matchers for
   linalg.*conv*/*pool* ops - such that given a `linalg.generic` op it
   identifies which linalg.*conv*/*pool* op it is.
-- It adds a few representative linalg.*conv*/*pool* ops to demo the
   matchers' capability and does so as part of `linalg-specialize-generic-ops`
   pass.
-- The goal is directed towards addressing the aim of
   [[RFC] Op explosion in Linalg](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863)
   iteratively for `*conv*/*pooling*` ops.
-- This is part-1 of a series of PRs aimed to add matchers for Convolution ops.
-- For further details, refer to llvm#163374 (review)

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Abhishek Varma (Abhishek-Varma)

Changes

-- This commit includes the basic infra/utilities to add matchers for
linalg.conv/pool ops - such that given a linalg.generic op it
identifies which linalg.conv/pool op it is.
-- It adds a few representative linalg.conv/pool ops to demo the
matchers' capability and does so as part of linalg-specialize-generic-ops
pass.
-- The goal is directed towards addressing the aim of
[RFC] Op explosion in Linalg
iteratively for *conv*/*pooling* ops.
-- This is part-1 of a series of PRs aimed to add matchers for Convolution ops.
-- For further details, refer to #163374 (review)

Signed-off-by: Abhishek Varma <abhvarma@amd.com>


Patch is 36.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163724.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+9)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+144)
  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+502)
  • (added) mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir (+112)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 48978eb7663d5..771d753a8bddb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
 std::optional<SmallVector<ReassociationIndices>>
 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
 
+//===----------------------------------------------------------------------===//
+// Convolution matcher utility
+//===----------------------------------------------------------------------===//
+
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op,
+                            SmallVector<int64_t> *dilations = nullptr,
+                            SmallVector<int64_t> *strides = nullptr);
+
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 40fc0d68e358f..35861002e309e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+                   ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+  SmallVector<Value> inputs = genericOp.getDpsInputs();
+  ValueRange outputs = genericOp.getDpsInits();
+  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+  SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+                                      ? TypeRange(ValueRange(outputs))
+                                      : TypeRange{};
+  LinalgOp namedOp;
+  if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+                                                    inputs, outputs);
+  } else {
+    Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+    Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+        genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+  }
+  return namedOp;
+}
+
+/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+        rewriter, genericOp, dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+        rewriter, genericOp, dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
+static FailureOr<LinalgOp>
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<utils::IteratorType> iteratorTypes =
+      genericOp.getIteratorTypesArray();
+  unsigned totalIterators = iteratorTypes.size();
+  switch (totalIterators) {
+  case 2:
+    return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
+  case 4:
+    return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
+  case 5:
+    return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
+  case 6:
+    return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
+  case 7:
+    return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
+  case 8:
+    return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
+  case 9:
+    return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
+  }
+  return failure();
+}
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -316,6 +455,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
   if (isaContractionOpInterface(genericOp)) {
     return specializeLinalgContractions(rewriter, genericOp);
   }
+
+  // Convolution - e.g. *conv/pooling*
+  if (isaConvolutionOpInterface(genericOp)) {
+    return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
+  }
   return failure();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 24d3722cf5426..c3c2819652129 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,6 +240,508 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
   return iteratorType == utils::IteratorType::reduction;
 }
 
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+  Operation *defOp = yieldVal.getDefiningOp();
+  if (!(isa_and_present<OpTypes>(defOp) || ...))
+    return false;
+
+  BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+  BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+  if (!lhsArg || !rhsArg)
+    return false;
+  return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+                                                                  body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+  return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+                                        uint32_t mapIndex, uint32_t dimIndex) {
+  auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+  if (dimIndex < affineMap.getNumResults())
+    return affineMap.getResult(dimIndex);
+  return nullptr;
+}
+
+// Check if `expr` is either:
+// - a dimension expr alone (implying *1), or
+// - a multiplication of dimension expr by constant.
+static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim,
+                                        int64_t &constantValue) {
+  if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
+    dim = dExpr;
+    constantValue = 1;
+    return true;
+  }
+
+  auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+  if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+    return false;
+
+  AffineExpr lhs = mulExpr.getLHS();
+  AffineExpr rhs = mulExpr.getRHS();
+
+  if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
+    if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
+      dim = dExpr;
+      constantValue = cst.getValue();
+      return true;
+    }
+  }
+  if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
+    if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
+      dim = dExpr;
+      constantValue = cst.getValue();
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+///   indexingMaps[0].getResult(iDim) ==
+///         indexingMaps[1].getResult(fDim) * <CST_1> +
+///         indexingMaps[n-1].getResult(oDim) * <CST_2>
+///  where, CST_1 and CST_2 can be any constant.
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+                                       unsigned fDim, unsigned oDim,
+                                       int64_t &dilation, int64_t &stride) {
+  unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+  AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
+  auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+  if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+    return false;
+
+  AffineExpr dim0, dim1;
+  int64_t c0, c1;
+
+  if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
+      isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
+    // Pattern matched with dims and constants extracted.
+    AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
+    AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
+    if (dim0 == fExpr && dim1 == oExpr) {
+      dilation = c0;
+      stride = c1;
+      return true;
+    } else if (dim1 == fExpr && dim0 == oExpr) {
+      dilation = c1;
+      stride = c0;
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+///   indexingMaps[aIndex].getResult(aDim) ==
+///   indexingMaps[bIndex].getResult(bDim)
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex,
+                                    unsigned aDim, unsigned bIndex,
+                                    unsigned bDim) {
+  return getAffineMapDim(indexingMaps, aIndex, aDim) ==
+         getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+/// Give an array of AffineMaps, verify each map to be of the corresponding
+/// `expectedSize`.
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
+                                       ArrayRef<int64_t> expectedSizes) {
+  if (indexingMaps.size() != expectedSizes.size())
+    return false;
+
+  for (auto [indexingMap, expectedSize] :
+       llvm::zip_equal(indexingMaps, expectedSizes)) {
+    auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
+    if (affineMap.getNumResults() != expectedSize)
+      return false;
+  }
+  return true;
+}
+
+/// Utility to update `dilations` and `strides` by copy the corresponding data
+/// from `tempDilations` and `tempStrides`.
+static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
+                                          SmallVector<int64_t> *strides,
+                                          ArrayRef<int64_t> tempDilations,
+                                          ArrayRef<int64_t> tempStrides) {
+  if (!(dilations && strides))
+    return true;
+  for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
+    dilations->push_back(dilation);
+    strides->push_back(stride);
+  }
+  return true;
+}
+
+static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
+                                      SmallVector<int64_t> *dilations,
+                                      SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+    return false;
+
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(1, 1);
+  SmallVector<int64_t> tempStrides(1, 1);
+  // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+  // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+  // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
+                                        SmallVector<int64_t> *dilations,
+                                        SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+    return false;
+
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(2, 1);
+  SmallVector<int64_t> tempStrides(2, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[1],
+                                  tempStrides[1]));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+                                           SmallVector<int64_t> *dilations,
+                                           SmallVector<int64_t> *strides) {
+  if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
+    return false;
+
+  unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+  SmallVector<int64_t> tempDilations(3, 1);
+  SmallVector<int64_t> tempStrides(3, 1);
+  // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  //                  -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+  // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  //                  -> (d5, d6, d7, d8, d4)>
+  // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+  //                  -> (d0, d1, d2, d3, d8, d4)>
+  bool returnVal =
+      (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+                                  /*oDim=*/1, tempDilations[0],
+                                  tempStrides[0]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+                                  /*oDim=*/2, tempDilations[1],
+                                  tempStrides[1]) &&
+       matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+                                  /*oDim=*/3, tempDilations[2],
+                                  tempStrides[2]) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+       matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+       matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+  return returnVal && updateConvDilationsAndStrides(dilations, strides,
+                                                    tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+                                SmallVector<int64_t> *strides) {
+  if (isa<linalg::PoolingNhwcMaxOp>(op))
+    return true;
+
+  if (!isaConvolutionOpInterface(op))
+    return false;
+
+  ArrayAt...
[truncated]

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for extracting this! Sharing my first set of comments. This is still quite dense, so I've not read everything yet 😅

Comment on lines 410 to 411
if (!isaConvolutionOpInterface(op))
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the logic in specializeGenericOp that also checks this, this should be an assert. Calling this function with something that doesn't satisfy this is a user error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, so the intention of placing it here as well is for ANY user to simply invoke isaConvolutionOfType<SOME_CONV_NAME>(genericOp) (on ANY LinalgOp).

I'm using this as part of specializeGenericOp to demonstrate the working of the matcher. I've plans to update a few PatternRewrites to use these matchers. Will be sending a PR for that too once we land ALL the matchers. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but right now, turning this into an assert would make sense, right?

In general, I would refrain from "tuning" the implementation towards what comes next. If an assert doesn't make sense in the future, you can always remove it. But "today" it feels 100% valid, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point but I still think it makes sense to keep the check as is instead of an assert.

It's got more to do with the fact that a given isa matchers' existence should be independent of its usage.
As a black-box, we'd want it to simply operate on ANY linalg.generic without any assumption.

The matcher thus should just respond with a YES/NO to whether a given linalg.generic is a convolution op of particular kind instead of crashing on an assert/assumption, in my opinion.

It just so happens that one of the independent users of a isaConvolutionOpOfType<...> API in this patch (for demo) is Specialize.cpp pass where I was trying to bail out early if a linalg.generic wasn't implementing ConvolutionOpInterface. I've removed that now though and the intent of the matcher being independent of whoever the caller is gets demoed properly now I feel.

Let me know your thoughts - I'm happy to converge to an agreeable point. :) Just putting my points above to better communicate my initial intention. 😅

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just putting my points above to better communicate my initial intention. 😅

And that's much appreciated! Ultimately, this is a conversation and it's important that we openly share and clarify our viewpoints 🙏🏻

The matcher thus should just respond with a YES/NO to whether a given linalg.generic is a convolution op of particular kind instead of crashing on an assert/assumption, in my opinion.

Why? :)

If you view this as a generic hook that could be used by anyone anywhere then yes. However, I view this as an implementation detail for this file. Since the following check is present in every function, it feels like repetition that ought to be factored out (assert would make sense though - that's just asserting a pre-condition):

  if (!isaConvolutionOpInterface(op))
    return false;

I've removed that now though and the intent of the matcher being independent of whoever the caller is gets demoed properly now I feel.

Sure, this makes sense if we want to provide these hooks as some generic facility. But, IMO, we don't. Note, we need to re-wire how Convolutions are defined and that will require refactoring these helper hooks. This is why I advocate for hiding as much as possible.

My suggestion is to consider these methods as implementation details. This is not a blocker for me though. Like yourself, trying to clarify my viewpoint on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I indeed view this as a generic hook that can be used by anyone/anywhere.

But I see that in order to re-wire how Convolutions are defined, it'd rather be better to abstract out the interface check.

I've thus removed it in the latest push. Could you please re-review this ? :)

Comment on lines +407 to +408
if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid this and various unnecessary calls and expand inferAndSpecializeBasedOnRank6ConvIteratorTypes instead with e.g.:

if (isa<linalg::DepthwiseConv1DNwcWcOp, linalg::DepthwiseConv2DNchwChwOp, ...>(op))
  return op;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same reasoning as the previous thread applies here.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for splitting the PR, it is easier to review! I'll take a look at [Utils.cpp] changes once we are aligned on the code structure.

Comment on lines 410 to 411
if (!isaConvolutionOpInterface(op))
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just putting my points above to better communicate my initial intention. 😅

And that's much appreciated! Ultimately, this is a conversation and it's important that we openly share and clarify our viewpoints 🙏🏻

The matcher thus should just respond with a YES/NO to whether a given linalg.generic is a convolution op of particular kind instead of crashing on an assert/assumption, in my opinion.

Why? :)

If you view this as a generic hook that could be used by anyone anywhere then yes. However, I view this as an implementation detail for this file. Since the following check is present in every function, it feels like repetition that ought to be factored out (assert would make sense though - that's just asserting a pre-condition):

  if (!isaConvolutionOpInterface(op))
    return false;

I've removed that now though and the intent of the matcher being independent of whoever the caller is gets demoed properly now I feel.

Sure, this makes sense if we want to provide these hooks as some generic facility. But, IMO, we don't. Note, we need to re-wire how Convolutions are defined and that will require refactoring these helper hooks. This is why I advocate for hiding as much as possible.

My suggestion is to consider these methods as implementation details. This is not a blocker for me though. Like yourself, trying to clarify my viewpoint on this.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are getting there 😅

I've started reviewing the utility functions, see my comments inline.

static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
unsigned fDim, unsigned oDim,
int64_t &dilation, int64_t &stride) {
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIndex = 2;

Also, instead of using "magic" 0, 1 and 2, could you define some global constants instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could do that but for Convolution ops outputMapIndex is indexingMaps.size() - 1 : so it can be either 2 or 4 (in a few Convolution ops' case).

How should I go about this?

Comment on lines +336 to +338
/// indexingMaps[1].getResult(fDim) * <CST_1> +
/// indexingMaps[n-1].getResult(oDim) * <CST_2>
/// where, CST_1 and CST_2 can be any constant.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are CST_1 and CST_2, I don't see those in the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So CST_1 and CST_2 would be dilations and strides constant.

I'm checking that using isDimTimesConstantOrDimOnly here :-

if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {

Comment on lines +395 to +396
/// Utility to update `dilations` and `strides` by copy the corresponding data
/// from `tempDilations` and `tempStrides`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this return a bool if the return value is always true? Should be void instead.

Also, why are strides and dilations updated? I don't follow the logic here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated it to be void now.

So in order to "re-create" the named ops in Specialize.cpp I'd also need to know the dilations/strides value reflected in the affine.map of the linalg.generic.
I included it as part of the matcher itself because while matching the indexing maps of input/filter/output itself are we able to obtain these.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants