Skip to content

Commit 6ac40aa

Browse files
authored
[HLSL] Add support for the HLSL matrix type (#159446)
fixes #109839 This change is really simple. It creates a matrix alias that will let HLSL use the existing clang `matrix_type` infra. The only additional change was to add explict alias for the typed dimensions of 1-4 inclusive matricies available in HLSL. Testing therefore is limited to exercising the alias, sema errors, and basic codegen. future work will add things like constructors and accessors. The main difference in this attempt is the type printer and less of an emphasis on tests where things overlap with existing `matrix_type` testing like cast behavior.
1 parent c31d503 commit 6ac40aa

File tree

10 files changed

+490
-6
lines changed

10 files changed

+490
-6
lines changed

clang/include/clang/Driver/Options.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4587,7 +4587,7 @@ defm ptrauth_block_descriptor_pointers : OptInCC1FFlag<"ptrauth-block-descriptor
45874587
def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
45884588
Visibility<[ClangOption, CC1Option]>,
45894589
HelpText<"Enable matrix data type and related builtin functions">,
4590-
MarshallingInfoFlag<LangOpts<"MatrixTypes">>;
4590+
MarshallingInfoFlag<LangOpts<"MatrixTypes">, hlsl.KeyPath>;
45914591

45924592
defm raw_string_literals : BoolFOption<"raw-string-literals",
45934593
LangOpts<"RawStringLiterals">, Default<std#".hasRawStringLiterals()">,

clang/include/clang/Sema/HLSLExternalSemaSource.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
4444
private:
4545
void defineTrivialHLSLTypes();
4646
void defineHLSLVectorAlias();
47+
void defineHLSLMatrixAlias();
4748
void defineHLSLTypesWithForwardDeclarations();
4849
void onCompletion(CXXRecordDecl *Record, CompletionFunction Fn);
4950
};

clang/lib/AST/TypePrinter.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -846,16 +846,45 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
846846
}
847847
}
848848

849-
void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
850-
raw_ostream &OS) {
851-
printBefore(T->getElementType(), OS);
852-
OS << " __attribute__((matrix_type(";
849+
static void printDims(const ConstantMatrixType *T, raw_ostream &OS) {
853850
OS << T->getNumRows() << ", " << T->getNumColumns();
851+
}
852+
853+
static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T,
854+
raw_ostream &OS) {
855+
OS << "matrix<";
856+
TP.printBefore(T->getElementType(), OS);
857+
}
858+
859+
static void printHLSLMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
860+
OS << ", ";
861+
printDims(T, OS);
862+
OS << ">";
863+
}
864+
865+
static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T,
866+
raw_ostream &OS) {
867+
TP.printBefore(T->getElementType(), OS);
868+
OS << " __attribute__((matrix_type(";
869+
printDims(T, OS);
854870
OS << ")))";
855871
}
856872

873+
void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
874+
raw_ostream &OS) {
875+
if (Policy.UseHLSLTypes) {
876+
printHLSLMatrixBefore(*this, T, OS);
877+
return;
878+
}
879+
printClangMatrixBefore(*this, T, OS);
880+
}
881+
857882
void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
858883
raw_ostream &OS) {
884+
if (Policy.UseHLSLTypes) {
885+
printHLSLMatrixAfter(T, OS);
886+
return;
887+
}
859888
printAfter(T->getElementType(), OS);
860889
}
861890

clang/lib/Headers/hlsl/hlsl_basic_types.h

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,239 @@ typedef vector<float64_t, 2> float64_t2;
115115
typedef vector<float64_t, 3> float64_t3;
116116
typedef vector<float64_t, 4> float64_t4;
117117

118+
#ifdef __HLSL_ENABLE_16_BIT
119+
typedef matrix<int16_t, 1, 1> int16_t1x1;
120+
typedef matrix<int16_t, 1, 2> int16_t1x2;
121+
typedef matrix<int16_t, 1, 3> int16_t1x3;
122+
typedef matrix<int16_t, 1, 4> int16_t1x4;
123+
typedef matrix<int16_t, 2, 1> int16_t2x1;
124+
typedef matrix<int16_t, 2, 2> int16_t2x2;
125+
typedef matrix<int16_t, 2, 3> int16_t2x3;
126+
typedef matrix<int16_t, 2, 4> int16_t2x4;
127+
typedef matrix<int16_t, 3, 1> int16_t3x1;
128+
typedef matrix<int16_t, 3, 2> int16_t3x2;
129+
typedef matrix<int16_t, 3, 3> int16_t3x3;
130+
typedef matrix<int16_t, 3, 4> int16_t3x4;
131+
typedef matrix<int16_t, 4, 1> int16_t4x1;
132+
typedef matrix<int16_t, 4, 2> int16_t4x2;
133+
typedef matrix<int16_t, 4, 3> int16_t4x3;
134+
typedef matrix<int16_t, 4, 4> int16_t4x4;
135+
typedef matrix<uint16_t, 1, 1> uint16_t1x1;
136+
typedef matrix<uint16_t, 1, 2> uint16_t1x2;
137+
typedef matrix<uint16_t, 1, 3> uint16_t1x3;
138+
typedef matrix<uint16_t, 1, 4> uint16_t1x4;
139+
typedef matrix<uint16_t, 2, 1> uint16_t2x1;
140+
typedef matrix<uint16_t, 2, 2> uint16_t2x2;
141+
typedef matrix<uint16_t, 2, 3> uint16_t2x3;
142+
typedef matrix<uint16_t, 2, 4> uint16_t2x4;
143+
typedef matrix<uint16_t, 3, 1> uint16_t3x1;
144+
typedef matrix<uint16_t, 3, 2> uint16_t3x2;
145+
typedef matrix<uint16_t, 3, 3> uint16_t3x3;
146+
typedef matrix<uint16_t, 3, 4> uint16_t3x4;
147+
typedef matrix<uint16_t, 4, 1> uint16_t4x1;
148+
typedef matrix<uint16_t, 4, 2> uint16_t4x2;
149+
typedef matrix<uint16_t, 4, 3> uint16_t4x3;
150+
typedef matrix<uint16_t, 4, 4> uint16_t4x4;
151+
#endif
152+
153+
typedef matrix<int, 1, 1> int1x1;
154+
typedef matrix<int, 1, 2> int1x2;
155+
typedef matrix<int, 1, 3> int1x3;
156+
typedef matrix<int, 1, 4> int1x4;
157+
typedef matrix<int, 2, 1> int2x1;
158+
typedef matrix<int, 2, 2> int2x2;
159+
typedef matrix<int, 2, 3> int2x3;
160+
typedef matrix<int, 2, 4> int2x4;
161+
typedef matrix<int, 3, 1> int3x1;
162+
typedef matrix<int, 3, 2> int3x2;
163+
typedef matrix<int, 3, 3> int3x3;
164+
typedef matrix<int, 3, 4> int3x4;
165+
typedef matrix<int, 4, 1> int4x1;
166+
typedef matrix<int, 4, 2> int4x2;
167+
typedef matrix<int, 4, 3> int4x3;
168+
typedef matrix<int, 4, 4> int4x4;
169+
typedef matrix<uint, 1, 1> uint1x1;
170+
typedef matrix<uint, 1, 2> uint1x2;
171+
typedef matrix<uint, 1, 3> uint1x3;
172+
typedef matrix<uint, 1, 4> uint1x4;
173+
typedef matrix<uint, 2, 1> uint2x1;
174+
typedef matrix<uint, 2, 2> uint2x2;
175+
typedef matrix<uint, 2, 3> uint2x3;
176+
typedef matrix<uint, 2, 4> uint2x4;
177+
typedef matrix<uint, 3, 1> uint3x1;
178+
typedef matrix<uint, 3, 2> uint3x2;
179+
typedef matrix<uint, 3, 3> uint3x3;
180+
typedef matrix<uint, 3, 4> uint3x4;
181+
typedef matrix<uint, 4, 1> uint4x1;
182+
typedef matrix<uint, 4, 2> uint4x2;
183+
typedef matrix<uint, 4, 3> uint4x3;
184+
typedef matrix<uint, 4, 4> uint4x4;
185+
typedef matrix<int32_t, 1, 1> int32_t1x1;
186+
typedef matrix<int32_t, 1, 2> int32_t1x2;
187+
typedef matrix<int32_t, 1, 3> int32_t1x3;
188+
typedef matrix<int32_t, 1, 4> int32_t1x4;
189+
typedef matrix<int32_t, 2, 1> int32_t2x1;
190+
typedef matrix<int32_t, 2, 2> int32_t2x2;
191+
typedef matrix<int32_t, 2, 3> int32_t2x3;
192+
typedef matrix<int32_t, 2, 4> int32_t2x4;
193+
typedef matrix<int32_t, 3, 1> int32_t3x1;
194+
typedef matrix<int32_t, 3, 2> int32_t3x2;
195+
typedef matrix<int32_t, 3, 3> int32_t3x3;
196+
typedef matrix<int32_t, 3, 4> int32_t3x4;
197+
typedef matrix<int32_t, 4, 1> int32_t4x1;
198+
typedef matrix<int32_t, 4, 2> int32_t4x2;
199+
typedef matrix<int32_t, 4, 3> int32_t4x3;
200+
typedef matrix<int32_t, 4, 4> int32_t4x4;
201+
typedef matrix<uint32_t, 1, 1> uint32_t1x1;
202+
typedef matrix<uint32_t, 1, 2> uint32_t1x2;
203+
typedef matrix<uint32_t, 1, 3> uint32_t1x3;
204+
typedef matrix<uint32_t, 1, 4> uint32_t1x4;
205+
typedef matrix<uint32_t, 2, 1> uint32_t2x1;
206+
typedef matrix<uint32_t, 2, 2> uint32_t2x2;
207+
typedef matrix<uint32_t, 2, 3> uint32_t2x3;
208+
typedef matrix<uint32_t, 2, 4> uint32_t2x4;
209+
typedef matrix<uint32_t, 3, 1> uint32_t3x1;
210+
typedef matrix<uint32_t, 3, 2> uint32_t3x2;
211+
typedef matrix<uint32_t, 3, 3> uint32_t3x3;
212+
typedef matrix<uint32_t, 3, 4> uint32_t3x4;
213+
typedef matrix<uint32_t, 4, 1> uint32_t4x1;
214+
typedef matrix<uint32_t, 4, 2> uint32_t4x2;
215+
typedef matrix<uint32_t, 4, 3> uint32_t4x3;
216+
typedef matrix<uint32_t, 4, 4> uint32_t4x4;
217+
typedef matrix<int64_t, 1, 1> int64_t1x1;
218+
typedef matrix<int64_t, 1, 2> int64_t1x2;
219+
typedef matrix<int64_t, 1, 3> int64_t1x3;
220+
typedef matrix<int64_t, 1, 4> int64_t1x4;
221+
typedef matrix<int64_t, 2, 1> int64_t2x1;
222+
typedef matrix<int64_t, 2, 2> int64_t2x2;
223+
typedef matrix<int64_t, 2, 3> int64_t2x3;
224+
typedef matrix<int64_t, 2, 4> int64_t2x4;
225+
typedef matrix<int64_t, 3, 1> int64_t3x1;
226+
typedef matrix<int64_t, 3, 2> int64_t3x2;
227+
typedef matrix<int64_t, 3, 3> int64_t3x3;
228+
typedef matrix<int64_t, 3, 4> int64_t3x4;
229+
typedef matrix<int64_t, 4, 1> int64_t4x1;
230+
typedef matrix<int64_t, 4, 2> int64_t4x2;
231+
typedef matrix<int64_t, 4, 3> int64_t4x3;
232+
typedef matrix<int64_t, 4, 4> int64_t4x4;
233+
typedef matrix<uint64_t, 1, 1> uint64_t1x1;
234+
typedef matrix<uint64_t, 1, 2> uint64_t1x2;
235+
typedef matrix<uint64_t, 1, 3> uint64_t1x3;
236+
typedef matrix<uint64_t, 1, 4> uint64_t1x4;
237+
typedef matrix<uint64_t, 2, 1> uint64_t2x1;
238+
typedef matrix<uint64_t, 2, 2> uint64_t2x2;
239+
typedef matrix<uint64_t, 2, 3> uint64_t2x3;
240+
typedef matrix<uint64_t, 2, 4> uint64_t2x4;
241+
typedef matrix<uint64_t, 3, 1> uint64_t3x1;
242+
typedef matrix<uint64_t, 3, 2> uint64_t3x2;
243+
typedef matrix<uint64_t, 3, 3> uint64_t3x3;
244+
typedef matrix<uint64_t, 3, 4> uint64_t3x4;
245+
typedef matrix<uint64_t, 4, 1> uint64_t4x1;
246+
typedef matrix<uint64_t, 4, 2> uint64_t4x2;
247+
typedef matrix<uint64_t, 4, 3> uint64_t4x3;
248+
typedef matrix<uint64_t, 4, 4> uint64_t4x4;
249+
250+
typedef matrix<half, 1, 1> half1x1;
251+
typedef matrix<half, 1, 2> half1x2;
252+
typedef matrix<half, 1, 3> half1x3;
253+
typedef matrix<half, 1, 4> half1x4;
254+
typedef matrix<half, 2, 1> half2x1;
255+
typedef matrix<half, 2, 2> half2x2;
256+
typedef matrix<half, 2, 3> half2x3;
257+
typedef matrix<half, 2, 4> half2x4;
258+
typedef matrix<half, 3, 1> half3x1;
259+
typedef matrix<half, 3, 2> half3x2;
260+
typedef matrix<half, 3, 3> half3x3;
261+
typedef matrix<half, 3, 4> half3x4;
262+
typedef matrix<half, 4, 1> half4x1;
263+
typedef matrix<half, 4, 2> half4x2;
264+
typedef matrix<half, 4, 3> half4x3;
265+
typedef matrix<half, 4, 4> half4x4;
266+
typedef matrix<float, 1, 1> float1x1;
267+
typedef matrix<float, 1, 2> float1x2;
268+
typedef matrix<float, 1, 3> float1x3;
269+
typedef matrix<float, 1, 4> float1x4;
270+
typedef matrix<float, 2, 1> float2x1;
271+
typedef matrix<float, 2, 2> float2x2;
272+
typedef matrix<float, 2, 3> float2x3;
273+
typedef matrix<float, 2, 4> float2x4;
274+
typedef matrix<float, 3, 1> float3x1;
275+
typedef matrix<float, 3, 2> float3x2;
276+
typedef matrix<float, 3, 3> float3x3;
277+
typedef matrix<float, 3, 4> float3x4;
278+
typedef matrix<float, 4, 1> float4x1;
279+
typedef matrix<float, 4, 2> float4x2;
280+
typedef matrix<float, 4, 3> float4x3;
281+
typedef matrix<float, 4, 4> float4x4;
282+
typedef matrix<double, 1, 1> double1x1;
283+
typedef matrix<double, 1, 2> double1x2;
284+
typedef matrix<double, 1, 3> double1x3;
285+
typedef matrix<double, 1, 4> double1x4;
286+
typedef matrix<double, 2, 1> double2x1;
287+
typedef matrix<double, 2, 2> double2x2;
288+
typedef matrix<double, 2, 3> double2x3;
289+
typedef matrix<double, 2, 4> double2x4;
290+
typedef matrix<double, 3, 1> double3x1;
291+
typedef matrix<double, 3, 2> double3x2;
292+
typedef matrix<double, 3, 3> double3x3;
293+
typedef matrix<double, 3, 4> double3x4;
294+
typedef matrix<double, 4, 1> double4x1;
295+
typedef matrix<double, 4, 2> double4x2;
296+
typedef matrix<double, 4, 3> double4x3;
297+
typedef matrix<double, 4, 4> double4x4;
298+
299+
#ifdef __HLSL_ENABLE_16_BIT
300+
typedef matrix<float16_t, 1, 1> float16_t1x1;
301+
typedef matrix<float16_t, 1, 2> float16_t1x2;
302+
typedef matrix<float16_t, 1, 3> float16_t1x3;
303+
typedef matrix<float16_t, 1, 4> float16_t1x4;
304+
typedef matrix<float16_t, 2, 1> float16_t2x1;
305+
typedef matrix<float16_t, 2, 2> float16_t2x2;
306+
typedef matrix<float16_t, 2, 3> float16_t2x3;
307+
typedef matrix<float16_t, 2, 4> float16_t2x4;
308+
typedef matrix<float16_t, 3, 1> float16_t3x1;
309+
typedef matrix<float16_t, 3, 2> float16_t3x2;
310+
typedef matrix<float16_t, 3, 3> float16_t3x3;
311+
typedef matrix<float16_t, 3, 4> float16_t3x4;
312+
typedef matrix<float16_t, 4, 1> float16_t4x1;
313+
typedef matrix<float16_t, 4, 2> float16_t4x2;
314+
typedef matrix<float16_t, 4, 3> float16_t4x3;
315+
typedef matrix<float16_t, 4, 4> float16_t4x4;
316+
#endif
317+
318+
typedef matrix<float32_t, 1, 1> float32_t1x1;
319+
typedef matrix<float32_t, 1, 2> float32_t1x2;
320+
typedef matrix<float32_t, 1, 3> float32_t1x3;
321+
typedef matrix<float32_t, 1, 4> float32_t1x4;
322+
typedef matrix<float32_t, 2, 1> float32_t2x1;
323+
typedef matrix<float32_t, 2, 2> float32_t2x2;
324+
typedef matrix<float32_t, 2, 3> float32_t2x3;
325+
typedef matrix<float32_t, 2, 4> float32_t2x4;
326+
typedef matrix<float32_t, 3, 1> float32_t3x1;
327+
typedef matrix<float32_t, 3, 2> float32_t3x2;
328+
typedef matrix<float32_t, 3, 3> float32_t3x3;
329+
typedef matrix<float32_t, 3, 4> float32_t3x4;
330+
typedef matrix<float32_t, 4, 1> float32_t4x1;
331+
typedef matrix<float32_t, 4, 2> float32_t4x2;
332+
typedef matrix<float32_t, 4, 3> float32_t4x3;
333+
typedef matrix<float32_t, 4, 4> float32_t4x4;
334+
typedef matrix<float64_t, 1, 1> float64_t1x1;
335+
typedef matrix<float64_t, 1, 2> float64_t1x2;
336+
typedef matrix<float64_t, 1, 3> float64_t1x3;
337+
typedef matrix<float64_t, 1, 4> float64_t1x4;
338+
typedef matrix<float64_t, 2, 1> float64_t2x1;
339+
typedef matrix<float64_t, 2, 2> float64_t2x2;
340+
typedef matrix<float64_t, 2, 3> float64_t2x3;
341+
typedef matrix<float64_t, 2, 4> float64_t2x4;
342+
typedef matrix<float64_t, 3, 1> float64_t3x1;
343+
typedef matrix<float64_t, 3, 2> float64_t3x2;
344+
typedef matrix<float64_t, 3, 3> float64_t3x3;
345+
typedef matrix<float64_t, 3, 4> float64_t3x4;
346+
typedef matrix<float64_t, 4, 1> float64_t4x1;
347+
typedef matrix<float64_t, 4, 2> float64_t4x2;
348+
typedef matrix<float64_t, 4, 3> float64_t4x3;
349+
typedef matrix<float64_t, 4, 4> float64_t4x4;
350+
118351
} // namespace hlsl
119352

120353
#endif //_HLSL_HLSL_BASIC_TYPES_H_

0 commit comments

Comments
 (0)