-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][emitc] Fix creating pointer from constant array #162083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
When creating a pointer from a constant emitc array, check if it is constant. If it is, create the pointer as opaque<"const {type}">>. Move out C type string creation logic from TranslateToCpp.cpp to getCTypeString in EmitC.cpp as a shared utility function.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-emitc Author: Hendrik_Klug (Jimmy2027) ChangesWhen creating a pointer from a constant emitc array, check if it is constant. If it is, create the pointer as opaque<"const {type}">>. Move out C type string creation logic from TranslateToCpp.cpp to getCTypeString in EmitC.cpp as a shared utility function. Full diff: https://github.com/llvm/llvm-project/pull/162083.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index eb7ddeb3bfc54..614895977588a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -27,6 +27,7 @@
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"
+#include <string>
#include <variant>
namespace mlir {
@@ -49,6 +50,10 @@ bool isSupportedFloatType(mlir::Type type);
/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isPointerWideType(mlir::Type type);
+/// Convert an MLIR type to its C type string representation.
+/// Returns an empty string if the type cannot be represented as a C type.
+std::string getCTypeString(Type type);
+
// Either a literal string, or an placeholder for the fmtArgs.
struct Placeholder {};
using ReplacementItem = std::variant<StringRef, Placeholder>;
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 2b7bdc9a7b7f8..7b05284818ecb 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -134,8 +135,26 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder,
llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
emitc::SubscriptOp subPtr =
emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
+
+ // Determine the pointer type
+ Type pointerElementType = arrayType.getElementType();
+
+ // Check if the array comes from a const global
+ if (auto getGlobalOp = arrayValue.getDefiningOp<emitc::GetGlobalOp>()) {
+ auto globalOp = SymbolTable::lookupNearestSymbolFrom<emitc::GlobalOp>(
+ getGlobalOp, getGlobalOp.getNameAttr());
+ if (globalOp && globalOp.getConstSpecifier()) {
+ // Create a const pointer type using opaque type
+ std::string cTypeString = emitc::getCTypeString(pointerElementType);
+ if (!cTypeString.empty()) {
+ pointerElementType = emitc::OpaqueType::get(builder.getContext(),
+ "const " + cTypeString);
+ }
+ }
+ }
+
emitc::ApplyOp ptr = emitc::ApplyOp::create(
- builder, loc, emitc::PointerType::get(arrayType.getElementType()),
+ builder, loc, emitc::PointerType::get(pointerElementType),
builder.getStringAttr("&"), subPtr);
return ptr;
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 5c8564bca6f86..d07993fb5a986 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -139,6 +139,43 @@ bool mlir::emitc::isFundamentalType(Type type) {
isa<emitc::PointerType>(type);
}
+std::string mlir::emitc::getCTypeString(Type type) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
+ switch (intType.getWidth()) {
+ case 1:
+ return "bool";
+ case 8:
+ return intType.isUnsigned() ? "uint8_t" : "int8_t";
+ case 16:
+ return intType.isUnsigned() ? "uint16_t" : "int16_t";
+ case 32:
+ return intType.isUnsigned() ? "uint32_t" : "int32_t";
+ case 64:
+ return intType.isUnsigned() ? "uint64_t" : "int64_t";
+ default:
+ return "";
+ }
+ }
+ if (auto floatType = dyn_cast<FloatType>(type)) {
+ if (floatType.getWidth() == 16) {
+ if (isa<Float16Type>(type))
+ return "_Float16";
+ if (isa<BFloat16Type>(type))
+ return "__bf16";
+ return "";
+ }
+ if (floatType.getWidth() == 32)
+ return "float";
+ if (floatType.getWidth() == 64)
+ return "double";
+ return "";
+ }
+ if (auto opaqueType = dyn_cast<emitc::OpaqueType>(type))
+ return opaqueType.getValue().str();
+
+ return "";
+}
+
/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index a5bd80e9d6b8b..16db8e1aaa12d 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1792,40 +1792,15 @@ LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
}
LogicalResult CppEmitter::emitType(Location loc, Type type) {
- if (auto iType = dyn_cast<IntegerType>(type)) {
- switch (iType.getWidth()) {
- case 1:
- return (os << "bool"), success();
- case 8:
- case 16:
- case 32:
- case 64:
- if (shouldMapToUnsigned(iType.getSignedness()))
- return (os << "uint" << iType.getWidth() << "_t"), success();
- else
- return (os << "int" << iType.getWidth() << "_t"), success();
- default:
- return emitError(loc, "cannot emit integer type ") << type;
- }
- }
- if (auto fType = dyn_cast<FloatType>(type)) {
- switch (fType.getWidth()) {
- case 16: {
- if (llvm::isa<Float16Type>(type))
- return (os << "_Float16"), success();
- if (llvm::isa<BFloat16Type>(type))
- return (os << "__bf16"), success();
- else
- return emitError(loc, "cannot emit float type ") << type;
- }
- case 32:
- return (os << "float"), success();
- case 64:
- return (os << "double"), success();
- default:
- return emitError(loc, "cannot emit float type ") << type;
- }
- }
+ std::string cTypeString = emitc::getCTypeString(type);
+ if (!cTypeString.empty())
+ return (os << cTypeString), success();
+
+ // Handle integer and float cases that failed above
+ if (isa<IntegerType>(type))
+ return emitError(loc, "cannot emit integer type ") << type;
+ if (isa<FloatType>(type))
+ return emitError(loc, "cannot emit float type ") << type;
if (auto iType = dyn_cast<IndexType>(type))
return (os << "size_t"), success();
if (auto sType = dyn_cast<emitc::SizeTType>(type))
@@ -1854,10 +1829,6 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
}
if (auto tType = dyn_cast<TupleType>(type))
return emitTupleType(loc, tType.getTypes());
- if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
- os << oType.getValue();
- return success();
- }
if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, aType.getElementType())))
return failure();
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 2b4eda37903d4..97c06639bf35b 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -58,3 +58,25 @@ module @globals {
return
}
}
+
+// -----
+
+// CHECK-LABEL: const_global_copy
+module @const_global_copy {
+ memref.global "private" constant @const_data : memref<4xi8> = dense<[1, 2, 3, 4]>
+ // CHECK: emitc.global static const @const_data : !emitc.array<4xi8> = dense<[1, 2, 3, 4]>
+
+ func.func @copy_from_const_global() {
+ // CHECK: get_global @const_data : !emitc.array<4xi8>
+ %0 = memref.get_global @const_data : memref<4xi8>
+ // CHECK: "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4xi8>
+ %1 = memref.alloca() : memref<4xi8>
+
+ // Verify that pointer from const global has const qualifier
+ // CHECK: apply "&"({{.*}}) : (!emitc.lvalue<i8>) -> !emitc.ptr<!emitc.opaque<"const int8_t">>
+ // CHECK: apply "&"({{.*}}) : (!emitc.lvalue<i8>) -> !emitc.ptr<i8>
+ // CHECK: call_opaque "memcpy"({{.*}}, {{.*}}, {{.*}}) : (!emitc.ptr<i8>, !emitc.ptr<!emitc.opaque<"const int8_t">>, !emitc.size_t) -> ()
+ memref.copy %0, %1 : memref<4xi8> to memref<4xi8>
+ return
+ }
+}
|
Hi @simon-camp, could you have a look at this? |
When creating a pointer from a constant emitc array, check if it is constant. If it is, create the pointer as opaque<"const {type}">>.
Move out C type string creation logic from TranslateToCpp.cpp to getCTypeString in EmitC.cpp as a shared utility function.