-
Notifications
You must be signed in to change notification settings - Fork 11.1k
/
SPIRVTypes.h
494 lines (385 loc) · 18 KB
/
SPIRVTypes.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
//===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the types in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
#define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include <cstdint>
#include <tuple>
namespace mlir {
namespace spirv {
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
struct ImageTypeStorage;
struct JointMatrixTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
struct SampledImageTypeStorage;
struct StructTypeStorage;
} // namespace detail
// Base SPIR-V type for providing availability queries.
class SPIRVType : public Type {
public:
using Type::Type;
static bool classof(Type type);
bool isScalarOrVector();
/// The extension requirements for each type are following the
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
/// convention.
using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<Extension>>;
/// Appends to `extensions` the extensions needed for this type to appear in
/// the given `storage` class. This method does not guarantee the uniqueness
/// of extensions; the same extension may be appended multiple times.
void getExtensions(ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
/// The capability requirements for each type are following the
/// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
/// convention.
using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<Capability>>;
/// Appends to `capabilities` the capabilities needed for this type to appear
/// in the given `storage` class. This method does not guarantee the
/// uniqueness of capabilities; the same capability may be appended multiple
/// times.
void getCapabilities(CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
/// Returns the size in bytes for each type. If no size can be calculated,
/// returns `std::nullopt`. Note that if the type has explicit layout, it is
/// also taken into account in calculation.
std::optional<int64_t> getSizeInBytes();
};
// SPIR-V scalar type: bool type, integer type, floating point type.
class ScalarType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
static bool classof(Type type);
/// Returns true if the given integer type is valid for the SPIR-V dialect.
static bool isValid(FloatType);
/// Returns true if the given float type is valid for the SPIR-V dialect.
static bool isValid(IntegerType);
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
std::optional<int64_t> getSizeInBytes();
};
// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
class CompositeType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
static bool classof(Type type);
/// Returns true if the given vector type is valid for the SPIR-V dialect.
static bool isValid(VectorType);
/// Return the number of elements of the type. This should only be called if
/// hasCompileTimeKnownNumElements is true.
unsigned getNumElements() const;
Type getElementType(unsigned) const;
/// Return true if the number of elements is known at compile time and is not
/// implementation dependent.
bool hasCompileTimeKnownNumElements() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
std::optional<int64_t> getSizeInBytes();
};
// SPIR-V array type
class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
detail::ArrayTypeStorage> {
public:
using Base::Base;
static constexpr StringLiteral name = "spirv.array";
static ArrayType get(Type elementType, unsigned elementCount);
/// Returns an array type with the given stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount,
unsigned stride);
unsigned getNumElements() const;
Type getElementType() const;
/// Returns the array stride in bytes. 0 means no stride decorated on this
/// type.
unsigned getArrayStride() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
/// Returns the array size in bytes. Since array type may have an explicit
/// stride declaration (in bytes), we also include it in the calculation.
std::optional<int64_t> getSizeInBytes();
};
// SPIR-V image type
class ImageType
: public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
public:
using Base::Base;
static constexpr StringLiteral name = "spirv.image";
static ImageType
get(Type elementType, Dim dim,
ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
ImageFormat format = ImageFormat::Unknown) {
return ImageType::get(
std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
elementType, dim, depth, arrayed, samplingInfo, samplerUse,
format));
}
static ImageType
get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
Type getElementType() const;
Dim getDim() const;
ImageDepthInfo getDepthInfo() const;
ImageArrayedInfo getArrayedInfo() const;
ImageSamplingInfo getSamplingInfo() const;
ImageSamplerUseInfo getSamplerUseInfo() const;
ImageFormat getImageFormat() const;
// TODO: Add support for Access qualifier
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V pointer type
class PointerType : public Type::TypeBase<PointerType, SPIRVType,
detail::PointerTypeStorage> {
public:
using Base::Base;
static constexpr StringLiteral name = "spirv.pointer";
static PointerType get(Type pointeeType, StorageClass storageClass);
Type getPointeeType() const;
StorageClass getStorageClass() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V run-time array type
class RuntimeArrayType
: public Type::TypeBase<RuntimeArrayType, SPIRVType,
detail::RuntimeArrayTypeStorage> {
public:
using Base::Base;
static constexpr StringLiteral name = "spirv.rtarray";
static RuntimeArrayType get(Type elementType);
/// Returns a runtime array type with the given stride in bytes.
static RuntimeArrayType get(Type elementType, unsigned stride);
Type getElementType() const;
/// Returns the array stride in bytes. 0 means no stride decorated on this
/// type.
unsigned getArrayStride() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V sampled image type
class SampledImageType
: public Type::TypeBase<SampledImageType, SPIRVType,
detail::SampledImageTypeStorage> {
public:
using Base::Base;
static constexpr StringLiteral name = "spirv.sampled_image";
static SampledImageType get(Type imageType);
static SampledImageType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type imageType);
Type getImageType() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<spirv::StorageClass> storage = std::nullopt);
void
getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<spirv::StorageClass> storage = std::nullopt);
};
/// SPIR-V struct type. Two kinds of struct types are supported:
/// - Literal: a literal struct type is uniqued by its fields (types + offset
/// info + decoration info).
/// - Identified: an indentified struct type is uniqued by its string identifier
/// (name). This is useful in representing recursive structs. For example, the
/// following C struct:
///
/// struct A {
/// A* next;
/// };
///
/// would be represented in MLIR as:
///
/// !spirv.struct<A, (!spirv.ptr<!spirv.struct<A>, Generic>)>
///
/// In the above, expressing recursive struct types is accomplished by giving a
/// recursive struct a unique identified and using that identifier in the struct
/// definition for recursive references.
class StructType
: public Type::TypeBase<StructType, CompositeType,
detail::StructTypeStorage, TypeTrait::IsMutable> {
public:
using Base::Base;
// Type for specifying the offset of the struct members
using OffsetInfo = uint32_t;
static constexpr StringLiteral name = "spirv.struct";
// Type for specifying the decoration(s) on struct members
struct MemberDecorationInfo {
uint32_t memberIndex : 31;
uint32_t hasValue : 1;
Decoration decoration;
uint32_t decorationValue;
MemberDecorationInfo(uint32_t index, uint32_t hasValue,
Decoration decoration, uint32_t decorationValue)
: memberIndex(index), hasValue(hasValue), decoration(decoration),
decorationValue(decorationValue) {}
bool operator==(const MemberDecorationInfo &other) const {
return (this->memberIndex == other.memberIndex) &&
(this->decoration == other.decoration) &&
(this->decorationValue == other.decorationValue);
}
bool operator<(const MemberDecorationInfo &other) const {
return this->memberIndex < other.memberIndex ||
(this->memberIndex == other.memberIndex &&
static_cast<uint32_t>(this->decoration) <
static_cast<uint32_t>(other.decoration));
}
};
/// Construct a literal StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {});
/// Construct an identified StructType. This creates a StructType whose body
/// (member types, offset info, and decorations) is not set yet. A call to
/// StructType::trySetBody(...) must follow when the StructType contents are
/// available (e.g. parsed or deserialized).
///
/// Note: If another thread creates (or had already created) a struct with the
/// same identifier, that struct will be returned as a result.
static StructType getIdentified(MLIRContext *context, StringRef identifier);
/// Construct a (possibly identified) StructType with no members.
///
/// Note: this method might fail in a multi-threaded setup if another thread
/// created an identified struct with the same identifier but with different
/// contents before returning. In which case, an empty (default-constructed)
/// StructType is returned.
static StructType getEmpty(MLIRContext *context, StringRef identifier = "");
/// For literal structs, return an empty string.
/// For identified structs, return the struct's identifier.
StringRef getIdentifier() const;
/// Returns true if the StructType is identified.
bool isIdentified() const;
unsigned getNumElements() const;
Type getElementType(unsigned) const;
TypeRange getElementTypes() const;
bool hasOffset() const;
uint64_t getMemberOffset(unsigned) const;
// Returns in `memberDecorations` the Decorations (apart from Offset)
// associated with all members of the StructType.
void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo>
&memberDecorations) const;
// Returns in `decorationsInfo` all the Decorations (apart from Offset)
// associated with the `i`-th member of the StructType.
void getMemberDecorations(
unsigned i,
SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
/// Sets the contents of an incomplete identified StructType. This method must
/// be called only for identified StructTypes and it must be called only once
/// per instance. Otherwise, failure() is returned.
LogicalResult
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {});
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
llvm::hash_code
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
// SPIR-V KHR cooperative matrix type
class CooperativeMatrixType
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
detail::CooperativeMatrixTypeStorage> {
public:
using Base::Base;
static constexpr StringLiteral name = "spirv.coopmatrix";
static CooperativeMatrixType get(Type elementType, uint32_t rows,
uint32_t columns, Scope scope,
CooperativeMatrixUseKHR use);
Type getElementType() const;
/// Returns the scope of the matrix.
Scope getScope() const;
/// Returns the number of rows of the matrix.
uint32_t getRows() const;
/// Returns the number of columns of the matrix.
uint32_t getColumns() const;
/// Returns the use parameter of the cooperative matrix.
CooperativeMatrixUseKHR getUse() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V joint matrix type
class JointMatrixINTELType
: public Type::TypeBase<JointMatrixINTELType, CompositeType,
detail::JointMatrixTypeStorage> {
public:
using Base::Base;
static constexpr StringLiteral name = "spirv.jointmatrix";
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
unsigned columns, MatrixLayout matrixLayout);
Type getElementType() const;
/// Return the scope of the joint matrix.
Scope getScope() const;
/// return the number of rows of the matrix.
unsigned getRows() const;
/// return the number of columns of the matrix.
unsigned getColumns() const;
/// return the layout of the matrix
MatrixLayout getMatrixLayout() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
// SPIR-V matrix type
class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
detail::MatrixTypeStorage> {
public:
using Base::Base;
static constexpr StringLiteral name = "spirv.matrix";
static MatrixType get(Type columnType, uint32_t columnCount);
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
/// Returns true if the matrix elements are vectors of float elements.
static bool isValidColumnType(Type columnType);
Type getColumnType() const;
/// Returns the number of rows.
unsigned getNumRows() const;
/// Returns the number of columns.
unsigned getNumColumns() const;
/// Returns total number of elements (rows*columns).
unsigned getNumElements() const;
/// Returns the elements' type (i.e, single element type).
Type getElementType() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
} // namespace spirv
} // namespace mlir
#endif // MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_