From 97a2bd8415dc6792b99ec0f091ad7570673c3f37 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 4 Jul 2024 09:24:23 +0200 Subject: [PATCH] Revert "[mlir][loops] Reland Refactor LoopFuseSiblingOp and support parallel fusion #94391 (#97607)" This reverts commit edbc0e30a9e587cee1189be023b9385adc2f239a. Reason for rollback. ASAN complains about this PR: ==4320==ERROR: AddressSanitizer: heap-use-after-free on address 0x502000006cd8 at pc 0x55e2978d63cf bp 0x7ffe6431c2b0 sp 0x7ffe6431c2a8 READ of size 8 at 0x502000006cd8 thread T0 #0 0x55e2978d63ce in map &, llvm::MutableArrayRef, nullptr> mlir/include/mlir/IR/IRMapping.h:40:11 #1 0x55e2978d63ce in mlir::createFused(mlir::LoopLikeOpInterface, mlir::LoopLikeOpInterface, mlir::RewriterBase&, std::__u::function (mlir::OpBuilder&, mlir::Location, llvm::ArrayRef)>, llvm::function_ref) mlir/lib/Interfaces/LoopLikeInterface.cpp:156:11 #2 0x55e2952a614b in mlir::fuseIndependentSiblingForLoops(mlir::scf::ForOp, mlir::scf::ForOp, mlir::RewriterBase&) mlir/lib/Dialect/SCF/Utils/Utils.cpp:1398:43 #3 0x55e291480c6f in mlir::transform::LoopFuseSiblingOp::apply(mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp:482:17 #4 0x55e29149ed5e in mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model::apply(mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477:56 #5 0x55e297494a60 in apply blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61:14 #6 0x55e297494a60 in mlir::transform::TransformState::applyTransform(mlir::transform::TransformOpInterface) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:953:48 #7 0x55e294646a8d in applySequenceBlock(mlir::Block&, mlir::transform::FailurePropagationMode, mlir::transform::TransformState&, mlir::transform::TransformResults&) mlir/lib/Dialect/Transform/IR/TransformOps.cpp:1788:15 #8 0x55e29464f927 in mlir::transform::NamedSequenceOp::apply(mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) mlir/lib/Dialect/Transform/IR/TransformOps.cpp:2155:10 #9 0x55e2945d28ee in mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model::apply(mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477:56 #10 0x55e297494a60 in apply blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61:14 #11 0x55e297494a60 in mlir::transform::TransformState::applyTransform(mlir::transform::TransformOpInterface) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:953:48 #12 0x55e2974a5fe2 in mlir::transform::applyTransforms(mlir::Operation*, mlir::transform::TransformOpInterface, mlir::RaggedArray> const&, mlir::transform::TransformOptions const&, bool) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:2016:16 #13 0x55e2945888d7 in mlir::transform::applyTransformNamedSequence(mlir::RaggedArray>, mlir::transform::TransformOpInterface, mlir::ModuleOp, mlir::transform::TransformOptions const&) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp:234:10 #14 0x55e294582446 in (anonymous namespace)::InterpreterPass::runOnOperation() mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp:147:16 #15 0x55e2978e93c6 in operator() mlir/lib/Pass/Pass.cpp:527:17 #16 0x55e2978e93c6 in void llvm::function_ref::callback_fn(long) llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12 #17 0x55e2978e207a in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12 #18 0x55e2978e207a in executeAction mlir/include/mlir/IR/MLIRContext.h:275:7 #19 0x55e2978e207a in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) mlir/lib/Pass/Pass.cpp:521:21 #20 0x55e2978e5fbf in runPipeline mlir/lib/Pass/Pass.cpp:593:16 #21 0x55e2978e5fbf in mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) mlir/lib/Pass/Pass.cpp:904:10 #22 0x55e2978e5b65 in mlir::PassManager::run(mlir::Operation*) mlir/lib/Pass/Pass.cpp:884:60 #23 0x55e291ebb460 in performActions(llvm::raw_ostream&, std::__u::shared_ptr const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:408:17 #24 0x55e291ebabd9 in processBuffer mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:481:9 #25 0x55e291ebabd9 in operator() mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:548:12 #26 0x55e291ebabd9 in llvm::LogicalResult llvm::function_ref>, llvm::raw_ostream&)>::callback_fn>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::__u::unique_ptr>, llvm::raw_ostream&) llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12 #27 0x55e297b1cffe in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12 #28 0x55e297b1cffe in mlir::splitAndProcessBuffer(std::__u::unique_ptr>, llvm::function_ref>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef)::$_0::operator()(llvm::StringRef) const mlir/lib/Support/ToolUtilities.cpp:86:16 #29 0x55e297b1c9c5 in interleave llvm/include/llvm/ADT/STLExtras.h:2125:3 #30 0x55e297b1c9c5 in interleave, (lambda at mlir/lib/Support/ToolUtilities.cpp:79:23), llvm::raw_ostream, llvm::StringRef> llvm/include/llvm/ADT/STLExtras.h:2147:3 #31 0x55e297b1c9c5 in mlir::splitAndProcessBuffer(std::__u::unique_ptr>, llvm::function_ref>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) mlir/lib/Support/ToolUtilities.cpp:89:3 #32 0x55e291eb0cf0 in mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:551:10 #33 0x55e291eb115c in mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:589:14 #34 0x55e291eb15f8 in mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:605:10 #35 0x55e29130d1be in main mlir/tools/mlir-opt/mlir-opt.cpp:311:33 #36 0x7fbcf3fff3d3 in __libc_start_main (/usr/grte/v5/lib64/libc.so.6+0x613d3) (BuildId: 9a996398ce14a94560b0c642eb4f6e94) #37 0x55e2912365a9 in _start /usr/grte/v5/debug-src/src/csu/../sysdeps/x86_64/start.S:120 0x502000006cd8 is located 8 bytes inside of 16-byte region [0x502000006cd0,0x502000006ce0) freed by thread T0 here: #0 0x55e29130b7e2 in operator delete(void*, unsigned long) compiler-rt/lib/asan/asan_new_delete.cpp:155:3 #1 0x55e2979eb657 in __libcpp_operator_delete #2 0x55e2979eb657 in __do_deallocate_handle_size<> #3 0x55e2979eb657 in __libcpp_deallocate #4 0x55e2979eb657 in deallocate #5 0x55e2979eb657 in deallocate #6 0x55e2979eb657 in operator() #7 0x55e2979eb657 in ~vector #8 0x55e2979eb657 in mlir::Block::~Block() mlir/lib/IR/Block.cpp:24:1 #9 0x55e2979ebc17 in deleteNode llvm/include/llvm/ADT/ilist.h:42:39 #10 0x55e2979ebc17 in erase llvm/include/llvm/ADT/ilist.h:205:5 #11 0x55e2979ebc17 in erase llvm/include/llvm/ADT/ilist.h:209:39 #12 0x55e2979ebc17 in mlir::Block::erase() mlir/lib/IR/Block.cpp:67:28 #13 0x55e297aef978 in mlir::RewriterBase::eraseBlock(mlir::Block*) mlir/lib/IR/PatternMatch.cpp:245:10 #14 0x55e297af0563 in mlir::RewriterBase::inlineBlockBefore(mlir::Block*, mlir::Block*, llvm::ilist_iterator, false, false>, mlir::ValueRange) mlir/lib/IR/PatternMatch.cpp:331:3 #15 0x55e297af06d8 in mlir::RewriterBase::mergeBlocks(mlir::Block*, mlir::Block*, mlir::ValueRange) mlir/lib/IR/PatternMatch.cpp:341:3 #16 0x55e297036608 in mlir::scf::ForOp::replaceWithAdditionalYields(mlir::RewriterBase&, mlir::ValueRange, bool, std::__u::function (mlir::OpBuilder&, mlir::Location, llvm::ArrayRef)> const&) mlir/lib/Dialect/SCF/IR/SCF.cpp:575:12 #17 0x55e2970673ca in mlir::detail::LoopLikeOpInterfaceInterfaceTraits::Model::replaceWithAdditionalYields(mlir::detail::LoopLikeOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::RewriterBase&, mlir::ValueRange, bool, std::__u::function (mlir::OpBuilder&, mlir::Location, llvm::ArrayRef)> const&) blaze-out/k8-opt-asan/bin/mlir/include/mlir/Interfaces/LoopLikeInterface.h.inc:658:56 #18 0x55e2978d5feb in replaceWithAdditionalYields blaze-out/k8-opt-asan/bin/mlir/include/mlir/Interfaces/LoopLikeInterface.cpp.inc:105:14 #19 0x55e2978d5feb in mlir::createFused(mlir::LoopLikeOpInterface, mlir::LoopLikeOpInterface, mlir::RewriterBase&, std::__u::function (mlir::OpBuilder&, mlir::Location, llvm::ArrayRef)>, llvm::function_ref) mlir/lib/Interfaces/LoopLikeInterface.cpp:135:14 #20 0x55e2952a614b in mlir::fuseIndependentSiblingForLoops(mlir::scf::ForOp, mlir::scf::ForOp, mlir::RewriterBase&) mlir/lib/Dialect/SCF/Utils/Utils.cpp:1398:43 #21 0x55e291480c6f in mlir::transform::LoopFuseSiblingOp::apply(mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp:482:17 #22 0x55e29149ed5e in mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model::apply(mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477:56 #23 0x55e297494a60 in apply blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61:14 #24 0x55e297494a60 in mlir::transform::TransformState::applyTransform(mlir::transform::TransformOpInterface) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:953:48 #25 0x55e294646a8d in applySequenceBlock(mlir::Block&, mlir::transform::FailurePropagationMode, mlir::transform::TransformState&, mlir::transform::TransformResults&) mlir/lib/Dialect/Transform/IR/TransformOps.cpp:1788:15 #26 0x55e29464f927 in mlir::transform::NamedSequenceOp::apply(mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) mlir/lib/Dialect/Transform/IR/TransformOps.cpp:2155:10 #27 0x55e2945d28ee in mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model::apply(mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477:56 #28 0x55e297494a60 in apply blaze-out/k8-opt-asan/bin/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61:14 #29 0x55e297494a60 in mlir::transform::TransformState::applyTransform(mlir::transform::TransformOpInterface) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:953:48 #30 0x55e2974a5fe2 in mlir::transform::applyTransforms(mlir::Operation*, mlir::transform::TransformOpInterface, mlir::RaggedArray> const&, mlir::transform::TransformOptions const&, bool) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp:2016:16 #31 0x55e2945888d7 in mlir::transform::applyTransformNamedSequence(mlir::RaggedArray>, mlir::transform::TransformOpInterface, mlir::ModuleOp, mlir::transform::TransformOptions const&) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp:234:10 #32 0x55e294582446 in (anonymous namespace)::InterpreterPass::runOnOperation() mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp:147:16 #33 0x55e2978e93c6 in operator() mlir/lib/Pass/Pass.cpp:527:17 #34 0x55e2978e93c6 in void llvm::function_ref::callback_fn(long) llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12 #35 0x55e2978e207a in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12 #36 0x55e2978e207a in executeAction mlir/include/mlir/IR/MLIRContext.h:275:7 #37 0x55e2978e207a in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) mlir/lib/Pass/Pass.cpp:521:21 #38 0x55e2978e5fbf in runPipeline mlir/lib/Pass/Pass.cpp:593:16 #39 0x55e2978e5fbf in mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) mlir/lib/Pass/Pass.cpp:904:10 #40 0x55e2978e5b65 in mlir::PassManager::run(mlir::Operation*) mlir/lib/Pass/Pass.cpp:884:60 #41 0x55e291ebb460 in performActions(llvm::raw_ostream&, std::__u::shared_ptr const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:408:17 #42 0x55e291ebabd9 in processBuffer mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:481:9 #43 0x55e291ebabd9 in operator() mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:548:12 #44 0x55e291ebabd9 in llvm::LogicalResult llvm::function_ref>, llvm::raw_ostream&)>::callback_fn>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::__u::unique_ptr>, llvm::raw_ostream&) llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12 #45 0x55e297b1cffe in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12 #46 0x55e297b1cffe in mlir::splitAndProcessBuffer(std::__u::unique_ptr>, llvm::function_ref>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef)::$_0::operator()(llvm::StringRef) const mlir/lib/Support/ToolUtilities.cpp:86:16 #47 0x55e297b1c9c5 in interleave llvm/include/llvm/ADT/STLExtras.h:2125:3 #48 0x55e297b1c9c5 in interleave, (lambda at mlir/lib/Support/ToolUtilities.cpp:79:23), llvm::raw_ostream, llvm::StringRef> llvm/include/llvm/ADT/STLExtras.h:2147:3 #49 0x55e297b1c9c5 in mlir::splitAndProcessBuffer(std::__u::unique_ptr>, llvm::function_ref>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) mlir/lib/Support/ToolUtilities.cpp:89:3 #50 0x55e291eb0cf0 in mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:551:10 #51 0x55e291eb115c in mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:589:14 previously allocated by thread T0 here: #0 0x55e29130ab5d in operator new(unsigned long) compiler-rt/lib/asan/asan_new_delete.cpp:86:3 #1 0x55e2979ed5d4 in __libcpp_operator_new #2 0x55e2979ed5d4 in __libcpp_allocate #3 0x55e2979ed5d4 in allocate #4 0x55e2979ed5d4 in __allocate_at_least > #5 0x55e2979ed5d4 in __split_buffer #6 0x55e2979ed5d4 in mlir::BlockArgument* std::__u::vector>::__push_back_slow_path(mlir::BlockArgument const&) #7 0x55e2979ec0f2 in push_back #8 0x55e2979ec0f2 in mlir::Block::addArgument(mlir::Type, mlir::Location) mlir/lib/IR/Block.cpp:154:13 #9 0x55e29796e457 in parseRegionBody mlir/lib/AsmParser/Parser.cpp:2172:34 #10 0x55e29796e457 in (anonymous namespace)::OperationParser::parseRegion(mlir::Region&, llvm::ArrayRef, bool) mlir/lib/AsmParser/Parser.cpp:2121:7 #11 0x55e29796b25e in (anonymous namespace)::CustomOpAsmParser::parseRegion(mlir::Region&, llvm::ArrayRef, bool) mlir/lib/AsmParser/Parser.cpp:1785:16 #12 0x55e297035742 in mlir::scf::ForOp::parse(mlir::OpAsmParser&, mlir::OperationState&) mlir/lib/Dialect/SCF/IR/SCF.cpp:521:14 #13 0x55e291322c18 in llvm::ParseResult llvm::detail::UniqueFunctionBase::CallImpl(void*, mlir::OpAsmParser&, mlir::OperationState&) llvm/include/llvm/ADT/FunctionExtras.h:220:12 #14 0x55e29795bea3 in operator() llvm/include/llvm/ADT/FunctionExtras.h:384:12 #15 0x55e29795bea3 in callback_fn > llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12 #16 0x55e29795bea3 in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12 #17 0x55e29795bea3 in parseOperation mlir/lib/AsmParser/Parser.cpp:1521:9 #18 0x55e29795bea3 in parseCustomOperation mlir/lib/AsmParser/Parser.cpp:2017:19 #19 0x55e29795bea3 in (anonymous namespace)::OperationParser::parseOperation() mlir/lib/AsmParser/Parser.cpp:1174:10 #20 0x55e297971d20 in parseBlockBody mlir/lib/AsmParser/Parser.cpp:2296:9 #21 0x55e297971d20 in (anonymous namespace)::OperationParser::parseBlock(mlir::Block*&) mlir/lib/AsmParser/Parser.cpp:2226:12 #22 0x55e29796e4f5 in parseRegionBody mlir/lib/AsmParser/Parser.cpp:2184:7 #23 0x55e29796e4f5 in (anonymous namespace)::OperationParser::parseRegion(mlir::Region&, llvm::ArrayRef, bool) mlir/lib/AsmParser/Parser.cpp:2121:7 #24 0x55e29796b25e in (anonymous namespace)::CustomOpAsmParser::parseRegion(mlir::Region&, llvm::ArrayRef, bool) mlir/lib/AsmParser/Parser.cpp:1785:16 #25 0x55e29796b2cf in (anonymous namespace)::CustomOpAsmParser::parseOptionalRegion(mlir::Region&, llvm::ArrayRef, bool) mlir/lib/AsmParser/Parser.cpp:1796:12 #26 0x55e2978d89ff in mlir::function_interface_impl::parseFunctionOp(mlir::OpAsmParser&, mlir::OperationState&, bool, mlir::StringAttr, llvm::function_ref, llvm::ArrayRef, mlir::function_interface_impl::VariadicFlag, std::__u::basic_string, std::__u::allocator>&)>, mlir::StringAttr, mlir::StringAttr) mlir/lib/Interfaces/FunctionImplementation.cpp:232:14 #27 0x55e2969ba41d in mlir::func::FuncOp::parse(mlir::OpAsmParser&, mlir::OperationState&) mlir/lib/Dialect/Func/IR/FuncOps.cpp:203:10 #28 0x55e291322c18 in llvm::ParseResult llvm::detail::UniqueFunctionBase::CallImpl(void*, mlir::OpAsmParser&, mlir::OperationState&) llvm/include/llvm/ADT/FunctionExtras.h:220:12 #29 0x55e29795bea3 in operator() llvm/include/llvm/ADT/FunctionExtras.h:384:12 #30 0x55e29795bea3 in callback_fn > llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12 #31 0x55e29795bea3 in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12 #32 0x55e29795bea3 in parseOperation mlir/lib/AsmParser/Parser.cpp:1521:9 #33 0x55e29795bea3 in parseCustomOperation mlir/lib/AsmParser/Parser.cpp:2017:19 #34 0x55e29795bea3 in (anonymous namespace)::OperationParser::parseOperation() mlir/lib/AsmParser/Parser.cpp:1174:10 #35 0x55e297959b78 in parse mlir/lib/AsmParser/Parser.cpp:2725:20 #36 0x55e297959b78 in mlir::parseAsmSourceFile(llvm::SourceMgr const&, mlir::Block*, mlir::ParserConfig const&, mlir::AsmParserState*, mlir::AsmParserCodeCompleteContext*) mlir/lib/AsmParser/Parser.cpp:2785:41 #37 0x55e29790d5c2 in mlir::parseSourceFile(std::__u::shared_ptr const&, mlir::Block*, mlir::ParserConfig const&, mlir::LocationAttr*) mlir/lib/Parser/Parser.cpp:46:10 #38 0x55e291ebbfe2 in parseSourceFile &> mlir/include/mlir/Parser/Parser.h:159:14 #39 0x55e291ebbfe2 in parseSourceFile mlir/include/mlir/Parser/Parser.h:189:10 #40 0x55e291ebbfe2 in mlir::parseSourceFileForTool(std::__u::shared_ptr const&, mlir::ParserConfig const&, bool) mlir/include/mlir/Tools/ParseUtilities.h:31:12 #41 0x55e291ebb263 in performActions(llvm::raw_ostream&, std::__u::shared_ptr const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:383:33 #42 0x55e291ebabd9 in processBuffer mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:481:9 #43 0x55e291ebabd9 in operator() mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:548:12 #44 0x55e291ebabd9 in llvm::LogicalResult llvm::function_ref>, llvm::raw_ostream&)>::callback_fn>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::__u::unique_ptr>, llvm::raw_ostream&) llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12 #45 0x55e297b1cffe in operator() llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12 #46 0x55e297b1cffe in mlir::splitAndProcessBuffer(std::__u::unique_ptr>, llvm::function_ref>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef)::$_0::operator()(llvm::StringRef) const mlir/lib/Support/ToolUtilities.cpp:86:16 #47 0x55e297b1c9c5 in interleave llvm/include/llvm/ADT/STLExtras.h:2125:3 #48 0x55e297b1c9c5 in interleave, (lambda at mlir/lib/Support/ToolUtilities.cpp:79:23), llvm::raw_ostream, llvm::StringRef> llvm/include/llvm/ADT/STLExtras.h:2147:3 #49 0x55e297b1c9c5 in mlir::splitAndProcessBuffer(std::__u::unique_ptr>, llvm::function_ref>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) mlir/lib/Support/ToolUtilities.cpp:89:3 #50 0x55e291eb0cf0 in mlir::MlirOptMain(llvm::raw_ostream&, std::__u::unique_ptr>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:551:10 #51 0x55e291eb115c in mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:589:14 #52 0x55e291eb15f8 in mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:605:10 #53 0x55e29130d1be in main mlir/tools/mlir-opt/mlir-opt.cpp:311:33 #54 0x7fbcf3fff3d3 in __libc_start_main (/usr/grte/v5/lib64/libc.so.6+0x613d3) (BuildId: 9a996398ce14a94560b0c642eb4f6e94) #55 0x55e2912365a9 in _start /usr/grte/v5/debug-src/src/csu/../sysdeps/x86_64/start.S:120 SUMMARY: AddressSanitizer: heap-use-after-free mlir/include/mlir/IR/IRMapping.h:40:11 in map &, llvm::MutableArrayRef, nullptr> Shadow bytes around the buggy address: 0x502000006a00: fa fa 00 fa fa fa 00 00 fa fa 00 fa fa fa 00 fa 0x502000006a80: fa fa 00 fa fa fa 00 00 fa fa 00 00 fa fa 00 00 0x502000006b00: fa fa 00 00 fa fa 00 00 fa fa 00 fa fa fa 00 fa 0x502000006b80: fa fa 00 fa fa fa 00 fa fa fa 00 00 fa fa 00 00 0x502000006c00: fa fa 00 00 fa fa 00 00 fa fa 00 00 fa fa fd fa =>0x502000006c80: fa fa fd fa fa fa fd fd fa fa fd[fd]fa fa fd fd 0x502000006d00: fa fa 00 fa fa fa 00 fa fa fa 00 fa fa fa 00 fa 0x502000006d80: fa fa 00 fa fa fa 00 fa fa fa 00 fa fa fa 00 fa 0x502000006e00: fa fa 00 fa fa fa 00 fa fa fa 00 00 fa fa 00 fa 0x502000006e80: fa fa 00 fa fa fa 00 00 fa fa 00 fa fa fa 00 fa 0x502000006f00: fa fa 00 fa fa fa 00 fa fa fa 00 fa fa fa 00 fa Shadow byte legend (one shadow byte represents 8 application bytes): Addressable: 00 Partially addressable: 01 02 03 04 05 06 07 Heap left redzone: fa Freed heap region: fd Stack left redzone: f1 Stack mid redzone: f2 Stack right redzone: f3 Stack after return: f5 Stack use after scope: f8 Global redzone: f9 Global init order: f6 Poisoned by user: f7 Container overflow: fc Array cookie: ac Intra object redzone: bb ASan internal: fe Left alloca redzone: ca Right alloca redzone: cb ==4320==ABORTING --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 3 +- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 20 -- .../mlir/Interfaces/LoopLikeInterface.h | 20 -- mlir/lib/Dialect/SCF/IR/SCF.cpp | 38 --- .../SCF/TransformOps/SCFTransformOps.cpp | 140 +++++++-- .../SCF/Transforms/ParallelLoopFusion.cpp | 80 ++++- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 279 ++++++----------- mlir/lib/Interfaces/LoopLikeInterface.cpp | 59 ---- .../SCF/transform-loop-fuse-sibling.mlir | 290 +----------------- 9 files changed, 283 insertions(+), 646 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index bf95fbe6721cf..f35ea962bea16 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -303,8 +303,7 @@ def ForallOp : SCF_Op<"forall", [ DeclareOpInterfaceMethods, + "promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index 6a40304e2eeba..de807c3e4e1f8 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -181,16 +181,6 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef sizes); void getPerfectlyNestedLoops(SmallVectorImpl &nestedLoops, scf::ForOp root); -//===----------------------------------------------------------------------===// -// Fusion related helpers -//===----------------------------------------------------------------------===// - -/// Check structural compatibility between two loops such as iteration space -/// and dominance. -bool checkFusionStructuralLegality(LoopLikeOpInterface target, - LoopLikeOpInterface source, - Diagnostic &diag); - /// Given two scf.forall loops, `target` and `source`, fuses `target` into /// `source`. Assumes that the given loops are siblings and are independent of /// each other. @@ -212,16 +202,6 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter); -/// Given two scf.parallel loops, `target` and `source`, fuses `target` into -/// `source`. Assumes that the given loops are siblings and are independent of -/// each other. -/// -/// This function does not perform any legality checks and simply fuses the -/// loops. The caller is responsible for ensuring that the loops are legal to -/// fuse. -scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target, - scf::ParallelOp source, - RewriterBase &rewriter); } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h index d08e097a9b4af..9925fc6ce6ca9 100644 --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h @@ -90,24 +90,4 @@ struct JamBlockGatherer { /// Include the generated interface declarations. #include "mlir/Interfaces/LoopLikeInterface.h.inc" -namespace mlir { -/// A function that rewrites `target`'s terminator as a teminator obtained by -/// fusing `source` into `target`. -using FuseTerminatorFn = - function_ref; - -/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to -/// `target`. The `NewYieldValuesFn` callback is used to pass to the -/// `replaceWithAdditionalYields` interface method to replace the loop with a -/// new loop with (possibly) additional yields, while the `FuseTerminatorFn` -/// callback is repsonsible for updating the fused loop terminator. -LoopLikeOpInterface createFused(LoopLikeOpInterface target, - LoopLikeOpInterface source, - RewriterBase &rewriter, - NewYieldValuesFn newYieldValuesFn, - FuseTerminatorFn fuseTerminatorFn); - -} // namespace mlir - #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_ diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index cb15e0ecebf05..907d7f794593d 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -618,44 +618,6 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, SmallVector ForallOp::getLoopRegions() { return {&getRegion()}; } -FailureOr ForallOp::replaceWithAdditionalYields( - RewriterBase &rewriter, ValueRange newInitOperands, - bool replaceInitOperandUsesInLoop, - const NewYieldValuesFn &newYieldValuesFn) { - // Create a new loop before the existing one, with the extra operands. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(getOperation()); - SmallVector inits(getOutputs()); - llvm::append_range(inits, newInitOperands); - scf::ForallOp newLoop = rewriter.create( - getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(), - inits, getMapping(), - /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); - - // Move the loop body to the new op. - rewriter.mergeBlocks(getBody(), newLoop.getBody(), - newLoop.getBody()->getArguments().take_front( - getBody()->getNumArguments())); - - if (replaceInitOperandUsesInLoop) { - // Replace all uses of `newInitOperands` with the corresponding basic block - // arguments. - for (auto &&[newOperand, oldOperand] : - llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back( - newInitOperands.size()))) { - rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) { - Operation *user = use.getOwner(); - return newLoop->isProperAncestor(user); - }); - } - } - - // Replace the old loop. - rewriter.replaceOp(getOperation(), - newLoop->getResults().take_front(getNumResults())); - return cast(newLoop.getOperation()); -} - /// Promotes the loop body of a forallOp to its containing block if it can be /// determined that the loop has a single iteration. LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 41834fea3bb84..56ff2709a589e 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -261,10 +261,8 @@ loopScheduling(scf::ForOp forOp, return 1; }; - std::optional ubConstant = - getConstantIntValue(forOp.getUpperBound()); - std::optional lbConstant = - getConstantIntValue(forOp.getLowerBound()); + std::optional ubConstant = getConstantIntValue(forOp.getUpperBound()); + std::optional lbConstant = getConstantIntValue(forOp.getLowerBound()); DenseMap opCycles; std::map> wrappedSchedule; for (Operation &op : forOp.getBody()->getOperations()) { @@ -449,6 +447,113 @@ void transform::TakeAssumedBranchOp::getEffects( // LoopFuseSiblingOp //===----------------------------------------------------------------------===// +/// Check if `target` and `source` are siblings, in the context that `target` +/// is being fused into `source`. +/// +/// This is a simple check that just checks if both operations are in the same +/// block and some checks to ensure that the fused IR does not violate +/// dominance. +static DiagnosedSilenceableFailure isOpSibling(Operation *target, + Operation *source) { + // Check if both operations are same. + if (target == source) + return emitSilenceableFailure(source) + << "target and source need to be different loops"; + + // Check if both operations are in the same block. + if (target->getBlock() != source->getBlock()) + return emitSilenceableFailure(source) + << "target and source are not in the same block"; + + // Check if fusion will violate dominance. + DominanceInfo domInfo(source); + if (target->isBeforeInBlock(source)) { + // Since `target` is before `source`, all users of results of `target` + // need to be dominated by `source`. + for (Operation *user : target->getUsers()) { + if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { + return emitSilenceableFailure(target) + << "user of results of target should be properly dominated by " + "source"; + } + } + } else { + // Since `target` is after `source`, all values used by `target` need + // to dominate `source`. + + // Check if operands of `target` are dominated by `source`. + for (Value operand : target->getOperands()) { + Operation *operandOp = operand.getDefiningOp(); + // Operands without defining operations are block arguments. When `target` + // and `source` occur in the same block, these operands dominate `source`. + if (!operandOp) + continue; + + // Operand's defining operation should properly dominate `source`. + if (!domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) + return emitSilenceableFailure(target) + << "operands of target should be properly dominated by source"; + } + + // Check if values used by `target` are dominated by `source`. + bool failed = false; + OpOperand *failedValue = nullptr; + visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { + Operation *operandOp = operand->get().getDefiningOp(); + if (operandOp && !domInfo.properlyDominates(operandOp, source, + /*enclosingOpOk=*/false)) { + // `operand` is not an argument of an enclosing block and the defining + // op of `operand` is outside `target` but does not dominate `source`. + failed = true; + failedValue = operand; + } + }); + + if (failed) + return emitSilenceableFailure(failedValue->getOwner()) + << "values used inside regions of target should be properly " + "dominated by source"; + } + + return DiagnosedSilenceableFailure::success(); +} + +/// Check if `target` scf.forall can be fused into `source` scf.forall. +/// +/// This simply checks if both loops have the same bounds, steps and mapping. +/// No attempt is made at checking that the side effects of `target` and +/// `source` are independent of each other. +static bool isForallWithIdenticalConfiguration(Operation *target, + Operation *source) { + auto targetOp = dyn_cast(target); + auto sourceOp = dyn_cast(source); + if (!targetOp || !sourceOp) + return false; + + return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && + targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && + targetOp.getMixedStep() == sourceOp.getMixedStep() && + targetOp.getMapping() == sourceOp.getMapping(); +} + +/// Check if `target` scf.for can be fused into `source` scf.for. +/// +/// This simply checks if both loops have the same bounds and steps. No attempt +/// is made at checking that the side effects of `target` and `source` are +/// independent of each other. +static bool isForWithIdenticalConfiguration(Operation *target, + Operation *source) { + auto targetOp = dyn_cast(target); + auto sourceOp = dyn_cast(source); + if (!targetOp || !sourceOp) + return false; + + return targetOp.getLowerBound() == sourceOp.getLowerBound() && + targetOp.getUpperBound() == sourceOp.getUpperBound() && + targetOp.getStep() == sourceOp.getStep(); +} + DiagnosedSilenceableFailure transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -464,32 +569,25 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, << "source handle (got " << llvm::range_size(sourceOps) << ")"; } - auto target = dyn_cast(*targetOps.begin()); - auto source = dyn_cast(*sourceOps.begin()); - if (!target || !source) - return emitSilenceableFailure(target->getLoc()) - << "target or source is not a loop op"; + Operation *target = *targetOps.begin(); + Operation *source = *sourceOps.begin(); - // Check if loops can be fused - Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error); - if (!mlir::checkFusionStructuralLegality(target, source, diag)) - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + // Check if the target and source are siblings. + DiagnosedSilenceableFailure diag = isOpSibling(target, source); + if (!diag.succeeded()) + return diag; Operation *fusedLoop; - // TODO: Support fusion for loop-like ops besides scf.for, scf.forall - // and scf.parallel. - if (isa(target) && isa(source)) { + /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. + if (isForWithIdenticalConfiguration(target, source)) { fusedLoop = fuseIndependentSiblingForLoops( cast(target), cast(source), rewriter); - } else if (isa(target) && isa(source)) { + } else if (isForallWithIdenticalConfiguration(target, source)) { fusedLoop = fuseIndependentSiblingForallLoops( cast(target), cast(source), rewriter); - } else if (isa(target) && isa(source)) { - fusedLoop = fuseIndependentSiblingParallelLoops( - cast(target), cast(source), rewriter); } else return emitSilenceableFailure(target->getLoc()) - << "unsupported loop type for fusion"; + << "operations cannot be fused"; assert(fusedLoop && "failed to fuse operations"); diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index b775f988576e3..5934d85373b03 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" @@ -38,6 +37,24 @@ static bool hasNestedParallelOp(ParallelOp ploop) { return walkResult.wasInterrupted(); } +/// Verify equal iteration spaces. +static bool equalIterationSpaces(ParallelOp firstPloop, + ParallelOp secondPloop) { + if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) + return false; + + auto matchOperands = [&](const OperandRange &lhs, + const OperandRange &rhs) -> bool { + // TODO: Extend this to support aliases and equal constants. + return std::equal(lhs.begin(), lhs.end(), rhs.begin()); + }; + return matchOperands(firstPloop.getLowerBound(), + secondPloop.getLowerBound()) && + matchOperands(firstPloop.getUpperBound(), + secondPloop.getUpperBound()) && + matchOperands(firstPloop.getStep(), secondPloop.getStep()); +} + /// Checks if the parallel loops have mixed access to the same buffers. Returns /// `true` if the first parallel loop writes to the same indices that the second /// loop reads. @@ -136,10 +153,9 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref mayAlias) { - Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark); return !hasNestedParallelOp(firstPloop) && !hasNestedParallelOp(secondPloop) && - checkFusionStructuralLegality(firstPloop, secondPloop, diag) && + equalIterationSpaces(firstPloop, secondPloop) && succeeded(verifyDependencies(firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias)); } @@ -158,9 +174,61 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, mayAlias)) return; - IRRewriter rewriter(builder); - secondPloop = mlir::fuseIndependentSiblingParallelLoops( - firstPloop, secondPloop, rewriter); + DominanceInfo dom; + // We are fusing first loop into second, make sure there are no users of the + // first loop results between loops. + for (Operation *user : firstPloop->getUsers()) + if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false)) + return; + + ValueRange inits1 = firstPloop.getInitVals(); + ValueRange inits2 = secondPloop.getInitVals(); + + SmallVector newInitVars(inits1.begin(), inits1.end()); + newInitVars.append(inits2.begin(), inits2.end()); + + IRRewriter b(builder); + b.setInsertionPoint(secondPloop); + auto newSecondPloop = b.create( + secondPloop.getLoc(), secondPloop.getLowerBound(), + secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); + + Block *newBlock = newSecondPloop.getBody(); + auto term1 = cast(block1->getTerminator()); + auto term2 = cast(block2->getTerminator()); + + b.inlineBlockBefore(block2, newBlock, newBlock->begin(), + newBlock->getArguments()); + b.inlineBlockBefore(block1, newBlock, newBlock->begin(), + newBlock->getArguments()); + + ValueRange results = newSecondPloop.getResults(); + if (!results.empty()) { + b.setInsertionPointToEnd(newBlock); + + ValueRange reduceArgs1 = term1.getOperands(); + ValueRange reduceArgs2 = term2.getOperands(); + SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); + newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); + + auto newReduceOp = b.create(term2.getLoc(), newReduceArgs); + + for (auto &&[i, reg] : llvm::enumerate(llvm::concat( + term1.getReductions(), term2.getReductions()))) { + Block &oldRedBlock = reg.front(); + Block &newRedBlock = newReduceOp.getReductions()[i].front(); + b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(), + newRedBlock.getArguments()); + } + + firstPloop.replaceAllUsesWith(results.take_front(inits1.size())); + secondPloop.replaceAllUsesWith(results.take_back(inits2.size())); + } + term1->erase(); + term2->erase(); + firstPloop.erase(); + secondPloop.erase(); + secondPloop = newSecondPloop; } void mlir::scf::naivelyFuseParallelOps( diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index abfc9a1b4d444..c0ee9d2afe91c 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -1263,131 +1262,54 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, return tileLoops; } -//===----------------------------------------------------------------------===// -// Fusion related helpers -//===----------------------------------------------------------------------===// - -/// Check if `target` and `source` are siblings, in the context that `target` -/// is being fused into `source`. -/// -/// This is a simple check that just checks if both operations are in the same -/// block and some checks to ensure that the fused IR does not violate -/// dominance. -static bool isOpSibling(Operation *target, Operation *source, - Diagnostic &diag) { - // Check if both operations are same. - if (target == source) { - diag << "target and source need to be different loops"; - return false; - } - - // Check if both operations are in the same block. - if (target->getBlock() != source->getBlock()) { - diag << "target and source are not in the same block"; - return false; - } - - // Check if fusion will violate dominance. - DominanceInfo domInfo(source); - if (target->isBeforeInBlock(source)) { - // Since `target` is before `source`, all users of results of `target` - // need to be dominated by `source`. - for (Operation *user : target->getUsers()) { - if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { - diag << "user of results of target should " - "be properly dominated by source"; - return false; - } - } - } else { - // Since `target` is after `source`, all values used by `target` need - // to dominate `source`. - - // Check if operands of `target` are dominated by `source`. - for (Value operand : target->getOperands()) { - Operation *operandOp = operand.getDefiningOp(); - // Operands without defining operations are block arguments. When `target` - // and `source` occur in the same block, these operands dominate `source`. - if (!operandOp) - continue; - - // Operand's defining operation should properly dominate `source`. - if (!domInfo.properlyDominates(operandOp, source, - /*enclosingOpOk=*/false)) { - diag << "operands of target should be properly dominated by source"; - return false; - } - } - - // Check if values used by `target` are dominated by `source`. - bool failed = false; - OpOperand *failedValue = nullptr; - visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { - Operation *operandOp = operand->get().getDefiningOp(); - if (operandOp && !domInfo.properlyDominates(operandOp, source, - /*enclosingOpOk=*/false)) { - // `operand` is not an argument of an enclosing block and the defining - // op of `operand` is outside `target` but does not dominate `source`. - failed = true; - failedValue = operand; - } - }); - - if (failed) { - diag << "values used inside regions of target should be properly " - "dominated by source"; - diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation"; - return false; - } - } - - return true; -} - -bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target, - LoopLikeOpInterface source, - Diagnostic &diag) { - if (target->getName() != source->getName()) { - diag << "target and source must be same loop type"; - return false; - } - - bool iterSpaceEq = - target.getLoopLowerBounds() == source.getLoopLowerBounds() && - target.getLoopUpperBounds() == source.getLoopUpperBounds() && - target.getLoopSteps() == source.getLoopSteps(); - // TODO: Decouple checks on concrete loop types and move this function - // somewhere for general utility for `LoopLikeOpInterface` - if (auto forAllTarget = dyn_cast(*target)) - iterSpaceEq = iterSpaceEq && forAllTarget.getMapping() == - cast(*source).getMapping(); - if (!iterSpaceEq) { - diag << "target and source iteration spaces must be equal"; - return false; - } - return isOpSibling(target, source, diag); -} - scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter) { - scf::ForallOp fusedLoop = cast(createFused( - target, source, rewriter, - [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { - // `ForallOp` does not have yields, rather an `InParallelOp` terminator. - return ValueRange{}; - }, - [&](RewriterBase &b, LoopLikeOpInterface source, - LoopLikeOpInterface &target, IRMapping mapping) { - auto sourceForall = cast(source); - auto targetForall = cast(target); - scf::InParallelOp fusedTerm = targetForall.getTerminator(); - b.setInsertionPointToEnd(fusedTerm.getBody()); - for (Operation &op : sourceForall.getTerminator().getYieldingOps()) - b.clone(op, mapping); - })); - rewriter.replaceOp(source, - fusedLoop.getResults().take_back(source.getNumResults())); + unsigned numTargetOuts = target.getNumResults(); + unsigned numSourceOuts = source.getNumResults(); + + // Create fused shared_outs. + SmallVector fusedOuts; + llvm::append_range(fusedOuts, target.getOutputs()); + llvm::append_range(fusedOuts, source.getOutputs()); + + // Create a new scf.forall op after the source loop. + rewriter.setInsertionPointAfter(source); + scf::ForallOp fusedLoop = rewriter.create( + source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), + source.getMixedStep(), fusedOuts, source.getMapping()); + + // Map control operands. + IRMapping mapping; + mapping.map(target.getInductionVars(), fusedLoop.getInductionVars()); + mapping.map(source.getInductionVars(), fusedLoop.getInductionVars()); + + // Map shared outs. + mapping.map(target.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); + mapping.map(source.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); + + // Append everything except the terminator into the fused operation. + rewriter.setInsertionPointToStart(fusedLoop.getBody()); + for (Operation &op : target.getBody()->without_terminator()) + rewriter.clone(op, mapping); + for (Operation &op : source.getBody()->without_terminator()) + rewriter.clone(op, mapping); + + // Fuse the old terminator in_parallel ops into the new one. + scf::InParallelOp targetTerm = target.getTerminator(); + scf::InParallelOp sourceTerm = source.getTerminator(); + scf::InParallelOp fusedTerm = fusedLoop.getTerminator(); + rewriter.setInsertionPointToStart(fusedTerm.getBody()); + for (Operation &op : targetTerm.getYieldingOps()) + rewriter.clone(op, mapping); + for (Operation &op : sourceTerm.getYieldingOps()) + rewriter.clone(op, mapping); + + // Replace old loops by substituting their uses by results of the fused loop. + rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); + rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); return fusedLoop; } @@ -1395,74 +1317,49 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter) { - scf::ForOp fusedLoop = cast(createFused( - target, source, rewriter, - [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { - return source.getYieldedValues(); - }, - [&](RewriterBase &b, LoopLikeOpInterface source, - LoopLikeOpInterface &target, IRMapping mapping) { - auto targetFor = cast(target); - auto newTerm = b.clone(*targetFor.getBody()->getTerminator(), mapping); - b.replaceOp(targetFor.getBody()->getTerminator(), newTerm); - })); - rewriter.replaceOp(source, - fusedLoop.getResults().take_back(source.getNumResults())); - return fusedLoop; -} - -// TODO: Finish refactoring this a la the above, but likely requires additional -// interface methods. -scf::ParallelOp mlir::fuseIndependentSiblingParallelLoops( - scf::ParallelOp target, scf::ParallelOp source, RewriterBase &rewriter) { - OpBuilder::InsertionGuard guard(rewriter); - Block *block1 = target.getBody(); - Block *block2 = source.getBody(); - auto term1 = cast(block1->getTerminator()); - auto term2 = cast(block2->getTerminator()); - - ValueRange inits1 = target.getInitVals(); - ValueRange inits2 = source.getInitVals(); - - SmallVector newInitVars(inits1.begin(), inits1.end()); - newInitVars.append(inits2.begin(), inits2.end()); - - rewriter.setInsertionPoint(source); - auto fusedLoop = rewriter.create( - rewriter.getFusedLoc(target.getLoc(), source.getLoc()), - source.getLowerBound(), source.getUpperBound(), source.getStep(), - newInitVars); - Block *newBlock = fusedLoop.getBody(); - rewriter.inlineBlockBefore(block2, newBlock, newBlock->begin(), - newBlock->getArguments()); - rewriter.inlineBlockBefore(block1, newBlock, newBlock->begin(), - newBlock->getArguments()); - - ValueRange results = fusedLoop.getResults(); - if (!results.empty()) { - rewriter.setInsertionPointToEnd(newBlock); - - ValueRange reduceArgs1 = term1.getOperands(); - ValueRange reduceArgs2 = term2.getOperands(); - SmallVector newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); - newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); - - auto newReduceOp = rewriter.create( - rewriter.getFusedLoc(term1.getLoc(), term2.getLoc()), newReduceArgs); - - for (auto &&[i, reg] : llvm::enumerate(llvm::concat( - term1.getReductions(), term2.getReductions()))) { - Block &oldRedBlock = reg.front(); - Block &newRedBlock = newReduceOp.getReductions()[i].front(); - rewriter.inlineBlockBefore(&oldRedBlock, &newRedBlock, - newRedBlock.begin(), - newRedBlock.getArguments()); - } - } - rewriter.replaceOp(target, results.take_front(inits1.size())); - rewriter.replaceOp(source, results.take_back(inits2.size())); - rewriter.eraseOp(term1); - rewriter.eraseOp(term2); + unsigned numTargetOuts = target.getNumResults(); + unsigned numSourceOuts = source.getNumResults(); + + // Create fused init_args, with target's init_args before source's init_args. + SmallVector fusedInitArgs; + llvm::append_range(fusedInitArgs, target.getInitArgs()); + llvm::append_range(fusedInitArgs, source.getInitArgs()); + + // Create a new scf.for op after the source loop (with scf.yield terminator + // (without arguments) only in case its init_args is empty). + rewriter.setInsertionPointAfter(source); + scf::ForOp fusedLoop = rewriter.create( + source.getLoc(), source.getLowerBound(), source.getUpperBound(), + source.getStep(), fusedInitArgs); + + // Map original induction variables and operands to those of the fused loop. + IRMapping mapping; + mapping.map(target.getInductionVar(), fusedLoop.getInductionVar()); + mapping.map(target.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_front(numTargetOuts)); + mapping.map(source.getInductionVar(), fusedLoop.getInductionVar()); + mapping.map(source.getRegionIterArgs(), + fusedLoop.getRegionIterArgs().take_back(numSourceOuts)); + + // Merge target's body into the new (fused) for loop and then source's body. + rewriter.setInsertionPointToStart(fusedLoop.getBody()); + for (Operation &op : target.getBody()->without_terminator()) + rewriter.clone(op, mapping); + for (Operation &op : source.getBody()->without_terminator()) + rewriter.clone(op, mapping); + + // Build fused yield results by appropriately mapping original yield operands. + SmallVector yieldResults; + for (Value operand : target.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + for (Value operand : source.getBody()->getTerminator()->getOperands()) + yieldResults.push_back(mapping.lookupOrDefault(operand)); + if (!yieldResults.empty()) + rewriter.create(source.getLoc(), yieldResults); + + // Replace old loops by substituting their uses by results of the fused loop. + rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); + rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts)); return fusedLoop; } diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp index 5a119a7cf2659..1e0e87b64e811 100644 --- a/mlir/lib/Interfaces/LoopLikeInterface.cpp +++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp @@ -8,8 +8,6 @@ #include "mlir/Interfaces/LoopLikeInterface.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "llvm/ADT/DenseSet.h" @@ -115,60 +113,3 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) { return success(); } - -LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target, - LoopLikeOpInterface source, - RewriterBase &rewriter, - NewYieldValuesFn newYieldValuesFn, - FuseTerminatorFn fuseTerminatorFn) { - auto targetIterArgs = target.getRegionIterArgs(); - std::optional> targetInductionVar = - target.getLoopInductionVars(); - SmallVector targetYieldOperands(target.getYieldedValues()); - auto sourceIterArgs = source.getRegionIterArgs(); - std::optional> sourceInductionVar = - *source.getLoopInductionVars(); - SmallVector sourceYieldOperands(source.getYieldedValues()); - auto sourceRegion = source.getLoopRegions().front(); - - FailureOr maybeFusedLoop = - target.replaceWithAdditionalYields(rewriter, source.getInits(), - /*replaceInitOperandUsesInLoop=*/false, - newYieldValuesFn); - if (failed(maybeFusedLoop)) - llvm_unreachable("failed to replace loop"); - LoopLikeOpInterface fusedLoop = *maybeFusedLoop; - // Since the target op is rewritten at the original's location, we move it to - // the soure op's location. - rewriter.moveOpBefore(fusedLoop, source); - - // Map control operands. - IRMapping mapping; - std::optional> fusedInductionVar = - fusedLoop.getLoopInductionVars(); - if (fusedInductionVar) { - if (!targetInductionVar || !sourceInductionVar) - llvm_unreachable( - "expected target and source loops to have induction vars"); - mapping.map(*targetInductionVar, *fusedInductionVar); - mapping.map(*sourceInductionVar, *fusedInductionVar); - } - mapping.map(targetIterArgs, - fusedLoop.getRegionIterArgs().take_front(targetIterArgs.size())); - mapping.map(targetYieldOperands, - fusedLoop.getYieldedValues().take_front(targetIterArgs.size())); - mapping.map(sourceIterArgs, - fusedLoop.getRegionIterArgs().take_back(sourceIterArgs.size())); - mapping.map(sourceYieldOperands, - fusedLoop.getYieldedValues().take_back(sourceIterArgs.size())); - // Append everything except the terminator into the fused operation. - rewriter.setInsertionPoint( - fusedLoop.getLoopRegions().front()->front().getTerminator()); - for (Operation &op : sourceRegion->front().without_terminator()) - rewriter.clone(op, mapping); - - // TODO: Replace with corresponding interface method if added - fuseTerminatorFn(rewriter, source, fusedLoop, mapping); - - return fusedLoop; -} diff --git a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir index f8246b74a5744..54dd2bdf953ca 100644 --- a/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir +++ b/mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir @@ -47,169 +47,6 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @fuse_two_parallel -// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { -func.func @fuse_two_parallel(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { -// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index -// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index -// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index -// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. - %c2 = arith.constant 2 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c1fp = arith.constant 1.0 : f32 -// CHECK: [[SUM:%.*]] = memref.alloc() - %sum = memref.alloc() : memref<2x2xf32> -// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) -// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { -// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] -// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] -// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] -// CHECK-NOT: scf.parallel -// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] -// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] -// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] -// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] -// CHECK: scf.reduce -// CHECK: } - scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> - %sum_elem = arith.addf %B_elem, %c1fp : f32 - memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> - scf.reduce - } - scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> - %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> - %product_elem = arith.mulf %sum_elem, %A_elem : f32 - memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> - scf.reduce - } -// CHECK: memref.dealloc [[SUM]] - memref.dealloc %sum : memref<2x2xf32> - return -} -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: func @fuse_two_parallel_reverse -// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}) { -func.func @fuse_two_parallel_reverse(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { -// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index -// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index -// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index -// CHECK-DAG: [[C1FP:%.*]] = arith.constant 1. - %c2 = arith.constant 2 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c1fp = arith.constant 1.0 : f32 -// CHECK: [[SUM:%.*]] = memref.alloc() - %sum = memref.alloc() : memref<2x2xf32> -// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) -// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { -// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]] -// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] -// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] -// CHECK: memref.store [[PRODUCT_ELEM]], [[B]]{{\[}}[[I]], [[J]]] -// CHECK-NOT: scf.parallel -// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]] -// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C1FP]] -// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] -// CHECK: scf.reduce -// CHECK: } - scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> - %sum_elem = arith.addf %B_elem, %c1fp : f32 - memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> - scf.reduce - } - scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> - %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> - %product_elem = arith.mulf %sum_elem, %A_elem : f32 - memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> - scf.reduce - } -// CHECK: memref.dealloc [[SUM]] - memref.dealloc %sum : memref<2x2xf32> - return -} -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused = transform.loop.fuse_sibling %parallel#1 into %parallel#0 : (!transform.any_op,!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: func @fuse_reductions_two -// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32) -func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32 -// CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) -// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32) -// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]] -// CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]] -// CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) { -// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): -// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 -// CHECK: scf.reduce.return %[[R]] : f32 -// CHECK: } -// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): -// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32 -// CHECK: scf.reduce.return %[[R]] : f32 -// CHECK: } -// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32 - %c2 = arith.constant 2 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %init1 = arith.constant 1.0 : f32 - %init2 = arith.constant 2.0 : f32 - %res1 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init1) -> f32 { - %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> - scf.reduce(%A_elem : f32) { - ^bb0(%lhs: f32, %rhs: f32): - %1 = arith.addf %lhs, %rhs : f32 - scf.reduce.return %1 : f32 - } - } - %res2 = scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init(%init2) -> f32 { - %B_elem = memref.load %B[%i, %j] : memref<2x2xf32> - scf.reduce(%B_elem : f32) { - ^bb0(%lhs: f32, %rhs: f32): - %1 = arith.mulf %lhs, %rhs : f32 - scf.reduce.return %1 : f32 - } - } - return %res1, %res2 : f32, f32 -} -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - // CHECK: func.func @fuse_2nd_for_into_1st([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}} func.func @fuse_2nd_for_into_1st(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index @@ -371,62 +208,6 @@ module attributes {transform.with_named_sequence} { } } - -// ----- - -// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 32) -#map = affine_map<(d0) -> (d0 * 32)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> -module { - // CHECK: func.func @loop_sibling_fusion(%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}} - func.func @loop_sibling_fusion(%arg0: tensor<128xf32>, %arg1: tensor<128x128xf16>, %arg2: tensor<128x64xf32>, %arg3: tensor<128x128xf32>) -> (tensor<128xf32>, tensor<128x128xf16>) { - // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<128x128xf16> - // CHECK-NEXT: %[[RESULTS:.*]]:2 = scf.forall (%[[I:.*]]) in (4) shared_outs(%[[S1:.*]] = %[[ARG0]], %[[S2:.*]] = %[[ARG1]]) -> (tensor<128xf32>, tensor<128x128xf16>) { - // CHECK-NEXT: %[[IDX:.*]] = affine.apply #[[$MAP]](%[[I]]) - // CHECK-NEXT: %[[SLICE0:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32> - // CHECK-NEXT: %[[SLICE1:.*]] = tensor.extract_slice %[[ARG3]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32> - // CHECK-NEXT: %[[SLICE2:.*]] = tensor.extract_slice %[[EMPTY]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16> - // CHECK-NEXT: %[[GENERIC:.*]] = linalg.generic {{.*}} ins(%[[SLICE1]] : {{.*}}) outs(%[[SLICE2]] : {{.*}}) - // CHECK: scf.forall.in_parallel { - // CHECK-NEXT: tensor.parallel_insert_slice %[[SLICE0]] into %[[S1]][%[[IDX]]] [32] [1] : tensor<32xf32> into tensor<128xf32> - // CHECK-NEXT: tensor.parallel_insert_slice %[[GENERIC]] into %[[S2]][%[[IDX]], 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16> - // CHECK-NEXT: } - // CHECK-NEXT: } {mapping = [#gpu.warp]} - // CHECK-NEXT: return %[[RESULTS]]#0, %[[RESULTS]]#1 - %0 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg0) -> (tensor<128xf32>) { - %3 = affine.apply #map(%arg4) - %extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 1] [1, 1] : tensor<128x128xf32> to tensor<32xf32> - scf.forall.in_parallel { - tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [32] [1] : tensor<32xf32> into tensor<128xf32> - } - } {mapping = [#gpu.warp]} - %1 = tensor.empty() : tensor<128x128xf16> - %2 = scf.forall (%arg4) in (4) shared_outs(%arg5 = %arg1) -> (tensor<128x128xf16>) { - %3 = affine.apply #map(%arg4) - %extracted_slice = tensor.extract_slice %arg3[%3, 0] [32, 128] [1, 1] : tensor<128x128xf32> to tensor<32x128xf32> - %extracted_slice_0 = tensor.extract_slice %1[%3, 0] [32, 128] [1, 1] : tensor<128x128xf16> to tensor<32x128xf16> - %4 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<32x128xf32>) outs(%extracted_slice_0 : tensor<32x128xf16>) { - ^bb0(%in: f32, %out: f16): - %5 = arith.truncf %in : f32 to f16 - linalg.yield %5 : f16 - } -> tensor<32x128xf16> - scf.forall.in_parallel { - tensor.parallel_insert_slice %4 into %arg5[%3, 0] [32, 128] [1, 1] : tensor<32x128xf16> into tensor<128x128xf16> - } - } {mapping = [#gpu.warp]} - return %0, %2 : tensor<128xf32>, tensor<128x128xf16> - } -} - -module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%root: !transform.any_op) { - %loops = transform.structured.match ops{["scf.forall"]} in %root : (!transform.any_op) -> !transform.any_op - %loop1, %loop2 = transform.split_handle %loops : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %loop3 = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op - transform.yield - } -} - // ----- func.func @source_for_uses_result_of_target_for_err(%A: tensor<128xf32>, %B: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) { @@ -501,9 +282,8 @@ func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>, %6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32> scf.yield %6 : tensor<128xf32> } - // expected-error @below {{values used inside regions of target should be properly dominated by source}} %dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) { - // expected-note @below {{see operation}} + // expected-error @below {{values used inside regions of target should be properly dominated by source}} %dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> %dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32> %dup5 = arith.addf %dup3, %dup2 : vector<16xf32> @@ -548,74 +328,6 @@ module attributes {transform.with_named_sequence} { transform.yield } } - -// ----- - -func.func @non_matching_iteration_spaces_err(%A: memref<2x2xf32>, %B: memref<2x2xf32>) { - %c2 = arith.constant 2 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c1fp = arith.constant 1.0 : f32 - %sum = memref.alloc() : memref<2x2xf32> - // expected-error @below {{target and source iteration spaces must be equal}} - scf.parallel (%i) = (%c0) to (%c2) step (%c1) { - %B_elem = memref.load %B[%i, %c0] : memref<2x2xf32> - %sum_elem = arith.addf %B_elem, %c1fp : f32 - memref.store %sum_elem, %sum[%i, %c0] : memref<2x2xf32> - scf.reduce - } - scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> - %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> - %product_elem = arith.mulf %sum_elem, %A_elem : f32 - memref.store %product_elem, %B[%i, %j] : memref<2x2xf32> - scf.reduce - } - memref.dealloc %sum : memref<2x2xf32> - return -} -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %parallel:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %fused = transform.loop.fuse_sibling %parallel#0 into %parallel#1 : (!transform.any_op,!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -func.func @non_matching_loop_types_err(%A: memref<2xf32>, %B: memref<2xf32>) { - %c2 = arith.constant 2 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c1fp = arith.constant 1.0 : f32 - %sum = memref.alloc() : memref<2xf32> - // expected-error @below {{target and source must be same loop type}} - scf.for %i = %c0 to %c2 step %c1 { - %B_elem = memref.load %B[%i] : memref<2xf32> - %sum_elem = arith.addf %B_elem, %c1fp : f32 - memref.store %sum_elem, %sum[%i] : memref<2xf32> - } - scf.parallel (%i) = (%c0) to (%c2) step (%c1) { - %sum_elem = memref.load %sum[%i] : memref<2xf32> - %A_elem = memref.load %A[%i] : memref<2xf32> - %product_elem = arith.mulf %sum_elem, %A_elem : f32 - memref.store %product_elem, %B[%i] : memref<2xf32> - scf.reduce - } - memref.dealloc %sum : memref<2xf32> - return -} -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op - %fused = transform.loop.fuse_sibling %0 into %1 : (!transform.any_op,!transform.any_op) -> !transform.any_op - transform.yield - } -} - // ----- // CHECK: func.func @foreach_loop_pair_fuse([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}