diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td index 8497f4f0c4b46..777db957b4d19 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td @@ -195,6 +195,15 @@ class ShapedTypeOf : class DimOf : DimOfType.result, dim>; +class LastDimOf : + StrFunc.result # ".back()">; + +class LastDimIsDynamic : + CPred<"::mlir::ShapedType::isDynamic(" # LastDimOf.result # ")">; + +class DimMatchesLastDim : + CPred.result # " == " # LastDimOf.result>; + class DimIsDynamic : CPred<"::mlir::ShapedType::isDynamic(" # DimOf.result # ")">; @@ -484,9 +493,15 @@ class ElementTypeMatchesScale32 : class TensorLengthMatchesPerChannel : PredOpTrait($" # tensor # ".getType()).getShape()[0] == " - "(getPerChannel() ? " - "::llvm::cast<::mlir::ShapedType>($input.getType()).getShape().back() : 1)">>; + Or<[ + Neg.result>>, + Neg.result>>, + DimIsDynamic, + And<[CPred<"getPerChannel()">, + Or<[LastDimIsDynamic<"input">, + DimMatchesLastDim]>]>, + And<[Neg>, DimIsOne]> + ]>>; #endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir new file mode 100644 index 0000000000000..a2973c73ab362 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-dynamic.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Rescale +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @rescale_per_channel_dynamic_input_last_dimension(%arg0: !spirv.arm.tensor) -> (!spirv.arm.tensor) { + %1 = spirv.Constant dense<[1]> : !spirv.arm.tensor<1xi16> + %2 = spirv.Constant dense<[0, 0]> : !spirv.arm.tensor<2xi8> + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // CHECK: {{%.*}} = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.arm.tensor, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<2xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor +} + +spirv.ARM.Graph @rescale_per_channel_dynamic_multiplier_and_shift_length(%arg0: !spirv.arm.tensor<2x3x4xi16>, %multiplier: !spirv.arm.tensor, %shift: !spirv.arm.tensor) -> (!spirv.arm.tensor<2x3x4xi16>) { + %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + // CHECK: {{%.*}} = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor, !spirv.arm.tensor, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = , per_channel = true, input_unsigned = false, output_unsigned = false, %arg0, %multiplier, %shift, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor, !spirv.arm.tensor, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<2x3x4xi16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x3x4xi16> + spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16> +}