diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td index 406fb43aaa4e8..7d4795ac40e93 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td @@ -83,6 +83,40 @@ class SPIRV_TosaElementwiseBinaryOp tra AllElementTypesMatch<["input1", "output"]>])> { } +class SPIRV_TosaConvolutionOp traits = []> : + SPIRV_TosaOpWithResult, + TypeConstraintImplicationOn<"input", I16, "output", [I64]>, + TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>, + TypeConstraintImplicationOn<"input", F16, "output", [F16]>, + TypeConstraintImplicationOn<"input", F32, "output", [F32]>, + TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>, + TypeConstraintImplicationOn<"input", F16, "weight", [F16]>, + TypeConstraintImplicationOn<"input", F32, "weight", [F32]>, + TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>, + TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>, + TypeImpliesAccType<"input", I8, ["INT32"]>, + TypeImpliesAccType<"input", I16, ["INT48"]>, + TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>, + TypeImpliesAccType<"input", BF16, ["FP32"]>, + TypeImpliesAccType<"input", F32, ["FP32"]>, + AllElementTypesMatch<["bias", "output"]>, + AllElementTypesMatch<["input", "input_zp"]>, + AllElementTypesMatch<["weight", "weight_zp"]>])> { + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + ::mlir::spirv::TensorArmType getWeightType() { + return cast<::mlir::spirv::TensorArmType>(getWeight().getType()); + } + ::mlir::spirv::TensorArmType getBiasType() { + return cast<::mlir::spirv::TensorArmType>(getBias().getType()); + } + }]; +} + def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure, OutputRankIsInputRankMinusOne<"input", "output">, @@ -190,22 +224,7 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [Pure, } -def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure, - TypeConstraintImplicationOn<"input", I8, "output", [I32]>, - TypeConstraintImplicationOn<"input", I16, "output", [I64]>, - TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>, - TypeConstraintImplicationOn<"input", F16, "output", [F16]>, - TypeConstraintImplicationOn<"input", F32, "output", [F32]>, - TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>, - TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>, - TypeImpliesAccType<"input", I8, ["INT32"]>, - TypeImpliesAccType<"input", I16, ["INT48"]>, - TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>, - TypeImpliesAccType<"input", BF16, ["FP32"]>, - TypeImpliesAccType<"input", F32, ["FP32"]>, - AllElementTypesMatch<["bias", "output"]>, - AllElementTypesMatch<["input", "input_zp"]>, - AllElementTypesMatch<["weight", "weight_zp"]>]> { +def SPIRV_TosaConv2DOp : SPIRV_TosaConvolutionOp<"Conv2D", 2> { let summary = "2D Convolution operator."; let description = [{ @@ -257,36 +276,10 @@ def SPIRV_TosaConv2DOp : SPIRV_TosaOpWithResult<"Conv2D", 2, [Pure, attr-dict `:` type(operands) `->` type(results) }]; - let extraClassDeclaration = extraBaseClassDeclaration#[{ - ::mlir::spirv::TensorArmType getInputType() { - return cast<::mlir::spirv::TensorArmType>(getInput().getType()); - } - ::mlir::spirv::TensorArmType getWeightType() { - return cast<::mlir::spirv::TensorArmType>(getWeight().getType()); - } - ::mlir::spirv::TensorArmType getBiasType() { - return cast<::mlir::spirv::TensorArmType>(getBias().getType()); - } - }]; } -def SPIRV_TosaConv3DOp : SPIRV_TosaOpWithResult<"Conv3D", 3, [Pure, - TypeConstraintImplicationOn<"input", I8, "output", [I32]>, - TypeConstraintImplicationOn<"input", I16, "output", [I64]>, - TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>, - TypeConstraintImplicationOn<"input", F16, "output", [F16]>, - TypeConstraintImplicationOn<"input", F32, "output", [F32]>, - TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>, - TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>, - TypeImpliesAccType<"input", I8, ["INT32"]>, - TypeImpliesAccType<"input", I16, ["INT48"]>, - TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>, - TypeImpliesAccType<"input", BF16, ["FP32"]>, - TypeImpliesAccType<"input", F32, ["FP32"]>, - AllElementTypesMatch<["bias", "output"]>, - AllElementTypesMatch<["input", "input_zp"]>, - AllElementTypesMatch<["weight", "weight_zp"]>]> { +def SPIRV_TosaConv3DOp : SPIRV_TosaConvolutionOp<"Conv3D", 3> { let summary = "3D Convolution operator."; let description = [{ @@ -337,36 +330,10 @@ def SPIRV_TosaConv3DOp : SPIRV_TosaOpWithResult<"Conv3D", 3, [Pure, attr-dict `:` type(operands) `->` type(results) }]; - let extraClassDeclaration = extraBaseClassDeclaration#[{ - ::mlir::spirv::TensorArmType getInputType() { - return cast<::mlir::spirv::TensorArmType>(getInput().getType()); - } - ::mlir::spirv::TensorArmType getWeightType() { - return cast<::mlir::spirv::TensorArmType>(getWeight().getType()); - } - ::mlir::spirv::TensorArmType getBiasType() { - return cast<::mlir::spirv::TensorArmType>(getBias().getType()); - } - }]; } -def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOpWithResult<"DepthwiseConv2D", 4, [Pure, - TypeConstraintImplicationOn<"input", I8, "output", [I32]>, - TypeConstraintImplicationOn<"input", I16, "output", [I64]>, - TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>, - TypeConstraintImplicationOn<"input", F16, "output", [F16]>, - TypeConstraintImplicationOn<"input", F32, "output", [F32]>, - TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>, - TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>, - TypeImpliesAccType<"input", I8, ["INT32"]>, - TypeImpliesAccType<"input", I16, ["INT48"]>, - TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>, - TypeImpliesAccType<"input", BF16, ["FP32"]>, - TypeImpliesAccType<"input", F32, ["FP32"]>, - AllElementTypesMatch<["bias", "output"]>, - AllElementTypesMatch<["input", "input_zp"]>, - AllElementTypesMatch<["weight", "weight_zp"]>]> { +def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaConvolutionOp<"DepthwiseConv2D", 4> { let summary = "Depthwise 2D Convolution operator."; let description = [{ @@ -418,17 +385,6 @@ def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOpWithResult<"DepthwiseConv2D", 4, [ attr-dict `:` type(operands) `->` type(results) }]; - let extraClassDeclaration = extraBaseClassDeclaration#[{ - ::mlir::spirv::TensorArmType getInputType() { - return cast<::mlir::spirv::TensorArmType>(getInput().getType()); - } - ::mlir::spirv::TensorArmType getWeightType() { - return cast<::mlir::spirv::TensorArmType>(getWeight().getType()); - } - ::mlir::spirv::TensorArmType getBiasType() { - return cast<::mlir::spirv::TensorArmType>(getBias().getType()); - } - }]; } @@ -635,22 +591,7 @@ def SPIRV_TosaRFFT2DOp : SPIRV_TosaOpWithComplexResult<"RFFT2D", 8, [Pure]> { } -def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [Pure, - TypeConstraintImplicationOn<"input", I8, "output", [I32]>, - TypeConstraintImplicationOn<"input", I16, "output", [I64]>, - TypeConstraintImplicationOn<"input", BF16, "output", [BF16]>, - TypeConstraintImplicationOn<"input", F16, "output", [F16]>, - TypeConstraintImplicationOn<"input", F32, "output", [F32]>, - TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>, - TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>, - TypeImpliesAccType<"input", I8, ["INT32"]>, - TypeImpliesAccType<"input", I16, ["INT48"]>, - TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>, - TypeImpliesAccType<"input", BF16, ["FP32"]>, - TypeImpliesAccType<"input", F32, ["FP32"]>, - AllElementTypesMatch<["bias", "output"]>, - AllElementTypesMatch<["input", "input_zp"]>, - AllElementTypesMatch<["weight", "weight_zp"]>]> { +def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaConvolutionOp<"TransposeConv2D", 9> { let summary = "Transpose 2D Convolution operator."; let description = [{ @@ -700,17 +641,6 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [ attr-dict `:` type(operands) `->` type(results) }]; - let extraClassDeclaration = extraBaseClassDeclaration#[{ - ::mlir::spirv::TensorArmType getInputType() { - return cast<::mlir::spirv::TensorArmType>(getInput().getType()); - } - ::mlir::spirv::TensorArmType getWeightType() { - return cast<::mlir::spirv::TensorArmType>(getWeight().getType()); - } - ::mlir::spirv::TensorArmType getBiasType() { - return cast<::mlir::spirv::TensorArmType>(getBias().getType()); - } - }]; }