diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 309573a562872..e947720471f78 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1922,7 +1922,9 @@ def TileUsingForallOp : It is the user's responsibility to ensure that `num_threads/tile_sizes` is a valid tiling specification (i.e. that only tiles parallel dimensions, - e.g. in the Linalg case). + e.g. in the Linalg case). If the dimension is not parallelizable, a warning + is issued to notify the user that the generated code is not safe to + parallelize. If non-empty, the `mapping` is added as an attribute to the resulting `scf.forall`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 30aed850bed81..462f692615faa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -304,6 +304,28 @@ static void calculateTileOffsetsAndSizes( } } +/// Returns a vector of bools representing if, for each axis, `op` can be tiled +/// without incurring in a race condition and thus it is thread-safe to do the +/// tiling. This is checked by iterating over numThreads and ensuring that the +/// corresponding iterator type is "parallel". If it is not, then we know that +/// such dimension is unsafe to tile. +SmallVector safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp, + ArrayRef numThreads) { + auto iterators = linalgOp.getIteratorTypesArray(); + SmallVector safeToTile(numThreads.size(), true); + + for (unsigned i = 0, e = numThreads.size(); i != e; i++) { + if (auto attr = llvm::dyn_cast_if_present(numThreads[i])) { + if (cast(attr).getValue().getSExtValue() > 1) { + safeToTile[i] = iterators[i] == utils::IteratorType::parallel; + } + } else { + safeToTile[i] = iterators[i] == utils::IteratorType::parallel; + } + } + return safeToTile; +} + /// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The /// tiling is specified by the number of tiles/threads `numThreads` and the /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is @@ -314,8 +336,10 @@ static void calculateTileOffsetsAndSizes( /// size of data. /// It is the user's responsibility to ensure that `numThreads` is a valid /// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the -/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will -/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. +/// Linalg case). If the dimension is not parallelizable, a warning is issued to +/// notify the user that the generated code is not safe to parallelize. If +/// `omitTileOffsetBoundsCheck` is true, then the function will assume that +/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. static FailureOr tileToForallOpImpl( RewriterBase &b, TilingInterface op, ArrayRef numThreads, std::optional> nominalTileSizes, @@ -344,6 +368,16 @@ static FailureOr tileToForallOpImpl( return getValueOrCreateConstantIndexOp(b, loc, ofr); })); + LinalgOp linalgOp = dyn_cast(op.getOperation()); + if (linalgOp) { + // Check if tiling is thread safe and print a warning if not. + SmallVector tilingSafety = + safeToTileToForall(b.getContext(), linalgOp, numThreads); + for (size_t i = 0; i < tilingSafety.size(); i++) + if (!tilingSafety[i]) + op.emitWarning() << "tiling is not thread safe at axis #" << i; + } + // 1. Create the ForallOp. We don't use the lambda body-builder // version because we require the use of RewriterBase in the body, so we // manually move the insertion point to the body below. diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir index abd807b3e4d3e..12e2dea5530b5 100644 --- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir @@ -586,3 +586,144 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @tile_thread_safety1(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100xf32> + return %0 : tensor<100xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> + +func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { + // expected-warning@below {{tiling is not thread safe at axis #0}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<300x8xf32> + return %0 : tensor<300x8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> + +func.func @tile_thread_safety3(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>) -> tensor<100x8xf32> { + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<100x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100x8xf32> + return %0 : tensor<100x8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d2)> + +func.func @tile_thread_safety4(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>, %arg2 : tensor<8xf32>) -> (tensor<100x8xf32>, tensor<8xf32>) { + // expected-warning@+2 {{tiling is not thread safe at axis #0}} + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0:2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1, %arg2 : tensor<100x8xf32>, tensor<8xf32>) { + ^bb0(%in: f32, %out1: f32, %out2: f32): + %1 = arith.addf %in, %out1 : f32 + %2 = arith.addf %in, %out2 : f32 + linalg.yield %1, %2 : f32, f32 + } -> (tensor<100x8xf32>, tensor<8xf32>) + return %0#0, %0#1 : tensor<100x8xf32>, tensor<8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @tile_thread_safety5(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100xf32> + return %0 : tensor<100xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 tile_sizes [10, 1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @tile_thread_safety6(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // expected-warning@below {{tiling is not thread safe at axis #2}} + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 0, 8] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +}