Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"

#include <string>
#include <variant>

namespace mlir {
Expand All @@ -49,6 +50,10 @@ bool isSupportedFloatType(mlir::Type type);
/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isPointerWideType(mlir::Type type);

/// Convert an MLIR type to its C type string representation.
/// Returns an empty string if the type cannot be represented as a C type.
std::string getCTypeString(Type type);

// Either a literal string, or an placeholder for the fmtArgs.
struct Placeholder {};
using ReplacementItem = std::variant<StringRef, Placeholder>;
Expand Down
21 changes: 20 additions & 1 deletion mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -134,8 +135,26 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder,
llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
emitc::SubscriptOp subPtr =
emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));

// Determine the pointer type
Type pointerElementType = arrayType.getElementType();

// Check if the array comes from a const global
if (auto getGlobalOp = arrayValue.getDefiningOp<emitc::GetGlobalOp>()) {
auto globalOp = SymbolTable::lookupNearestSymbolFrom<emitc::GlobalOp>(
getGlobalOp, getGlobalOp.getNameAttr());
if (globalOp && globalOp.getConstSpecifier()) {
// Create a const pointer type using opaque type
std::string cTypeString = emitc::getCTypeString(pointerElementType);
if (!cTypeString.empty()) {
pointerElementType = emitc::OpaqueType::get(builder.getContext(),
"const " + cTypeString);
}
}
}

emitc::ApplyOp ptr = emitc::ApplyOp::create(
builder, loc, emitc::PointerType::get(arrayType.getElementType()),
builder, loc, emitc::PointerType::get(pointerElementType),
builder.getStringAttr("&"), subPtr);

return ptr;
Expand Down
37 changes: 37 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,43 @@ bool mlir::emitc::isFundamentalType(Type type) {
isa<emitc::PointerType>(type);
}

std::string mlir::emitc::getCTypeString(Type type) {
if (auto intType = dyn_cast<IntegerType>(type)) {
switch (intType.getWidth()) {
case 1:
return "bool";
case 8:
return intType.isUnsigned() ? "uint8_t" : "int8_t";
case 16:
return intType.isUnsigned() ? "uint16_t" : "int16_t";
case 32:
return intType.isUnsigned() ? "uint32_t" : "int32_t";
case 64:
return intType.isUnsigned() ? "uint64_t" : "int64_t";
default:
return "";
}
}
if (auto floatType = dyn_cast<FloatType>(type)) {
if (floatType.getWidth() == 16) {
if (isa<Float16Type>(type))
return "_Float16";
if (isa<BFloat16Type>(type))
return "__bf16";
return "";
}
if (floatType.getWidth() == 32)
return "float";
if (floatType.getWidth() == 64)
return "double";
return "";
}
if (auto opaqueType = dyn_cast<emitc::OpaqueType>(type))
return opaqueType.getValue().str();

return "";
}

/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
Expand Down
47 changes: 9 additions & 38 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1792,40 +1792,15 @@ LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
}

LogicalResult CppEmitter::emitType(Location loc, Type type) {
if (auto iType = dyn_cast<IntegerType>(type)) {
switch (iType.getWidth()) {
case 1:
return (os << "bool"), success();
case 8:
case 16:
case 32:
case 64:
if (shouldMapToUnsigned(iType.getSignedness()))
return (os << "uint" << iType.getWidth() << "_t"), success();
else
return (os << "int" << iType.getWidth() << "_t"), success();
default:
return emitError(loc, "cannot emit integer type ") << type;
}
}
if (auto fType = dyn_cast<FloatType>(type)) {
switch (fType.getWidth()) {
case 16: {
if (llvm::isa<Float16Type>(type))
return (os << "_Float16"), success();
if (llvm::isa<BFloat16Type>(type))
return (os << "__bf16"), success();
else
return emitError(loc, "cannot emit float type ") << type;
}
case 32:
return (os << "float"), success();
case 64:
return (os << "double"), success();
default:
return emitError(loc, "cannot emit float type ") << type;
}
}
std::string cTypeString = emitc::getCTypeString(type);
if (!cTypeString.empty())
return (os << cTypeString), success();

// Handle integer and float cases that failed above
if (isa<IntegerType>(type))
return emitError(loc, "cannot emit integer type ") << type;
if (isa<FloatType>(type))
return emitError(loc, "cannot emit float type ") << type;
if (auto iType = dyn_cast<IndexType>(type))
return (os << "size_t"), success();
if (auto sType = dyn_cast<emitc::SizeTType>(type))
Expand Down Expand Up @@ -1854,10 +1829,6 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
}
if (auto tType = dyn_cast<TupleType>(type))
return emitTupleType(loc, tType.getTypes());
if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
os << oType.getValue();
return success();
}
if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, aType.getElementType())))
return failure();
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,25 @@ module @globals {
return
}
}

// -----

// CHECK-LABEL: const_global_copy
module @const_global_copy {
memref.global "private" constant @const_data : memref<4xi8> = dense<[1, 2, 3, 4]>
// CHECK: emitc.global static const @const_data : !emitc.array<4xi8> = dense<[1, 2, 3, 4]>

func.func @copy_from_const_global() {
// CHECK: get_global @const_data : !emitc.array<4xi8>
%0 = memref.get_global @const_data : memref<4xi8>
// CHECK: "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4xi8>
%1 = memref.alloca() : memref<4xi8>

// Verify that pointer from const global has const qualifier
// CHECK: apply "&"({{.*}}) : (!emitc.lvalue<i8>) -> !emitc.ptr<!emitc.opaque<"const int8_t">>
// CHECK: apply "&"({{.*}}) : (!emitc.lvalue<i8>) -> !emitc.ptr<i8>
// CHECK: call_opaque "memcpy"({{.*}}, {{.*}}, {{.*}}) : (!emitc.ptr<i8>, !emitc.ptr<!emitc.opaque<"const int8_t">>, !emitc.size_t) -> ()
memref.copy %0, %1 : memref<4xi8> to memref<4xi8>
return
}
}