Skip to content

[mlir][spirv] Add SPV_EXT_FP8 type support to SPIR-V TOSA ops#193199

Merged
davidegrohmann merged 1 commit into
llvm:mainfrom
davidegrohmann:mlir-spv-fp8-support-spirv-tosa
Apr 22, 2026
Merged

[mlir][spirv] Add SPV_EXT_FP8 type support to SPIR-V TOSA ops#193199
davidegrohmann merged 1 commit into
llvm:mainfrom
davidegrohmann:mlir-spv-fp8-support-spirv-tosa

Conversation

@davidegrohmann
Copy link
Copy Markdown
Contributor

Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type definitions and extending op constraints for the newly supported element types.

Also update verifier coverage to reflect the new constraints:

  • refresh existing negative tests whose diagnostics now list FP8 types
  • add negative tests for SPV_EXT_FP8-specific output, weight, accumulator, and cast restrictions

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 21, 2026

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Davide Grohmann (davidegrohmann)

Changes

Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type definitions and extending op constraints for the newly supported element types.

Also update verifier coverage to reflect the new constraints:

  • refresh existing negative tests whose diagnostics now list FP8 types
  • add negative tests for SPV_EXT_FP8-specific output, weight, accumulator, and cast restrictions

Patch is 59.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/193199.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td (+76-52)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td (+20-13)
  • (modified) mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir (+279-5)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index c873e3069733c..4e3689e6001f6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -156,14 +156,20 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
     TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
     TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
     TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
+    TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16]>,
+    TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16]>,
     TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
     TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
     TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
+    TypeConstraintImplicationOn<"input", F8E4M3FN, "weight", [F8E4M3FN]>,
+    TypeConstraintImplicationOn<"input", F8E5M2, "weight", [F8E5M2]>,
     TypeImpliesAccType<"input", I8, ["INT32"]>,
     TypeImpliesAccType<"input", I16, ["INT48"]>,
     TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
     TypeImpliesAccType<"input", BF16, ["FP32"]>,
     TypeImpliesAccType<"input", F32, ["FP32"]>,
+    TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
+    TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
     AllElementTypesMatch<["bias", "output"]>,
     AllElementTypesMatch<["input", "input_zp"]>,
     AllElementTypesMatch<["weight", "weight_zp"]>])> {
@@ -249,7 +255,7 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
   let arguments = (ins
     SPIRV_TensorArmAxisAttr: $axis,
     SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm: $input
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input
   );
 
   let results = (outs
@@ -277,6 +283,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
   TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
   TypeImpliesAccType<"input", BF16, ["FP32"]>,
   TypeImpliesAccType<"input", F32, ["FP32"]>,
+  TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
+  TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
   AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
   let summary = "Performs average pooling on the input.";
 
@@ -304,13 +312,13 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
     SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
     SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $output_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $input,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $input_zp,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $output_zp
   );
 
   let results = (outs
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $output
   );
 
   let assemblyFormat = [{
@@ -361,11 +369,11 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
     SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
-    SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $input,
+    SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $weight,
     SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
-    SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $input_zp,
+    SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $weight_zp
   );
 
   let results = (outs
@@ -416,11 +424,11 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
     SPIRV_I32_1DTensorArmOfLength3Attr: $dilation,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D: $input,
-    SPIRV_I8OrF16OrF32OrBF16_TensorArm5D: $weight,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm5D: $input,
+    SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm5D: $weight,
     SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
-    SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $input_zp,
+    SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $weight_zp
   );
 
   let results = (outs
@@ -472,11 +480,11 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4>
     SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
-    SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $input,
+    SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $weight,
     SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
-    SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $input_zp,
+    SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $weight_zp
   );
 
   let results = (outs
@@ -557,6 +565,8 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
   TypeConstraintImplicationOn<"A", BF16, "output", [F32]>,
   TypeConstraintImplicationOn<"A", F16, "output", [F16, F32]>,
   TypeConstraintImplicationOn<"A", F32, "output", [F32]>,
+  TypeConstraintImplicationOn<"A", F8E4M3FN, "output", [F16]>,
+  TypeConstraintImplicationOn<"A", F8E5M2, "output", [F16]>,
   AllElementTypesMatch<["A", "A_zp", "B", "B_zp"]>]> {
   let summary = "Matrix Multiplication operator.";
 
@@ -579,10 +589,10 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
   }];
 
   let arguments = (ins
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $A,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $B,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $A_zp,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $B_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $A,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $B,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $A_zp,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $B_zp
   );
 
   let results = (outs
@@ -634,11 +644,11 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
     SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
     SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
     SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $input
   );
 
   let results = (outs
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $output
   );
 
   let assemblyFormat = [{
@@ -734,11 +744,11 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
     SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
     SPIRV_TosaExtAccTypeAttr: $acc_type,
     SPIRV_BoolConstAttr: $local_bound,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
-    SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $input,
+    SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D: $weight,
     SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
-    SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
-    SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
+    SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $input_zp,
+    SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $weight_zp
   );
 
   let results = (outs
@@ -2167,11 +2177,11 @@ def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,
 
   let arguments = (ins
     SPIRV_TensorArmAxisAttr: $axis,
-    Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm>: $input1
+    Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm>: $input1
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2214,13 +2224,13 @@ def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1,
     SPIRV_I32_1DTensorArmOfEvenLength2To12: $padding,
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1: $pad_const
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_1DTensorArmOfLength1: $pad_const
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2267,12 +2277,12 @@ def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1,
     SPIRV_I32_1DTensorArmOfLength1To6: $shape
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2315,11 +2325,11 @@ def SPIRV_TosaReverseOp : SPIRV_TosaOpWithResult<"Reverse", 57, [Pure,
 
   let arguments = (ins
     SPIRV_TensorArmAxisAttr: $axis,
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2362,13 +2372,13 @@ def SPIRV_TosaSliceOp : SPIRV_TosaOpWithResult<"Slice", 58, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1,
     SPIRV_I32_1DTensorArmOfLength1To6: $start,
     SPIRV_I32_1DTensorArmOfLength1To6: $size
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2416,12 +2426,12 @@ def SPIRV_TosaTileOp : SPIRV_TosaOpWithResult<"Tile", 59, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1,
     SPIRV_I32_1DTensorArmOfLength1To6: $multiples
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2466,11 +2476,11 @@ def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,
 
   let arguments = (ins
     SPIRV_I32_1DTensorArmOfLength1To6Attr: $perms,
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $input1
   );
 
   let results = (outs
-    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
   );
 
   let assemblyFormat = [{
@@ -2512,12 +2522,12 @@ def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
   }];
 
   let arguments = (ins
-    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values,
+    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $values,
     SPIRV_I32_TensorArm2D: $indices
   );
 
   let results = (outs
-    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $output
+    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $output
   );
 
   let assemblyFormat = [{
@@ -2566,13 +2576,13 @@ def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
   }];
 
   let arguments = (ins
-    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_in,
+    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $values_in,
     SPIRV_I32_TensorArm2D: $indices,
-    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $input
+    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $input
   );
 
   let results = (outs
-    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_out
+    SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D: $values_out
   );
 
   let assemblyFormat = [{
@@ -2687,13 +2697,15 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,
 
 def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
   AllShapesMatch<["input", "output"]>,
-  TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8]>,
-  TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16]>,
+  TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8, F8E4M3FN, F8E5M2]>,
+  TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16, F8E4M3FN, F8E5M2]>,
   TypeConstraintImplicationOn<"input", I16, "output", [F16, F32, I32, I8, SPIRV_Bool, BF16]>,
   TypeConstraintImplicationOn<"input", I32, "output", [F16, F32, I16, I8, SPIRV_Bool, BF16]>,
   TypeConstraintImplicationOn<"input", I8, "output", [F16, F32, I16, I32, SPIRV_Bool, BF16]>,
   TypeConstraintImplicationOn<"input", SPIRV_Bool, "output", [I16, I32, I8]>,
-  TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8]>]> {
+  TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8, F8E4M3FN, F8E5M2]>,
+  TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16, F32, BF16]>,
+  TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16, F32, BF16]>]> {
   let summary = "Cast operation.";
 
   let description = [{
@@ -2737,6 +2749,18 @@ def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
     | int16   | bf16    |
     | int32   | bf16    |
     | int8    | bf16    |
+    | bf16    | fp8e4m3 |
+    | fp8e4m3 | bf16    |
+    | bf16    | fp8e5m2 |
+    | fp8e5m2 | bf16    |
+    | float16 | fp8e4m3 |
+    | float32 | fp8e4m3 |
+    | fp8e4m3 | float16 |
+    | fp8e4m3 | float32 |
+    | float16 | fp8e5m2 |
+    | float32 | fp8e5m2 |
+    | fp8e5m2 | float16 |
+    | fp8e5m2 | float32 |
 
     References:
       * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
@@ -2750,11 +2774,11 @@ def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
   }];
 
   let arguments = (ins
-    SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32_TensorArm: $input
+    SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32OrF8E4M3FNOrF8E5M2_TensorArm: $input
   );
 
   let results = (outs
-    SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16_TensorArm: $output
+    SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm: $output
   );
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 6c918aec28845..316dc02da5b1e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -22,16 +22,20 @@ def SPIRV_I8OrI16OrI32OrI64 : AnyIntOfWidths<[8, 16, 32, 64]>;
 def SPIRV_I16OrI32 : AnyIntOfWidths<[16, 32]>;
 def SPIRV_I32OrI64 : AnyIntOfWidths<[32, 64]>;
 def SPIRV_F16OrF32OrBF16 : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR]>;
-def SPIRV_I8OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Int8, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR, SPIRV_Float8E4M3EXT, SPIRV_Float8E5M2EXT]>;
 def SPIRV_I8OrI16OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_Int8, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_I8OrI16, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
 def SPIRV_I32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Int32, SPIRV_F16OrF32OrBF16]>;
 def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
 def SPIRV_I32OrI64OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I32OrI64, SPIRV_F16OrF32OrBF16]>;
 def SPIRV_I32OrI64OrF16OrF32 : AnyTypeOf<[SPIRV_I32OrI64, SPIRV_Float16, SPIRV_Float32]>;
 def SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32OrI64, SPIRV_F16OrF32OrBF16]>;
-def SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
 def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
-def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_Bool, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
+def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
+def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_Bool, SPIRV_F16OrF32OrBF16OrF8E4M3FNOrF8E5M2]>;
 def SPIRV_I8OrI32 : AnyTypeOf<[SPIRV_Int8, SPIRV_Int32]>;
 
 def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
@@ -57,23 +61,26 @@ def SPIRV_I32_TensorArm2D : TensorArmRankOf<[SPIRV_Int32], [2]>;
 def SPIRV_F32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
 def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [1]>;
 def SPIRV_I8OrI16_TensorArm1D : TensorArmRankOf<[SPIRV_I8OrI16], [1]>;
-def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [3]>;
-def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16], [3]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2], [3]>;
+def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2], [3]>;
 def SPIRV_I32OrI64OrF16OrF32_TensorArm3D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32], [3]>;
 def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [4]>;
-def SPIRV_I8OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16], [4]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2], [4]>;
+def SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16OrF8E4M3FNOrF8E5M2], [4]>;
 def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [4]>;
 def SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16], [4]>;
-def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [5]>;
-...
[truncated]

Comment thread mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td Outdated
Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared
type definitions and extending op constraints for the newly supported
element types.

Also update verifier coverage to reflect the new constraints:
- refresh existing negative tests whose diagnostics now list FP8 types
- add negative tests for SPV_EXT_FP8-specific output, weight,
  accumulator, and cast restrictions

Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
Change-Id: Ie636acc87669a66b53410e7efbd3edafa6ee0da1
@davidegrohmann davidegrohmann force-pushed the mlir-spv-fp8-support-spirv-tosa branch from 0e521c0 to 3e16658 Compare April 22, 2026 08:11
@davidegrohmann davidegrohmann merged commit 34a917a into llvm:main Apr 22, 2026
11 checks passed
@davidegrohmann davidegrohmann deleted the mlir-spv-fp8-support-spirv-tosa branch April 22, 2026 09:31
linuxlonelyeagle pushed a commit to linuxlonelyeagle/llvm-project that referenced this pull request Apr 23, 2026
…93199)

Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type
definitions and extending op constraints for the newly supported element
types.

Also update verifier coverage to reflect the new constraints:
- refresh existing negative tests whose diagnostics now list FP8 types
- add negative tests for SPV_EXT_FP8-specific output, weight,
accumulator, and cast restrictions

Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
yingopq pushed a commit to yingopq/llvm-project that referenced this pull request Apr 29, 2026
…93199)

Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type
definitions and extending op constraints for the newly supported element
types.

Also update verifier coverage to reflect the new constraints:
- refresh existing negative tests whose diagnostics now list FP8 types
- add negative tests for SPV_EXT_FP8-specific output, weight,
accumulator, and cast restrictions

Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
KHicketts pushed a commit to KHicketts/llvm-project that referenced this pull request Apr 30, 2026
…93199)

Add SPV_EXT_FP8 support for SPIR-V TOSA ops by updating the shared type
definitions and extending op constraints for the newly supported element
types.

Also update verifier coverage to reflect the new constraints:
- refresh existing negative tests whose diagnostics now list FP8 types
- add negative tests for SPV_EXT_FP8-specific output, weight,
accumulator, and cast restrictions

Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants