Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -8627,6 +8627,8 @@ def err_conditional_vector_size : Error<
def err_conditional_vector_element_size : Error<
"vector condition type %0 and result type %1 do not have elements of the "
"same size">;
def err_conditional_vector_scalar_type_unsupported : Error<
"scalar type %0 not supported with vector condition type %1">;
def err_conditional_vector_has_void : Error<
"GNU vector conditional operand cannot be %select{void|a throw expression}0">;
def err_conditional_vector_operand_type
Expand Down
4 changes: 0 additions & 4 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -8699,10 +8699,6 @@ class Sema final : public SemaBase {
ExprResult &RHS,
SourceLocation QuestionLoc);

QualType CheckSizelessVectorConditionalTypes(ExprResult &Cond,
ExprResult &LHS, ExprResult &RHS,
SourceLocation QuestionLoc);

//// Determines if a type is trivially relocatable
/// according to the C++26 rules.
// FIXME: This is in Sema because it requires
Expand Down
184 changes: 54 additions & 130 deletions clang/lib/Sema/SemaExprCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5658,20 +5658,13 @@ static bool ConvertForConditional(Sema &Self, ExprResult &E, QualType T) {
// extension.
static bool isValidVectorForConditionalCondition(ASTContext &Ctx,
QualType CondTy) {
if (!CondTy->isVectorType() && !CondTy->isExtVectorType())
bool IsSVEVectorType = CondTy->isSveVLSBuiltinType();
if (!CondTy->isVectorType() && !CondTy->isExtVectorType() && !IsSVEVectorType)
return false;
const QualType EltTy =
cast<VectorType>(CondTy.getCanonicalType())->getElementType();
assert(!EltTy->isEnumeralType() && "Vectors cant be enum types");
return EltTy->isIntegralType(Ctx);
}

static bool isValidSizelessVectorForConditionalCondition(ASTContext &Ctx,
QualType CondTy) {
if (!CondTy->isSveVLSBuiltinType())
return false;
const QualType EltTy =
cast<BuiltinType>(CondTy.getCanonicalType())->getSveEltType(Ctx);
IsSVEVectorType
? cast<BuiltinType>(CondTy.getCanonicalType())->getSveEltType(Ctx)
: cast<VectorType>(CondTy.getCanonicalType())->getElementType();
assert(!EltTy->isEnumeralType() && "Vectors cant be enum types");
return EltTy->isIntegralType(Ctx);
}
Expand All @@ -5683,21 +5676,29 @@ QualType Sema::CheckVectorConditionalTypes(ExprResult &Cond, ExprResult &LHS,
RHS = DefaultFunctionArrayLvalueConversion(RHS.get());

QualType CondType = Cond.get()->getType();
const auto *CondVT = CondType->castAs<VectorType>();
QualType CondElementTy = CondVT->getElementType();
unsigned CondElementCount = CondVT->getNumElements();
QualType LHSType = LHS.get()->getType();
const auto *LHSVT = LHSType->getAs<VectorType>();
QualType RHSType = RHS.get()->getType();
const auto *RHSVT = RHSType->getAs<VectorType>();

QualType ResultType;
bool LHSIsVector = LHSType->isVectorType() || LHSType->isSizelessVectorType();
bool RHSIsVector = RHSType->isVectorType() || RHSType->isSizelessVectorType();

auto GetVectorInfo =
[&](QualType Type) -> std::pair<QualType, llvm::ElementCount> {
if (const auto *VT = Type->getAs<VectorType>())
return std::make_pair(VT->getElementType(),
llvm::ElementCount::getFixed(VT->getNumElements()));
ASTContext::BuiltinVectorTypeInfo VectorInfo =
Context.getBuiltinVectorTypeInfo(Type->castAs<BuiltinType>());
return std::make_pair(VectorInfo.ElementType, VectorInfo.EC);
};

auto [CondElementTy, CondElementCount] = GetVectorInfo(CondType);

if (LHSVT && RHSVT) {
if (isa<ExtVectorType>(CondVT) != isa<ExtVectorType>(LHSVT)) {
QualType ResultType;
if (LHSIsVector && RHSIsVector) {
if (CondType->isExtVectorType() != LHSType->isExtVectorType()) {
Diag(QuestionLoc, diag::err_conditional_vector_cond_result_mismatch)
<< /*isExtVector*/ isa<ExtVectorType>(CondVT);
<< /*isExtVector*/ CondType->isExtVectorType();
return {};
}

Expand All @@ -5708,12 +5709,17 @@ QualType Sema::CheckVectorConditionalTypes(ExprResult &Cond, ExprResult &LHS,
return {};
}
ResultType = Context.getCommonSugaredType(LHSType, RHSType);
} else if (LHSVT || RHSVT) {
ResultType = CheckVectorOperands(
LHS, RHS, QuestionLoc, /*isCompAssign*/ false, /*AllowBothBool*/ true,
/*AllowBoolConversions*/ false,
/*AllowBoolOperation*/ true,
/*ReportInvalid*/ true);
} else if (LHSIsVector || RHSIsVector) {
if (CondType->isSizelessVectorType())
ResultType = CheckSizelessVectorOperands(LHS, RHS, QuestionLoc,
/*IsCompAssign*/ false,
ArithConvKind::Conditional);
else
ResultType = CheckVectorOperands(
LHS, RHS, QuestionLoc, /*isCompAssign*/ false, /*AllowBothBool*/ true,
/*AllowBoolConversions*/ false,
/*AllowBoolOperation*/ true,
/*ReportInvalid*/ true);
if (ResultType.isNull())
return {};
} else {
Expand All @@ -5731,24 +5737,33 @@ QualType Sema::CheckVectorConditionalTypes(ExprResult &Cond, ExprResult &LHS,
<< ResultElementTy;
return {};
}
if (CondType->isExtVectorType())
ResultType =
Context.getExtVectorType(ResultElementTy, CondVT->getNumElements());
else
ResultType = Context.getVectorType(
ResultElementTy, CondVT->getNumElements(), VectorKind::Generic);

if (CondType->isExtVectorType()) {
ResultType = Context.getExtVectorType(ResultElementTy,
CondElementCount.getFixedValue());
} else if (CondType->isSizelessVectorType()) {
ResultType = Context.getScalableVectorType(
ResultElementTy, CondElementCount.getKnownMinValue());
// There are not scalable vector type mappings for all element counts.
if (ResultType.isNull()) {
Diag(QuestionLoc, diag::err_conditional_vector_scalar_type_unsupported)
<< ResultElementTy << CondType;
return {};
}
} else {
ResultType = Context.getVectorType(ResultElementTy,
CondElementCount.getFixedValue(),
VectorKind::Generic);
}
LHS = ImpCastExprToType(LHS.get(), ResultType, CK_VectorSplat);
RHS = ImpCastExprToType(RHS.get(), ResultType, CK_VectorSplat);
}

assert(!ResultType.isNull() && ResultType->isVectorType() &&
assert(!ResultType.isNull() &&
(ResultType->isVectorType() || ResultType->isSizelessVectorType()) &&
(!CondType->isExtVectorType() || ResultType->isExtVectorType()) &&
"Result should have been a vector type");
auto *ResultVectorTy = ResultType->castAs<VectorType>();
QualType ResultElementTy = ResultVectorTy->getElementType();
unsigned ResultElementCount = ResultVectorTy->getNumElements();

auto [ResultElementTy, ResultElementCount] = GetVectorInfo(ResultType);
if (ResultElementCount != CondElementCount) {
Diag(QuestionLoc, diag::err_conditional_vector_size) << CondType
<< ResultType;
Expand All @@ -5767,90 +5782,6 @@ QualType Sema::CheckVectorConditionalTypes(ExprResult &Cond, ExprResult &LHS,
return ResultType;
}

QualType Sema::CheckSizelessVectorConditionalTypes(ExprResult &Cond,
ExprResult &LHS,
ExprResult &RHS,
SourceLocation QuestionLoc) {
LHS = DefaultFunctionArrayLvalueConversion(LHS.get());
RHS = DefaultFunctionArrayLvalueConversion(RHS.get());

QualType CondType = Cond.get()->getType();
const auto *CondBT = CondType->castAs<BuiltinType>();
QualType CondElementTy = CondBT->getSveEltType(Context);
llvm::ElementCount CondElementCount =
Context.getBuiltinVectorTypeInfo(CondBT).EC;

QualType LHSType = LHS.get()->getType();
const auto *LHSBT =
LHSType->isSveVLSBuiltinType() ? LHSType->getAs<BuiltinType>() : nullptr;
QualType RHSType = RHS.get()->getType();
const auto *RHSBT =
RHSType->isSveVLSBuiltinType() ? RHSType->getAs<BuiltinType>() : nullptr;

QualType ResultType;

if (LHSBT && RHSBT) {
// If both are sizeless vector types, they must be the same type.
if (!Context.hasSameType(LHSType, RHSType)) {
Diag(QuestionLoc, diag::err_conditional_vector_mismatched)
<< LHSType << RHSType;
return QualType();
}
ResultType = LHSType;
} else if (LHSBT || RHSBT) {
ResultType = CheckSizelessVectorOperands(LHS, RHS, QuestionLoc,
/*IsCompAssign*/ false,
ArithConvKind::Conditional);
if (ResultType.isNull())
return QualType();
} else {
// Both are scalar so splat
QualType ResultElementTy;
LHSType = LHSType.getCanonicalType().getUnqualifiedType();
RHSType = RHSType.getCanonicalType().getUnqualifiedType();

if (Context.hasSameType(LHSType, RHSType))
ResultElementTy = LHSType;
else
ResultElementTy = UsualArithmeticConversions(LHS, RHS, QuestionLoc,
ArithConvKind::Conditional);

if (ResultElementTy->isEnumeralType()) {
Diag(QuestionLoc, diag::err_conditional_vector_operand_type)
<< ResultElementTy;
return QualType();
}

ResultType = Context.getScalableVectorType(
ResultElementTy, CondElementCount.getKnownMinValue());

LHS = ImpCastExprToType(LHS.get(), ResultType, CK_VectorSplat);
RHS = ImpCastExprToType(RHS.get(), ResultType, CK_VectorSplat);
}

assert(!ResultType.isNull() && ResultType->isSveVLSBuiltinType() &&
"Result should have been a vector type");
auto *ResultBuiltinTy = ResultType->castAs<BuiltinType>();
QualType ResultElementTy = ResultBuiltinTy->getSveEltType(Context);
llvm::ElementCount ResultElementCount =
Context.getBuiltinVectorTypeInfo(ResultBuiltinTy).EC;

if (ResultElementCount != CondElementCount) {
Diag(QuestionLoc, diag::err_conditional_vector_size)
<< CondType << ResultType;
return QualType();
}

if (Context.getTypeSize(ResultElementTy) !=
Context.getTypeSize(CondElementTy)) {
Diag(QuestionLoc, diag::err_conditional_vector_element_size)
<< CondType << ResultType;
return QualType();
}

return ResultType;
}

QualType Sema::CXXCheckConditionalOperands(ExprResult &Cond, ExprResult &LHS,
ExprResult &RHS, ExprValueKind &VK,
ExprObjectKind &OK,
Expand All @@ -5864,14 +5795,10 @@ QualType Sema::CXXCheckConditionalOperands(ExprResult &Cond, ExprResult &LHS,
bool IsVectorConditional =
isValidVectorForConditionalCondition(Context, Cond.get()->getType());

bool IsSizelessVectorConditional =
isValidSizelessVectorForConditionalCondition(Context,
Cond.get()->getType());

// C++11 [expr.cond]p1
// The first expression is contextually converted to bool.
if (!Cond.get()->isTypeDependent()) {
ExprResult CondRes = IsVectorConditional || IsSizelessVectorConditional
ExprResult CondRes = IsVectorConditional
? DefaultFunctionArrayLvalueConversion(Cond.get())
: CheckCXXBooleanCondition(Cond.get());
if (CondRes.isInvalid())
Expand Down Expand Up @@ -5940,9 +5867,6 @@ QualType Sema::CXXCheckConditionalOperands(ExprResult &Cond, ExprResult &LHS,
if (IsVectorConditional)
return CheckVectorConditionalTypes(Cond, LHS, RHS, QuestionLoc);

if (IsSizelessVectorConditional)
return CheckSizelessVectorConditionalTypes(Cond, LHS, RHS, QuestionLoc);

// WebAssembly tables are not allowed as conditional LHS or RHS.
if (LTy->isWebAssemblyTableType() || RTy->isWebAssemblyTableType()) {
Diag(QuestionLoc, diag::err_wasm_table_conditional_expression)
Expand Down
37 changes: 37 additions & 0 deletions clang/test/Sema/AArch64/sve-vector-conditional-op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: %clang_cc1 %s -fsyntax-only -triple aarch64-none-linux-gnu -target-feature +sve -verify

typedef int fixed_vector __attribute__((vector_size(4)));

auto error_fixed_vector_result(__SVBool_t svbool, fixed_vector a, fixed_vector b) {
// expected-error@+1 {{vector condition type '__SVBool_t' and result type 'fixed_vector' (vector of 1 'int' value) do not have the same number of elements}}
return svbool ? a : b;
}

auto error_void_result(__SVBool_t svbool) {
// expected-error@+1 {{GNU vector conditional operand cannot be void}}
return svbool ? (void)0 : (void)1;
}

auto error_sve_splat_result_unsupported(__SVBool_t svbool, long long a, long long b) {
// expected-error@+1 {{scalar type 'long long' not supported with vector condition type '__SVBool_t'}}
return svbool ? a : b;
}

auto error_sve_vector_result_matched_element_count(__SVBool_t svbool, __SVUint32_t a, __SVUint32_t b) {
// expected-error@+1 {{vector condition type '__SVBool_t' and result type '__SVUint32_t' do not have the same number of elements}}
return svbool ? a : b;
}

// The following cases should be supported:

__SVBool_t cond_svbool(__SVBool_t a, __SVBool_t b) {
return a < b ? a : b;
}

__SVFloat32_t cond_svf32(__SVFloat32_t a, __SVFloat32_t b) {
return a < b ? a : b;
}

__SVUint64_t cond_u64_splat(__SVUint64_t a) {
return a < 1ul ? a : 1ul;
}