-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[Linalg] Add *Conv2D* matchers #168362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Linalg] Add *Conv2D* matchers #168362
Conversation
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Abhishek Varma (Abhishek-Varma) Changes-- This commit is the third in the series of adding matchers Signed-off-by: Abhishek Varma <abhvarma@amd.com> Patch is 62.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168362.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index c2485a08932dd..b52b93f8cc9b9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -279,6 +279,17 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
CONV_OP_SPECIALIZER(linalg::Conv2DOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp);
+ CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp);
CONV_OP_SPECIALIZER(linalg::Conv3DOp);
// -----------------------------
// Depthwise Convolution ops.
@@ -287,6 +298,10 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp);
+ CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
// -----------------------------
// Pooling ops.
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 6b85e6ba0ede2..57593abac7ab0 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,8 +240,8 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
//===----------------------------------------------------------------------===//
/// Returns the BlockArgument that leads to `val`, if any. Traverses optional
-/// ext* ops.
-static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
+/// ext*/sitofp ops.
+static BlockArgument getBlockArgumentWithOptionalCastOps(Value val) {
BlockArgument blockArg = dyn_cast<BlockArgument>(val);
if ((blockArg))
return blockArg;
@@ -249,18 +249,62 @@ static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) {
Operation *defOp = val.getDefiningOp();
if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
!dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
- !dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
+ !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
+ !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
return nullptr;
}
return dyn_cast<BlockArgument>(defOp->getOperand(0));
}
+/// Utility function to match the zero point offset body of convolution ops.
+/// It takes input the addition op and multiplication op expected in every
+/// convolution op and matches the following for both operands of multiplication
+/// op :-
+/// %a - %b
+/// where, %a and %b can have optional upcast operation.
+static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
+ Block *body) {
+ Operation *subOp1 = mulOp->getOperand(0).getDefiningOp();
+ if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp1))
+ return false;
+ Operation *subOp2 = mulOp->getOperand(1).getDefiningOp();
+ if (!isa_and_present<arith::SubIOp, arith::SubFOp>(subOp2))
+ return false;
+ BlockArgument inputBlockArg =
+ getBlockArgumentWithOptionalCastOps(subOp1->getOperand(0));
+ BlockArgument inputScalarBlockArg =
+ getBlockArgumentWithOptionalCastOps(subOp1->getOperand(1));
+ BlockArgument filterBlockArg =
+ getBlockArgumentWithOptionalCastOps(subOp2->getOperand(0));
+ BlockArgument filterScalarBlockArg =
+ getBlockArgumentWithOptionalCastOps(subOp2->getOperand(1));
+ BlockArgument outBlockArg =
+ getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+ if (!inputBlockArg || !inputScalarBlockArg || !filterBlockArg ||
+ !filterScalarBlockArg || !outBlockArg ||
+ inputBlockArg.getOwner() != body ||
+ inputScalarBlockArg.getOwner() != body ||
+ filterBlockArg.getOwner() != body ||
+ filterScalarBlockArg.getOwner() != body ||
+ outBlockArg.getOwner() != body || inputBlockArg.getArgNumber() != 0 ||
+ inputScalarBlockArg.getArgNumber() != 2 ||
+ filterBlockArg.getArgNumber() != 1 ||
+ filterScalarBlockArg.getArgNumber() != 3 ||
+ outBlockArg.getArgNumber() != 4)
+ return false;
+ return true;
+}
+
/// Utility to match block body for convolution ops.
/// The body is thus expected to yield :-
/// %out + (%lhs * %rhs)
/// where: %lhs, %rhs and %out are block arguments and
/// %lhs and %rhs can have optional upcast operation.
-static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
+/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
+/// %input - %input_scalar
+/// where, %input_scalar can have optional upcast operation.
+static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
+ bool zeroPointOffset = false) {
Operation *addOp = yieldVal.getDefiningOp();
if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
return false;
@@ -269,12 +313,15 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) {
if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
return false;
+ if (zeroPointOffset) {
+ return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
+ }
BlockArgument lhsBlockArg =
- getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
BlockArgument rhsBlockArg =
- getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1));
+ getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
BlockArgument outBlockArg =
- getBlockArgumentWithOptionalExtOps(addOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
@@ -291,9 +338,9 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
return false;
BlockArgument lhsArg =
- getBlockArgumentWithOptionalExtOps(defOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(defOp->getOperand(0));
BlockArgument rhsArg =
- getBlockArgumentWithOptionalExtOps(defOp->getOperand(1));
+ getBlockArgumentWithOptionalCastOps(defOp->getOperand(1));
if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
rhsArg.getArgNumber() != 0)
@@ -599,49 +646,45 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
return bodyMatcherForConvolutionOps(yieldVal, body);
}
-// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
-// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
-// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
+// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
template <>
-bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv3DOp>(op))
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwcFhwcOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(3, 1);
- *strides = SmallVector<int64_t>(3, 1);
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
MLIRContext *context = op->getContext();
- AffineExpr D = getAffineDimExpr(0, context);
+ AffineExpr N = getAffineDimExpr(0, context);
AffineExpr H = getAffineDimExpr(1, context);
AffineExpr W = getAffineDimExpr(2, context);
- AffineExpr d = getAffineDimExpr(3, context);
+ AffineExpr F = getAffineDimExpr(3, context);
AffineExpr h = getAffineDimExpr(4, context);
AffineExpr w = getAffineDimExpr(5, context);
+ AffineExpr c = getAffineDimExpr(6, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
- // Match: D * stride + d * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
- /*oDim=*/0, (*dilations)[0], (*strides)[0]))
- return false;
// Match: H * stride + h * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
- /*oDim=*/1, (*dilations)[1], (*strides)[1]))
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
return false;
// Match: W * stride + w * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
- /*oDim=*/2, (*dilations)[2], (*strides)[2]))
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
- {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
- H * (*strides)[1] + h * (*dilations)[1],
- W * (*strides)[2] + w * (*dilations)[2]},
- /*filterMap=*/{d, h, w},
- /*outputMap=*/{D, H, W}},
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{F, h, w, c},
+ /*outputMap=*/{N, H, W, F}},
indexingMaps, context))
return false;
// Match body
@@ -651,37 +694,45 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
return bodyMatcherForConvolutionOps(yieldVal, body);
}
-// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)>
-// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
-// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
+// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+ if (isa<linalg::Conv2DNhwcHwcfOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
MLIRContext *context = op->getContext();
AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr W = getAffineDimExpr(1, context);
- AffineExpr C = getAffineDimExpr(2, context);
- AffineExpr w = getAffineDimExpr(3, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr F = getAffineDimExpr(3, context);
+ AffineExpr h = getAffineDimExpr(4, context);
+ AffineExpr w = getAffineDimExpr(5, context);
+ AffineExpr c = getAffineDimExpr(6, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+ /*oDim=*/1, (*dilations)[0], (*strides)[0]))
+ return false;
// Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
- /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
- {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
- /*filterMap=*/{C, w},
- /*outputMap=*/{N, C, W}},
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{h, w, c, F},
+ /*outputMap=*/{N, H, W, F}},
indexingMaps, context))
return false;
// Match body
@@ -691,37 +742,196 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
return bodyMatcherForConvolutionOps(yieldVal, body);
}
-// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
-// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
-// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
+// #inputMap = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ if (isa<linalg::Conv2DNchwFchwOp>(op))
return true;
assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
MLIRContext *context = op->getContext();
AffineExpr N = getAffineDimExpr(0, context);
- AffineExpr W = getAffineDimExpr(1, context);
- AffineExpr C = getAffineDimExpr(2, context);
- AffineExpr w = getAffineDimExpr(3, context);
+ AffineExpr F = getAffineDimExpr(1, context);
+ AffineExpr H = getAffineDimExpr(2, context);
+ AffineExpr W = getAffineDimExpr(3, context);
+ AffineExpr C = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ return false;
// Match: W * stride + w * dilation
- if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1]},
+ /*filterMap=*/{F, C, h, w},
+ /*outputMap=*/{N, F, H, W}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body);
+}
+
+// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)>
+// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)>
+// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()>
+// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr H = getAffineDimExpr(1, context);
+ AffineExpr W = getAffineDimExpr(2, context);
+ AffineExpr F = getAffineDimExpr(3, context);
+ AffineExpr h = getAffineDimExpr(4, context);
+ AffineExpr w = getAffineDimExpr(5, context);
+ AffineExpr c = getAffineDimExpr(6, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[1], (*strides)[1]))
+ return false;
// Match expected indexing maps
if (!convLayoutMatches(
- {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
- /*filterMap=*/{w, C},
- /*outputMap=*/{N, W, C}},
+ {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1], c},
+ /*filterMap=*/{F, h, w, c},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, F}},
+ indexingMaps, context))
+ return false;
+ // Match body
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ return bodyMatcherForConvolutionOps(yieldVal, body, /*zeroPointOffset=*/true);
+}
+
+// #inputMap = affine_map<(N, F, H, W, C, h, w) -> (N, C, H + h, W + w)>
+// #filterMap = affine_map<(N, F, H, W, C, h, w) -> (F, C, h, w)>
+// #scalarMap = affine_map<(N, F, H, W, C, h, w) -> ()>
+// #outputMap = affine_map<(N, F, H, W, C, h, w) -> (N, F, H, W)>
+template <>
+bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNchwFchwQOp>(op))
+ return true;
+
+ assert(isaConvolutionOpInterface(op) &&
+ "expected op to implement ConvolutionOpInterface");
+
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
+ MLIRContext *context = op->getContext();
+ AffineExpr N = getAffineDimExpr(0, context);
+ AffineExpr F = getAffineDimExpr(1, context);
+ AffineExpr H = getAffineDimExpr(2, context);
+ AffineExpr W = getAffineDimExpr(3, context);
+ AffineExpr C = getAffineDimExpr(4, context);
+ AffineExpr h = getAffineDimExpr(5, context);
+ AffineExpr w = getAffineDimExpr(6, context);
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ // First fetch dilations/strides :-
+ // Match: H * stride + h * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, (*dilations)[0], (*strides)[0]))
+ return false;
+ // Match: W * stride + w * dilation
+ if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, (*dilations)[1], (*strides)[1]))
+ return false;
+ // Match expected indexing maps
+ if (!convLayoutMatches(
+ {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
+ W * (*strides)[1] + w * (*dilations)[1]},
+ /*filterMap=*/{F, C, h, w},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, F, H, W}},
+ indexingMaps, context))
+ return false;
+ // Match bo...
[truncated]
|
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't reviewed detailed implementation yet, just pointing out some missing tests. Github/git is not doing good at diff this time, so it'll take me some time to review it.
(also, I'm still thinking if we can refactor something or not, so we don't need many changes.)
|
@Abhishek-Varma I've been thinking how to make the implementation simpler. Below is the plan that I brainstormed with AI agent, and I think we can do the refactoring first. Then this PR will be much easier to review. The main benefit is that the code itself documents the expected indexing maps, and it simplifies code a lot to me. Can you take a look and let me know what you think? Problem: Each template specialization follows the exact same pattern (~50-60 lines each), leading to ~800+ lines of highly repetitive code across the 15 new matchers. This pattern is consistent across ALL conv types (1D, 2D, 3D, regular, depthwise, quantized). The repeated boilerplate in each matcher: template <>
bool isaConvolutionOpOfType<linalg::SomeConvOp>(...) {
if (isa<linalg::SomeConvOp>(op)) return true; // 1. type check
assert(isaConvolutionOpInterface(op) && ...); // 2. assert
*dilations = SmallVector<int64_t>(N, 1); // 3. init
*strides = SmallVector<int64_t>(N, 1);
MLIRContext *context = op->getContext();
AffineExpr A = getAffineDimExpr(0, context); // 4. dims (5-8 lines)
AffineExpr B = getAffineDimExpr(1, context);
// ... more dims ...
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!matchConvDimAddExprPattern(...)) // 5. stride matching
return false;
// ... more stride matching ...
if (!convLayoutMatches({...}, indexingMaps, context)) // 6. map matching
return false;
Block *body = op.getBlock(); // 7. body match
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
return bodyMatcherForConvolutionOps(yieldVal, body);
}Suggested Solution: Introduce a /// Helper class for building convolution op matchers with minimal boilerplate.
/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants.
class ConvMatcherBuilder {
LinalgOp op;
MLIRContext *ctx;
SmallVector<int64_t> *dilations, *strides;
ArrayAttr indexingMaps;
bool matched = true;
public:
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank,
SmallVector<int64_t> *d, SmallVector<int64_t> *s)
: op(op), ctx(op->getContext()), dilations(d), strides(s),
indexingMaps(op.getIndexingMaps()) {
*dilations = SmallVector<int64_t>(spatialRank, 1);
*strides = SmallVector<int64_t>(spatialRank, 1);
}
/// Get affine dimension expression for dimension i.
AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
/// Build strided expression: base * stride[idx] + kernel * dilation[idx]
AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
return base * (*strides)[idx] + kernel * (*dilations)[idx];
}
/// Match stride/dilation pattern for a spatial dimension.
/// Returns *this for method chaining.
ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim,
unsigned oDim, unsigned idx) {
if (matched) {
matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
(*dilations)[idx], (*strides)[idx]);
}
return *this;
}
/// Match expected indexing maps layout.
/// Returns *this for method chaining.
ConvMatcherBuilder &expectMaps(
std::initializer_list<SmallVector<AffineExpr>> maps) {
if (matched)
matched = convLayoutMatches(maps, indexingMaps, ctx);
return *this;
}
/// Match body pattern. This should be called last.
bool matchBody(bool zeroPointOffset = false) {
if (!matched)
return false;
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body,
zeroPointOffset);
}
};Example usage - Conv3DOp (before: ~55 lines, after: ~12 lines): template <>
bool isaConvolutionOpOfType<linalg::Conv3DOp>(
LinalgOp op, SmallVector<int64_t> *dilations, SmallVector<int64_t> *strides) {
if (isa<linalg::Conv3DOp>(op)) return true;
assert(isaConvolutionOpInterface(op) && "expected ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
auto D = m.dim(0), H = m.dim(1), W = m.dim(2);
auto d = m.dim(3), h = m.dim(4), w = m.dim(5);
return m.matchStride(0, 0, 0, 0)
.matchStride(1, 1, 1, 1)
.matchStride(2, 2, 2, 2)
.expectMaps({/*in=*/{m.strided(D,d,0), m.strided(H,h,1), m.strided(W,w,2)},
/*filter=*/{d, h, w},
/*out=*/{D, H, W}})
.matchBody();
}Example - DepthwiseConv2DNhwcHwcQOp (quantized, before: ~55 lines, after: ~14 lines): template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
LinalgOp op, SmallVector<int64_t> *dilations, SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true;
assert(isaConvolutionOpInterface(op) && "expected ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
auto N = m.dim(0), H = m.dim(1), W = m.dim(2), c = m.dim(3);
auto h = m.dim(4), w = m.dim(5);
return m.matchStride(1, 0, 1, 0)
.matchStride(2, 1, 2, 1)
.expectMaps({/*in=*/{N, m.strided(H,h,0), m.strided(W,w,1), c},
/*filter=*/{h, w, c},
/*inputScalar=*/{}, /*filterScalar=*/{},
/*out=*/{N, H, W, c}})
.matchBody(/*zeroPointOffset=*/true);
}Example - Conv2DNgchwGfchwOp (grouped conv, before: ~55 lines, after: ~14 lines): template <>
bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
LinalgOp op, SmallVector<int64_t> *dilations, SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNgchwGfchwOp>(op)) return true;
assert(isaConvolutionOpInterface(op) && "expected ConvolutionOpInterface");
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
auto N = m.dim(0), G = m.dim(1), FG = m.dim(2);
auto H = m.dim(3), W = m.dim(4), C = m.dim(5);
auto h = m.dim(6), w = m.dim(7);
return m.matchStride(3, 3, 3, 0)
.matchStride(4, 4, 4, 1)
.expectMaps({/*in=*/{N, G, C, m.strided(H,h,0), m.strided(W,w,1)},
/*filter=*/{G, FG, C, h, w},
/*out=*/{N, G, FG, H, W}})
.matchBody();
}Benefits:
|
|
I checked most matching except |
|
Hi @hanhanW ! That was such a nice suggestion. I've addressed your comments! I was trying to also see if there's a way for us to simply just initialize with "(N, H, W, C, h, w)" string using the Can you please take a look now ? |
I'll be out for a week. QQ: can we split it to two PRs? One for introducing |
-- This commit is a follow-up and third in the series of adding matchers for conv/pool ops. Refer: llvm#163724 -- It introduces ConvMatchBuilder class in order to reduce the repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants. -- Refer to [Conv2D thread](llvm#168362 (comment)) for further context. Signed-off-by: Abhishek Varma <abhvarma@amd.com>
|
I've raised the Let me know if I should close this PR for now and reopen after the |
…69704) -- This commit is a follow-up and third in the series of adding matchers for conv/pool ops. Refer: #163724 -- It introduces ConvMatchBuilder class in order to reduce the repetitive code across Conv1D/2D/3D/Depthwise/Pooling variants. -- Refer to [Conv2D thread](#168362 (comment)) for further context. Signed-off-by: Abhishek Varma <abhvarma@amd.com>
-- This commit is the fourth in the series of adding matchers for linalg.conv/pool. Refer: llvm#163724 -- In this commit all variants of Conv2D convolution ops have been added. Signed-off-by: Abhishek Varma <abhvarma@amd.com>
3c90eee to
968cb69
Compare
|
I've rebased the PR to use |
-- This commit is the fourth in the series of adding matchers
for linalg.conv/pool. Refer: #163724
-- In this commit all variants of Conv2D convolution ops have been
added.
-- It also refactors the way these matchers work to make adding more matchers concise.
Signed-off-by: Abhishek Varma abhvarma@amd.com