diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a5bd80e9d6b8b..4726656869d62 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -782,6 +782,64 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } +static LogicalResult printOperation(CppEmitter &emitter, + mlir::UnrealizedConversionCastOp castOp) { + raw_ostream &os = emitter.ostream(); + Operation &op = *castOp.getOperation(); + + if (castOp.getResults().size() != 1 || castOp.getOperands().size() != 1) { + return castOp.emitOpError( + "expected single result and single operand for conversion cast"); + } + + Type destType = castOp.getResult(0).getType(); + + auto srcPtrType = + mlir::dyn_cast(castOp.getOperand(0).getType()); + auto destArrayType = mlir::dyn_cast(destType); + + if (srcPtrType && destArrayType) { + + // Emit declaration: (*v13)[dims] = + if (failed(emitter.emitType(op.getLoc(), destArrayType.getElementType()))) + return failure(); + os << " (*" << emitter.getOrCreateName(op.getResult(0)) << ")"; + for (int64_t dim : destArrayType.getShape()) + os << "[" << dim << "]"; + os << " = "; + + os << "("; + + // Emit the C++ type for "datatype (*)[dim1][dim2]..." + if (failed(emitter.emitType(op.getLoc(), destArrayType.getElementType()))) + return failure(); + + os << " (*)"; // Pointer to array + + for (int64_t dim : destArrayType.getShape()) { + os << "[" << dim << "]"; + } + os << ")"; + if (failed(emitter.emitOperand(castOp.getOperand(0)))) + return failure(); + + return success(); + } + + // Fallback to generic C-style cast for other cases + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + + os << "("; + if (failed(emitter.emitType(op.getLoc(), destType))) + return failure(); + os << ")"; + if (failed(emitter.emitOperand(castOp.getOperand(0)))) + return failure(); + + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::ApplyOp applyOp) { raw_ostream &os = emitter.ostream(); @@ -1291,7 +1349,29 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) { std::string out; llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getValue()); + Value baseValue = op.getValue(); + + // Check if the baseValue (%arg1) is a result of UnrealizedConversionCastOp + // that converts a pointer to an array type. + if (auto castOp = dyn_cast_or_null( + baseValue.getDefiningOp())) { + auto destArrayType = + mlir::dyn_cast(castOp.getResult(0).getType()); + auto srcPtrType = + mlir::dyn_cast(castOp.getOperand(0).getType()); + + // If it's a pointer being cast to an array, emit (*varName) + if (srcPtrType && destArrayType) { + ss << "(*" << getOrCreateName(baseValue) << ")"; + } else { + // Fallback if the cast is not our specific pointer-to-array case + ss << getOrCreateName(baseValue); + } + } else { + // Default behavior for a regular array or other base types + ss << getOrCreateName(baseValue); + } + for (auto index : op.getIndices()) { ss << "[" << getOrCreateName(index) << "]"; } @@ -1747,6 +1827,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { cacheDeferredOpResult(op.getResult(), getSubscriptName(op)); return success(); }) + .Case( + [&](auto op) { return printOperation(*this, op); }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); }); diff --git a/mlir/test/Target/Cpp/unrealized_conversion_cast.mlir b/mlir/test/Target/Cpp/unrealized_conversion_cast.mlir new file mode 100644 index 0000000000000..3971189218c39 --- /dev/null +++ b/mlir/test/Target/Cpp/unrealized_conversion_cast.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +// CHECK-LABEL: void builtin_cast +func.func @builtin_cast(%arg0: !emitc.ptr){ + // CHECK : float (*v2)[1][3][4][4] = (float (*)[1][3][4][4])v1 + %1 = builtin.unrealized_conversion_cast %arg0 : !emitc.ptr to !emitc.array<1x3x4x4xf32> +return +} + +// CHECK-LABEL: void builtin_cast_index +func.func @builtin_cast_index(%arg0: !emitc.size_t){ + // CHECK : size_t v2 = (size_t)v1 + %1 = builtin.unrealized_conversion_cast %arg0 : !emitc.size_t to index +return +}