Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][EmitC] Add Arith to EmitC conversions #84151

Merged
merged 4 commits into from
Mar 7, 2024
Merged

Conversation

marbre
Copy link
Member

@marbre marbre commented Mar 6, 2024

This adds patterns and a pass to convert the Arith dialect to EmitC. For now, this covers arithemtic binary ops operating on floating point types.

It is not checked within the patterns whether the types, such as the Tensor type, are supported in the respective EmitC operations. If unsupported types should be converted, the conversion will fail anyway because no legal EmitC operation can be created. This can clearly be improved in a follow up, also resulting in better error messages. Functions for such checks should not solely be used in the conversions and should also be (re)used in the verifier.

@llvmbot llvmbot added the mlir label Mar 6, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 6, 2024

@llvm/pr-subscribers-mlir

Author: Marius Brehler (marbre)

Changes

This adds patterns and a pass to convert the Arith dialect to EmitC. For now, this covers arithemtic binary ops operating on floating point types.

It is not checked within the patterns whether the types, such as the Tensor type, are supported in the respective EmitC operations. If unsupported types should be converted, the conversion will fail anyway because no legal EmitC operation can be created. This can clearly be improved in a follow up, also resulting in better error messages. Functions for such checks should not solely be used in the conversions and should also be (re)used in the verifier.


Full diff: https://github.com/llvm/llvm-project/pull/84151.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+9)
  • (added) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+59)
  • (added) mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp (+47)
  • (added) mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt (+16)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+27)
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 81f69210fade8d..f2aa4fb535402d 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 94fc7a7d2194bf..0e76069faf44c0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -133,6 +133,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ArithToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertArithToEmitC : Pass<"convert-arith-to-emitc", "ModuleOp"> {
+  let summary = "Convert Arith dialect to EmitC dialect";
+  let dependentDialects = ["emitc::EmitCDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ArithToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
new file mode 100644
index 00000000000000..15942f54441424
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -0,0 +1,59 @@
+//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert the Arith dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+template <typename ArithOp, typename EmitCOp>
+class ArithOpConversion final : public OpConversionPattern<ArithOp> {
+public:
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
+                                                  adaptor.getOperands());
+
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns) {
+  MLIRContext *ctx = patterns.getContext();
+
+  // clang-format off
+  patterns.add<
+    ArithOpConversion<arith::AddFOp, emitc::AddOp>,
+    ArithOpConversion<arith::DivFOp, emitc::DivOp>,
+    ArithOpConversion<arith::MulFOp, emitc::MulOp>,
+    ArithOpConversion<arith::SubFOp, emitc::SubOp>
+  >(ctx);
+  // clang-format on
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
new file mode 100644
index 00000000000000..f6a531eaaca242
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -0,0 +1,47 @@
+//===- ArithToEmitCPass.cpp - Func to EmitC Pass ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert the Arith dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARITHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct ConvertArithToEmitC
+    : public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertArithToEmitC::runOnOperation() {
+  ConversionTarget target(getContext());
+
+  target.addLegalDialect<emitc::EmitCDialect>();
+
+  RewritePatternSet patterns(&getContext());
+  populateArithToEmitCPatterns(patterns);
+
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..a3784f47c3bc2d
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRArithToEmitC
+  ArithToEmitC.cpp
+  ArithToEmitCPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIREmitCDialect
+  MLIRPass
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9e421f7c49dbc3..8219cf98575f3c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
 add_subdirectory(ArithCommon)
 add_subdirectory(ArithToAMDGPU)
 add_subdirectory(ArithToArmSME)
+add_subdirectory(ArithToEmitC)
 add_subdirectory(ArithToLLVM)
 add_subdirectory(ArithToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8a8dd6e10c48aa..2961b1574c49b7 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4011,6 +4011,7 @@ cc_library(
         ":AffineToStandard",
         ":ArithToAMDGPU",
         ":ArithToArmSME",
+        ":ArithToEmitC",
         ":ArithToLLVM",
         ":ArithToSPIRV",
         ":ArmNeon2dToIntr",
@@ -8156,6 +8157,32 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "ArithToEmitC",
+    srcs = glob([
+        "lib/Conversion/ArithToEmitC/*.cpp",
+        "lib/Conversion/ArithToEmitC/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/ArithToEmitC/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/ArithToEmitC",
+    ],
+    deps = [
+        ":ArithDialect",
+        ":ConversionPassIncGen",
+        ":EmitCDialect",
+        ":IR",
+        ":Pass",
+        ":Support",
+        ":TransformUtils",
+        ":Transforms",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "ArithToLLVM",
     srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),

@mgehre-amd
Copy link
Contributor

I guess you will still add tests?

@marbre
Copy link
Member Author

marbre commented Mar 6, 2024

I guess you will still add tests?

Of course! Accidentally missed to include the file in my commit, thanks for catching this.

This adds patterns and a pass to convert the Arith dialect to EmitC.
For now, this covers arithemtic binary ops operating on floating point
types.

It is not checked within the patterns whether the types, such as the
Tensor type, are supported in the respective EmitC operations. If
unsupported types should be converted, the conversion will fail anyway
because no legal EmitC operation can be created. This can clearly be
improved in a follow up, also resulting in better error messages.
Functions for such checks should not solely be used in the conversions
and should also be (re)used in the verifier.
Copy link

github-actions bot commented Mar 6, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@marbre marbre requested a review from simon-camp March 7, 2024 08:46
@marbre marbre merged commit c40146c into llvm:main Mar 7, 2024
3 of 4 checks passed
@marbre marbre deleted the arith-to-emitc branch March 7, 2024 10:34
mgehre-amd pushed a commit to Xilinx/llvm-project that referenced this pull request Mar 11, 2024
This adds patterns and a pass to convert the Arith dialect to EmitC. For
now, this covers arithemtic binary ops operating on floating point
types.

It is not checked within the patterns whether the types, such as the
Tensor type, are supported in the respective EmitC operations. If
unsupported types should be converted, the conversion will fail anyway
because no legal EmitC operation can be created. This can clearly be
improved in a follow up, also resulting in better error messages.
Functions for such checks should not solely be used in the conversions
and should also be (re)used in the verifier.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants