diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h index e088eb31338dc..b80232f112b64 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -28,19 +28,22 @@ struct TosaLevel { int32_t MAX_LOG2_SIZE = 0; int32_t MAX_NESTING = 0; int32_t MAX_TENSOR_LIST_SIZE = 0; + int32_t MAX_SHAPE_LEN = 0; bool operator==(const TosaLevel &rhs) { return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && MAX_NESTING == rhs.MAX_NESTING && - MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; + MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE && + MAX_SHAPE_LEN == rhs.MAX_SHAPE_LEN; } }; -static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; +static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, + 31, 6, 64, 16}; static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, - 63, 256, 256}; + 63, 256, 256, 64}; TargetEnvAttr lookupTargetEnv(Operation *op); TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index a900aef04f753..387d38411f0fe 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -228,12 +228,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { if (type.getRank() > highest_rank) return op->emitOpError() << "failed level check: " << operandOrResult << " rank(shape) <= MAX_RANK"; - } else if (tosa::shapeType shapeType = - dyn_cast(typeToCheck)) { - if (shapeType.getRank() > highest_rank) - return op->emitOpError() - << "failed shape type level check: " << typeToCheck - << " exceeds MAX_RANK"; } return success(); } @@ -255,6 +249,18 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { return levelCheckSize(op, v.getType(), operandOrResult); } + // Perform the Level shape length check on a value. + LogicalResult levelCheckShapeLength(Operation *op, const Type typeToCheck, + const StringRef operandOrResult) { + if (tosa::shapeType shapeType = dyn_cast(typeToCheck)) { + if (shapeType.getRank() > targetEnv.getLevel().MAX_SHAPE_LEN) + return op->emitOpError() + << "failed shape type level check: " << typeToCheck + << " exceeds MAX_SHAPE_LEN"; + } + return success(); + } + // Level check sizes of all operands and results of the operation. template LogicalResult levelCheckSizes(T tosaOp) { @@ -288,6 +294,20 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { return success(); } + // Level check shape lengths of all operands and results of an operation that + // are tosa.shape type. + template + LogicalResult levelCheckShapeLengths(T tosaOp) { + for (const auto &v : tosaOp->getOperands()) { + if (failed(levelCheckShapeLength(tosaOp, v.getType(), "operand"))) + return failure(); + } + if (failed(levelCheckShapeLength(tosaOp, tosaOp.getResult().getType(), + "result"))) + return failure(); + return success(); + } + // Level check ranks and sizes. LogicalResult levelCheckRanksAndSizes(Operation *op); @@ -591,9 +611,9 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { return failure(); \ } -#define CHECK_RANKS(tosaOp) \ +#define CHECK_SHAPE_LEN(tosaOp) \ if (isa(op)) { \ - if (failed(levelCheckRanks(cast(op)))) \ + if (failed(levelCheckShapeLengths(cast(op)))) \ return failure(); \ } @@ -700,27 +720,27 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { // Shape Operators CHECK_SIZES(ConstShape); - // For the following operations, check whether the rank of each operand - // is valid given a level. + // For the following operations, check whether the shape length of each + // operand is valid given a level. // Shape Operators - CHECK_RANKS(AddShape); - CHECK_RANKS(ConcatShape); - CHECK_RANKS(DivCeilShape); - CHECK_RANKS(DivFloorShape); - CHECK_RANKS(Exp2Shape); - CHECK_RANKS(Log2CeilShape); - CHECK_RANKS(Log2FloorShape); - CHECK_RANKS(MaxShape); - CHECK_RANKS(MinShape); - CHECK_RANKS(ModShape); - CHECK_RANKS(MulShape); - CHECK_RANKS(SliceShape); - CHECK_RANKS(SubShape); + CHECK_SHAPE_LEN(AddShape); + CHECK_SHAPE_LEN(ConcatShape); + CHECK_SHAPE_LEN(DivCeilShape); + CHECK_SHAPE_LEN(DivFloorShape); + CHECK_SHAPE_LEN(Exp2Shape); + CHECK_SHAPE_LEN(Log2CeilShape); + CHECK_SHAPE_LEN(Log2FloorShape); + CHECK_SHAPE_LEN(MaxShape); + CHECK_SHAPE_LEN(MinShape); + CHECK_SHAPE_LEN(ModShape); + CHECK_SHAPE_LEN(MulShape); + CHECK_SHAPE_LEN(SliceShape); + CHECK_SHAPE_LEN(SubShape); #undef CHECK_RANKS_AND_SIZES #undef CHECK_SIZES -#undef CHECK_RANKS +#undef CHECK_SHAPE_LEN return success(); } diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index d874ab6d23a50..dd5ece417cf9e 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -390,7 +390,7 @@ func.func @test_pad_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1 func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> { %1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7> - // expected-error@+1 {{'tosa.reshape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}} + // expected-error@+1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}} %0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32> return %0 : tensor<1x1x1x1x1x1x819xf32> } @@ -1665,22 +1665,22 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32 // ----- -func.func @test_add_shape_invalid_rank() -> !tosa.shape<13> { - %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13> - %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13> - // expected-error@+1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<13>' exceeds MAX_RANK}} - %c = tosa.add_shape %a, %b : (!tosa.shape<13>, !tosa.shape<13>) -> !tosa.shape<13> - return %c : !tosa.shape<13> +func.func @test_add_shape_invalid_rank() -> !tosa.shape<17> { + %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17> + %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17> + // expected-error@+1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}} + %c = tosa.add_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17> + return %c : !tosa.shape<17> } // ----- -func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> { - %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7> - %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7> - // expected-error@+1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}} - %c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7> - return %c : !tosa.shape<7> +func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<17> { + %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17> + %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17> + // expected-error@+1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}} + %c = tosa.div_floor_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17> + return %c : !tosa.shape<17> } // ----- @@ -1721,37 +1721,37 @@ func.func @test_concat_shape_invalid_list_size() { // ----- -func.func @test_exp2_shape_invalid_rank() -> !tosa.shape<7> { - %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7> - // expected-error@+1 {{'tosa.exp2_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}} - %1 = tosa.exp2_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7> - return %1 : !tosa.shape<7> +func.func @test_exp2_shape_invalid_rank() -> !tosa.shape<17> { + %0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17> + // expected-error@+1 {{'tosa.exp2_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}} + %1 = tosa.exp2_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17> + return %1 : !tosa.shape<17> } // ----- -func.func @test_log2_floor_shape_invalid_rank() -> !tosa.shape<7> { - %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7> - // expected-error@+1 {{'tosa.log2_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}} - %1 = tosa.log2_floor_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7> - return %1 : !tosa.shape<7> +func.func @test_log2_floor_shape_invalid_rank() -> !tosa.shape<17> { + %0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17> + // expected-error@+1 {{'tosa.log2_floor_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}} + %1 = tosa.log2_floor_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17> + return %1 : !tosa.shape<17> } // ----- -func.func @test_log2_ceil_shape_invalid_rank() -> !tosa.shape<7> { - %0 = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7> - // expected-error@+1 {{'tosa.log2_ceil_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}} - %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<7>) -> !tosa.shape<7> - return %1 : !tosa.shape<7> +func.func @test_log2_ceil_shape_invalid_rank() -> !tosa.shape<17> { + %0 = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17> + // expected-error@+1 {{'tosa.log2_ceil_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}} + %1 = tosa.log2_ceil_shape %0 : (!tosa.shape<17>) -> !tosa.shape<17> + return %1 : !tosa.shape<17> } // ----- -func.func @test_mod_shape_invalid_rank() -> !tosa.shape<9> { - %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<9xindex>} : () -> !tosa.shape<9> - %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<9xindex>} : () -> !tosa.shape<9> - // expected-error@+1 {{'tosa.mod_shape' op failed shape type level check: '!tosa.shape<9>' exceeds MAX_RANK}} - %c = tosa.mod_shape %a, %b : (!tosa.shape<9>, !tosa.shape<9>) -> !tosa.shape<9> - return %c : !tosa.shape<9> +func.func @test_mod_shape_invalid_rank() -> !tosa.shape<17> { + %a = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17> + %b = tosa.const_shape {values = dense<0> : tensor<17xindex>} : () -> !tosa.shape<17> + // expected-error@+1 {{'tosa.mod_shape' op failed shape type level check: '!tosa.shape<17>' exceeds MAX_SHAPE_LEN}} + %c = tosa.mod_shape %a, %b : (!tosa.shape<17>, !tosa.shape<17>) -> !tosa.shape<17> + return %c : !tosa.shape<17> } diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir index 036a3d4d0ba2e..d2d010d3a0845 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir @@ -37,3 +37,13 @@ func.func @test_validate_without_tosa(%arg0: f32) -> f32 { %0 = math.asin %arg0 : f32 return %0 : f32 } + +// ----- + +// CHECK-LABEL: test_pad_large_input_rank +func.func @test_pad_large_input_rank(%arg0: tensor<13x21x3x1x1x1xf32>) -> tensor<13x21x3x1x1x1xf32> { + %0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32> + %padding = tosa.const_shape {values = dense<0> : tensor<12xindex>} : () -> !tosa.shape<12> + %1 = tosa.pad %arg0, %padding, %0 : (tensor<13x21x3x1x1x1xf32>, !tosa.shape<12>, tensor<1xf32>) -> tensor<13x21x3x1x1x1xf32> + return %1 : tensor<13x21x3x1x1x1xf32> +}