-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][emitc] Arith to EmitC conversion: constants #83798
Conversation
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Tina Jung (TinaAMD) ChangesAdd a conversion pass from Arith to EmitC. Add an initial conversion from Full diff: https://github.com/llvm/llvm-project/pull/83798.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
new file mode 100644
index 00000000000000..43322ac7f51f6c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
@@ -0,0 +1,22 @@
+//===- ArithToEmitC.h - Convert Arith to EmitC ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_ARITHTOEMITCCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateArithToEmitCConversionPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 81f69210fade8d..f41400a633ef22 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -13,6 +13,7 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 94fc7a7d2194bf..358ac997fba2a3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -133,6 +133,18 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArithToEmitC
+//===----------------------------------------------------------------------===//
+
+def ArithToEmitCConversionPass : Pass<"convert-arith-to-emitc"> {
+ let summary = "Convert Arith ops to EmitC ops";
+ let description = [{
+ Convert `arith` operations to operations in the `emitc` dialect.
+ }];
+ let dependentDialects = ["emitc::EmitCDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ArithToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
new file mode 100644
index 00000000000000..648fd2b4af0b70
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -0,0 +1,104 @@
+//===- ArithToEmitC.cpp - Arith to EmitC conversion -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert arith ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOEMITCCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+static bool isConvertibleToEmitC(Type type) {
+ Type baseType = type;
+ if (auto tensorType = dyn_cast<TensorType>(type)) {
+ if (!tensorType.hasRank() || !tensorType.hasStaticShape()) {
+ return false;
+ }
+ baseType = tensorType.getElementType();
+ }
+
+ if (isa<IndexType>(baseType)) {
+ return true;
+ }
+
+ if (auto intType = dyn_cast<IntegerType>(baseType)) {
+ switch (intType.getWidth()) {
+ case 1:
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ return true;
+ }
+ return false;
+ }
+
+ if (auto floatType = dyn_cast<FloatType>(baseType)) {
+ return floatType.isF32() || floatType.isF64();
+ }
+
+ return false;
+}
+
+class ArithConstantOpConversionPattern
+ : public OpRewritePattern<arith::ConstantOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::ConstantOp arithConst,
+ PatternRewriter &rewriter) const override {
+
+ auto constantType = arithConst.getType();
+ if (!isConvertibleToEmitC(constantType)) {
+ return rewriter.notifyMatchFailure(arithConst.getLoc(),
+ "Type cannot be converted to emitc");
+ }
+
+ rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, constantType,
+ arithConst.getValue());
+ return success();
+ }
+};
+
+struct ConvertArithToEmitCPass
+ : public impl::ArithToEmitCConversionPassBase<ConvertArithToEmitCPass> {
+public:
+ void runOnOperation() override {
+
+ ConversionTarget target(getContext());
+ target.addIllegalDialect<arith::ArithDialect>();
+ target.addLegalDialect<emitc::EmitCDialect>();
+ RewritePatternSet patterns(&getContext());
+ populateArithToEmitCConversionPatterns(patterns);
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+void mlir::populateArithToEmitCConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ArithConstantOpConversionPattern>(patterns.getContext());
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..c1bb6d71310edb
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(ArithToEmitC
+ ArithToEmitC.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIREmitCDialect
+ MLIRArithDialect
+ MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9e421f7c49dbc3..8219cf98575f3c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
add_subdirectory(ArithToArmSME)
+add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir
new file mode 100644
index 00000000000000..b13c6506787c56
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
+
+func.func @arith_constant_complex_tensor() -> (tensor<complex<i32>>) {
+ // expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
+ %c = arith.constant dense<(2, 2)> : tensor<complex<i32>>
+ return %c : tensor<complex<i32>>
+}
+
+// -----
+
+func.func @arith_constant_invalid_int_type() -> (i10) {
+ // expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
+ %c = arith.constant 0 : i10
+ return %c : i10
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
new file mode 100644
index 00000000000000..2583dd832c314c
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-emitc %s | FileCheck %s
+
+// CHECK-LABEL: arith_constants
+func.func @arith_constants() {
+ // CHECK: emitc.constant
+ // CHECK-SAME: value = 0 : index
+ %c_index = arith.constant 0 : index
+ // CHECK: emitc.constant
+ // CHECK-SAME: value = 0 : i32
+ %c_signless_int_32 = arith.constant 0 : i32
+ // CHECK: emitc.constant
+ // CHECK-SAME: value = 0.{{0+}}e+00 : f32
+ %c_float_32 = arith.constant 0.0 : f32
+ // CHECK: emitc.constant
+ // CHECK-SAME: value = dense<0> : tensor<i32>
+ %c_tensor_single_value = arith.constant dense<0> : tensor<i32>
+ // CHECK: emitc.constant
+ // CHECK-SAME: value{{.*}}[1, 2], [-3, 9], [0, 0], [2, -1]{{.*}}tensor<4x2xi64>
+ %c_tensor_value = arith.constant dense<[[1, 2], [-3, 9], [0, 0], [2, -1]]> : tensor<4x2xi64>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just send #84151, which adds what I already had for Arith to EmitC. The changes are structured a little bit different (e.g. with the patch it is allowed to register the patterns within other passes) and further adds a Bazel build configuration. I would therefore highly appreciate if you could apply your changes op top, in particular the conversion pattern for arith.constant
, after #84151 has landed.
* Add a conversion from `arith.constant` to `emitc.constant`. * Drop the translation for `arith.constant`s.
c2ebdad
to
a595b17
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please including dropping
":ArithDialect", |
Regarding the other changes, I need to free up some time to take a closer look, especially at the test.
Thanks, I dropped it! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please readd the dropped tests in const.mlir
file. Besides that this looks good.
Add tests for float/tensor/index types to the `emitc.constant` tests (they were previously tested on `arith.constant`s).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
@TinaAMD let me know if I should hit the "Squash and merge" button for you. |
Yes, that would be great. |
I've merged your PR but unfortunately it was associated with an anonymous GH email ( |
I changed it, thanks for the hint. |
* Add a conversion from `arith.constant` to `emitc.constant`. * Drop the translation for `arith.constant`s.
arith.constant
toemitc.constant
.arith.constant
s.