From 5d459622a7a91cf212b0ff1c512b8043364d94bc Mon Sep 17 00:00:00 2001 From: Pablo Antonio Martinez Date: Mon, 5 Feb 2024 16:57:59 +0000 Subject: [PATCH 1/5] [mlir][Linalg][Transform] Emit a warning when tile_using_forall generates non thread-safe code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This warning aims to complement the comment in the documentation that says: "It is the user’s responsibility to ensure that num_threads/tile_sizes is a valid tiling specification (i.e. that only tiles parallel dimensions, e.g. in the Linalg case)." because: 1. Not all users of tile_using_forall know that tiling the wrong dimension/s (e.g., a non-parallel dimension) will generate non thread-safe code, so this warning will inform the user about it. 2. Users of tile_using_forall may know this limitation, but they may not realize that they are tiling a non-parallel dimension, so the warning may help in the debugging process. --- .../Linalg/TransformOps/LinalgTransformOps.td | 4 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 57 +++++++++- mlir/test/Dialect/Linalg/tile-to-forall.mlir | 100 ++++++++++++++++++ 3 files changed, 158 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 309573a562872..e947720471f78 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1922,7 +1922,9 @@ def TileUsingForallOp : It is the user's responsibility to ensure that `num_threads/tile_sizes` is a valid tiling specification (i.e. that only tiles parallel dimensions, - e.g. in the Linalg case). + e.g. in the Linalg case). If the dimension is not parallelizable, a warning + is issued to notify the user that the generated code is not safe to + parallelize. If non-empty, the `mapping` is added as an attribute to the resulting `scf.forall`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 30aed850bed81..ed97ad70e6e39 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -304,6 +304,50 @@ static void calculateTileOffsetsAndSizes( } } +/// Returns a vector of bools representing if, for the given axis, `op` can be +/// tiled by `numThreads` without incurring in a race condition and thus it is +/// thread-safe to do the tiling. This is checked by iterating over the affine +/// maps of the outputs in `op` and ensuring that all the results in the map are +/// present in the affine map represented by the tiling sizes, which is derived +/// from `numThreads` or `nominalTileSizes`. +SmallVector +safeToTileToForall(mlir::MLIRContext *ctx, TilingInterface op, + ArrayRef numThreads, + std::optional> nominalTileSizes, + int numDims) { + ArrayRef tilingValues = + nominalTileSizes.has_value() ? *nominalTileSizes : numThreads; + + SmallVector safeToTile(tilingValues.size(), true); + LinalgOp linalgOp = dyn_cast(op.getOperation()); + if (!linalgOp) + return safeToTile; + + SmallVector dimExprs; + dimExprs.reserve(numDims); + for (unsigned i = 0; i < tilingValues.size(); i++) { + if (auto attr = llvm::dyn_cast_if_present(tilingValues[i])) { + if (cast(attr).getValue().getSExtValue() > 1) + dimExprs.push_back(mlir::getAffineDimExpr(i, ctx)); + } else { + dimExprs.push_back(mlir::getAffineDimExpr(i, ctx)); + } + } + + for (uint32_t resNum = 0; resNum < op->getNumResults(); resNum++) { + AffineMap map = + linalgOp.getIndexingMapMatchingResult(op->getResult(resNum)); + + for (AffineExpr r : dimExprs) { + unsigned int axis = cast(r).getPosition(); + if (!llvm::is_contained(map.getResults(), r)) + safeToTile[axis] = false; + } + } + + return safeToTile; +} + /// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The /// tiling is specified by the number of tiles/threads `numThreads` and the /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is @@ -314,8 +358,10 @@ static void calculateTileOffsetsAndSizes( /// size of data. /// It is the user's responsibility to ensure that `numThreads` is a valid /// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the -/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will -/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. +/// Linalg case). If the dimension is not parallelizable, a warning is issued to +/// notify the user that the generated code is not safe to parallelize. If +/// `omitTileOffsetBoundsCheck` is true, then the function will assume that +/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. static FailureOr tileToForallOpImpl( RewriterBase &b, TilingInterface op, ArrayRef numThreads, std::optional> nominalTileSizes, @@ -344,6 +390,13 @@ static FailureOr tileToForallOpImpl( return getValueOrCreateConstantIndexOp(b, loc, ofr); })); + // Check if tiling is thread safe and print a warning if not. + SmallVector tilingSafety = safeToTileToForall( + b.getContext(), op, numThreads, nominalTileSizes, loopRanges.size()); + for (size_t i = 0; i < tilingSafety.size(); i++) + if (!tilingSafety[i]) + op.emitWarning() << "tiling is not thread safe at axis #" << i; + // 1. Create the ForallOp. We don't use the lambda body-builder // version because we require the use of RewriterBase in the body, so we // manually move the insertion point to the body below. diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir index abd807b3e4d3e..e52f76c619575 100644 --- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir @@ -586,3 +586,103 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @tile_thread_safety1(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-warning@+1 {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100xf32> + return %0 : tensor<100xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> + +func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { + // expected-warning@+1 {{tiling is not thread safe at axis #0}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<300x8xf32> + return %0 : tensor<300x8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> + +func.func @tile_thread_safety3(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>) -> tensor<100x8xf32> { + // expected-warning@+1 {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<100x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100x8xf32> + return %0 : tensor<100x8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d2)> + +func.func @tile_thread_safety4(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>, %arg2 : tensor<8xf32>) -> (tensor<100x8xf32>, tensor<8xf32>) { + // expected-warning@+2 {{tiling is not thread safe at axis #0}} + // expected-warning@+1 {{tiling is not thread safe at axis #1}} + %0:2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1, %arg2 : tensor<100x8xf32>, tensor<8xf32>) { + ^bb0(%in: f32, %out1: f32, %out2: f32): + %1 = arith.addf %in, %out1 : f32 + %2 = arith.addf %in, %out2 : f32 + linalg.yield %1, %2 : f32, f32 + } -> (tensor<100x8xf32>, tensor<8xf32>) + return %0#0, %0#1 : tensor<100x8xf32>, tensor<8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + From cb774445e9b4069af84d072e06866497136c90e5 Mon Sep 17 00:00:00 2001 From: Pablo Antonio Martinez Date: Thu, 14 Mar 2024 16:57:02 +0000 Subject: [PATCH 2/5] [mlir][Linalg][Transform] Small nits, bugfix and tests. Fix bug when tile_size=1 was specified. Also, add test case using tile_size and another use case to tile linalg.matmul to show that this also works for non linalg.generic ops --- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 40 +++++++-------- mlir/test/Dialect/Linalg/tile-to-forall.mlir | 49 +++++++++++++++++-- 2 files changed, 65 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index ed97ad70e6e39..fe9a3c658e6d5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -305,38 +305,35 @@ static void calculateTileOffsetsAndSizes( } /// Returns a vector of bools representing if, for the given axis, `op` can be -/// tiled by `numThreads` without incurring in a race condition and thus it is -/// thread-safe to do the tiling. This is checked by iterating over the affine -/// maps of the outputs in `op` and ensuring that all the results in the map are -/// present in the affine map represented by the tiling sizes, which is derived -/// from `numThreads` or `nominalTileSizes`. +/// tiled by without incurring in a race condition and thus it is thread-safe to +/// do the tiling. This is checked by iterating over the affine maps of the +/// outputs in `op` and ensuring that all the results in the map are present in +/// the affine map represented by the tiling sizes, which is derived from +/// `numThreads` or `nominalTileSizes`. SmallVector -safeToTileToForall(mlir::MLIRContext *ctx, TilingInterface op, +safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp, ArrayRef numThreads, std::optional> nominalTileSizes, int numDims) { ArrayRef tilingValues = nominalTileSizes.has_value() ? *nominalTileSizes : numThreads; + int minTile = nominalTileSizes.has_value() ? 0 : 1; SmallVector safeToTile(tilingValues.size(), true); - LinalgOp linalgOp = dyn_cast(op.getOperation()); - if (!linalgOp) - return safeToTile; - SmallVector dimExprs; dimExprs.reserve(numDims); - for (unsigned i = 0; i < tilingValues.size(); i++) { + for (unsigned i = 0, e = tilingValues.size(); i != e; i++) { if (auto attr = llvm::dyn_cast_if_present(tilingValues[i])) { - if (cast(attr).getValue().getSExtValue() > 1) + if (cast(attr).getValue().getSExtValue() > minTile) dimExprs.push_back(mlir::getAffineDimExpr(i, ctx)); } else { dimExprs.push_back(mlir::getAffineDimExpr(i, ctx)); } } - for (uint32_t resNum = 0; resNum < op->getNumResults(); resNum++) { + for (unsigned resNum = 0; resNum < linalgOp->getNumResults(); resNum++) { AffineMap map = - linalgOp.getIndexingMapMatchingResult(op->getResult(resNum)); + linalgOp.getIndexingMapMatchingResult(linalgOp->getResult(resNum)); for (AffineExpr r : dimExprs) { unsigned int axis = cast(r).getPosition(); @@ -390,12 +387,15 @@ static FailureOr tileToForallOpImpl( return getValueOrCreateConstantIndexOp(b, loc, ofr); })); - // Check if tiling is thread safe and print a warning if not. - SmallVector tilingSafety = safeToTileToForall( - b.getContext(), op, numThreads, nominalTileSizes, loopRanges.size()); - for (size_t i = 0; i < tilingSafety.size(); i++) - if (!tilingSafety[i]) - op.emitWarning() << "tiling is not thread safe at axis #" << i; + LinalgOp linalgOp = dyn_cast(op.getOperation()); + if (linalgOp) { + // Check if tiling is thread safe and print a warning if not. + SmallVector tilingSafety = safeToTileToForall( + b.getContext(), linalgOp, numThreads, nominalTileSizes, loopRanges.size()); + for (size_t i = 0; i < tilingSafety.size(); i++) + if (!tilingSafety[i]) + op.emitWarning() << "tiling is not thread safe at axis #" << i; + } // 1. Create the ForallOp. We don't use the lambda body-builder // version because we require the use of RewriterBase in the body, so we diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir index e52f76c619575..74eb0b12aa8d1 100644 --- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir @@ -593,7 +593,7 @@ module attributes {transform.with_named_sequence} { #map1 = affine_map<(d0, d1) -> (d0)> func.func @tile_thread_safety1(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { - // expected-warning@+1 {{tiling is not thread safe at axis #1}} + // expected-warning@below {{tiling is not thread safe at axis #1}} %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) { ^bb0(%in: f32, %out: f32): %1 = arith.addf %in, %out : f32 @@ -617,7 +617,7 @@ module attributes {transform.with_named_sequence} { #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { - // expected-warning@+1 {{tiling is not thread safe at axis #0}} + // expected-warning@below {{tiling is not thread safe at axis #0}} %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) { ^bb0(%in: f32, %out: f32): %1 = arith.addf %in, %out : f32 @@ -641,7 +641,7 @@ module attributes {transform.with_named_sequence} { #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> func.func @tile_thread_safety3(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>) -> tensor<100x8xf32> { - // expected-warning@+1 {{tiling is not thread safe at axis #1}} + // expected-warning@below {{tiling is not thread safe at axis #1}} %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<100x8xf32>) { ^bb0(%in: f32, %out: f32): %1 = arith.addf %in, %out : f32 @@ -667,7 +667,7 @@ module attributes {transform.with_named_sequence} { func.func @tile_thread_safety4(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>, %arg2 : tensor<8xf32>) -> (tensor<100x8xf32>, tensor<8xf32>) { // expected-warning@+2 {{tiling is not thread safe at axis #0}} - // expected-warning@+1 {{tiling is not thread safe at axis #1}} + // expected-warning@below {{tiling is not thread safe at axis #1}} %0:2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1, %arg2 : tensor<100x8xf32>, tensor<8xf32>) { ^bb0(%in: f32, %out1: f32, %out2: f32): %1 = arith.addf %in, %out1 : f32 @@ -686,3 +686,44 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @tile_thread_safety5(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100xf32> + return %0 : tensor<100xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 tile_sizes [10, 1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @tile_thread_safety6(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // expected-warning@below {{tiling is not thread safe at axis #2}} + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 4, 8] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} \ No newline at end of file From 3cfccde6c1170371ed1ca48949b69ccbd39fc255 Mon Sep 17 00:00:00 2001 From: Pablo Antonio Martinez Date: Thu, 14 Mar 2024 16:59:13 +0000 Subject: [PATCH 3/5] [mlir][Linalg][Transform] Simplify implementation Rather than comparing the outputs affine maps against the maps infered from the tile sizes, we can simply check that tiled dimensions do not contain the "reduction" iterator type. If they do, then we are certain they are not safe to tile. --- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index fe9a3c658e6d5..05d2d7cd94553 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -306,10 +306,11 @@ static void calculateTileOffsetsAndSizes( /// Returns a vector of bools representing if, for the given axis, `op` can be /// tiled by without incurring in a race condition and thus it is thread-safe to -/// do the tiling. This is checked by iterating over the affine maps of the -/// outputs in `op` and ensuring that all the results in the map are present in -/// the affine map represented by the tiling sizes, which is derived from -/// `numThreads` or `nominalTileSizes`. +/// do the tiling. This is checked by iterating over the affine map represented +/// by the tiling sizes (which is derived from `numThreads` or +/// `nominalTileSizes`) and ensuring that the corresponding iterator type is +/// not "reduction". If it is, then we know that such dimension is unsafe to +/// tile. SmallVector safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp, ArrayRef numThreads, @@ -331,15 +332,11 @@ safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp, } } - for (unsigned resNum = 0; resNum < linalgOp->getNumResults(); resNum++) { - AffineMap map = - linalgOp.getIndexingMapMatchingResult(linalgOp->getResult(resNum)); - - for (AffineExpr r : dimExprs) { - unsigned int axis = cast(r).getPosition(); - if (!llvm::is_contained(map.getResults(), r)) - safeToTile[axis] = false; - } + auto iterators = linalgOp.getIteratorTypesArray(); + for (AffineExpr r : dimExprs) { + unsigned int axis = cast(r).getPosition(); + if (iterators[axis] == utils::IteratorType::reduction) + safeToTile[axis] = false; } return safeToTile; @@ -390,8 +387,9 @@ static FailureOr tileToForallOpImpl( LinalgOp linalgOp = dyn_cast(op.getOperation()); if (linalgOp) { // Check if tiling is thread safe and print a warning if not. - SmallVector tilingSafety = safeToTileToForall( - b.getContext(), linalgOp, numThreads, nominalTileSizes, loopRanges.size()); + SmallVector tilingSafety = + safeToTileToForall(b.getContext(), linalgOp, numThreads, + nominalTileSizes, loopRanges.size()); for (size_t i = 0; i < tilingSafety.size(); i++) if (!tilingSafety[i]) op.emitWarning() << "tiling is not thread safe at axis #" << i; From 903e883c316e44cd3361408b2580e667f8bba9cf Mon Sep 17 00:00:00 2001 From: Pablo Antonio Martinez Date: Thu, 21 Mar 2024 16:26:19 +0000 Subject: [PATCH 4/5] [mlir][Linalg][Transform] Simplify even more the implementation Simplify the implementation by removing the end loop and by only checking numThreads (before it was also checking nominalTileSizes). This is possible because in the event that num_threads is not specified, it is automatically derived from tile_sizes, so we are always sure that numThreads will contain the values we need. Also changed a bit the tests to show that it the implementation properly handles the case where there is a zero. --- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 53 +++++++------------ mlir/test/Dialect/Linalg/tile-to-forall.mlir | 4 +- 2 files changed, 20 insertions(+), 37 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 05d2d7cd94553..19a74b15c947f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -304,41 +304,25 @@ static void calculateTileOffsetsAndSizes( } } -/// Returns a vector of bools representing if, for the given axis, `op` can be -/// tiled by without incurring in a race condition and thus it is thread-safe to -/// do the tiling. This is checked by iterating over the affine map represented -/// by the tiling sizes (which is derived from `numThreads` or -/// `nominalTileSizes`) and ensuring that the corresponding iterator type is -/// not "reduction". If it is, then we know that such dimension is unsafe to -/// tile. -SmallVector -safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp, - ArrayRef numThreads, - std::optional> nominalTileSizes, - int numDims) { - ArrayRef tilingValues = - nominalTileSizes.has_value() ? *nominalTileSizes : numThreads; - int minTile = nominalTileSizes.has_value() ? 0 : 1; - - SmallVector safeToTile(tilingValues.size(), true); - SmallVector dimExprs; - dimExprs.reserve(numDims); - for (unsigned i = 0, e = tilingValues.size(); i != e; i++) { - if (auto attr = llvm::dyn_cast_if_present(tilingValues[i])) { - if (cast(attr).getValue().getSExtValue() > minTile) - dimExprs.push_back(mlir::getAffineDimExpr(i, ctx)); +/// Returns a vector of bools representing if, for each axis, `op` can be tiled +/// without incurring in a race condition and thus it is thread-safe to do the +/// tiling. This is checked by iterating over numThreads and ensuring that the +/// corresponding iterator type is "parallel". If it is not, then we know that +/// such dimension is unsafe to tile. +SmallVector safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp, + ArrayRef numThreads) { + auto iterators = linalgOp.getIteratorTypesArray(); + SmallVector safeToTile(numThreads.size(), true); + + for (unsigned i = 0, e = numThreads.size(); i != e; i++) { + if (auto attr = llvm::dyn_cast_if_present(numThreads[i])) { + if (cast(attr).getValue().getSExtValue() > 1) { + safeToTile[i] = iterators[i] == utils::IteratorType::parallel; + } } else { - dimExprs.push_back(mlir::getAffineDimExpr(i, ctx)); + safeToTile[i] = iterators[i] == utils::IteratorType::parallel; } } - - auto iterators = linalgOp.getIteratorTypesArray(); - for (AffineExpr r : dimExprs) { - unsigned int axis = cast(r).getPosition(); - if (iterators[axis] == utils::IteratorType::reduction) - safeToTile[axis] = false; - } - return safeToTile; } @@ -387,9 +371,8 @@ static FailureOr tileToForallOpImpl( LinalgOp linalgOp = dyn_cast(op.getOperation()); if (linalgOp) { // Check if tiling is thread safe and print a warning if not. - SmallVector tilingSafety = - safeToTileToForall(b.getContext(), linalgOp, numThreads, - nominalTileSizes, loopRanges.size()); + SmallVector tilingSafety = safeToTileToForall( + b.getContext(), linalgOp, numThreads); for (size_t i = 0; i < tilingSafety.size(); i++) if (!tilingSafety[i]) op.emitWarning() << "tiling is not thread safe at axis #" << i; diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir index 74eb0b12aa8d1..12e2dea5530b5 100644 --- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir @@ -722,8 +722,8 @@ func.func @tile_thread_safety6(%A: tensor, %B: tensor, %C: ten module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 4, 8] + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 0, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } -} \ No newline at end of file +} From ad3fc41b831b1e2995126fd6128c522407583061 Mon Sep 17 00:00:00 2001 From: Pablo Antonio Martinez Date: Thu, 21 Mar 2024 16:35:20 +0000 Subject: [PATCH 5/5] clang-format --- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 19a74b15c947f..462f692615faa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -371,8 +371,8 @@ static FailureOr tileToForallOpImpl( LinalgOp linalgOp = dyn_cast(op.getOperation()); if (linalgOp) { // Check if tiling is thread safe and print a warning if not. - SmallVector tilingSafety = safeToTileToForall( - b.getContext(), linalgOp, numThreads); + SmallVector tilingSafety = + safeToTileToForall(b.getContext(), linalgOp, numThreads); for (size_t i = 0; i < tilingSafety.size(); i++) if (!tilingSafety[i]) op.emitWarning() << "tiling is not thread safe at axis #" << i;