diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 3df28f2ef3334..b53d9ebe70587 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -8627,6 +8627,8 @@ def err_conditional_vector_size : Error< def err_conditional_vector_element_size : Error< "vector condition type %0 and result type %1 do not have elements of the " "same size">; +def err_conditional_vector_scalar_type_unsupported : Error< + "scalar type %0 not supported with vector condition type %1">; def err_conditional_vector_has_void : Error< "GNU vector conditional operand cannot be %select{void|a throw expression}0">; def err_conditional_vector_operand_type diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 37598f8530c09..3b69d6f723045 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -8699,10 +8699,6 @@ class Sema final : public SemaBase { ExprResult &RHS, SourceLocation QuestionLoc); - QualType CheckSizelessVectorConditionalTypes(ExprResult &Cond, - ExprResult &LHS, ExprResult &RHS, - SourceLocation QuestionLoc); - //// Determines if a type is trivially relocatable /// according to the C++26 rules. // FIXME: This is in Sema because it requires diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp index 0fe242dce45e9..56c3f496569c8 100644 --- a/clang/lib/Sema/SemaExprCXX.cpp +++ b/clang/lib/Sema/SemaExprCXX.cpp @@ -5658,20 +5658,13 @@ static bool ConvertForConditional(Sema &Self, ExprResult &E, QualType T) { // extension. static bool isValidVectorForConditionalCondition(ASTContext &Ctx, QualType CondTy) { - if (!CondTy->isVectorType() && !CondTy->isExtVectorType()) + bool IsSVEVectorType = CondTy->isSveVLSBuiltinType(); + if (!CondTy->isVectorType() && !CondTy->isExtVectorType() && !IsSVEVectorType) return false; const QualType EltTy = - cast(CondTy.getCanonicalType())->getElementType(); - assert(!EltTy->isEnumeralType() && "Vectors cant be enum types"); - return EltTy->isIntegralType(Ctx); -} - -static bool isValidSizelessVectorForConditionalCondition(ASTContext &Ctx, - QualType CondTy) { - if (!CondTy->isSveVLSBuiltinType()) - return false; - const QualType EltTy = - cast(CondTy.getCanonicalType())->getSveEltType(Ctx); + IsSVEVectorType + ? cast(CondTy.getCanonicalType())->getSveEltType(Ctx) + : cast(CondTy.getCanonicalType())->getElementType(); assert(!EltTy->isEnumeralType() && "Vectors cant be enum types"); return EltTy->isIntegralType(Ctx); } @@ -5683,21 +5676,29 @@ QualType Sema::CheckVectorConditionalTypes(ExprResult &Cond, ExprResult &LHS, RHS = DefaultFunctionArrayLvalueConversion(RHS.get()); QualType CondType = Cond.get()->getType(); - const auto *CondVT = CondType->castAs(); - QualType CondElementTy = CondVT->getElementType(); - unsigned CondElementCount = CondVT->getNumElements(); QualType LHSType = LHS.get()->getType(); - const auto *LHSVT = LHSType->getAs(); QualType RHSType = RHS.get()->getType(); - const auto *RHSVT = RHSType->getAs(); - QualType ResultType; + bool LHSIsVector = LHSType->isVectorType() || LHSType->isSizelessVectorType(); + bool RHSIsVector = RHSType->isVectorType() || RHSType->isSizelessVectorType(); + + auto GetVectorInfo = + [&](QualType Type) -> std::pair { + if (const auto *VT = Type->getAs()) + return std::make_pair(VT->getElementType(), + llvm::ElementCount::getFixed(VT->getNumElements())); + ASTContext::BuiltinVectorTypeInfo VectorInfo = + Context.getBuiltinVectorTypeInfo(Type->castAs()); + return std::make_pair(VectorInfo.ElementType, VectorInfo.EC); + }; + auto [CondElementTy, CondElementCount] = GetVectorInfo(CondType); - if (LHSVT && RHSVT) { - if (isa(CondVT) != isa(LHSVT)) { + QualType ResultType; + if (LHSIsVector && RHSIsVector) { + if (CondType->isExtVectorType() != LHSType->isExtVectorType()) { Diag(QuestionLoc, diag::err_conditional_vector_cond_result_mismatch) - << /*isExtVector*/ isa(CondVT); + << /*isExtVector*/ CondType->isExtVectorType(); return {}; } @@ -5708,12 +5709,17 @@ QualType Sema::CheckVectorConditionalTypes(ExprResult &Cond, ExprResult &LHS, return {}; } ResultType = Context.getCommonSugaredType(LHSType, RHSType); - } else if (LHSVT || RHSVT) { - ResultType = CheckVectorOperands( - LHS, RHS, QuestionLoc, /*isCompAssign*/ false, /*AllowBothBool*/ true, - /*AllowBoolConversions*/ false, - /*AllowBoolOperation*/ true, - /*ReportInvalid*/ true); + } else if (LHSIsVector || RHSIsVector) { + if (CondType->isSizelessVectorType()) + ResultType = CheckSizelessVectorOperands(LHS, RHS, QuestionLoc, + /*IsCompAssign*/ false, + ArithConvKind::Conditional); + else + ResultType = CheckVectorOperands( + LHS, RHS, QuestionLoc, /*isCompAssign*/ false, /*AllowBothBool*/ true, + /*AllowBoolConversions*/ false, + /*AllowBoolOperation*/ true, + /*ReportInvalid*/ true); if (ResultType.isNull()) return {}; } else { @@ -5731,24 +5737,33 @@ QualType Sema::CheckVectorConditionalTypes(ExprResult &Cond, ExprResult &LHS, << ResultElementTy; return {}; } - if (CondType->isExtVectorType()) - ResultType = - Context.getExtVectorType(ResultElementTy, CondVT->getNumElements()); - else - ResultType = Context.getVectorType( - ResultElementTy, CondVT->getNumElements(), VectorKind::Generic); - + if (CondType->isExtVectorType()) { + ResultType = Context.getExtVectorType(ResultElementTy, + CondElementCount.getFixedValue()); + } else if (CondType->isSizelessVectorType()) { + ResultType = Context.getScalableVectorType( + ResultElementTy, CondElementCount.getKnownMinValue()); + // There are not scalable vector type mappings for all element counts. + if (ResultType.isNull()) { + Diag(QuestionLoc, diag::err_conditional_vector_scalar_type_unsupported) + << ResultElementTy << CondType; + return {}; + } + } else { + ResultType = Context.getVectorType(ResultElementTy, + CondElementCount.getFixedValue(), + VectorKind::Generic); + } LHS = ImpCastExprToType(LHS.get(), ResultType, CK_VectorSplat); RHS = ImpCastExprToType(RHS.get(), ResultType, CK_VectorSplat); } - assert(!ResultType.isNull() && ResultType->isVectorType() && + assert(!ResultType.isNull() && + (ResultType->isVectorType() || ResultType->isSizelessVectorType()) && (!CondType->isExtVectorType() || ResultType->isExtVectorType()) && "Result should have been a vector type"); - auto *ResultVectorTy = ResultType->castAs(); - QualType ResultElementTy = ResultVectorTy->getElementType(); - unsigned ResultElementCount = ResultVectorTy->getNumElements(); + auto [ResultElementTy, ResultElementCount] = GetVectorInfo(ResultType); if (ResultElementCount != CondElementCount) { Diag(QuestionLoc, diag::err_conditional_vector_size) << CondType << ResultType; @@ -5767,90 +5782,6 @@ QualType Sema::CheckVectorConditionalTypes(ExprResult &Cond, ExprResult &LHS, return ResultType; } -QualType Sema::CheckSizelessVectorConditionalTypes(ExprResult &Cond, - ExprResult &LHS, - ExprResult &RHS, - SourceLocation QuestionLoc) { - LHS = DefaultFunctionArrayLvalueConversion(LHS.get()); - RHS = DefaultFunctionArrayLvalueConversion(RHS.get()); - - QualType CondType = Cond.get()->getType(); - const auto *CondBT = CondType->castAs(); - QualType CondElementTy = CondBT->getSveEltType(Context); - llvm::ElementCount CondElementCount = - Context.getBuiltinVectorTypeInfo(CondBT).EC; - - QualType LHSType = LHS.get()->getType(); - const auto *LHSBT = - LHSType->isSveVLSBuiltinType() ? LHSType->getAs() : nullptr; - QualType RHSType = RHS.get()->getType(); - const auto *RHSBT = - RHSType->isSveVLSBuiltinType() ? RHSType->getAs() : nullptr; - - QualType ResultType; - - if (LHSBT && RHSBT) { - // If both are sizeless vector types, they must be the same type. - if (!Context.hasSameType(LHSType, RHSType)) { - Diag(QuestionLoc, diag::err_conditional_vector_mismatched) - << LHSType << RHSType; - return QualType(); - } - ResultType = LHSType; - } else if (LHSBT || RHSBT) { - ResultType = CheckSizelessVectorOperands(LHS, RHS, QuestionLoc, - /*IsCompAssign*/ false, - ArithConvKind::Conditional); - if (ResultType.isNull()) - return QualType(); - } else { - // Both are scalar so splat - QualType ResultElementTy; - LHSType = LHSType.getCanonicalType().getUnqualifiedType(); - RHSType = RHSType.getCanonicalType().getUnqualifiedType(); - - if (Context.hasSameType(LHSType, RHSType)) - ResultElementTy = LHSType; - else - ResultElementTy = UsualArithmeticConversions(LHS, RHS, QuestionLoc, - ArithConvKind::Conditional); - - if (ResultElementTy->isEnumeralType()) { - Diag(QuestionLoc, diag::err_conditional_vector_operand_type) - << ResultElementTy; - return QualType(); - } - - ResultType = Context.getScalableVectorType( - ResultElementTy, CondElementCount.getKnownMinValue()); - - LHS = ImpCastExprToType(LHS.get(), ResultType, CK_VectorSplat); - RHS = ImpCastExprToType(RHS.get(), ResultType, CK_VectorSplat); - } - - assert(!ResultType.isNull() && ResultType->isSveVLSBuiltinType() && - "Result should have been a vector type"); - auto *ResultBuiltinTy = ResultType->castAs(); - QualType ResultElementTy = ResultBuiltinTy->getSveEltType(Context); - llvm::ElementCount ResultElementCount = - Context.getBuiltinVectorTypeInfo(ResultBuiltinTy).EC; - - if (ResultElementCount != CondElementCount) { - Diag(QuestionLoc, diag::err_conditional_vector_size) - << CondType << ResultType; - return QualType(); - } - - if (Context.getTypeSize(ResultElementTy) != - Context.getTypeSize(CondElementTy)) { - Diag(QuestionLoc, diag::err_conditional_vector_element_size) - << CondType << ResultType; - return QualType(); - } - - return ResultType; -} - QualType Sema::CXXCheckConditionalOperands(ExprResult &Cond, ExprResult &LHS, ExprResult &RHS, ExprValueKind &VK, ExprObjectKind &OK, @@ -5864,14 +5795,10 @@ QualType Sema::CXXCheckConditionalOperands(ExprResult &Cond, ExprResult &LHS, bool IsVectorConditional = isValidVectorForConditionalCondition(Context, Cond.get()->getType()); - bool IsSizelessVectorConditional = - isValidSizelessVectorForConditionalCondition(Context, - Cond.get()->getType()); - // C++11 [expr.cond]p1 // The first expression is contextually converted to bool. if (!Cond.get()->isTypeDependent()) { - ExprResult CondRes = IsVectorConditional || IsSizelessVectorConditional + ExprResult CondRes = IsVectorConditional ? DefaultFunctionArrayLvalueConversion(Cond.get()) : CheckCXXBooleanCondition(Cond.get()); if (CondRes.isInvalid()) @@ -5940,9 +5867,6 @@ QualType Sema::CXXCheckConditionalOperands(ExprResult &Cond, ExprResult &LHS, if (IsVectorConditional) return CheckVectorConditionalTypes(Cond, LHS, RHS, QuestionLoc); - if (IsSizelessVectorConditional) - return CheckSizelessVectorConditionalTypes(Cond, LHS, RHS, QuestionLoc); - // WebAssembly tables are not allowed as conditional LHS or RHS. if (LTy->isWebAssemblyTableType() || RTy->isWebAssemblyTableType()) { Diag(QuestionLoc, diag::err_wasm_table_conditional_expression) diff --git a/clang/test/Sema/AArch64/sve-vector-conditional-op.cpp b/clang/test/Sema/AArch64/sve-vector-conditional-op.cpp new file mode 100644 index 0000000000000..0ca55e6268658 --- /dev/null +++ b/clang/test/Sema/AArch64/sve-vector-conditional-op.cpp @@ -0,0 +1,37 @@ +// RUN: %clang_cc1 %s -fsyntax-only -triple aarch64-none-linux-gnu -target-feature +sve -verify + +typedef int fixed_vector __attribute__((vector_size(4))); + +auto error_fixed_vector_result(__SVBool_t svbool, fixed_vector a, fixed_vector b) { + // expected-error@+1 {{vector condition type '__SVBool_t' and result type 'fixed_vector' (vector of 1 'int' value) do not have the same number of elements}} + return svbool ? a : b; +} + +auto error_void_result(__SVBool_t svbool) { + // expected-error@+1 {{GNU vector conditional operand cannot be void}} + return svbool ? (void)0 : (void)1; +} + +auto error_sve_splat_result_unsupported(__SVBool_t svbool, long long a, long long b) { + // expected-error@+1 {{scalar type 'long long' not supported with vector condition type '__SVBool_t'}} + return svbool ? a : b; +} + +auto error_sve_vector_result_matched_element_count(__SVBool_t svbool, __SVUint32_t a, __SVUint32_t b) { + // expected-error@+1 {{vector condition type '__SVBool_t' and result type '__SVUint32_t' do not have the same number of elements}} + return svbool ? a : b; +} + +// The following cases should be supported: + +__SVBool_t cond_svbool(__SVBool_t a, __SVBool_t b) { + return a < b ? a : b; +} + +__SVFloat32_t cond_svf32(__SVFloat32_t a, __SVFloat32_t b) { + return a < b ? a : b; +} + +__SVUint64_t cond_u64_splat(__SVUint64_t a) { + return a < 1ul ? a : 1ul; +}