diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index ac5d6207259eb..62c015a85ee36 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -216,22 +216,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) { bool mlir::tosa::hasUniqueConstantScatterIndices( ShapedType indicesType, DenseIntElementsAttr indicesAttr) { - llvm::ArrayRef const indicesShape = indicesType.getShape(); + const llvm::ArrayRef indicesShape = indicesType.getShape(); const unsigned int indicesRank = indicesShape.size(); const unsigned int lastDimSize = indicesShape[indicesRank - 1]; // check each batch of indices from the flat indicesAttr values // for duplicates - auto const indicesValues = indicesAttr.getValues(); + auto const indicesValues = indicesAttr.getValues(); assert( (indicesValues.size() % lastDimSize == 0) && "Constant indices data length should be a multiple of indicesShape[-1]"); - std::vector indices(lastDimSize); + std::vector indices(lastDimSize); for (auto beg = indicesValues.begin(); beg < indicesValues.end(); beg += lastDimSize) { std::copy(beg, beg + lastDimSize, indices.begin()); - std::sort(indices.begin(), indices.end()); + std::sort(indices.begin(), indices.end(), + [](const APInt &a, const APInt &b) { return a.slt(b); }); if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) { // found duplicate values in indices in batch return false; diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index c9e03ca53a729..3d24928487ed2 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" func.func @test_cast(%arg0: tensor) -> tensor<5xi32> { @@ -2044,6 +2044,16 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens // ----- +// CHECK-LABEL: test_scatter_duplicate_indices_int64 +func.func @test_scatter_duplicate_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> { + %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64> + // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}} + %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> + return %0 : tensor<2x52x3xf32> +} + +// ----- + func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> { // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}} %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32> diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir index acbff73b8b948..c285ae3cf44ee 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -2,6 +2,7 @@ // ----- +// CHECK-LABEL: test_matmul_fp8_mixed_precision_operands func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> { %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> @@ -146,3 +147,12 @@ func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64 %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64> return %0 : tensor<12x16xi64> } + +// ----- + +// CHECK-LABEL: test_scatter_const_indices_int64 +func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> { + %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64> + %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> + return %0 : tensor<2x52x3xf32> +}