Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memref.reshape causes segfault in JIT backend with opt_level=3 #200

Open
zzzDavid opened this issue Sep 1, 2023 · 3 comments
Open

memref.reshape causes segfault in JIT backend with opt_level=3 #200

zzzDavid opened this issue Sep 1, 2023 · 3 comments
Assignees
Labels
MLIR Limitation MLIR limitations

Comments

@zzzDavid
Copy link
Collaborator

zzzDavid commented Sep 1, 2023

Description

This thread documents an issue we met with memref.reshape. The generated IR is correct, it can be compiled with clang and executes correctly when mlir ExecutionEngine optimization level is set to 0, 1, 2. However, if ExecutionEngine optimization level is set to 3, this triggers a segfault.

Specifically, this step causes segfault:

execution_engine = ExecutionEngine(
            lowered, opt_level=3, shared_libs=shared_libs)

Current solution

This is likely an issue with MLIR JIT compiler. We bypass this issue by setting the optimization level lower than 3.

Sample IR to repeat this issue

module {
  memref.global "private" constant @const_0 : memref<3xi64> = dense<[5, 2, 4]>
  memref.global "private" constant @const_1 : memref<2xi64> = dense<[5, 8]>
  func.func @kernel(%arg0: memref<5x3x2xf32>, %arg1: memref<4x3xf32>, %arg2: memref<4xf32>) -> memref<5x8xf32> attributes {itypes = "___", otypes = "_"} {
    %c1_i32 = arith.constant 1 : i32
    %0 = arith.sitofp %c1_i32 : i32 to f32
    %1 = arith.negf %0 : f32
    %c2_i32 = arith.constant 2 : i32
    %2 = arith.sitofp %c2_i32 : i32 to f32
    %3 = arith.negf %2 : f32
    %alloc = memref.alloc() {name = "output1"} : memref<5x2x3xf32>
    %c0_i32 = arith.constant 0 : i32
    %4 = arith.sitofp %c0_i32 : i32 to f32
    linalg.fill {op_name = "transpose_init_zero_0"} ins(%4 : f32) outs(%alloc : memref<5x2x3xf32>)
    linalg.transpose ins(%arg0 : memref<5x3x2xf32>) outs(%alloc : memref<5x2x3xf32>) permutation = [0, 2, 1]  {op_name = "transpose_1"}
    %alloc_0 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_1 = arith.constant 0 : i32
    %5 = arith.sitofp %c0_i32_1 : i32 to f32
    linalg.fill {op_name = "linear_init_zero_2"} ins(%5 : f32) outs(%alloc_0 : memref<5x2x4xf32>)
    %alloc_2 = memref.alloc() : memref<3x4xf32>
    %c0_i32_3 = arith.constant 0 : i32
    %6 = arith.sitofp %c0_i32_3 : i32 to f32
    linalg.fill {op_name = "transpose_init_zero_3"} ins(%6 : f32) outs(%alloc_2 : memref<3x4xf32>)
    linalg.transpose ins(%arg1 : memref<4x3xf32>) outs(%alloc_2 : memref<3x4xf32>) permutation = [1, 0]  {op_name = "transpose_4"}
    %alloc_4 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_5 = arith.constant 0 : i32
    %7 = arith.sitofp %c0_i32_5 : i32 to f32
    linalg.fill {op_name = "matmul_init_zero_5"} ins(%7 : f32) outs(%alloc_4 : memref<5x2x4xf32>)
    %alloc_6 = memref.alloc() : memref<5x3x4xf32>
    linalg.broadcast ins(%alloc_2 : memref<3x4xf32>) outs(%alloc_6 : memref<5x3x4xf32>) dimensions = [0] 
    %alloc_7 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_8 = arith.constant 0 : i32
    %8 = arith.sitofp %c0_i32_8 : i32 to f32
    linalg.fill {op_name = "bmm_init_zero_6"} ins(%8 : f32) outs(%alloc_7 : memref<5x2x4xf32>)
    linalg.batch_matmul {op_name = "bmm_7"} ins(%alloc, %alloc_6 : memref<5x2x3xf32>, memref<5x3x4xf32>) outs(%alloc_7 : memref<5x2x4xf32>)
    %alloc_9 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_10 = arith.constant 0 : i32
    %9 = arith.sitofp %c0_i32_10 : i32 to f32
    linalg.fill {op_name = "view_init_zero_8"} ins(%9 : f32) outs(%alloc_9 : memref<5x2x4xf32>)
    %10 = memref.get_global @const_0 : memref<3xi64>
    %reshape = memref.reshape %alloc_7(%10) : (memref<5x2x4xf32>, memref<3xi64>) -> memref<5x2x4xf32>
    %alloc_11 = memref.alloc() : memref<5x2x4xf32>
    linalg.broadcast ins(%arg2 : memref<4xf32>) outs(%alloc_11 : memref<5x2x4xf32>) dimensions = [0, 1] 
    %alloc_12 = memref.alloc() {name = "output2"} : memref<5x2x4xf32>
    %c0_i32_13 = arith.constant 0 : i32
    %11 = arith.sitofp %c0_i32_13 : i32 to f32
    linalg.fill {op_name = "add_init_zero_9"} ins(%11 : f32) outs(%alloc_12 : memref<5x2x4xf32>)
    linalg.add {op_name = "add_10"} ins(%reshape, %alloc_11 : memref<5x2x4xf32>, memref<5x2x4xf32>) outs(%alloc_12 : memref<5x2x4xf32>)
    %alloc_14 = memref.alloc() : memref<5x8xf32>
    %c0_i32_15 = arith.constant 0 : i32
    %12 = arith.sitofp %c0_i32_15 : i32 to f32
    linalg.fill {op_name = "view_init_zero_11"} ins(%12 : f32) outs(%alloc_14 : memref<5x8xf32>)
    %13 = memref.get_global @const_1 : memref<2xi64>
    %reshape_16 = memref.reshape %alloc_12(%13) {name = "output"} : (memref<5x2x4xf32>, memref<2xi64>) -> memref<5x8xf32>
    return %reshape_16 : memref<5x8xf32>
  }


  func.func @main() {
	%arg0 = memref.alloc() : memref<5x3x2xf32>
	%arg1 = memref.alloc() : memref<4x3xf32>
	%arg2 = memref.alloc() : memref<4xf32>
	%arg3 = func.call @kernel(%arg0, %arg1, %arg2) : (memref<5x3x2xf32>, memref<4x3xf32>, memref<4xf32>) -> memref<5x8xf32>
	return
  }

}

Stack trace

 #3 0x00007f2186a5d950 llvm::isPotentiallyReachable(llvm::Instruction const*, llvm::Instruction const*, llvm::SmallPtrSetImpl<llvm::BasicBlock*> const*, llvm::DominatorTree const*, llvm::LoopInfo const*) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a21950)
 #4 0x00007f2186a2d619 llvm::EarliestEscapeInfo::isNotCapturedBeforeOrAt(llvm::Value const*, llvm::Instruction const*) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59f1619)
 #5 0x00007f2186a276bb llvm::BasicAAResult::getModRefInfo(llvm::CallBase const*, llvm::MemoryLocation const&, llvm::AAQueryInfo&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59eb6bb)
 #6 0x00007f2186a0756b llvm::AAResults::getModRefInfo(llvm::CallBase const*, llvm::MemoryLocation const&, llvm::AAQueryInfo&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59cb56b)
 #7 0x00007f2186a08e51 llvm::AAResults::getModRefInfo(llvm::Instruction const*, std::optional<llvm::MemoryLocation> const&, llvm::AAQueryInfo&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59cce51)
 #8 0x00007f218638ebdf (anonymous namespace)::DSEState::isReadClobber(llvm::MemoryLocation const&, llvm::Instruction*) DeadStoreElimination.cpp:0:0
 #9 0x00007f2186398eae (anonymous namespace)::DSEState::getDomMemoryDef(llvm::MemoryDef*, llvm::MemoryAccess*, llvm::MemoryLocation const&, llvm::Value const*, unsigned int&, unsigned int&, bool, unsigned int&) DeadStoreElimination.cpp:0:0
#10 0x00007f218639af52 (anonymous namespace)::eliminateDeadStores(llvm::Function&, llvm::AAResults&, llvm::MemorySSA&, llvm::DominatorTree&, llvm::PostDominatorTree&, llvm::AssumptionCache&, llvm::TargetLibraryInfo const&, llvm::LoopInfo const&) DeadStoreElimination.cpp:0:0
#11 0x00007f218639d038 llvm::DSEPass::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5361038)
#12 0x00007f218554e35e llvm::detail::PassModel<llvm::Function, llvm::DSEPass, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x451235e)
#13 0x00007f218737d514 llvm::PassManager<llvm::Function, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x6341514)
#14 0x00007f21855472ce llvm::detail::PassModel<llvm::Function, llvm::PassManager<llvm::Function, llvm::AnalysisManager<llvm::Function>>, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x450b2ce)
#15 0x00007f2186a7739f llvm::CGSCCToFunctionPassAdaptor::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a3b39f)
#16 0x00007f218554d40e llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::CGSCCToFunctionPassAdaptor, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x451140e)
#17 0x00007f2186a6ff5b llvm::PassManager<llvm::LazyCallGraph::SCC, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a33f5b)
#18 0x00007f218554d3ce llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::PassManager<llvm::LazyCallGraph::SCC, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x45113ce)
#19 0x00007f2186a73a55 llvm::DevirtSCCRepeatedPass::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a37a55)
#20 0x00007f218554d3ee llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::DevirtSCCRepeatedPass, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x45113ee)
#21 0x00007f2186a71bf9 llvm::ModuleToPostOrderCGSCCPassAdaptor::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a35bf9)
#22 0x00007f21857956ef llvm::ModuleInlinerWrapperPass::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x47596ef)
#23 0x00007f218554cf8e llvm::detail::PassModel<llvm::Module, llvm::ModuleInlinerWrapperPass, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Module>>::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x4510f8e)
#24 0x00007f2185542f50 mlir::makeOptimizingTransformer(unsigned int, unsigned int, llvm::TargetMachine*)::'lambda'(llvm::Module*)::operator()(llvm::Module*) const OptUtils.cpp:0:0
#25 0x00007f2185543cad std::_Function_handler<llvm::Error (llvm::Module*), mlir::makeOptimizingTransformer(unsigned int, unsigned int, llvm::TargetMachine*)::'lambda'(llvm::Module*)>::_M_invoke(std::_Any_data const&, llvm::Module*&&) OptUtils.cpp:0:0
#26 0x00007f218264550d llvm::Error llvm::function_ref<llvm::Error (llvm::Module*)>::callback_fn<std::function<llvm::Error (llvm::Module*)>>(long, llvm::Module*) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x160950d)
#27 0x00007f218306e1e6 mlir::ExecutionEngine::create(mlir::Operation*, mlir::ExecutionEngineOptions const&, std::unique_ptr<llvm::TargetMachine, std::default_delete<llvm::TargetMachine>>) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x20321e6)
#28 0x00007f2182646ab5 mlirExecutionEngineCreate (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x160aab5)
#29 0x00007f21803de5e5 pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)::operator()(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool) const /work/shared/users/common/llvm-project-18.x/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp:82:77
#30 0x00007f21803df91a void pybind11::detail::initimpl::factory<pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type (*)(), (anonymous namespace)::PyExecutionEngine* (MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<(anonymous namespace)::PyExecutionEngine>, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [327]>(pybind11::class_<(anonymous namespace)::PyExecutionEngine>&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [327]) &&::'lambda'(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)::operator()(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool) const /home/nz264/anaconda3/envs/mlir/lib/python3.8/site-packages/pybind11/include/pybind11/detail/init.h:242:29
#31 0x00007f21803e37e3 pybind11::class_<(anonymous namespace)::PyExecutionEngine> pybind11::detail::argument_loader<pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool>::call_impl<void, void pybind11::detail::initimpl::factory<pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type (*)(), (anonymous namespace)::PyExecutionEngine* (MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<(anonymous namespace)::PyExecutionEngine>, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [327]>(pybind11::class_<(anonymous namespace)::PyExecutionEngine>&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [327]) &&::'lambda'(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)&, 0ul, 1ul, 2ul, 3ul, 4ul, pybind11::detail::void_type>(void pybind11::detail::initimpl::factory<pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type (*)(), (anonymous namespace)::PyExecutionEngine* (MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<(anonymous namespace)::PyExecutionEngine>, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [327]>(pybind11::class_<(anonymous namespace)::PyExecutionEngine>&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [327]) &&::'lambda'(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)&, std::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul, 4ul>, pybind11::detail::void_type&&) && /home/nz264/anaconda3/envs/mlir/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1205:91
#32 0x00007f21803e3439 _ZNO8pybind116detail15argument_loaderIJRNS0_16value_and_holderE10MlirModuleiRKSt6vectorISsSaISsEEbEE4callIvNS0_9void_typeERZNOS0_8initimpl7factoryIZL34pybind11_init__mlirExecutionEngineRNS_7module_EEUlS4_iS9_bE_PFSC_vEFPN12_GLOBAL__N_117PyExecutionEngineES4_iS9_bESI_E7executeINS_6class_ISL_JEEEJNS_3argENS_5arg_vEST_ST_A327_cEEEvRT_DpRKT0_EUlS3_S4_iS9_bE_EENSt9enable_ifIXsrSt7is_voidISV_E5valueESC_E4typeEOT1_ /home/nz264/anaconda3/envs/mlir/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1183:26
@zzzDavid zzzDavid added the MLIR Limitation MLIR limitations label Sep 1, 2023
@zzzDavid zzzDavid self-assigned this Sep 1, 2023
@zzzDavid
Copy link
Collaborator Author

zzzDavid commented Sep 1, 2023

To compile any MLIR IR with gcc/clang, we can do this:

mlir-opt example.mlir \
	--convert-linalg-to-affine-loops \
	--one-shot-bufferize \
	--lower-affine \
	--convert-scf-to-cf \
	--convert-cf-to-llvm \
	--convert-func-to-llvm \
	--convert-arith-to-llvm \
	--finalize-memref-to-llvm \
	--reconcile-unrealized-casts \
	-o example.llvm.mlir


mlir-translate example.llvm.mlir \
	--mlir-to-llvmir \
	-o example.ll

llc example.ll -o example.s
as example.s -o example.o
gcc example.o -o example.exe

@zzzDavid
Copy link
Collaborator Author

zzzDavid commented Sep 1, 2023

Associated allo program to this sample:

import allo
from allo.ir.types import int32, float32
import numpy as np

def test_library_higher_dimension_ops(enable_tensor):
    M = 5
    N = 4
    K = 3
    L = 2
    A = np.random.uniform(size=(M, K, L)).astype(np.float32)
    B = np.random.uniform(size=(N, K)).astype(np.float32)
    C = np.random.uniform(size=(N,)).astype(np.float32)

    def kernel(
        A: float32[M, K, L], B: float32[N, K], C: float32[N]
    ) -> float32[M, L * N]:
        output1 = allo.transpose(A, (-1, -2))
        output2 = allo.linear(output1, B, C)
        output = allo.view(output2, (5, 8))
        return output

    s = allo.customize(kernel, enable_tensor=enable_tensor)
    mod = s.build()
    outp = mod(A, B, C)
    np_outp = kernel(A, B, C)
    np.testing.assert_allclose(outp, np_outp, rtol=1e-5)
    
if __name__ == "__main__":
	test_library_higher_dimension_ops(False)

@Zhichenzzz
Copy link

Zhichenzzz commented Sep 1, 2023

Thank you! This works for me. But when opt-level is set to 0 or 1, type test case can not pass.

def test_compare_int_float():
        Ty = Int(5)
    
        def kernel(A: Ty) -> Ty:
            B: Ty = 0
            if A > B or A + 1 < 0.0:
                B = A
            return B
    
        s = allo.customize(kernel)
        mod = s.build()
        assert mod(2) == kernel(2)
>       assert mod(-3) == kernel(-3)
E       assert 29 == -3
E        +  where 29 = <allo.backend.llvm.LLVMModule object at 0x7f8d21573c70>(-3)
E        +  and   -3 = <function test_compare_int_float.<locals>.kernel at 0x7f8d21d588b0>(-3)

tests/test_types.py:165: AssertionError

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
MLIR Limitation MLIR limitations
Projects
None yet
Development

No branches or pull requests

2 participants