-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add initial convert-memref-to-emitc pass #85389
Conversation
@llvm/pr-subscribers-mlir Author: Matthias Gehre (mgehre-amd) ChangesThis translates Full diff: https://github.com/llvm/llvm-project/pull/85389.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
new file mode 100644
index 00000000000000..9815267a2f28b8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -0,0 +1,29 @@
+//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+class TypeConverter;
+
+#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
+
+void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+ TypeConverter &converter);
+
+std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index f2aa4fb535402d..8dde20099e7aa3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -45,6 +45,7 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bd81cc6d5323bf..7e7ee3a2f780f6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -753,6 +753,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// MemRefToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
+ let summary = "Convert MemRef dialect to EmitC dialect";
+ let dependentDialects = ["emitc::EmitCDialect"];
+}
+
//===----------------------------------------------------------------------===//
// MemRefToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 8219cf98575f3c..41ab7046b91ce3 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToSPIRV)
+add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
add_subdirectory(NVGPUToNVVM)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..7bcec4cbadfce4
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRMemRefToEmitC
+ MemRefToEmitC.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIREmitCDialect
+ MLIRFuncDialect
+ MLIRFuncTransforms
+ MLIRMemRefDialect
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
new file mode 100644
index 00000000000000..b40bf8e9bf0df8
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -0,0 +1,163 @@
+//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert memref ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Disallow all memrefs even though we only have conversions
+/// for memrefs with static shape right now to have good diagnostics.
+bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
+
+template <typename RangeT>
+bool areLegal(RangeT &&range) {
+ return llvm::all_of(range, [](Type type) { return isLegal(type); });
+}
+
+bool isLegal(Operation *op) {
+ return areLegal(op->getOperandTypes()) && areLegal(op->getResultTypes());
+}
+
+bool isSignatureLegal(FunctionType ty) {
+ return areLegal(ty.getInputs()) && areLegal(ty.getResults());
+}
+
+struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getType().hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "cannot transform alloca with dynamic shape");
+ }
+
+ if (op.getAlignment().value_or(1) > 1) {
+ // TODO: Allow alignment if it is not more than the natural alignment
+ // of the C array.
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "cannot transform alloca with alignment requirement");
+ }
+
+ auto resultTy = getTypeConverter()->convertType(op.getType());
+ auto noInit = emitc::OpaqueAttr::get(getContext(), "");
+ rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
+ return success();
+ }
+};
+
+struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ rewriter.replaceOpWithNewOp<emitc::SubscriptOp>(op, operands.getMemref(),
+ operands.getIndices());
+ return success();
+ }
+};
+
+struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto subscript = rewriter.create<emitc::SubscriptOp>(
+ op.getLoc(), operands.getMemref(), operands.getIndices());
+ rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+ operands.getValue());
+ return success();
+ }
+};
+
+struct ConvertMemRefToEmitCPass
+ : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ void runOnOperation() override {
+ TypeConverter converter;
+ // Fallback for other types.
+ converter.addConversion([](Type type) { return type; });
+ populateMemRefToEmitCTypeConversion(converter);
+ converter.addConversion(
+ [&converter](FunctionType ty) -> std::optional<Type> {
+ SmallVector<Type> inputs;
+ if (failed(converter.convertTypes(ty.getInputs(), inputs)))
+ return std::nullopt;
+
+ SmallVector<Type> results;
+ if (failed(converter.convertTypes(ty.getResults(), results)))
+ return std::nullopt;
+
+ return FunctionType::get(ty.getContext(), inputs, results);
+ });
+
+ RewritePatternSet patterns(&getContext());
+ populateMemRefToEmitCConversionPatterns(patterns, converter);
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
+ populateReturnOpTypeConversionPattern(patterns, converter);
+
+ ConversionTarget target(getContext());
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
+ target.addDynamicallyLegalDialect<func::FuncDialect>(
+ [](Operation *op) { return isLegal(op); });
+ target.addIllegalDialect<memref::MemRefDialect>();
+ target.addLegalDialect<emitc::EmitCDialect>();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
+ typeConverter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
+ if (memRefType.hasStaticShape()) {
+ return emitc::ArrayType::get(memRefType.getShape(),
+ memRefType.getElementType());
+ }
+ return {};
+ });
+}
+
+void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+ TypeConverter &converter) {
+ patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
+ patterns.getContext());
+}
+
+std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
+ return std::make_unique<ConvertMemRefToEmitCPass>();
+}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
new file mode 100644
index 00000000000000..0d4e4139b85fb4
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
+
+// Unranked memrefs are not converted
+// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @memref_unranked(%arg0 : memref<*xf32>) {
+ return
+}
+
+// -----
+
+// Memrefs with dynamic shapes are not converted
+// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @memref_dynamic_shape(%arg0 : memref<2x?xf32>) {
+ return
+}
+
+// -----
+
+func.func @memref_op(%arg0 : memref<2x4xf32>) {
+ // expected-error@+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
+ memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+ return
+}
+
+// -----
+
+func.func @alloca_with_dynamic_shape() {
+ %0 = index.constant 1
+ // expected-error@+1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+ %1 = memref.alloca(%0) : memref<4x?xf32>
+ return
+}
+
+// -----
+
+func.func @alloca_with_alignment() {
+ // expected-error@+1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+ %1 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
+ return
+}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
new file mode 100644
index 00000000000000..8a11dcf7603c62
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL: memref_arg
+// CHECK-SAME: !emitc.array<32xf32>)
+func.func @memref_arg(%arg0 : memref<32xf32>) {
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: memref_return
+// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>) -> !emitc.array<32xf32>
+func.func @memref_return(%arg0 : memref<32xf32>) -> memref<32xf32> {
+// CHECK: return %[[arg0]] : !emitc.array<32xf32>
+ func.return %arg0 : memref<32xf32>
+}
+
+// CHECK-LABEL: memref_call
+// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>)
+func.func @memref_call(%arg0 : memref<32xf32>) {
+// CHECK: call @memref_return(%[[arg0]]) : (!emitc.array<32xf32>) -> !emitc.array<32xf32>
+ func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: alloca
+func.func @alloca() {
+ // CHECK "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+ %0 = memref.alloca() : memref<4x8xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: memref_load_store
+// CHECK-SAME: %[[arg0:.*]]: !emitc.array<4x8xf32>, %[[arg1:.*]]: !emitc.array<3x5xf32>
+// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
+func.func @memref_load_store(%in: memref<4x8xf32>, %out: memref<3x5xf32>, %i: index, %j: index) {
+ // CHECK: %[[load:.*]] = emitc.subscript %[[arg0]][%[[i]], %[[j]]] : <4x8xf32>
+ %0 = memref.load %in[%i, %j] : memref<4x8xf32>
+ // CHECK: %[[store_loc:.*]] = emitc.subscript %[[arg1]][%[[i]], %[[j]]] : <3x5xf32>
+ // CHECK: emitc.assign %[[load]] : f32 to %[[store_loc:.*]] : f32
+ memref.store %0, %out[%i, %j] : memref<3x5xf32>
+ return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a3243abde5bbc5..94f9b47092b725 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4168,6 +4168,7 @@ cc_library(
":MathToLLVM",
":MathToLibm",
":MathToSPIRV",
+ ":MemRefToEmitC",
":MemRefToLLVM",
":MemRefToSPIRV",
":NVGPUToNVVM",
@@ -8179,6 +8180,33 @@ cc_library(
"//llvm:Support",
],
)
+cc_library(
+ name = "MemRefToEmitC",
+ srcs = glob([
+ "lib/Conversion/MemRefToEmitC/*.cpp",
+ "lib/Conversion/MemRefToEmitC/*.h",
+ ]),
+ hdrs = glob([
+ "include/mlir/Conversion/MemRefToEmitC/*.h",
+ ]),
+ includes = [
+ "include",
+ "lib/Conversion/MemRefToEmitC",
+ ],
+ deps = [
+ ":ConversionPassIncGen",
+ ":EmitCDialect",
+ ":FuncDialect",
+ ":FuncTransforms",
+ ":MemRefDialect",
+ ":IR",
+ ":Pass",
+ ":Support",
+ ":TransformUtils",
+ ":Transforms",
+ "//llvm:Support",
+ ],
+)
cc_library(
name = "MemRefToLLVM",
|
This translates memref types in func.func, func.call and func.return to emitc.array and it translates memref.alloca, memref.load & memref.store to emitc.variable, emitc.subscipt and emitc.assign.
7939c31
to
88cdc80
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pushing this forward. Some suggestions after a very quick first look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One NIT for the PR description / squash commit message, translates must be converts. See https://mlir.llvm.org/getting_started/Glossary/#translation vs. https://mlir.llvm.org/getting_started/Glossary/#conversion.
This converts `memref.alloca`, `memref.load` & `memref.store` to `emitc.variable`, `emitc.subscript` and `emitc.assign`.
This converts `memref.alloca`, `memref.load` & `memref.store` to `emitc.variable`, `emitc.subscript` and `emitc.assign`.
This converts
memref.alloca
,memref.load
&memref.store
toemitc.variable
,emitc.subscript
andemitc.assign
.