Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add initial convert-memref-to-emitc pass #85389

Merged
merged 10 commits into from
Mar 21, 2024

Conversation

mgehre-amd
Copy link
Contributor

@mgehre-amd mgehre-amd commented Mar 15, 2024

This converts memref.alloca, memref.load & memref.store to emitc.variable, emitc.subscript and emitc.assign.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 15, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Gehre (mgehre-amd)

Changes

This translates memref types in func.func, func.call and func.return to emitc.array
and it translates memref.alloca, memref.load & memref.store to emitc.variable, emitc.subscript and emitc.assign.


Full diff: https://github.com/llvm/llvm-project/pull/85389.diff

9 Files Affected:

  • (added) mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h (+29)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+9)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt (+19)
  • (added) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+163)
  • (added) mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir (+40)
  • (added) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir (+47)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+28)
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
new file mode 100644
index 00000000000000..9815267a2f28b8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -0,0 +1,29 @@
+//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+class TypeConverter;
+
+#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
+
+void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+                                             TypeConverter &converter);
+
+std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index f2aa4fb535402d..8dde20099e7aa3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -45,6 +45,7 @@
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bd81cc6d5323bf..7e7ee3a2f780f6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -753,6 +753,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MemRefToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
+  let summary = "Convert MemRef dialect to EmitC dialect";
+  let dependentDialects = ["emitc::EmitCDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 8219cf98575f3c..41ab7046b91ce3 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
 add_subdirectory(MathToSPIRV)
+add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
 add_subdirectory(NVGPUToNVVM)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..7bcec4cbadfce4
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRMemRefToEmitC
+  MemRefToEmitC.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIREmitCDialect
+  MLIRFuncDialect
+  MLIRFuncTransforms
+  MLIRMemRefDialect
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
new file mode 100644
index 00000000000000..b40bf8e9bf0df8
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -0,0 +1,163 @@
+//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert memref ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Disallow all memrefs even though we only have conversions
+/// for memrefs with static shape right now to have good diagnostics.
+bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
+
+template <typename RangeT>
+bool areLegal(RangeT &&range) {
+  return llvm::all_of(range, [](Type type) { return isLegal(type); });
+}
+
+bool isLegal(Operation *op) {
+  return areLegal(op->getOperandTypes()) && areLegal(op->getResultTypes());
+}
+
+bool isSignatureLegal(FunctionType ty) {
+  return areLegal(ty.getInputs()) && areLegal(ty.getResults());
+}
+
+struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getType().hasStaticShape()) {
+      return rewriter.notifyMatchFailure(
+          op.getLoc(), "cannot transform alloca with dynamic shape");
+    }
+
+    if (op.getAlignment().value_or(1) > 1) {
+      // TODO: Allow alignment if it is not more than the natural alignment
+      // of the C array.
+      return rewriter.notifyMatchFailure(
+          op.getLoc(), "cannot transform alloca with alignment requirement");
+    }
+
+    auto resultTy = getTypeConverter()->convertType(op.getType());
+    auto noInit = emitc::OpaqueAttr::get(getContext(), "");
+    rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
+    return success();
+  }
+};
+
+struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    rewriter.replaceOpWithNewOp<emitc::SubscriptOp>(op, operands.getMemref(),
+                                                    operands.getIndices());
+    return success();
+  }
+};
+
+struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto subscript = rewriter.create<emitc::SubscriptOp>(
+        op.getLoc(), operands.getMemref(), operands.getIndices());
+    rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+                                                 operands.getValue());
+    return success();
+  }
+};
+
+struct ConvertMemRefToEmitCPass
+    : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+  void runOnOperation() override {
+    TypeConverter converter;
+    // Fallback for other types.
+    converter.addConversion([](Type type) { return type; });
+    populateMemRefToEmitCTypeConversion(converter);
+    converter.addConversion(
+        [&converter](FunctionType ty) -> std::optional<Type> {
+          SmallVector<Type> inputs;
+          if (failed(converter.convertTypes(ty.getInputs(), inputs)))
+            return std::nullopt;
+
+          SmallVector<Type> results;
+          if (failed(converter.convertTypes(ty.getResults(), results)))
+            return std::nullopt;
+
+          return FunctionType::get(ty.getContext(), inputs, results);
+        });
+
+    RewritePatternSet patterns(&getContext());
+    populateMemRefToEmitCConversionPatterns(patterns, converter);
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    populateReturnOpTypeConversionPattern(patterns, converter);
+
+    ConversionTarget target(getContext());
+    target.addDynamicallyLegalOp<func::FuncOp>(
+        [](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
+    target.addDynamicallyLegalDialect<func::FuncDialect>(
+        [](Operation *op) { return isLegal(op); });
+    target.addIllegalDialect<memref::MemRefDialect>();
+    target.addLegalDialect<emitc::EmitCDialect>();
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
+  typeConverter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
+    if (memRefType.hasStaticShape()) {
+      return emitc::ArrayType::get(memRefType.getShape(),
+                                   memRefType.getElementType());
+    }
+    return {};
+  });
+}
+
+void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+                                                   TypeConverter &converter) {
+  patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
+                                                         patterns.getContext());
+}
+
+std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
+  return std::make_unique<ConvertMemRefToEmitCPass>();
+}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
new file mode 100644
index 00000000000000..0d4e4139b85fb4
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
+
+// Unranked memrefs are not converted
+// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @memref_unranked(%arg0 : memref<*xf32>) {
+  return
+}
+
+// -----
+
+// Memrefs with dynamic shapes are not converted
+// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @memref_dynamic_shape(%arg0 : memref<2x?xf32>) {
+  return
+}
+
+// -----
+
+func.func @memref_op(%arg0 : memref<2x4xf32>) {
+  // expected-error@+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
+  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+  return
+}
+
+// -----
+
+func.func @alloca_with_dynamic_shape() {
+  %0 = index.constant 1
+  // expected-error@+1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+  %1 = memref.alloca(%0) : memref<4x?xf32>
+  return
+}
+
+// -----
+
+func.func @alloca_with_alignment() {
+  // expected-error@+1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+  %1 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
+  return
+}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
new file mode 100644
index 00000000000000..8a11dcf7603c62
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL: memref_arg
+// CHECK-SAME:  !emitc.array<32xf32>)
+func.func @memref_arg(%arg0 : memref<32xf32>) {
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: memref_return
+// CHECK-SAME:  %[[arg0:.*]]: !emitc.array<32xf32>) -> !emitc.array<32xf32>
+func.func @memref_return(%arg0 : memref<32xf32>) -> memref<32xf32> {
+// CHECK: return %[[arg0]] : !emitc.array<32xf32>
+  func.return %arg0 : memref<32xf32>
+}
+
+// CHECK-LABEL: memref_call
+// CHECK-SAME:  %[[arg0:.*]]: !emitc.array<32xf32>)
+func.func @memref_call(%arg0 : memref<32xf32>) {
+// CHECK: call @memref_return(%[[arg0]]) : (!emitc.array<32xf32>) -> !emitc.array<32xf32>
+  func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: alloca
+func.func @alloca() {
+  // CHECK "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+  %0 = memref.alloca() : memref<4x8xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: memref_load_store
+// CHECK-SAME:  %[[arg0:.*]]: !emitc.array<4x8xf32>, %[[arg1:.*]]: !emitc.array<3x5xf32>
+// CHECK-SAME:  %[[i:.*]]: index, %[[j:.*]]: index
+func.func @memref_load_store(%in: memref<4x8xf32>, %out: memref<3x5xf32>, %i: index, %j: index) {
+  // CHECK: %[[load:.*]] = emitc.subscript %[[arg0]][%[[i]], %[[j]]] : <4x8xf32>
+  %0 = memref.load %in[%i, %j] : memref<4x8xf32>
+  // CHECK: %[[store_loc:.*]] = emitc.subscript %[[arg1]][%[[i]], %[[j]]] : <3x5xf32>
+  // CHECK: emitc.assign %[[load]] : f32 to %[[store_loc:.*]] : f32
+  memref.store %0, %out[%i, %j] : memref<3x5xf32>
+  return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a3243abde5bbc5..94f9b47092b725 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4168,6 +4168,7 @@ cc_library(
         ":MathToLLVM",
         ":MathToLibm",
         ":MathToSPIRV",
+        ":MemRefToEmitC",
         ":MemRefToLLVM",
         ":MemRefToSPIRV",
         ":NVGPUToNVVM",
@@ -8179,6 +8180,33 @@ cc_library(
         "//llvm:Support",
     ],
 )
+cc_library(
+    name = "MemRefToEmitC",
+    srcs = glob([
+        "lib/Conversion/MemRefToEmitC/*.cpp",
+        "lib/Conversion/MemRefToEmitC/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/MemRefToEmitC/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/MemRefToEmitC",
+    ],
+    deps = [
+        ":ConversionPassIncGen",
+        ":EmitCDialect",
+        ":FuncDialect",
+        ":FuncTransforms",
+        ":MemRefDialect",
+        ":IR",
+        ":Pass",
+        ":Support",
+        ":TransformUtils",
+        ":Transforms",
+        "//llvm:Support",
+    ],
+)
 
 cc_library(
     name = "MemRefToLLVM",

This translates memref types in func.func, func.call and func.return
to emitc.array
and it translates memref.alloca, memref.load & memref.store to
emitc.variable, emitc.subscipt and emitc.assign.
Copy link
Member

@marbre marbre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pushing this forward. Some suggestions after a very quick first look.

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp Outdated Show resolved Hide resolved
mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h Outdated Show resolved Hide resolved
@llvmbot llvmbot added the bazel "Peripheral" support tier build system: utils/bazel label Mar 18, 2024
mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt Outdated Show resolved Hide resolved
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel Outdated Show resolved Hide resolved
@mgehre-amd mgehre-amd requested a review from marbre March 21, 2024 12:34
Copy link
Member

@marbre marbre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One NIT for the PR description / squash commit message, translates must be converts. See https://mlir.llvm.org/getting_started/Glossary/#translation vs. https://mlir.llvm.org/getting_started/Glossary/#conversion.

@mgehre-amd mgehre-amd merged commit 0aa6d57 into llvm:main Mar 21, 2024
3 of 4 checks passed
@mgehre-amd mgehre-amd deleted the mgehre.memref_to_emitc branch March 21, 2024 13:27
mgehre-amd added a commit to Xilinx/llvm-project that referenced this pull request Mar 22, 2024
This converts `memref.alloca`, `memref.load` & `memref.store` to
`emitc.variable`, `emitc.subscript` and `emitc.assign`.
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
This converts `memref.alloca`, `memref.load` & `memref.store` to
`emitc.variable`, `emitc.subscript` and `emitc.assign`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants