-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[NFC][Linalg] Follow-up on ConvMatchBuilder #170080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NFC][Linalg] Follow-up on ConvMatchBuilder #170080
Conversation
-- This commit addresses [follow-up review comments on 169704](llvm#169704 (review)). Signed-off-by: Abhishek Varma <abhvarma@amd.com>
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Abhishek Varma (Abhishek-Varma) Changes-- This commit addresses follow-up review comments on 169704. Signed-off-by: Abhishek Varma <abhvarma@amd.com> Full diff: https://github.com/llvm/llvm-project/pull/170080.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e85a2ab26bd32..01e6e1e248658 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -430,19 +430,33 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
})));
}
-/// Enum of all kinds of Pooling Op's type.
-enum PoolingType {
- NONE,
- MAX_SIGNED,
- MAX_UNSIGNED,
- MIN_SIGNED,
- MIN_UNSIGNED,
- SUM
+/// Enum representing pooling operation types used by ConvMatcherBuilder.
+enum class PoolingType {
+ None,
+ MaxSigned,
+ MaxUnsigned,
+ MinSigned,
+ MinUnsigned,
+ Sum
};
/// Helper class for building convolution op matchers with minimal boilerplate.
/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
/// as Pooling ops.
+///
+/// Usage: Create an instance with the op, spatial rank, and output pointers for
+/// extracted dilations/strides. Then chain matchStride() calls for each spatial
+/// dimension, followed by matchMaps() to verify indexing maps, and finally
+/// matchBody() to verify the operation body pattern.
+///
+/// The `matched` flag starts as `true` and is set to `false` if any match step
+/// fails. This allows chaining multiple match calls; once any match fails, all
+/// subsequent calls become no-ops and the final result is `false`.
+///
+/// The `dilations` and `strides` pointers are output parameters that get
+/// populated with the extracted dilation and stride values from the operation's
+/// indexing maps during matchStride() calls. These values are initially set to
+/// 1 for each spatial dimension and updated as patterns are matched.
class ConvMatcherBuilder {
LinalgOp op;
MLIRContext *ctx;
@@ -454,7 +468,7 @@ class ConvMatcherBuilder {
public:
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
SmallVector<int64_t> *s,
- PoolingType poolingType = PoolingType::NONE)
+ PoolingType poolingType = PoolingType::None)
: op(op), ctx(op->getContext()), dilations(d), strides(s),
indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
*dilations = SmallVector<int64_t>(spatialRank, 1);
@@ -474,16 +488,16 @@ class ConvMatcherBuilder {
ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
unsigned idx) {
if (matched) {
- matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
- (*dilations)[idx], (*strides)[idx]);
+ matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
+ (*dilations)[idx], (*strides)[idx]);
}
return *this;
}
/// Match expected indexing maps layout. Returns *this for method chaining.
- ConvMatcherBuilder &expectMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
+ ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
if (matched)
- matched = convLayoutMatches(maps, indexingMaps, ctx);
+ matched &= convLayoutMatches(maps, indexingMaps, ctx);
return *this;
}
@@ -494,17 +508,17 @@ class ConvMatcherBuilder {
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
switch (poolingType) {
- case PoolingType::NONE:
+ case PoolingType::None:
return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
- case PoolingType::MAX_SIGNED:
+ case PoolingType::MaxSigned:
return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
- case PoolingType::MAX_UNSIGNED:
+ case PoolingType::MaxUnsigned:
return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
- case PoolingType::MIN_SIGNED:
+ case PoolingType::MinSigned:
return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
- case PoolingType::MIN_UNSIGNED:
+ case PoolingType::MinUnsigned:
return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
- case PoolingType::SUM:
+ case PoolingType::Sum:
return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
}
return false;
@@ -533,9 +547,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
AffineExpr w = m.dim(1);
return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
- .expectMaps({/*inputMap=*/{m.strided(W, w, 0)},
- /*filterMap=*/{w},
- /*outputMap=*/{W}})
+ .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
+ /*filterMap=*/{w},
+ /*outputMap=*/{W}})
.matchBody();
}
@@ -560,9 +574,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
AffineExpr c = m.dim(4);
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
- /*filterMap=*/{w, c, F},
- /*outputMap=*/{N, W, F}})
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
+ /*filterMap=*/{w, c, F},
+ /*outputMap=*/{N, W, F}})
.matchBody();
}
@@ -587,9 +601,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
AffineExpr w = m.dim(4);
return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
- .expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
- /*filterMap=*/{F, c, w},
- /*outputMap=*/{N, F, W}})
+ .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
+ /*filterMap=*/{F, c, w},
+ /*outputMap=*/{N, F, W}})
.matchBody();
}
@@ -614,9 +628,9 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
- .expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{h, w},
- /*outputMap=*/{H, W}})
+ .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{H, W}})
.matchBody();
}
@@ -644,10 +658,10 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
- .expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2)},
- /*filterMap=*/{d, h, w},
- /*outputMap=*/{D, H, W}})
+ .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2)},
+ /*filterMap=*/{d, h, w},
+ /*outputMap=*/{D, H, W}})
.matchBody();
}
@@ -671,9 +685,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
AffineExpr w = m.dim(3);
return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
- .expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
- /*filterMap=*/{C, w},
- /*outputMap=*/{N, C, W}})
+ .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+ /*filterMap=*/{C, w},
+ /*outputMap=*/{N, C, W}})
.matchBody();
}
@@ -697,9 +711,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
AffineExpr w = m.dim(3);
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w, C},
- /*outputMap=*/{N, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C},
+ /*outputMap=*/{N, W, C}})
.matchBody();
}
@@ -724,9 +738,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
AffineExpr w = m.dim(4);
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w, C, CM},
- /*outputMap=*/{N, W, C, CM}})
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C, CM},
+ /*outputMap=*/{N, W, C, CM}})
.matchBody();
}
@@ -753,9 +767,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
.matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{C, h, w},
- /*outputMap=*/{N, C, H, W}})
+ .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{C, h, w},
+ /*outputMap=*/{N, C, H, W}})
.matchBody();
}
@@ -789,10 +803,10 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
.matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
- .expectMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2), C},
- /*filterMap=*/{d, h, w, C, CM},
- /*outputMap=*/{N, D, H, W, C, CM}})
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), C},
+ /*filterMap=*/{d, h, w, C, CM},
+ /*outputMap=*/{N, D, H, W, C, CM}})
.matchBody();
}
@@ -810,7 +824,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
"expected op to implement ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MAX_SIGNED);
+ PoolingType::MaxSigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -820,9 +834,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
.matchBody();
}
@@ -840,7 +854,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
"expected op to implement ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MIN_SIGNED);
+ PoolingType::MinSigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -850,9 +864,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
.matchBody();
}
@@ -870,7 +884,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
"expected op to implement ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::SUM);
+ PoolingType::Sum);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -880,9 +894,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
.matchBody();
}
@@ -900,7 +914,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
"expected op to implement ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MAX_UNSIGNED);
+ PoolingType::MaxUnsigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -910,9 +924,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
.matchBody();
}
@@ -930,7 +944,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
"expected op to implement ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MIN_UNSIGNED);
+ PoolingType::MinUnsigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -940,9 +954,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
+ .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
.matchBody();
}
|
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, many thanks for the prompt reply 🙏🏻
-- This commit addresses [follow-up review comments on 169704](llvm#169704 (review)). -- Contains NFC nit/minor changes. Signed-off-by: Abhishek Varma <abhvarma@amd.com>
-- This commit addresses follow-up review comments on 169704.
-- Contains NFC nit/minor changes.
Signed-off-by: Abhishek Varma abhvarma@amd.com