diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 12435119b98a1..ae209679ece6c 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -402,6 +402,57 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { return false; } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetFieldOp getFieldOp) { + emitter.cacheDeferredOpResult(getFieldOp.getResult(), + getFieldOp.getFieldName()); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetGlobalOp getGlobalOp) { + emitter.cacheDeferredOpResult(getGlobalOp.getResult(), getGlobalOp.getName()); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LiteralOp literalOp) { + emitter.cacheDeferredOpResult(literalOp.getResult(), literalOp.getValue()); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::MemberOp memberOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << emitter.getOrCreateName(memberOp.getOperand()); + ss << "." << memberOp.getMember(); + emitter.cacheDeferredOpResult(memberOp.getResult(), out); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::MemberOfPtrOp memberOfPtrOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << emitter.getOrCreateName(memberOfPtrOp.getOperand()); + ss << "->" << memberOfPtrOp.getMember(); + emitter.cacheDeferredOpResult(memberOfPtrOp.getResult(), out); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::SubscriptOp subscriptOp) { + std::string out; + llvm::raw_string_ostream ss(out); + ss << emitter.getOrCreateName(subscriptOp.getValue()); + for (auto index : subscriptOp.getIndices()) { + ss << "[" << emitter.getOrCreateName(index) << "]"; + } + emitter.cacheDeferredOpResult(subscriptOp.getResult(), out); + return success(); +} + static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value) { OpResult result = operation->getResult(0); @@ -1761,41 +1812,19 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, emitc::DoOp, emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp, - emitc::ForOp, emitc::FuncOp, emitc::GlobalOp, emitc::IfOp, - emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp, - emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp, - emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp, - emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, - emitc::VerbatimOp>( + emitc::ForOp, emitc::FuncOp, emitc::GetFieldOp, + emitc::GetGlobalOp, emitc::GlobalOp, emitc::IfOp, + emitc::IncludeOp, emitc::LiteralOp, emitc::LoadOp, + emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, + emitc::MemberOfPtrOp, emitc::MemberOp, emitc::MulOp, + emitc::RemOp, emitc::ReturnOp, emitc::SubscriptOp, emitc::SubOp, + emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp, + emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( [&](auto op) { return printOperation(*this, op); }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getName()); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getFieldName()); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getValue()); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), getSubscriptName(op)); - return success(); - }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); });