-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][LLVM] Improve lowering of llvm.byval
function arguments
#100028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-gpu Author: Diego Caballero (dcaballe) ChangesTODO: Still thinking how to test this. If no better suggestion, I'll create a test pass with a custom LLVM converter... When a function argument is annotated with the
Unfortunately, this makes the type conversion context-dependent, which is something that the type conversion infrastructure (i.e., To fix this problem, this PR changes the FuncToLLVM conversion logic to always generate an Full diff: https://github.com/llvm/llvm-project/pull/100028.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e228229302cff..e448409e24b2a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -16,6 +16,7 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -40,6 +41,16 @@ class LLVMTypeConverter : public TypeConverter {
public:
using TypeConverter::convertType;
+ /// Encodes the passing mode for function arguments annotated with
+ /// `llvm.byval` and `llvm.byref` attributes:
+ /// * BYVAL: The argument has an `llvm.byval` attribute and, therefore,
+ /// it's passed by value.
+ /// * BYREF: The argument has an `llvm.byref` attribute and, therefore,
+ /// it's passed by reference.
+ /// * UNKNOWN: The argument doesn't have either `llvm.byval` or
+ /// `llvm.byref` attributes so its passing mode is unknown.
+ enum class ArgumentPassingMode { BYVAL, BYREF, UNKNOWN };
+
/// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
/// Optionally takes a data layout analysis to use in conversions.
LLVMTypeConverter(MLIRContext *ctx,
@@ -50,11 +61,17 @@ class LLVMTypeConverter : public TypeConverter {
LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options,
const DataLayoutAnalysis *analysis = nullptr);
+ /// Returns the passing mode for function arguments annotated with
+ /// `llvm.byval` and `llvm.byref` attributes.
+ SmallVector<ArgumentPassingMode>
+ getArgumentsPassingMode(FunctionOpInterface funcOp) const;
+
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
- Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
- bool useBarePtrCallConv,
+ Type convertFunctionSignature(FunctionType funcTy,
+ ArrayRef<ArgumentPassingMode> argsPassingMode,
+ bool isVariadic, bool useBarePtrCallConv,
SignatureConversion &result) const;
/// Convert a non-empty list of types to be returned from a function into an
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 059acb217709c..6272c22f45407 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -20,6 +20,7 @@
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -279,10 +280,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
+ auto argsPassingMode = converter.getArgumentsPassingMode(funcOp);
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = converter.convertFunctionSignature(
- funcTy, varargsAttr && varargsAttr.getValue(),
+ funcTy, argsPassingMode, varargsAttr && varargsAttr.getValue(),
shouldUseBarePtrCallConv(funcOp, &converter), result);
if (!llvmType)
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6053e34f30a41..3ccd661fc09fe 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
#include "GPUOpsLowering.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -53,9 +54,9 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
gpuFuncOp.front().getNumArguments());
-
+ auto argsPassingMode = getTypeConverter()->getArgumentsPassingMode(gpuFuncOp);
Type funcType = getTypeConverter()->convertFunctionSignature(
- gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
+ gpuFuncOp.getFunctionType(), argsPassingMode, /*isVariadic=*/false,
getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
if (!funcType) {
return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 32d02d5e438bd..5484fce738d04 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -269,23 +269,55 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
return LLVM::LLVMPointerType::get(type.getContext());
}
+SmallVector<LLVMTypeConverter::ArgumentPassingMode>
+LLVMTypeConverter::getArgumentsPassingMode(FunctionOpInterface funcOp) const {
+ SmallVector<ArgumentPassingMode> argsPassingMode(
+ funcOp.getNumArguments(), ArgumentPassingMode::UNKNOWN);
+
+ for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+ for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+ if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName()) {
+ argsPassingMode[argIdx] = ArgumentPassingMode::BYVAL;
+ break;
+ }
+ if (namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
+ argsPassingMode[argIdx] = ArgumentPassingMode::BYREF;
+ break;
+ }
+ }
+ }
+
+ return argsPassingMode;
+}
+
// Function types are converted to LLVM Function types by recursively converting
// argument and result types. If MLIR Function has zero results, the LLVM
// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
Type LLVMTypeConverter::convertFunctionSignature(
- FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+ FunctionType funcTy, ArrayRef<ArgumentPassingMode> argsPassingMode,
+ bool isVariadic, bool useBarePtrCallConv,
LLVMTypeConverter::SignatureConversion &result) const {
// Select the argument converter depending on the calling convention.
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
: structFuncArgTypeConverter;
// Convert argument types one by one and check for errors.
+ MLIRContext *ctx = funcTy.getContext();
for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
- SmallVector<Type, 8> converted;
- if (failed(funcArgConverter(*this, type, converted)))
+ SmallVector<Type, 8> convertedTypes;
+ if (failed(funcArgConverter(*this, type, convertedTypes)))
return {};
- result.addInputs(idx, converted);
+ // Type converter can't differenciate between converting an argument type or
+ // any other type. If a converted argument has the `llvm.byval` attribute,
+ // we replace the type with an LLVM pointer so that the `llvm.byval`
+ // convention is correct.
+ ArgumentPassingMode argPassMode = argsPassingMode[idx];
+ for (Type &convertedTy : convertedTypes) {
+ if (argPassMode == ArgumentPassingMode::BYVAL)
+ convertedTy = LLVM::LLVMPointerType::get(ctx);
+ }
+ result.addInputs(idx, convertedTypes);
}
// If function does not return anything, create the void result type,
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index da09384bfbe89..106228c3a619f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1410,9 +1410,10 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
auto funcType = funcOp.getFunctionType();
TypeConverter::SignatureConversion signatureConverter(
funcType.getNumInputs());
+ auto argsPassingMode = typeConverter.getArgumentsPassingMode(funcOp);
auto llvmType = typeConverter.convertFunctionSignature(
- funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
- signatureConverter);
+ funcType, argsPassingMode, /*isVariadic=*/false,
+ /*useBarePtrCallConv=*/false, signatureConverter);
if (!llvmType)
return failure();
|
Will this not lead to inconsistencies in the types? |
Good point! I see how that is automatically handled for my case... but that's may not be the cases in general. |
Isn't there a problem with the way |
Good point! There are a few aspects related to this:
I’m planning to follow up on this and automatically introduce the |
Taking another angle, the property to keep is that the type converter is always 1:1, that is for a given type before conversion you always get the same converted type. |
It sounds reasonable. I was hoping to avoid creating two function signatures but let me give it a try... |
I'm not sure what you're referring to? To be clear: in my mind there should be no change to the TypeConverter or the dialect conversion framework in general. |
I meant creating two converted function types if we introduced the We can also do it before that point and create the converted function with the final type. However, it's not ideal either. We would have to update the argument types in the We would also have to map the attributes from the original function to the converted function, since the attributes haven’t been propagated to the converted function yet. Finally, we would have to replicate all this functionality in the SPIR-V and GPU lowerings to LLVM... This is mostly why I encapsulated all this complexity behind the LLVM type converter (I tried a few other things before reaching this point). It makes the type conversion dependent on the argument attributes (which is accurate for the specific case of the function signature, after all) but the LLVM type converter would return a valid converted function type in the first place. No strong opinion, though. It’s a matter of trade-offs. Happy to follow the feedback provided but I want to make sure we are all on the same page. |
From an initial read of the PR, I agree with Mehdi here. This PR seems to be creating an implicit transformation for a given LLVM attribute, which seems a bit off (I wouldn't necessarily expect the converter to automatically pointer-ize the argument type). It feels like this kind of behavior should be left directly to the llvm func conversion, with additional extensions added to the llvm converter (if it's to continue being used as an extension point, otherwise some other extension point) for opting-into/controlling this behavior. I'm just a bit wary of having double meanings for attributes that are contextual (given it's not really searchable/easily understandable, you'll only see the LLVM docs which are quite clear on expected usage). |
I saw |
Note that the current changes are to Ok, if no other comments, let me try to move the logic to the llvm function conversion so that we can compare both approaches. |
When a function argument is annotated with the `llvm.byval` attribute, [LLVM expects] (https://llvm.org/docs/LangRef.html#parameter-attributes) the function argument type to be an `llvm.ptr`. For example: ``` func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct<(i32)>} { ... } ``` Unfortunately, this makes the type conversion context-dependent, which is something that the type conversion infrastructure (i.e., `LLVMTypeConverter` in this particular case) doesn't support. For example, we may want to convert `MyType` to `llvm.struct<(i32)>` in general, but to an `llvm.ptr` type only when it's a function argument passed by value. To fix this problem, this PR changes the FuncToLLVM conversion logic to always generate an `llvm.ptr` when the function argument has a `llvm.byval` attribute. An `llvm.load` is inserted into the function to retrieve the value expected by the argument users.
This reverts commit eb7cec3.
The second approach is ready for review! Hopefully this makes more sense. |
Just for reference:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The general approach looks good, it would be nice to have the logic a bit cleaner though. Avoiding the secondary pass over the function arguments doing type conversion would probably clean this up a lot. Suppose that may be somewhere in the middle of the two approaches, but I leave things up to you.
auto llvmType = | ||
cast_or_null<LLVM::LLVMFunctionType>(converter.convertFunctionSignature( | ||
funcTy, varargsAttr && varargsAttr.getValue(), | ||
shouldUseBarePtrCallConv(funcOp, &converter), result)); | ||
if (!llvmType) | ||
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed"); | ||
|
||
// Replace the type of `llvm.byval` and `llvm.byref` arguments that were not | ||
// converted to an LLVM pointer type. | ||
SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs; | ||
filterByValRefNonPtrAttrs(funcOp, result, byValRefNonPtrAttrs); | ||
llvmType = converter.materializePtrForByValByRefFuncArgs( | ||
llvmType, byValRefNonPtrAttrs, result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like the filtering and signature conversion could be in a new convertFunctionSignature
overload that takes a FuncOp (would encapsulate this a bit more, cleanup the materializePtrForByValByRefFuncArgs logic, and avoid the double type construction), but no super strong feeling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that was the motivation for option num 1. I implemented the override and remove a bunch of new methods that are no longer needed. PTAL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
When a function argument is annotated with the
llvm.byval
attribute, LLVM expects the function argument type to be anllvm.ptr
. For example:Unfortunately, this makes the type conversion context-dependent, which is something that the type conversion infrastructure (i.e.,
LLVMTypeConverter
in this particular case) doesn't support. For example, we may want to convertMyType
tollvm.struct<(i32)>
in general, but to anllvm.ptr
type only when it's a function argument passed by value.To fix this problem, this PR changes the FuncToLLVM conversion logic to generate an
llvm.ptr
when the function argument has allvm.byval
attribute. Anllvm.load
is inserted into the function to retrieve the value expected by the argument users.