Skip to content

Commit f92db34

Browse files
committed
[mlir][amdgpu] Add scaled_ext_packed{8,16} operations
1 parent 8009a5b commit f92db34

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,48 @@ def AMDGPU_ExtPackedFp8Op :
112112
}];
113113
}
114114

115+
def AMDGPU_ScaledExtPacked8Op
116+
: AMDGPU_Op<"scaled_ext_packed8", [Pure]>,
117+
Arguments<(
118+
ins VectorOfLengthAndType<[8], [F4E2M1FN,F8E4M3FN,F8E5M2]>:$source,
119+
F32:$scale,
120+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index)>,
121+
Results<(
122+
outs AnyTypeOf<[FixedVectorOfLengthAndType<[8], [F32]>,
123+
FixedVectorOfLengthAndType<[8], [F16]>,
124+
FixedVectorOfLengthAndType<[8], [BF16]>]>:$res)> {
125+
let summary = "Extend a vector of packed floating point values";
126+
127+
let description = [{
128+
Extend and scale eight packed floats in to eight floats and return them.
129+
}];
130+
131+
let assemblyFormat = [{
132+
attr-dict $source `,` $scale `[` $index `]` `:` type($source) `to` type($res)
133+
}];
134+
}
135+
136+
def AMDGPU_ScaledExtPacked16Op
137+
: AMDGPU_Op<"scaled_ext_packed16", [Pure]>,
138+
Arguments<(
139+
ins VectorOfLengthAndType<[16], [F6E2M3FN, F6E3M2FN]>:$source,
140+
F32:$scale,
141+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index)>,
142+
Results<(
143+
outs AnyTypeOf<[FixedVectorOfLengthAndType<[16], [F32]>,
144+
FixedVectorOfLengthAndType<[16], [F16]>,
145+
FixedVectorOfLengthAndType<[16], [BF16]>]>:$res)> {
146+
let summary = "Extend a vector of packed floating point values";
147+
148+
let description = [{
149+
Extend and scale 16 packed floats to 16 floats and return them.
150+
}];
151+
152+
let assemblyFormat = [{
153+
attr-dict $source `,` $scale `[` $index `]` `:` type($source) `to` type($res)
154+
}];
155+
}
156+
115157
def AMDGPU_ScaledExtPackedOp
116158
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
117159
Arguments<(

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,61 @@ func.func @scaled_ext_scalar_f4e2m1_bf16(%v: vector<2xf4E2M1FN>, %scale: f32) ->
221221
func.return %ret : vector<2xbf16>
222222
}
223223

224+
// CHECK-LABEL: func.func @scaled_ext_packed8_fp4
225+
func.func @scaled_ext_packed8_fp4(%v: vector<8xf4E2M1FN>, %scale: f32) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
226+
// CHECK: amdgpu.scaled_ext_packed8
227+
%ret0 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf4E2M1FN> to vector<8xf16>
228+
// CHECK: amdgpu.scaled_ext_packed8
229+
%ret1 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf4E2M1FN> to vector<8xbf16>
230+
// CHECK: amdgpu.scaled_ext_packed8
231+
%ret2 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf4E2M1FN> to vector<8xf32>
232+
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
233+
}
234+
235+
// CHECK-LABEL: func.func @scaled_ext_packed8_fp8
236+
func.func @scaled_ext_packed8_fp8(%v: vector<8xf8E4M3FN>, %scale: f32) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
237+
// CHECK: amdgpu.scaled_ext_packed8
238+
%ret0 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E4M3FN> to vector<8xf16>
239+
// CHECK: amdgpu.scaled_ext_packed8
240+
%ret1 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E4M3FN> to vector<8xbf16>
241+
// CHECK: amdgpu.scaled_ext_packed8
242+
%ret2 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E4M3FN> to vector<8xf32>
243+
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
244+
}
245+
246+
// CHECK-LABEL: func.func @scaled_ext_packed8_bf8
247+
func.func @scaled_ext_packed8_bf8(%v: vector<8xf8E5M2>, %scale: f32) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
248+
// CHECK: amdgpu.scaled_ext_packed8
249+
%ret0 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E5M2> to vector<8xf16>
250+
// CHECK: amdgpu.scaled_ext_packed8
251+
%ret1 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E5M2> to vector<8xbf16>
252+
// CHECK: amdgpu.scaled_ext_packed8
253+
%ret2 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E5M2> to vector<8xf32>
254+
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
255+
}
256+
257+
// CHECK-LABEL: func.func @scaled_ext_packed16_fp6
258+
func.func @scaled_ext_packed16_fp6(%v: vector<16xf6E2M3FN>, %scale: f32) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
259+
// CHECK: amdgpu.scaled_ext_packed16
260+
%ret0 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E2M3FN> to vector<16xf16>
261+
// CHECK: amdgpu.scaled_ext_packed16
262+
%ret1 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E2M3FN> to vector<16xbf16>
263+
// CHECK: amdgpu.scaled_ext_packed16
264+
%ret2 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E2M3FN> to vector<16xf32>
265+
func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32>
266+
}
267+
268+
// CHECK-LABEL: func.func @scaled_ext_packed16_bf16
269+
func.func @scaled_ext_packed16_bf16(%v: vector<16xf6E3M2FN>, %scale: f32) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
270+
// CHECK: amdgpu.scaled_ext_packed16
271+
%ret0 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E3M2FN> to vector<16xf16>
272+
// CHECK: amdgpu.scaled_ext_packed16
273+
%ret1 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E3M2FN> to vector<16xbf16>
274+
// CHECK: amdgpu.scaled_ext_packed16
275+
%ret2 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E3M2FN> to vector<16xf32>
276+
func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32>
277+
}
278+
224279
// CHECK-LABEL: func.func @packed_scaled_trunc_f8e4m3_f32
225280
// CHECK: amdgpu.packed_scaled_trunc
226281
func.func @packed_scaled_trunc_f8e4m3_f32(%v: vector<2xf32>, %scale: f32) -> vector<4xf8E4M3FN> {

0 commit comments

Comments
 (0)