Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions mlir/include/mlir/Conversion/MPIToLLVM/MPIToLLVM.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_MPITOLLVM_H
#define MLIR_CONVERSION_MPITOLLVM_H

#include "mlir/IR/DialectRegistry.h"

namespace mlir {

class LLVMTypeConverter;
class RewritePatternSet;

#define GEN_PASS_DECL_MPITOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"

namespace mpi {
void populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);

void registerConvertMPIToLLVMInterface(DialectRegistry &registry);

} // namespace mpi
} // namespace mlir

#endif // MLIR_CONVERSION_MPITOLLVM_H
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/MPI/IR/MPITypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MPI_Type<string name, string typeMnemonic, list<Trait> traits = []>
//===----------------------------------------------------------------------===//

def MPI_Retval : MPI_Type<"Retval", "retval"> {
let summary = "MPI function call return value";
let summary = "MPI function call return value (!mpi.retval)";
let description = [{
This type represents a return value from an MPI function call.
This value can be MPI_SUCCESS, MPI_ERR_IN_STATUS, or any error code.
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef MLIR_INITALLEXTENSIONS_H_
#define MLIR_INITALLEXTENSIONS_H_

#include "Conversion/MPIToLLVM/MPIToLLVM.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include "Conversion/MPIToLLVM/MPIToLLVM.h"
#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"

#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
Expand Down Expand Up @@ -62,6 +63,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
registerConvertFuncToLLVMInterface(registry);
index::registerConvertIndexToLLVMInterface(registry);
registerConvertMathToLLVMInterface(registry);
mpi::registerConvertMPIToLLVMInterface(registry);
registerConvertMemRefToLLVMInterface(registry);
registerConvertNVVMToLLVMInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
add_subdirectory(MPIToLLVM)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/Conversion/MPIToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_conversion_library(MLIRMPIToLLVM
MPIToLLVM.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MPIToLLVM

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRMPIDialect
)
230 changes: 230 additions & 0 deletions mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Pass/Pass.h"

#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include <mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h>
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"


using namespace mlir;

namespace {

struct InitOpLowering : ConvertOpToLLVMPattern<mpi::InitOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct CommRankOpLowering : ConvertOpToLLVMPattern<mpi::CommRankOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct FinalizeOpLowering : ConvertOpToLLVMPattern<mpi::FinalizeOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

// TODO: this was copied from GPUOpsLowering.cpp:288
// is this okay, or should this be moved to some common file?
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, Location loc,

Comment on lines +45 to +48
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// TODO: this was copied from GPUOpsLowering.cpp:288
// is this okay, or should this be moved to some common file?
LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,
} // namespace
// TODO: this was copied from GPUOpsLowering.cpp:288
// is this okay, or should this be moved to some common file?
static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, const Location loc,

LLVM convention is to mark functions as static rather than having them in anonymous namespaces. Ditto the other function

ConversionPatternRewriter &rewriter,
StringRef name,
LLVM::LLVMFunctionType type) {
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
Comment on lines +52 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LLVM::LLVMFuncOp ret;
if (!(ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name))) {
auto ret = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
if (ret)
return ret;
...

ulta nit: Looks nicer as an early return in my opinion, Ditto below

ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
LLVM::Linkage::External);
}
return ret;
}

// TODO: this is pretty close to getOrDefineFunction, can probably be factored
LLVM::GlobalOp getOrDefineExternalStruct(ModuleOp &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter,
StringRef name,
LLVM::LLVMStructType type) {
LLVM::GlobalOp ret;
if (!(ret = moduleOp.lookupSymbol<LLVM::GlobalOp>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
ret = rewriter.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/false, LLVM::Linkage::External, name,
/*value=*/Attribute(), /*alignment=*/0, 0);
}
return ret;
}

} // namespace

//===----------------------------------------------------------------------===//
// InitOpLowering
//===----------------------------------------------------------------------===//

LogicalResult
InitOpLowering::matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// get loc
auto loc = op.getLoc();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto loc = op.getLoc();
Location loc = op.getLoc();

Ditto in other places where the type of the variable does not appear in the right-hand side expression


// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());

// instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
Value llvmnull = nullPtrOp.getRes();

// grab a reference to the global module op:
auto moduleOp = op->getParentOfType<ModuleOp>();

// LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
auto initFuncType =
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);

// replace init with function call
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
ValueRange{llvmnull, llvmnull});

return success();
}

//===----------------------------------------------------------------------===//
// FinalizeOpLowering
//===----------------------------------------------------------------------===//

LogicalResult
FinalizeOpLowering::matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// get loc
auto loc = op.getLoc();

// grab a reference to the global module op:
auto moduleOp = op->getParentOfType<ModuleOp>();

// LLVM Function type representing `i32 MPI_Finalize()`
auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(moduleOp, loc, rewriter,
"MPI_Finalize", initFuncType);

// replace init with function call
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});

return success();
}

//===----------------------------------------------------------------------===//
// CommRankLowering
//===----------------------------------------------------------------------===//

LogicalResult
CommRankOpLowering::matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// get some helper vars
auto loc = op.getLoc();
auto context = rewriter.getContext();
auto i32 = rewriter.getI32Type();

// ptrType `!llvm.ptr`
Type ptrType = LLVM::LLVMPointerType::get(context);

// get external opaque struct pointer type
auto commStructT = LLVM::LLVMStructType::getOpaque("MPI_ABI_Comm", context);

// grab a reference to the global module op:
auto moduleOp = op->getParentOfType<ModuleOp>();

// make sure global op definition exists
getOrDefineExternalStruct(moduleOp, loc, rewriter, "MPI_COMM_WORLD",
commStructT);

// get address of @MPI_COMM_WORLD
auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
auto commWorld = rewriter.create<LLVM::AddressOfOp>(
loc, ptrType, SymbolRefAttr::get(context, "MPI_COMM_WORLD"));

// LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
auto rankFuncType = LLVM::LLVMFunctionType::get(i32, {ptrType, ptrType});
// get or create function declaration:
LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);

// replace init with function call
auto callOp = rewriter.create<LLVM::CallOp>(
loc, initDecl, ValueRange{commWorld.getRes(), rankptr.getRes()});

// load the rank into a register
auto loadedRank =
rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());

// if retval is checked, replace uses of retval with the results from the call
// op
SmallVector<Value> replacements;
if (op.getRetval()) {
replacements.push_back(callOp.getResult());
}
Comment on lines +187 to +189
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (op.getRetval()) {
replacements.push_back(callOp.getResult());
}
if (op.getRetval())
replacements.push_back(callOp.getResult());

// replace all uses, then erase op
replacements.push_back(loadedRank.getRes());
rewriter.replaceOp(op, replacements);

return success();
}

//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//

void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<InitOpLowering>(converter);
patterns.add<CommRankOpLowering>(converter);
patterns.add<FinalizeOpLowering>(converter);
}

//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//

namespace {
/// Implement the interface to convert Func to LLVM.
struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;

/// Hook for derived dialect interface to provide conversion patterns
/// and mark dialect legal for the conversion target.
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
mpi::populateMPIToLLVMConversionPatterns(typeConverter, patterns);
}
};
} // namespace

void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
dialect->addInterfaces<FuncToLLVMDialectInterface>();
});
}
40 changes: 40 additions & 0 deletions mlir/test/Conversion/MPIToLLVM/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: mlir-opt -convert-to-llvm %s | FileCheck %s

module {
// CHECK: llvm.func @MPI_Finalize() -> i32
// CHECK: llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32
// CHECK: llvm.mlir.global external @MPI_COMM_WORLD() {addr_space = 0 : i32} : !llvm.struct<"MPI_ABI_Comm", opaque>
// CHECK: llvm.func @MPI_Init(!llvm.ptr, !llvm.ptr) -> i32

func.func @mpi_test(%arg0: memref<100xf32>) {
%0 = mpi.init : !mpi.retval
// CHECK: %7 = llvm.mlir.zero : !llvm.ptr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it has always been convention that the check lines for a particular ops lowering are before the op being lowered.

// CHECK-NEXT: %8 = llvm.call @MPI_Init(%7, %7) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %8 : i32 to !mpi.retval


%retval, %rank = mpi.comm_rank : !mpi.retval, i32
// CHECK: %10 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %11 = llvm.alloca %10 x i32 : (i32) -> !llvm.ptr
// CHECK-NEXT: %12 = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr
// CHECK-NEXT: %13 = llvm.call @MPI_Comm_rank(%12, %11) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK-NEXT: %14 = llvm.load %11 : !llvm.ptr -> i32
// CHECK-NEXT: %15 = builtin.unrealized_conversion_cast %13 : i32 to !mpi.retval
Comment on lines +11 to +22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests should be using filecheck variables for the SSA values. It is also convention that the // of the comments have the same indentation as the body.

You can also try (but don't see this as blocking) to slim down the tests a bit to not test the syntax of the LLVM operations so much, but rather just the lowering.
Just

%[[NULL_PTR:.*]] = llvm.mlir.zero
%[[INIT_RAW:.*]] = llvm.call @MPI_INIT(%[[NULL_PTR, %[[NULL_PTR]])
%[[INIT:.*]] = builtin.unrealized_conversion_cast %[[INIT_RAW]] : i32 to !mpi.retval

is fine (maybe splitting some of the type signatures with // CHECK-SAME: on the next lines if you feel they are very relevant.


mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32

%1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval

mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32

%2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval

%3 = mpi.finalize : !mpi.retval
// CHECK: %18 = llvm.call @MPI_Finalize() : () -> i32

%4 = mpi.retval_check %retval = <MPI_SUCCESS> : i1

%5 = mpi.error_class %0 : !mpi.retval
return
}
}