Skip to content

Conversation

dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Jul 23, 2024

When a function argument is annotated with the llvm.byval attribute, LLVM expects the function argument type to be an llvm.ptr. For example:

func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct<(i32)>} {
  ...
}

Unfortunately, this makes the type conversion context-dependent, which is something that the type conversion infrastructure (i.e., LLVMTypeConverter in this particular case) doesn't support. For example, we may want to convert MyType to llvm.struct<(i32)> in general, but to an llvm.ptr type only when it's a function argument passed by value.

To fix this problem, this PR changes the FuncToLLVM conversion logic to generate an llvm.ptr when the function argument has a llvm.byval attribute. An llvm.load is inserted into the function to retrieve the value expected by the argument users.

@llvmbot
Copy link
Member

llvmbot commented Jul 23, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Diego Caballero (dcaballe)

Changes

TODO: Still thinking how to test this. If no better suggestion, I'll create a test pass with a custom LLVM converter...

When a function argument is annotated with the llvm.byval attribute, [LLVM expects] (https://llvm.org/docs/LangRef.html#parameter-attributes) the function argument type to be an llvm.ptr. For example:

func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct&lt;(i32)&gt;} {
  ...
}

Unfortunately, this makes the type conversion context-dependent, which is something that the type conversion infrastructure (i.e., LLVMTypeConverter in this particular case) doesn't support. For example, we may want to convert MyType to llvm.struct&lt;(i32)&gt; in general, but to an llvm.ptr type only when it's a function argument passed by value.

To fix this problem, this PR changes the FuncToLLVM conversion logic to always generate an llvm.ptr when the type resulting from the function argument has a llvm.byval attribute.


Full diff: https://github.com/llvm/llvm-project/pull/100028.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h (+19-2)
  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+3-1)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+3-2)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+36-4)
  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+3-2)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e228229302cff..e448409e24b2a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -40,6 +41,16 @@ class LLVMTypeConverter : public TypeConverter {
 public:
   using TypeConverter::convertType;
 
+  /// Encodes the passing mode for function arguments annotated with
+  /// `llvm.byval` and `llvm.byref` attributes:
+  ///   * BYVAL: The argument has an `llvm.byval` attribute and, therefore,
+  ///            it's passed by value.
+  ///   * BYREF: The argument has an `llvm.byref` attribute and, therefore,
+  ///            it's passed by reference.
+  ///   * UNKNOWN: The argument doesn't have either `llvm.byval` or
+  ///              `llvm.byref` attributes so its passing mode is unknown.
+  enum class ArgumentPassingMode { BYVAL, BYREF, UNKNOWN };
+
   /// Create an LLVMTypeConverter using the default LowerToLLVMOptions.
   /// Optionally takes a data layout analysis to use in conversions.
   LLVMTypeConverter(MLIRContext *ctx,
@@ -50,11 +61,17 @@ class LLVMTypeConverter : public TypeConverter {
   LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options,
                     const DataLayoutAnalysis *analysis = nullptr);
 
+  /// Returns the passing mode for function arguments annotated with
+  /// `llvm.byval` and `llvm.byref` attributes.
+  SmallVector<ArgumentPassingMode>
+  getArgumentsPassingMode(FunctionOpInterface funcOp) const;
+
   /// Convert a function type.  The arguments and results are converted one by
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
   /// is populated with argument mapping.
-  Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
-                                bool useBarePtrCallConv,
+  Type convertFunctionSignature(FunctionType funcTy,
+                                ArrayRef<ArgumentPassingMode> argsPassingMode,
+                                bool isVariadic, bool useBarePtrCallConv,
                                 SignatureConversion &result) const;
 
   /// Convert a non-empty list of types to be returned from a function into an
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 059acb217709c..6272c22f45407 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -279,10 +280,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
 
   // Convert the original function arguments. They are converted using the
   // LLVMTypeConverter provided to this legalization pattern.
+  auto argsPassingMode = converter.getArgumentsPassingMode(funcOp);
   auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
   TypeConverter::SignatureConversion result(funcOp.getNumArguments());
   auto llvmType = converter.convertFunctionSignature(
-      funcTy, varargsAttr && varargsAttr.getValue(),
+      funcTy, argsPassingMode, varargsAttr && varargsAttr.getValue(),
       shouldUseBarePtrCallConv(funcOp, &converter), result);
   if (!llvmType)
     return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6053e34f30a41..3ccd661fc09fe 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
 #include "GPUOpsLowering.h"
 
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -53,9 +54,9 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   // Remap proper input types.
   TypeConverter::SignatureConversion signatureConversion(
       gpuFuncOp.front().getNumArguments());
-
+  auto argsPassingMode = getTypeConverter()->getArgumentsPassingMode(gpuFuncOp);
   Type funcType = getTypeConverter()->convertFunctionSignature(
-      gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
+      gpuFuncOp.getFunctionType(), argsPassingMode, /*isVariadic=*/false,
       getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
   if (!funcType) {
     return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 32d02d5e438bd..5484fce738d04 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -269,23 +269,55 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
   return LLVM::LLVMPointerType::get(type.getContext());
 }
 
+SmallVector<LLVMTypeConverter::ArgumentPassingMode>
+LLVMTypeConverter::getArgumentsPassingMode(FunctionOpInterface funcOp) const {
+  SmallVector<ArgumentPassingMode> argsPassingMode(
+      funcOp.getNumArguments(), ArgumentPassingMode::UNKNOWN);
+
+  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+    for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+      if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName()) {
+        argsPassingMode[argIdx] = ArgumentPassingMode::BYVAL;
+        break;
+      }
+      if (namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
+        argsPassingMode[argIdx] = ArgumentPassingMode::BYREF;
+        break;
+      }
+    }
+  }
+
+  return argsPassingMode;
+}
+
 // Function types are converted to LLVM Function types by recursively converting
 // argument and result types.  If MLIR Function has zero results, the LLVM
 // Function has one VoidType result.  If MLIR Function has more than one result,
 // they are into an LLVM StructType in their order of appearance.
 Type LLVMTypeConverter::convertFunctionSignature(
-    FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+    FunctionType funcTy, ArrayRef<ArgumentPassingMode> argsPassingMode,
+    bool isVariadic, bool useBarePtrCallConv,
     LLVMTypeConverter::SignatureConversion &result) const {
   // Select the argument converter depending on the calling convention.
   useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
   auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
                                              : structFuncArgTypeConverter;
   // Convert argument types one by one and check for errors.
+  MLIRContext *ctx = funcTy.getContext();
   for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
-    SmallVector<Type, 8> converted;
-    if (failed(funcArgConverter(*this, type, converted)))
+    SmallVector<Type, 8> convertedTypes;
+    if (failed(funcArgConverter(*this, type, convertedTypes)))
       return {};
-    result.addInputs(idx, converted);
+    // Type converter can't differenciate between converting an argument type or
+    // any other type. If a converted argument has the `llvm.byval` attribute,
+    // we replace the type with an LLVM pointer so that the `llvm.byval`
+    // convention is correct.
+    ArgumentPassingMode argPassMode = argsPassingMode[idx];
+    for (Type &convertedTy : convertedTypes) {
+      if (argPassMode == ArgumentPassingMode::BYVAL)
+        convertedTy = LLVM::LLVMPointerType::get(ctx);
+    }
+    result.addInputs(idx, convertedTypes);
   }
 
   // If function does not return anything, create the void result type,
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index da09384bfbe89..106228c3a619f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1410,9 +1410,10 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
     auto funcType = funcOp.getFunctionType();
     TypeConverter::SignatureConversion signatureConverter(
         funcType.getNumInputs());
+    auto argsPassingMode = typeConverter.getArgumentsPassingMode(funcOp);
     auto llvmType = typeConverter.convertFunctionSignature(
-        funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
-        signatureConverter);
+        funcType, argsPassingMode, /*isVariadic=*/false,
+        /*useBarePtrCallConv=*/false, signatureConverter);
     if (!llvmType)
       return failure();
 

@Dinistro
Copy link
Contributor

Will this not lead to inconsistencies in the types?
Given your MyType example, this could result in IR where the argument is now an llvm.ptr, while all its usages expect an llvm.struct<(i32)>. Is this supposed to be fixed by some additional meterializations, or how would one deal with such cases?

@dcaballe
Copy link
Contributor Author

Good point! I see how that is automatically handled for my case... but that's may not be the cases in general.
We can address this by introducing a load (e.g., %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(...)>) and replacing all the uses of the argument with it. IIRC, there was an easy way to do that for function arguments... Let me look into it. Thanks!

@dcaballe dcaballe requested a review from Dinistro July 23, 2024 16:22
@joker-eph joker-eph requested a review from River707 July 29, 2024 17:27
@joker-eph
Copy link
Collaborator

Isn't there a problem with the way llvm.byval seems modeled here: why is this something that annotating before the lowering instead of a property of the type you lower?

@dcaballe
Copy link
Contributor Author

Isn't there a problem with the way llvm.byval seems modeled here: why is this something that annotating before the lowering instead of a property of the type you lower?

Good point! There are a few aspects related to this:

  • There seems to be multiple users of this “pre-lowering annotation” approach, including FIR and probably some downstreams, which goes beyond the llvm.byval case. We can see how a few attributes are processed here and they are not inserted by the LLVM lowering itself. In some cases, this is probably an attempt at reusing those attributes in other dialects before LLVM.

  • This llvm.byval mechanism seems very specific to LLVM. A dialect that is eventually lowered to LLVM (or any other egress dialect) may model a pass-by-value of a non-ptr type by just passing an SSA value of that type. Unfortunately, if that type is lowered to an aggregate type in LLVM, the pass-by-value will have to go through the llvm.byval mechanism, which requires passing an llvm.ptr instead of the actual aggregate type. To me, this looks like an LLVM ABI requirement more than a property of the type lowered to LLVM.

  • The llvm.byval mechanism can also be used for non-aggregate types (not sure why we would use it instead of just passing the non-aggregate type directly). This makes the previous bullet also applicable to non-aggregate types.

I’m planning to follow up on this and automatically introduce the llvm.byval attributes for aggregate type arguments. However, that won’t solve the problem for non-aggregate types and other LLVM attributes that are currently introduced before the lowering.

@joker-eph
Copy link
Collaborator

Taking another angle, the property to keep is that the type converter is always 1:1, that is for a given type before conversion you always get the same converted type.
We could do that and handle the byval conversion separately (outside of the type converter) directly in the signature conversion logic targeting llvm.func? (the logic would have to add the right load to the converted type as well).

@dcaballe
Copy link
Contributor Author

We could do that and handle the byval conversion separately (outside of the type converter) directly in the signature conversion logic targeting llvm.func? (the logic would have to add the right load to the converted type as well).

It sounds reasonable. I was hoping to avoid creating two function signatures but let me give it a try...

@joker-eph
Copy link
Collaborator

I was hoping to avoid creating two function signatures but let me give it a try...

I'm not sure what you're referring to? To be clear: in my mind there should be no change to the TypeConverter or the dialect conversion framework in general.

@dcaballe
Copy link
Contributor Author

I meant creating two converted function types if we introduced the llvm.ptr argument types after this point.

We can also do it before that point and create the converted function with the final type. However, it's not ideal either. We would have to update the argument types in the TypeConverter::SignatureConversion returned by converter.convertFunctionSignature. To some extent, this means the LLVM type converter wouldn't be providing a valid conversion in the first place, which is unexpected.

We would also have to map the attributes from the original function to the converted function, since the attributes haven’t been propagated to the converted function yet. Finally, we would have to replicate all this functionality in the SPIR-V and GPU lowerings to LLVM...

This is mostly why I encapsulated all this complexity behind the LLVM type converter (I tried a few other things before reaching this point). It makes the type conversion dependent on the argument attributes (which is accurate for the specific case of the function signature, after all) but the LLVM type converter would return a valid converted function type in the first place.

No strong opinion, though. It’s a matter of trade-offs. Happy to follow the feedback provided but I want to make sure we are all on the same page.

@River707
Copy link
Contributor

River707 commented Jul 30, 2024

From an initial read of the PR, I agree with Mehdi here. This PR seems to be creating an implicit transformation for a given LLVM attribute, which seems a bit off (I wouldn't necessarily expect the converter to automatically pointer-ize the argument type). It feels like this kind of behavior should be left directly to the llvm func conversion, with additional extensions added to the llvm converter (if it's to continue being used as an extension point, otherwise some other extension point) for opting-into/controlling this behavior. I'm just a bit wary of having double meanings for attributes that are contextual (given it's not really searchable/easily understandable, you'll only see the LLVM docs which are quite clear on expected usage).

@joker-eph
Copy link
Collaborator

I saw TypeConverter.h being modified and thought it was the general dialect conversion, not something specific to LLVM conversion.
It can be reasonable to have dedicated support for LLVM specific things in the LLVMTypeConverter indeed.

@dcaballe
Copy link
Contributor Author

I saw TypeConverter.h being modified and thought it was the general dialect conversion, not something specific to LLVM conversion.

Note that the current changes are to LLVMTypeConverter (under .../LLVMCommon/TypeConverter.h), not the generic type converter.

Ok, if no other comments, let me try to move the logic to the llvm function conversion so that we can compare both approaches.
Thanks!

dcaballe added 3 commits July 30, 2024 16:01
When a function argument is annotated with the `llvm.byval` attribute,
[LLVM expects] (https://llvm.org/docs/LangRef.html#parameter-attributes)
the function argument type to be an `llvm.ptr`. For example:

```
func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct<(i32)>} {
  ...
}
```

Unfortunately, this makes the type conversion context-dependent, which is
something that the type conversion infrastructure (i.e., `LLVMTypeConverter`
in this particular case) doesn't support. For example, we may want to convert
`MyType` to `llvm.struct<(i32)>` in general, but to an `llvm.ptr` type only
when it's a function argument passed by value.

To fix this problem, this PR changes the FuncToLLVM conversion logic to always
generate an `llvm.ptr` when the function argument has a `llvm.byval` attribute.
An `llvm.load` is inserted into the function to retrieve the value expected by
the argument users.
@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Aug 1, 2024
@dcaballe
Copy link
Contributor Author

dcaballe commented Aug 1, 2024

The second approach is ready for review! Hopefully this makes more sense.

@dcaballe
Copy link
Contributor Author

dcaballe commented Aug 2, 2024

Just for reference:

Copy link
Contributor

@River707 River707 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The general approach looks good, it would be nice to have the logic a bit cleaner though. Avoiding the secondary pass over the function arguments doing type conversion would probably clean this up a lot. Suppose that may be somewhere in the middle of the two approaches, but I leave things up to you.

Comment on lines 348 to 360
auto llvmType =
cast_or_null<LLVM::LLVMFunctionType>(converter.convertFunctionSignature(
funcTy, varargsAttr && varargsAttr.getValue(),
shouldUseBarePtrCallConv(funcOp, &converter), result));
if (!llvmType)
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");

// Replace the type of `llvm.byval` and `llvm.byref` arguments that were not
// converted to an LLVM pointer type.
SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
filterByValRefNonPtrAttrs(funcOp, result, byValRefNonPtrAttrs);
llvmType = converter.materializePtrForByValByRefFuncArgs(
llvmType, byValRefNonPtrAttrs, result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like the filtering and signature conversion could be in a new convertFunctionSignature overload that takes a FuncOp (would encapsulate this a bit more, cleanup the materializePtrForByValByRefFuncArgs logic, and avoid the double type construction), but no super strong feeling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was the motivation for option num 1. I implemented the override and remove a bunch of new methods that are no longer needed. PTAL

Copy link
Contributor

@River707 River707 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

@dcaballe dcaballe merged commit 2ac2e9a into llvm:main Aug 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants