diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h index 21331e5aa89f3..cb2489335a317 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h +++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "llvm/Support/LogicalResult.h" namespace mlir { namespace NVVM { @@ -82,7 +83,8 @@ class PtxBuilder { needsManualRegisterMapping(needsManualRegisterMapping) {} /// Add an operand with the read/write input type. - void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read); + LogicalResult insertValue(Value v, + PTXRegisterMod itype = PTXRegisterMod::Read); /// Builds the inline assembly Op and returns it. The `insertValue` needs to /// be called to pass operands before building the PTX. diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index c67ec3642f121..314cbed2e4f33 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -26,6 +26,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/DebugLog.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvvm-to-llvm" @@ -62,7 +63,8 @@ struct PtxLowering PtxBuilder generator(op, rewriter, needsManualMapping); for (auto &[asmValue, modifier] : asmValues) { LDBG() << asmValue << "\t Modifier : " << modifier; - generator.insertValue(asmValue, modifier); + if (failed(generator.insertValue(asmValue, modifier))) + return failure(); } generator.buildAndReplaceOp(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp index 6d2a64f94e3ca..7220e10ea84d3 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp @@ -12,10 +12,17 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/Regex.h" #define DEBUG_TYPE "ptx-builder" @@ -31,35 +38,88 @@ using namespace NVVM; static constexpr int64_t kSharedMemorySpace = 3; -static char getRegisterType(Type type) { - if (type.isInteger(1)) - return 'b'; - if (type.isInteger(16)) - return 'h'; - if (type.isInteger(32)) - return 'r'; - if (type.isInteger(64)) - return 'l'; - if (type.isF32()) - return 'f'; - if (type.isF64()) - return 'd'; - if (auto ptr = dyn_cast(type)) { - // Shared address spaces is addressed with 32-bit pointers. - if (ptr.getAddressSpace() == kSharedMemorySpace) { +static FailureOr getRegisterType(Type type, Location loc) { + MLIRContext *ctx = type.getContext(); + auto i16 = IntegerType::get(ctx, 16); + auto i32 = IntegerType::get(ctx, 32); + auto f32 = Float32Type::get(ctx); + + auto getRegisterTypeForScalar = [&](Type type) -> FailureOr { + if (type.isInteger(1)) + return 'b'; + if (type.isInteger(16)) + return 'h'; + if (type.isInteger(32)) return 'r'; + if (type.isInteger(64)) + return 'l'; + if (type.isF32()) + return 'f'; + if (type.isF64()) + return 'd'; + if (auto ptr = dyn_cast(type)) { + // Shared address spaces is addressed with 32-bit pointers. + if (ptr.getAddressSpace() == kSharedMemorySpace) { + return 'r'; + } + return 'l'; + } + // register type for struct is not supported. + mlir::emitError( + loc, "The register type could not be deduced from MLIR type. The ") + << type + << " is not supported. Supported types are:" + "i1, i16, i32, i64, f32, f64," + "pointers.\nPlease use llvm.bitcast if you have different type. " + "\nSee the constraints from here: " + "https://docs.nvidia.com/cuda/inline-ptx-assembly/" + "index.html#constraints"; + return failure(); + }; + + // Packed registers + if (auto v = dyn_cast(type)) { + assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported"); + + int64_t lanes = v.getNumElements(); + Type elem = v.getElementType(); + + // Case 1. Single vector + if (lanes <= 1) + return getRegisterTypeForScalar(elem); + + // Case 2. Packed registers + Type widened = elem; + switch (lanes) { + + case 2: + if (elem.isF16() || elem.isBF16()) // vector<2xf16> + widened = f32; + else if (elem.isFloat(8)) // vector<2xf8> + widened = i16; + break; + case 4: + if (elem.isInteger(8)) // vector + widened = i32; + else if (elem.isFloat(8)) // vector + widened = f32; + else if (elem.isFloat(4)) // vector + widened = i16; + break; + // Other packing is not supported + default: + break; } - return 'l'; + return getRegisterTypeForScalar(widened); } - // register type for struct is not supported. - llvm_unreachable("The register type could not deduced from MLIR type"); - return '?'; + + return getRegisterTypeForScalar(type); } -static char getRegisterType(Value v) { +static FailureOr getRegisterType(Value v, Location loc) { if (v.getDefiningOp()) return 'n'; - return getRegisterType(v.getType()); + return getRegisterType(v.getType(), loc); } /// Extract every element of a struct value. @@ -75,10 +135,11 @@ static SmallVector extractStructElements(PatternRewriter &rewriter, return elems; } -void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { +LogicalResult PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { LDBG() << v << "\t Modifier : " << itype << "\n"; registerModifiers.push_back(itype); + Location loc = interfaceOp->getLoc(); auto getModifier = [&]() -> const char * { switch (itype) { case PTXRegisterMod::Read: @@ -111,21 +172,29 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { } for (auto [idx, t] : llvm::enumerate(stype.getBody())) { if (itype != PTXRegisterMod::Write) { - Value extractValue = LLVM::ExtractValueOp::create( - rewriter, interfaceOp->getLoc(), v, idx); + Value extractValue = + LLVM::ExtractValueOp::create(rewriter, loc, v, idx); addValue(extractValue); } if (itype == PTXRegisterMod::ReadWrite) { ss << idx << ","; } else { - ss << getModifier() << getRegisterType(t) << ","; + FailureOr regType = getRegisterType(t, loc); + if (failed(regType)) + return rewriter.notifyMatchFailure(loc, + "failed to get register type"); + ss << getModifier() << regType.value() << ","; } } - return; + return success(); } // Handle Scalars addValue(v); - ss << getModifier() << getRegisterType(v) << ","; + FailureOr regType = getRegisterType(v, loc); + if (failed(regType)) + return rewriter.notifyMatchFailure(loc, "failed to get register type"); + ss << getModifier() << regType.value() << ","; + return success(); } /// Check if the operation needs to pack and unpack results. diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 2a19c72ab0840..ce17650d16d32 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -756,3 +756,38 @@ llvm.func @nvvm_pmevent() { nvvm.pmevent id = 4 llvm.return } + +// ----- + +llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>, %mask : i32, %zero: i32) { +// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,r,r" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32 + %wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};" + ro(%src, %mask, %zero : vector<4xi8>, i32, i32) + -> i32 + llvm.return +} + +llvm.func @inline_ptx_pack_2bf16(%a : f32, %b : f32) { + // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rn.satfinite.bf16x2.f32 $0, $1, $2;", "=f,f,f" %{{.*}}, %{{.*}} : (f32, f32) -> vector<2xbf16> + %wo0 = nvvm.inline_ptx "cvt.rn.satfinite.bf16x2.f32 {$w0}, {$r0}, {$r1};" + ro(%a, %b : f32, f32) + -> vector<2xbf16> + llvm.return +} + +llvm.func @inline_ptx_cvt_rn_e4m3x2_f16x2(%a : i16) { +// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rz.satfinite.ue8m0x2.bf16x2 $0, $1", "=f,h" %{{.*}} : (i16) -> vector<2xbf16> + %wo0 = nvvm.inline_ptx "cvt.rz.satfinite.ue8m0x2.bf16x2 {$w0}, {$r0}" + ro(%a : i16) + -> vector<2xbf16> + llvm.return +} + +llvm.func @cvt_i8_bf16(%a : i8) { + // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .b16 r;\0A\09.reg .b8 s;\0A\09mov.b16 {s,_}, $0;\0A\09cvt.rn.bf16.s8 r, s;\0A\09mov.b16 $1, r;\0A\09", "=h,h" %{{.*}} : (i16) -> i16 + %za = llvm.zext %a : i8 to i16 + %wo0 = nvvm.inline_ptx "{\n\t.reg .b16 r;\n\t.reg .b8 s;\n\tmov.b16 {s,_}, {$w0};\n\tcvt.rn.bf16.s8 r, s;\n\tmov.b16 {$r0}, r;\n\t" + ro(%za : i16) + -> i16 + llvm.return +}