diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h index 096510a09e324..6f3b0916a7a60 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -243,6 +243,11 @@ bool getConstShapeValues(Operation *op, // returns a small vector of int64_t values that attr contains SmallVector convertFromIntAttr(const DenseElementsAttr &attr, const int rank); + +// returns true iff constant indices for scatter op contains unique indices +// per batch +bool hasUniqueConstantScatterIndices(ShapedType indicesType, + DenseIntElementsAttr indicesAttr); } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index d33fc902de3a1..229f42d3178b5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1244,10 +1244,36 @@ bool checkErrorIfCondIf(Operation *op) { return true; } +bool checkErrorIfScatter(Operation *op) { + auto scatterOp = dyn_cast(op); + if (!scatterOp) + return true; + + // for constant indices, check that there are no duplicate values + DenseIntElementsAttr indicesAttr; + if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr))) + return true; + + auto const indicesType = + dyn_cast(scatterOp.getIndices().getType()); + if (!indicesType || !indicesType.hasRank()) { + op->emitOpError("expect ranked indices tensor"); + return false; + } + + if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) { + op->emitOpError("indices values contain duplicates"); + return false; + } + + return true; +} + LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op) || !checkErrorIfRescale(op) || - !checkErrorIfPad(op) || !checkErrorIfCondIf(op)) + !checkErrorIfPad(op) || !checkErrorIfCondIf(op) || + !checkErrorIfScatter(op)) return failure(); return success(); } diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index e1b3be74b50fd..9844abcc34cb1 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -213,3 +213,30 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) { } return {}; } + +bool mlir::tosa::hasUniqueConstantScatterIndices( + ShapedType indicesType, DenseIntElementsAttr indicesAttr) { + llvm::ArrayRef const indicesShape = indicesType.getShape(); + const unsigned int indicesRank = indicesShape.size(); + const unsigned int lastDimSize = indicesShape[indicesRank - 1]; + + // check each batch of indices from the flat indicesAttr values + // for duplicates + auto const indicesValues = indicesAttr.getValues(); + assert( + (indicesValues.size() % lastDimSize == 0) && + "Constant indices data length should be a multiple of indicesShape[-1]"); + + std::vector indices(lastDimSize); + for (auto beg = indicesValues.begin(); beg < indicesValues.end(); + beg += lastDimSize) { + std::copy(beg, beg + lastDimSize, indices.begin()); + std::sort(indices.begin(), indices.end()); + if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) { + // found duplicate values in indices in batch + return false; + } + } + + return true; +} diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index a4617fc6fba8b..805522799a6d8 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -2015,3 +2015,13 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8> return %r : tensor<1x1xui8> } + +// ----- + +// CHECK-LABEL: test_scatter_duplicate_indices +func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> { + %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12]]> : tensor<2x12xi32> } : () -> tensor<2x12xi32> + // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}} + %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> + return %0 : tensor<2x52x3xf32> +}