-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][NVVM] Add Permute Op #169793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add Permute Op #169793
Conversation
This patch adds the permute op. Lit tests are added to verify the lowering to the intrinsics. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt Signed-off-by: Dharuni R Acharya <dharunira@nvidia.com>
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Dharuni R Acharya (DharuniRAcharya) ChangesThis patch adds the PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt Full diff: https://github.com/llvm/llvm-project/pull/169793.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index d78145d690fc8..0b3ca385d0a78 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1567,6 +1567,133 @@ def NVVM_ElectSyncOp : NVVM_Op<"elect.sync">
}];
}
+//===----------------------------------------------------------------------===//
+// Permute Bytes (Prmt)
+//===----------------------------------------------------------------------===//
+
+// Attributes for the permute operation modes supported by PTX.
+def PermuteModeDefault : I32EnumAttrCase<"DEFAULT", 0, "default">;
+def PermuteModeF4E : I32EnumAttrCase<"F4E", 1, "f4e">;
+def PermuteModeB4E : I32EnumAttrCase<"B4E", 2, "b4e">;
+def PermuteModeRC8 : I32EnumAttrCase<"RC8", 3, "rc8">;
+def PermuteModeECL : I32EnumAttrCase<"ECL", 4, "ecl">;
+def PermuteModeECR : I32EnumAttrCase<"ECR", 5, "ecr">;
+def PermuteModeRC16 : I32EnumAttrCase<"RC16", 6, "rc16">;
+
+def PermuteMode : I32EnumAttr<"PermuteMode", "NVVM permute mode",
+ [PermuteModeDefault, PermuteModeF4E,
+ PermuteModeB4E, PermuteModeRC8, PermuteModeECL,
+ PermuteModeECR, PermuteModeRC16]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+
+def PermuteModeAttr : EnumAttr<NVVM_Dialect, PermuteMode, "permute_mode"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_PermuteOp : NVVM_Op<"prmt", [Pure]>,
+ Results<(outs I32:$res)>,
+ Arguments<(ins I32:$lo, Optional<I32>:$hi, I32:$selector,
+ PermuteModeAttr:$mode)> {
+ let summary = "Permute bytes from two 32-bit registers";
+ let description = [{
+ The `nvvm.prmt` operation constructs a permutation of the
+ bytes of the first one or two operands, selecting based on
+ the 2 least significant bits of the final operand.
+
+ The bytes in the first one or two source operands are numbered.
+ The first source operand (%lo) is numbered {b3, b2, b1, b0},
+ in the case of the '``default``', '``f4e``' and '``b4e``' variants,
+ the second source operand (%hi) is numbered {b7, b6, b5, b4}.
+
+ Modes:
+ - `default`: Index mode - each nibble in `selector` selects a byte from the 8-byte pool
+ - `f4e` : Forward 4 extract - extracts 4 contiguous bytes starting from position in `selector`
+ - `b4e` : Backward 4 extract - extracts 4 contiguous bytes in reverse order
+ - `rc8` : Replicate 8 - replicates the lower 8 bits across the 32-bit result
+ - `ecl` : Edge clamp left - clamps out-of-range indices to the leftmost valid byte
+ - `ecr` : Edge clamp right - clamps out-of-range indices to the rightmost valid byte
+ - `rc16` : Replicate 16 - replicates the lower 16 bits across the 32-bit result
+
+ Depending on the 2 least significant bits of the %selector operand, the result
+ of the permutation is defined as follows:
+
+ +------------+----------------+--------------+
+ | Mode | %selector[1:0] | Output |
+ +------------+----------------+--------------+
+ | '``f4e``' | 0 | {3, 2, 1, 0} |
+ | +----------------+--------------+
+ | | 1 | {4, 3, 2, 1} |
+ | +----------------+--------------+
+ | | 2 | {5, 4, 3, 2} |
+ | +----------------+--------------+
+ | | 3 | {6, 5, 4, 3} |
+ +------------+----------------+--------------+
+ | '``b4e``' | 0 | {5, 6, 7, 0} |
+ | +----------------+--------------+
+ | | 1 | {6, 7, 0, 1} |
+ | +----------------+--------------+
+ | | 2 | {7, 0, 1, 2} |
+ | +----------------+--------------+
+ | | 3 | {0, 1, 2, 3} |
+ +------------+----------------+--------------+
+ | '``rc8``' | 0 | {0, 0, 0, 0} |
+ | +----------------+--------------+
+ | | 1 | {1, 1, 1, 1} |
+ | +----------------+--------------+
+ | | 2 | {2, 2, 2, 2} |
+ | +----------------+--------------+
+ | | 3 | {3, 3, 3, 3} |
+ +------------+----------------+--------------+
+ | '``ecl``' | 0 | {3, 2, 1, 0} |
+ | +----------------+--------------+
+ | | 1 | {3, 2, 1, 1} |
+ | +----------------+--------------+
+ | | 2 | {3, 2, 2, 2} |
+ | +----------------+--------------+
+ | | 3 | {3, 3, 3, 3} |
+ +------------+----------------+--------------+
+ | '``ecr``' | 0 | {0, 0, 0, 0} |
+ | +----------------+--------------+
+ | | 1 | {1, 1, 1, 0} |
+ | +----------------+--------------+
+ | | 2 | {2, 2, 1, 0} |
+ | +----------------+--------------+
+ | | 3 | {3, 2, 1, 0} |
+ +------------+----------------+--------------+
+ | '``rc16``' | 0 | {1, 0, 1, 0} |
+ | +----------------+--------------+
+ | | 1 | {3, 2, 3, 2} |
+ | +----------------+--------------+
+ | | 2 | {1, 0, 1, 0} |
+ | +----------------+--------------+
+ | | 3 | {3, 2, 3, 2} |
+ +------------+----------------+--------------+
+
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt)
+ }];
+
+ let assemblyFormat = [{
+ $mode $lo `,` $selector (`,` $hi^)? attr-dict `:` type($res)
+ }];
+
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(NVVM::PermuteMode mode, llvm::Value *lo,
+ llvm::Value *hi, llvm::Value *selector);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::PermuteOp::getIntrinsicIDAndArgs(
+ $mode, $lo, $hi, $selector);
+ $res = createIntrinsicCall(builder, id, args);
+ }];
+}
+
def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">;
def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">;
def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 428bc72c88a30..cfe5a98929467 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -448,6 +448,33 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
return success();
}
+LogicalResult PermuteOp::verify() {
+ using Mode = NVVM::PermuteMode;
+ bool hasHi = static_cast<bool>(getHi());
+
+ switch (getMode()) {
+ case Mode::DEFAULT:
+ case Mode::F4E:
+ case Mode::B4E:
+ if (!hasHi)
+ return emitError("mode '") << stringifyPermuteMode(getMode())
+ << "' requires 'hi' operand i.e. it requires "
+ "3 operands - lo, hi, selector.";
+ break;
+ case Mode::RC8:
+ case Mode::ECL:
+ case Mode::ECR:
+ case Mode::RC16:
+ if (hasHi)
+ return emitError("mode '") << stringifyPermuteMode(getMode())
+ << "' does not accept 'hi' operand i.e. it "
+ "requires 2 operands - lo, selector.";
+ break;
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Stochastic Rounding Conversion Ops
//===----------------------------------------------------------------------===//
@@ -3379,6 +3406,29 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
return {intrinsicID, args};
}
+mlir::NVVM::IDArgPair PermuteOp::getIntrinsicIDAndArgs(NVVM::PermuteMode mode,
+ llvm::Value *lo,
+ llvm::Value *hi,
+ llvm::Value *selector) {
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
+ llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
+ llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
+ llvm::Intrinsic::nvvm_prmt_rc16};
+
+ unsigned modeIndex = static_cast<unsigned>(mode);
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(lo);
+
+ if (modeIndex < 3)
+ args.push_back(hi);
+
+ args.push_back(selector);
+
+ return {IDs[modeIndex], args};
+}
+
//===----------------------------------------------------------------------===//
// NVVM tcgen05.mma functions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir
new file mode 100644
index 0000000000000..06beec2e4b78b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+llvm.func @invalid_default_missing_hi(%lo: i32, %sel: i32) -> i32 {
+ // expected-error @below {{mode 'default' requires 'hi' operand i.e. it requires 3 operands - lo, hi, selector}}
+ %r = nvvm.prmt #nvvm.permute_mode<default> %lo, %sel : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_f4e_missing_hi(%lo: i32, %sel: i32) -> i32 {
+ // expected-error @below {{mode 'f4e' requires 'hi' operand i.e. it requires 3 operands - lo, hi, selector}}
+ %r = nvvm.prmt #nvvm.permute_mode<f4e> %lo, %sel : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_b4e_missing_hi(%lo: i32, %sel: i32) -> i32 {
+ // expected-error @below {{mode 'b4e' requires 'hi' operand i.e. it requires 3 operands - lo, hi, selector}}
+ %r = nvvm.prmt #nvvm.permute_mode<b4e> %lo, %sel : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_rc8_with_hi(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+ // expected-error @below {{mode 'rc8' does not accept 'hi' operand i.e. it requires 2 operands - lo, selector}}
+ %r = nvvm.prmt #nvvm.permute_mode<rc8> %lo, %sel, %hi : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_ecl_with_hi(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+ // expected-error @below {{mode 'ecl' does not accept 'hi' operand i.e. it requires 2 operands - lo, selector}}
+ %r = nvvm.prmt #nvvm.permute_mode<ecl> %lo, %sel, %hi : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_ecr_with_hi(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+ // expected-error @below {{mode 'ecr' does not accept 'hi' operand i.e. it requires 2 operands - lo, selector}}
+ %r = nvvm.prmt #nvvm.permute_mode<ecr> %lo, %sel, %hi : i32
+ llvm.return %r : i32
+}
+
+llvm.func @invalid_rc16_with_hi(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+ // expected-error @below {{mode 'rc16' does not accept 'hi' operand i.e. it requires 2 operands - lo, selector}}
+ %r = nvvm.prmt #nvvm.permute_mode<rc16> %lo, %sel, %hi : i32
+ llvm.return %r : i32
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir
new file mode 100644
index 0000000000000..cfe150bfbac01
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @test_prmt_default
+llvm.func @test_prmt_default(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<default> %lo, %sel, %hi : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_f4e
+llvm.func @test_prmt_f4e(%lo: i32, %pos: i32, %hi: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<f4e> %lo, %pos, %hi : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_b4e
+llvm.func @test_prmt_b4e(%lo: i32, %pos: i32, %hi: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.b4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<b4e> %lo, %pos, %hi : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_rc8
+llvm.func @test_prmt_rc8(%val: i32, %sel: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<rc8> %val, %sel : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_ecl
+llvm.func @test_prmt_ecl(%val: i32, %sel: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.ecl(i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<ecl> %val, %sel : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_ecr
+llvm.func @test_prmt_ecr(%val: i32, %sel: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.ecr(i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<ecr> %val, %sel : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_rc16
+llvm.func @test_prmt_rc16(%val: i32, %sel: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt.rc16(i32 %{{.*}}, i32 %{{.*}})
+ %result = nvvm.prmt #nvvm.permute_mode<rc16> %val, %sel : i32
+ llvm.return %result : i32
+}
+
+// CHECK-LABEL: @test_prmt_mixed
+llvm.func @test_prmt_mixed(%lo: i32, %sel: i32, %hi: i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r1 = nvvm.prmt #nvvm.permute_mode<default> %lo, %sel, %hi : i32
+
+ // CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}})
+ %r2 = nvvm.prmt #nvvm.permute_mode<rc8> %r1, %sel : i32
+
+ // CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %r3 = nvvm.prmt #nvvm.permute_mode<f4e> %r2, %lo, %sel : i32
+
+ llvm.return %r3 : i32
+}
|
grypp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, PR is very clean and good shape
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
durga4github
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest revision LGTM
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/22448 Here is the relevant piece of the build log for the reference |
This patch adds the `permute` op. Lit tests are added to verify the lowering to the intrinsics. Negative tests are also added to check the error-handling of invalid combinations. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt Signed-off-by: Dharuni R Acharya <dharunira@nvidia.com>
This patch adds the `permute` op. Lit tests are added to verify the lowering to the intrinsics. Negative tests are also added to check the error-handling of invalid combinations. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt Signed-off-by: Dharuni R Acharya <dharunira@nvidia.com>
This patch adds the `permute` op. Lit tests are added to verify the lowering to the intrinsics. Negative tests are also added to check the error-handling of invalid combinations. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt Signed-off-by: Dharuni R Acharya <dharunira@nvidia.com>
This patch adds the
permuteop.Lit tests are added to verify the lowering to the intrinsics.
Negative tests are also added to check the error-handling of invalid combinations.
PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt