diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp index d85b2ad9a0542..9ce413ceeaf6b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp @@ -423,6 +423,31 @@ static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) { return out; } +/// Return the constraint index of the predicate operand. The predicate +/// constraint ("b") is always the last non-tied token in the canonicalized +/// constraint string. Tied constraints (digit-only tokens from read-write +/// canonicalization) are appended at the end, so we walk backwards to skip +/// them. +static unsigned getPredicateConstraintIndex(StringRef constraints) { + SmallVector tokens; + constraints.split(tokens, ','); + assert(!tokens.empty() && "expected at least a predicate constraint"); + + auto isTiedConstraint = [](StringRef tok) { + unsigned idx; + return !tok.trim().getAsInteger(10, idx); + }; + + size_t numTied = 0; + for (StringRef tok : llvm::reverse(tokens)) { + if (!isTiedConstraint(tok)) + break; + ++numTied; + } + assert(numTied < tokens.size() && "all constraints are tied"); + return tokens.size() - numTied - 1; +} + LLVM::InlineAsmOp PtxBuilder::build() { auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(), LLVM::AsmDialect::AD_ATT); @@ -443,8 +468,9 @@ LLVM::InlineAsmOp PtxBuilder::build() { // Add the predicate to the asm string. if (interfaceOp.getPredicate().has_value() && interfaceOp.getPredicate().value()) { + unsigned predIdx = getPredicateConstraintIndex(registerConstraints); std::string predicateStr = "@%"; - predicateStr += std::to_string((ptxOperands.size() - 1)); + predicateStr += std::to_string(predIdx); ptxInstruction = predicateStr + " " + ptxInstruction; } diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 8e16a92d96a7b..a188aec18134c 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -614,11 +614,25 @@ llvm.func @ex2(%input : f32, %pred : i1) { // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32 %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32) -> f32 - // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32 + // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$2 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32 %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred -> f32 llvm.return } +// CHECK-LABEL: @multi_return_pred( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[pred:[a-zA-Z0-9_]+]]: i1) +llvm.func @multi_return_pred(%a : i32, %b : i32, %pred : i1) -> i32 { + // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "@$4 {.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", "=r,=r,r,r,b" %[[arg0]], %[[arg1]], %[[pred]] : (i32, i32, i1) -> !llvm.struct<(i32, i32)> + // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)> + // CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)> + // CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32 + // CHECK: llvm.return %[[S4]] : i32 + %r1, %r2 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}" + ro (%a, %b : i32,i32), predicate = %pred -> i32,i32 + %r3 = llvm.add %r1, %r2 : i32 + llvm.return %r3 : i32 +} + // CHECK-LABEL: @multi_return( // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32) llvm.func @multi_return(%a : i32, %b : i32) -> i32 { @@ -651,6 +665,24 @@ llvm.func @inline_ptx_multi_rw(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> llvm.return %r4 : f32 } +// CHECK-LABEL: @inline_ptx_multi_rw_pred( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32, %[[pred:[a-zA-Z0-9_]+]]: i1) +llvm.func @inline_ptx_multi_rw_pred(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32, %pred : i1) -> f32 { +// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "@$4 {.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", +// CHECK-SAME: "=f,=f,r,r,b,0,1" +// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]], %[[pred]] +// CHECK-SAME: : (f32, f32, i32, i32, i1) -> !llvm.struct<(f32, f32)> +// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32)> +// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32)> +// CHECK: %[[S3:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32 +// CHECK: llvm.return %[[S3]] : f32 + nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p;}" + ro (%a, %b : i32,i32) + rw (%rw_c, %rw_d: f32,f32), predicate = %pred + %r4 = llvm.fadd %rw_c, %rw_d : f32 + llvm.return %r4 : f32 +} + // CHECK-LABEL: @inline_ptx_multi_rw_r( // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32) llvm.func @inline_ptx_multi_rw_r(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 { @@ -678,6 +710,33 @@ llvm.func @inline_ptx_multi_rw_r(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) llvm.return %r5 : f32 } +// CHECK-LABEL: @inline_ptx_multi_rw_r_pred( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32, %[[pred:[a-zA-Z0-9_]+]]: i1) +llvm.func @inline_ptx_multi_rw_r_pred(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32, %pred : i1) -> f32 { +// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "@$6 {.reg .pred p; setp.ge.s32 p, $4, $5; selp.s32 $0, $4,$5, p; selp.s32 $1, $4,$5, p; selp.s32 $2, $4,$5, p; selp.s32 $3, $4,$5, p;}", +// CHECK-SAME: "=f,=f,=r,=r,r,r,b,0,1" +// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]], %[[pred]] : +// CHECK-SAME: (f32, f32, i32, i32, i1) -> !llvm.struct<(f32, f32, i32, i32)> +// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32, i32, i32)> +// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32, i32, i32)> +// CHECK: %[[S3:.+]] = llvm.extractvalue %[[S0]][2] : !llvm.struct<(f32, f32, i32, i32)> +// CHECK: %[[S4:.+]] = llvm.extractvalue %[[S0]][3] : !llvm.struct<(f32, f32, i32, i32)> +// CHECK: %[[S5:.+]] = llvm.add %[[S3]], %[[S4]] : i32 +// CHECK: %[[S6:.+]] = llvm.sitofp %[[S5]] : i32 to f32 +// CHECK: %[[S7:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32 +// CHECK: %[[S8:.+]] = llvm.fadd %[[S6]], %[[S2]] : f32 +// CHECK: llvm.return %[[S8]] : f32 + + %wo0, %wo1 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}" + ro (%a, %b : i32,i32) + rw (%rw_c, %rw_d: f32,f32), predicate = %pred -> i32,i32 + %r3 = llvm.add %wo0, %wo1 : i32 + %r3f = llvm.sitofp %r3 : i32 to f32 + %r4 = llvm.fadd %rw_c, %rw_d : f32 + %r5 = llvm.fadd %r3f, %rw_d : f32 + llvm.return %r5 : f32 +} + // ----- llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>, %mask : i32, %zero: i32) {