diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp index 894de4408c375..e004d5f64733e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp @@ -107,11 +107,32 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { ss << getModifier() << getRegisterType(v) << ","; } +/// Check if the operation needs to pack and unpack results. +static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) { + return interfaceOp->getNumResults() > 1; +} + +/// Pack the result types of the interface operation. +/// If the operation has multiple results, it packs them into a struct +/// type. Otherwise, it returns the original result types. +static SmallVector packResultTypes(MLIRContext *ctx, + BasicPtxBuilderInterface interfaceOp) { + TypeRange results = interfaceOp->getResultTypes(); + + if (!needsPackUnpack(interfaceOp)) + return llvm::to_vector<1>(results); + + SmallVector elems(results.begin(), results.end()); + auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false); + return {sTy}; +} + LLVM::InlineAsmOp PtxBuilder::build() { + MLIRContext *ctx = interfaceOp->getContext(); auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(), LLVM::AsmDialect::AD_ATT); - auto resultTypes = interfaceOp->getResultTypes(); + SmallVector resultTypes = packResultTypes(ctx, interfaceOp); // Remove the last comma from the constraints string. if (!registerConstraints.empty() && @@ -136,7 +157,7 @@ LLVM::InlineAsmOp PtxBuilder::build() { rewriter, interfaceOp->getLoc(), /*result types=*/resultTypes, /*operands=*/ptxOperands, - /*asm_string=*/llvm::StringRef(ptxInstruction), + /*asm_string=*/ptxInstruction, /*constraints=*/registerConstraints.data(), /*has_side_effects=*/interfaceOp.hasSideEffect(), /*is_align_stack=*/false, LLVM::TailCallKind::None, @@ -147,9 +168,34 @@ LLVM::InlineAsmOp PtxBuilder::build() { void PtxBuilder::buildAndReplaceOp() { LLVM::InlineAsmOp inlineAsmOp = build(); LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n"); - if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) { - rewriter.replaceOp(interfaceOp, inlineAsmOp); - } else { + + // Case 1: no result + if (inlineAsmOp->getNumResults() == 0) { rewriter.eraseOp(interfaceOp); + return; + } + + // Case 2: single result, forward it directly + if (!needsPackUnpack(interfaceOp)) { + rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults()); + return; } + + // Case 3: multiple results were packed; unpack the struct. + assert(mlir::LLVM::LLVMStructType::classof( + inlineAsmOp.getResultTypes().front()) && + "Expected result type to be LLVMStructType when unpacking multiple " + "results"); + auto structTy = llvm::cast( + inlineAsmOp.getResultTypes().front()); + + SmallVector unpacked; + Value structVal = inlineAsmOp.getResult(0); + for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) { + Value unpackedValue = LLVM::ExtractValueOp::create( + rewriter, interfaceOp->getLoc(), structVal, idx); + unpacked.push_back(unpackedValue); + } + + rewriter.replaceOp(interfaceOp, unpacked); } diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 24873340d7122..b38347c7cd1b7 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -683,6 +683,18 @@ llvm.func @ex2(%input : f32, %pred : i1) { llvm.return } +// CHECK-LABEL: @multi_return( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32) +llvm.func @multi_return(%a : i32, %b : i32) -> i32 { + // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)> + // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)> + // CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)> + // CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32 + // CHECK: llvm.return %[[S4]] : i32 + %r1, %r2 = nvvm.inline_ptx "{\n\t .reg .pred p;\n\t setp.ge.s32 p, $2, $3;\n\t selp.s32 $0, $2, $3, p;\n\t selp.s32 $1, $2, $3, !p;\n\t}\n" (%a, %b) : i32,i32 -> i32,i32 + %r3 = llvm.add %r1, %r2 : i32 + llvm.return %r3 : i32 +} // ----- // CHECK-LABEL: @nvvm_pmevent