[mlir][spirv] Add SPV_EXT_FP8 type support to SPIR-V TOSA ops#193199
Merged
davidegrohmann merged 1 commit intoApr 22, 2026
Conversation
Member
|
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Davide Grohmann (davidegrohmann) ChangesAdd SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type definitions and extending op constraints for the newly supported element types. Also update verifier coverage to reflect the new constraints:
Patch is 59.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/193199.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index c873e3069733c..4e3689e6001f6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -156,14 +156,20 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+ TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16]>,
+ TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16]>,
TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
+ TypeConstraintImplicationOn<"input", F8E4M3FN, "weight", [F8E4M3FN]>,
+ TypeConstraintImplicationOn<"input", F8E5M2, "weight", [F8E5M2]>,
TypeImpliesAccType<"input", I8, ["INT32"]>,
TypeImpliesAccType<"input", I16, ["INT48"]>,
TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
TypeImpliesAccType<"input", BF16, ["FP32"]>,
TypeImpliesAccType<"input", F32, ["FP32"]>,
+ TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
+ TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
AllElementTypesMatch<["bias", "output"]>,
AllElementTypesMatch<["input", "input_zp"]>,
AllElementTypesMatch<["weight", "weight_zp"]>])> {
@@ -249,7 +255,7 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm: $input
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input
);
let results = (outs
@@ -277,6 +283,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
TypeImpliesAccType<"input", BF16, ["FP32"]>,
TypeImpliesAccType<"input", F32, ["FP32"]>,
+ TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
+ TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
let summary = "Performs average pooling on the input.";
@@ -304,13 +312,13 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
SPIRV_TosaExtAccTypeAttr: $acc_type,
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
- SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
- SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $output_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $input,
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $output_zp
);
let results = (outs
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $output
);
let assemblyFormat = [{
@@ -361,11 +369,11 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
- SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $input,
+ SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $weight,
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
- SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
- SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $weight_zp
);
let results = (outs
@@ -416,11 +424,11 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
SPIRV_I32_1DTensorArmOfLength3Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D: $input,
- SPIRV_I8OrF16OrF32OrBF16_TensorArm5D: $weight,
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm5D: $input,
+ SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm5D: $weight,
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
- SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
- SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $weight_zp
);
let results = (outs
@@ -472,11 +480,11 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4>
SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
- SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $input,
+ SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $weight,
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
- SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
- SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $weight_zp
);
let results = (outs
@@ -557,6 +565,8 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
TypeConstraintImplicationOn<"A", BF16, "output", [F32]>,
TypeConstraintImplicationOn<"A", F16, "output", [F16, F32]>,
TypeConstraintImplicationOn<"A", F32, "output", [F32]>,
+ TypeConstraintImplicationOn<"A", F8E4M3FN, "output", [F16]>,
+ TypeConstraintImplicationOn<"A", F8E5M2, "output", [F16]>,
AllElementTypesMatch<["A", "A_zp", "B", "B_zp"]>]> {
let summary = "Matrix Multiplication operator.";
@@ -579,10 +589,10 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
}];
let arguments = (ins
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $A,
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $B,
- SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $A_zp,
- SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $B_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $A,
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $B,
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $A_zp,
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $B_zp
);
let results = (outs
@@ -634,11 +644,11 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $input
);
let results = (outs
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $output
);
let assemblyFormat = [{
@@ -734,11 +744,11 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
- SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
- SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $input,
+ SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $weight,
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
- SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
- SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $weight_zp
);
let results = (outs
@@ -2167,11 +2177,11 @@ def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
- Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm>: $input1
+ Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm>: $input1
);
let results = (outs
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
);
let assemblyFormat = [{
@@ -2214,13 +2224,13 @@ def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
}];
let arguments = (ins
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1,
SPIRV_I32_1DTensorArmOfEvenLength2To12: $padding,
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1: $pad_const
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $pad_const
);
let results = (outs
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
);
let assemblyFormat = [{
@@ -2267,12 +2277,12 @@ def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
}];
let arguments = (ins
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1,
SPIRV_I32_1DTensorArmOfLength1To6: $shape
);
let results = (outs
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
);
let assemblyFormat = [{
@@ -2315,11 +2325,11 @@ def SPIRV_TosaReverseOp : SPIRV_TosaOpWithResult<"Reverse", 57, [Pure,
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1
);
let results = (outs
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
);
let assemblyFormat = [{
@@ -2362,13 +2372,13 @@ def SPIRV_TosaSliceOp : SPIRV_TosaOpWithResult<"Slice", 58, [Pure,
}];
let arguments = (ins
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1,
SPIRV_I32_1DTensorArmOfLength1To6: $start,
SPIRV_I32_1DTensorArmOfLength1To6: $size
);
let results = (outs
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
);
let assemblyFormat = [{
@@ -2416,12 +2426,12 @@ def SPIRV_TosaTileOp : SPIRV_TosaOpWithResult<"Tile", 59, [Pure,
}];
let arguments = (ins
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1,
SPIRV_I32_1DTensorArmOfLength1To6: $multiples
);
let results = (outs
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
);
let assemblyFormat = [{
@@ -2466,11 +2476,11 @@ def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,
let arguments = (ins
SPIRV_I32_1DTensorArmOfLength1To6Attr: $perms,
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1
);
let results = (outs
- SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
);
let assemblyFormat = [{
@@ -2512,12 +2522,12 @@ def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
}];
let arguments = (ins
- SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values,
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $values,
SPIRV_I32_TensorArm2D: $indices
);
let results = (outs
- SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $output
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $output
);
let assemblyFormat = [{
@@ -2566,13 +2576,13 @@ def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
}];
let arguments = (ins
- SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_in,
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $values_in,
SPIRV_I32_TensorArm2D: $indices,
- SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $input
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $input
);
let results = (outs
- SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_out
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $values_out
);
let assemblyFormat = [{
@@ -2687,13 +2697,15 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
AllShapesMatch<["input", "output"]>,
- TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8]>,
- TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16]>,
+ TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8, F8E4M3FN, F8E5M2]>,
+ TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16, F8E4M3FN, F8E5M2]>,
TypeConstraintImplicationOn<"input", I16, "output", [F16, F32, I32, I8, SPIRV_Bool, BF16]>,
TypeConstraintImplicationOn<"input", I32, "output", [F16, F32, I16, I8, SPIRV_Bool, BF16]>,
TypeConstraintImplicationOn<"input", I8, "output", [F16, F32, I16, I32, SPIRV_Bool, BF16]>,
TypeConstraintImplicationOn<"input", SPIRV_Bool, "output", [I16, I32, I8]>,
- TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8]>]> {
+ TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8, F8E4M3FN, F8E5M2]>,
+ TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16, F32, BF16]>,
+ TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16, F32, BF16]>]> {
let summary = "Cast operation.";
let description = [{
@@ -2737,6 +2749,18 @@ def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
| int16 | bf16 |
| int32 | bf16 |
| int8 | bf16 |
+ | bf16 | fp8e4m3 |
+ | fp8e4m3 | bf16 |
+ | bf16 | fp8e5m2 |
+ | fp8e5m2 | bf16 |
+ | float16 | fp8e4m3 |
+ | float32 | fp8e4m3 |
+ | fp8e4m3 | float16 |
+ | fp8e4m3 | float32 |
+ | float16 | fp8e5m2 |
+ | float32 | fp8e5m2 |
+ | fp8e5m2 | float16 |
+ | fp8e5m2 | float32 |
References:
* https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
@@ -2750,11 +2774,11 @@ def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
}];
let arguments = (ins
- SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32_TensorArm: $input
+ SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32OrF8E4M3FNOrF8E5M2_TensorArm: $input
);
let results = (outs
- SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16_TensorArm: $output
+ SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
);
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 6c918aec28845..316dc02da5b1e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -22,16 +22,20 @@ def SPIRV_I8OrI16OrI32OrI64 : AnyIntOfWidths<[8, 16, 32, 64]>;
def SPIRV_I16OrI32 : AnyIntOfWidths<[16, 32]>;
def SPIRV_I32OrI64 : AnyIntOfWidths<[32, 64]>;
def SPIRV_F16OrF32OrBF16 : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR]>;
-def SPIRV_I8OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Int8, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR, SPIRV_Float8E4M3EXT, SPIRV_Float8E5M2EXT]>;
def SPIRV_I8OrI16OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_Int8, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_I8OrI16, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
def SPIRV_I32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Int32, SPIRV_F16OrF32OrBF16]>;
def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
def SPIRV_I32OrI64OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I32OrI64, SPIRV_F16OrF32OrBF16]>;
def SPIRV_I32OrI64OrF16OrF32 : AnyTypeOf<[SPIRV_I32OrI64, SPIRV_Float16, SPIRV_Float32]>;
def SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32OrI64, SPIRV_F16OrF32OrBF16]>;
-def SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
-def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_Bool, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
+def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
+def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_Bool, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
def SPIRV_I8OrI32 : AnyTypeOf<[SPIRV_Int8, SPIRV_Int32]>;
def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
@@ -57,23 +61,26 @@ def SPIRV_I32_TensorArm2D : TensorArmRankOf<[SPIRV_Int32], [2]>;
def SPIRV_F32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [1]>;
def SPIRV_I8OrI16_TensorArm1D : TensorArmRankOf<[SPIRV_I8OrI16], [1]>;
-def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [3]>;
-def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16], [3]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2], [3]>;
+def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2], [3]>;
def SPIRV_I32OrI64OrF16OrF32_TensorArm3D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32], [3]>;
def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [4]>;
-def SPIRV_I8OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16], [4]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2], [4]>;
+def SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2], [4]>;
def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [4]>;
def SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16], [4]>;
-def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [5]>;
-...
[truncated]
|
IgWod
reviewed
Apr 21, 2026
Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type definitions and extending op constraints for the newly supported element types. Also update verifier coverage to reflect the new constraints: - refresh existing negative tests whose diagnostics now list FP8 types - add negative tests for SPV_EXT_FP8-specific output, weight, accumulator, and cast restrictions Signed-off-by: Davide Grohmann <davide.grohmann@arm.com> Change-Id: Ie636acc87669a66b53410e7efbd3edafa6ee0da1
0e521c0 to
3e16658
Compare
IgWod
approved these changes
Apr 22, 2026
linuxlonelyeagle
pushed a commit
to linuxlonelyeagle/llvm-project
that referenced
this pull request
Apr 23, 2026
…93199) Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type definitions and extending op constraints for the newly supported element types. Also update verifier coverage to reflect the new constraints: - refresh existing negative tests whose diagnostics now list FP8 types - add negative tests for SPV_EXT_FP8-specific output, weight, accumulator, and cast restrictions Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
yingopq
pushed a commit
to yingopq/llvm-project
that referenced
this pull request
Apr 29, 2026
…93199) Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type definitions and extending op constraints for the newly supported element types. Also update verifier coverage to reflect the new constraints: - refresh existing negative tests whose diagnostics now list FP8 types - add negative tests for SPV_EXT_FP8-specific output, weight, accumulator, and cast restrictions Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
KHicketts
pushed a commit
to KHicketts/llvm-project
that referenced
this pull request
Apr 30, 2026
…93199) Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type definitions and extending op constraints for the newly supported element types. Also update verifier coverage to reflect the new constraints: - refresh existing negative tests whose diagnostics now list FP8 types - add negative tests for SPV_EXT_FP8-specific output, weight, accumulator, and cast restrictions Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type definitions and extending op constraints for the newly supported element types.
Also update verifier coverage to reflect the new constraints: