diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index c2485a08932dd..bbfbd2e9736a1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -279,6 +279,17 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp); CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp); CONV_OP_SPECIALIZER(linalg::Conv2DOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwcHwcfQOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwcFhwcQOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNchwFchwQOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNgchwFgchwOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNgchwGfchwQOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp); + CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp); CONV_OP_SPECIALIZER(linalg::Conv3DOp); // ----------------------------- // Depthwise Convolution ops. @@ -287,6 +298,10 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp); // ----------------------------- // Pooling ops. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index e85a2ab26bd32..a59b2663f2998 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -240,8 +240,8 @@ bool isReductionIterator(utils::IteratorType iteratorType) { //===----------------------------------------------------------------------===// /// Returns the BlockArgument that leads to `val`, if any. Traverses optional -/// ext* ops. -static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { +/// ext*/sitofp ops. +static BlockArgument getBlockArgumentWithOptionalCastOps(Value val) { BlockArgument blockArg = dyn_cast(val); if ((blockArg)) return blockArg; @@ -249,18 +249,62 @@ static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { Operation *defOp = val.getDefiningOp(); if (!dyn_cast_if_present(defOp) && !dyn_cast_if_present(defOp) && - !dyn_cast_if_present(defOp)) { + !dyn_cast_if_present(defOp) && + !dyn_cast_if_present(defOp)) { return nullptr; } return dyn_cast(defOp->getOperand(0)); } +/// Utility function to match the zero point offset body of convolution ops. +/// It takes input the addition op and multiplication op expected in every +/// convolution op and matches the following for both operands of multiplication +/// op :- +/// %a - %b +/// where, %a and %b can have optional upcast operation. +static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp, + Block *body) { + Operation *subOp1 = mulOp->getOperand(0).getDefiningOp(); + if (!isa_and_present(subOp1)) + return false; + Operation *subOp2 = mulOp->getOperand(1).getDefiningOp(); + if (!isa_and_present(subOp2)) + return false; + BlockArgument inputBlockArg = + getBlockArgumentWithOptionalCastOps(subOp1->getOperand(0)); + BlockArgument inputScalarBlockArg = + getBlockArgumentWithOptionalCastOps(subOp1->getOperand(1)); + BlockArgument filterBlockArg = + getBlockArgumentWithOptionalCastOps(subOp2->getOperand(0)); + BlockArgument filterScalarBlockArg = + getBlockArgumentWithOptionalCastOps(subOp2->getOperand(1)); + BlockArgument outBlockArg = + getBlockArgumentWithOptionalCastOps(addOp->getOperand(0)); + if (!inputBlockArg || !inputScalarBlockArg || !filterBlockArg || + !filterScalarBlockArg || !outBlockArg || + inputBlockArg.getOwner() != body || + inputScalarBlockArg.getOwner() != body || + filterBlockArg.getOwner() != body || + filterScalarBlockArg.getOwner() != body || + outBlockArg.getOwner() != body || inputBlockArg.getArgNumber() != 0 || + inputScalarBlockArg.getArgNumber() != 2 || + filterBlockArg.getArgNumber() != 1 || + filterScalarBlockArg.getArgNumber() != 3 || + outBlockArg.getArgNumber() != 4) + return false; + return true; +} + /// Utility to match block body for convolution ops. /// The body is thus expected to yield :- /// %out + (%lhs * %rhs) /// where: %lhs, %rhs and %out are block arguments and /// %lhs and %rhs can have optional upcast operation. -static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) { +/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :- +/// %input - %input_scalar +/// where, %input_scalar can have optional upcast operation. +static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body, + bool zeroPointOffset = false) { Operation *addOp = yieldVal.getDefiningOp(); if (!isa_and_present(addOp)) return false; @@ -269,12 +313,15 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) { if (!isa_and_present(mulOp)) return false; + if (zeroPointOffset) { + return bodyMatcherForZeroPointOffsets(addOp, mulOp, body); + } BlockArgument lhsBlockArg = - getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0)); + getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0)); BlockArgument rhsBlockArg = - getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1)); + getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1)); BlockArgument outBlockArg = - getBlockArgumentWithOptionalExtOps(addOp->getOperand(0)); + getBlockArgumentWithOptionalCastOps(addOp->getOperand(0)); if (!lhsBlockArg || !rhsBlockArg || !outBlockArg || lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body || outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 || @@ -291,9 +338,9 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { return false; BlockArgument lhsArg = - getBlockArgumentWithOptionalExtOps(defOp->getOperand(0)); + getBlockArgumentWithOptionalCastOps(defOp->getOperand(0)); BlockArgument rhsArg = - getBlockArgumentWithOptionalExtOps(defOp->getOperand(1)); + getBlockArgumentWithOptionalCastOps(defOp->getOperand(1)); if (!lhsArg || !rhsArg || lhsArg.getOwner() != body || rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 || rhsArg.getArgNumber() != 0) @@ -488,14 +535,15 @@ class ConvMatcherBuilder { } /// Match body pattern. This should be called last. - bool matchBody() { + bool matchBody(bool zeroPointOffset = false) { if (!matched) return false; Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); switch (poolingType) { case PoolingType::NONE: - return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body); + return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body, + zeroPointOffset); case PoolingType::MAX_SIGNED: return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body); case PoolingType::MAX_UNSIGNED: @@ -620,6 +668,361 @@ bool isaConvolutionOpOfType(LinalgOp op, .matchBody(); } +// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)> +// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)> +// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr F = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + AffineExpr c = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c}, + /*filterMap=*/{h, w, c, F}, + /*outputMap=*/{N, H, W, F}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)> +// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (h, w, c, F)> +// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()> +// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr F = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + AffineExpr c = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c}, + /*filterMap=*/{h, w, c, F}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, H, W, F}}) + .matchBody(/*zeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)> +// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)> +// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr F = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + AffineExpr c = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c}, + /*filterMap=*/{F, h, w, c}, + /*outputMap=*/{N, H, W, F}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)> +// #filterMap = affine_map<(N, H, W, F, h, w, c) -> (F, h, w, c)> +// #scalarMap = affine_map<(N, H, W, F, h, w, c) -> ()> +// #outputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr F = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + AffineExpr c = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c}, + /*filterMap=*/{F, h, w, c}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, H, W, F}}) + .matchBody(/*zeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, F, H, W, c, h, w) -> (N, c, H + h, W + w)> +// #filterMap = affine_map<(N, F, H, W, c, h, w) -> (F, c, h, w)> +// #outputMap = affine_map<(N, F, H, W, c, h, w) -> (N, F, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr F = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr c = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + + return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{F, c, h, w}, + /*outputMap=*/{N, F, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, F, H, W, c, h, w) -> (N, c, H + h, W + w)> +// #filterMap = affine_map<(N, F, H, W, c, h, w) -> (F, c, h, w)> +// #scalarMap = affine_map<(N, F, H, W, c, h, w) -> ()> +// #outputMap = affine_map<(N, F, H, W, c, h, w) -> (N, F, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr F = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr c = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + + return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{F, c, h, w}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, F, H, W}}) + .matchBody(/*zeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, c, H + h, W + w)> +// #filterMap = affine_map<(N, G, F, H, W, c, h, w) -> (F, G, c, h, w)> +// #outputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, F, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr G = m.dim(1); + AffineExpr F = m.dim(2); + AffineExpr H = m.dim(3); + AffineExpr W = m.dim(4); + AffineExpr c = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + + return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0) + .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1) + .expectMaps( + {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{F, G, c, h, w}, + /*outputMap=*/{N, G, F, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, c, H + h, W + w)> +// #filterMap = affine_map<(N, G, F, H, W, c, h, w) -> (G, F, c, h, w)> +// #outputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, F, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr G = m.dim(1); + AffineExpr F = m.dim(2); + AffineExpr H = m.dim(3); + AffineExpr W = m.dim(4); + AffineExpr c = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + + return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0) + .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1) + .expectMaps( + {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{G, F, c, h, w}, + /*outputMap=*/{N, G, F, H, W}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, c, H + h, W + w)> +// #filterMap = affine_map<(N, G, F, H, W, c, h, w) -> (G, F, c, h, w)> +// #scalarMap = affine_map<(N, G, F, H, W, c, h, w) -> ()> +// #outputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, F, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr G = m.dim(1); + AffineExpr F = m.dim(2); + AffineExpr H = m.dim(3); + AffineExpr W = m.dim(4); + AffineExpr c = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + + return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0) + .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1) + .expectMaps( + {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)}, + /*filterMap=*/{G, F, c, h, w}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, G, F, H, W}}) + .matchBody(/*zeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H + h, W + w, G, c)> +// #filterMap = affine_map<(N, H, W, G, F, h, w, c) -> (G, F, h, w, c)> +// #outputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H, W, G, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr G = m.dim(3); + AffineExpr F = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + AffineExpr c = m.dim(7); + + return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1) + .expectMaps( + {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c}, + /*filterMap=*/{G, F, h, w, c}, + /*outputMap=*/{N, H, W, G, F}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H + h, W + w, G, c)> +// #filterMap = affine_map<(N, H, W, G, F, h, w, c) -> (G, F, h, w, c)> +// #scalarMap = affine_map<(N, H, W, G, F, h, w, c) -> ()> +// #outputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H, W, G, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr G = m.dim(3); + AffineExpr F = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + AffineExpr c = m.dim(7); + + return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1) + .expectMaps( + {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c}, + /*filterMap=*/{G, F, h, w, c}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, H, W, G, F}}) + .matchBody(/*zeroPointOffset=*/true); +} + // #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)> // #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)> // #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)> @@ -759,6 +1162,130 @@ bool isaConvolutionOpOfType( .matchBody(); } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w, C)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w, C}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w, C)> +// #scalarMap = affine_map<(N, H, W, C, h, w) -> ()> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr h = m.dim(4); + AffineExpr w = m.dim(5); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w, C}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, H, W, C}}) + .matchBody(/*zeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, CM, h, w) -> (h, w, C, CM)> +// #outputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H, W, C, CM)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr CM = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w, C, CM}, + /*outputMap=*/{N, H, W, C, CM}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, CM, h, w) -> (h, w, C, CM)> +// #scalarMap = affine_map<(N, H, W, C, CM, h, w) -> ()> +// #outputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H, W, C, CM)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr H = m.dim(1); + AffineExpr W = m.dim(2); + AffineExpr C = m.dim(3); + AffineExpr CM = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C}, + /*filterMap=*/{h, w, C, CM}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, H, W, C, CM}}) + .matchBody(/*zeroPointOffset=*/true); +} + // #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) // -> (N, D + d, H + h, W + w, C)> // #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C) diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index 4b2d42a3ae4e0..289d55ce9911a 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -55,6 +55,149 @@ func.func @conv_2d(%in : tensor, %filter : tensor, %out : tens // ----- +func.func @conv_2d_nhwc_hwcf(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_hwcf +// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: dilations = dense<2> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nhwc_hwcf_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_2d_nhwc_hwcf_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_hwcf_q +// CHECK: linalg.conv_2d_nhwc_hwcf_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nhwc_fhwc(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_fhwc +// CHECK: linalg.conv_2d_nhwc_fhwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nhwc_fhwc_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwc_fhwc_q +// CHECK: linalg.conv_2d_nhwc_fhwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nchw_fchw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_nchw_fchw + {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 4]> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nchw_fchw +// CHECK: linalg.conv_2d_nchw_fchw +// CHECK-SAME: dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 4]> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nchw_fchw_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_2d_nchw_fchw_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nchw_fchw_q +// CHECK: linalg.conv_2d_nchw_fchw_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_ngchw_fgchw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_ngchw_fgchw + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_ngchw_fgchw +// CHECK: linalg.conv_2d_ngchw_fgchw +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_ngchw_gfchw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_ngchw_gfchw + {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_ngchw_gfchw +// CHECK: linalg.conv_2d_ngchw_gfchw +// CHECK-SAME: dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_ngchw_gfchw_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_2d_ngchw_gfchw_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_ngchw_gfchw_q +// CHECK: linalg.conv_2d_ngchw_gfchw_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nhwgc_gfhwc(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_2d_nhwgc_gfhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwgc_gfhwc +// CHECK: linalg.conv_2d_nhwgc_gfhwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64> + +// ----- + +func.func @conv_2d_nhwgc_gfhwc_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_2d_nhwgc_gfhwc_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d_nhwgc_gfhwc_q +// CHECK: linalg.conv_2d_nhwgc_gfhwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + func.func @conv_3d(%in : tensor, %filter : tensor, %out : tensor) -> tensor { %0 = linalg.conv_3d ins(%in, %filter : tensor, tensor) @@ -121,6 +264,58 @@ func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tens // ----- +func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwc +// CHECK: linalg.depthwise_conv_2d_nhwc_hwc +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64> + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwc_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwc_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwc_q +// CHECK: linalg.depthwise_conv_2d_nhwc_hwc_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwcm(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwcm + {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 1]> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwcm +// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm +// CHECK-SAME: dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[3, 1]> : tensor<2xi64> + +// ----- + +func.func @depthwise_conv_2d_nhwc_hwcm_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nhwc_hwcm_q +// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm_q +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %output: tensor) -> tensor { %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}