diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 7ca31a83ecaa..59d28f65ed30 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -246,15 +246,20 @@ def DynamicCastOp : CIR_Op<"dyn_cast"> { cast-to-complete operation. }]; - let arguments = (ins DynamicCastKind:$kind, - RecordPtr:$src, - OptionalAttr:$info, - UnitAttr:$relative_layout); - let results = (outs CIR_PointerType:$result); + let arguments = (ins + DynamicCastKind:$kind, + CIR_PtrToRecordType:$src, + OptionalAttr:$info, + UnitAttr:$relative_layout + ); + + let results = (outs + CIR_PtrToAnyOf<[CIR_VoidType, CIR_RecordType]>:$result + ); let assemblyFormat = [{ `(` - $kind `,` $src `:` type($src) + $kind `,` $src `:` qualified(type($src)) (`,` qualified($info)^)? (`relative_layout` $relative_layout^)? `)` @@ -273,8 +278,6 @@ def DynamicCastOp : CIR_Op<"dyn_cast"> { return getType().isVoidPtr(); } }]; - - let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -3037,7 +3040,7 @@ def GetRuntimeMemberOp : CIR_Op<"get_runtime_member"> { }]; let arguments = (ins - Arg:$addr, + Arg:$addr, Arg:$member); let results = (outs Res:$result); @@ -3091,7 +3094,7 @@ def GetMethodOp : CIR_Op<"get_method"> { method. }]; - let arguments = (ins CIR_MethodType:$method, RecordPtr:$object); + let arguments = (ins CIR_MethodType:$method, CIR_PtrToRecordType:$object); let results = (outs FuncPtr:$callee, CIR_VoidPtrType:$adjusted_this); let assemblyFormat = [{ diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td index 2604680d9e4a..fd2d75b7f746 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td @@ -31,6 +31,18 @@ class CIR_ConfinedType preds, string summary = ""> : Type]>, summary, type.cppType>; +// Generates a type summary. +// - For a single type: returns its summary. +// - For multiple types: returns `any of `. +class CIR_TypeSummaries types> { + assert !not(!empty(types)), "expects non-empty list of types"; + + list summaries = !foreach(type, types, type.summary); + string joined = !interleave(summaries, ", "); + + string value = !if(!eq(!size(types), 1), joined, "any of " # joined); +} + //===----------------------------------------------------------------------===// // IntType predicates //===----------------------------------------------------------------------===// @@ -151,6 +163,12 @@ def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType], def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">; +//===----------------------------------------------------------------------===// +// Record Type predicates +//===----------------------------------------------------------------------===// + +def CIR_AnyRecordType : CIR_TypeBase<"::cir::RecordType", "record type">; + //===----------------------------------------------------------------------===// // Pointer Type predicates //===----------------------------------------------------------------------===// @@ -176,9 +194,14 @@ class CIR_PtrToPtrTo class CIR_PointeePred : SubstLeaves<"$_self", "::mlir::cast<::cir::PointerType>($_self).getPointee()", pred>; -class CIR_PtrToType - : CIR_ConfinedType], - "pointer to " # type.summary>; +class CIR_PtrToAnyOf types, string summary = ""> + : CIR_ConfinedType)>], + !if(!empty(summary), + "pointer to " # CIR_TypeSummaries.value, + summary)>; + +class CIR_PtrToType : CIR_PtrToAnyOf<[type]>; // Void pointer type constraints def CIR_VoidPtrType @@ -197,4 +220,6 @@ def CIR_PtrToIntOrFloatType : CIR_PtrToType; def CIR_PtrToComplexType : CIR_PtrToType; +def CIR_PtrToRecordType : CIR_PtrToType; + #endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index 6ccce77d2028..5bf4e2c432c9 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -505,15 +505,6 @@ def CIR_VoidType : CIR_Type<"Void", "void"> { // Constraints -// Pointer to record -def RecordPtr : Type< - And<[ - CPred<"::mlir::isa<::cir::PointerType>($_self)">, - CPred<"::mlir::isa<::cir::RecordType>(" - "::mlir::cast<::cir::PointerType>($_self).getPointee())"> - ]>, "!cir.record*"> { -} - // Pointer to exception info def ExceptionPtr : Type< And<[ diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 16238ece7b84..ee7d1f426b50 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -843,19 +843,6 @@ OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) { return {}; } -//===----------------------------------------------------------------------===// -// DynamicCastOp -//===----------------------------------------------------------------------===// - -LogicalResult cir::DynamicCastOp::verify() { - auto resultPointeeTy = mlir::cast(getType()).getPointee(); - if (!mlir::isa(resultPointeeTy)) - return emitOpError() - << "cir.dyn_cast must produce a void ptr or record ptr"; - - return mlir::success(); -} - //===----------------------------------------------------------------------===// // BaseDataMemberOp & DerivedDataMemberOp //===----------------------------------------------------------------------===// @@ -3650,8 +3637,7 @@ LogicalResult cir::InsertMemberOp::verify() { //===----------------------------------------------------------------------===// LogicalResult cir::GetRuntimeMemberOp::verify() { - auto recordTy = - cast(cast(getAddr().getType()).getPointee()); + auto recordTy = cast(getAddr().getType().getPointee()); auto memberPtrTy = getMember().getType(); if (recordTy != memberPtrTy.getClsTy()) { diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index a0968a7fb025..e2a4eb3bb1d2 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -894,7 +894,7 @@ module { module { cir.func @invalid_base_type(%arg0 : !cir.data_member) { %0 = cir.alloca !u32i, !cir.ptr, ["tmp"] {alignment = 4 : i64} - // expected-error@+1 {{'cir.get_runtime_member' op operand #0 must be !cir.record*}} + // expected-error@+1 {{'cir.get_runtime_member' op operand #0 must be pointer to record type}} %1 = cir.get_runtime_member %0[%arg0 : !cir.data_member] : !cir.ptr -> !cir.ptr cir.return }