Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,22 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
SmallVector<Value> inputs = genericOp.getDpsInputs();
ValueRange outputs = genericOp.getDpsInits();
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
? TypeRange(ValueRange(outputs))
: TypeRange{};
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
LinalgOp namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
LinalgOp namedOp;
// Ops with no dilations and no strides.
if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
inputs, outputs);
} else {
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
}
return namedOp;
}

Expand All @@ -265,9 +273,19 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
strides); \
// -----------------------------
// Convolution ops.
// -----------------------------
CONV_OP_SPECIALIZER(linalg::Conv1DOp);
CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
CONV_OP_SPECIALIZER(linalg::Conv2DOp);
CONV_OP_SPECIALIZER(linalg::Conv3DOp);
// -----------------------------
// Depthwise Convolution ops.
// -----------------------------
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
// -----------------------------
Expand Down
300 changes: 299 additions & 1 deletion mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
unsigned inputMapIdx = 0, filterMapIdx = 1,
outputMapIdx = indexingMaps.size() - 1;
AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
return false;

Expand Down Expand Up @@ -434,6 +434,263 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
})));
}

// #inputMap = affine_map<(W, w) -> (W + w)>
// #filterMap = affine_map<(W, w) -> (w)>
// #outputMap = affine_map<(W, w) -> (W)>
template <>
bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
if (isa<linalg::Conv1DOp>(op))
return true;

assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");

*dilations = SmallVector<int64_t>(1, 1);
*strides = SmallVector<int64_t>(1, 1);
MLIRContext *context = op->getContext();
AffineExpr W = getAffineDimExpr(0, context);
AffineExpr w = getAffineDimExpr(1, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: W * stride + w * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
/*oDim=*/0, (*dilations)[0], (*strides)[0]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
{/*inputMap=*/{W * (*strides)[0] + w * (*dilations)[0]},
/*filterMap=*/{w},
/*outputMap=*/{W}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
return bodyMatcherForConvolutionOps(yieldVal, body);
}

// #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)>
// #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)>
// #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)>
template <>
bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
if (isa<linalg::Conv1DNwcWcfOp>(op))
return true;

assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");

*dilations = SmallVector<int64_t>(1, 1);
*strides = SmallVector<int64_t>(1, 1);
MLIRContext *context = op->getContext();
AffineExpr N = getAffineDimExpr(0, context);
AffineExpr W = getAffineDimExpr(1, context);
AffineExpr F = getAffineDimExpr(2, context);
AffineExpr w = getAffineDimExpr(3, context);
AffineExpr c = getAffineDimExpr(4, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: W * stride + w * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
{/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], c},
/*filterMap=*/{w, c, F},
/*outputMap=*/{N, W, F}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
return bodyMatcherForConvolutionOps(yieldVal, body);
}

// #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)>
// #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)>
// #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)>
template <>
bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
if (isa<linalg::Conv1DNcwFcwOp>(op))
return true;

assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");

*dilations = SmallVector<int64_t>(1, 1);
*strides = SmallVector<int64_t>(1, 1);
MLIRContext *context = op->getContext();
AffineExpr N = getAffineDimExpr(0, context);
AffineExpr F = getAffineDimExpr(1, context);
AffineExpr W = getAffineDimExpr(2, context);
AffineExpr c = getAffineDimExpr(3, context);
AffineExpr w = getAffineDimExpr(4, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: W * stride + w * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
/*oDim=*/2, (*dilations)[0], (*strides)[0]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
{/*inputMap=*/{N, c, W * (*strides)[0] + w * (*dilations)[0]},
/*filterMap=*/{F, c, w},
/*outputMap=*/{N, F, W}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
return bodyMatcherForConvolutionOps(yieldVal, body);
}

// #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)>
// #filterMap = affine_map<(H, W, h, w) -> (h, w)>
// #outputMap = affine_map<(H, W, h, w) -> (H, W)>
template <>
bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DOp>(op))
return true;

assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");

*dilations = SmallVector<int64_t>(2, 1);
*strides = SmallVector<int64_t>(2, 1);
MLIRContext *context = op->getContext();
AffineExpr H = getAffineDimExpr(0, context);
AffineExpr W = getAffineDimExpr(1, context);
AffineExpr h = getAffineDimExpr(2, context);
AffineExpr w = getAffineDimExpr(3, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: H * stride + h * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
/*oDim=*/0, (*dilations)[0], (*strides)[0]))
return false;
// Match: W * stride + w * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
/*oDim=*/1, (*dilations)[1], (*strides)[1]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
{/*inputMap=*/{H * (*strides)[0] + h * (*dilations)[0],
W * (*strides)[1] + w * (*dilations)[1]},
/*filterMap=*/{h, w},
/*outputMap=*/{H, W}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
return bodyMatcherForConvolutionOps(yieldVal, body);
}

// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
template <>
bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
if (isa<linalg::Conv3DOp>(op))
return true;

assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");

*dilations = SmallVector<int64_t>(3, 1);
*strides = SmallVector<int64_t>(3, 1);
MLIRContext *context = op->getContext();
AffineExpr D = getAffineDimExpr(0, context);
AffineExpr H = getAffineDimExpr(1, context);
AffineExpr W = getAffineDimExpr(2, context);
AffineExpr d = getAffineDimExpr(3, context);
AffineExpr h = getAffineDimExpr(4, context);
AffineExpr w = getAffineDimExpr(5, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: D * stride + d * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
/*oDim=*/0, (*dilations)[0], (*strides)[0]))
return false;
// Match: H * stride + h * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
/*oDim=*/1, (*dilations)[1], (*strides)[1]))
return false;
// Match: W * stride + w * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
/*oDim=*/2, (*dilations)[2], (*strides)[2]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
{/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
H * (*strides)[1] + h * (*dilations)[1],
W * (*strides)[2] + w * (*dilations)[2]},
/*filterMap=*/{d, h, w},
/*outputMap=*/{D, H, W}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
return bodyMatcherForConvolutionOps(yieldVal, body);
}

// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)>
// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
return true;

assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");

*dilations = SmallVector<int64_t>(1, 1);
*strides = SmallVector<int64_t>(1, 1);
MLIRContext *context = op->getContext();
AffineExpr N = getAffineDimExpr(0, context);
AffineExpr W = getAffineDimExpr(1, context);
AffineExpr C = getAffineDimExpr(2, context);
AffineExpr w = getAffineDimExpr(3, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: W * stride + w * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
/*oDim=*/2, (*dilations)[0], (*strides)[0]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
{/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
/*filterMap=*/{C, w},
/*outputMap=*/{N, C, W}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
return bodyMatcherForConvolutionOps(yieldVal, body);
}

// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
Expand Down Expand Up @@ -474,6 +731,47 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
return bodyMatcherForConvolutionOps(yieldVal, body);
}

// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
return true;

assert(isaConvolutionOpInterface(op) &&
"expected op to implement ConvolutionOpInterface");

*dilations = SmallVector<int64_t>(1, 1);
*strides = SmallVector<int64_t>(1, 1);
MLIRContext *context = op->getContext();
AffineExpr N = getAffineDimExpr(0, context);
AffineExpr W = getAffineDimExpr(1, context);
AffineExpr C = getAffineDimExpr(2, context);
AffineExpr CM = getAffineDimExpr(3, context);
AffineExpr w = getAffineDimExpr(4, context);
ArrayAttr indexingMaps = op.getIndexingMaps();
// First fetch dilations/strides :-
// Match: W * stride + w * dilation
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
return false;
// Match expected indexing maps
if (!convLayoutMatches(
{/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
/*filterMap=*/{w, C, CM},
/*outputMap=*/{N, W, C, CM}},
indexingMaps, context))
return false;
// Match body
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
return bodyMatcherForConvolutionOps(yieldVal, body);
}

// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
Expand Down
Loading
Loading