diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index c427b0942aafd..8a3d07043013e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1567,6 +1567,133 @@ def NVVM_ElectSyncOp : NVVM_Op<"elect.sync"> }]; } +//===----------------------------------------------------------------------===// +// Permute Bytes (Prmt) +//===----------------------------------------------------------------------===// + +// Attributes for the permute operation modes supported by PTX. +def PermuteModeDefault : I32EnumAttrCase<"DEFAULT", 0, "default">; +def PermuteModeF4E : I32EnumAttrCase<"F4E", 1, "f4e">; +def PermuteModeB4E : I32EnumAttrCase<"B4E", 2, "b4e">; +def PermuteModeRC8 : I32EnumAttrCase<"RC8", 3, "rc8">; +def PermuteModeECL : I32EnumAttrCase<"ECL", 4, "ecl">; +def PermuteModeECR : I32EnumAttrCase<"ECR", 5, "ecr">; +def PermuteModeRC16 : I32EnumAttrCase<"RC16", 6, "rc16">; + +def PermuteMode : I32EnumAttr<"PermuteMode", "NVVM permute mode", + [PermuteModeDefault, PermuteModeF4E, + PermuteModeB4E, PermuteModeRC8, PermuteModeECL, + PermuteModeECR, PermuteModeRC16]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} + +def PermuteModeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_PermuteOp : NVVM_Op<"prmt", [Pure]>, + Results<(outs I32:$res)>, + Arguments<(ins I32:$lo, Optional:$hi, I32:$selector, + PermuteModeAttr:$mode)> { + let summary = "Permute bytes from two 32-bit registers"; + let description = [{ + The `nvvm.prmt` operation constructs a permutation of the + bytes of the first one or two operands, selecting based on + the 2 least significant bits of the final operand. + + The bytes in the first one or two source operands are numbered. + The first source operand (%lo) is numbered {b3, b2, b1, b0}, + in the case of the '``default``', '``f4e``' and '``b4e``' variants, + the second source operand (%hi) is numbered {b7, b6, b5, b4}. + + Modes: + - `default`: Index mode - each nibble in `selector` selects a byte from the 8-byte pool + - `f4e` : Forward 4 extract - extracts 4 contiguous bytes starting from position in `selector` + - `b4e` : Backward 4 extract - extracts 4 contiguous bytes in reverse order + - `rc8` : Replicate 8 - replicates the lower 8 bits across the 32-bit result + - `ecl` : Edge clamp left - clamps out-of-range indices to the leftmost valid byte + - `ecr` : Edge clamp right - clamps out-of-range indices to the rightmost valid byte + - `rc16` : Replicate 16 - replicates the lower 16 bits across the 32-bit result + + Depending on the 2 least significant bits of the %selector operand, the result + of the permutation is defined as follows: + + +------------+----------------+--------------+ + | Mode | %selector[1:0] | Output | + +------------+----------------+--------------+ + | '``f4e``' | 0 | {3, 2, 1, 0} | + | +----------------+--------------+ + | | 1 | {4, 3, 2, 1} | + | +----------------+--------------+ + | | 2 | {5, 4, 3, 2} | + | +----------------+--------------+ + | | 3 | {6, 5, 4, 3} | + +------------+----------------+--------------+ + | '``b4e``' | 0 | {5, 6, 7, 0} | + | +----------------+--------------+ + | | 1 | {6, 7, 0, 1} | + | +----------------+--------------+ + | | 2 | {7, 0, 1, 2} | + | +----------------+--------------+ + | | 3 | {0, 1, 2, 3} | + +------------+----------------+--------------+ + | '``rc8``' | 0 | {0, 0, 0, 0} | + | +----------------+--------------+ + | | 1 | {1, 1, 1, 1} | + | +----------------+--------------+ + | | 2 | {2, 2, 2, 2} | + | +----------------+--------------+ + | | 3 | {3, 3, 3, 3} | + +------------+----------------+--------------+ + | '``ecl``' | 0 | {3, 2, 1, 0} | + | +----------------+--------------+ + | | 1 | {3, 2, 1, 1} | + | +----------------+--------------+ + | | 2 | {3, 2, 2, 2} | + | +----------------+--------------+ + | | 3 | {3, 3, 3, 3} | + +------------+----------------+--------------+ + | '``ecr``' | 0 | {0, 0, 0, 0} | + | +----------------+--------------+ + | | 1 | {1, 1, 1, 0} | + | +----------------+--------------+ + | | 2 | {2, 2, 1, 0} | + | +----------------+--------------+ + | | 3 | {3, 2, 1, 0} | + +------------+----------------+--------------+ + | '``rc16``' | 0 | {1, 0, 1, 0} | + | +----------------+--------------+ + | | 1 | {3, 2, 3, 2} | + | +----------------+--------------+ + | | 2 | {1, 0, 1, 0} | + | +----------------+--------------+ + | | 3 | {3, 2, 3, 2} | + +------------+----------------+--------------+ + + [For more information, see PTX ISA] + (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt) + }]; + + let assemblyFormat = [{ + $mode $selector `,` $lo (`,` $hi^)? attr-dict `:` type($res) + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::PermuteOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; +} + def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">; def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">; def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">; diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 846b299312d67..d3c305555fde8 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -448,6 +448,31 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() { return success(); } +LogicalResult PermuteOp::verify() { + using Mode = NVVM::PermuteMode; + bool hasHi = static_cast(getHi()); + + switch (getMode()) { + case Mode::DEFAULT: + case Mode::F4E: + case Mode::B4E: + if (!hasHi) + return emitError("mode '") + << stringifyPermuteMode(getMode()) << "' requires 'hi' operand."; + break; + case Mode::RC8: + case Mode::ECL: + case Mode::ECR: + case Mode::RC16: + if (hasHi) + return emitError("mode '") << stringifyPermuteMode(getMode()) + << "' does not accept 'hi' operand."; + break; + } + + return success(); +} + //===----------------------------------------------------------------------===// // Stochastic Rounding Conversion Ops //===----------------------------------------------------------------------===// @@ -3855,6 +3880,31 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs( return {intrinsicID, args}; } +mlir::NVVM::IDArgPair +PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + NVVM::PermuteMode mode = thisOp.getMode(); + + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e, + llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8, + llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr, + llvm::Intrinsic::nvvm_prmt_rc16}; + + unsigned modeIndex = static_cast(mode); + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getLo())); + + // Only first 3 modes (Default, f4e, b4e) need the hi operand. + if (modeIndex < 3) + args.push_back(mt.lookupValue(thisOp.getHi())); + + args.push_back(mt.lookupValue(thisOp.getSelector())); + + return {IDs[modeIndex], args}; +} + //===----------------------------------------------------------------------===// // NVVM tcgen05.mma functions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir new file mode 100644 index 0000000000000..1d6c23cc29fbc --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +llvm.func @invalid_default_missing_hi(%sel: i32, %lo: i32) -> i32 { + // expected-error @below {{mode 'default' requires 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode %sel, %lo : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_f4e_missing_hi(%sel: i32, %lo: i32) -> i32 { + // expected-error @below {{mode 'f4e' requires 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode %sel, %lo : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_b4e_missing_hi(%sel: i32, %lo: i32) -> i32 { + // expected-error @below {{mode 'b4e' requires 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode %sel, %lo : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_rc8_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'rc8' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode %sel, %lo, %hi : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_ecl_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'ecl' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode %sel, %lo, %hi : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_ecr_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'ecr' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode %sel, %lo, %hi : i32 + llvm.return %r : i32 +} + +llvm.func @invalid_rc16_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // expected-error @below {{mode 'rc16' does not accept 'hi' operand.}} + %r = nvvm.prmt #nvvm.permute_mode %sel, %lo, %hi : i32 + llvm.return %r : i32 +} diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir new file mode 100644 index 0000000000000..d2baae79ce95e --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @test_prmt_default +llvm.func @test_prmt_default(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode %sel, %lo, %hi : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_f4e +llvm.func @test_prmt_f4e(%pos: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode %pos, %lo, %hi : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_b4e +llvm.func @test_prmt_b4e(%pos: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.b4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode %pos, %lo, %hi : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_rc8 +llvm.func @test_prmt_rc8(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_ecl +llvm.func @test_prmt_ecl(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.ecl(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_ecr +llvm.func @test_prmt_ecr(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.ecr(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_rc16 +llvm.func @test_prmt_rc16(%sel: i32, %val: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt.rc16(i32 %{{.*}}, i32 %{{.*}}) + %result = nvvm.prmt #nvvm.permute_mode %sel, %val : i32 + llvm.return %result : i32 +} + +// CHECK-LABEL: @test_prmt_mixed +llvm.func @test_prmt_mixed(%sel: i32, %lo: i32, %hi: i32) -> i32 { + // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r1 = nvvm.prmt #nvvm.permute_mode %sel, %lo, %hi : i32 + + // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}}) + %r2 = nvvm.prmt #nvvm.permute_mode %sel, %r1 : i32 + + // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r3 = nvvm.prmt #nvvm.permute_mode %lo, %r2, %sel : i32 + + llvm.return %r3 : i32 +}