From 115efb87a91fcaad0ce40c3705f44bad3c297c45 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Mon, 1 Dec 2025 01:27:14 -0600 Subject: [PATCH] [NFC][Linalg] Follow-up on ConvMatchBuilder -- This commit addresses [follow-up review comments on 169704](https://github.com/llvm/llvm-project/pull/169704#pullrequestreview-3521785548). Signed-off-by: Abhishek Varma --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 156 +++++++++++++----------- 1 file changed, 85 insertions(+), 71 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index e85a2ab26bd32..01e6e1e248658 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -430,19 +430,33 @@ static bool convLayoutMatches(ArrayRef> mapListExpected, }))); } -/// Enum of all kinds of Pooling Op's type. -enum PoolingType { - NONE, - MAX_SIGNED, - MAX_UNSIGNED, - MIN_SIGNED, - MIN_UNSIGNED, - SUM +/// Enum representing pooling operation types used by ConvMatcherBuilder. +enum class PoolingType { + None, + MaxSigned, + MaxUnsigned, + MinSigned, + MinUnsigned, + Sum }; /// Helper class for building convolution op matchers with minimal boilerplate. /// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well /// as Pooling ops. +/// +/// Usage: Create an instance with the op, spatial rank, and output pointers for +/// extracted dilations/strides. Then chain matchStride() calls for each spatial +/// dimension, followed by matchMaps() to verify indexing maps, and finally +/// matchBody() to verify the operation body pattern. +/// +/// The `matched` flag starts as `true` and is set to `false` if any match step +/// fails. This allows chaining multiple match calls; once any match fails, all +/// subsequent calls become no-ops and the final result is `false`. +/// +/// The `dilations` and `strides` pointers are output parameters that get +/// populated with the extracted dilation and stride values from the operation's +/// indexing maps during matchStride() calls. These values are initially set to +/// 1 for each spatial dimension and updated as patterns are matched. class ConvMatcherBuilder { LinalgOp op; MLIRContext *ctx; @@ -454,7 +468,7 @@ class ConvMatcherBuilder { public: ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector *d, SmallVector *s, - PoolingType poolingType = PoolingType::NONE) + PoolingType poolingType = PoolingType::None) : op(op), ctx(op->getContext()), dilations(d), strides(s), indexingMaps(op.getIndexingMaps()), poolingType(poolingType) { *dilations = SmallVector(spatialRank, 1); @@ -474,16 +488,16 @@ class ConvMatcherBuilder { ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim, unsigned idx) { if (matched) { - matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim, - (*dilations)[idx], (*strides)[idx]); + matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim, + (*dilations)[idx], (*strides)[idx]); } return *this; } /// Match expected indexing maps layout. Returns *this for method chaining. - ConvMatcherBuilder &expectMaps(ArrayRef> maps) { + ConvMatcherBuilder &matchMaps(ArrayRef> maps) { if (matched) - matched = convLayoutMatches(maps, indexingMaps, ctx); + matched &= convLayoutMatches(maps, indexingMaps, ctx); return *this; } @@ -494,17 +508,17 @@ class ConvMatcherBuilder { Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); switch (poolingType) { - case PoolingType::NONE: + case PoolingType::None: return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body); - case PoolingType::MAX_SIGNED: + case PoolingType::MaxSigned: return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body); - case PoolingType::MAX_UNSIGNED: + case PoolingType::MaxUnsigned: return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body); - case PoolingType::MIN_SIGNED: + case PoolingType::MinSigned: return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body); - case PoolingType::MIN_UNSIGNED: + case PoolingType::MinUnsigned: return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body); - case PoolingType::SUM: + case PoolingType::Sum: return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body); } return false; @@ -533,9 +547,9 @@ bool isaConvolutionOpOfType(LinalgOp op, AffineExpr w = m.dim(1); return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) - .expectMaps({/*inputMap=*/{m.strided(W, w, 0)}, - /*filterMap=*/{w}, - /*outputMap=*/{W}}) + .matchMaps({/*inputMap=*/{m.strided(W, w, 0)}, + /*filterMap=*/{w}, + /*outputMap=*/{W}}) .matchBody(); } @@ -560,9 +574,9 @@ bool isaConvolutionOpOfType( AffineExpr c = m.dim(4); return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) - .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c}, - /*filterMap=*/{w, c, F}, - /*outputMap=*/{N, W, F}}) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c}, + /*filterMap=*/{w, c, F}, + /*outputMap=*/{N, W, F}}) .matchBody(); } @@ -587,9 +601,9 @@ bool isaConvolutionOpOfType( AffineExpr w = m.dim(4); return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0) - .expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)}, - /*filterMap=*/{F, c, w}, - /*outputMap=*/{N, F, W}}) + .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)}, + /*filterMap=*/{F, c, w}, + /*outputMap=*/{N, F, W}}) .matchBody(); } @@ -614,9 +628,9 @@ bool isaConvolutionOpOfType(LinalgOp op, return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1) - .expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)}, - /*filterMap=*/{h, w}, - /*outputMap=*/{H, W}}) + .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{h, w}, + /*outputMap=*/{H, W}}) .matchBody(); } @@ -644,10 +658,10 @@ bool isaConvolutionOpOfType(LinalgOp op, return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1) .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2) - .expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1), - m.strided(W, w, 2)}, - /*filterMap=*/{d, h, w}, - /*outputMap=*/{D, H, W}}) + .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2)}, + /*filterMap=*/{d, h, w}, + /*outputMap=*/{D, H, W}}) .matchBody(); } @@ -671,9 +685,9 @@ bool isaConvolutionOpOfType( AffineExpr w = m.dim(3); return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) - .expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)}, - /*filterMap=*/{C, w}, - /*outputMap=*/{N, C, W}}) + .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)}, + /*filterMap=*/{C, w}, + /*outputMap=*/{N, C, W}}) .matchBody(); } @@ -697,9 +711,9 @@ bool isaConvolutionOpOfType( AffineExpr w = m.dim(3); return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) - .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, - /*filterMap=*/{w, C}, - /*outputMap=*/{N, W, C}}) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w, C}, + /*outputMap=*/{N, W, C}}) .matchBody(); } @@ -724,9 +738,9 @@ bool isaConvolutionOpOfType( AffineExpr w = m.dim(4); return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) - .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, - /*filterMap=*/{w, C, CM}, - /*outputMap=*/{N, W, C, CM}}) + .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w, C, CM}, + /*outputMap=*/{N, W, C, CM}}) .matchBody(); } @@ -753,9 +767,9 @@ bool isaConvolutionOpOfType( return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1) - .expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)}, - /*filterMap=*/{C, h, w}, - /*outputMap=*/{N, C, H, W}}) + .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{C, h, w}, + /*outputMap=*/{N, C, H, W}}) .matchBody(); } @@ -789,10 +803,10 @@ bool isaConvolutionOpOfType( return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) - .expectMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), - m.strided(W, w, 2), C}, - /*filterMap=*/{d, h, w, C, CM}, - /*outputMap=*/{N, D, H, W, C, CM}}) + .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), C}, + /*filterMap=*/{d, h, w, C, CM}, + /*outputMap=*/{N, D, H, W, C, CM}}) .matchBody(); } @@ -810,7 +824,7 @@ bool isaConvolutionOpOfType( "expected op to implement ConvolutionOpInterface"); ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, - PoolingType::MAX_SIGNED); + PoolingType::MaxSigned); AffineExpr N = m.dim(0); AffineExpr H = m.dim(1); AffineExpr W = m.dim(2); @@ -820,9 +834,9 @@ bool isaConvolutionOpOfType( return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) - .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, - /*filterMap=*/{h, w}, - /*outputMap=*/{N, H, W, C}}) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) .matchBody(); } @@ -840,7 +854,7 @@ bool isaConvolutionOpOfType( "expected op to implement ConvolutionOpInterface"); ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, - PoolingType::MIN_SIGNED); + PoolingType::MinSigned); AffineExpr N = m.dim(0); AffineExpr H = m.dim(1); AffineExpr W = m.dim(2); @@ -850,9 +864,9 @@ bool isaConvolutionOpOfType( return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) - .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, - /*filterMap=*/{h, w}, - /*outputMap=*/{N, H, W, C}}) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) .matchBody(); } @@ -870,7 +884,7 @@ bool isaConvolutionOpOfType( "expected op to implement ConvolutionOpInterface"); ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, - PoolingType::SUM); + PoolingType::Sum); AffineExpr N = m.dim(0); AffineExpr H = m.dim(1); AffineExpr W = m.dim(2); @@ -880,9 +894,9 @@ bool isaConvolutionOpOfType( return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) - .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, - /*filterMap=*/{h, w}, - /*outputMap=*/{N, H, W, C}}) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) .matchBody(); } @@ -900,7 +914,7 @@ bool isaConvolutionOpOfType( "expected op to implement ConvolutionOpInterface"); ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, - PoolingType::MAX_UNSIGNED); + PoolingType::MaxUnsigned); AffineExpr N = m.dim(0); AffineExpr H = m.dim(1); AffineExpr W = m.dim(2); @@ -910,9 +924,9 @@ bool isaConvolutionOpOfType( return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) - .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, - /*filterMap=*/{h, w}, - /*outputMap=*/{N, H, W, C}}) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) .matchBody(); } @@ -930,7 +944,7 @@ bool isaConvolutionOpOfType( "expected op to implement ConvolutionOpInterface"); ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, - PoolingType::MIN_UNSIGNED); + PoolingType::MinUnsigned); AffineExpr N = m.dim(0); AffineExpr H = m.dim(1); AffineExpr W = m.dim(2); @@ -940,9 +954,9 @@ bool isaConvolutionOpOfType( return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) - .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, - /*filterMap=*/{h, w}, - /*outputMap=*/{N, H, W, C}}) + .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) .matchBody(); }