-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[HLSL] Add support for the HLSL matrix type #159446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
fixes llvm#109839 This change is really simple. It creates a matrix alias that will let HLSL use the existing clang `matrix_type` infra. The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL. Testing therefore is limited to exercising the alias. The main difference in this attempt is the type printer.
@llvm/pr-subscribers-hlsl @llvm/pr-subscribers-clang Author: Farzon Lotfi (farzonl) Changesfixes #109839 This change is really simple. It creates a matrix alias that will let HLSL use the existing clang The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL. Testing therefore is limited to exercising the alias, sema errors, and basic codegen. The main difference in this attempt is the type printer and less of an emphasis on tests where things overlap with existing Patch is 24.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159446.diff 10 Files Affected:
diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index a7c514e809aa9..e8869e04390f3 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -4582,7 +4582,7 @@ defm ptrauth_block_descriptor_pointers : OptInCC1FFlag<"ptrauth-block-descriptor
def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
Visibility<[ClangOption, CC1Option]>,
HelpText<"Enable matrix data type and related builtin functions">,
- MarshallingInfoFlag<LangOpts<"MatrixTypes">>;
+ MarshallingInfoFlag<LangOpts<"MatrixTypes">, hlsl.KeyPath>;
defm raw_string_literals : BoolFOption<"raw-string-literals",
LangOpts<"RawStringLiterals">, Default<std#".hasRawStringLiterals()">,
diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h
index d93fb8c8eef6b..049fc7b8fe3f2 100644
--- a/clang/include/clang/Sema/HLSLExternalSemaSource.h
+++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -44,6 +44,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
private:
void defineTrivialHLSLTypes();
void defineHLSLVectorAlias();
+ void defineHLSLMatrixAlias();
void defineHLSLTypesWithForwardDeclarations();
void onCompletion(CXXRecordDecl *Record, CompletionFunction Fn);
};
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index cd59678d67f2f..82859c4015772 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -846,16 +846,41 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
}
}
-void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
- raw_ostream &OS) {
- printBefore(T->getElementType(), OS);
- OS << " __attribute__((matrix_type(";
+static void printDims(const ConstantMatrixType *T, raw_ostream &OS) {
OS << T->getNumRows() << ", " << T->getNumColumns();
+}
+
+static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) {
+ OS << "matrix<";
+ TP.printBefore(T->getElementType(), OS);
+}
+
+static void printHLSLMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
+ OS << ", ";
+ printDims(T, OS);
+ OS << ">";
+}
+
+static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T, raw_ostream &OS) {
+ TP.printBefore(T->getElementType(), OS);
+ OS << " __attribute__((matrix_type(";
+ printDims(T, OS);
OS << ")))";
}
-void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
- raw_ostream &OS) {
+void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, raw_ostream &OS) {
+ if (Policy.UseHLSLTypes) {
+ printHLSLMatrixBefore(*this, T, OS);
+ return;
+ }
+ printClangMatrixBefore(*this, T, OS);
+}
+
+void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
+ if (Policy.UseHLSLTypes) {
+ printHLSLMatrixAfter(T, OS);
+ return;
+ }
printAfter(T->getElementType(), OS);
}
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index eff94e0d7f950..c750261952d4f 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -115,6 +115,239 @@ typedef vector<float64_t, 2> float64_t2;
typedef vector<float64_t, 3> float64_t3;
typedef vector<float64_t, 4> float64_t4;
+ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<int16_t, 1, 1> int16_t1x1;
+typedef matrix<int16_t, 1, 2> int16_t1x2;
+typedef matrix<int16_t, 1, 3> int16_t1x3;
+typedef matrix<int16_t, 1, 4> int16_t1x4;
+typedef matrix<int16_t, 2, 1> int16_t2x1;
+typedef matrix<int16_t, 2, 2> int16_t2x2;
+typedef matrix<int16_t, 2, 3> int16_t2x3;
+typedef matrix<int16_t, 2, 4> int16_t2x4;
+typedef matrix<int16_t, 3, 1> int16_t3x1;
+typedef matrix<int16_t, 3, 2> int16_t3x2;
+typedef matrix<int16_t, 3, 3> int16_t3x3;
+typedef matrix<int16_t, 3, 4> int16_t3x4;
+typedef matrix<int16_t, 4, 1> int16_t4x1;
+typedef matrix<int16_t, 4, 2> int16_t4x2;
+typedef matrix<int16_t, 4, 3> int16_t4x3;
+typedef matrix<int16_t, 4, 4> int16_t4x4;
+typedef matrix<uint16_t, 1, 1> uint16_t1x1;
+typedef matrix<uint16_t, 1, 2> uint16_t1x2;
+typedef matrix<uint16_t, 1, 3> uint16_t1x3;
+typedef matrix<uint16_t, 1, 4> uint16_t1x4;
+typedef matrix<uint16_t, 2, 1> uint16_t2x1;
+typedef matrix<uint16_t, 2, 2> uint16_t2x2;
+typedef matrix<uint16_t, 2, 3> uint16_t2x3;
+typedef matrix<uint16_t, 2, 4> uint16_t2x4;
+typedef matrix<uint16_t, 3, 1> uint16_t3x1;
+typedef matrix<uint16_t, 3, 2> uint16_t3x2;
+typedef matrix<uint16_t, 3, 3> uint16_t3x3;
+typedef matrix<uint16_t, 3, 4> uint16_t3x4;
+typedef matrix<uint16_t, 4, 1> uint16_t4x1;
+typedef matrix<uint16_t, 4, 2> uint16_t4x2;
+typedef matrix<uint16_t, 4, 3> uint16_t4x3;
+typedef matrix<uint16_t, 4, 4> uint16_t4x4;
+#endif
+
+typedef matrix<int, 1, 1> int1x1;
+typedef matrix<int, 1, 2> int1x2;
+typedef matrix<int, 1, 3> int1x3;
+typedef matrix<int, 1, 4> int1x4;
+typedef matrix<int, 2, 1> int2x1;
+typedef matrix<int, 2, 2> int2x2;
+typedef matrix<int, 2, 3> int2x3;
+typedef matrix<int, 2, 4> int2x4;
+typedef matrix<int, 3, 1> int3x1;
+typedef matrix<int, 3, 2> int3x2;
+typedef matrix<int, 3, 3> int3x3;
+typedef matrix<int, 3, 4> int3x4;
+typedef matrix<int, 4, 1> int4x1;
+typedef matrix<int, 4, 2> int4x2;
+typedef matrix<int, 4, 3> int4x3;
+typedef matrix<int, 4, 4> int4x4;
+typedef matrix<uint, 1, 1> uint1x1;
+typedef matrix<uint, 1, 2> uint1x2;
+typedef matrix<uint, 1, 3> uint1x3;
+typedef matrix<uint, 1, 4> uint1x4;
+typedef matrix<uint, 2, 1> uint2x1;
+typedef matrix<uint, 2, 2> uint2x2;
+typedef matrix<uint, 2, 3> uint2x3;
+typedef matrix<uint, 2, 4> uint2x4;
+typedef matrix<uint, 3, 1> uint3x1;
+typedef matrix<uint, 3, 2> uint3x2;
+typedef matrix<uint, 3, 3> uint3x3;
+typedef matrix<uint, 3, 4> uint3x4;
+typedef matrix<uint, 4, 1> uint4x1;
+typedef matrix<uint, 4, 2> uint4x2;
+typedef matrix<uint, 4, 3> uint4x3;
+typedef matrix<uint, 4, 4> uint4x4;
+typedef matrix<int32_t, 1, 1> int32_t1x1;
+typedef matrix<int32_t, 1, 2> int32_t1x2;
+typedef matrix<int32_t, 1, 3> int32_t1x3;
+typedef matrix<int32_t, 1, 4> int32_t1x4;
+typedef matrix<int32_t, 2, 1> int32_t2x1;
+typedef matrix<int32_t, 2, 2> int32_t2x2;
+typedef matrix<int32_t, 2, 3> int32_t2x3;
+typedef matrix<int32_t, 2, 4> int32_t2x4;
+typedef matrix<int32_t, 3, 1> int32_t3x1;
+typedef matrix<int32_t, 3, 2> int32_t3x2;
+typedef matrix<int32_t, 3, 3> int32_t3x3;
+typedef matrix<int32_t, 3, 4> int32_t3x4;
+typedef matrix<int32_t, 4, 1> int32_t4x1;
+typedef matrix<int32_t, 4, 2> int32_t4x2;
+typedef matrix<int32_t, 4, 3> int32_t4x3;
+typedef matrix<int32_t, 4, 4> int32_t4x4;
+typedef matrix<uint32_t, 1, 1> uint32_t1x1;
+typedef matrix<uint32_t, 1, 2> uint32_t1x2;
+typedef matrix<uint32_t, 1, 3> uint32_t1x3;
+typedef matrix<uint32_t, 1, 4> uint32_t1x4;
+typedef matrix<uint32_t, 2, 1> uint32_t2x1;
+typedef matrix<uint32_t, 2, 2> uint32_t2x2;
+typedef matrix<uint32_t, 2, 3> uint32_t2x3;
+typedef matrix<uint32_t, 2, 4> uint32_t2x4;
+typedef matrix<uint32_t, 3, 1> uint32_t3x1;
+typedef matrix<uint32_t, 3, 2> uint32_t3x2;
+typedef matrix<uint32_t, 3, 3> uint32_t3x3;
+typedef matrix<uint32_t, 3, 4> uint32_t3x4;
+typedef matrix<uint32_t, 4, 1> uint32_t4x1;
+typedef matrix<uint32_t, 4, 2> uint32_t4x2;
+typedef matrix<uint32_t, 4, 3> uint32_t4x3;
+typedef matrix<uint32_t, 4, 4> uint32_t4x4;
+typedef matrix<int64_t, 1, 1> int64_t1x1;
+typedef matrix<int64_t, 1, 2> int64_t1x2;
+typedef matrix<int64_t, 1, 3> int64_t1x3;
+typedef matrix<int64_t, 1, 4> int64_t1x4;
+typedef matrix<int64_t, 2, 1> int64_t2x1;
+typedef matrix<int64_t, 2, 2> int64_t2x2;
+typedef matrix<int64_t, 2, 3> int64_t2x3;
+typedef matrix<int64_t, 2, 4> int64_t2x4;
+typedef matrix<int64_t, 3, 1> int64_t3x1;
+typedef matrix<int64_t, 3, 2> int64_t3x2;
+typedef matrix<int64_t, 3, 3> int64_t3x3;
+typedef matrix<int64_t, 3, 4> int64_t3x4;
+typedef matrix<int64_t, 4, 1> int64_t4x1;
+typedef matrix<int64_t, 4, 2> int64_t4x2;
+typedef matrix<int64_t, 4, 3> int64_t4x3;
+typedef matrix<int64_t, 4, 4> int64_t4x4;
+typedef matrix<uint64_t, 1, 1> uint64_t1x1;
+typedef matrix<uint64_t, 1, 2> uint64_t1x2;
+typedef matrix<uint64_t, 1, 3> uint64_t1x3;
+typedef matrix<uint64_t, 1, 4> uint64_t1x4;
+typedef matrix<uint64_t, 2, 1> uint64_t2x1;
+typedef matrix<uint64_t, 2, 2> uint64_t2x2;
+typedef matrix<uint64_t, 2, 3> uint64_t2x3;
+typedef matrix<uint64_t, 2, 4> uint64_t2x4;
+typedef matrix<uint64_t, 3, 1> uint64_t3x1;
+typedef matrix<uint64_t, 3, 2> uint64_t3x2;
+typedef matrix<uint64_t, 3, 3> uint64_t3x3;
+typedef matrix<uint64_t, 3, 4> uint64_t3x4;
+typedef matrix<uint64_t, 4, 1> uint64_t4x1;
+typedef matrix<uint64_t, 4, 2> uint64_t4x2;
+typedef matrix<uint64_t, 4, 3> uint64_t4x3;
+typedef matrix<uint64_t, 4, 4> uint64_t4x4;
+
+typedef matrix<half, 1, 1> half1x1;
+typedef matrix<half, 1, 2> half1x2;
+typedef matrix<half, 1, 3> half1x3;
+typedef matrix<half, 1, 4> half1x4;
+typedef matrix<half, 2, 1> half2x1;
+typedef matrix<half, 2, 2> half2x2;
+typedef matrix<half, 2, 3> half2x3;
+typedef matrix<half, 2, 4> half2x4;
+typedef matrix<half, 3, 1> half3x1;
+typedef matrix<half, 3, 2> half3x2;
+typedef matrix<half, 3, 3> half3x3;
+typedef matrix<half, 3, 4> half3x4;
+typedef matrix<half, 4, 1> half4x1;
+typedef matrix<half, 4, 2> half4x2;
+typedef matrix<half, 4, 3> half4x3;
+typedef matrix<half, 4, 4> half4x4;
+typedef matrix<float, 1, 1> float1x1;
+typedef matrix<float, 1, 2> float1x2;
+typedef matrix<float, 1, 3> float1x3;
+typedef matrix<float, 1, 4> float1x4;
+typedef matrix<float, 2, 1> float2x1;
+typedef matrix<float, 2, 2> float2x2;
+typedef matrix<float, 2, 3> float2x3;
+typedef matrix<float, 2, 4> float2x4;
+typedef matrix<float, 3, 1> float3x1;
+typedef matrix<float, 3, 2> float3x2;
+typedef matrix<float, 3, 3> float3x3;
+typedef matrix<float, 3, 4> float3x4;
+typedef matrix<float, 4, 1> float4x1;
+typedef matrix<float, 4, 2> float4x2;
+typedef matrix<float, 4, 3> float4x3;
+typedef matrix<float, 4, 4> float4x4;
+typedef matrix<double, 1, 1> double1x1;
+typedef matrix<double, 1, 2> double1x2;
+typedef matrix<double, 1, 3> double1x3;
+typedef matrix<double, 1, 4> double1x4;
+typedef matrix<double, 2, 1> double2x1;
+typedef matrix<double, 2, 2> double2x2;
+typedef matrix<double, 2, 3> double2x3;
+typedef matrix<double, 2, 4> double2x4;
+typedef matrix<double, 3, 1> double3x1;
+typedef matrix<double, 3, 2> double3x2;
+typedef matrix<double, 3, 3> double3x3;
+typedef matrix<double, 3, 4> double3x4;
+typedef matrix<double, 4, 1> double4x1;
+typedef matrix<double, 4, 2> double4x2;
+typedef matrix<double, 4, 3> double4x3;
+typedef matrix<double, 4, 4> double4x4;
+
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<float16_t, 1, 1> float16_t1x1;
+typedef matrix<float16_t, 1, 2> float16_t1x2;
+typedef matrix<float16_t, 1, 3> float16_t1x3;
+typedef matrix<float16_t, 1, 4> float16_t1x4;
+typedef matrix<float16_t, 2, 1> float16_t2x1;
+typedef matrix<float16_t, 2, 2> float16_t2x2;
+typedef matrix<float16_t, 2, 3> float16_t2x3;
+typedef matrix<float16_t, 2, 4> float16_t2x4;
+typedef matrix<float16_t, 3, 1> float16_t3x1;
+typedef matrix<float16_t, 3, 2> float16_t3x2;
+typedef matrix<float16_t, 3, 3> float16_t3x3;
+typedef matrix<float16_t, 3, 4> float16_t3x4;
+typedef matrix<float16_t, 4, 1> float16_t4x1;
+typedef matrix<float16_t, 4, 2> float16_t4x2;
+typedef matrix<float16_t, 4, 3> float16_t4x3;
+typedef matrix<float16_t, 4, 4> float16_t4x4;
+#endif
+
+typedef matrix<float32_t, 1, 1> float32_t1x1;
+typedef matrix<float32_t, 1, 2> float32_t1x2;
+typedef matrix<float32_t, 1, 3> float32_t1x3;
+typedef matrix<float32_t, 1, 4> float32_t1x4;
+typedef matrix<float32_t, 2, 1> float32_t2x1;
+typedef matrix<float32_t, 2, 2> float32_t2x2;
+typedef matrix<float32_t, 2, 3> float32_t2x3;
+typedef matrix<float32_t, 2, 4> float32_t2x4;
+typedef matrix<float32_t, 3, 1> float32_t3x1;
+typedef matrix<float32_t, 3, 2> float32_t3x2;
+typedef matrix<float32_t, 3, 3> float32_t3x3;
+typedef matrix<float32_t, 3, 4> float32_t3x4;
+typedef matrix<float32_t, 4, 1> float32_t4x1;
+typedef matrix<float32_t, 4, 2> float32_t4x2;
+typedef matrix<float32_t, 4, 3> float32_t4x3;
+typedef matrix<float32_t, 4, 4> float32_t4x4;
+typedef matrix<float64_t, 1, 1> float64_t1x1;
+typedef matrix<float64_t, 1, 2> float64_t1x2;
+typedef matrix<float64_t, 1, 3> float64_t1x3;
+typedef matrix<float64_t, 1, 4> float64_t1x4;
+typedef matrix<float64_t, 2, 1> float64_t2x1;
+typedef matrix<float64_t, 2, 2> float64_t2x2;
+typedef matrix<float64_t, 2, 3> float64_t2x3;
+typedef matrix<float64_t, 2, 4> float64_t2x4;
+typedef matrix<float64_t, 3, 1> float64_t3x1;
+typedef matrix<float64_t, 3, 2> float64_t3x2;
+typedef matrix<float64_t, 3, 3> float64_t3x3;
+typedef matrix<float64_t, 3, 4> float64_t3x4;
+typedef matrix<float64_t, 4, 1> float64_t4x1;
+typedef matrix<float64_t, 4, 2> float64_t4x2;
+typedef matrix<float64_t, 4, 3> float64_t4x3;
+typedef matrix<float64_t, 4, 4> float64_t4x4;
+
} // namespace hlsl
#endif //_HLSL_HLSL_BASIC_TYPES_H_
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 3386d8da281e9..9ac60712b87b8 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -121,8 +121,80 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
HLSLNamespace->addDecl(Template);
}
+void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
+ ASTContext &AST = SemaPtr->getASTContext();
+ llvm::SmallVector<NamedDecl *> TemplateParams;
+
+ auto *TypeParam = TemplateTypeParmDecl::Create(
+ AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+ &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+ TypeParam->setDefaultArgument(
+ AST, SemaPtr->getTrivialTemplateArgumentLoc(
+ TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
+
+ TemplateParams.emplace_back(TypeParam);
+
+ // these should be 64 bit to be consistent with other clang matrices.
+ auto *RowsParam = NonTypeTemplateParmDecl::Create(
+ AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+ &AST.Idents.get("rows_count", tok::TokenKind::identifier), AST.IntTy,
+ false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+ llvm::APInt RVal(AST.getIntWidth(AST.IntTy), 4);
+ TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
+ /*IsDefaulted=*/true);
+ RowsParam->setDefaultArgument(
+ AST, SemaPtr->getTrivialTemplateArgumentLoc(RDefault, AST.IntTy,
+ SourceLocation(), RowsParam));
+ TemplateParams.emplace_back(RowsParam);
+
+ auto *ColsParam = NonTypeTemplateParmDecl::Create(
+ AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 2,
+ &AST.Idents.get("cols_count", tok::TokenKind::identifier), AST.IntTy,
+ false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+ llvm::APInt CVal(AST.getIntWidth(AST.IntTy), 4);
+ TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
+ /*IsDefaulted=*/true);
+ ColsParam->setDefaultArgument(
+ AST, SemaPtr->getTrivialTemplateArgumentLoc(CDefault, AST.IntTy,
+ SourceLocation(), ColsParam));
+ TemplateParams.emplace_back(ColsParam);
+
+ auto *ParamList =
+ TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
+ TemplateParams, SourceLocation(), nullptr);
+
+ IdentifierInfo &II = AST.Idents.get("matrix", tok::TokenKind::identifier);
+
+ QualType AliasType = AST.getDependentSizedMatrixType(
+ AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+ DeclRefExpr::Create(
+ AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam, false,
+ DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
+ AST.IntTy, VK_LValue),
+ DeclRefExpr::Create(
+ AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam, false,
+ DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
+ AST.IntTy, VK_LValue),
+ SourceLocation());
+
+ auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+ SourceLocation(), &II,
+ AST.getTrivialTypeSourceInfo(AliasType));
+ Record->setImplicit(true);
+
+ auto *Template =
+ TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+ Record->getIdentifier(), ParamList, Record);
+
+ Record->setDescribedAliasTemplate(Template);
+ Template->setImplicit(true);
+ Template->setLexicalDeclContext(Record->getDeclContext());
+ HLSLNamespace->addDecl(Template);
+}
+
void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
defineHLSLVectorAlias();
+ defineHLSLMatrixAlias();
}
/// Set up common members and attributes for buffer types
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 0af38472b0fec..38f81a24945c5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3285,7 +3285,6 @@ static void BuildFlattenedTypeList(QualType BaseTy,
while (!WorkList.empty()) {
QualType T = WorkList.pop_back_val();
T = T.getCanonicalType().getUnqualifiedType();
- assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
llvm::SmallVector<QualType, 16> ElementFields;
// Generally I've avoided recursion in this algorithm, but arrays of
diff --git a/clang/test/AST/HLSL/matrix-alias.hlsl b/clang/test/AST/HLSL/matrix-alias.hlsl
new file mode 100644
index 0000000000000..2758b6f0d202f
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-alias.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -ast-dump -o - %s | FileCheck %s
+
+// Test that matrix aliases are set up properly for HLSL
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'vector<element, element_count>'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'vector<element, element_count>' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// CHECK-SAME: NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+ // Verify that the alias is generated inside the hlsl namespace.
+ hlsl::matrix<float, 2, 2> Mat2x2f;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:26:3, col:36>
+ // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:29> col:29 Mat2x2f 'hlsl::matrix<float, 2, 2>'
+
+ // Verify that you don't need to specify the namespace.
+ matrix<int, 2, 2> Mat2x2i;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:32:3, col:28>
+ // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:21> col:21 Mat2x2i 'matrix<int, 2, 2>'
+
+ // Build a bigger matrix.
+ matrix<double, 4, 4> Mat4x4d;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:38:3, col:31>
+ // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:24> col:24 Mat4x4d 'matrix<double, 4, 4>'
+
+ // Verify that the implicit arguments generate the correct type.
+ matrix<> ImpMat4x4;
+
+ // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:44:3, col:21>
+ // CHECK-NEXT: ...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
5081bc3
to
9e18ebb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me at a high level, great to see more users of the matrix intrinsics!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing we didn't do with the vector alias that we should do for the matrix
one is add a concept requirement that the dimensions be in the range [1,4].
We removed the restriction on vector length in HLSL, but I don't believe we have any plans to do the same for matrices.
fixes #109839
This change is really simple. It creates a matrix alias that will let HLSL use the existing clang
matrix_type
infra.The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL.
Testing therefore is limited to exercising the alias, sema errors, and basic codegen.
future work will add things like constructors and accessors.
The main difference in this attempt is the type printer and less of an emphasis on tests where things overlap with existing
matrix_type
testing like cast behavior.