diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td index e91537186df59..599dd601fe7ac 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td @@ -44,6 +44,7 @@ def CIR_Dialect : Dialect { static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; } static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; } static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; } + static llvm::StringRef getExceptionAttrName() { return "exception"; } void registerAttributes(); void registerTypes(); diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 2b361ed0982c6..a335db42ee3c9 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2685,24 +2685,36 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> { empty. The first operand of this operation must be a pointer to the callee function. The rest operands are arguments to the callee function. + If the `cir.call` has the `exception` keyword, the call can throw. + Example: ```mlir + // Direct call %0 = cir.call @foo() + + // Indirect call + %20 = cir.call %18(%17) + + // Call that might throw + cir.call exception @foo_that_might_throw() -> () ``` }]; let results = (outs Optional:$result); - let arguments = commonArgs; + let arguments = !con((ins UnitAttr:$exception), commonArgs); let builders = [ OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType, - "mlir::ValueRange":$operands), [{ + "mlir::ValueRange":$operands, + CArg<"mlir::UnitAttr", "{}">:$exception), [{ $_state.addOperands(operands); if (callee) $_state.addAttribute("callee", callee); if (resType && !isa(resType)) $_state.addTypes(resType); + if (exception) + $_state.addAttribute("exception", exception); }]> ]; } diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 2d2ef422bfaef..e88d0bb41bd80 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -708,6 +708,12 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, mlir::FlatSymbolRefAttr calleeAttr; llvm::ArrayRef allResultTypes; + bool hasExceptions = false; + if (mlir::succeeded(parser.parseOptionalKeyword("exception"))) { + result.addAttribute("exception", parser.getBuilder().getUnitAttr()); + hasExceptions = true; + } + // If we cannot parse a string callee, it means this is an indirect call. if (!parser .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(), @@ -729,9 +735,15 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, if (parser.parseRParen()) return mlir::failure(); - if (parser.parseOptionalKeyword("nothrow").succeeded()) + llvm::SMLoc optionalNothrowLoc = parser.getCurrentLocation(); + if (parser.parseOptionalKeyword("nothrow").succeeded()) { + if (hasExceptions) + return parser.emitError( + optionalNothrowLoc, + "should have either `exception` or `nothrow`, but not both"); result.addAttribute(CIRDialect::getNoThrowAttrName(), mlir::UnitAttr::get(parser.getContext())); + } if (parser.parseOptionalKeyword("side_effect").succeeded()) { if (parser.parseLParen().failed()) @@ -767,13 +779,16 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, - mlir::OpAsmPrinter &printer, bool isNothrow, - cir::SideEffect sideEffect) { + mlir::OpAsmPrinter &printer, bool exception, + bool isNothrow, cir::SideEffect sideEffect) { printer << ' '; auto callLikeOp = mlir::cast(op); auto ops = callLikeOp.getArgOperands(); + if (exception) + printer << "exception "; + if (calleeSym) { // Direct calls printer.printAttributeWithoutType(calleeSym); @@ -793,10 +808,10 @@ static void printCallCommon(mlir::Operation *op, printer << ")"; } - printer.printOptionalAttrDict(op->getAttrs(), - {CIRDialect::getCalleeAttrName(), - CIRDialect::getNoThrowAttrName(), - CIRDialect::getSideEffectAttrName()}); + llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = { + CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(), + CIRDialect::getSideEffectAttrName(), CIRDialect::getExceptionAttrName()}; + printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); printer << " : "; printer.printFunctionalType(op->getOperands().getTypes(), @@ -811,8 +826,8 @@ mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser, void cir::CallOp::print(mlir::OpAsmPrinter &p) { mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr; cir::SideEffect sideEffect = getSideEffect(); - printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(), - sideEffect); + printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getException(), + getNothrow(), sideEffect); } static LogicalResult diff --git a/clang/test/CIR/IR/call.cir b/clang/test/CIR/IR/call.cir index 59f28be36846f..63e679936fdae 100644 --- a/clang/test/CIR/IR/call.cir +++ b/clang/test/CIR/IR/call.cir @@ -61,4 +61,32 @@ cir.func @f7(%arg0: !cir.ptr !s32i>>) -> !s32i { // CHECK-NEXT: cir.return %[[#ret]] : !s32i // CHECK-NEXT: } +cir.func private @division() -> !s32i + +cir.func dso_local @call_with_exception_attribute() { + cir.scope { + cir.try { + %0 = cir.call exception @division() : () -> !s32i + cir.yield + } catch all { + cir.yield + } + } + cir.return +} + +// CHECK: cir.func private @division() -> !s32i + +// CHECK: cir.func dso_local @call_with_exception_attribute() { +// CHECK-NEXT: cir.scope { +// CHECK-NEXT: cir.try { +// CHECK-NEXT: %[[CALL:.*]] = cir.call exception @division() : () -> !s32i +// CHECK-NEXT: cir.yield +// CHECK-NEXT: } catch all { +// CHECK-NEXT: cir.yield +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: cir.return +// CHECK-NEXT: } + } diff --git a/clang/test/CIR/IR/invalid-call.cir b/clang/test/CIR/IR/invalid-call.cir index a9c7e38f73af6..eb93b0558eac4 100644 --- a/clang/test/CIR/IR/invalid-call.cir +++ b/clang/test/CIR/IR/invalid-call.cir @@ -80,3 +80,22 @@ cir.func @f13() { cir.call @f12(%0) : (!s32i) -> () cir.return } + +// ----- + +!s32i = !cir.int + +cir.func private @division() -> !s32i + +cir.func dso_local @calling_division_inside_try_block() { + cir.scope { + cir.try { + // expected-error @below {{should have either `exception` or `nothrow`, but not both}} + %0 = cir.call exception @division() nothrow : () -> !s32i + cir.yield + } catch all { + cir.yield + } + } + cir.return +}