diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h index 3e3fcd7d1fb82..21331e5aa89f3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h +++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h @@ -26,11 +26,11 @@ namespace NVVM { enum class PTXRegisterMod { /// Read register with no modifier Read = 0, - /// Read register with '+' modifier + /// Write register with '=' modifier Write = 2, - /// Read register with '=' modifier. - /// Note that, this is not natively supported by LLVM, but it is possible to - /// set read and write for the same operand. + /// ReadWrite register with '+' modifier. + /// Note that, this is not natively supported by LLVM, the Interface does + /// mapping ReadWrite = 1, }; @@ -67,13 +67,19 @@ class PtxBuilder { SmallVector ptxOperands; // Register constraints (read, write, readwrite) and register data types std::string registerConstraints; - + // Modifiers + SmallVector registerModifiers; + // Has return value as write-only or read-write bool hasResult = false; + // Indicates if the Op will handle the register mapping manually. + bool needsManualRegisterMapping = false; public: /// Single constructor that only initializes members. - PtxBuilder(Operation *op, PatternRewriter &rewriter) - : interfaceOp(op), rewriter(rewriter) {} + PtxBuilder(Operation *op, PatternRewriter &rewriter, + bool needsManualRegisterMapping = false) + : interfaceOp(op), rewriter(rewriter), + needsManualRegisterMapping(needsManualRegisterMapping) {} /// Add an operand with the read/write input type. void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read); @@ -87,6 +93,16 @@ class PtxBuilder { void buildAndReplaceOp(); }; +/// Count the number of placeholder variables such as {$r}, {$w}, {$rw} in the +/// PTX code. +void countPlaceholderNumbers(StringRef ptxCode, + llvm::SmallDenseSet &seenRW, + llvm::SmallDenseSet &seenW, + llvm::SmallDenseSet &seenR, + llvm::SmallVectorImpl &rwNums, + llvm::SmallVectorImpl &wNums, + llvm::SmallVectorImpl &rNums); + } // namespace NVVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td index e98b94b5b3052..086cdccb01221 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td +++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td @@ -124,19 +124,21 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> { following this order: 1) Adds results 2) Adds operands - 3) Adds attributes + 3) Adds attributes + Returns true if the OP is going to do register mapping itself }], - /*retType=*/"void", + /*retType=*/"bool", /*methodName=*/"getAsmValues", /*args=*/(ins "::mlir::RewriterBase &":$rewriter, - "llvm::SmallVectorImpl>&" : $asmValues), + "llvm::SmallVectorImpl>&" : $asmValues + ), /*methodBody=*/"", /*defaultImpl=*/ [{ mlir::Operation* op = $_op; // Step 1. Add results - for (auto val : op->getResults()) - asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write}); + for (auto val : op->getResults()) + asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write}); // Step 2. Add operands for (auto val : op->getOperands()) @@ -149,6 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> { asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read}); } } + return false; // No manual mapping needed }] > ]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index f9cd58de8915f..786d42cf15666 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -315,16 +315,19 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx", }]; let arguments = (ins Variadic:$readOnlyArgs, + Variadic:$readWriteArgs, StrAttr:$ptxCode, PtxPredicate:$predicate); let results = (outs Variadic:$writeOnlyArgs); - - let assemblyFormat = [{ - $ptxCode `(` $readOnlyArgs `)` - (`,` `predicate` `=` $predicate^)? attr-dict - `:` type(operands) - (`->` type($writeOnlyArgs)^)? + + let assemblyFormat = [{ + $ptxCode + ( `ro` `(` $readOnlyArgs^ `:` type($readOnlyArgs) `)` )? + ( `rw` `(` $readWriteArgs^ `:` type($readWriteArgs) `)` )? + (`,` `predicate` `=` $predicate^)? + attr-dict + ( `->` type($writeOnlyArgs)^ )? }]; let extraClassDefinition = [{ @@ -333,6 +336,10 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx", return std::string(ptxInstStr.data()); } }]; + + let extraClassDeclaration = [{ + bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl> &); + }]; } //===----------------------------------------------------------------------===// @@ -3027,8 +3034,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async", let hasVerifier = 1; let extraClassDeclaration = [{ - void getAsmValues(RewriterBase &rewriter, - llvm::SmallVectorImpl> &asmValues); + bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl> &); }]; } diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index e0144bff4d371..c67ec3642f121 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -57,9 +57,9 @@ struct PtxLowering SmallVector> asmValues; LDBG() << op.getPtx(); - PtxBuilder generator(op, rewriter); - op.getAsmValues(rewriter, asmValues); + bool needsManualMapping = op.getAsmValues(rewriter, asmValues); + PtxBuilder generator(op, rewriter, needsManualMapping); for (auto &[asmValue, modifier] : asmValues) { LDBG() << asmValue << "\t Modifier : " << modifier; generator.insertValue(asmValue, modifier); diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp index 6765aa4b9dac9..6d2a64f94e3ca 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp @@ -13,7 +13,10 @@ #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/DebugLog.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Regex.h" #define DEBUG_TYPE "ptx-builder" @@ -59,19 +62,37 @@ static char getRegisterType(Value v) { return getRegisterType(v.getType()); } +/// Extract every element of a struct value. +static SmallVector extractStructElements(PatternRewriter &rewriter, + Location loc, Value structVal) { + auto structTy = dyn_cast(structVal.getType()); + assert(structTy && "expected LLVM struct"); + + SmallVector elems; + for (unsigned i : llvm::seq(0, structTy.getBody().size())) + elems.push_back(rewriter.create(loc, structVal, i)); + + return elems; +} + void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { - LDBG() << v << "\t Modifier : " << &itype; + LDBG() << v << "\t Modifier : " << itype << "\n"; + registerModifiers.push_back(itype); + auto getModifier = [&]() -> const char * { - if (itype == PTXRegisterMod::ReadWrite) { - assert(false && "Read-Write modifier is not supported. Try setting the " - "same value as Write and Read separately."); - return "+"; - } - if (itype == PTXRegisterMod::Write) { + switch (itype) { + case PTXRegisterMod::Read: + return ""; + case PTXRegisterMod::Write: return "="; + case PTXRegisterMod::ReadWrite: + // "Read-Write modifier is not actually supported + // Interface will change it to "=" later and add integer mapping + return "+"; } - return ""; + llvm_unreachable("Unknown PTX register modifier"); }; + auto addValue = [&](Value v) { if (itype == PTXRegisterMod::Read) { ptxOperands.push_back(v); @@ -108,38 +129,247 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { } /// Check if the operation needs to pack and unpack results. -static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) { - return interfaceOp->getNumResults() > 1; +static bool +needsPackUnpack(BasicPtxBuilderInterface interfaceOp, + bool needsManualRegisterMapping, + SmallVectorImpl ®isterModifiers) { + if (needsManualRegisterMapping) + return false; + const unsigned writeOnlyVals = interfaceOp->getNumResults(); + const unsigned readWriteVals = + llvm::count_if(registerModifiers, [](PTXRegisterMod m) { + return m == PTXRegisterMod::ReadWrite; + }); + return (writeOnlyVals + readWriteVals) > 1; } /// Pack the result types of the interface operation. /// If the operation has multiple results, it packs them into a struct /// type. Otherwise, it returns the original result types. -static SmallVector packResultTypes(MLIRContext *ctx, - BasicPtxBuilderInterface interfaceOp) { - TypeRange results = interfaceOp->getResultTypes(); +static SmallVector +packResultTypes(BasicPtxBuilderInterface interfaceOp, + bool needsManualRegisterMapping, + SmallVectorImpl ®isterModifiers, + SmallVectorImpl &ptxOperands) { + MLIRContext *ctx = interfaceOp->getContext(); + TypeRange resultRange = interfaceOp->getResultTypes(); + + if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping, + registerModifiers)) { + // Single value path: + if (interfaceOp->getResults().size() == 1) + return SmallVector{resultRange.front()}; + + // No declared results: if there is an RW, forward its type. + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) + if (m == PTXRegisterMod::ReadWrite) + return SmallVector{v.getType()}; + } + + SmallVector packed; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) + if (m == PTXRegisterMod::ReadWrite) + packed.push_back(v.getType()); + for (Type t : resultRange) + packed.push_back(t); + + if (packed.empty()) + return {}; + + auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false); + return SmallVector{sTy}; +} + +/// Canonicalize the register constraints: +/// - Turn every "+X" into "=X" +/// - Append (at the very end) the 0-based indices of tokens that were "+X" +/// Examples: +/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2" +/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2" +static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) { + SmallVector toks; + SmallVector out; + SmallVector plusIdx; + + csv.split(toks, ','); + out.reserve(toks.size() + 8); + + for (unsigned i = 0, e = toks.size(); i < e; ++i) { + StringRef t = toks[i].trim(); + if (t.consume_front("+")) { + plusIdx.push_back(i); + out.push_back(("=" + t).str()); + } else { + out.push_back(t.str()); + } + } + + // Append indices of original "+X" tokens. + for (unsigned idx : plusIdx) + out.push_back(std::to_string(idx)); + + // Join back to CSV. + std::string result; + result.reserve(csv.size() + plusIdx.size() * 2); + llvm::raw_string_ostream os(result); + for (size_t i = 0; i < out.size(); ++i) { + if (i) + os << ','; + os << out[i]; + } + return os.str(); +} + +constexpr llvm::StringLiteral kReadWritePrefix{"rw"}; +constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"}; +constexpr llvm::StringLiteral kReadOnlyPrefix{"r"}; + +/// Returns a regex that matches {$rwN}, {$wN}, {$rN} +static llvm::Regex getPredicateMappingRegex() { + llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})", + kReadWritePrefix, kWriteOnlyPrefix, + kReadOnlyPrefix) + .str()); + return rx; +} + +void mlir::NVVM::countPlaceholderNumbers( + StringRef ptxCode, llvm::SmallDenseSet &seenRW, + llvm::SmallDenseSet &seenW, + llvm::SmallDenseSet &seenR, + llvm::SmallVectorImpl &rwNums, + llvm::SmallVectorImpl &wNums, + llvm::SmallVectorImpl &rNums) { - if (!needsPackUnpack(interfaceOp)) - return llvm::to_vector<1>(results); + llvm::Regex rx = getPredicateMappingRegex(); + StringRef rest = ptxCode; - SmallVector elems(results.begin(), results.end()); - auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false); - return {sTy}; + SmallVector m; // 0: full, 1: kind, 2: number + while (!rest.empty() && rx.match(rest, &m)) { + unsigned num = 0; + (void)m[2].getAsInteger(10, num); + // Insert it into the vector only the first time we see this number + if (m[1].equals_insensitive(kReadWritePrefix)) { + if (seenRW.insert(num).second) + rwNums.push_back(num); + } else if (m[1].equals_insensitive(kWriteOnlyPrefix)) { + if (seenW.insert(num).second) + wNums.push_back(num); + } else { + if (seenR.insert(num).second) + rNums.push_back(num); + } + + const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size(); + rest = rest.drop_front(advance); + } +} + +/// Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into +/// compact `$K` indices: +/// - All `rw*` first (sorted by N), +/// - Then `w*`, +/// - Then `r*`. +/// If there a predicate, it comes always in the end. +/// Each number is assigned once; duplicates are ignored. +/// +/// Example Input: +/// "{ +/// reg .pred p; +/// setp.ge.s32 p, {$r0}, {$r1};" +/// selp.s32 {$rw0}, {$r0}, {$r1}, p; +/// selp.s32 {$rw1}, {$r0}, {$r1}, p; +/// selp.s32 {$w0}, {$r0}, {$r1}, p; +/// selp.s32 {$w1}, {$r0}, {$r1}, p; +/// }\n" +/// Example Output: +/// "{ +/// reg .pred p; +/// setp.ge.s32 p, $4, $5;" +/// selp.s32 $0, $4, $5, p; +/// selp.s32 $1, $4, $5, p; +/// selp.s32 $2, $4, $5, p; +/// selp.s32 $3, $4, $5, p; +/// }\n" +static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) { + llvm::SmallDenseSet seenRW, seenW, seenR; + llvm::SmallVector rwNums, wNums, rNums; + + // Step 1. Count Register Placeholder numbers + countPlaceholderNumbers(ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums); + + // Step 2. Sort the Register Placeholder numbers + llvm::sort(rwNums); + llvm::sort(wNums); + llvm::sort(rNums); + + // Step 3. Create mapping from original to new IDs + llvm::DenseMap rwMap, wMap, rMap; + unsigned nextId = 0; + for (unsigned n : rwNums) + rwMap[n] = nextId++; + for (unsigned n : wNums) + wMap[n] = nextId++; + for (unsigned n : rNums) + rMap[n] = nextId++; + + // Step 4. Rewrite the PTX code with new IDs + std::string out; + out.reserve(ptxCode.size()); + size_t prev = 0; + StringRef rest = ptxCode; + SmallVector matches; + llvm::Regex rx = getPredicateMappingRegex(); + while (!rest.empty() && rx.match(rest, &matches)) { + // Compute absolute match bounds in the original buffer. + size_t absStart = (size_t)(matches[0].data() - ptxCode.data()); + size_t absEnd = absStart + matches[0].size(); + + // Emit text before the match. + out.append(ptxCode.data() + prev, ptxCode.data() + absStart); + + // Emit compact $K + unsigned num = 0; + (void)matches[2].getAsInteger(10, num); + unsigned id = 0; + if (matches[1].equals_insensitive(kReadWritePrefix)) + id = rwMap.lookup(num); + else if (matches[1].equals_insensitive(kWriteOnlyPrefix)) + id = wMap.lookup(num); + else + id = rMap.lookup(num); + + out.push_back('$'); + out += std::to_string(id); + + prev = absEnd; + + const size_t advance = + (size_t)(matches[0].data() - rest.data()) + matches[0].size(); + rest = rest.drop_front(advance); + } + + // Step 5. Tail. + out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size()); + return out; } LLVM::InlineAsmOp PtxBuilder::build() { - MLIRContext *ctx = interfaceOp->getContext(); auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(), LLVM::AsmDialect::AD_ATT); - SmallVector resultTypes = packResultTypes(ctx, interfaceOp); + SmallVector resultTypes = packResultTypes( + interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands); // Remove the last comma from the constraints string. if (!registerConstraints.empty() && registerConstraints[registerConstraints.size() - 1] == ',') registerConstraints.pop_back(); + registerConstraints = canonicalizeRegisterConstraints(registerConstraints); std::string ptxInstruction = interfaceOp.getPtx(); + if (!needsManualRegisterMapping) + ptxInstruction = rewriteAsmPlaceholders(ptxInstruction); // Add the predicate to the asm string. if (interfaceOp.getPredicate().has_value() && @@ -169,33 +399,87 @@ void PtxBuilder::buildAndReplaceOp() { LLVM::InlineAsmOp inlineAsmOp = build(); LDBG() << "\n Generated PTX \n\t" << inlineAsmOp; - // Case 1: no result - if (inlineAsmOp->getNumResults() == 0) { + // Case 0: no result at all → just erase wrapper op. + if (!hasResult) { rewriter.eraseOp(interfaceOp); return; } - // Case 2: single result, forward it directly - if (!needsPackUnpack(interfaceOp)) { + if (needsManualRegisterMapping) { rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults()); return; } - // Case 3: multiple results were packed; unpack the struct. - assert(mlir::LLVM::LLVMStructType::classof( - inlineAsmOp.getResultTypes().front()) && - "Expected result type to be LLVMStructType when unpacking multiple " - "results"); - auto structTy = llvm::cast( - inlineAsmOp.getResultTypes().front()); + // Case 1: Simple path, return single scalar + if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping, + registerModifiers)) { + if (inlineAsmOp->getNumResults() > 0) { + rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults()); + } else { + // RW-only case with no declared results: forward the RW value. + SmallVector results; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) + if (m == PTXRegisterMod::ReadWrite) { + results.push_back(v); + break; + } + rewriter.replaceOp(interfaceOp, results); + } + return; + } + + const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) { + return m == PTXRegisterMod::ReadWrite; + }); - SmallVector unpacked; + // All multi-value paths produce a single struct result we need to unpack. + assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) && + "expected struct return for multi-result inline asm"); Value structVal = inlineAsmOp.getResult(0); - for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) { - Value unpackedValue = LLVM::ExtractValueOp::create( - rewriter, interfaceOp->getLoc(), structVal, idx); - unpacked.push_back(unpackedValue); + SmallVector unpacked = + extractStructElements(rewriter, interfaceOp->getLoc(), structVal); + + // Case 2: only declared results (no RW): replace the op with all unpacked. + if (!hasRW && interfaceOp->getResults().size() > 0) { + rewriter.replaceOp(interfaceOp, unpacked); + return; } - rewriter.replaceOp(interfaceOp, unpacked); + // Case 3: RW-only (no declared results): update RW uses and erase wrapper. + if (hasRW && interfaceOp->getResults().size() == 0) { + unsigned idx = 0; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) { + if (m != PTXRegisterMod::ReadWrite) + continue; + Value repl = unpacked[idx++]; + v.replaceUsesWithIf(repl, [&](OpOperand &use) { + Operation *owner = use.getOwner(); + return owner != interfaceOp && owner != inlineAsmOp; + }); + } + rewriter.eraseOp(interfaceOp); + return; + } + + // Case 4: mixed (RW + declared results). + { + // First rewrite RW operands in place. + unsigned idx = 0; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) { + if (m != PTXRegisterMod::ReadWrite) + continue; + Value repl = unpacked[idx++]; + v.replaceUsesWithIf(repl, [&](OpOperand &use) { + Operation *owner = use.getOwner(); + return owner != interfaceOp && owner != inlineAsmOp; + }); + } + // The remaining unpacked values correspond to the declared results. + SmallVector tail; + tail.reserve(unpacked.size() - idx); + for (unsigned i = idx, e = unpacked.size(); i < e; ++i) + tail.push_back(unpacked[i]); + + rewriter.replaceOp(interfaceOp, tail); + } } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index dbcc738b4419f..042103cad83ca 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1123,7 +1123,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() { return ptx; } -void NVVM::WgmmaMmaAsyncOp::getAsmValues( +bool NVVM::WgmmaMmaAsyncOp::getAsmValues( RewriterBase &rewriter, llvm::SmallVectorImpl> &asmValues) { @@ -1154,7 +1154,9 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues( {makeConstantI32(rewriter, 1 - static_cast(getLayoutB())), mlir::NVVM::PTXRegisterMod::Read}); } + return true; // Has manual mapping } + LogicalResult NVVM::FenceProxyOp::verify() { if (getKind() == NVVM::ProxyKind::TENSORMAP) return emitOpError() << "tensormap proxy is not a supported proxy kind"; @@ -1870,6 +1872,21 @@ llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) { } } +bool NVVM::InlinePtxOp::getAsmValues( + RewriterBase &rewriter, + llvm::SmallVectorImpl> + &asmValues) { + for (auto arg : getReadWriteArgs()) + asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite}); + for (auto arg : getResults()) + asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write}); + for (auto arg : getReadOnlyArgs()) + asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read}); + if (getPredicate()) + asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read}); + return false; // No manual mapping needed +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index b38347c7cd1b7..2a19c72ab0840 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -667,34 +667,82 @@ llvm.func @init_mbarrier( %count : i32, %pred : i1) { // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r" - nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32 + nvvm.inline_ptx "mbarrier.init.b64 [{$r0}], {$r1};" ro (%barrier_gen, %count : !llvm.ptr, i32) // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" - nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1 + nvvm.inline_ptx "mbarrier.init.b64 [{$r0}], {$r1};" ro (%barrier_gen, %count : !llvm.ptr, i32), predicate = %pred llvm.return } // ----- llvm.func @ex2(%input : f32, %pred : i1) { // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32 - %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32 + %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32) -> f32 // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32 - %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input), predicate = %pred : f32, i1 -> f32 + %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred -> f32 llvm.return } // CHECK-LABEL: @multi_return( // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32) llvm.func @multi_return(%a : i32, %b : i32) -> i32 { - // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)> + // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)> // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)> // CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)> // CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32 // CHECK: llvm.return %[[S4]] : i32 - %r1, %r2 = nvvm.inline_ptx "{\n\t .reg .pred p;\n\t setp.ge.s32 p, $2, $3;\n\t selp.s32 $0, $2, $3, p;\n\t selp.s32 $1, $2, $3, !p;\n\t}\n" (%a, %b) : i32,i32 -> i32,i32 + %r1, %r2 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}" + ro (%a, %b : i32,i32) -> i32,i32 %r3 = llvm.add %r1, %r2 : i32 llvm.return %r3 : i32 } + +// CHECK-LABEL: @inline_ptx_multi_rw( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32) +llvm.func @inline_ptx_multi_rw(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 { +// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", +// CHECK-SAME: "=f,=f,r,r,0,1" +// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]] +// CHECK-SAME: : (f32, f32, i32, i32) -> !llvm.struct<(f32, f32)> +// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32)> +// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32)> +// CHECK: %[[S3:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32 +// CHECK: llvm.return %[[S3]] : f32 + nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p;}" + ro (%a, %b : i32,i32) + rw (%rw_c, %rw_d: f32,f32) + %r4 = llvm.fadd %rw_c, %rw_d : f32 + llvm.return %r4 : f32 +} + +// CHECK-LABEL: @inline_ptx_multi_rw_r( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32) +llvm.func @inline_ptx_multi_rw_r(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 { +// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $4, $5; selp.s32 $0, $4,$5, p; selp.s32 $1, $4,$5, p; selp.s32 $2, $4,$5, p; selp.s32 $3, $4,$5, p;}", +// CHECK-SAME: "=f,=f,=r,=r,r,r,0,1" +// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]] : +// CHECK-SAME: (f32, f32, i32, i32) -> !llvm.struct<(f32, f32, i32, i32)> +// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32, i32, i32)> +// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32, i32, i32)> +// CHECK: %[[S3:.+]] = llvm.extractvalue %[[S0]][2] : !llvm.struct<(f32, f32, i32, i32)> +// CHECK: %[[S4:.+]] = llvm.extractvalue %[[S0]][3] : !llvm.struct<(f32, f32, i32, i32)> +// CHECK: %[[S5:.+]] = llvm.add %[[S3]], %[[S4]] : i32 +// CHECK: %[[S6:.+]] = llvm.sitofp %[[S5]] : i32 to f32 +// CHECK: %[[S7:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32 +// CHECK: %[[S8:.+]] = llvm.fadd %[[S6]], %[[S2]] : f32 +// CHECK: llvm.return %[[S8]] : f32 + + %wo0, %wo1 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}" + ro (%a, %b : i32,i32) + rw (%rw_c, %rw_d: f32,f32) -> i32,i32 + %r3 = llvm.add %wo0, %wo1 : i32 + %r3f = llvm.sitofp %r3 : i32 to f32 + %r4 = llvm.fadd %rw_c, %rw_d : f32 + %r5 = llvm.fadd %r3f, %rw_d : f32 + llvm.return %r5 : f32 +} + + // ----- // CHECK-LABEL: @nvvm_pmevent diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py index 0eef97d95479a..3eb62bef50de9 100644 --- a/mlir/test/python/dialects/nvvm.py +++ b/mlir/test/python/dialects/nvvm.py @@ -5,6 +5,8 @@ from mlir.dialects import nvvm from mlir.dialects import llvm from mlir.dialects import func +import mlir.extras.types as T +from mlir.dialects import arith def constructAndPrintInModule(f): @@ -25,6 +27,7 @@ def testSmoke(): "!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>" ) shape_attr = Attribute.parse("#nvvm.shape") + # CHECK-LABEL: func @wgmma_f32_f16_f16(%arg0: i64, %arg1: i64) @func.FuncOp.from_py_func(i64, i64) def wgmma_f32_f16_f16(desc_a, desc_b): @@ -48,3 +51,41 @@ def wgmma_f32_f16_f16(desc_a, desc_b): layoutA=nvvm.MMALayout.col, layoutB=nvvm.MMALayout.col, ) + + +# CHECK-LABEL: TEST: test_inline_ptx +# CHECK-LABEL: func.func @my_inline_ptx( +# CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: f32, %[[arg1:[a-zA-Z0-9_]+]]: f32, %[[arg2:[a-zA-Z0-9_]+]]: i32, %[[arg3:[a-zA-Z0-9_]+]]: i32) +# CHECK: %[[S0:.+]]:2 = nvvm.inline_ptx +# CHECK-SAME: ro(%[[arg0]], %[[arg1]] : f32, f32) rw(%[[arg2]], %[[arg3]] : i32, i32) -> f32, f32 +# CHECK: %[[S1:.+]] = arith.addf %[[arg0]], %[[arg1]] : f32 +# CHECK: %[[S2:.+]] = arith.addi %[[arg2]], %[[arg3]] : i32 +# CHECK: %[[S3:.+]] = arith.addf %[[S0]]#0, %[[S0]]#1 : f32 + + +@constructAndPrintInModule +def test_inline_ptx(): + i32 = T.i32() + f32 = T.f32() + + @func.FuncOp.from_py_func(f32, f32, i32, i32) + def my_inline_ptx(a, b, c, d): + ptx = r""" + { + .reg .pred p; + setp.ge.s32 p, {$r0}, {$r1}; + selp.s32 {$r0}, {$r0}, {$r1}, p; + selp.s32 {$r1}, {$r0}, {$r1}, p; + selp.s32 {$rw0}, {$r0}, {$r1}, p; + selp.s32 {$rw1}, {$r0}, {$r1}, p; + } + """ + wo0, wo1 = nvvm.inline_ptx( + read_only_args=[a, b], + read_write_args=[c, d], + write_only_args=[f32, f32], + ptx_code=ptx, + ) + arith.addf(a, b) + arith.addi(c, d) + arith.addf(wo0, wo1)