diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td index e91537186df59..34df9af7fc06d 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td @@ -44,6 +44,7 @@ def CIR_Dialect : Dialect { static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; } static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; } static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; } + static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; } void registerAttributes(); void registerTypes(); diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 2b361ed0982c6..8f3e25b3c9737 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2580,7 +2580,7 @@ def CIR_FuncOp : CIR_Op<"func", [ } //===----------------------------------------------------------------------===// -// CallOp +// CallOp and TryCallOp //===----------------------------------------------------------------------===// def CIR_SideEffect : CIR_I32EnumAttr< @@ -2707,6 +2707,98 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> { ]; } +def CIR_TryCallOp : CIR_CallOpBase<"try_call",[ + DeclareOpInterfaceMethods, + Terminator, AttrSizedOperandSegments +]> { + let summary = "try_call operation"; + + let description = [{ + Mostly similar to cir.call but requires two destination + branches, one for handling exceptions in case its thrown and + the other one to follow on regular control-flow. + + Example: + + ```mlir + // Direct call + %result = cir.try_call @division(%a, %b) ^continue, ^landing_pad + : (f32, f32) -> f32 + ``` + }]; + + let arguments = !con((ins + Variadic:$contOperands, + Variadic:$landingPadOperands + ), commonArgs); + + let results = (outs Optional:$result); + let successors = (successor AnySuccessor:$cont, AnySuccessor:$landing_pad); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType, + "mlir::Block *":$cont, "mlir::Block *":$landing_pad, + CArg<"mlir::ValueRange", "{}">:$operands, + CArg<"mlir::ValueRange", "{}">:$contOperands, + CArg<"mlir::ValueRange", "{}">:$landingPadOperands, + CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{ + $_state.addOperands(operands); + if (callee) + $_state.addAttribute("callee", callee); + if (resType && !isa(resType)) + $_state.addTypes(resType); + + $_state.addAttribute("side_effect", + SideEffectAttr::get($_builder.getContext(), sideEffect)); + + // Handle branches + $_state.addOperands(contOperands); + $_state.addOperands(landingPadOperands); + // The TryCall ODS layout is: cont, landing_pad, operands. + llvm::copy(::llvm::ArrayRef({ + static_cast(contOperands.size()), + static_cast(landingPadOperands.size()), + static_cast(operands.size()) + }), + odsState.getOrAddProperties().operandSegmentSizes.begin()); + $_state.addSuccessors(cont); + $_state.addSuccessors(landing_pad); + }]>, + OpBuilder<(ins "mlir::Value":$ind_target, + "FuncType":$fn_type, + "mlir::Block *":$cont, "mlir::Block *":$landing_pad, + CArg<"mlir::ValueRange", "{}">:$operands, + CArg<"mlir::ValueRange", "{}">:$contOperands, + CArg<"mlir::ValueRange", "{}">:$landingPadOperands, + CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{ + ::llvm::SmallVector finalCallOperands({ind_target}); + finalCallOperands.append(operands.begin(), operands.end()); + $_state.addOperands(finalCallOperands); + + if (!fn_type.hasVoidReturn()) + $_state.addTypes(fn_type.getReturnType()); + + $_state.addAttribute("side_effect", + SideEffectAttr::get($_builder.getContext(), sideEffect)); + + // Handle branches + $_state.addOperands(contOperands); + $_state.addOperands(landingPadOperands); + // The TryCall ODS layout is: cont, landing_pad, operands. + llvm::copy(::llvm::ArrayRef({ + static_cast(contOperands.size()), + static_cast(landingPadOperands.size()), + static_cast(finalCallOperands.size()) + }), + odsState.getOrAddProperties().operandSegmentSizes.begin()); + $_state.addSuccessors(cont); + $_state.addSuccessors(landing_pad); + }]> + ]; +} + //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 2d2ef422bfaef..11074af3ef127 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -701,13 +701,78 @@ unsigned cir::CallOp::getNumArgOperands() { return this->getOperation()->getNumOperands(); } +static mlir::ParseResult +parseTryCallBranches(mlir::OpAsmParser &parser, mlir::OperationState &result, + llvm::SmallVectorImpl + &continueOperands, + llvm::SmallVectorImpl + &landingPadOperands, + llvm::SmallVectorImpl &continueTypes, + llvm::SmallVectorImpl &landingPadTypes, + llvm::SMLoc &continueOperandsLoc, + llvm::SMLoc &landingPadOperandsLoc) { + mlir::Block *continueSuccessor = nullptr; + mlir::Block *landingPadSuccessor = nullptr; + + if (parser.parseSuccessor(continueSuccessor)) + return mlir::failure(); + + if (mlir::succeeded(parser.parseOptionalLParen())) { + continueOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(continueOperands)) + return mlir::failure(); + if (parser.parseColon()) + return mlir::failure(); + + if (parser.parseTypeList(continueTypes)) + return mlir::failure(); + if (parser.parseRParen()) + return mlir::failure(); + } + + if (parser.parseComma()) + return mlir::failure(); + + if (parser.parseSuccessor(landingPadSuccessor)) + return mlir::failure(); + + if (mlir::succeeded(parser.parseOptionalLParen())) { + landingPadOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(landingPadOperands)) + return mlir::failure(); + if (parser.parseColon()) + return mlir::failure(); + + if (parser.parseTypeList(landingPadTypes)) + return mlir::failure(); + if (parser.parseRParen()) + return mlir::failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes)) + return mlir::failure(); + + result.addSuccessors(continueSuccessor); + result.addSuccessors(landingPadSuccessor); + return mlir::success(); +} + static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, - mlir::OperationState &result) { + mlir::OperationState &result, + bool hasDestinationBlocks = false) { llvm::SmallVector ops; llvm::SMLoc opsLoc; mlir::FlatSymbolRefAttr calleeAttr; llvm::ArrayRef allResultTypes; + // TryCall control flow related + llvm::SmallVector continueOperands; + llvm::SMLoc continueOperandsLoc; + llvm::SmallVector continueTypes; + llvm::SmallVector landingPadOperands; + llvm::SMLoc landingPadOperandsLoc; + llvm::SmallVector landingPadTypes; + // If we cannot parse a string callee, it means this is an indirect call. if (!parser .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(), @@ -729,6 +794,14 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, if (parser.parseRParen()) return mlir::failure(); + if (hasDestinationBlocks && + parseTryCallBranches(parser, result, continueOperands, landingPadOperands, + continueTypes, landingPadTypes, continueOperandsLoc, + landingPadOperandsLoc) + .failed()) { + return ::mlir::failure(); + } + if (parser.parseOptionalKeyword("nothrow").succeeded()) result.addAttribute(CIRDialect::getNoThrowAttrName(), mlir::UnitAttr::get(parser.getContext())); @@ -761,6 +834,24 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands)) return mlir::failure(); + if (hasDestinationBlocks) { + // The TryCall ODS layout is: cont, landing_pad, operands. + llvm::copy(::llvm::ArrayRef( + {static_cast(continueOperands.size()), + static_cast(landingPadOperands.size()), + static_cast(ops.size())}), + result.getOrAddProperties() + .operandSegmentSizes.begin()); + + if (parser.resolveOperands(continueOperands, continueTypes, + continueOperandsLoc, result.operands)) + return ::mlir::failure(); + + if (parser.resolveOperands(landingPadOperands, landingPadTypes, + landingPadOperandsLoc, result.operands)) + return ::mlir::failure(); + } + return mlir::success(); } @@ -768,7 +859,9 @@ static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, - cir::SideEffect sideEffect) { + cir::SideEffect sideEffect, + mlir::Block *cont = nullptr, + mlir::Block *landingPad = nullptr) { printer << ' '; auto callLikeOp = mlir::cast(op); @@ -782,8 +875,35 @@ static void printCallCommon(mlir::Operation *op, assert(indirectCallee); printer << indirectCallee; } + printer << "(" << ops << ")"; + if (cont) { + assert(landingPad && "expected two successors"); + auto tryCall = dyn_cast(op); + assert(tryCall && "regular calls do not branch"); + printer << ' ' << tryCall.getCont(); + if (!tryCall.getContOperands().empty()) { + printer << "("; + printer << tryCall.getContOperands(); + printer << ' ' << ":"; + printer << ' '; + printer << tryCall.getContOperands().getTypes(); + printer << ")"; + } + printer << ","; + printer << ' '; + printer << tryCall.getLandingPad(); + if (!tryCall.getLandingPadOperands().empty()) { + printer << "("; + printer << tryCall.getLandingPadOperands(); + printer << ' ' << ":"; + printer << ' '; + printer << tryCall.getLandingPadOperands().getTypes(); + printer << ")"; + } + } + if (isNothrow) printer << " nothrow"; @@ -793,10 +913,11 @@ static void printCallCommon(mlir::Operation *op, printer << ")"; } - printer.printOptionalAttrDict(op->getAttrs(), - {CIRDialect::getCalleeAttrName(), - CIRDialect::getNoThrowAttrName(), - CIRDialect::getSideEffectAttrName()}); + llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = { + CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(), + CIRDialect::getSideEffectAttrName(), + CIRDialect::getOperandSegmentSizesAttrName()}; + printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); printer << " : "; printer.printFunctionalType(op->getOperands().getTypes(), @@ -878,6 +999,70 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return verifyCallCommInSymbolUses(*this, symbolTable); } +//===----------------------------------------------------------------------===// +// TryCallOp +//===----------------------------------------------------------------------===// + +mlir::OperandRange cir::TryCallOp::getArgOperands() { + if (isIndirect()) + return getArgs().drop_front(1); + return getArgs(); +} + +mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() { + mlir::MutableOperandRange args = getArgsMutable(); + if (isIndirect()) + return args.slice(1, args.size() - 1); + return args; +} + +mlir::Value cir::TryCallOp::getIndirectCall() { + assert(isIndirect()); + return getOperand(0); +} + +/// Return the operand at index 'i'. +Value cir::TryCallOp::getArgOperand(unsigned i) { + if (isIndirect()) + ++i; + return getOperand(i); +} + +/// Return the number of operands. +unsigned cir::TryCallOp::getNumArgOperands() { + if (isIndirect()) + return this->getOperation()->getNumOperands() - 1; + return this->getOperation()->getNumOperands(); +} + +LogicalResult +cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + return verifyCallCommInSymbolUses(*this, symbolTable); +} + +mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true); +} + +void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) { + mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr; + cir::SideEffect sideEffect = getSideEffect(); + printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(), + sideEffect, getCont(), getLandingPad()); +} + +mlir::SuccessorOperands cir::TryCallOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + if (index == 0) + return SuccessorOperands(getContOperandsMutable()); + if (index == 1) + return SuccessorOperands(getLandingPadOperandsMutable()); + + // index == 2 + return SuccessorOperands(getArgOperandsMutable()); +} + //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 5a6193fa8d840..12f3db01c77d8 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1385,7 +1385,9 @@ static mlir::LogicalResult rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands, mlir::ConversionPatternRewriter &rewriter, const mlir::TypeConverter *converter, - mlir::FlatSymbolRefAttr calleeAttr) { + mlir::FlatSymbolRefAttr calleeAttr, + mlir::Block *continueBlock = nullptr, + mlir::Block *landingPadBlock = nullptr) { llvm::SmallVector llvmResults; mlir::ValueTypeRange cirResults = op->getResultTypes(); auto call = cast(op); @@ -1414,7 +1416,7 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands, llvmFnTy = converter->convertType( fn.getFunctionType()); assert(llvmFnTy && "Failed to convert function type"); - } else if (auto alias = mlir::cast(callee)) { + } else if (auto alias = mlir::dyn_cast(callee)) { // If the callee was an alias. In that case, // we need to prepend the address of the alias to the operands. The // way aliases work in the LLVM dialect is a little counter-intuitive. @@ -1452,17 +1454,21 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands, converter->convertType(calleeFuncTy)); } - assert(!cir::MissingFeatures::opCallLandingPad()); - assert(!cir::MissingFeatures::opCallContinueBlock()); assert(!cir::MissingFeatures::opCallCallConv()); + if (landingPadBlock) { + rewriter.replaceOpWithNewOp( + op, llvmFnTy, calleeAttr, callOperands, continueBlock, + mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{}); + return mlir::success(); + } + auto newOp = rewriter.replaceOpWithNewOp( op, llvmFnTy, calleeAttr, callOperands); if (memoryEffects) newOp.setMemoryEffectsAttr(memoryEffects); newOp.setNoUnwind(noUnwind); newOp.setWillReturn(willReturn); - return mlir::success(); } @@ -1473,6 +1479,14 @@ mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite( getTypeConverter(), op.getCalleeAttr()); } +mlir::LogicalResult CIRToLLVMTryCallOpLowering::matchAndRewrite( + cir::TryCallOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter, + getTypeConverter(), op.getCalleeAttr(), + op.getCont(), op.getLandingPad()); +} + mlir::LogicalResult CIRToLLVMReturnAddrOpLowering::matchAndRewrite( cir::ReturnAddrOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { diff --git a/clang/test/CIR/IR/try-call.cir b/clang/test/CIR/IR/try-call.cir new file mode 100644 index 0000000000000..6c23d3add15c8 --- /dev/null +++ b/clang/test/CIR/IR/try-call.cir @@ -0,0 +1,31 @@ +// RUN: cir-opt %s --verify-roundtrip | FileCheck %s + +!s32i = !cir.int + +module { + +cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i + +cir.func @flatten_structure_with_try_call_op() { + %a = cir.const #cir.int<1> : !s32i + %b = cir.const #cir.int<2> : !s32i + %3 = cir.try_call @division(%a, %b) ^continue, ^landing_pad : (!s32i, !s32i) -> !s32i + ^continue: + cir.br ^landing_pad + ^landing_pad: + cir.return +} + +// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i + +// CHECK: cir.func @flatten_structure_with_try_call_op() { +// CHECK-NEXT: %[[CONST_0:.*]] = cir.const #cir.int<1> : !s32i +// CHECK-NEXT: %[[CONST_1:.*]] = cir.const #cir.int<2> : !s32i +// CHECK-NEXT: %[[CALL:.*]] = cir.try_call @division(%0, %1) ^[[CONTINUE:.*]], ^[[LANDING_PAD:.*]] : (!s32i, !s32i) -> !s32i +// CHECK-NEXT: ^[[CONTINUE]]: +// CHECK-NEXT: cir.br ^[[LANDING_PAD]] +// CHECK-NEXT: ^[[LANDING_PAD]]: +// CHECK-NEXT: cir.return +// CHECK-NEXT: } + +}