Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1922,7 +1922,9 @@ def TileUsingForallOp :

It is the user's responsibility to ensure that `num_threads/tile_sizes` is
a valid tiling specification (i.e. that only tiles parallel dimensions,
e.g. in the Linalg case).
e.g. in the Linalg case). If the dimension is not parallelizable, a warning
is issued to notify the user that the generated code is not safe to
parallelize.

If non-empty, the `mapping` is added as an attribute to the
resulting `scf.forall`.
Expand Down
38 changes: 36 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,28 @@ static void calculateTileOffsetsAndSizes(
}
}

/// Returns a vector of bools representing if, for each axis, `op` can be tiled
/// without incurring in a race condition and thus it is thread-safe to do the
/// tiling. This is checked by iterating over numThreads and ensuring that the
/// corresponding iterator type is "parallel". If it is not, then we know that
/// such dimension is unsafe to tile.
SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
ArrayRef<OpFoldResult> numThreads) {
auto iterators = linalgOp.getIteratorTypesArray();
SmallVector<bool> safeToTile(numThreads.size(), true);

for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
}
} else {
safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
}
}
return safeToTile;
}

/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
/// tiling is specified by the number of tiles/threads `numThreads` and the
/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
Expand All @@ -314,8 +336,10 @@ static void calculateTileOffsetsAndSizes(
/// size of data.
/// It is the user's responsibility to ensure that `numThreads` is a valid
/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
/// Linalg case). If the dimension is not parallelizable, a warning is issued to
/// notify the user that the generated code is not safe to parallelize. If
/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
static FailureOr<ForallTilingResult> tileToForallOpImpl(
RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
Expand Down Expand Up @@ -344,6 +368,16 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
return getValueOrCreateConstantIndexOp(b, loc, ofr);
}));

LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
if (linalgOp) {
// Check if tiling is thread safe and print a warning if not.
SmallVector<bool> tilingSafety =
safeToTileToForall(b.getContext(), linalgOp, numThreads);
for (size_t i = 0; i < tilingSafety.size(); i++)
if (!tilingSafety[i])
op.emitWarning() << "tiling is not thread safe at axis #" << i;
}

// 1. Create the ForallOp. We don't use the lambda body-builder
// version because we require the use of RewriterBase in the body, so we
// manually move the insertion point to the body below.
Expand Down
141 changes: 141 additions & 0 deletions mlir/test/Dialect/Linalg/tile-to-forall.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,144 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>

func.func @tile_thread_safety1(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
// expected-warning@below {{tiling is not thread safe at axis #1}}
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) {
^bb0(%in: f32, %out: f32):
%1 = arith.addf %in, %out : f32
linalg.yield %1 : f32
} -> tensor<100xf32>
return %0 : tensor<100xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [4, 2]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>

func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
// expected-warning@below {{tiling is not thread safe at axis #0}}
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) {
^bb0(%in: f32, %out: f32):
%1 = arith.addf %in, %out : f32
linalg.yield %1 : f32
} -> tensor<300x8xf32>
return %0 : tensor<300x8xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>

func.func @tile_thread_safety3(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>) -> tensor<100x8xf32> {
// expected-warning@below {{tiling is not thread safe at axis #1}}
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<100x8xf32>) {
^bb0(%in: f32, %out: f32):
%1 = arith.addf %in, %out : f32
linalg.yield %1 : f32
} -> tensor<100x8xf32>
return %0 : tensor<100x8xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>

func.func @tile_thread_safety4(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>, %arg2 : tensor<8xf32>) -> (tensor<100x8xf32>, tensor<8xf32>) {
// expected-warning@+2 {{tiling is not thread safe at axis #0}}
// expected-warning@below {{tiling is not thread safe at axis #1}}
%0:2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1, %arg2 : tensor<100x8xf32>, tensor<8xf32>) {
^bb0(%in: f32, %out1: f32, %out2: f32):
%1 = arith.addf %in, %out1 : f32
%2 = arith.addf %in, %out2 : f32
linalg.yield %1, %2 : f32, f32
} -> (tensor<100x8xf32>, tensor<8xf32>)
return %0#0, %0#1 : tensor<100x8xf32>, tensor<8xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>

func.func @tile_thread_safety5(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
// expected-warning@below {{tiling is not thread safe at axis #1}}
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) {
^bb0(%in: f32, %out: f32):
%1 = arith.addf %in, %out : f32
linalg.yield %1 : f32
} -> tensor<100xf32>
return %0 : tensor<100xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%forall, %tiled_generic = transform.structured.tile_using_forall %0 tile_sizes [10, 1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// -----

func.func @tile_thread_safety6(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-warning@below {{tiling is not thread safe at axis #2}}
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
return %0 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 0, 8]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}