From 1e3b9425fe96f735e5fce380a2c92314ad75897e Mon Sep 17 00:00:00 2001 From: Anton Lydike Date: Fri, 14 Jun 2024 11:26:26 +0100 Subject: [PATCH] add initial set of lowerings for MPI dialect --- .../mlir/Conversion/MPIToLLVM/MPIToLLVM.h | 30 +++ mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 2 +- mlir/include/mlir/InitAllExtensions.h | 2 + mlir/lib/Conversion/CMakeLists.txt | 1 + mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt | 17 ++ mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp | 230 ++++++++++++++++++ mlir/test/Conversion/MPIToLLVM/ops.mlir | 40 +++ 7 files changed, 321 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h create mode 100644 mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt create mode 100644 mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp create mode 100644 mlir/test/Conversion/MPIToLLVM/ops.mlir diff --git a/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h new file mode 100644 index 0000000000000..940e5e8097318 --- /dev/null +++ b/mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h @@ -0,0 +1,30 @@ +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_MPITOLLVM_H +#define MLIR_CONVERSION_MPITOLLVM_H + +#include "mlir/IR/DialectRegistry.h" + +namespace mlir { + +class LLVMTypeConverter; +class RewritePatternSet; + +#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" + +namespace mpi { +void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +void registerConvertMPIToLLVMInterface(DialectRegistry ®istry); + +} // namespace mpi +} // namespace mlir + +#endif // MLIR_CONVERSION_MPITOLLVM_H diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index 87eefa719d45c..57ac512642829 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -30,7 +30,7 @@ class MPI_Type traits = []> //===----------------------------------------------------------------------===// def MPI_Retval : MPI_Type<"Retval", "retval"> { - let summary = "MPI function call return value"; + let summary = "MPI function call return value (!mpi.retval)"; let description = [{ This type represents a return value from an MPI function call. This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code. diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 20a4ab6f18a28..3d7a45b97785f 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -14,6 +14,7 @@ #ifndef MLIR_INITALLEXTENSIONS_H_ #define MLIR_INITALLEXTENSIONS_H_ +#include "Conversion/MPIToLLVM/MPIToLLVM.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -62,6 +63,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { registerConvertFuncToLLVMInterface(registry); index::registerConvertIndexToLLVMInterface(registry); registerConvertMathToLLVMInterface(registry); + mpi::registerConvertMPIToLLVMInterface(registry); registerConvertMemRefToLLVMInterface(registry); registerConvertNVVMToLLVMInterface(registry); ub::registerConvertUBToLLVMInterface(registry); diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 0a03a2e133db1..46e3768801560 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -39,6 +39,7 @@ add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) +add_subdirectory(MPIToLLVM) add_subdirectory(NVGPUToNVVM) add_subdirectory(NVVMToLLVM) add_subdirectory(OpenACCToSCF) diff --git a/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt new file mode 100644 index 0000000000000..f81fb25e56840 --- /dev/null +++ b/mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRMPIToLLVM + MPIToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRMPIDialect + ) diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp new file mode 100644 index 0000000000000..d87a10aab8f49 --- /dev/null +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -0,0 +1,230 @@ +//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/Pass/Pass.h" + +#include + +using namespace mlir; + +namespace { + +struct InitOpLowering : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct CommRankOpLowering : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +struct FinalizeOpLowering : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +// TODO: this was copied from GPUOpsLowering.cpp:288 +// is this okay, or should this be moved to some common file? +LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc, + ConversionPatternRewriter &rewriter, + StringRef name, + LLVM::LLVMFunctionType type) { + LLVM::LLVMFuncOp ret; + if (!(ret = moduleOp.lookupSymbol(name))) { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + ret = rewriter.create(loc, name, type, + LLVM::Linkage::External); + } + return ret; +} + +// TODO: this is pretty close to getOrDefineFunction, can probably be factored +LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc, + ConversionPatternRewriter &rewriter, + StringRef name, + LLVM::LLVMStructType type) { + LLVM::GlobalOp ret; + if (!(ret = moduleOp.lookupSymbol(name))) { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + ret = rewriter.create( + loc, type, /*isConstant=*/false, LLVM::Linkage::External, name, + /*value=*/Attribute(), /*alignment=*/0, 0); + } + return ret; +} + +} // namespace + +//===----------------------------------------------------------------------===// +// InitOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult +InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // get loc + auto loc = op.getLoc(); + + // ptrType `!llvm.ptr` + Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + + // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr` + auto nullPtrOp = rewriter.create(loc, ptrType); + Value llvmnull = nullPtrOp.getRes(); + + // grab a reference to the global module op: + auto moduleOp = op->getParentOfType(); + + // LLVM Function type representing `i32 MPI_Init(ptr, ptr)` + auto initFuncType = + LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType}); + // get or create function declaration: + LLVM::LLVMFuncOp initDecl = + getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType); + + // replace init with function call + rewriter.replaceOpWithNewOp(op, initDecl, + ValueRange{llvmnull, llvmnull}); + + return success(); +} + +//===----------------------------------------------------------------------===// +// FinalizeOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult +FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // get loc + auto loc = op.getLoc(); + + // grab a reference to the global module op: + auto moduleOp = op->getParentOfType(); + + // LLVM Function type representing `i32 MPI_Finalize()` + auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {}); + // get or create function declaration: + LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter, + "MPI_Finalize", initFuncType); + + // replace init with function call + rewriter.replaceOpWithNewOp(op, initDecl, ValueRange{}); + + return success(); +} + +//===----------------------------------------------------------------------===// +// CommRankLowering +//===----------------------------------------------------------------------===// + +LogicalResult +CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // get some helper vars + auto loc = op.getLoc(); + auto context = rewriter.getContext(); + auto i32 = rewriter.getI32Type(); + + // ptrType `!llvm.ptr` + Type ptrType = LLVM::LLVMPointerType::get(context); + + // get external opaque struct pointer type + auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context); + + // grab a reference to the global module op: + auto moduleOp = op->getParentOfType(); + + // make sure global op definition exists + getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD", + commStructT); + + // get address of @MPI_COMM_WORLD + auto one = rewriter.create(loc, i32, 1); + auto rankptr = rewriter.create(loc, ptrType, i32, one); + auto commWorld = rewriter.create( + loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD")); + + // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)` + auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType}); + // get or create function declaration: + LLVM::LLVMFuncOp initDecl = getOrDefineFunction( + moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType); + + // replace init with function call + auto callOp = rewriter.create( + loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()}); + + // load the rank into a register + auto loadedRank = + rewriter.create(loc, i32, rankptr.getResult()); + + // if retval is checked, replace uses of retval with the results from the call + // op + SmallVector replacements; + if (op.getRetval()) { + replacements.push_back(callOp.getResult()); + } + // replace all uses, then erase op + replacements.push_back(loadedRank.getRes()); + rewriter.replaceOp(op, replacements); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add(converter); + patterns.add(converter); + patterns.add(converter); +} + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert Func to LLVM. +struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + mpi::populateMPIToLLVMConversionPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/test/Conversion/MPIToLLVM/ops.mlir b/mlir/test/Conversion/MPIToLLVM/ops.mlir new file mode 100644 index 0000000000000..a7a44ad24909a --- /dev/null +++ b/mlir/test/Conversion/MPIToLLVM/ops.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt -convert-to-llvm %s | FileCheck %s + +module { +// CHECK: llvm.func @MPI_Finalize() -> i32 +// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32 +// CHECK: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque> +// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32 + + func.func @mpi_test(%arg0: memref<100xf32>) { + %0 = mpi.init : !mpi.retval +// CHECK: %7 = llvm.mlir.zero : !llvm.ptr +// CHECK-NEXT: %8 = llvm.call @MPI_Init(%7, %7) : (!llvm.ptr, !llvm.ptr) -> i32 +// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %8 : i32 to !mpi.retval + + + %retval, %rank = mpi.comm_rank : !mpi.retval, i32 +// CHECK: %10 = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %11 = llvm.alloca %10 x i32 : (i32) -> !llvm.ptr +// CHECK-NEXT: %12 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr +// CHECK-NEXT: %13 = llvm.call @MPI_Comm_rank(%12, %11) : (!llvm.ptr, !llvm.ptr) -> i32 +// CHECK-NEXT: %14 = llvm.load %11 : !llvm.ptr -> i32 +// CHECK-NEXT: %15 = builtin.unrealized_conversion_cast %13 : i32 to !mpi.retval + + mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 + + %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + + mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 + + %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + + %3 = mpi.finalize : !mpi.retval +// CHECK: %18 = llvm.call @MPI_Finalize() : () -> i32 + + %4 = mpi.retval_check %retval = : i1 + + %5 = mpi.error_class %0 : !mpi.retval + return + } +}