diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index 2b3c4a3f00dc..399240161363 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -311,12 +311,12 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand) { auto operandTy = mlir::cast(operand.getType()); - return create(loc, operandTy.getElementTy(), operand); + return create(loc, operandTy.getElementType(), operand); } mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) { auto operandTy = mlir::cast(operand.getType()); - return create(loc, operandTy.getElementTy(), operand); + return create(loc, operandTy.getElementType(), operand); } mlir::Value createComplexBinOp(mlir::Location loc, mlir::Value lhs, diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 60e0b6f238ce..62193b10fd9d 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1386,7 +1386,10 @@ def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> { }]; let results = (outs CIR_ComplexType:$result); - let arguments = (ins CIR_AnyIntOrFloat:$real, CIR_AnyIntOrFloat:$imag); + let arguments = (ins + CIR_AnyIntOrFloatType:$real, + CIR_AnyIntOrFloatType:$imag + ); let assemblyFormat = [{ $real `,` $imag @@ -1414,7 +1417,7 @@ def ComplexRealOp : CIR_Op<"complex.real", [Pure]> { ``` }]; - let results = (outs CIR_AnyIntOrFloat:$result); + let results = (outs CIR_AnyIntOrFloatType:$result); let arguments = (ins CIR_ComplexType:$operand); let assemblyFormat = [{ @@ -1439,7 +1442,7 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> { ``` }]; - let results = (outs CIR_AnyIntOrFloat:$result); + let results = (outs CIR_AnyIntOrFloatType:$result); let arguments = (ins CIR_ComplexType:$operand); let assemblyFormat = [{ @@ -5564,9 +5567,9 @@ def AtomicFetch : CIR_Op<"atomic.fetch", %res = cir.atomic.fetch(add, %ptr : !cir.ptr, %val : !s32i, seq_cst) : !s32i }]; - let results = (outs CIR_AnyIntOrFloat:$result); + let results = (outs CIR_AnyIntOrFloatType:$result); let arguments = (ins Arg:$ptr, - CIR_AnyIntOrFloat:$val, + CIR_AnyIntOrFloatType:$val, AtomicFetchKind:$binop, Arg:$mem_order, UnitAttr:$is_volatile, diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td index 42e5fa4c73c5..b94fd6b934a8 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td @@ -135,6 +135,16 @@ def CIR_AnyFloatType : AnyTypeOf<[ let cppFunctionName = "isAnyFloatingPointType"; } -def CIR_AnyIntOrFloat : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType]>; +def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType], + "integer or floating point type" +> { + let cppFunctionName = "isAnyIntegerOrFloatingPointType"; +} + +//===----------------------------------------------------------------------===// +// Complex Type predicates +//===----------------------------------------------------------------------===// + +def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">; #endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index 25f77842fa4b..3c37c0208eac 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -159,24 +159,35 @@ def CIR_ComplexType : CIR_Type<"Complex", "complex", CIR type that represents a C complex number. `cir.complex` models the C type `T _Complex`. - The parameter `elementTy` gives the type of the real and imaginary part of - the complex number. `elementTy` must be either a CIR integer type or a CIR + The type models complex values, per C99 6.2.5p11. It supports the C99 + complex float types as well as the GCC integer complex extensions. + + The parameter `elementType` gives the type of the real and imaginary part of + the complex number. `elementType` must be either a CIR integer type or a CIR floating-point type. }]; - let parameters = (ins "mlir::Type":$elementTy); + let parameters = (ins CIR_AnyIntOrFloatType:$elementType); let builders = [ - TypeBuilderWithInferredContext<(ins "mlir::Type":$elementTy), [{ - return $_get(elementTy.getContext(), elementTy); + TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{ + return $_get(elementType.getContext(), elementType); }]>, ]; let assemblyFormat = [{ - `<` $elementTy `>` + `<` $elementType `>` }]; - let genVerifyDecl = 1; + let extraClassDeclaration = [{ + bool isFloatingPointComplex() const { + return isAnyFloatingPointType(getElementType()); + } + + bool isIntegerComplex() const { + return mlir::isa(getElementType()); + } + }]; } //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h index 825d4d20fd2e..ef5ed4258c6a 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h +++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h @@ -816,7 +816,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { auto srcPtrTy = mlir::cast(value.getType()); auto srcComplexTy = mlir::cast(srcPtrTy.getPointee()); return create( - loc, getPointerTo(srcComplexTy.getElementTy()), value); + loc, getPointerTo(srcComplexTy.getElementType()), value); } Address createRealPtr(mlir::Location loc, Address addr) { @@ -830,7 +830,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { auto srcPtrTy = mlir::cast(value.getType()); auto srcComplexTy = mlir::cast(srcPtrTy.getPointee()); return create( - loc, getPointerTo(srcComplexTy.getElementTy()), value); + loc, getPointerTo(srcComplexTy.getElementType()), value); } Address createImagPtr(mlir::Location loc, Address addr) { diff --git a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp index 754fa895afce..32251d1ad796 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp @@ -831,7 +831,7 @@ mlir::Value ComplexExprEmitter::VisitImaginaryLiteral(const ImaginaryLiteral *IL) { auto Loc = CGF.getLoc(IL->getExprLoc()); auto Ty = mlir::cast(CGF.convertType(IL->getType())); - auto ElementTy = Ty.getElementTy(); + auto ElementTy = Ty.getElementType(); mlir::TypedAttr RealValueAttr; mlir::TypedAttr ImagValueAttr; @@ -875,8 +875,7 @@ mlir::Value CIRGenFunction::emitPromotedComplexExpr(const Expr *E, mlir::Value CIRGenFunction::emitPromotedValue(mlir::Value result, QualType PromotionType) { - assert(mlir::isa( - mlir::cast(result.getType()).getElementTy()) && + assert(!mlir::cast(result.getType()).isIntegerComplex() && "integral complex will never be promoted"); return builder.createCast(cir::CastKind::float_complex, result, convertType(PromotionType)); @@ -884,8 +883,7 @@ mlir::Value CIRGenFunction::emitPromotedValue(mlir::Value result, mlir::Value CIRGenFunction::emitUnPromotedValue(mlir::Value result, QualType UnPromotionType) { - assert(mlir::isa( - mlir::cast(result.getType()).getElementTy()) && + assert(!mlir::cast(result.getType()).isIntegerComplex() && "integral complex will never be promoted"); return builder.createCast(cir::CastKind::float_complex, result, convertType(UnPromotionType)); diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp index 88d5bebd8e7e..9cbf20881a92 100644 --- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp @@ -357,12 +357,12 @@ LogicalResult FPAttr::verify(function_ref emitError, LogicalResult ComplexAttr::verify(function_ref emitError, cir::ComplexType type, mlir::TypedAttr real, mlir::TypedAttr imag) { - auto elemTy = type.getElementTy(); - if (real.getType() != elemTy) { + auto elemType = type.getElementType(); + if (real.getType() != elemType) { emitError() << "type of the real part does not match the complex type"; return failure(); } - if (imag.getType() != elemTy) { + if (imag.getType() != elemType) { emitError() << "type of the imaginary part does not match the complex type"; return failure(); } diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index d289f4a6addd..04d24fe7fcda 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -655,7 +655,7 @@ LogicalResult cir::CastOp::verify() { auto resComplexTy = mlir::dyn_cast(resType); if (!resComplexTy) return emitOpError() << "requires !cir.complex type for result"; - if (srcType != resComplexTy.getElementTy()) + if (srcType != resComplexTy.getElementType()) return emitOpError() << "requires source type match result element type"; return success(); } @@ -665,7 +665,7 @@ LogicalResult cir::CastOp::verify() { auto resComplexTy = mlir::dyn_cast(resType); if (!resComplexTy) return emitOpError() << "requires !cir.complex type for result"; - if (srcType != resComplexTy.getElementTy()) + if (srcType != resComplexTy.getElementType()) return emitOpError() << "requires source type match result element type"; return success(); } @@ -675,7 +675,7 @@ LogicalResult cir::CastOp::verify() { return emitOpError() << "requires !cir.complex type for source"; if (!mlir::isa(resType)) return emitOpError() << "requires !cir.float type for result"; - if (srcComplexTy.getElementTy() != resType) + if (srcComplexTy.getElementType() != resType) return emitOpError() << "requires source element type match result type"; return success(); } @@ -685,71 +685,66 @@ LogicalResult cir::CastOp::verify() { return emitOpError() << "requires !cir.complex type for source"; if (!mlir::isa(resType)) return emitOpError() << "requires !cir.int type for result"; - if (srcComplexTy.getElementTy() != resType) + if (srcComplexTy.getElementType() != resType) return emitOpError() << "requires source element type match result type"; return success(); } case cir::CastKind::float_complex_to_bool: { auto srcComplexTy = mlir::dyn_cast(srcType); - if (!srcComplexTy || - !mlir::isa(srcComplexTy.getElementTy())) + if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex()) return emitOpError() - << "requires !cir.complex type for source"; + << "requires floating point !cir.complex type for source"; if (!mlir::isa(resType)) return emitOpError() << "requires !cir.bool type for result"; return success(); } case cir::CastKind::int_complex_to_bool: { auto srcComplexTy = mlir::dyn_cast(srcType); - if (!srcComplexTy || !mlir::isa(srcComplexTy.getElementTy())) + if (!srcComplexTy || !srcComplexTy.isIntegerComplex()) return emitOpError() - << "requires !cir.complex type for source"; + << "requires floating point !cir.complex type for source"; if (!mlir::isa(resType)) return emitOpError() << "requires !cir.bool type for result"; return success(); } case cir::CastKind::float_complex: { auto srcComplexTy = mlir::dyn_cast(srcType); - if (!srcComplexTy || - !mlir::isa(srcComplexTy.getElementTy())) + if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex()) return emitOpError() - << "requires !cir.complex type for source"; + << "requires floating point !cir.complex type for source"; auto resComplexTy = mlir::dyn_cast(resType); - if (!resComplexTy || - !mlir::isa(resComplexTy.getElementTy())) + if (!resComplexTy || !resComplexTy.isFloatingPointComplex()) return emitOpError() - << "requires !cir.complex type for result"; + << "requires floating point !cir.complex type for result"; return success(); } case cir::CastKind::float_complex_to_int_complex: { auto srcComplexTy = mlir::dyn_cast(srcType); - if (!srcComplexTy || - !mlir::isa(srcComplexTy.getElementTy())) + if (!srcComplexTy || !srcComplexTy.isFloatingPointComplex()) return emitOpError() - << "requires !cir.complex type for source"; + << "requires floating point !cir.complex type for source"; auto resComplexTy = mlir::dyn_cast(resType); - if (!resComplexTy || !mlir::isa(resComplexTy.getElementTy())) - return emitOpError() << "requires !cir.complex type for result"; + if (!resComplexTy || !resComplexTy.isIntegerComplex()) + return emitOpError() << "requires integer !cir.complex type for result"; return success(); } case cir::CastKind::int_complex: { auto srcComplexTy = mlir::dyn_cast(srcType); - if (!srcComplexTy || !mlir::isa(srcComplexTy.getElementTy())) - return emitOpError() << "requires !cir.complex type for source"; + if (!srcComplexTy || !srcComplexTy.isIntegerComplex()) + return emitOpError() << "requires integer !cir.complex type for source"; auto resComplexTy = mlir::dyn_cast(resType); - if (!resComplexTy || !mlir::isa(resComplexTy.getElementTy())) - return emitOpError() << "requires !cir.complex type for result"; + if (!resComplexTy || !resComplexTy.isIntegerComplex()) + return emitOpError() << "requires integer !cir.complex type for result"; return success(); } case cir::CastKind::int_complex_to_float_complex: { auto srcComplexTy = mlir::dyn_cast(srcType); - if (!srcComplexTy || !mlir::isa(srcComplexTy.getElementTy())) - return emitOpError() << "requires !cir.complex type for source"; + if (!srcComplexTy || !srcComplexTy.isIntegerComplex()) + return emitOpError() << "requires integer !cir.complex type for source"; auto resComplexTy = mlir::dyn_cast(resType); - if (!resComplexTy || - !mlir::isa(resComplexTy.getElementTy())) + if (!resComplexTy || !resComplexTy.isFloatingPointComplex()) return emitOpError() - << "requires !cir.complex type for result"; + << "requires floating point !cir.complex type for result"; return success(); } case cir::CastKind::member_ptr_to_bool: { @@ -912,7 +907,7 @@ LogicalResult cir::DerivedMethodOp::verify() { //===----------------------------------------------------------------------===// LogicalResult cir::ComplexCreateOp::verify() { - if (getType().getElementTy() != getReal().getType()) { + if (getType().getElementType() != getReal().getType()) { emitOpError() << "operand type of cir.complex.create does not match its result type"; return failure(); @@ -945,7 +940,7 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// LogicalResult cir::ComplexRealOp::verify() { - if (getType() != getOperand().getType().getElementTy()) { + if (getType() != getOperand().getType().getElementType()) { emitOpError() << "cir.complex.real result type does not match operand type"; return failure(); } @@ -960,7 +955,7 @@ OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) { } LogicalResult cir::ComplexImagOp::verify() { - if (getType() != getOperand().getType().getElementTy()) { + if (getType() != getOperand().getType().getElementType()) { emitOpError() << "cir.complex.imag result type does not match operand type"; return failure(); } @@ -984,7 +979,7 @@ LogicalResult cir::ComplexRealPtrOp::verify() { auto operandPointeeTy = mlir::cast(operandPtrTy.getPointee()); - if (resultPointeeTy != operandPointeeTy.getElementTy()) { + if (resultPointeeTy != operandPointeeTy.getElementType()) { emitOpError() << "cir.complex.real_ptr result type does not match operand type"; return failure(); @@ -999,7 +994,7 @@ LogicalResult cir::ComplexImagPtrOp::verify() { auto operandPointeeTy = mlir::cast(operandPtrTy.getPointee()); - if (resultPointeeTy != operandPointeeTy.getElementTy()) { + if (resultPointeeTy != operandPointeeTy.getElementType()) { emitOpError() << "cir.complex.imag_ptr result type does not match operand type"; return failure(); diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index 7135fc76f07c..60742f76eb13 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -798,18 +798,6 @@ bool cir::isIntOrIntVectorTy(mlir::Type t) { // ComplexType Definitions //===----------------------------------------------------------------------===// -mlir::LogicalResult cir::ComplexType::verify( - llvm::function_ref emitError, - mlir::Type elementTy) { - if (!mlir::isa(elementTy)) { - emitError() << "element type of !cir.complex must be either a " - "floating-point type or an integer type"; - return failure(); - } - - return success(); -} - llvm::TypeSize cir::ComplexType::getTypeSizeInBits(const mlir::DataLayout &dataLayout, mlir::DataLayoutEntryListRef params) const { @@ -818,8 +806,7 @@ cir::ComplexType::getTypeSizeInBits(const mlir::DataLayout &dataLayout, // as an array type containing exactly two elements of the corresponding // real type. - auto elementTy = getElementTy(); - return dataLayout.getTypeSizeInBits(elementTy) * 2; + return dataLayout.getTypeSizeInBits(getElementType()) * 2; } uint64_t @@ -830,8 +817,7 @@ cir::ComplexType::getABIAlignment(const mlir::DataLayout &dataLayout, // as an array type containing exactly two elements of the corresponding // real type. - auto elementTy = getElementTy(); - return dataLayout.getTypeABIAlignment(elementTy); + return dataLayout.getTypeABIAlignment(getElementType()); } //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp index f6ad6d3244e4..545f2b725277 100644 --- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp @@ -526,7 +526,7 @@ static mlir::Value lowerComplexToComplexCast(MLIRContext &ctx, CastOp op) { auto src = op.getSrc(); auto dstComplexElemTy = - mlir::cast(op.getType()).getElementTy(); + mlir::cast(op.getType()).getElementType(); auto srcReal = builder.createComplexReal(op.getLoc(), src); auto srcImag = builder.createComplexReal(op.getLoc(), src); @@ -591,7 +591,7 @@ static mlir::Value buildComplexBinOpLibCall( llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) { - auto elementTy = mlir::cast(ty.getElementTy()); + auto elementTy = mlir::cast(ty.getElementType()); auto libFuncName = libFuncNameGetter( llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics())); @@ -673,7 +673,7 @@ static mlir::Value lowerComplexMul(LoweringPreparePass &pass, auto ty = op.getType(); auto range = op.getRange(); - if (mlir::isa(ty.getElementTy()) || + if (mlir::isa(ty.getElementType()) || range == cir::ComplexRangeKind::Basic || range == cir::ComplexRangeKind::Improved || range == cir::ComplexRangeKind::Promoted) @@ -809,7 +809,7 @@ static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) { auto ty = op.getType(); - if (mlir::isa(ty.getElementTy())) { + if (mlir::isa(ty.getElementType())) { auto range = op.getRange(); if (range == cir::ComplexRangeKind::Improved || (range == cir::ComplexRangeKind::Promoted && !op.getPromoted())) diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 788c3844295b..d6928425e650 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1779,7 +1779,7 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( mlir::cast(op.getValue()).getValue()); } else if (auto complexTy = mlir::dyn_cast(op.getType())) { auto complexAttr = mlir::cast(op.getValue()); - auto complexElemTy = complexTy.getElementTy(); + auto complexElemTy = complexTy.getElementType(); auto complexElemLLVMTy = typeConverter->convertType(complexElemTy); mlir::Attribute components[2]; @@ -4482,7 +4482,7 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter, converter.addConversion([&](cir::ComplexType type) -> mlir::Type { // A complex type is lowered to an LLVM struct that contains the real and // imaginary part as data fields. - mlir::Type elementTy = converter.convertType(type.getElementTy()); + mlir::Type elementTy = converter.convertType(type.getElementType()); mlir::Type structFields[2] = {elementTy, elementTy}; return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(), structFields); diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index 243e2f074481..2e6d3ea7d94c 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -1525,3 +1525,10 @@ cir.global external dsolocal @vfp = #cir.ptr : !cir.ptr + +// ----- + +// Verify that complex type does not accept arbitrary type + +// expected-error @below {{integer or floating point type}} +!complex = !cir.complex>