diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index cc23955f31f23..419340256fa59 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -241,6 +241,7 @@ class Tosa_I32EnumAttr; @@ -274,6 +275,7 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>; def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>; def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>; def Tosa_EXT_INT64 : I32EnumAttrCase<"int64", 13>; +def Tosa_EXT_SHAPE : I32EnumAttrCase<"shape", 14>; def Tosa_ExtensionAttr @@ -281,7 +283,7 @@ def Tosa_ExtensionAttr Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, - Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64 + Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_SHAPE, ]> { let extraClassDeclaration = [{ static llvm::SmallVector getAllValues() { @@ -290,7 +292,7 @@ def Tosa_ExtensionAttr Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft, Extension::variable, Extension::controlflow, Extension::doubleround, Extension::inexactround, Extension::dynamic, Extension::mxfp, - Extension::int64 + Extension::int64, Extension::shape }; } }]; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index ea58f49b64c44..bee253689bab7 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -154,6 +154,7 @@ class TosaProfileCompliance { case Extension::controlflow: case Extension::dynamic: case Extension::int64: + case Extension::shape: return {Profile::pro_fp, Profile::pro_int}; case Extension::none: return {}; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td index 90cda42d95624..7b1c7e208ebe3 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td @@ -30,15 +30,8 @@ def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> { class Tosa_ShapeOp traits = []> : Tosa_Op { - list availability = [ - Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, - Extension<[]>, - ]; - let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; - - let hasFolder = 1; } // op trait: shape operator has same ranks for operands and results @@ -53,6 +46,29 @@ class Tosa_ElementwiseShapeOp traits = []> } +//===----------------------------------------------------------------------===// +// Operator: AddShape +//===----------------------------------------------------------------------===// +def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> { + let summary = "Elementwise addition of shapes."; + + let description = [{ + Elementwise addition of input1 and input2. Size of shapes must match. + }]; + + let arguments = (ins + Tosa_Shape:$input1, + Tosa_Shape:$input2 + ); + + let results = (outs Tosa_Shape:$output); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_SHAPE]>, + ]; +} + //===----------------------------------------------------------------------===// // Operator: ConstShape //===----------------------------------------------------------------------===// @@ -80,6 +96,99 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> { ]; let hasVerifier = 1; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// Operator: DivCeilShape +//===----------------------------------------------------------------------===// +def Tosa_DivCeilShapeOp : Tosa_ElementwiseShapeOp<"div_ceil_shape", [Pure]> { + let summary = "Elementwise ceiling divide of shapes."; + + let description = [{ + Elementwise divide of input1 by input2. The result of the divide is rounded up. + }]; + + let arguments = (ins + Tosa_Shape:$input1, + Tosa_Shape:$input2 + ); + + let results = (outs Tosa_Shape:$output); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_SHAPE]> + ]; +} + +//===----------------------------------------------------------------------===// +// Operator: DivFloorShape +//===----------------------------------------------------------------------===// +def Tosa_DivFloorShapeOp : Tosa_ElementwiseShapeOp<"div_floor_shape", [Pure]> { + let summary = "Elementwise floor divide of shapes."; + + let description = [{ + Elementwise integer divide of input1 by input2. The result of the divide is rounded down. + }]; + + let arguments = (ins + Tosa_Shape:$input1, + Tosa_Shape:$input2 + ); + + let results = (outs Tosa_Shape:$output); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_SHAPE]> + ]; +} + +//===----------------------------------------------------------------------===// +// Operator: MulShape +//===----------------------------------------------------------------------===// +def Tosa_MulShapeOp : Tosa_ElementwiseShapeOp<"mul_shape", [Pure]> { + let summary = "Elementwise multiplication of shapes."; + + let description = [{ + Elementwise multiplication of input1 and input2. + }]; + + let arguments = (ins + Tosa_Shape:$input1, + Tosa_Shape:$input2 + ); + + let results = (outs Tosa_Shape:$output); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_SHAPE]> + ]; +} + +//===----------------------------------------------------------------------===// +// Operator: SubShape +//===----------------------------------------------------------------------===// +def Tosa_SubShapeOp : Tosa_ElementwiseShapeOp<"sub_shape", [Pure]> { + let summary = "Elementwise subtraction of shapes."; + + let description = [{ + Elementwise subtraction of input1 and input2. Size of shapes must match. + }]; + + let arguments = (ins + Tosa_Shape:$input1, + Tosa_Shape:$input2 + ); + + let results = (outs Tosa_Shape:$output); + + list availability = [ + Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, + Extension<[Tosa_EXT_SHAPE]>, + ]; } #endif // TOSA_SHAPE_OPS diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index eb47e85cf9b0b..01f78f86d427b 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) { return TosaSpecificationVersion(1, 0); case Extension::mxfp: case Extension::int64: + case Extension::shape: return TosaSpecificationVersion(1, 1); case Extension::none: return TosaSpecificationVersion(0, 0); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index ddd9c70402fdc..c9150d5b34d00 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -317,7 +317,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // Type Invariant Extension, a capability extension that is independent // of the data type, meaning any compatible type can be used. No type // constraint for those operations. + POPULATE_PROFILE_INFO_SKIP(AddShape) POPULATE_PROFILE_INFO_SKIP(ConstShape) + POPULATE_PROFILE_INFO_SKIP(DivCeilShape) + POPULATE_PROFILE_INFO_SKIP(DivFloorShape) + POPULATE_PROFILE_INFO_SKIP(MulShape) + POPULATE_PROFILE_INFO_SKIP(SubShape) POPULATE_PROFILE_INFO_SKIP(Yield) POPULATE_PROFILE_INFO_SKIP(If) POPULATE_PROFILE_INFO_SKIP(While) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index b54ed5585d72d..421ef237e628f 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -218,6 +218,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { if (type.getRank() > highest_rank) return op->emitOpError() << "failed level check: " << operandOrResult << " rank(shape) <= MAX_RANK"; + } else if (tosa::shapeType shapeType = + dyn_cast(typeToCheck)) { + if (shapeType.getRank() > highest_rank) + return op->emitOpError() + << "failed shape type level check: " << typeToCheck + << " exceeds MAX_RANK"; } return success(); } @@ -638,15 +644,21 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_RANKS_AND_SIZES(CastFromBlockScaled); CHECK_RANKS_AND_SIZES(CastToBlockScaled); CHECK_RANKS_AND_SIZES(Rescale); + // Data Nodes + CHECK_RANKS_AND_SIZES(Const); + CHECK_RANKS_AND_SIZES(Identity); // Control Flow Operators CHECK_RANKS_AND_SIZES(If); // Variable Operators CHECK_RANKS_AND_SIZES(Variable); CHECK_RANKS_AND_SIZES(VariableWrite); CHECK_RANKS_AND_SIZES(VariableRead); - // Data Nodes - CHECK_RANKS_AND_SIZES(Const); - CHECK_RANKS_AND_SIZES(Identity); + // Shape Operators + CHECK_RANKS_AND_SIZES(AddShape); + CHECK_RANKS_AND_SIZES(DivCeilShape); + CHECK_RANKS_AND_SIZES(DivFloorShape); + CHECK_RANKS_AND_SIZES(MulShape); + CHECK_RANKS_AND_SIZES(SubShape); // For the following operators, check whether the size of each tensor // operand is valid in a given Level. diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 68a95787b81c7..a06406fcdab1f 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -584,3 +584,13 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor< %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU> } + +// ----- + +func.func @test_mul_shape() -> !tosa.shape<4> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> + %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> + // expected-error@+1 {{'tosa.mul_shape' op illegal: requires [shape] but not enabled in target}} + %c = tosa.mul_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4> + return %c : !tosa.shape<4> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index a7087647e542b..213c4ae054c51 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -390,7 +390,7 @@ func.func @test_pad_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1 func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> { %1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7> - // expected-error@+1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}} + // expected-error@+1 {{'tosa.reshape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}} %0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32> return %0 : tensor<1x1x1x1x1x1x819xf32> } @@ -1662,3 +1662,23 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32 %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU> } + +// ----- + +func.func @test_add_shape_invalid_rank() -> !tosa.shape<13> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13> + %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13> + // expected-error@+1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<13>' exceeds MAX_RANK}} + %c = tosa.add_shape %a, %b : (!tosa.shape<13>, !tosa.shape<13>) -> !tosa.shape<13> + return %c : !tosa.shape<13> +} + +// ----- + +func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7> + %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7> + // expected-error@+1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}} + %c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7> + return %c : !tosa.shape<7> +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index a4591f7ffd393..2c4ec857ad20e 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -1374,3 +1374,48 @@ func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> { %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8> return %0 : tensor<2x!tosa.mxint8> } + +// ----- +// CHECK-LABEL: test_add_shape +func.func @test_add_shape() -> !tosa.shape<4> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> + %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> + %c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4> + return %c : !tosa.shape<4> +} + +// ----- +// CHECK-LABEL: test_sub_shape +func.func @test_sub_shape() -> !tosa.shape<3> { + %a = tosa.const_shape {values = dense<[10, 5, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> + %b = tosa.const_shape {values = dense<[2, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %c = tosa.sub_shape %a, %b : (!tosa.shape<3>, !tosa.shape<3>) -> !tosa.shape<3> + return %c : !tosa.shape<3> +} + +// ----- +// CHECK-LABEL: test_mul_shape +func.func @test_mul_shape() -> !tosa.shape<4> { + %a = tosa.const_shape {values = dense<[2, 3, 4, 5]> : tensor<4xindex>} : () -> !tosa.shape<4> + %b = tosa.const_shape {values = dense<[7, 0, 2, 6]> : tensor<4xindex>} : () -> !tosa.shape<4> + %c = tosa.mul_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4> + return %c : !tosa.shape<4> +} + +// ----- +// CHECK-LABEL: test_div_ceil_shape +func.func @test_div_ceil_shape() -> !tosa.shape<4> { + %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %b = tosa.const_shape {values = dense<[2, 3, 4, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> + %c = tosa.div_ceil_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4> + return %c : !tosa.shape<4> +} + +// ----- +// CHECK-LABEL: test_div_floor_shape +func.func @test_div_floor_shape() -> !tosa.shape<4> { + %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %b = tosa.const_shape {values = dense<[2, 3, 4, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> + %c = tosa.div_floor_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4> + return %c : !tosa.shape<4> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir index c285ae3cf44ee..66a94559348a8 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64" -tosa-validate="strict-op-spec-alignment" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64,shape" -tosa-validate="strict-op-spec-alignment" | FileCheck %s // ----- @@ -156,3 +156,12 @@ func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: te %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> return %0 : tensor<2x52x3xf32> } + +// ----- +// CHECK-LABEL: test_add_shape +func.func @test_add_shape() -> !tosa.shape<4> { + %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> + %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> + %c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4> + return %c : !tosa.shape<4> +} diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 6cf76cdc7ad8e..a70709b4ecc6a 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -1222,3 +1222,19 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4 %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU> } + +// ----- + +func.func @test_elementwise_shape_op_same_inputs_rank(%arg0: !tosa.shape<4>, %arg1: !tosa.shape<3>) -> !tosa.shape<4> { + // expected-error@+1 {{'tosa.add_shape' op operands don't have matching ranks}} + %0 = tosa.add_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<3>) -> !tosa.shape<4> + return %0 : !tosa.shape<4> +} + +// ----- + +func.func @test_elementwise_shape_op_same_input_output_rank(%arg0: !tosa.shape<4>, %arg1: !tosa.shape<4>) -> !tosa.shape<3> { + // expected-error@+1 {{'tosa.div_floor_shape' op result shape has different rank than operands}} + %0 = tosa.div_floor_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<3> + return %0 : !tosa.shape<3> +}