-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[NFC][Linalg] Introduce ConvMatchBuilder + refactor Conv matchers #169704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NFC][Linalg] Introduce ConvMatchBuilder + refactor Conv matchers #169704
Conversation
-- This commit is a follow-up and third in the series of adding matchers for conv/pool ops. Refer: llvm#163724 -- It introduces ConvMatchBuilder class in order to reduce the repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants. -- Refer to [Conv2D thread](llvm#168362 (comment)) for further context. Signed-off-by: Abhishek Varma <abhvarma@amd.com>
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Abhishek Varma (Abhishek-Varma) Changes-- This commit is a follow-up and third in the series of adding Signed-off-by: Abhishek Varma <abhvarma@amd.com> Patch is 35.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169704.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 6b85e6ba0ede2..f7b86e7c385de 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -416,10 +416,6 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
return false;
}
-// ---------------------------------------------
-// Matchers for specific convolution operation.
-// ---------------------------------------------
-
/// Returns true if the given indexing maps matches with the expected indexing
/// maps.
static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
@@ -434,6 +430,92 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
})));
}
+/// Enum of all kinds of Pooling Op's type.
+enum PoolingType {
+ NONE,
+ MAX_SIGNED,
+ MAX_UNSIGNED,
+ MIN_SIGNED,
+ MIN_UNSIGNED,
+ SUM
+};
+
+/// Helper class for building convolution op matchers with minimal boilerplate.
+/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
+/// as Pooling ops.
+class ConvMatcherBuilder {
+ LinalgOp op;
+ MLIRContext *ctx;
+ SmallVector<int64_t> *dilations, *strides;
+ ArrayAttr indexingMaps;
+ PoolingType poolingType;
+ bool matched = true;
+
+public:
+ ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
+ SmallVector<int64_t> *s,
+ PoolingType poolingType = PoolingType::NONE)
+ : op(op), ctx(op->getContext()), dilations(d), strides(s),
+ indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
+ *dilations = SmallVector<int64_t>(spatialRank, 1);
+ *strides = SmallVector<int64_t>(spatialRank, 1);
+ }
+
+ /// Get affine dimension expression for dimension i.
+ AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
+
+ /// Build strided expression: base * stride[idx] + kernel * dilation[idx]
+ AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
+ return base * (*strides)[idx] + kernel * (*dilations)[idx];
+ }
+
+ /// Match stride/dilation pattern for a spatial dimension.
+ /// Returns *this for method chaining.
+ ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
+ unsigned idx) {
+ if (matched) {
+ matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
+ (*dilations)[idx], (*strides)[idx]);
+ }
+ return *this;
+ }
+
+ /// Match expected indexing maps layout.
+ /// Returns *this for method chaining.
+ ConvMatcherBuilder &expectMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
+ if (matched)
+ matched = convLayoutMatches(maps, indexingMaps, ctx);
+ return *this;
+ }
+
+ /// Match body pattern. This should be called last.
+ bool matchBody() {
+ if (!matched)
+ return false;
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ switch (poolingType) {
+ case PoolingType::NONE:
+ return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
+ case PoolingType::MAX_SIGNED:
+ return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::MAX_UNSIGNED:
+ return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::MIN_SIGNED:
+ return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::MIN_UNSIGNED:
+ return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
+ case PoolingType::SUM:
+ return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
+ }
+ return false;
+ }
+};
+
+// ---------------------------------------------
+// Matchers for specific convolution operation.
+// ---------------------------------------------
+
// #inputMap = affine_map<(W, w) -> (W + w)>
// #filterMap = affine_map<(W, w) -> (w)>
// #outputMap = affine_map<(W, w) -> (W)>
@@ -447,29 +529,15 @@ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
- MLIRContext *context = op->getContext();
- AffineExpr W = getAffineDimExpr(0, context);
- AffineExpr w = getAffineDimExpr(1, context);
- ArrayAttr indexingMaps = op.getIndexingMaps();
- // First fetch dilations/strides :-
- // Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
- /*oDim=*/0, (*dilations)[0], (*strides)[0]))
- return false;
- // Match expected indexing maps
- if (!convLayoutMatches(
- {/*inputMap=*/{W * (*strides)[0] + w * (*dilations)[0]},
- /*filterMap=*/{w},
- /*outputMap=*/{W}},
- indexingMaps, context))
- return false;
- // Match body
- Block *body = op.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr W = m.dim(0);
+ AffineExpr w = m.dim(1);
+
+ return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .expectMaps({/*inputMap=*/{m.strided(W, w, 0)},
+ /*filterMap=*/{w},
+ /*outputMap=*/{W}})
+ .matchBody();
}
// #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)>
@@ -485,32 +553,18 @@ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
- MLIRContext *context = op->getContext();
- AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr W = getAffineDimExpr(1, context);
- AffineExpr F = getAffineDimExpr(2, context);
- AffineExpr w = getAffineDimExpr(3, context);
- AffineExpr c = getAffineDimExpr(4, context);
- ArrayAttr indexingMaps = op.getIndexingMaps();
- // First fetch dilations/strides :-
- // Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
- /*oDim=*/1, (*dilations)[0], (*strides)[0]))
- return false;
- // Match expected indexing maps
- if (!convLayoutMatches(
- {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], c},
- /*filterMap=*/{w, c, F},
- /*outputMap=*/{N, W, F}},
- indexingMaps, context))
- return false;
- // Match body
- Block *body = op.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr F = m.dim(2);
+ AffineExpr w = m.dim(3);
+ AffineExpr c = m.dim(4);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
+ /*filterMap=*/{w, c, F},
+ /*outputMap=*/{N, W, F}})
+ .matchBody();
}
// #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)>
@@ -526,32 +580,18 @@ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
- MLIRContext *context = op->getContext();
- AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr F = getAffineDimExpr(1, context);
- AffineExpr W = getAffineDimExpr(2, context);
- AffineExpr c = getAffineDimExpr(3, context);
- AffineExpr w = getAffineDimExpr(4, context);
- ArrayAttr indexingMaps = op.getIndexingMaps();
- // First fetch dilations/strides :-
- // Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
- /*oDim=*/2, (*dilations)[0], (*strides)[0]))
- return false;
- // Match expected indexing maps
- if (!convLayoutMatches(
- {/*inputMap=*/{N, c, W * (*strides)[0] + w * (*dilations)[0]},
- /*filterMap=*/{F, c, w},
- /*outputMap=*/{N, F, W}},
- indexingMaps, context))
- return false;
- // Match body
- Block *body = op.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr F = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr c = m.dim(3);
+ AffineExpr w = m.dim(4);
+
+ return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+ .expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
+ /*filterMap=*/{F, c, w},
+ /*outputMap=*/{N, F, W}})
+ .matchBody();
}
// #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)>
@@ -567,36 +607,18 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(2, 1);
- *strides = SmallVector<int64_t>(2, 1);
- MLIRContext *context = op->getContext();
- AffineExpr H = getAffineDimExpr(0, context);
- AffineExpr W = getAffineDimExpr(1, context);
- AffineExpr h = getAffineDimExpr(2, context);
- AffineExpr w = getAffineDimExpr(3, context);
- ArrayAttr indexingMaps = op.getIndexingMaps();
- // First fetch dilations/strides :-
- // Match: H * stride + h * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
- /*oDim=*/0, (*dilations)[0], (*strides)[0]))
- return false;
- // Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
- /*oDim=*/1, (*dilations)[1], (*strides)[1]))
- return false;
- // Match expected indexing maps
- if (!convLayoutMatches(
- {/*inputMap=*/{H * (*strides)[0] + h * (*dilations)[0],
- W * (*strides)[1] + w * (*dilations)[1]},
- /*filterMap=*/{h, w},
- /*outputMap=*/{H, W}},
- indexingMaps, context))
- return false;
- // Match body
- Block *body = op.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ AffineExpr H = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr h = m.dim(2);
+ AffineExpr w = m.dim(3);
+
+ return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+ .expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{H, W}})
+ .matchBody();
}
// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
@@ -612,43 +634,22 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(3, 1);
- *strides = SmallVector<int64_t>(3, 1);
- MLIRContext *context = op->getContext();
- AffineExpr D = getAffineDimExpr(0, context);
- AffineExpr H = getAffineDimExpr(1, context);
- AffineExpr W = getAffineDimExpr(2, context);
- AffineExpr d = getAffineDimExpr(3, context);
- AffineExpr h = getAffineDimExpr(4, context);
- AffineExpr w = getAffineDimExpr(5, context);
- ArrayAttr indexingMaps = op.getIndexingMaps();
- // First fetch dilations/strides :-
- // Match: D * stride + d * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
- /*oDim=*/0, (*dilations)[0], (*strides)[0]))
- return false;
- // Match: H * stride + h * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
- /*oDim=*/1, (*dilations)[1], (*strides)[1]))
- return false;
- // Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
- /*oDim=*/2, (*dilations)[2], (*strides)[2]))
- return false;
- // Match expected indexing maps
- if (!convLayoutMatches(
- {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
- H * (*strides)[1] + h * (*dilations)[1],
- W * (*strides)[2] + w * (*dilations)[2]},
- /*filterMap=*/{d, h, w},
- /*outputMap=*/{D, H, W}},
- indexingMaps, context))
- return false;
- // Match body
- Block *body = op.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ AffineExpr D = m.dim(0);
+ AffineExpr H = m.dim(1);
+ AffineExpr W = m.dim(2);
+ AffineExpr d = m.dim(3);
+ AffineExpr h = m.dim(4);
+ AffineExpr w = m.dim(5);
+
+ return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+ .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
+ .expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2)},
+ /*filterMap=*/{d, h, w},
+ /*outputMap=*/{D, H, W}})
+ .matchBody();
}
// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)>
@@ -664,31 +665,17 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
- MLIRContext *context = op->getContext();
- AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr W = getAffineDimExpr(1, context);
- AffineExpr C = getAffineDimExpr(2, context);
- AffineExpr w = getAffineDimExpr(3, context);
- ArrayAttr indexingMaps = op.getIndexingMaps();
- // First fetch dilations/strides :-
- // Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
- /*oDim=*/2, (*dilations)[0], (*strides)[0]))
- return false;
- // Match expected indexing maps
- if (!convLayoutMatches(
- {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
- /*filterMap=*/{C, w},
- /*outputMap=*/{N, C, W}},
- indexingMaps, context))
- return false;
- // Match body
- Block *body = op.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr C = m.dim(2);
+ AffineExpr w = m.dim(3);
+
+ return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+ .expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+ /*filterMap=*/{C, w},
+ /*outputMap=*/{N, C, W}})
+ .matchBody();
}
// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
@@ -704,31 +691,17 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
- MLIRContext *context = op->getContext();
- AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr W = getAffineDimExpr(1, context);
- AffineExpr C = getAffineDimExpr(2, context);
- AffineExpr w = getAffineDimExpr(3, context);
- ArrayAttr indexingMaps = op.getIndexingMaps();
- // First fetch dilations/strides :-
- // Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
- /*oDim=*/1, (*dilations)[0], (*strides)[0]))
- return false;
- // Match expected indexing maps
- if (!convLayoutMatches(
- {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
- /*filterMap=*/{w, C},
- /*outputMap=*/{N, W, C}},
- indexingMaps, context))
- return false;
- // Match body
- Block *body = op.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr C = m.dim(2);
+ AffineExpr w = m.dim(3);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C},
+ /*outputMap=*/{N, W, C}})
+ .matchBody();
}
// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
@@ -744,32 +717,18 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
- MLIRContext *context = op->getContext();
- AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr W = getAffineDimExpr(1, context);
- AffineExpr C = getAffineDimExpr(2, context);
- AffineExpr CM = getAffineDimExpr(3, context);
- AffineExpr w = getAffineDimExpr(4, context);
- ArrayAttr indexingMaps = op.getIndexingMaps();
- // First fetch dilations/strides :-
- // Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
- /*oDim=*/1, (*dilations)[0], (*strides)[0]))
- return false;
- // Match expected indexing maps
- if (!convLayoutMatches(
- {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
- /*filterMap=*/{w, C, CM},
- /*outputMap=*/{N, W, C, CM}},
- indexingMaps, context))
- return false;
- // Match body
- Block *body = op.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- return bodyMatcherForConvolutionOps(yieldVal, body);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ AffineExpr N = m.dim(0);
+ AffineExpr W = m.dim(1);
+ AffineExpr C = m.dim(2);
+ AffineExpr CM = m.dim(3);
+ AffineExpr w = m.dim(4);
+
+ return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C, CM},
+ /*outputMap=*/{N, W, C, CM}})
+ .matchBody();
}
// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
@@ -785,38 +744,20 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(2, 1);
- *strides = SmallVector<int64_t>(2, 1);
- MLIRContext *context = op->getContext();
- AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr H = getAffineDimExpr(1, context);
- AffineExpr W = getAffineDimExpr(2, context);
- AffineExpr C = getAffineDimExpr(3, context);
- AffineExpr h = getAffineDimExpr(4, context);
- AffineExpr w = getAffineDimExpr(5, context);
- ArrayAttr indexingMaps = op.getIndexingMaps();
- // First fetch dilations/strides :-
- // Match: H * stride + h * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
- /*oDim=*/2, (*dilations)[0], (*strides)[...
[truncated]
|
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice cleanup, thanks a ton! Let's also mark the PR with nfc. Just few nits.
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/198/builds/10076 Here is the relevant piece of the build log for the reference |
-- This commit is a follow-up and third in the series of adding
matchers for conv/pool ops. Refer: #163724
-- It introduces ConvMatchBuilder class in order to reduce the
repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants.
-- Refer to Conv2D thread for further context.
Signed-off-by: Abhishek Varma abhvarma@amd.com