/
AMX.td
313 lines (271 loc) · 11.1 KB
/
AMX.td
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the basic operations for the AMX dialect.
//
// The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
// multiply unit (TMUL), a tile control register (TILECFG), and eight
// tile registers TMM0 through TMM7 (TILEDATA).
//
// The AMX dialect provides a bridge between MLIR concepts, such as
// 2-d vector, operations, and memrefs, and the lower level details
// of Intel AMX, such as configuration setup, tile sizes, instructions,
// and tile release.
//
// Note that since configuration changes (implicit at dialect level) are
// costly, it is highly recommended to use the AMX dialect on same-shaped
// vectors, at least within a single method.
//
// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
//
//===----------------------------------------------------------------------===//
#ifndef AMX
#define AMX
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
// AMX dialect definition.
//===----------------------------------------------------------------------===//
def AMX_Dialect : Dialect {
let name = "amx";
let cppNamespace = "::mlir::amx";
let description = [{
The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
multiply unit (TMUL), a tile control register (TILECFG), and eight
tile registers TMM0 through TMM7 (TILEDATA).
This `AMX` dialect provides a bridge between MLIR concepts such as
vectors and memrefs and the lower level LLVM IR support of AMX.
The dialect is split into user-facing AMX ops (AMX_Op) and
backend-facing intrinsic ops (AMX_IntrOp).
Note that since configuration changes (implicit at dialect level) are
costly, it is highly recommended to use the AMX dialect on same-shaped
vectors, at least within a single method.
For details, see the Intel documentation:
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
}];
}
//===----------------------------------------------------------------------===//
// AMX Op and IntrOp definitions.
//===----------------------------------------------------------------------===//
class AMX_Op<string mnemonic, list<Trait> traits = []> :
Op<AMX_Dialect, mnemonic, traits> {}
// The "internal" intrinsics are meant for compiler usage.
class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
LLVM_IntrOpBase<AMX_Dialect, mnemonic,
"x86_" # !subst(".", "_", mnemonic) # "_internal",
[], [], traits, numResults>;
//===----------------------------------------------------------------------===//
// AMX Op definitions (user facing).
//===----------------------------------------------------------------------===//
//
// Tile reset.
//
def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
let summary = "tile zero operation";
let description = [{
Zeroes the destination tile, with the shape defined by the 2-dim
vector type of the result. This is eventually lowered into the
"tilezero" instruction with the corresponding tile configuration.
Example:
```mlir
%0 = amx.tile_zero : vector<16x16xbf16>
```
}];
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
}
}];
let assemblyFormat = "attr-dict `:` type($res)";
let hasVerifier = 1;
}
//
// Tile memory operations.
//
def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
let summary = "tile load operation";
let description = [{
Loads a tile from memory defined by a base and indices, with the
shape defined by the 2-dim vector type of the result. This is
eventually lowered into the "tileloadd" instruction with the
corresponding tile configuration.
Example:
```mlir
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
"type($base) `into` type($res)";
let hasVerifier = 1;
}
def TileStoreOp : AMX_Op<"tile_store"> {
let summary = "tile store operation";
let description = [{
Stores a tile to memory defined by a base and indices, with the
shape defined by the 2-dim vector type of the value. This is
eventually lowered into the "tilestored" instruction with the
corresponding tile configuration.
Example:
```mlir
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getVal().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
"type($base) `,` type($val)";
let hasVerifier = 1;
}
//
// Tile arithmetic operations.
//
def TileMulFOp : AMX_Op<"tile_mulf", [
Pure, AllTypesMatch<["acc", "res"]>]> {
let summary = "tile multiplication operation (floating-point)";
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
pairs of "bf16"). The operation is eventually lowered into the
"tdpbf16ps" instruction with the corresponding tile configuration.
Example:
```mlir
%0 = amx.tile_mulf %a, %b, %c
: vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
```
}];
let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
VectorOfRankAndType<[2], [F32, BF16]>:$acc);
let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
let extraClassDeclaration = [{
VectorType getLhsVectorType() {
return ::llvm::cast<VectorType>(getLhs().getType());
}
VectorType getRhsVectorType() {
return ::llvm::cast<VectorType>(getRhs().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
let hasVerifier = 1;
}
def TileMulIOp : AMX_Op<"tile_muli", [
Pure, AllTypesMatch<["acc", "res"]>]> {
let summary = "tile multiplication operation (integer)";
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
combinations (4 bytes packed into dwords in the columns of both the
source operand tiles; the zero or sign extension is specified with
the attributes and default to sign extended). The operation is eventually
lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
instructions with the corresponding tile configuration.
Example:
```mlir
%0 = amx.tile_muli %a zext, %b zext, %c
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
```
}];
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
VectorOfRankAndType<[2], [I32, I8]>:$acc,
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
let extraClassDeclaration = [{
VectorType getLhsVectorType() {
return ::llvm::cast<VectorType>(getLhs().getType());
}
VectorType getRhsVectorType() {
return ::llvm::cast<VectorType>(getRhs().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// AMX IntrOp definitions (LLVM compiler facing).
//===----------------------------------------------------------------------===//
//
// Tile reset. Parameters define the tile size.
//
def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
Arguments<(ins AnyInteger, AnyInteger)>;
//
// Tile memory operations. Parameters define the tile size,
// base address, and stride between consecutive rows for the
// memory operation.
//
def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
Arguments<(ins AnyInteger,
AnyInteger, LLVM_AnyPointer, AnyInteger)>;
def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
Arguments<(ins AnyInteger,
AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
//
// Tile multiplication operations (series of dot products). Parameters
// define the tile sizes and source and destination tiles for the
// operation. Note that the prefix "tdp" stands for tile dot product.
//
// Dot product of bf16 tiles into f32 tile.
def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
// Dot product of i8 tiles into i32 tile (with sign/zero extension).
def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
// Dot product of i8 tiles into i32 tile (with zero/sign extension).
def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
// Dot product of i8 tiles into i32 tile (with zero/zero extension).
def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
#endif // AMX