diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h new file mode 100644 index 0000000000000..9cb43689d1ce6 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h @@ -0,0 +1,20 @@ +//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H +#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H + +namespace mlir { +class RewritePatternSet; +class TypeConverter; + +void populateArithToEmitCPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h new file mode 100644 index 0000000000000..6b98fed7185ea --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h @@ -0,0 +1,21 @@ +//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H +#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTARITHTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 81f69210fade8..f2aa4fb535402 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -13,6 +13,7 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" +#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 94fc7a7d2194b..bd81cc6d5323b 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -133,6 +133,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> { ]; } +//===----------------------------------------------------------------------===// +// ArithToEmitC +//===----------------------------------------------------------------------===// + +def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> { + let summary = "Convert Arith dialect to EmitC dialect"; + let dependentDialects = ["emitc::EmitCDialect"]; +} + //===----------------------------------------------------------------------===// // ArithToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp new file mode 100644 index 0000000000000..6909534d4790f --- /dev/null +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -0,0 +1,60 @@ +//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert the Arith dialect to the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +namespace { +template +class ArithOpConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.template replaceOpWithNewOp(arithOp, arithOp.getType(), + adaptor.getOperands()); + + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + + // clang-format off + patterns.add< + ArithOpConversion, + ArithOpConversion, + ArithOpConversion, + ArithOpConversion + >(typeConverter, ctx); + // clang-format on +} diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp new file mode 100644 index 0000000000000..b377c063a7aa0 --- /dev/null +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp @@ -0,0 +1,53 @@ +//===- ArithToEmitCPass.cpp - Arith to EmitC Pass ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert the Arith dialect to the EmitC +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTARITHTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ConvertArithToEmitC + : public impl::ConvertArithToEmitCBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertArithToEmitC::runOnOperation() { + ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addIllegalDialect(); + target.addLegalOp(); + + RewritePatternSet patterns(&getContext()); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + + populateArithToEmitCPatterns(typeConverter, patterns); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt new file mode 100644 index 0000000000000..a3784f47c3bc2 --- /dev/null +++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRArithToEmitC + ArithToEmitC.cpp + ArithToEmitCPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIREmitCDialect + MLIRPass + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 9e421f7c49dbc..8219cf98575f3 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL) add_subdirectory(ArithCommon) add_subdirectory(ArithToAMDGPU) add_subdirectory(ArithToArmSME) +add_subdirectory(ArithToEmitC) add_subdirectory(ArithToLLVM) add_subdirectory(ArithToSPIRV) add_subdirectory(ArmNeon2dToIntr) diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir new file mode 100644 index 0000000000000..6a56474a5c48b --- /dev/null +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s + +func.func @arith_ops(%arg0: f32, %arg1: f32) { + // CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32 + %0 = arith.addf %arg0, %arg1 : f32 + // CHECK: [[V1:[^ ]*]] = emitc.div %arg0, %arg1 : (f32, f32) -> f32 + %1 = arith.divf %arg0, %arg1 : f32 + // CHECK: [[V2:[^ ]*]] = emitc.mul %arg0, %arg1 : (f32, f32) -> f32 + %2 = arith.mulf %arg0, %arg1 : f32 + // CHECK: [[V3:[^ ]*]] = emitc.sub %arg0, %arg1 : (f32, f32) -> f32 + %3 = arith.subf %arg0, %arg1 : f32 + + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 8a8dd6e10c48a..2961b1574c49b 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4011,6 +4011,7 @@ cc_library( ":AffineToStandard", ":ArithToAMDGPU", ":ArithToArmSME", + ":ArithToEmitC", ":ArithToLLVM", ":ArithToSPIRV", ":ArmNeon2dToIntr", @@ -8156,6 +8157,32 @@ cc_library( ], ) +cc_library( + name = "ArithToEmitC", + srcs = glob([ + "lib/Conversion/ArithToEmitC/*.cpp", + "lib/Conversion/ArithToEmitC/*.h", + ]), + hdrs = glob([ + "include/mlir/Conversion/ArithToEmitC/*.h", + ]), + includes = [ + "include", + "lib/Conversion/ArithToEmitC", + ], + deps = [ + ":ArithDialect", + ":ConversionPassIncGen", + ":EmitCDialect", + ":IR", + ":Pass", + ":Support", + ":TransformUtils", + ":Transforms", + "//llvm:Support", + ], +) + cc_library( name = "ArithToLLVM", srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),