diff --git a/clang/include/clang/AST/TypeBase.h b/clang/include/clang/AST/TypeBase.h index 625cc77dc1f08..5c51ec52daed3 100644 --- a/clang/include/clang/AST/TypeBase.h +++ b/clang/include/clang/AST/TypeBase.h @@ -4378,8 +4378,6 @@ class ConstantMatrixType final : public MatrixType { unsigned NumRows; unsigned NumColumns; - static constexpr unsigned MaxElementsPerDimension = (1 << 20) - 1; - ConstantMatrixType(QualType MatrixElementType, unsigned NRows, unsigned NColumns, QualType CanonElementType); @@ -4398,16 +4396,6 @@ class ConstantMatrixType final : public MatrixType { return getNumRows() * getNumColumns(); } - /// Returns true if \p NumElements is a valid matrix dimension. - static constexpr bool isDimensionValid(size_t NumElements) { - return NumElements > 0 && NumElements <= MaxElementsPerDimension; - } - - /// Returns the maximum number of elements per dimension. - static constexpr unsigned getMaxElementsPerDimension() { - return MaxElementsPerDimension; - } - void Profile(llvm::FoldingSetNodeID &ID) { Profile(ID, getElementType(), getNumRows(), getNumColumns(), getTypeClass()); diff --git a/clang/include/clang/Basic/LangOptions.def b/clang/include/clang/Basic/LangOptions.def index 9e850089ad87f..690439c1258c1 100644 --- a/clang/include/clang/Basic/LangOptions.def +++ b/clang/include/clang/Basic/LangOptions.def @@ -432,6 +432,7 @@ ENUM_LANGOPT(RegisterStaticDestructors, RegisterStaticDestructorsKind, 2, LANGOPT(RegCall4, 1, 0, NotCompatible, "Set __regcall4 as a default calling convention to respect __regcall ABI v.4") LANGOPT(MatrixTypes, 1, 0, NotCompatible, "Enable or disable the builtin matrix type") +VALUE_LANGOPT(MaxMatrixDimension, 32, (1 << 20) - 1, NotCompatible, "maximum allowed matrix dimension") LANGOPT(CXXAssumptions, 1, 1, NotCompatible, "Enable or disable codegen and compile-time checks for C++23's [[assume]] attribute") diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp index a8b41ba18fa01..fa363bc6fea7c 100644 --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -4713,8 +4713,8 @@ QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows, assert(MatrixType::isValidElementType(ElementTy) && "need a valid element type"); - assert(ConstantMatrixType::isDimensionValid(NumRows) && - ConstantMatrixType::isDimensionValid(NumColumns) && + assert(NumRows > 0 && NumRows <= LangOpts.MaxMatrixDimension && + NumColumns > 0 && NumColumns <= LangOpts.MaxMatrixDimension && "need valid matrix dimensions"); void *InsertPos = nullptr; if (ConstantMatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos)) diff --git a/clang/lib/Basic/LangOptions.cpp b/clang/lib/Basic/LangOptions.cpp index f034514466d3f..4f5d20d946de8 100644 --- a/clang/lib/Basic/LangOptions.cpp +++ b/clang/lib/Basic/LangOptions.cpp @@ -131,8 +131,12 @@ void LangOptions::setLangDefaults(LangOptions &Opts, Language Lang, Opts.NamedLoops = Std.isC2y(); Opts.HLSL = Lang == Language::HLSL; - if (Opts.HLSL && Opts.IncludeDefaultHeader) - Includes.push_back("hlsl.h"); + if (Opts.HLSL) { + if (Opts.IncludeDefaultHeader) + Includes.push_back("hlsl.h"); + // Set maximum matrix dimension to 4 for HLSL + Opts.MaxMatrixDimension = 4; + } // Set OpenCL Version. Opts.OpenCL = Std.isOpenCL(); diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp index e118dda4780e2..81c836fe60452 100644 --- a/clang/lib/Sema/HLSLExternalSemaSource.cpp +++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp @@ -159,7 +159,8 @@ void HLSLExternalSemaSource::defineHLSLMatrixAlias() { SourceLocation(), ColsParam)); TemplateParams.emplace_back(ColsParam); - const unsigned MaxMatDim = 4; + const unsigned MaxMatDim = SemaPtr->getLangOpts().MaxMatrixDimension; + ; auto *MaxRow = IntegerLiteral::Create( AST, llvm::APInt(AST.getIntWidth(AST.IntTy), MaxMatDim), AST.IntTy, SourceLocation()); diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 063db05665af1..de37ee8ba5783 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -16240,9 +16240,9 @@ getAndVerifyMatrixDimension(Expr *Expr, StringRef Name, Sema &S) { return {}; } uint64_t Dim = Value->getZExtValue(); - if (!ConstantMatrixType::isDimensionValid(Dim)) { + if (Dim == 0 || Dim > S.Context.getLangOpts().MaxMatrixDimension) { S.Diag(Expr->getBeginLoc(), diag::err_builtin_matrix_invalid_dimension) - << Name << ConstantMatrixType::getMaxElementsPerDimension(); + << Name << S.Context.getLangOpts().MaxMatrixDimension; return {}; } return Dim; diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp index a9e7c34de94f4..811295c566a66 100644 --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -2517,12 +2517,12 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols, Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange; return QualType(); } - if (!ConstantMatrixType::isDimensionValid(MatrixRows)) { + if (MatrixRows > Context.getLangOpts().MaxMatrixDimension) { Diag(AttrLoc, diag::err_attribute_size_too_large) << RowRange << "matrix row"; return QualType(); } - if (!ConstantMatrixType::isDimensionValid(MatrixColumns)) { + if (MatrixColumns > Context.getLangOpts().MaxMatrixDimension) { Diag(AttrLoc, diag::err_attribute_size_too_large) << ColRange << "matrix column"; return QualType();