Skip to content

Commit

Permalink
[mlir][linalg][transform] Add TileOp to transform dialect
Browse files Browse the repository at this point in the history
This commit adds a tiling op to the transform dialect as an external op.

Differential Revision: https://reviews.llvm.org/D124661
  • Loading branch information
matthias-springer committed Apr 29, 2022
1 parent e66127e commit 3c2a74a
Show file tree
Hide file tree
Showing 12 changed files with 414 additions and 9 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Linalg/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(IR)
add_subdirectory(TransformOps)

set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Linalg)
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS LinalgTransformOps.td)
mlir_tablegen(LinalgTransformOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgTransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRLinalgTransformOpsIncGen)
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -0,0 +1,30 @@
//===- LinalgTransformOps.h - Linalg transform ops --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
#define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"

//===----------------------------------------------------------------------===//
// Linalg Transform Operations
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc"

namespace mlir {
class DialectRegistry;

namespace linalg {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace linalg
} // namespace mlir

#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
@@ -0,0 +1,45 @@
//===- LinalgTransformOps.td - Linalg transform ops --------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LINALG_TRANSFORM_OPS
#define LINALG_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def TileOp : Op<Transform_Dialect, "structured.tile",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
Indicates that the given `target` op should be tiled with the options
provided as attributes. This transform generates a loop nest with a smaller
("tiled") target operation in its body. Currently limited to LinalgOps.

`sizes` are the tile sizes. A tile size of `0` indicates that the
respective dimension should not be tiled. No loop will be generated for such
dimensions. If all tile sizes are `0`, this transform is effectively a
no-op.

This op returns handles to the tiled op (in the generated loop nest) and the
generated loops. The number of loops is the number of non-zero tile sizes.
}];

let arguments = (ins PDL_Operation:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
let results = (outs PDL_Operation:$tiled_linalg_op,
Variadic<PDL_Operation>:$loops);

let hasCustomAssemblyFormat = 1;
}

#endif // LINALG_TRANSFORM_OPS
6 changes: 6 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Expand Up @@ -33,6 +33,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
Expand Down Expand Up @@ -101,6 +102,11 @@ inline void registerAllDialects(DialectRegistry &registry) {
tosa::TosaDialect,
x86vector::X86VectorDialect>();
// clang-format on

// Register all dialect extensions.
linalg::registerTransformDialectExtension(registry);

// Register all external models.
arith::registerBufferizableOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(TransformOps)
add_subdirectory(Transforms)
add_subdirectory(Utils)
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -0,0 +1,18 @@
add_mlir_dialect_library(MLIRLinalgTransformOps
LinalgTransformOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg/TransformOps

DEPENDS
MLIRLinalgTransformOpsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRLinalg
MLIRLinalgTransforms
MLIRParser
MLIRPDL
MLIRSideEffectInterfaces
MLIRTransformDialect
)
198 changes: 198 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -0,0 +1,198 @@
//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Parser/Parser.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::transform;

/// Extracts a vector of int64_t from an array attribute. Asserts if the
/// attribute contains values other than integers.
static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
SmallVector<int64_t> result;
result.reserve(attr.size());
for (APInt value : attr.getAsValueRange<IntegerAttr>())
result.push_back(value.getSExtValue());
return result;
}

/// Extracts a vector of unsigned from an array attribute. Asserts if the
/// attribute contains values other than intergers. May truncate.
static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
SmallVector<unsigned> result;
result.reserve(attr.size());
for (APInt value : attr.getAsValueRange<IntegerAttr>())
result.push_back(value.getZExtValue());
return result;
}

namespace {
/// A simple pattern rewriter that implements no special logic.
class SimpleRewriter : public PatternRewriter {
public:
SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
};
} // namespace

//===----------------------------------------------------------------------===//
// TileOp
//===----------------------------------------------------------------------===//

/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
static LogicalResult
applyTilingToAll(Operation *transformOp, Value target,
ArrayRef<int64_t> tileSizes,
transform::TransformResults &transformResults,
transform::TransformState &state,
function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
// Number of loops: Number of tiles sizes that are not zero.
size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
// All payload ops. These should all be LinalgOps for now.
ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);

SmallVector<Operation *> tiledLinalgOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
for (unsigned int i = 0; i < numLoops; ++i)
loopOps[i].reserve(payloadOps.size());

for (Operation *target : payloadOps) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
if (!linalgOp)
return transformOp->emitError("only LinalgOps are supported");

FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
if (failed(tiled))
return failure();

tiledLinalgOps.push_back(tiled->op);
if (tiled->loops.size() != numLoops)
// Not enough loops were generated. This usually means that the input size
// was smaller than the tiling size.
// TODO: LinalgTilingPattern should return failure().
return failure();
for (unsigned int i = 0; i < numLoops; ++i)
loopOps[i].push_back(tiled->loops[i]);
}

transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
for (unsigned int i = 0; i < numLoops; ++i)
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
return success();
}

LogicalResult transform::TileOp::apply(TransformResults &transformResults,
TransformState &state) {
LinalgTilingOptions tilingOptions;
SmallVector<int64_t> tileSizes = extractI64Array(getSizes());

if (!tileSizes.empty())
tilingOptions.setTileSizes(tileSizes);
tilingOptions.setInterchange(extractUIntArray(getInterchange()));
LinalgTilingPattern pattern(getContext(), tilingOptions);

return applyTilingToAll(getOperation(), getTarget(), tileSizes,
transformResults, state, [&](LinalgOp linalgOp) {
SimpleRewriter rewriter(linalgOp.getContext());
return pattern.returningMatchAndRewrite(linalgOp,
rewriter);
});
}

ParseResult transform::TileOp::parse(OpAsmParser &parser,
OperationState &result) {
StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue();
OpAsmParser::UnresolvedOperand targetOperand;
SMLoc opLoc;
parser.getCurrentLocation(&opLoc);
if (parser.parseOperand(targetOperand))
return parser.emitError(opLoc, "expected 'target' operand");
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
Attribute sizesAttr = result.attributes.get(sizesAttrName);
if (!sizesAttr)
return parser.emitError(opLoc)
<< "expected '" << sizesAttrName << "' attribute";
auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
if (!sizesArrayAttr)
return parser.emitError(opLoc)
<< "'" << sizesAttrName << "' attribute must be an array";
Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
size_t numExpectedLoops =
sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
return failure();
return success();
}

void TileOp::print(OpAsmPrinter &p) {
p << ' ';
p << getTarget();
p.printOptionalAttrDict((*this)->getAttrs());
}

void TileOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
// `target` arg is consumed and can no longer be used.
effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
TransformMappingResource::get());

for (Value r : getResults()) {
effects.emplace_back(MemoryEffects::Write::get(), r,
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Allocate::get(), r,
TransformMappingResource::get());
}

effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
/// Registers new ops and declares PDL as dependent dialect since the additional
/// ops are using PDL types for operands and results.
class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
public:
LinalgTransformDialectExtension() {
declareDependentDialect<pdl::PDLDialect>();
declareDependentDialect<scf::SCFDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
>();
}
};
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"

void mlir::linalg::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<LinalgTransformDialectExtension>();
}
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Expand Up @@ -168,8 +168,8 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,

// Shift all IndexOp results by the tile offset.
SmallVector<Value> allIvs;
transform(loopRanges, std::back_inserter(allIvs),
[](Range range) { return range.offset; });
llvm::transform(loopRanges, std::back_inserter(allIvs),
[](Range range) { return range.offset; });
addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);

return clonedOp;
Expand Down
15 changes: 8 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Expand Up @@ -87,10 +87,11 @@ getTiledProducerLoops(OpResult producerResult,
assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
"expect slice and producer loop dimensions map one-to-one");
SmallVector<int64_t> tiledProducerLoopIndices;
transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
return tiledProducerIndexingSubMap.getDimPosition(idx);
});
llvm::transform(
llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
return tiledProducerIndexingSubMap.getDimPosition(idx);
});

return tiledProducerLoopIndices;
}
Expand Down Expand Up @@ -141,9 +142,9 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,

// Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
SmallVector<Value> producerLoopBounds;
transform(producerOp.createLoopRanges(b, loc),
std::back_inserter(producerLoopBounds),
[](Range range) { return range.size; });
llvm::transform(producerOp.createLoopRanges(b, loc),
std::back_inserter(producerLoopBounds),
[](Range range) { return range.size; });
SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);

// Tile the producer operands given the `sliceOp` ranges. Iterate the
Expand Down

0 comments on commit 3c2a74a

Please sign in to comment.