Skip to content

Commit

Permalink
[mlir] add types to the transform dialect
Browse files Browse the repository at this point in the history
Introduce a type system for the transform dialect. A transform IR type
captures the expectations of the transform IR on the payload IR
operations that are being transformed, such as being of a certain kind
or implementing an interface that enables the transformation. This
provides stricter checking and better readability of the transform IR
than using the catch-all "handle" type.

This change implements the basic support for a type system amendable to
dialect extensions and adds a drop-in replacement for the unrestricted
"handle" type. The actual switch of transform dialect ops to that type
will happen in a separate commit.

See https://discourse.llvm.org/t/rfc-type-system-for-the-transform-dialect/65702

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D135164
  • Loading branch information
ftynse committed Oct 11, 2022
1 parent b845add commit bba85eb
Show file tree
Hide file tree
Showing 24 changed files with 538 additions and 33 deletions.
4 changes: 4 additions & 0 deletions mlir/docs/Dialects/Transform.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

[TOC]

[include "Dialects/TransformTypes.md"]

[include "Dialects/TransformOps.md"]

## Bufferization Transform Operations
Expand All @@ -16,4 +18,6 @@

[include "Dialects/LinalgStructuredTransformOps.md"]

[include "Dialects/TransformTypeInterfaces.md"]

[include "Dialects/TransformOpInterfaces.md"]
15 changes: 15 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform)
add_public_tablegen_target(MLIRTransformDialectIncGen)
add_dependencies(mlir-headers MLIRTransformDialectIncGen)

set(LLVM_TARGET_DEFINITIONS TransformTypes.td)
mlir_tablegen(TransformTypes.h.inc -gen-typedef-decls)
mlir_tablegen(TransformTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRTransformTypesIncGen)
add_dependencies(mlir-headers MLIRTransformTypesIncGen)
add_mlir_doc(TransformTypes TransformTypes Dialects/ -gen-typedef-docs)

set(LLVM_TARGET_DEFINITIONS TransformAttrs.td)
mlir_tablegen(TransformDialectEnums.h.inc -gen-enum-decls)
mlir_tablegen(TransformDialectEnums.cpp.inc -gen-enum-defs)
Expand All @@ -17,5 +24,13 @@ add_dependencies(mlir-headers MLIRTransformDialectEnumIncGen)
add_mlir_dialect(TransformOps transform)
add_mlir_doc(TransformOps TransformOps Dialects/ -gen-dialect-doc -dialect=transform)

# Contrary to what the name claims, this only produces the _op_ interface.
add_mlir_interface(TransformInterfaces)
add_mlir_doc(TransformInterfaces TransformOpInterfaces Dialects/ -gen-op-interface-docs)

set(LLVM_TARGET_DEFINITIONS TransformInterfaces.td)
mlir_tablegen(TransformTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TransformTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRTransformDialectTypeInterfacesIncGen)
add_dependencies(mlir-headers MLIRTransformDialectTypeInterfacesIncGen)
add_mlir_doc(TransformInterfaces TransformTypeInterfaces Dialects/ -gen-type-interface-docs)
37 changes: 37 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringMap.h"

namespace mlir {
Expand All @@ -37,6 +38,11 @@ static inline void checkImplementsTransformInterface(MLIRContext *context) {
"ops injected into the transform dialect must implement "
"MemoryEffectsOpInterface");
}

/// Asserts that the type provided as template argument implements the
/// TransformTypeInterface. This must be a dynamic assertion since interface
/// implementations may be registered at runtime.
void checkImplementsTransformTypeInterface(TypeID typeID, MLIRContext *context);
} // namespace detail
#endif // NDEBUG
} // namespace transform
Expand Down Expand Up @@ -120,6 +126,18 @@ class TransformDialectExtension
});
}

/// Injects the types into the Transform dialect. The types must implement
/// the TransformTypeInterface and the implementation must be already
/// available when the type is injected. Furthermore, the types must provide
/// a `getMnemonic` static method returning an object convertible to
/// `StringRef` that is unique across all injected types.
template <typename... TypeTys>
void registerTypes() {
opInitializers.push_back([](TransformDialect *transformDialect) {
transformDialect->addTypesChecked<TypeTys...>();
});
}

/// Declares that this Transform dialect extension depends on the dialect
/// provided as template parameter. When the Transform dialect is loaded,
/// dependent dialects will be loaded as well. This is intended for dialects
Expand Down Expand Up @@ -182,6 +200,25 @@ class TransformDialectExtension
bool buildOnly;
};

template <typename Type>
void TransformDialect::addTypeIfNotRegistered() {
// Use the address of the parse method as a proxy for identifying whether we
// are registering the same type class for the same mnemonic.
StringRef mnemonic = Type::getMnemonic();
auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse);
if (!inserted) {
const ExtensionTypeParsingHook &parsingHook = it->getValue();
if (*parsingHook.target<mlir::Type (*)(AsmParser &)>() != &Type::parse)
reportDuplicateTypeRegistration(mnemonic);
}
typePrintingHooks.try_emplace(
TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) {
printer << Type::getMnemonic();
cast<Type>(type).print(printer);
});
addTypes<Type>();
}

/// A wrapper for transform dialect extensions that forces them to be
/// constructed in the build-only mode.
template <typename DerivedTy>
Expand Down
55 changes: 54 additions & 1 deletion mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,23 @@ def Transform_Dialect : Dialect {
const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
getPDLConstraintHooks() const;

/// Parses a type registered by this dialect or one of its extensions.
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;

/// Prints a type registered by this dialect or one of its extensions.
void printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const override;

/// Parser callback for an individual type registered by this dialect or
/// its extensions.
using ExtensionTypeParsingHook =
std::function<::mlir::Type (::mlir::AsmParser &)>;

/// Printer callback for an individual type registered by this dialect or
/// its extensions.
using ExtensionTypePrintingHook =
std::function<void (::mlir::Type, ::mlir::AsmPrinter &)>;

private:
template <typename OpTy>
void addOperationIfNotRegistered() {
Expand Down Expand Up @@ -344,6 +361,28 @@ def Transform_Dialect : Dialect {
#endif // NDEBUG
}

/// Registers the types specified as template parameters with the
/// Transform dialect. Checks that they meet the requirements for
/// Transform IR types.
template <typename... TypeTys>
void addTypesChecked() {
(addTypeIfNotRegistered<TypeTys>(), ...);

#ifndef NDEBUG
(detail::checkImplementsTransformTypeInterface(
TypeID::get<TypeTys>(), getContext()), ...);
#endif // NDEBUG
}

/// Implementation of the type registration for a single type, should
/// not be called directly, use addTypesChecked instead.
template <typename Type>
void addTypeIfNotRegistered();

/// Reports a repeated registration error of a type with the given
/// mnemonic.
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);

template <typename, typename...>
friend class TransformDialectExtension;

Expand All @@ -352,9 +391,23 @@ def Transform_Dialect : Dialect {
void mergeInPDLMatchHooks(
::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns);

//===----------------------------------------------------------------===//
// Data fields
//===----------------------------------------------------------------===//

/// A container for PDL constraint function that can be used by
/// operations in this dialect.
PDLPatternModule pdlMatchHooks;
::mlir::PDLPatternModule pdlMatchHooks;

/// A map from type mnemonic to its parsing function for the remainder of
/// the syntax. The parser has access to the mnemonic, so it is used for
/// further dispatch.
::llvm::StringMap<ExtensionTypeParsingHook> typeParsingHooks;

/// A map from type TypeID to its printing function. No need to do string
/// lookups when the type is fully constructed.
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;
}];
}

Expand Down
37 changes: 24 additions & 13 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/IR/OpDefinition.h"

#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ScopeExit.h"

namespace mlir {
Expand Down Expand Up @@ -279,13 +280,16 @@ class TransformState {
/// list of operations in the payload IR. The arguments must be defined in
/// blocks of the currently processed transform IR region, typically after a
/// region scope is defined.
void mapBlockArguments(BlockArgument argument,
ArrayRef<Operation *> operations) {
///
/// Returns failure if the payload does not satisfy the conditions associated
/// with the type of the handle value.
LogicalResult mapBlockArguments(BlockArgument argument,
ArrayRef<Operation *> operations) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
assert(argument.getParentRegion() == regionStack.back() &&
"mapping block arguments from a region other than the active one");
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
setPayloadOps(argument, operations);
return setPayloadOps(argument, operations);
}

// Forward declarations to support limited visibility.
Expand Down Expand Up @@ -478,7 +482,10 @@ class TransformState {
/// is invalid given the transformation "consumes" the handle as expressed
/// by side effects. Practically, a transformation consuming a handle means
/// that the associated payload operation may no longer exist.
void setPayloadOps(Value value, ArrayRef<Operation *> targets);
///
/// Returns failure if the payload does not satisfy the conditions associated
/// with the type of the handle value.
LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);

/// Forgets the payload IR ops associated with the given transform IR value.
void removePayloadOps(Value value);
Expand All @@ -488,8 +495,12 @@ class TransformState {
/// expected to return the modified operation or nullptr. In the latter case,
/// the corresponding operation is no longer associated with the transform IR
/// value.
void updatePayloadOps(Value value,
function_ref<Operation *(Operation *)> callback);
///
/// Returns failure if the payload does not satisfy the conditions associated
/// with the type of the handle value.
LogicalResult
updatePayloadOps(Value value,
function_ref<Operation *(Operation *)> callback);

/// If the operand is a handle consumed by the operation, i.e. has the "free"
/// memory effect associated with it, identifies other handles that are
Expand Down Expand Up @@ -574,9 +585,9 @@ namespace detail {
/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
/// to either the list of operations associated with its operand or the root of
/// the payload IR, depending on what is available in the context.
void mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
Operation *op,
Region &region);
LogicalResult
mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
Operation *op, Region &region);

/// Verification hook for PossibleTopLevelTransformOpTrait.
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
Expand Down Expand Up @@ -613,17 +624,17 @@ class PossibleTopLevelTransformOpTrait
/// Sets up the mapping between the entry block of the given region of this op
/// and the relevant list of Payload IR operations in the given state. The
/// state is expected to be already scoped at the region of this operation.
void mapBlockArguments(TransformState &state, Region &region) {
LogicalResult mapBlockArguments(TransformState &state, Region &region) {
assert(region.getParentOp() == this->getOperation() &&
"op comes from the wrong region");
detail::mapPossibleTopLevelTransformOpBlockArguments(
return detail::mapPossibleTopLevelTransformOpBlockArguments(
state, this->getOperation(), region);
}
void mapBlockArguments(TransformState &state) {
LogicalResult mapBlockArguments(TransformState &state) {
assert(
this->getOperation()->getNumRegions() == 1 &&
"must indicate the region to map if the operation has more than one");
mapBlockArguments(state, this->getOperation()->getRegion(0));
return mapBlockArguments(state, this->getOperation()->getRegion(0));
}
};

Expand Down
25 changes: 25 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,31 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
}];
}

def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> {
let description = [{
Types that can be used for Transform dialect handle values. Such types
define the properties of Payload IR operations associated with the handle.
A user of such a handle can assume that these properties have been verified
for any Payload IR operation associated with it.
}];

let cppNamespace = "::mlir::transform";

let methods = [
InterfaceMethod<
/*desc=*/[{
Checks if the given list of associated Payload IR operations satisfy
the conditions defined by this type. If not, produces a silenceable
error at the specified location.
}],
/*returnType=*/"::mlir::DiagnosedSilenceableFailure",
/*name=*/"checkPayload",
/*arguments=*/(ins "::mlir::Location":$loc,
"::mlir::ArrayRef<::mlir::Operation *>":$payload)
>
];
}

def FunctionalStyleTransformOpTrait
: NativeOpTrait<"FunctionalStyleTransformOpTrait"> {
let cppNamespace = "::mlir::transform";
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H

#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"

namespace mlir {
Expand Down
18 changes: 18 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS

include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
Expand Down Expand Up @@ -96,6 +97,23 @@ def AlternativesOp : TransformDialectOp<"alternatives",
let hasVerifier = 1;
}

def CastOp : TransformDialectOp<"cast",
[TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
// TODO: temporarily fallback support for casting from PDL_Operation type.
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation *target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}

def ForeachOp : TransformDialectOp<"foreach",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
Expand Down
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformTypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- TransformTypes.h - Transform dialect types ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"

namespace mlir {
class DiagnosedSilenceableFailure;
class Operation;
class Type;
} // namespace mlir

#include "mlir/Dialect/Transform/IR/TransformTypeInterfaces.h.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Transform/IR/TransformTypes.h.inc"

#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES_H

0 comments on commit bba85eb

Please sign in to comment.