Skip to content

Commit c3832b0

Browse files
committed
[mlir][amdgpu] Use existing scaled_ext_packed instead of new ops
1 parent f92db34 commit c3832b0

File tree

4 files changed

+60
-91
lines changed

4 files changed

+60
-91
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -112,60 +112,27 @@ def AMDGPU_ExtPackedFp8Op :
112112
}];
113113
}
114114

115-
def AMDGPU_ScaledExtPacked8Op
116-
: AMDGPU_Op<"scaled_ext_packed8", [Pure]>,
117-
Arguments<(
118-
ins VectorOfLengthAndType<[8], [F4E2M1FN,F8E4M3FN,F8E5M2]>:$source,
119-
F32:$scale,
120-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index)>,
121-
Results<(
122-
outs AnyTypeOf<[FixedVectorOfLengthAndType<[8], [F32]>,
123-
FixedVectorOfLengthAndType<[8], [F16]>,
124-
FixedVectorOfLengthAndType<[8], [BF16]>]>:$res)> {
125-
let summary = "Extend a vector of packed floating point values";
126-
127-
let description = [{
128-
Extend and scale eight packed floats in to eight floats and return them.
129-
}];
130-
131-
let assemblyFormat = [{
132-
attr-dict $source `,` $scale `[` $index `]` `:` type($source) `to` type($res)
133-
}];
134-
}
135-
136-
def AMDGPU_ScaledExtPacked16Op
137-
: AMDGPU_Op<"scaled_ext_packed16", [Pure]>,
138-
Arguments<(
139-
ins VectorOfLengthAndType<[16], [F6E2M3FN, F6E3M2FN]>:$source,
140-
F32:$scale,
141-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index)>,
142-
Results<(
143-
outs AnyTypeOf<[FixedVectorOfLengthAndType<[16], [F32]>,
144-
FixedVectorOfLengthAndType<[16], [F16]>,
145-
FixedVectorOfLengthAndType<[16], [BF16]>]>:$res)> {
146-
let summary = "Extend a vector of packed floating point values";
147-
148-
let description = [{
149-
Extend and scale 16 packed floats to 16 floats and return them.
150-
}];
151-
152-
let assemblyFormat = [{
153-
attr-dict $source `,` $scale `[` $index `]` `:` type($source) `to` type($res)
154-
}];
155-
}
156-
157115
def AMDGPU_ScaledExtPackedOp
158116
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
159117
Arguments<(
160118
ins AnyTypeOf<[VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>,
161-
VectorOfLengthAndType<[1, 2, 3, 4, 5, 6, 7, 8],
162-
[F4E2M1FN]>]>:$source,
119+
VectorOfLengthAndType<[1, 2, 3, 4, 5, 6, 7, 8],[F4E2M1FN]>,
120+
VectorOfLengthAndType<[8],[F4E2M1FN, F8E4M3FN, F8E5M2]>,
121+
VectorOfLengthAndType<[16], [F6E2M3FN, F6E3M2FN]>]>:$source,
163122
F32:$scale,
164-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index)>,
123+
OptionalAttr<ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>>:$index,
124+
OptionalAttr<ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>>:$scaleSel)>,
165125
Results<(
166126
outs AnyTypeOf<[FixedVectorOfLengthAndType<[2], [F32]>,
167127
FixedVectorOfLengthAndType<[2], [F16]>,
168-
FixedVectorOfLengthAndType<[2], [BF16]>]>:$res)> {
128+
FixedVectorOfLengthAndType<[2], [BF16]>,
129+
FixedVectorOfLengthAndType<[8], [F32]>,
130+
FixedVectorOfLengthAndType<[8], [F16]>,
131+
FixedVectorOfLengthAndType<[8], [BF16]>,
132+
FixedVectorOfLengthAndType<[16], [F32]>,
133+
FixedVectorOfLengthAndType<[16], [F16]>,
134+
FixedVectorOfLengthAndType<[16], [BF16]>]>:$res)> {
135+
169136
let summary = "Extend a vector of packed floating point values";
170137

171138
let description = [{
@@ -181,8 +148,9 @@ def AMDGPU_ScaledExtPackedOp
181148
the remaining values in the <2 x i8> will be filled with
182149
undefined values as needed.
183150
}];
151+
184152
let assemblyFormat = [{
185-
attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res)
153+
attr-dict $source ( `[` $index^ `]` )? `,` $scale ( `[` $scaleSel^ `]` )? `:` type($source) `to` type($res)
186154
}];
187155
}
188156

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,31 +1510,31 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
15101510

15111511
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
15121512
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1513-
op, destVecType, i32Source, scale, op.getIndex());
1513+
op, destVecType, i32Source, scale, *op.getIndex());
15141514
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF16())
15151515
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1516-
op, destVecType, i32Source, scale, op.getIndex());
1516+
op, destVecType, i32Source, scale, *op.getIndex());
15171517
else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isBF16())
15181518
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1519-
op, destVecType, i32Source, scale, op.getIndex());
1519+
op, destVecType, i32Source, scale, *op.getIndex());
15201520
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF32())
15211521
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1522-
op, destVecType, i32Source, scale, op.getIndex());
1522+
op, destVecType, i32Source, scale, *op.getIndex());
15231523
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isF16())
15241524
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1525-
op, destVecType, i32Source, scale, op.getIndex());
1525+
op, destVecType, i32Source, scale, *op.getIndex());
15261526
else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.isBF16())
15271527
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1528-
op, destVecType, i32Source, scale, op.getIndex());
1528+
op, destVecType, i32Source, scale, *op.getIndex());
15291529
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF32())
15301530
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1531-
op, destVecType, i32Source, scale, op.getIndex());
1531+
op, destVecType, i32Source, scale, *op.getIndex());
15321532
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isF16())
15331533
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1534-
op, destVecType, i32Source, scale, op.getIndex());
1534+
op, destVecType, i32Source, scale, *op.getIndex());
15351535
else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.isBF16())
15361536
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1537-
op, destVecType, i32Source, scale, op.getIndex());
1537+
op, destVecType, i32Source, scale, *op.getIndex());
15381538
else
15391539
return failure();
15401540

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
482482
VectorType::get(1, inType), in);
483483
// TODO: replace this with non-packed ScaledExtOp
484484
Value scaleExt = amdgpu::ScaledExtPackedOp::create(
485-
rewriter, loc, extScaleResultType, inCast, scale, 0);
485+
rewriter, loc, extScaleResultType, inCast, scale,
486+
rewriter.getI32IntegerAttr(0), nullptr);
486487
scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
487488
return success();
488489
}
@@ -539,7 +540,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
539540
// TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
540541
Value scaleExt = amdgpu::ScaledExtPackedOp::create(
541542
rewriter, loc, extScaleResultType, inSlice, uniformScale,
542-
j / opOutWidth);
543+
rewriter.getI32IntegerAttr(j / opOutWidth), nullptr);
543544
if (outSliceWidth < opOutWidth) {
544545
scaleExt = vector::ExtractStridedSliceOp::create(
545546
rewriter, loc, scaleExt, 0, outSliceWidth, 1);

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -221,58 +221,58 @@ func.func @scaled_ext_scalar_f4e2m1_bf16(%v: vector<2xf4E2M1FN>, %scale: f32) ->
221221
func.return %ret : vector<2xbf16>
222222
}
223223

224-
// CHECK-LABEL: func.func @scaled_ext_packed8_fp4
224+
// CHECK-LABEL: func.func @scaled_ext_packed8_fp
225225
func.func @scaled_ext_packed8_fp4(%v: vector<8xf4E2M1FN>, %scale: f32) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
226-
// CHECK: amdgpu.scaled_ext_packed8
227-
%ret0 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf4E2M1FN> to vector<8xf16>
228-
// CHECK: amdgpu.scaled_ext_packed8
229-
%ret1 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf4E2M1FN> to vector<8xbf16>
230-
// CHECK: amdgpu.scaled_ext_packed8
231-
%ret2 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf4E2M1FN> to vector<8xf32>
226+
// CHECK: amdgpu.scaled_ext_packed
227+
%ret0 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<8xf4E2M1FN> to vector<8xf16>
228+
// CHECK: amdgpu.scaled_ext_packed
229+
%ret1 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<8xf4E2M1FN> to vector<8xbf16>
230+
// CHECK: amdgpu.scaled_ext_packed
231+
%ret2 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<8xf4E2M1FN> to vector<8xf32>
232232
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
233233
}
234234

235-
// CHECK-LABEL: func.func @scaled_ext_packed8_fp8
235+
// CHECK-LABEL: func.func @scaled_ext_packed8_fp
236236
func.func @scaled_ext_packed8_fp8(%v: vector<8xf8E4M3FN>, %scale: f32) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
237-
// CHECK: amdgpu.scaled_ext_packed8
238-
%ret0 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E4M3FN> to vector<8xf16>
239-
// CHECK: amdgpu.scaled_ext_packed8
240-
%ret1 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E4M3FN> to vector<8xbf16>
241-
// CHECK: amdgpu.scaled_ext_packed8
242-
%ret2 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E4M3FN> to vector<8xf32>
237+
// CHECK: amdgpu.scaled_ext_packed
238+
%ret0 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<8xf8E4M3FN> to vector<8xf16>
239+
// CHECK: amdgpu.scaled_ext_packed
240+
%ret1 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<8xf8E4M3FN> to vector<8xbf16>
241+
// CHECK: amdgpu.scaled_ext_packed
242+
%ret2 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<8xf8E4M3FN> to vector<8xf32>
243243
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
244244
}
245245

246-
// CHECK-LABEL: func.func @scaled_ext_packed8_bf8
246+
// CHECK-LABEL: func.func @scaled_ext_packed8_bf
247247
func.func @scaled_ext_packed8_bf8(%v: vector<8xf8E5M2>, %scale: f32) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
248-
// CHECK: amdgpu.scaled_ext_packed8
249-
%ret0 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E5M2> to vector<8xf16>
250-
// CHECK: amdgpu.scaled_ext_packed8
251-
%ret1 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E5M2> to vector<8xbf16>
252-
// CHECK: amdgpu.scaled_ext_packed8
253-
%ret2 = amdgpu.scaled_ext_packed8 %v, %scale[0] : vector<8xf8E5M2> to vector<8xf32>
248+
// CHECK: amdgpu.scaled_ext_packed
249+
%ret0 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<8xf8E5M2> to vector<8xf16>
250+
// CHECK: amdgpu.scaled_ext_packed
251+
%ret1 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<8xf8E5M2> to vector<8xbf16>
252+
// CHECK: amdgpu.scaled_ext_packed
253+
%ret2 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<8xf8E5M2> to vector<8xf32>
254254
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
255255
}
256256

257257
// CHECK-LABEL: func.func @scaled_ext_packed16_fp6
258258
func.func @scaled_ext_packed16_fp6(%v: vector<16xf6E2M3FN>, %scale: f32) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
259-
// CHECK: amdgpu.scaled_ext_packed16
260-
%ret0 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E2M3FN> to vector<16xf16>
261-
// CHECK: amdgpu.scaled_ext_packed16
262-
%ret1 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E2M3FN> to vector<16xbf16>
263-
// CHECK: amdgpu.scaled_ext_packed16
264-
%ret2 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E2M3FN> to vector<16xf32>
259+
// CHECK: amdgpu.scaled_ext_packed
260+
%ret0 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<16xf6E2M3FN> to vector<16xf16>
261+
// CHECK: amdgpu.scaled_ext_packed
262+
%ret1 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<16xf6E2M3FN> to vector<16xbf16>
263+
// CHECK: amdgpu.scaled_ext_packed
264+
%ret2 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<16xf6E2M3FN> to vector<16xf32>
265265
func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32>
266266
}
267267

268268
// CHECK-LABEL: func.func @scaled_ext_packed16_bf16
269269
func.func @scaled_ext_packed16_bf16(%v: vector<16xf6E3M2FN>, %scale: f32) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
270-
// CHECK: amdgpu.scaled_ext_packed16
271-
%ret0 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E3M2FN> to vector<16xf16>
272-
// CHECK: amdgpu.scaled_ext_packed16
273-
%ret1 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E3M2FN> to vector<16xbf16>
274-
// CHECK: amdgpu.scaled_ext_packed16
275-
%ret2 = amdgpu.scaled_ext_packed16 %v, %scale[0] : vector<16xf6E3M2FN> to vector<16xf32>
270+
// CHECK: amdgpu.scaled_ext_packed
271+
%ret0 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<16xf6E3M2FN> to vector<16xf16>
272+
// CHECK: amdgpu.scaled_ext_packed
273+
%ret1 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<16xf6E3M2FN> to vector<16xbf16>
274+
// CHECK: amdgpu.scaled_ext_packed
275+
%ret2 = amdgpu.scaled_ext_packed %v, %scale[0] : vector<16xf6E3M2FN> to vector<16xf32>
276276
func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32>
277277
}
278278

0 commit comments

Comments
 (0)