Skip to content

Commit

Permalink
[mlir] add OperationType to the Transform dialect
Browse files Browse the repository at this point in the history
Add a new OperationType handle type to the Transform dialect. This
transform type is parameterized by the name of the payload operation it
can point to. It is intended as a constraint on transformations that are
only applicable to a specific kind of payload operations. If a
transformation is applicable to a small set of operation classes, it can
be wrapped into a transform op by using a disjunctive constraint, such
as `Type<Or<[Transform_ConcreteOperation<"foo">.predicate,
Transform_ConcreteOperation<"bar">.predicate]>>` for its operand without
modifying this type. Broader sets of accepted operations should be
modeled as specific types.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D135586
  • Loading branch information
ftynse committed Oct 11, 2022
1 parent 6bb997c commit 3e1f6d0
Show file tree
Hide file tree
Showing 21 changed files with 451 additions and 11 deletions.
46 changes: 46 additions & 0 deletions mlir/include/mlir-c/Dialect/Transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//===-- mlir-c/Dialect/Transform.h - C API for Transform Dialect --*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_C_DIALECT_TRANSFORM_H
#define MLIR_C_DIALECT_TRANSFORM_H

#include "mlir-c/IR.h"
#include "mlir-c/Support.h"

#ifdef __cplusplus
extern "C" {
#endif

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);

//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);

MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);

//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type);

MLIR_CAPI_EXPORTED MlirType
mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);

MLIR_CAPI_EXPORTED MlirStringRef
mlirTransformOperationTypeGetOperationName(MlirType type);

#ifdef __cplusplus
}
#endif

#endif // MLIR_C_DIALECT_TRANSFORM_H
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ def Transform_Dialect : Dialect {
/// mnemonic.
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);

void initializeTypes();

template <typename, typename...>
friend class TransformDialectExtension;

Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> {
"::mlir::ArrayRef<::mlir::Operation *>":$payload)
>
];

let extraSharedClassDeclaration = [{
DiagnosedSilenceableFailure emitSilenceableError(Location loc) const {
Diagnostic diag(loc, DiagnosticSeverity::Error);
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
}];
}

def FunctionalStyleTransformOpTrait
Expand Down
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,24 @@ def Transform_AnyOpType : TypeDef<Transform_Dialect, "AnyOp",
let assemblyFormat = "";
}

def Transform_OperationType : TypeDef<Transform_Dialect, "Operation",
[DeclareTypeInterfaceMethods<TransformTypeInterface>]> {
let description = [{
Transform IR handle that can be associated with a list of Payload IR
operations with the specified operation name.
}];
let mnemonic = "op";
let parameters = (ins
StringRefParameter<"Name of the allowed payload operation">:$operation_name
);
let assemblyFormat = "`<` $operation_name `>`";
}

class Transform_ConcreteOpType<string opname>
: Type<And<[Transform_OperationType.predicate,
CPred<"$_self.cast<::mlir::transform::OperationType>()"
".getOperationName() == \"" # opname # "\"">]>,
"Transform IR handle to " # opname # " operations",
"::mlir::transform::OperationType">;

#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES
64 changes: 64 additions & 0 deletions mlir/lib/Bindings/Python/DialectTransform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"

namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;

void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//

auto anyOpType =
mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
anyOpType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirTransformAnyOpTypeGet(ctx));
},
"Get an instance of AnyOpType in the given context.", py::arg("cls"),
py::arg("context") = py::none());

//===-------------------------------------------------------------------===//
// OperationType
//===-------------------------------------------------------------------===//

auto operationType =
mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType);
operationType.def_classmethod(
"get",
[](py::object cls, const std::string &operationName, MlirContext ctx) {
MlirStringRef cOperationName =
mlirStringRefCreate(operationName.data(), operationName.size());
return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
},
"Get an instance of OperationType for the given kind in the given "
"context",
py::arg("cls"), py::arg("operation_name"),
py::arg("context") = py::none());
operationType.def_property_readonly(
"operation_name",
[](MlirType type) {
MlirStringRef operationName =
mlirTransformOperationTypeGetOperationName(type);
return py::str(operationName.data, operationName.length);
},
"Get the name of the payload operation accepted by the handle.");
}

PYBIND11_MODULE(_mlirDialectsTransform, m) {
m.doc() = "MLIR Transform dialect.";
populateDialectTransformSubmodule(m);
}
9 changes: 9 additions & 0 deletions mlir/lib/CAPI/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ add_mlir_upstream_c_api_library(MLIRCAPITensor
MLIRTensorDialect
)

add_mlir_upstream_c_api_library(MLIRCAPITransformDialect
Transform.cpp

PARTIAL_SOURCES_INTENDED
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIRTransformDialect
)

add_mlir_upstream_c_api_library(MLIRCAPIQuant
Quant.cpp

Expand Down
48 changes: 48 additions & 0 deletions mlir/lib/CAPI/Dialect/Transform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//===- Transform.cpp - C Interface for Transform dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"

using namespace mlir;

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform,
transform::TransformDialect)

//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//

bool mlirTypeIsATransformAnyOpType(MlirType type) {
return unwrap(type).isa<transform::AnyOpType>();
}

MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}

//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//

bool mlirTypeIsATransformOperationType(MlirType type) {
return unwrap(type).isa<transform::OperationType>();
}

MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef operationName) {
return wrap(
transform::OperationType::get(unwrap(ctx), unwrap(operationName)));
}

MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
return wrap(unwrap(type).cast<transform::OperationType>().getOperationName());
}
5 changes: 1 addition & 4 deletions mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ void transform::TransformDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
addTypesChecked<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
>();
initializeTypes();

pdl::OperationType::attachInterface<
PDLOperationTypeTransformTypeInterfaceImpl>(*getContext());
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,31 @@ generatedTypePrinter(Type def, AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"

void transform::TransformDialect::initializeTypes() {
addTypesChecked<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
>();
}

DiagnosedSilenceableFailure
transform::AnyOpType::checkPayload(Location loc,
ArrayRef<Operation *> payload) const {
return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure
transform::OperationType::checkPayload(Location loc,
ArrayRef<Operation *> payload) const {
OperationName opName(getOperationName(), loc.getContext());
for (Operation *op : payload) {
if (opName != op->getName()) {
DiagnosedSilenceableFailure diag =
emitSilenceableError(loc) << "incompatible payload operation name";
diag.attachNote(op->getLoc()) << "payload operation";
return diag;
}
}

return DiagnosedSilenceableFailure::success();
}
14 changes: 14 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ declare_mlir_dialect_python_bindings(
SOURCES
dialects/_transform_ops_ext.py
dialects/transform/__init__.py
_mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform)

declare_mlir_dialect_extension_python_bindings(
Expand Down Expand Up @@ -353,6 +354,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
MLIRCAPISparseTensor
)

declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
MODULE_NAME _mlirDialectsTransform
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectTransform.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPITransformDialect
)

declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
MODULE_NAME _mlirAsyncPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect
Expand Down
26 changes: 26 additions & 0 deletions mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional

from mlir.ir import Type, Context


class AnyOpType(Type):
@staticmethod
def isinstance(type: Type) -> bool: ...

@staticmethod
def get(context: Optional[Context] = None) -> AnyOpType: ...


class OperationType(Type):
@staticmethod
def isinstance(type: Type) -> bool: ...

@staticmethod
def get(operation_name: str, context: Optional[Context] = None) -> OperationType: ...

@property
def operation_name(self) -> str: ...
10 changes: 10 additions & 0 deletions mlir/python/mlir/dialects/_transform_ops_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ def _get_symbol_ref_attr(value: Union[Attribute, str]):
return FlatSymbolRefAttr.get(value)


class CastOp:

def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(
result_type,
_get_op_result_or_value(target),
loc=loc,
ip=ip)


class GetClosestIsolatedParentOp:

def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
Expand Down
1 change: 1 addition & 0 deletions mlir/python/mlir/dialects/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ def _as_int(self):
return 2

from .._transform_ops_gen import *
from ..._mlir_libs._mlirDialectsTransform import *
14 changes: 11 additions & 3 deletions mlir/test/CAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ _add_capi_test_executable(mlir-capi-pass-test
MLIRCAPITransforms
)

_add_capi_test_executable(mlir-capi-pdl-test
pdl.c
LINK_LIBS PRIVATE
MLIRCAPIIR
MLIRCAPIRegisterEverything
MLIRCAPIPDL
)

_add_capi_test_executable(mlir-capi-sparse-tensor-test
sparse_tensor.c
LINK_LIBS PRIVATE
Expand All @@ -70,10 +78,10 @@ _add_capi_test_executable(mlir-capi-quant-test
MLIRCAPIQuant
)

_add_capi_test_executable(mlir-capi-pdl-test
pdl.c
_add_capi_test_executable(mlir-capi-transform-test
transform.c
LINK_LIBS PRIVATE
MLIRCAPIIR
MLIRCAPIRegisterEverything
MLIRCAPIPDL
MLIRCAPITransformDialect
)

0 comments on commit 3e1f6d0

Please sign in to comment.