-
Notifications
You must be signed in to change notification settings - Fork 10.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir] add OperationType to the Transform dialect
Add a new OperationType handle type to the Transform dialect. This transform type is parameterized by the name of the payload operation it can point to. It is intended as a constraint on transformations that are only applicable to a specific kind of payload operations. If a transformation is applicable to a small set of operation classes, it can be wrapped into a transform op by using a disjunctive constraint, such as `Type<Or<[Transform_ConcreteOperation<"foo">.predicate, Transform_ConcreteOperation<"bar">.predicate]>>` for its operand without modifying this type. Broader sets of accepted operations should be modeled as specific types. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D135586
- Loading branch information
Showing
21 changed files
with
451 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
//===-- mlir-c/Dialect/Transform.h - C API for Transform Dialect --*- C -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM | ||
// Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_C_DIALECT_TRANSFORM_H | ||
#define MLIR_C_DIALECT_TRANSFORM_H | ||
|
||
#include "mlir-c/IR.h" | ||
#include "mlir-c/Support.h" | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform); | ||
|
||
//===---------------------------------------------------------------------===// | ||
// AnyOpType | ||
//===---------------------------------------------------------------------===// | ||
|
||
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type); | ||
|
||
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); | ||
|
||
//===---------------------------------------------------------------------===// | ||
// OperationType | ||
//===---------------------------------------------------------------------===// | ||
|
||
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type); | ||
|
||
MLIR_CAPI_EXPORTED MlirType | ||
mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName); | ||
|
||
MLIR_CAPI_EXPORTED MlirStringRef | ||
mlirTransformOperationTypeGetOperationName(MlirType type); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
||
#endif // MLIR_C_DIALECT_TRANSFORM_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir-c/Dialect/Transform.h" | ||
#include "mlir-c/IR.h" | ||
#include "mlir-c/Support.h" | ||
#include "mlir/Bindings/Python/PybindAdaptors.h" | ||
|
||
namespace py = pybind11; | ||
using namespace mlir; | ||
using namespace mlir::python; | ||
using namespace mlir::python::adaptors; | ||
|
||
void populateDialectTransformSubmodule(const pybind11::module &m) { | ||
//===-------------------------------------------------------------------===// | ||
// AnyOpType | ||
//===-------------------------------------------------------------------===// | ||
|
||
auto anyOpType = | ||
mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType); | ||
anyOpType.def_classmethod( | ||
"get", | ||
[](py::object cls, MlirContext ctx) { | ||
return cls(mlirTransformAnyOpTypeGet(ctx)); | ||
}, | ||
"Get an instance of AnyOpType in the given context.", py::arg("cls"), | ||
py::arg("context") = py::none()); | ||
|
||
//===-------------------------------------------------------------------===// | ||
// OperationType | ||
//===-------------------------------------------------------------------===// | ||
|
||
auto operationType = | ||
mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType); | ||
operationType.def_classmethod( | ||
"get", | ||
[](py::object cls, const std::string &operationName, MlirContext ctx) { | ||
MlirStringRef cOperationName = | ||
mlirStringRefCreate(operationName.data(), operationName.size()); | ||
return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); | ||
}, | ||
"Get an instance of OperationType for the given kind in the given " | ||
"context", | ||
py::arg("cls"), py::arg("operation_name"), | ||
py::arg("context") = py::none()); | ||
operationType.def_property_readonly( | ||
"operation_name", | ||
[](MlirType type) { | ||
MlirStringRef operationName = | ||
mlirTransformOperationTypeGetOperationName(type); | ||
return py::str(operationName.data, operationName.length); | ||
}, | ||
"Get the name of the payload operation accepted by the handle."); | ||
} | ||
|
||
PYBIND11_MODULE(_mlirDialectsTransform, m) { | ||
m.doc() = "MLIR Transform dialect."; | ||
populateDialectTransformSubmodule(m); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
//===- Transform.cpp - C Interface for Transform dialect ------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir-c/Dialect/Transform.h" | ||
#include "mlir-c/Support.h" | ||
#include "mlir/CAPI/Registration.h" | ||
#include "mlir/Dialect/Transform/IR/TransformDialect.h" | ||
#include "mlir/Dialect/Transform/IR/TransformTypes.h" | ||
|
||
using namespace mlir; | ||
|
||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform, | ||
transform::TransformDialect) | ||
|
||
//===---------------------------------------------------------------------===// | ||
// AnyOpType | ||
//===---------------------------------------------------------------------===// | ||
|
||
bool mlirTypeIsATransformAnyOpType(MlirType type) { | ||
return unwrap(type).isa<transform::AnyOpType>(); | ||
} | ||
|
||
MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { | ||
return wrap(transform::AnyOpType::get(unwrap(ctx))); | ||
} | ||
|
||
//===---------------------------------------------------------------------===// | ||
// OperationType | ||
//===---------------------------------------------------------------------===// | ||
|
||
bool mlirTypeIsATransformOperationType(MlirType type) { | ||
return unwrap(type).isa<transform::OperationType>(); | ||
} | ||
|
||
MlirType mlirTransformOperationTypeGet(MlirContext ctx, | ||
MlirStringRef operationName) { | ||
return wrap( | ||
transform::OperationType::get(unwrap(ctx), unwrap(operationName))); | ||
} | ||
|
||
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) { | ||
return wrap(unwrap(type).cast<transform::OperationType>().getOperationName()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Optional | ||
|
||
from mlir.ir import Type, Context | ||
|
||
|
||
class AnyOpType(Type): | ||
@staticmethod | ||
def isinstance(type: Type) -> bool: ... | ||
|
||
@staticmethod | ||
def get(context: Optional[Context] = None) -> AnyOpType: ... | ||
|
||
|
||
class OperationType(Type): | ||
@staticmethod | ||
def isinstance(type: Type) -> bool: ... | ||
|
||
@staticmethod | ||
def get(operation_name: str, context: Optional[Context] = None) -> OperationType: ... | ||
|
||
@property | ||
def operation_name(self) -> str: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.