@@ -204,38 +204,21 @@ func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
204
204
return %res_0 : vector <4 xf32 >
205
205
}
206
206
207
-
208
207
// -----
209
208
210
209
// CHECK-LABEL: func @scaled_mfma_ugly_shapes
211
- // CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
212
- // CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
213
- // CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[0] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
214
- // CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
215
210
// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
216
211
// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
217
212
// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
218
213
// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
219
- func.func @scaled_mfma_ugly_shapes (%opA: vector <32 xf4 E2 M1 FN>, %opB: vector <32 xf4 E2 M1 FN>, %scalesA: vector <5 x5 xf8 E8 M0 FNU>, %scalesB: vector <7 x23 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector < 4 x f32 >, vector < 4 x f32 >, vector < 4 x f32 >, vector < 4 x f32 > ) {
214
+ func.func @scaled_mfma_ugly_shapes (%opA: vector <32 xf4 E2 M1 FN>, %opB: vector <32 xf4 E2 M1 FN>, %scalesA: vector <5 x5 xf8 E8 M0 FNU>, %scalesB: vector <7 x23 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >) {
220
215
%cst_0 = arith.constant dense <0.000000e+00 > : vector <4 xf32 >
221
216
%cst_1 = arith.constant dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
222
- %scaleA_0_0 = vector.extract %scalesA [0 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
223
- %scaleA_0_1 = vector.extract %scalesA [1 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
224
- %scaleA_0_2 = vector.extract %scalesA [2 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
225
- %scaleA_0_3 = vector.extract %scalesA [3 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
226
217
%scaleA_0_4 = vector.extract %scalesA [4 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
227
218
%scaleA_0_5 = vector.extract %scalesA [4 , 1 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
228
219
%scaleA_0_6 = vector.extract %scalesA [4 , 2 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
229
220
%scaleA_0_7 = vector.extract %scalesA [4 , 3 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
230
221
231
- // idx = 138 + 8 = 146 => opsel = 2
232
- %scaleB_6_8 = vector.extract %scalesB [6 , 8 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
233
- // idx = 147 => opsel = 3
234
- %scaleB_6_9 = vector.extract %scalesB [6 , 9 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
235
- // idx = 148 => opsel = 0
236
- %scaleB_6_10 = vector.extract %scalesB [6 , 10 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
237
- // idx = 149 => opsel = 1
238
- %scaleB_6_11 = vector.extract %scalesB [6 , 11 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
239
222
// idx = 160 => opsel = 3 (last idx of last 4 bytes)
240
223
%scaleB_6_22 = vector.extract %scalesB [6 , 22 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
241
224
// idx = 159 => opsel = 3
@@ -245,31 +228,19 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
245
228
// idx = 157 => opsel = 1
246
229
%scaleB_6_19 = vector.extract %scalesB [6 , 19 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
247
230
248
- %sA_0_0 = vector.insert %scaleA_0_0 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
249
- %sA_0_1 = vector.insert %scaleA_0_1 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
250
- %sA_0_2 = vector.insert %scaleA_0_2 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
251
- %sA_0_3 = vector.insert %scaleA_0_3 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
252
231
%sA_0_4 = vector.insert %scaleA_0_4 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
253
232
%sA_0_5 = vector.insert %scaleA_0_5 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
254
233
%sA_0_6 = vector.insert %scaleA_0_6 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
255
234
%sA_0_7 = vector.insert %scaleA_0_7 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
256
235
257
- %sB_6_8 = vector.insert %scaleB_6_8 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
258
- %sB_6_9 = vector.insert %scaleB_6_9 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
259
- %sB_6_10 = vector.insert %scaleB_6_10 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
260
- %sB_6_11 = vector.insert %scaleB_6_11 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
261
236
%sB_6_22 = vector.insert %scaleB_6_22 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
262
237
%sB_6_21 = vector.insert %scaleB_6_21 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
263
238
%sB_6_20 = vector.insert %scaleB_6_20 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
264
239
%sB_6_19 = vector.insert %scaleB_6_19 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
265
240
266
- %res_0 = amdgpu.scaled_mfma (%sA_0_0 [0 ] * %opA ) * (%sB_6_8 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
267
- %res_1 = amdgpu.scaled_mfma (%sA_0_1 [0 ] * %opA ) * (%sB_6_9 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
268
- %res_2 = amdgpu.scaled_mfma (%sA_0_2 [0 ] * %opA ) * (%sB_6_10 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
269
- %res_3 = amdgpu.scaled_mfma (%sA_0_3 [0 ] * %opA ) * (%sB_6_11 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
270
241
%res_4 = amdgpu.scaled_mfma (%sA_0_4 [0 ] * %opA ) * (%sB_6_22 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
271
242
%res_5 = amdgpu.scaled_mfma (%sA_0_5 [0 ] * %opA ) * (%sB_6_21 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
272
243
%res_6 = amdgpu.scaled_mfma (%sA_0_6 [0 ] * %opA ) * (%sB_6_20 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
273
244
%res_7 = amdgpu.scaled_mfma (%sA_0_7 [0 ] * %opA ) * (%sB_6_19 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
274
- return %res_0 , %res_1 , %res_2 , %res_3 , % res_4 , %res_5 , %res_6 , %res_7 : vector < 4 x f32 >, vector < 4 x f32 >, vector < 4 x f32 >, vector < 4 x f32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >
245
+ return %res_4 , %res_5 , %res_6 , %res_7 : vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >
275
246
}
0 commit comments