From 8f72155036be688a55eb32f178389ba4d92babfc Mon Sep 17 00:00:00 2001 From: Hendrik Klug Date: Mon, 6 Oct 2025 12:13:29 +0000 Subject: [PATCH] [mlir][emitc] Fix creating pointer from constant array When creating a pointer from a constant emitc array, check if it is constant. If it is, create the pointer as opaque<"const {type}">>. Move out C type string creation logic from TranslateToCpp.cpp to getCTypeString in EmitC.cpp as a shared utility function. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 5 ++ .../MemRefToEmitC/MemRefToEmitC.cpp | 21 ++++++++- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 37 +++++++++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 47 ++++--------------- .../MemRefToEmitC/memref-to-emitc.mlir | 22 +++++++++ 5 files changed, 93 insertions(+), 39 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index eb7ddeb3bfc54..614895977588a 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -27,6 +27,7 @@ #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc" #include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc" +#include #include namespace mlir { @@ -49,6 +50,10 @@ bool isSupportedFloatType(mlir::Type type); /// Determines whether \p type is a emitc.size_t/ssize_t type. bool isPointerWideType(mlir::Type type); +/// Convert an MLIR type to its C type string representation. +/// Returns an empty string if the type cannot be represented as a C type. +std::string getCTypeString(Type type); + // Either a literal string, or an placeholder for the fmtArgs. struct Placeholder {}; using ReplacementItem = std::variant; diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 2b7bdc9a7b7f8..7b05284818ecb 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" @@ -134,8 +135,26 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder, llvm::SmallVector indices(arrayType.getRank(), zeroIndex); emitc::SubscriptOp subPtr = emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices)); + + // Determine the pointer type + Type pointerElementType = arrayType.getElementType(); + + // Check if the array comes from a const global + if (auto getGlobalOp = arrayValue.getDefiningOp()) { + auto globalOp = SymbolTable::lookupNearestSymbolFrom( + getGlobalOp, getGlobalOp.getNameAttr()); + if (globalOp && globalOp.getConstSpecifier()) { + // Create a const pointer type using opaque type + std::string cTypeString = emitc::getCTypeString(pointerElementType); + if (!cTypeString.empty()) { + pointerElementType = emitc::OpaqueType::get(builder.getContext(), + "const " + cTypeString); + } + } + } + emitc::ApplyOp ptr = emitc::ApplyOp::create( - builder, loc, emitc::PointerType::get(arrayType.getElementType()), + builder, loc, emitc::PointerType::get(pointerElementType), builder.getStringAttr("&"), subPtr); return ptr; diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 5c8564bca6f86..d07993fb5a986 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -139,6 +139,43 @@ bool mlir::emitc::isFundamentalType(Type type) { isa(type); } +std::string mlir::emitc::getCTypeString(Type type) { + if (auto intType = dyn_cast(type)) { + switch (intType.getWidth()) { + case 1: + return "bool"; + case 8: + return intType.isUnsigned() ? "uint8_t" : "int8_t"; + case 16: + return intType.isUnsigned() ? "uint16_t" : "int16_t"; + case 32: + return intType.isUnsigned() ? "uint32_t" : "int32_t"; + case 64: + return intType.isUnsigned() ? "uint64_t" : "int64_t"; + default: + return ""; + } + } + if (auto floatType = dyn_cast(type)) { + if (floatType.getWidth() == 16) { + if (isa(type)) + return "_Float16"; + if (isa(type)) + return "__bf16"; + return ""; + } + if (floatType.getWidth() == 32) + return "float"; + if (floatType.getWidth() == 64) + return "double"; + return ""; + } + if (auto opaqueType = dyn_cast(type)) + return opaqueType.getValue().str(); + + return ""; +} + /// Check that the type of the initial value is compatible with the operations /// result type. static LogicalResult verifyInitializationAttribute(Operation *op, diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a5bd80e9d6b8b..16db8e1aaa12d 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1792,40 +1792,15 @@ LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type, } LogicalResult CppEmitter::emitType(Location loc, Type type) { - if (auto iType = dyn_cast(type)) { - switch (iType.getWidth()) { - case 1: - return (os << "bool"), success(); - case 8: - case 16: - case 32: - case 64: - if (shouldMapToUnsigned(iType.getSignedness())) - return (os << "uint" << iType.getWidth() << "_t"), success(); - else - return (os << "int" << iType.getWidth() << "_t"), success(); - default: - return emitError(loc, "cannot emit integer type ") << type; - } - } - if (auto fType = dyn_cast(type)) { - switch (fType.getWidth()) { - case 16: { - if (llvm::isa(type)) - return (os << "_Float16"), success(); - if (llvm::isa(type)) - return (os << "__bf16"), success(); - else - return emitError(loc, "cannot emit float type ") << type; - } - case 32: - return (os << "float"), success(); - case 64: - return (os << "double"), success(); - default: - return emitError(loc, "cannot emit float type ") << type; - } - } + std::string cTypeString = emitc::getCTypeString(type); + if (!cTypeString.empty()) + return (os << cTypeString), success(); + + // Handle integer and float cases that failed above + if (isa(type)) + return emitError(loc, "cannot emit integer type ") << type; + if (isa(type)) + return emitError(loc, "cannot emit float type ") << type; if (auto iType = dyn_cast(type)) return (os << "size_t"), success(); if (auto sType = dyn_cast(type)) @@ -1854,10 +1829,6 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { } if (auto tType = dyn_cast(type)) return emitTupleType(loc, tType.getTypes()); - if (auto oType = dyn_cast(type)) { - os << oType.getValue(); - return success(); - } if (auto aType = dyn_cast(type)) { if (failed(emitType(loc, aType.getElementType()))) return failure(); diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index 2b4eda37903d4..97c06639bf35b 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -58,3 +58,25 @@ module @globals { return } } + +// ----- + +// CHECK-LABEL: const_global_copy +module @const_global_copy { + memref.global "private" constant @const_data : memref<4xi8> = dense<[1, 2, 3, 4]> + // CHECK: emitc.global static const @const_data : !emitc.array<4xi8> = dense<[1, 2, 3, 4]> + + func.func @copy_from_const_global() { + // CHECK: get_global @const_data : !emitc.array<4xi8> + %0 = memref.get_global @const_data : memref<4xi8> + // CHECK: "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4xi8> + %1 = memref.alloca() : memref<4xi8> + + // Verify that pointer from const global has const qualifier + // CHECK: apply "&"({{.*}}) : (!emitc.lvalue) -> !emitc.ptr> + // CHECK: apply "&"({{.*}}) : (!emitc.lvalue) -> !emitc.ptr + // CHECK: call_opaque "memcpy"({{.*}}, {{.*}}, {{.*}}) : (!emitc.ptr, !emitc.ptr>, !emitc.size_t) -> () + memref.copy %0, %1 : memref<4xi8> to memref<4xi8> + return + } +}