Skip to content

Conversation

@Abhishek-Varma
Copy link
Contributor

@Abhishek-Varma Abhishek-Varma commented Nov 17, 2025

-- This commit is the fourth in the series of adding matchers
for linalg.conv/pool. Refer: #163724
-- In this commit all variants of Conv2D convolution ops have been
added.
-- It also refactors the way these matchers work to make adding more matchers concise.

Signed-off-by: Abhishek Varma abhvarma@amd.com

@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Abhishek Varma (Abhishek-Varma)

Changes

-- This commit is the third in the series of adding matchers
for linalg.conv/pool. Refer: #163724
-- In this commit all variants of Conv2D convolution ops have been
added.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>


Patch is 62.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168362.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+15)
  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+882-91)
  • (modified) mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir (+149-6)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index c2485a08932dd..b52b93f8cc9b9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -279,6 +279,17 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
   CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
   CONV_OP_SPECIALIZER(linalg::Conv2DOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
+  CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
   CONV_OP_SPECIALIZER(linalg::Conv3DOp);
   // -----------------------------
   // Depthwise Convolution ops.
@@ -287,6 +298,10 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
+  CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
   CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
   // -----------------------------
   // Pooling ops.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 6b85e6ba0ede2..57593abac7ab0 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,8 +240,8 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
 //===----------------------------------------------------------------------===//
 
 /// Returns the BlockArgument that leads to `val`, if any. Traverses optional
-/// ext* ops.
-static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+/// ext*/sitofp ops.
+static BlockArgument getBlockArgumentWithOptionalCastOps(Value val) {
   BlockArgument blockArg = dyn_cast<BlockArgument>(val);
   if ((blockArg))
     return blockArg;
@@ -249,18 +249,62 @@ static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
   Operation *defOp = val.getDefiningOp();
   if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
       !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
-      !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+      !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
+      !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
     return nullptr;
   }
   return dyn_cast<BlockArgument>(defOp->getOperand(0));
 }
 
+/// Utility function to match the zero point offset body of convolution ops.
+/// It takes input the addition op and multiplication op expected in every
+/// convolution op and matches the following for both operands of multiplication
+/// op :-
+///     %a - %b
+///   where, %a and %b can have optional upcast operation.
+static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
+                                           Block *body) {
+  Operation *subOp1 = mulOp->getOperand(0).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp1))
+    return false;
+  Operation *subOp2 = mulOp->getOperand(1).getDefiningOp();
+  if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp2))
+    return false;
+  BlockArgument inputBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp1->getOperand(0));
+  BlockArgument inputScalarBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp1->getOperand(1));
+  BlockArgument filterBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp2->getOperand(0));
+  BlockArgument filterScalarBlockArg =
+      getBlockArgumentWithOptionalCastOps(subOp2->getOperand(1));
+  BlockArgument outBlockArg =
+      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+  if (!inputBlockArg || !inputScalarBlockArg || !filterBlockArg ||
+      !filterScalarBlockArg || !outBlockArg ||
+      inputBlockArg.getOwner() != body ||
+      inputScalarBlockArg.getOwner() != body ||
+      filterBlockArg.getOwner() != body ||
+      filterScalarBlockArg.getOwner() != body ||
+      outBlockArg.getOwner() != body || inputBlockArg.getArgNumber() != 0 ||
+      inputScalarBlockArg.getArgNumber() != 2 ||
+      filterBlockArg.getArgNumber() != 1 ||
+      filterScalarBlockArg.getArgNumber() != 3 ||
+      outBlockArg.getArgNumber() != 4)
+    return false;
+  return true;
+}
+
 /// Utility to match block body for convolution ops.
 /// The body is thus expected to yield :-
 ///     %out + (%lhs * %rhs)
 ///   where: %lhs, %rhs and %out are block arguments and
 ///          %lhs and %rhs can have optional upcast operation.
-static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
+/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
+///       %input - %input_scalar
+///          where, %input_scalar can have optional upcast operation.
+static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
+                                         bool zeroPointOffset = false) {
   Operation *addOp = yieldVal.getDefiningOp();
   if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
     return false;
@@ -269,12 +313,15 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
   if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
     return false;
 
+  if (zeroPointOffset) {
+    return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
+  }
   BlockArgument lhsBlockArg =
-      getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
   BlockArgument rhsBlockArg =
-      getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1));
+      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
   BlockArgument outBlockArg =
-      getBlockArgumentWithOptionalExtOps(addOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
   if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
       lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
       outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
@@ -291,9 +338,9 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
     return false;
 
   BlockArgument lhsArg =
-      getBlockArgumentWithOptionalExtOps(defOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(defOp->getOperand(0));
   BlockArgument rhsArg =
-      getBlockArgumentWithOptionalExtOps(defOp->getOperand(1));
+      getBlockArgumentWithOptionalCastOps(defOp->getOperand(1));
   if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
       rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
       rhsArg.getArgNumber() != 0)
@@ -599,49 +646,45 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap  = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
-// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
-// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
 template <>
-bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
-                                              SmallVector<int64_t> *dilations,
-                                              SmallVector<int64_t> *strides) {
-  if (isa<linalg::Conv3DOp>(op))
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcFhwcOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(3, 1);
-  *strides = SmallVector<int64_t>(3, 1);
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
-  AffineExpr D = getAffineDimExpr(0, context);
+  AffineExpr N = getAffineDimExpr(0, context);
   AffineExpr H = getAffineDimExpr(1, context);
   AffineExpr W = getAffineDimExpr(2, context);
-  AffineExpr d = getAffineDimExpr(3, context);
+  AffineExpr F = getAffineDimExpr(3, context);
   AffineExpr h = getAffineDimExpr(4, context);
   AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
-  // Match: D * stride + d * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
-                                  /*oDim=*/0, (*dilations)[0], (*strides)[0]))
-    return false;
   // Match: H * stride + h * dilation
   if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
-                                  /*oDim=*/1, (*dilations)[1], (*strides)[1]))
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
     return false;
   // Match: W * stride + w * dilation
   if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
-                                  /*oDim=*/2, (*dilations)[2], (*strides)[2]))
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
-                         H * (*strides)[1] + h * (*dilations)[1],
-                         W * (*strides)[2] + w * (*dilations)[2]},
-           /*filterMap=*/{d, h, w},
-           /*outputMap=*/{D, H, W}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{F, h, w, c},
+           /*outputMap=*/{N, H, W, F}},
           indexingMaps, context))
     return false;
   // Match body
@@ -651,37 +694,45 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap  = affine_map<(N, W, C, w) -> (N, C, W + w)>
-// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
-// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
 template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+  if (isa<linalg::Conv2DNhwcHwcfOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
   AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr F = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+                                  /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+    return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
-                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
     return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
-           /*filterMap=*/{C, w},
-           /*outputMap=*/{N, C, W}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{h, w, c, F},
+           /*outputMap=*/{N, H, W, F}},
           indexingMaps, context))
     return false;
   // Match body
@@ -691,37 +742,196 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
   return bodyMatcherForConvolutionOps(yieldVal, body);
 }
 
-// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
-// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
-// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
+// #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
 template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
     LinalgOp op, SmallVector<int64_t> *dilations,
     SmallVector<int64_t> *strides) {
-  if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+  if (isa<linalg::Conv2DNchwFchwOp>(op))
     return true;
 
   assert(isaConvolutionOpInterface(op) &&
          "expected op to implement ConvolutionOpInterface");
 
-  *dilations = SmallVector<int64_t>(1, 1);
-  *strides = SmallVector<int64_t>(1, 1);
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
   MLIRContext *context = op->getContext();
   AffineExpr N = getAffineDimExpr(0, context);
-  AffineExpr W = getAffineDimExpr(1, context);
-  AffineExpr C = getAffineDimExpr(2, context);
-  AffineExpr w = getAffineDimExpr(3, context);
+  AffineExpr F = getAffineDimExpr(1, context);
+  AffineExpr H = getAffineDimExpr(2, context);
+  AffineExpr W = getAffineDimExpr(3, context);
+  AffineExpr C = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
   ArrayAttr indexingMaps = op.getIndexingMaps();
   // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
   // Match: W * stride + w * dilation
-  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{F, C, h, w},
+           /*outputMap=*/{N, F, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap  = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr H = getAffineDimExpr(1, context);
+  AffineExpr W = getAffineDimExpr(2, context);
+  AffineExpr F = getAffineDimExpr(3, context);
+  AffineExpr h = getAffineDimExpr(4, context);
+  AffineExpr w = getAffineDimExpr(5, context);
+  AffineExpr c = getAffineDimExpr(6, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
                                   /*oDim=*/1, (*dilations)[0], (*strides)[0]))
     return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+    return false;
   // Match expected indexing maps
   if (!convLayoutMatches(
-          {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
-           /*filterMap=*/{w, C},
-           /*outputMap=*/{N, W, C}},
+          {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1], c},
+           /*filterMap=*/{F, h, w, c},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, H, W, F}},
+          indexingMaps, context))
+    return false;
+  // Match body
+  Block *body = op.getBlock();
+  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+  Value yieldVal = yieldOp.getOperand(0);
+  return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap  = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #scalarMap = affine_map<(N, F, H, W, C, h, w) -> ()>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
+    LinalgOp op, SmallVector<int64_t> *dilations,
+    SmallVector<int64_t> *strides) {
+  if (isa<linalg::Conv2DNchwFchwQOp>(op))
+    return true;
+
+  assert(isaConvolutionOpInterface(op) &&
+         "expected op to implement ConvolutionOpInterface");
+
+  *dilations = SmallVector<int64_t>(2, 1);
+  *strides = SmallVector<int64_t>(2, 1);
+  MLIRContext *context = op->getContext();
+  AffineExpr N = getAffineDimExpr(0, context);
+  AffineExpr F = getAffineDimExpr(1, context);
+  AffineExpr H = getAffineDimExpr(2, context);
+  AffineExpr W = getAffineDimExpr(3, context);
+  AffineExpr C = getAffineDimExpr(4, context);
+  AffineExpr h = getAffineDimExpr(5, context);
+  AffineExpr w = getAffineDimExpr(6, context);
+  ArrayAttr indexingMaps = op.getIndexingMaps();
+  // First fetch dilations/strides :-
+  // Match: H * stride + h * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+                                  /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+    return false;
+  // Match: W * stride + w * dilation
+  if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+                                  /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+    return false;
+  // Match expected indexing maps
+  if (!convLayoutMatches(
+          {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+                         W * (*strides)[1] + w * (*dilations)[1]},
+           /*filterMap=*/{F, C, h, w},
+           /*scalarMap=*/{},
+           /*scalarMap=*/{},
+           /*outputMap=*/{N, F, H, W}},
+          indexingMaps, context))
+    return false;
+  // Match bo...
[truncated]

@hanhanW hanhanW requested a review from ftynse November 21, 2025 21:33
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't reviewed detailed implementation yet, just pointing out some missing tests. Github/git is not doing good at diff this time, so it'll take me some time to review it.

(also, I'm still thinking if we can refactor something or not, so we don't need many changes.)

@hanhanW
Copy link
Contributor

hanhanW commented Nov 25, 2025

@Abhishek-Varma I've been thinking how to make the implementation simpler. Below is the plan that I brainstormed with AI agent, and I think we can do the refactoring first. Then this PR will be much easier to review.

The main benefit is that the code itself documents the expected indexing maps, and it simplifies code a lot to me. Can you take a look and let me know what you think?


Problem: Each template specialization follows the exact same pattern (~50-60 lines each), leading to ~800+ lines of highly repetitive code across the 15 new matchers. This pattern is consistent across ALL conv types (1D, 2D, 3D, regular, depthwise, quantized).

The repeated boilerplate in each matcher:

template <>
bool isaConvolutionOpOfType<linalg::SomeConvOp>(...) {
  if (isa<linalg::SomeConvOp>(op)) return true;           // 1. type check
  assert(isaConvolutionOpInterface(op) && ...);           // 2. assert
  *dilations = SmallVector<int64_t>(N, 1);                // 3. init
  *strides = SmallVector<int64_t>(N, 1);
  MLIRContext *context = op->getContext();
  AffineExpr A = getAffineDimExpr(0, context);            // 4. dims (5-8 lines)
  AffineExpr B = getAffineDimExpr(1, context);
  // ... more dims ...
  ArrayAttr indexingMaps = op.getIndexingMaps();
  if (!matchConvDimAddExprPattern(...))                   // 5. stride matching
    return false;
  // ... more stride matching ...
  if (!convLayoutMatches({...}, indexingMaps, context))   // 6. map matching
    return false;
  Block *body = op.getBlock();                            // 7. body match
  auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
  Value yieldVal = yieldOp.getOperand(0);
  return bodyMatcherForConvolutionOps(yieldVal, body);
}

Suggested Solution: Introduce a ConvMatcherBuilder helper class to reduce each matcher from ~50 lines to ~12 lines while improving readability:

/// Helper class for building convolution op matchers with minimal boilerplate.
/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants.
class ConvMatcherBuilder {
  LinalgOp op;
  MLIRContext *ctx;
  SmallVector<int64_t> *dilations, *strides;
  ArrayAttr indexingMaps;
  bool matched = true;

public:
  ConvMatcherBuilder(LinalgOp op, unsigned spatialRank,
                     SmallVector<int64_t> *d, SmallVector<int64_t> *s)
      : op(op), ctx(op->getContext()), dilations(d), strides(s),
        indexingMaps(op.getIndexingMaps()) {
    *dilations = SmallVector<int64_t>(spatialRank, 1);
    *strides = SmallVector<int64_t>(spatialRank, 1);
  }

  /// Get affine dimension expression for dimension i.
  AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }

  /// Build strided expression: base * stride[idx] + kernel * dilation[idx]
  AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
    return base * (*strides)[idx] + kernel * (*dilations)[idx];
  }

  /// Match stride/dilation pattern for a spatial dimension.
  /// Returns *this for method chaining.
  ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim,
                                  unsigned oDim, unsigned idx) {
    if (matched) {
      matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
                                           (*dilations)[idx], (*strides)[idx]);
    }
    return *this;
  }

  /// Match expected indexing maps layout.
  /// Returns *this for method chaining.
  ConvMatcherBuilder &expectMaps(
      std::initializer_list<SmallVector<AffineExpr>> maps) {
    if (matched)
      matched = convLayoutMatches(maps, indexingMaps, ctx);
    return *this;
  }

  /// Match body pattern. This should be called last.
  bool matchBody(bool zeroPointOffset = false) {
    if (!matched)
      return false;
    Block *body = op.getBlock();
    auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
    return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body,
                                        zeroPointOffset);
  }
};

Example usage - Conv3DOp (before: ~55 lines, after: ~12 lines):

template <>
bool isaConvolutionOpOfType<linalg::Conv3DOp>(
    LinalgOp op, SmallVector<int64_t> *dilations, SmallVector<int64_t> *strides) {
  if (isa<linalg::Conv3DOp>(op)) return true;
  assert(isaConvolutionOpInterface(op) && "expected ConvolutionOpInterface");

  ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
  auto D = m.dim(0), H = m.dim(1), W = m.dim(2);
  auto d = m.dim(3), h = m.dim(4), w = m.dim(5);

  return m.matchStride(0, 0, 0, 0)
          .matchStride(1, 1, 1, 1)
          .matchStride(2, 2, 2, 2)
          .expectMaps({/*in=*/{m.strided(D,d,0), m.strided(H,h,1), m.strided(W,w,2)},
                       /*filter=*/{d, h, w},
                       /*out=*/{D, H, W}})
          .matchBody();
}

Example - DepthwiseConv2DNhwcHwcQOp (quantized, before: ~55 lines, after: ~14 lines):

template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
    LinalgOp op, SmallVector<int64_t> *dilations, SmallVector<int64_t> *strides) {
  if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true;
  assert(isaConvolutionOpInterface(op) && "expected ConvolutionOpInterface");

  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
  auto N = m.dim(0), H = m.dim(1), W = m.dim(2), c = m.dim(3);
  auto h = m.dim(4), w = m.dim(5);

  return m.matchStride(1, 0, 1, 0)
          .matchStride(2, 1, 2, 1)
          .expectMaps({/*in=*/{N, m.strided(H,h,0), m.strided(W,w,1), c},
                       /*filter=*/{h, w, c},
                       /*inputScalar=*/{}, /*filterScalar=*/{},
                       /*out=*/{N, H, W, c}})
          .matchBody(/*zeroPointOffset=*/true);
}

Example - Conv2DNgchwGfchwOp (grouped conv, before: ~55 lines, after: ~14 lines):

template <>
bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
    LinalgOp op, SmallVector<int64_t> *dilations, SmallVector<int64_t> *strides) {
  if (isa<linalg::Conv2DNgchwGfchwOp>(op)) return true;
  assert(isaConvolutionOpInterface(op) && "expected ConvolutionOpInterface");

  ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
  auto N = m.dim(0), G = m.dim(1), FG = m.dim(2);
  auto H = m.dim(3), W = m.dim(4), C = m.dim(5);
  auto h = m.dim(6), w = m.dim(7);

  return m.matchStride(3, 3, 3, 0)
          .matchStride(4, 4, 4, 1)
          .expectMaps({/*in=*/{N, G, C, m.strided(H,h,0), m.strided(W,w,1)},
                       /*filter=*/{G, FG, C, h, w},
                       /*out=*/{N, G, FG, H, W}})
          .matchBody();
}

Benefits:

Metric Before After
Lines per matcher ~50-60 ~12-15
Total lines for 15 new ops ~800 ~200
Readability Low (wall of boilerplate) High (declarative layout visible)
Error-prone copy-paste High Low
Works for all conv types N/A Yes (1D, 2D, 3D, regular, depthwise, quantized)

@hanhanW
Copy link
Contributor

hanhanW commented Nov 25, 2025

I checked most matching except matchStride. Can you take a look at that one?

@Abhishek-Varma
Copy link
Contributor Author

Hi @hanhanW !

That was such a nice suggestion. I've addressed your comments!

I was trying to also see if there's a way for us to simply just initialize with "(N, H, W, C, h, w)" string using the ConvMatchBuilder and then we simply just pass the string versions of the indexing maps directly - so this way the caller doesn't need to initialize so many dims locally. But it turned out to be a bit more involved so I just stuck with the current suggestion. :D

Can you please take a look now ?

@hanhanW
Copy link
Contributor

hanhanW commented Nov 26, 2025

Hi @hanhanW !

That was such a nice suggestion. I've addressed your comments!

I was trying to also see if there's a way for us to simply just initialize with "(N, H, W, C, h, w)" string using the ConvMatchBuilder and then we simply just pass the string versions of the indexing maps directly - so this way the caller doesn't need to initialize so many dims locally. But it turned out to be a bit more involved so I just stuck with the current suggestion. :D

Can you please take a look now ?

I'll be out for a week. QQ: can we split it to two PRs? One for introducing ConvMatchBuilder and use it for existing patterns, and the other (which can be this one) is for 2D convs that we adding in this PR?

Abhishek-Varma added a commit to Abhishek-Varma/llvm-project that referenced this pull request Nov 26, 2025
-- This commit is a follow-up and third in the series of adding
   matchers for conv/pool ops. Refer: llvm#163724
-- It introduces ConvMatchBuilder class in order to reduce the
   repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants.
-- Refer to [Conv2D thread](llvm#168362 (comment)) for further context.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
@Abhishek-Varma
Copy link
Contributor Author

I've raised the ConvMatchBuilder PR : #169704

Let me know if I should close this PR for now and reopen after the ConvMatchBuilder PR gets merged - or just let it be.

CC: @hanhanW @banach-space @MaheshRavishankar

Abhishek-Varma added a commit that referenced this pull request Nov 29, 2025
…69704)

-- This commit is a follow-up and third in the series of adding
matchers for conv/pool ops. Refer:
#163724
-- It introduces ConvMatchBuilder class in order to reduce the
   repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants.
-- Refer to [Conv2D
thread](#168362 (comment))
for further context.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
-- This commit is the fourth in the series of adding matchers
   for linalg.conv/pool. Refer: llvm#163724
-- In this commit all variants of Conv2D convolution ops have been
   added.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
@Abhishek-Varma Abhishek-Varma force-pushed the avarma_conv_matcher_pr_3 branch from 3c90eee to 968cb69 Compare November 29, 2025 16:46
@Abhishek-Varma
Copy link
Contributor Author

I've rebased the PR to use ConvMatchBuilder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants