Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 85 additions & 71 deletions mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,19 +430,33 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
})));
}

/// Enum of all kinds of Pooling Op's type.
enum PoolingType {
NONE,
MAX_SIGNED,
MAX_UNSIGNED,
MIN_SIGNED,
MIN_UNSIGNED,
SUM
/// Enum representing pooling operation types used by ConvMatcherBuilder.
enum class PoolingType {
None,
MaxSigned,
MaxUnsigned,
MinSigned,
MinUnsigned,
Sum
};

/// Helper class for building convolution op matchers with minimal boilerplate.
/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
/// as Pooling ops.
///
/// Usage: Create an instance with the op, spatial rank, and output pointers for
/// extracted dilations/strides. Then chain matchStride() calls for each spatial
/// dimension, followed by matchMaps() to verify indexing maps, and finally
/// matchBody() to verify the operation body pattern.
///
/// The `matched` flag starts as `true` and is set to `false` if any match step
/// fails. This allows chaining multiple match calls; once any match fails, all
/// subsequent calls become no-ops and the final result is `false`.
///
/// The `dilations` and `strides` pointers are output parameters that get
/// populated with the extracted dilation and stride values from the operation's
/// indexing maps during matchStride() calls. These values are initially set to
/// 1 for each spatial dimension and updated as patterns are matched.
class ConvMatcherBuilder {
LinalgOp op;
MLIRContext *ctx;
Expand All @@ -454,7 +468,7 @@ class ConvMatcherBuilder {
public:
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
SmallVector<int64_t> *s,
PoolingType poolingType = PoolingType::NONE)
PoolingType poolingType = PoolingType::None)
: op(op), ctx(op->getContext()), dilations(d), strides(s),
indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
*dilations = SmallVector<int64_t>(spatialRank, 1);
Expand All @@ -474,16 +488,16 @@ class ConvMatcherBuilder {
ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
unsigned idx) {
if (matched) {
matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
(*dilations)[idx], (*strides)[idx]);
matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
(*dilations)[idx], (*strides)[idx]);
}
return *this;
}

/// Match expected indexing maps layout. Returns *this for method chaining.
ConvMatcherBuilder &expectMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
if (matched)
matched = convLayoutMatches(maps, indexingMaps, ctx);
matched &= convLayoutMatches(maps, indexingMaps, ctx);
return *this;
}

Expand All @@ -494,17 +508,17 @@ class ConvMatcherBuilder {
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
switch (poolingType) {
case PoolingType::NONE:
case PoolingType::None:
return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
case PoolingType::MAX_SIGNED:
case PoolingType::MaxSigned:
return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
case PoolingType::MAX_UNSIGNED:
case PoolingType::MaxUnsigned:
return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
case PoolingType::MIN_SIGNED:
case PoolingType::MinSigned:
return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
case PoolingType::MIN_UNSIGNED:
case PoolingType::MinUnsigned:
return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
case PoolingType::SUM:
case PoolingType::Sum:
return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
}
return false;
Expand Down Expand Up @@ -533,9 +547,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
AffineExpr w = m.dim(1);

return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
.expectMaps({/*inputMap=*/{m.strided(W, w, 0)},
/*filterMap=*/{w},
/*outputMap=*/{W}})
.matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
/*filterMap=*/{w},
/*outputMap=*/{W}})
.matchBody();
}

Expand All @@ -560,9 +574,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
AffineExpr c = m.dim(4);

return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
/*filterMap=*/{w, c, F},
/*outputMap=*/{N, W, F}})
.matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
/*filterMap=*/{w, c, F},
/*outputMap=*/{N, W, F}})
.matchBody();
}

Expand All @@ -587,9 +601,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
AffineExpr w = m.dim(4);

return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
.expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
/*filterMap=*/{F, c, w},
/*outputMap=*/{N, F, W}})
.matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
/*filterMap=*/{F, c, w},
/*outputMap=*/{N, F, W}})
.matchBody();
}

Expand All @@ -614,9 +628,9 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,

return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
.expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
/*filterMap=*/{h, w},
/*outputMap=*/{H, W}})
.matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
/*filterMap=*/{h, w},
/*outputMap=*/{H, W}})
.matchBody();
}

Expand Down Expand Up @@ -644,10 +658,10 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
.expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
m.strided(W, w, 2)},
/*filterMap=*/{d, h, w},
/*outputMap=*/{D, H, W}})
.matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
m.strided(W, w, 2)},
/*filterMap=*/{d, h, w},
/*outputMap=*/{D, H, W}})
.matchBody();
}

Expand All @@ -671,9 +685,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
AffineExpr w = m.dim(3);

return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
.expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
/*filterMap=*/{C, w},
/*outputMap=*/{N, C, W}})
.matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
/*filterMap=*/{C, w},
/*outputMap=*/{N, C, W}})
.matchBody();
}

Expand All @@ -697,9 +711,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
AffineExpr w = m.dim(3);

return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
/*filterMap=*/{w, C},
/*outputMap=*/{N, W, C}})
.matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
/*filterMap=*/{w, C},
/*outputMap=*/{N, W, C}})
.matchBody();
}

Expand All @@ -724,9 +738,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
AffineExpr w = m.dim(4);

return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
/*filterMap=*/{w, C, CM},
/*outputMap=*/{N, W, C, CM}})
.matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
/*filterMap=*/{w, C, CM},
/*outputMap=*/{N, W, C, CM}})
.matchBody();
}

Expand All @@ -753,9 +767,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(

return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
.matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
.expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
/*filterMap=*/{C, h, w},
/*outputMap=*/{N, C, H, W}})
.matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
/*filterMap=*/{C, h, w},
/*outputMap=*/{N, C, H, W}})
.matchBody();
}

Expand Down Expand Up @@ -789,10 +803,10 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
.matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
.expectMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
m.strided(W, w, 2), C},
/*filterMap=*/{d, h, w, C, CM},
/*outputMap=*/{N, D, H, W, C, CM}})
.matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
m.strided(W, w, 2), C},
/*filterMap=*/{d, h, w, C, CM},
/*outputMap=*/{N, D, H, W, C, CM}})
.matchBody();
}

Expand All @@ -810,7 +824,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
"expected op to implement ConvolutionOpInterface");

ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::MAX_SIGNED);
PoolingType::MaxSigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
Expand All @@ -820,9 +834,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(

return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
.expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
/*filterMap=*/{h, w},
/*outputMap=*/{N, H, W, C}})
.matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
/*filterMap=*/{h, w},
/*outputMap=*/{N, H, W, C}})
.matchBody();
}

Expand All @@ -840,7 +854,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
"expected op to implement ConvolutionOpInterface");

ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::MIN_SIGNED);
PoolingType::MinSigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
Expand All @@ -850,9 +864,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(

return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
.expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
/*filterMap=*/{h, w},
/*outputMap=*/{N, H, W, C}})
.matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
/*filterMap=*/{h, w},
/*outputMap=*/{N, H, W, C}})
.matchBody();
}

Expand All @@ -870,7 +884,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
"expected op to implement ConvolutionOpInterface");

ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::SUM);
PoolingType::Sum);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
Expand All @@ -880,9 +894,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(

return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
.expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
/*filterMap=*/{h, w},
/*outputMap=*/{N, H, W, C}})
.matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
/*filterMap=*/{h, w},
/*outputMap=*/{N, H, W, C}})
.matchBody();
}

Expand All @@ -900,7 +914,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
"expected op to implement ConvolutionOpInterface");

ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::MAX_UNSIGNED);
PoolingType::MaxUnsigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
Expand All @@ -910,9 +924,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(

return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
.expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
/*filterMap=*/{h, w},
/*outputMap=*/{N, H, W, C}})
.matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
/*filterMap=*/{h, w},
/*outputMap=*/{N, H, W, C}})
.matchBody();
}

Expand All @@ -930,7 +944,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
"expected op to implement ConvolutionOpInterface");

ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::MIN_UNSIGNED);
PoolingType::MinUnsigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
Expand All @@ -940,9 +954,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(

return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
.expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
/*filterMap=*/{h, w},
/*outputMap=*/{N, H, W, C}})
.matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
/*filterMap=*/{h, w},
/*outputMap=*/{N, H, W, C}})
.matchBody();
}

Expand Down