Skip to content

Conversation

@Abhishek-Varma
Copy link
Contributor

-- This commit addresses follow-up review comments on 169704.
-- Contains NFC nit/minor changes.

Signed-off-by: Abhishek Varma abhvarma@amd.com

-- This commit addresses [follow-up review comments on 169704](llvm#169704 (review)).

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Abhishek Varma (Abhishek-Varma)

Changes

-- This commit addresses follow-up review comments on 169704.
-- Contains NFC nit/minor changes.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>


Full diff: https://github.com/llvm/llvm-project/pull/170080.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+85-71)
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e85a2ab26bd32..01e6e1e248658 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -430,19 +430,33 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
                           })));
 }
 
-/// Enum of all kinds of Pooling Op's type.
-enum PoolingType {
-  NONE,
-  MAX_SIGNED,
-  MAX_UNSIGNED,
-  MIN_SIGNED,
-  MIN_UNSIGNED,
-  SUM
+/// Enum representing pooling operation types used by ConvMatcherBuilder.
+enum class PoolingType {
+  None,
+  MaxSigned,
+  MaxUnsigned,
+  MinSigned,
+  MinUnsigned,
+  Sum
 };
 
 /// Helper class for building convolution op matchers with minimal boilerplate.
 /// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
 /// as Pooling ops.
+///
+/// Usage: Create an instance with the op, spatial rank, and output pointers for
+/// extracted dilations/strides. Then chain matchStride() calls for each spatial
+/// dimension, followed by matchMaps() to verify indexing maps, and finally
+/// matchBody() to verify the operation body pattern.
+///
+/// The `matched` flag starts as `true` and is set to `false` if any match step
+/// fails. This allows chaining multiple match calls; once any match fails, all
+/// subsequent calls become no-ops and the final result is `false`.
+///
+/// The `dilations` and `strides` pointers are output parameters that get
+/// populated with the extracted dilation and stride values from the operation's
+/// indexing maps during matchStride() calls. These values are initially set to
+/// 1 for each spatial dimension and updated as patterns are matched.
 class ConvMatcherBuilder {
   LinalgOp op;
   MLIRContext *ctx;
@@ -454,7 +468,7 @@ class ConvMatcherBuilder {
 public:
   ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
                      SmallVector<int64_t> *s,
-                     PoolingType poolingType = PoolingType::NONE)
+                     PoolingType poolingType = PoolingType::None)
       : op(op), ctx(op->getContext()), dilations(d), strides(s),
         indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
     *dilations = SmallVector<int64_t>(spatialRank, 1);
@@ -474,16 +488,16 @@ class ConvMatcherBuilder {
   ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
                                   unsigned idx) {
     if (matched) {
-      matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
-                                           (*dilations)[idx], (*strides)[idx]);
+      matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
+                                            (*dilations)[idx], (*strides)[idx]);
     }
     return *this;
   }
 
   /// Match expected indexing maps layout. Returns *this for method chaining.
-  ConvMatcherBuilder &expectMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
+  ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
     if (matched)
-      matched = convLayoutMatches(maps, indexingMaps, ctx);
+      matched &= convLayoutMatches(maps, indexingMaps, ctx);
     return *this;
   }
 
@@ -494,17 +508,17 @@ class ConvMatcherBuilder {
     Block *body = op.getBlock();
     auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
     switch (poolingType) {
-    case PoolingType::NONE:
+    case PoolingType::None:
       return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
-    case PoolingType::MAX_SIGNED:
+    case PoolingType::MaxSigned:
       return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
-    case PoolingType::MAX_UNSIGNED:
+    case PoolingType::MaxUnsigned:
       return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
-    case PoolingType::MIN_SIGNED:
+    case PoolingType::MinSigned:
       return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
-    case PoolingType::MIN_UNSIGNED:
+    case PoolingType::MinUnsigned:
       return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
-    case PoolingType::SUM:
+    case PoolingType::Sum:
       return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
     }
     return false;
@@ -533,9 +547,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
   AffineExpr w = m.dim(1);
 
   return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{m.strided(W, w, 0)},
-                   /*filterMap=*/{w},
-                   /*outputMap=*/{W}})
+      .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
+                  /*filterMap=*/{w},
+                  /*outputMap=*/{W}})
       .matchBody();
 }
 
@@ -560,9 +574,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
   AffineExpr c = m.dim(4);
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
-                   /*filterMap=*/{w, c, F},
-                   /*outputMap=*/{N, W, F}})
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
+                  /*filterMap=*/{w, c, F},
+                  /*outputMap=*/{N, W, F}})
       .matchBody();
 }
 
@@ -587,9 +601,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
   AffineExpr w = m.dim(4);
 
   return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
-                   /*filterMap=*/{F, c, w},
-                   /*outputMap=*/{N, F, W}})
+      .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
+                  /*filterMap=*/{F, c, w},
+                  /*outputMap=*/{N, F, W}})
       .matchBody();
 }
 
@@ -614,9 +628,9 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
 
   return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
       .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{H, W}})
+      .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{H, W}})
       .matchBody();
 }
 
@@ -644,10 +658,10 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
   return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
       .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
       .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
-      .expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
-                                 m.strided(W, w, 2)},
-                   /*filterMap=*/{d, h, w},
-                   /*outputMap=*/{D, H, W}})
+      .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
+                                m.strided(W, w, 2)},
+                  /*filterMap=*/{d, h, w},
+                  /*outputMap=*/{D, H, W}})
       .matchBody();
 }
 
@@ -671,9 +685,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
   AffineExpr w = m.dim(3);
 
   return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
-                   /*filterMap=*/{C, w},
-                   /*outputMap=*/{N, C, W}})
+      .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+                  /*filterMap=*/{C, w},
+                  /*outputMap=*/{N, C, W}})
       .matchBody();
 }
 
@@ -697,9 +711,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
   AffineExpr w = m.dim(3);
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
-                   /*filterMap=*/{w, C},
-                   /*outputMap=*/{N, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                  /*filterMap=*/{w, C},
+                  /*outputMap=*/{N, W, C}})
       .matchBody();
 }
 
@@ -724,9 +738,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
   AffineExpr w = m.dim(4);
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
-      .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
-                   /*filterMap=*/{w, C, CM},
-                   /*outputMap=*/{N, W, C, CM}})
+      .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+                  /*filterMap=*/{w, C, CM},
+                  /*outputMap=*/{N, W, C, CM}})
       .matchBody();
 }
 
@@ -753,9 +767,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
 
   return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
       .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
-                   /*filterMap=*/{C, h, w},
-                   /*outputMap=*/{N, C, H, W}})
+      .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+                  /*filterMap=*/{C, h, w},
+                  /*outputMap=*/{N, C, H, W}})
       .matchBody();
 }
 
@@ -789,10 +803,10 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
       .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
-      .expectMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
-                                 m.strided(W, w, 2), C},
-                   /*filterMap=*/{d, h, w, C, CM},
-                   /*outputMap=*/{N, D, H, W, C, CM}})
+      .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+                                m.strided(W, w, 2), C},
+                  /*filterMap=*/{d, h, w, C, CM},
+                  /*outputMap=*/{N, D, H, W, C, CM}})
       .matchBody();
 }
 
@@ -810,7 +824,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
          "expected op to implement ConvolutionOpInterface");
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
-                       PoolingType::MAX_SIGNED);
+                       PoolingType::MaxSigned);
   AffineExpr N = m.dim(0);
   AffineExpr H = m.dim(1);
   AffineExpr W = m.dim(2);
@@ -820,9 +834,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{N, H, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, H, W, C}})
       .matchBody();
 }
 
@@ -840,7 +854,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
          "expected op to implement ConvolutionOpInterface");
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
-                       PoolingType::MIN_SIGNED);
+                       PoolingType::MinSigned);
   AffineExpr N = m.dim(0);
   AffineExpr H = m.dim(1);
   AffineExpr W = m.dim(2);
@@ -850,9 +864,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{N, H, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, H, W, C}})
       .matchBody();
 }
 
@@ -870,7 +884,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
          "expected op to implement ConvolutionOpInterface");
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
-                       PoolingType::SUM);
+                       PoolingType::Sum);
   AffineExpr N = m.dim(0);
   AffineExpr H = m.dim(1);
   AffineExpr W = m.dim(2);
@@ -880,9 +894,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{N, H, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, H, W, C}})
       .matchBody();
 }
 
@@ -900,7 +914,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
          "expected op to implement ConvolutionOpInterface");
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
-                       PoolingType::MAX_UNSIGNED);
+                       PoolingType::MaxUnsigned);
   AffineExpr N = m.dim(0);
   AffineExpr H = m.dim(1);
   AffineExpr W = m.dim(2);
@@ -910,9 +924,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{N, H, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, H, W, C}})
       .matchBody();
 }
 
@@ -930,7 +944,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
          "expected op to implement ConvolutionOpInterface");
 
   ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
-                       PoolingType::MIN_UNSIGNED);
+                       PoolingType::MinUnsigned);
   AffineExpr N = m.dim(0);
   AffineExpr H = m.dim(1);
   AffineExpr W = m.dim(2);
@@ -940,9 +954,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
 
   return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
       .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
-      .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
-                   /*filterMap=*/{h, w},
-                   /*outputMap=*/{N, H, W, C}})
+      .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+                  /*filterMap=*/{h, w},
+                  /*outputMap=*/{N, H, W, C}})
       .matchBody();
 }
 

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, many thanks for the prompt reply 🙏🏻

@Abhishek-Varma Abhishek-Varma merged commit 7ce7141 into llvm:main Dec 1, 2025
13 checks passed
aahrun pushed a commit to aahrun/llvm-project that referenced this pull request Dec 1, 2025
-- This commit addresses [follow-up review comments on
169704](llvm#169704 (review)).
-- Contains NFC nit/minor changes.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants