Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ struct ForallToParallelLoop final

parentOp->walk([&](scf::ForallOp forallOp) {
if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
return signalPassFailure();
return WalkResult::skip();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a failure? Why remove the signalPassFailure()?
There is also likely a missing error message before signalPassFailure() here (we shouldn't fail silently).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result is not a real failure as it only occurs on match failure.

The main motivation for the change is that simply calling signalPassFailure() produces no output when the pass is called (at least from CLI). I'd expect the IR to remain unchanged in such case.
I think I should've captures the walk result and added some error on interruption. But there is no reason to interrupt on this error.

Perhaps a greedy rewriter could be better here instead of walking the graph manually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scf::forallToParallelLoop function internally calls notifyMatchFailure, so some diagnostic should occur. That may not mean much if the pass terminates successfully though.

Copy link
Collaborator

@joker-eph joker-eph Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

notifyMatchFailure is a debug function. This is a question of semantics for the pass though, and unfortunately this pass does not even have a description!
Can we start here and document the pass behavior before changing it?
(is the pass promising to turn all ForAll to scf.parallel? Or it is opportunistically doing it? Under which conditions? etc)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a question of semantics for the pass though, and unfortunately this pass does not even have a description!

The lack of description is my fault. My original intention was to error out when scf.forall cannot be lowered. I don't think it makes sense to run this transform before bufferization, and after bufferization all scf.forall operations should produce no results.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points all together. My change was too eager too.

@sabauma My view on the pass is that indicating full failure (through signalPassFailure) is a bit heavy handed in this case (and viewed it as "error") but if that is the intention, it is equally valid approach.
I'll leave the pass as is. Perhaps the description could be explicit about the intended behavior.

}
return WalkResult::advance();
});
}
};
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/SCF/forall-to-parallel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,21 @@ func.func @mapping_attr() -> () {
return

}

// -----

// CHECK-LABEL: @forall_with_outputs
// CHECK-SAME: %[[ARG0:.+]]: tensor<32x32xf32>
func.func @forall_with_outputs(%arg0: tensor<32x32xf32>) -> tensor<8x112x32x32xf32> {
// CHECK-NOT: scf.parallel
// CHECK: %[[RES:.+]] = scf.forall{{.*}}shared_outs
// CHECK: return %[[RES]] : tensor<8x112x32x32xf32>

%0 = tensor.empty() : tensor<8x112x32x32xf32>
%1 = scf.forall (%arg1, %arg2) in (8, 112) shared_outs(%arg3 = %0) -> (tensor<8x112x32x32xf32>) {
scf.forall.in_parallel {
tensor.parallel_insert_slice %arg0 into %arg3[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x112x32x32xf32>
}
}
return %1 : tensor<8x112x32x32xf32>
}