Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 76 additions & 52 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,20 @@ class SPIRV_TosaConvolutionOp<string mnemonic, int opcode, list<Trait> traits =
TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>,
TypeConstraintImplicationOn<"input", F16, "output", [F16]>,
TypeConstraintImplicationOn<"input", F32, "output", [F32]>,
TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16]>,
TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16]>,
TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
TypeConstraintImplicationOn<"input", F8E4M3FN, "weight", [F8E4M3FN]>,
TypeConstraintImplicationOn<"input", F8E5M2, "weight", [F8E5M2]>,
TypeImpliesAccType<"input", I8, ["INT32"]>,
TypeImpliesAccType<"input", I16, ["INT48"]>,
TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
TypeImpliesAccType<"input", BF16, ["FP32"]>,
TypeImpliesAccType<"input", F32, ["FP32"]>,
TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
AllElementTypesMatch<["bias", "output"]>,
AllElementTypesMatch<["input", "input_zp"]>,
AllElementTypesMatch<["weight", "weight_zp"]>])> {
Expand Down Expand Up @@ -249,7 +255,7 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure,
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm: $input
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm: $input
);

let results = (outs
Expand Down Expand Up @@ -277,6 +283,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
TypeImpliesAccType<"input", BF16, ["FP32"]>,
TypeImpliesAccType<"input", F32, ["FP32"]>,
TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>,
TypeImpliesAccType<"input", F8E5M2, ["FP16"]>,
AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> {
let summary = "Performs average pooling on the input.";

Expand Down Expand Up @@ -304,13 +312,13 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe
SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $output_zp
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $output_zp
);

let results = (outs
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $output
);

let assemblyFormat = [{
Expand Down Expand Up @@ -361,11 +369,11 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> {
SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D: $weight,
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
);

let results = (outs
Expand Down Expand Up @@ -416,11 +424,11 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> {
SPIRV_I32_1DTensorArmOfLength3Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D: $input,
SPIRV_I8OrF16OrF32OrBF16_TensorArm5D: $weight,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm5D: $input,
SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm5D: $weight,
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
);

let results = (outs
Expand Down Expand Up @@ -472,11 +480,11 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4>
SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D: $weight,
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
);

let results = (outs
Expand Down Expand Up @@ -557,6 +565,8 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
TypeConstraintImplicationOn<"A", BF16, "output", [F32]>,
TypeConstraintImplicationOn<"A", F16, "output", [F16, F32]>,
TypeConstraintImplicationOn<"A", F32, "output", [F32]>,
TypeConstraintImplicationOn<"A", F8E4M3FN, "output", [F16]>,
TypeConstraintImplicationOn<"A", F8E5M2, "output", [F16]>,
AllElementTypesMatch<["A", "A_zp", "B", "B_zp"]>]> {
let summary = "Matrix Multiplication operator.";

Expand All @@ -579,10 +589,10 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect,
}];

let arguments = (ins
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $A,
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $B,
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $A_zp,
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $B_zp
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm3D: $A,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm3D: $B,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $A_zp,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $B_zp
);

let results = (outs
Expand Down Expand Up @@ -634,11 +644,11 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure,
SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input
);

let results = (outs
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $output
);

let assemblyFormat = [{
Expand Down Expand Up @@ -734,11 +744,11 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9>
SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input,
SPIRV_I8OrF16OrF32OrBF16OrFP8_TensorArm4D: $weight,
SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp,
SPIRV_I8OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $weight_zp
);

let results = (outs
Expand Down Expand Up @@ -2167,11 +2177,11 @@ def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,

let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm>: $input1
Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm>: $input1
);

let results = (outs
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
);

let assemblyFormat = [{
Expand Down Expand Up @@ -2214,13 +2224,13 @@ def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
}];

let arguments = (ins
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
SPIRV_I32_1DTensorArmOfEvenLength2To12: $padding,
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1: $pad_const
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $pad_const
);

let results = (outs
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
);

let assemblyFormat = [{
Expand Down Expand Up @@ -2267,12 +2277,12 @@ def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
}];

let arguments = (ins
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
SPIRV_I32_1DTensorArmOfLength1To6: $shape
);

let results = (outs
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
);

let assemblyFormat = [{
Expand Down Expand Up @@ -2315,11 +2325,11 @@ def SPIRV_TosaReverseOp : SPIRV_TosaOpWithResult<"Reverse", 57, [Pure,

let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1
);

let results = (outs
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
);

let assemblyFormat = [{
Expand Down Expand Up @@ -2362,13 +2372,13 @@ def SPIRV_TosaSliceOp : SPIRV_TosaOpWithResult<"Slice", 58, [Pure,
}];

let arguments = (ins
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
SPIRV_I32_1DTensorArmOfLength1To6: $start,
SPIRV_I32_1DTensorArmOfLength1To6: $size
);

let results = (outs
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
);

let assemblyFormat = [{
Expand Down Expand Up @@ -2416,12 +2426,12 @@ def SPIRV_TosaTileOp : SPIRV_TosaOpWithResult<"Tile", 59, [Pure,
}];

let arguments = (ins
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1,
SPIRV_I32_1DTensorArmOfLength1To6: $multiples
);

let results = (outs
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
);

let assemblyFormat = [{
Expand Down Expand Up @@ -2466,11 +2476,11 @@ def SPIRV_TosaTransposeOp : SPIRV_TosaOpWithResult<"Transpose", 60, [Pure,

let arguments = (ins
SPIRV_I32_1DTensorArmOfLength1To6Attr: $perms,
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input1
);

let results = (outs
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $output
);

let assemblyFormat = [{
Expand Down Expand Up @@ -2512,12 +2522,12 @@ def SPIRV_TosaGatherOp : SPIRV_TosaOpWithResult<"Gather", 61, [NoMemoryEffect,
}];

let arguments = (ins
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values,
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $values,
SPIRV_I32_TensorArm2D: $indices
);

let results = (outs
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $output
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $output
);

let assemblyFormat = [{
Expand Down Expand Up @@ -2566,13 +2576,13 @@ def SPIRV_TosaScatterOp : SPIRV_TosaOpWithResult<"Scatter", 62, [NoMemoryEffect,
}];

let arguments = (ins
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_in,
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $values_in,
SPIRV_I32_TensorArm2D: $indices,
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $input
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $input
);

let results = (outs
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_out
SPIRV_I8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm3D: $values_out
);

let assemblyFormat = [{
Expand Down Expand Up @@ -2687,13 +2697,15 @@ def SPIRV_TosaResizeOp : SPIRV_TosaOpWithResult<"Resize", 63, [Pure,

def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
AllShapesMatch<["input", "output"]>,
TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8]>,
TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16]>,
TypeConstraintImplicationOn<"input", F16, "output", [F32, I16, I32, I8, F8E4M3FN, F8E5M2]>,
TypeConstraintImplicationOn<"input", F32, "output", [F16, I16, I32, I8, BF16, F8E4M3FN, F8E5M2]>,
TypeConstraintImplicationOn<"input", I16, "output", [F16, F32, I32, I8, SPIRV_Bool, BF16]>,
TypeConstraintImplicationOn<"input", I32, "output", [F16, F32, I16, I8, SPIRV_Bool, BF16]>,
TypeConstraintImplicationOn<"input", I8, "output", [F16, F32, I16, I32, SPIRV_Bool, BF16]>,
TypeConstraintImplicationOn<"input", SPIRV_Bool, "output", [I16, I32, I8]>,
TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8]>]> {
TypeConstraintImplicationOn<"input", BF16, "output", [F32, I16, I32, I8, F8E4M3FN, F8E5M2]>,
TypeConstraintImplicationOn<"input", F8E4M3FN, "output", [F16, F32, BF16]>,
TypeConstraintImplicationOn<"input", F8E5M2, "output", [F16, F32, BF16]>]> {
let summary = "Cast operation.";

let description = [{
Expand Down Expand Up @@ -2737,6 +2749,18 @@ def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
| int16 | bf16 |
| int32 | bf16 |
| int8 | bf16 |
| bf16 | fp8e4m3 |
| fp8e4m3 | bf16 |
| bf16 | fp8e5m2 |
| fp8e5m2 | bf16 |
| float16 | fp8e4m3 |
| float32 | fp8e4m3 |
| fp8e4m3 | float16 |
| fp8e4m3 | float32 |
| float16 | fp8e5m2 |
| float32 | fp8e5m2 |
| fp8e5m2 | float16 |
| fp8e5m2 | float32 |

References:
* https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_cast
Expand All @@ -2750,11 +2774,11 @@ def SPIRV_TosaCastOp : SPIRV_TosaOpWithResult<"Cast", 64, [Pure,
}];

let arguments = (ins
SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32_TensorArm: $input
SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16OrFP8_TensorArm: $input
);

let results = (outs
SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16_TensorArm: $output
SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16OrFP8_TensorArm: $output
);

let assemblyFormat = [{
Expand Down
Loading