diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 6b85e6ba0ede2..e85a2ab26bd32 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -416,10 +416,6 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, return false; } -// --------------------------------------------- -// Matchers for specific convolution operation. -// --------------------------------------------- - /// Returns true if the given indexing maps matches with the expected indexing /// maps. static bool convLayoutMatches(ArrayRef> mapListExpected, @@ -434,6 +430,91 @@ static bool convLayoutMatches(ArrayRef> mapListExpected, }))); } +/// Enum of all kinds of Pooling Op's type. +enum PoolingType { + NONE, + MAX_SIGNED, + MAX_UNSIGNED, + MIN_SIGNED, + MIN_UNSIGNED, + SUM +}; + +/// Helper class for building convolution op matchers with minimal boilerplate. +/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well +/// as Pooling ops. +class ConvMatcherBuilder { + LinalgOp op; + MLIRContext *ctx; + SmallVector *dilations, *strides; + ArrayAttr indexingMaps; + PoolingType poolingType; + bool matched = true; + +public: + ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector *d, + SmallVector *s, + PoolingType poolingType = PoolingType::NONE) + : op(op), ctx(op->getContext()), dilations(d), strides(s), + indexingMaps(op.getIndexingMaps()), poolingType(poolingType) { + *dilations = SmallVector(spatialRank, 1); + *strides = SmallVector(spatialRank, 1); + } + + /// Get affine dimension expression for dimension `i`. + AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); } + + /// Build strided expression: base * stride[idx] + kernel * dilation[idx]. + AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) { + return base * (*strides)[idx] + kernel * (*dilations)[idx]; + } + + /// Match stride/dilation pattern for a spatial dimension. + /// Returns *this for method chaining. + ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim, + unsigned idx) { + if (matched) { + matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim, + (*dilations)[idx], (*strides)[idx]); + } + return *this; + } + + /// Match expected indexing maps layout. Returns *this for method chaining. + ConvMatcherBuilder &expectMaps(ArrayRef> maps) { + if (matched) + matched = convLayoutMatches(maps, indexingMaps, ctx); + return *this; + } + + /// Match body pattern. This should be called last. + bool matchBody() { + if (!matched) + return false; + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + switch (poolingType) { + case PoolingType::NONE: + return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body); + case PoolingType::MAX_SIGNED: + return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MAX_UNSIGNED: + return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MIN_SIGNED: + return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::MIN_UNSIGNED: + return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body); + case PoolingType::SUM: + return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body); + } + return false; + } +}; + +//===----------------------------------------------------------------------===// +// Matchers for specific convolution operation. +//===----------------------------------------------------------------------===// + // #inputMap = affine_map<(W, w) -> (W + w)> // #filterMap = affine_map<(W, w) -> (w)> // #outputMap = affine_map<(W, w) -> (W)> @@ -447,29 +528,15 @@ bool isaConvolutionOpOfType(LinalgOp op, assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(1, 1); - *strides = SmallVector(1, 1); - MLIRContext *context = op->getContext(); - AffineExpr W = getAffineDimExpr(0, context); - AffineExpr w = getAffineDimExpr(1, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, - /*oDim=*/0, (*dilations)[0], (*strides)[0])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{W * (*strides)[0] + w * (*dilations)[0]}, - /*filterMap=*/{w}, - /*outputMap=*/{W}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForConvolutionOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr W = m.dim(0); + AffineExpr w = m.dim(1); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .expectMaps({/*inputMap=*/{m.strided(W, w, 0)}, + /*filterMap=*/{w}, + /*outputMap=*/{W}}) + .matchBody(); } // #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)> @@ -485,32 +552,18 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(1, 1); - *strides = SmallVector(1, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr W = getAffineDimExpr(1, context); - AffineExpr F = getAffineDimExpr(2, context); - AffineExpr w = getAffineDimExpr(3, context); - AffineExpr c = getAffineDimExpr(4, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, (*dilations)[0], (*strides)[0])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], c}, - /*filterMap=*/{w, c, F}, - /*outputMap=*/{N, W, F}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForConvolutionOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr F = m.dim(2); + AffineExpr w = m.dim(3); + AffineExpr c = m.dim(4); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c}, + /*filterMap=*/{w, c, F}, + /*outputMap=*/{N, W, F}}) + .matchBody(); } // #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)> @@ -526,32 +579,18 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(1, 1); - *strides = SmallVector(1, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr F = getAffineDimExpr(1, context); - AffineExpr W = getAffineDimExpr(2, context); - AffineExpr c = getAffineDimExpr(3, context); - AffineExpr w = getAffineDimExpr(4, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, - /*oDim=*/2, (*dilations)[0], (*strides)[0])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, c, W * (*strides)[0] + w * (*dilations)[0]}, - /*filterMap=*/{F, c, w}, - /*outputMap=*/{N, F, W}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForConvolutionOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr F = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr c = m.dim(3); + AffineExpr w = m.dim(4); + + return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0) + .expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)}, + /*filterMap=*/{F, c, w}, + /*outputMap=*/{N, F, W}}) + .matchBody(); } // #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)> @@ -567,36 +606,18 @@ bool isaConvolutionOpOfType(LinalgOp op, assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(2, 1); - *strides = SmallVector(2, 1); - MLIRContext *context = op->getContext(); - AffineExpr H = getAffineDimExpr(0, context); - AffineExpr W = getAffineDimExpr(1, context); - AffineExpr h = getAffineDimExpr(2, context); - AffineExpr w = getAffineDimExpr(3, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: H * stride + h * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, - /*oDim=*/0, (*dilations)[0], (*strides)[0])) - return false; - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, - /*oDim=*/1, (*dilations)[1], (*strides)[1])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{H * (*strides)[0] + h * (*dilations)[0], - W * (*strides)[1] + w * (*dilations)[1]}, - /*filterMap=*/{h, w}, - /*outputMap=*/{H, W}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForConvolutionOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr H = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr h = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1) + .expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{h, w}, + /*outputMap=*/{H, W}}) + .matchBody(); } // #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)> @@ -612,43 +633,22 @@ bool isaConvolutionOpOfType(LinalgOp op, assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(3, 1); - *strides = SmallVector(3, 1); - MLIRContext *context = op->getContext(); - AffineExpr D = getAffineDimExpr(0, context); - AffineExpr H = getAffineDimExpr(1, context); - AffineExpr W = getAffineDimExpr(2, context); - AffineExpr d = getAffineDimExpr(3, context); - AffineExpr h = getAffineDimExpr(4, context); - AffineExpr w = getAffineDimExpr(5, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: D * stride + d * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, - /*oDim=*/0, (*dilations)[0], (*strides)[0])) - return false; - // Match: H * stride + h * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, - /*oDim=*/1, (*dilations)[1], (*strides)[1])) - return false; - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, - /*oDim=*/2, (*dilations)[2], (*strides)[2])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0], - H * (*strides)[1] + h * (*dilations)[1], - W * (*strides)[2] + w * (*dilations)[2]}, - /*filterMap=*/{d, h, w}, - /*outputMap=*/{D, H, W}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForConvolutionOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr D = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr d = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0) + .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1) + .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2) + .expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2)}, + /*filterMap=*/{d, h, w}, + /*outputMap=*/{D, H, W}}) + .matchBody(); } // #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)> @@ -664,31 +664,17 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(1, 1); - *strides = SmallVector(1, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr W = getAffineDimExpr(1, context); - AffineExpr C = getAffineDimExpr(2, context); - AffineExpr w = getAffineDimExpr(3, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, (*dilations)[0], (*strides)[0])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]}, - /*filterMap=*/{C, w}, - /*outputMap=*/{N, C, W}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForConvolutionOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) + .expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)}, + /*filterMap=*/{C, w}, + /*outputMap=*/{N, C, W}}) + .matchBody(); } // #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)> @@ -704,31 +690,17 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(1, 1); - *strides = SmallVector(1, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr W = getAffineDimExpr(1, context); - AffineExpr C = getAffineDimExpr(2, context); - AffineExpr w = getAffineDimExpr(3, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, (*dilations)[0], (*strides)[0])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C}, - /*filterMap=*/{w, C}, - /*outputMap=*/{N, W, C}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForConvolutionOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr w = m.dim(3); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w, C}, + /*outputMap=*/{N, W, C}}) + .matchBody(); } // #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)> @@ -744,32 +716,18 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(1, 1); - *strides = SmallVector(1, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr W = getAffineDimExpr(1, context); - AffineExpr C = getAffineDimExpr(2, context); - AffineExpr CM = getAffineDimExpr(3, context); - AffineExpr w = getAffineDimExpr(4, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, (*dilations)[0], (*strides)[0])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C}, - /*filterMap=*/{w, C, CM}, - /*outputMap=*/{N, W, C, CM}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForConvolutionOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr W = m.dim(1); + AffineExpr C = m.dim(2); + AffineExpr CM = m.dim(3); + AffineExpr w = m.dim(4); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C}, + /*filterMap=*/{w, C, CM}, + /*outputMap=*/{N, W, C, CM}}) + .matchBody(); } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> @@ -785,38 +743,20 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(2, 1); - *strides = SmallVector(2, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr H = getAffineDimExpr(1, context); - AffineExpr W = getAffineDimExpr(2, context); - AffineExpr C = getAffineDimExpr(3, context); - AffineExpr h = getAffineDimExpr(4, context); - AffineExpr w = getAffineDimExpr(5, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: H * stride + h * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, (*dilations)[0], (*strides)[0])) - return false; - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, - /*oDim=*/3, (*dilations)[1], (*strides)[1])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0], - W * (*strides)[1] + w * (*dilations)[1]}, - /*filterMap=*/{C, h, w}, - /*outputMap=*/{N, C, H, W}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForConvolutionOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{C, h, w}, + /*outputMap=*/{N, C, H, W}}) + .matchBody(); } // #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) @@ -835,46 +775,25 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(3, 1); - *strides = SmallVector(3, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr D = getAffineDimExpr(1, context); - AffineExpr H = getAffineDimExpr(2, context); - AffineExpr W = getAffineDimExpr(3, context); - AffineExpr CM = getAffineDimExpr(4, context); - AffineExpr d = getAffineDimExpr(5, context); - AffineExpr h = getAffineDimExpr(6, context); - AffineExpr w = getAffineDimExpr(7, context); - AffineExpr C = getAffineDimExpr(8, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: D * stride + d * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, (*dilations)[0], (*strides)[0])) - return false; - // Match: H * stride + h * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, (*dilations)[1], (*strides)[1])) - return false; - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, - /*oDim=*/3, (*dilations)[2], (*strides)[2])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, D * (*strides)[0] + d * (*dilations)[0], - H * (*strides)[1] + h * (*dilations)[1], - W * (*strides)[2] + w * (*dilations)[2], C}, - /*filterMap=*/{d, h, w, C, CM}, - /*outputMap=*/{N, D, H, W, C, CM}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForConvolutionOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr CM = m.dim(4); + AffineExpr d = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + AffineExpr C = m.dim(8); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) + .expectMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), C}, + /*filterMap=*/{d, h, w, C, CM}, + /*outputMap=*/{N, D, H, W, C, CM}}) + .matchBody(); } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> @@ -890,38 +809,21 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(2, 1); - *strides = SmallVector(2, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr H = getAffineDimExpr(1, context); - AffineExpr W = getAffineDimExpr(2, context); - AffineExpr C = getAffineDimExpr(3, context); - AffineExpr h = getAffineDimExpr(4, context); - AffineExpr w = getAffineDimExpr(5, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: H * stride + h * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, (*dilations)[0], (*strides)[0])) - return false; - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, (*dilations)[1], (*strides)[1])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], - W * (*strides)[1] + w * (*dilations)[1], C}, - /*filterMap=*/{h, w}, - /*outputMap=*/{N, H, W, C}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForMaxSignedPoolOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MAX_SIGNED); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> @@ -937,38 +839,21 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(2, 1); - *strides = SmallVector(2, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr H = getAffineDimExpr(1, context); - AffineExpr W = getAffineDimExpr(2, context); - AffineExpr C = getAffineDimExpr(3, context); - AffineExpr h = getAffineDimExpr(4, context); - AffineExpr w = getAffineDimExpr(5, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: H * stride + h * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, (*dilations)[0], (*strides)[0])) - return false; - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, (*dilations)[1], (*strides)[1])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], - W * (*strides)[1] + w * (*dilations)[1], C}, - /*filterMap=*/{h, w}, - /*outputMap=*/{N, H, W, C}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForMinSignedPoolOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MIN_SIGNED); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> @@ -984,38 +869,21 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(2, 1); - *strides = SmallVector(2, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr H = getAffineDimExpr(1, context); - AffineExpr W = getAffineDimExpr(2, context); - AffineExpr C = getAffineDimExpr(3, context); - AffineExpr h = getAffineDimExpr(4, context); - AffineExpr w = getAffineDimExpr(5, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: H * stride + h * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, (*dilations)[0], (*strides)[0])) - return false; - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, (*dilations)[1], (*strides)[1])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], - W * (*strides)[1] + w * (*dilations)[1], C}, - /*filterMap=*/{h, w}, - /*outputMap=*/{N, H, W, C}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForSumPoolOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::SUM); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> @@ -1031,38 +899,21 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(2, 1); - *strides = SmallVector(2, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr H = getAffineDimExpr(1, context); - AffineExpr W = getAffineDimExpr(2, context); - AffineExpr C = getAffineDimExpr(3, context); - AffineExpr h = getAffineDimExpr(4, context); - AffineExpr w = getAffineDimExpr(5, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: H * stride + h * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, (*dilations)[0], (*strides)[0])) - return false; - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, (*dilations)[1], (*strides)[1])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], - W * (*strides)[1] + w * (*dilations)[1], C}, - /*filterMap=*/{h, w}, - /*outputMap=*/{N, H, W, C}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForMaxUnsignedPoolOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MAX_UNSIGNED); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> @@ -1078,38 +929,21 @@ bool isaConvolutionOpOfType( assert(isaConvolutionOpInterface(op) && "expected op to implement ConvolutionOpInterface"); - *dilations = SmallVector(2, 1); - *strides = SmallVector(2, 1); - MLIRContext *context = op->getContext(); - AffineExpr N = getAffineDimExpr(0, context); - AffineExpr H = getAffineDimExpr(1, context); - AffineExpr W = getAffineDimExpr(2, context); - AffineExpr C = getAffineDimExpr(3, context); - AffineExpr h = getAffineDimExpr(4, context); - AffineExpr w = getAffineDimExpr(5, context); - ArrayAttr indexingMaps = op.getIndexingMaps(); - // First fetch dilations/strides :- - // Match: H * stride + h * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, (*dilations)[0], (*strides)[0])) - return false; - // Match: W * stride + w * dilation - if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, (*dilations)[1], (*strides)[1])) - return false; - // Match expected indexing maps - if (!convLayoutMatches( - {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], - W * (*strides)[1] + w * (*dilations)[1], C}, - /*filterMap=*/{h, w}, - /*outputMap=*/{N, H, W, C}}, - indexingMaps, context)) - return false; - // Match body - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); - return bodyMatcherForMinUnsignedPoolOps(yieldVal, body); + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides, + PoolingType::MIN_UNSIGNED); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); } Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,