Skip to content

Commit

Permalink
[mlir] add decompose and generalize to structured transform ops
Browse files Browse the repository at this point in the history
These ops complement the tiling/padding transformations by transforming
higher-level named structured operations such as depthwise convolutions into
lower-level and/or generic equivalents that are better handled by some
downstream transformations.

Differential Revision: https://reviews.llvm.org/D126698
  • Loading branch information
ftynse committed Jun 2, 2022
1 parent d42fe9a commit ce2e198
Show file tree
Hide file tree
Showing 8 changed files with 448 additions and 192 deletions.
Expand Up @@ -16,6 +16,49 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
let description = [{
Decomposes named complex operations, such as higher-dimensional
(depthwise) convolutions, into combinations of lower-dimensional equivalents
when possible. The operand handle must point to a list of such operations.
The returning handle points to the main produced computational operation,
such as the lower-dimensional convolution.
}];

let arguments = (ins PDL_Operation:$target);
let results = (outs PDL_Operation:$transformed);
let assemblyFormat = "$target attr-dict";

let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
::mlir::linalg::LinalgOp target);
}];
}

def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
let description = [{
Transforms a named structued operation into the generic form with the
explicit attached region. The operand handle must point to a list of
structured operations, it is consumed by the transformation and is not
expected to be used afterwards. The resulting handle points to the list
of equivalent generic operations, in the same order as the original named
operations.
}];

let arguments = (ins PDL_Operation:$target);
let results = (outs PDL_Operation:$transformed);
let assemblyFormat = "$target attr-dict";

let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
::mlir::linalg::LinalgOp target);
}];
}

def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
Expand Down
50 changes: 50 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -708,6 +708,56 @@ struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> {
LinalgPaddingOptions options;
};

/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
/// convolution ops.
struct DownscaleSizeOneWindowed2DConvolution final
: public OpRewritePattern<Conv2DNhwcHwcfOp> {
DownscaleSizeOneWindowed2DConvolution(
MLIRContext *context,
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
filter(std::move(f)) {}

FailureOr<Conv1DNwcWcfOp>
returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const;

LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(convOp, rewriter);
}

private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};

/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
/// dimensions into 1-D depthwise convolution ops.
struct DownscaleDepthwiseConv2DNhwcHwcOp final
: public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
DownscaleDepthwiseConv2DNhwcHwcOp(
MLIRContext *context,
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
filter(std::move(f)) {}

FailureOr<DepthwiseConv1DNwcWcOp>
returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
PatternRewriter &rewriter) const;

LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(convOp, rewriter);
}

private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};

struct LinalgFusionOptions {
/// List of operands indices to use for fusion.
llvm::SmallSet<unsigned, 1> indicesToFuse = {};
Expand Down
93 changes: 72 additions & 21 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Expand Up @@ -50,6 +50,68 @@ class SimpleRewriter : public PatternRewriter {
};
} // namespace

/// Attempts to apply the pattern specified as template argument to the given
/// operation. The pattern is expected to have a `returningMatchAndRewrite`
/// function that returns the "main" result or failure. Returns failure if the
/// pattern failed to apply. Extra arguments are forwarded to the pattern
/// constructor.
template <typename PatternTy, typename... Args>
static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
// Check if the given operation has the type expected by the pattern.
using OpTy = typename llvm::function_traits<
decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
auto op = dyn_cast<OpTy>(operation);
if (!op)
return failure();

// Apply the pattern directly to the op.
PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
SimpleRewriter rewriter(operation->getContext());
rewriter.setInsertionPoint(operation);
auto result = pattern.returningMatchAndRewrite(op, rewriter);
if (failed(result))
return failure();
return cast<LinalgOp>(result->getOperation());
}

//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//

FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
FailureOr<LinalgOp> windowed =
tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
if (succeeded(windowed))
return windowed;

FailureOr<LinalgOp> depthwise =
tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
if (succeeded(depthwise))
return depthwise;

InFlightDiagnostic diag = emitError() << "failed to apply";
diag.attachNote(target.getLoc()) << "attempted to apply to this op";
return diag;
}

//===----------------------------------------------------------------------===//
// GeneralizeOp
//===----------------------------------------------------------------------===//

FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
// Exit early if no transformation is needed.
if (isa<GenericOp>(target))
return target;

FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
if (succeeded(generic))
return generic;

InFlightDiagnostic diag = emitError() << "failed to apply";
diag.attachNote(target.getLoc()) << "attempted to apply to this op";
return diag;
}

//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
Expand All @@ -70,15 +132,7 @@ FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
return diag;
}

GenericOpInterchangePattern pattern(getContext(), interchangeVector);
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<GenericOp> result =
pattern.returningMatchAndRewrite(genericTarget, rewriter);
if (failed(result))
return failure();

return cast<LinalgOp>(result->getOperation());
return tryApply<GenericOpInterchangePattern>(target, interchangeVector);
}

LogicalResult transform::InterchangeOp::verify() {
Expand Down Expand Up @@ -147,18 +201,15 @@ FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
paddingOptions.setTransposePaddings(transposePaddings);

LinalgPaddingPattern pattern(getContext(), paddingOptions);
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<LinalgOp> patternResult =
pattern.returningMatchAndRewrite(target, rewriter);
if (failed(patternResult)) {
InFlightDiagnostic diag = emitError()
<< "failed to apply pattern to target op";
diag.attachNote(target.getLoc()) << "target op";
return diag;
}
return patternResult;
FailureOr<LinalgOp> result =
tryApply<LinalgPaddingPattern>(target, paddingOptions);
if (succeeded(result))
return result;

InFlightDiagnostic diag = emitError()
<< "failed to apply pattern to target op";
diag.attachNote(target.getLoc()) << "target op";
return diag;
}

LogicalResult transform::PadOp::verify() {
Expand Down

0 comments on commit ce2e198

Please sign in to comment.