-
Notifications
You must be signed in to change notification settings - Fork 10.8k
/
GPUDialect.h
179 lines (147 loc) · 6.63 KB
/
GPUDialect.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
//===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the GPU kernel-related operations and puts them in the
// corresponding dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_GPU_IR_GPUDIALECT_H
#define MLIR_DIALECT_GPU_IR_GPUDIALECT_H
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
namespace gpu {
/// Utility class for the GPU dialect to represent triples of `Value`s
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
struct KernelDim3 {
Value x;
Value y;
Value z;
};
class AsyncTokenType
: public Type::TypeBase<AsyncTokenType, Type, TypeStorage> {
public:
// Used for generic hooks in TypeBase.
using Base::Base;
};
/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
/// and type.
struct MMAMatrixStorageType : public TypeStorage {
MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes,
Type elementType, StringRef operand)
: dimShapes(dimShapes), numDims(numDims), elementType(elementType),
operand(operand) {}
/// The hash key for uniquing.
using KeyTy = std::tuple<ArrayRef<int64_t>, Type, StringRef>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(getShape(), elementType, operand);
}
/// Construction.
static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
StringRef operand = allocator.copyInto(std::get<2>(key));
return new (allocator.allocate<MMAMatrixStorageType>())
MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(key),
operand);
}
ArrayRef<int64_t> getShape() const {
return ArrayRef<int64_t>(dimShapes, numDims);
}
StringRef getOperand() const { return operand; }
/// Reference to the shape of the MMA matrix.
const int64_t *dimShapes;
/// Number of dimensions in the MMA matrix.
unsigned numDims;
/// Element type of elements held in the MMA matrix.
Type elementType;
/// MMA operand that this MMAMatrix holds. The general form of operation this
/// type supports is given by the equation C += A*B. This field specifies
/// which operand in the given equation is held by this type. The valid values
/// are "AOp", "BOp" and "COp".
StringRef operand;
};
/// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply
/// accumulate operations. MMAMatrices are taken as direct operands by these
/// operations and are also produced as results. These matrices are meant to
/// reside in the registers. A limited number of pointwise operations can be
/// performed on these matrices, i.e., operations which operate uniformly on
/// all the elements in the matrix and do not change the order of matrix
/// elements. The above conditions exist because the layout of matrix elements
/// inside the matrix is opaque i.e., the elements may be present in the
/// matrix in any order. The general usage of this type is shown as follows:-
///
/// %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 :
/// index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
///
/// The MMAMatrixType describes the shape of the matrix being loaded and the
/// operand being loaded too. The operand needs to be specified to aid the
/// lowering of this type to dialects such as NVVM where each workitem may
/// hold different amount of elements depending on the elementType of the
/// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type
/// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
/// are:-
///
/// %3 = gpu.subgroup_mma_compute %0, %1, %2 :
/// !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">
/// -> !gpu.mma_matrix<16x16xf32, "COp">
///
///
/// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
/// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
// TODO: consider moving this to ODS.
class MMAMatrixType
: public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
public:
using Base::Base;
/// Get MMAMatrixType and verify construction Invariants.
static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
/// Get MMAMatrixType at a particular location and verify construction
/// Invariants.
static MMAMatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
/// Check if a type is valid a MMAMatrixType elementType.
static bool isValidElementType(Type elementType);
/// Verify that shape and elementType are actually allowed for the
/// MMAMatrixType.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
/// Get number of dims.
unsigned getNumDims() const;
/// Get shape of the matrix.
ArrayRef<int64_t> getShape() const;
/// Get elementType of a single element.
Type getElementType() const;
/// The general form of operation this type supports is given by the equation
/// C += A*B. This function returns which operand in the given equation is
/// held by this type. String returned can be one of"AOp", "BOp" and "COp".
StringRef getOperand() const;
};
// Adds a `gpu.async.token` to the front of the argument list.
void addAsyncDependency(Operation *op, Value token);
} // namespace gpu
} // namespace mlir
#include "mlir/Dialect/GPU/IR/GPUOpsEnums.h.inc"
#include "mlir/Dialect/GPU/IR/GPUOpsDialect.h.inc"
#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/GPU/IR/GPUOps.h.inc"
#endif // MLIR_DIALECT_GPU_IR_GPUDIALECT_H