diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 7700867bb461f..86f4e7e3e2c4a 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1102,6 +1102,22 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { llvm::AttrBuilder().addAlignmentAttr(llvm::Align(attr.getInt()))); } + if (auto attr = func.getArgAttrOfType(argIdx, "llvm.sret")) { + auto argTy = mlirArg.getType().dyn_cast(); + if (!argTy.isa()) + return func.emitError( + "llvm.sret attribute attached to LLVM non-pointer argument"); + llvmArg.addAttr(llvm::Attribute::AttrKind::StructRet); + } + + if (auto attr = func.getArgAttrOfType(argIdx, "llvm.byval")) { + auto argTy = mlirArg.getType().dyn_cast(); + if (!argTy.isa()) + return func.emitError( + "llvm.byval attribute attached to LLVM non-pointer argument"); + llvmArg.addAttr(llvm::Attribute::AttrKind::ByVal); + } + valueMapping[mlirArg] = &llvmArg; argIdx++; } diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir index 65dc33cc1c4f9..2cec1bca1f74f 100644 --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -87,6 +87,16 @@ module { llvm.return } + // CHECK: llvm.func @byvalattr(%{{.*}}: !llvm.ptr {llvm.byval}) + llvm.func @byvalattr(%arg0: !llvm.ptr {llvm.byval}) { + llvm.return + } + + // CHECK: llvm.func @sretattr(%{{.*}}: !llvm.ptr {llvm.sret}) + llvm.func @sretattr(%arg0: !llvm.ptr {llvm.sret}) { + llvm.return + } + // CHECK: llvm.func @variadic(...) llvm.func @variadic(...) diff --git a/mlir/test/Target/llvmir-invalid.mlir b/mlir/test/Target/llvmir-invalid.mlir index 14117594e2f89..fcd98ef4b143a 100644 --- a/mlir/test/Target/llvmir-invalid.mlir +++ b/mlir/test/Target/llvmir-invalid.mlir @@ -14,6 +14,19 @@ llvm.func @invalid_noalias(%arg0 : !llvm.float {llvm.noalias = true}) -> !llvm.f // ----- +// expected-error @+1 {{llvm.sret attribute attached to LLVM non-pointer argument}} +llvm.func @invalid_noalias(%arg0 : !llvm.float {llvm.sret}) -> !llvm.float { + llvm.return %arg0 : !llvm.float +} +// ----- + +// expected-error @+1 {{llvm.byval attribute attached to LLVM non-pointer argument}} +llvm.func @invalid_noalias(%arg0 : !llvm.float {llvm.byval}) -> !llvm.float { + llvm.return %arg0 : !llvm.float +} + +// ----- + // expected-error @+1 {{llvm.align attribute attached to LLVM non-pointer argument}} llvm.func @invalid_align(%arg0 : !llvm.float {llvm.align = 4}) -> !llvm.float { llvm.return %arg0 : !llvm.float