Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,133 @@ def NVVM_ElectSyncOp : NVVM_Op<"elect.sync">
}];
}

//===----------------------------------------------------------------------===//
// Permute Bytes (Prmt)
//===----------------------------------------------------------------------===//

// Attributes for the permute operation modes supported by PTX.
def PermuteModeDefault : I32EnumAttrCase<"DEFAULT", 0, "default">;
def PermuteModeF4E : I32EnumAttrCase<"F4E", 1, "f4e">;
def PermuteModeB4E : I32EnumAttrCase<"B4E", 2, "b4e">;
def PermuteModeRC8 : I32EnumAttrCase<"RC8", 3, "rc8">;
def PermuteModeECL : I32EnumAttrCase<"ECL", 4, "ecl">;
def PermuteModeECR : I32EnumAttrCase<"ECR", 5, "ecr">;
def PermuteModeRC16 : I32EnumAttrCase<"RC16", 6, "rc16">;

def PermuteMode : I32EnumAttr<"PermuteMode", "NVVM permute mode",
[PermuteModeDefault, PermuteModeF4E,
PermuteModeB4E, PermuteModeRC8, PermuteModeECL,
PermuteModeECR, PermuteModeRC16]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}

def PermuteModeAttr : EnumAttr<NVVM_Dialect, PermuteMode, "permute_mode"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_PermuteOp : NVVM_Op<"prmt", [Pure]>,
Results<(outs I32:$res)>,
Arguments<(ins I32:$lo, Optional<I32>:$hi, I32:$selector,
PermuteModeAttr:$mode)> {
let summary = "Permute bytes from two 32-bit registers";
let description = [{
The `nvvm.prmt` operation constructs a permutation of the
bytes of the first one or two operands, selecting based on
the 2 least significant bits of the final operand.

The bytes in the first one or two source operands are numbered.
The first source operand (%lo) is numbered {b3, b2, b1, b0},
in the case of the '``default``', '``f4e``' and '``b4e``' variants,
the second source operand (%hi) is numbered {b7, b6, b5, b4}.

Modes:
- `default`: Index mode - each nibble in `selector` selects a byte from the 8-byte pool
- `f4e` : Forward 4 extract - extracts 4 contiguous bytes starting from position in `selector`
- `b4e` : Backward 4 extract - extracts 4 contiguous bytes in reverse order
- `rc8` : Replicate 8 - replicates the lower 8 bits across the 32-bit result
- `ecl` : Edge clamp left - clamps out-of-range indices to the leftmost valid byte
- `ecr` : Edge clamp right - clamps out-of-range indices to the rightmost valid byte
- `rc16` : Replicate 16 - replicates the lower 16 bits across the 32-bit result

Depending on the 2 least significant bits of the %selector operand, the result
of the permutation is defined as follows:

+------------+----------------+--------------+
| Mode | %selector[1:0] | Output |
+------------+----------------+--------------+
| '``f4e``' | 0 | {3, 2, 1, 0} |
| +----------------+--------------+
| | 1 | {4, 3, 2, 1} |
| +----------------+--------------+
| | 2 | {5, 4, 3, 2} |
| +----------------+--------------+
| | 3 | {6, 5, 4, 3} |
+------------+----------------+--------------+
| '``b4e``' | 0 | {5, 6, 7, 0} |
| +----------------+--------------+
| | 1 | {6, 7, 0, 1} |
| +----------------+--------------+
| | 2 | {7, 0, 1, 2} |
| +----------------+--------------+
| | 3 | {0, 1, 2, 3} |
+------------+----------------+--------------+
| '``rc8``' | 0 | {0, 0, 0, 0} |
| +----------------+--------------+
| | 1 | {1, 1, 1, 1} |
| +----------------+--------------+
| | 2 | {2, 2, 2, 2} |
| +----------------+--------------+
| | 3 | {3, 3, 3, 3} |
+------------+----------------+--------------+
| '``ecl``' | 0 | {3, 2, 1, 0} |
| +----------------+--------------+
| | 1 | {3, 2, 1, 1} |
| +----------------+--------------+
| | 2 | {3, 2, 2, 2} |
| +----------------+--------------+
| | 3 | {3, 3, 3, 3} |
+------------+----------------+--------------+
| '``ecr``' | 0 | {0, 0, 0, 0} |
| +----------------+--------------+
| | 1 | {1, 1, 1, 0} |
| +----------------+--------------+
| | 2 | {2, 2, 1, 0} |
| +----------------+--------------+
| | 3 | {3, 2, 1, 0} |
+------------+----------------+--------------+
| '``rc16``' | 0 | {1, 0, 1, 0} |
| +----------------+--------------+
| | 1 | {3, 2, 3, 2} |
| +----------------+--------------+
| | 2 | {1, 0, 1, 0} |
| +----------------+--------------+
| | 3 | {3, 2, 3, 2} |
+------------+----------------+--------------+

[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prmt)
}];

let assemblyFormat = [{
$mode $selector `,` $lo (`,` $hi^)? attr-dict `:` type($res)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder);
}];

string llvmBuilder = [{
auto [id, args] = NVVM::PermuteOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
$res = createIntrinsicCall(builder, id, args);
}];
}

def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">;
def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">;
def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">;
Expand Down
50 changes: 50 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,31 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
return success();
}

LogicalResult PermuteOp::verify() {
using Mode = NVVM::PermuteMode;
bool hasHi = static_cast<bool>(getHi());

switch (getMode()) {
case Mode::DEFAULT:
case Mode::F4E:
case Mode::B4E:
if (!hasHi)
return emitError("mode '")
<< stringifyPermuteMode(getMode()) << "' requires 'hi' operand.";
break;
case Mode::RC8:
case Mode::ECL:
case Mode::ECR:
case Mode::RC16:
if (hasHi)
return emitError("mode '") << stringifyPermuteMode(getMode())
<< "' does not accept 'hi' operand.";
break;
}

return success();
}

//===----------------------------------------------------------------------===//
// Stochastic Rounding Conversion Ops
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3379,6 +3404,31 @@ NVVM::IDArgPair ClusterLaunchControlQueryCancelOp::getIntrinsicIDAndArgs(
return {intrinsicID, args};
}

mlir::NVVM::IDArgPair
PermuteOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::PermuteOp>(op);
NVVM::PermuteMode mode = thisOp.getMode();

static constexpr llvm::Intrinsic::ID IDs[] = {
llvm::Intrinsic::nvvm_prmt, llvm::Intrinsic::nvvm_prmt_f4e,
llvm::Intrinsic::nvvm_prmt_b4e, llvm::Intrinsic::nvvm_prmt_rc8,
llvm::Intrinsic::nvvm_prmt_ecl, llvm::Intrinsic::nvvm_prmt_ecr,
llvm::Intrinsic::nvvm_prmt_rc16};

unsigned modeIndex = static_cast<unsigned>(mode);
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getLo()));

// Only first 3 modes (Default, f4e, b4e) need the hi operand.
if (modeIndex < 3)
args.push_back(mt.lookupValue(thisOp.getHi()));

args.push_back(mt.lookupValue(thisOp.getSelector()));

return {IDs[modeIndex], args};
}

//===----------------------------------------------------------------------===//
// NVVM tcgen05.mma functions
//===----------------------------------------------------------------------===//
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/permute_invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s

llvm.func @invalid_default_missing_hi(%sel: i32, %lo: i32) -> i32 {
// expected-error @below {{mode 'default' requires 'hi' operand.}}
%r = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo : i32
llvm.return %r : i32
}

llvm.func @invalid_f4e_missing_hi(%sel: i32, %lo: i32) -> i32 {
// expected-error @below {{mode 'f4e' requires 'hi' operand.}}
%r = nvvm.prmt #nvvm.permute_mode<f4e> %sel, %lo : i32
llvm.return %r : i32
}

llvm.func @invalid_b4e_missing_hi(%sel: i32, %lo: i32) -> i32 {
// expected-error @below {{mode 'b4e' requires 'hi' operand.}}
%r = nvvm.prmt #nvvm.permute_mode<b4e> %sel, %lo : i32
llvm.return %r : i32
}

llvm.func @invalid_rc8_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
// expected-error @below {{mode 'rc8' does not accept 'hi' operand.}}
%r = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %lo, %hi : i32
llvm.return %r : i32
}

llvm.func @invalid_ecl_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
// expected-error @below {{mode 'ecl' does not accept 'hi' operand.}}
%r = nvvm.prmt #nvvm.permute_mode<ecl> %sel, %lo, %hi : i32
llvm.return %r : i32
}

llvm.func @invalid_ecr_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
// expected-error @below {{mode 'ecr' does not accept 'hi' operand.}}
%r = nvvm.prmt #nvvm.permute_mode<ecr> %sel, %lo, %hi : i32
llvm.return %r : i32
}

llvm.func @invalid_rc16_with_hi(%sel: i32, %lo: i32, %hi: i32) -> i32 {
// expected-error @below {{mode 'rc16' does not accept 'hi' operand.}}
%r = nvvm.prmt #nvvm.permute_mode<rc16> %sel, %lo, %hi : i32
llvm.return %r : i32
}
64 changes: 64 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/permute_valid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// CHECK-LABEL: @test_prmt_default
llvm.func @test_prmt_default(%sel: i32, %lo: i32, %hi: i32) -> i32 {
// CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%result = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo, %hi : i32
llvm.return %result : i32
}

// CHECK-LABEL: @test_prmt_f4e
llvm.func @test_prmt_f4e(%pos: i32, %lo: i32, %hi: i32) -> i32 {
// CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%result = nvvm.prmt #nvvm.permute_mode<f4e> %pos, %lo, %hi : i32
llvm.return %result : i32
}

// CHECK-LABEL: @test_prmt_b4e
llvm.func @test_prmt_b4e(%pos: i32, %lo: i32, %hi: i32) -> i32 {
// CHECK: call i32 @llvm.nvvm.prmt.b4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%result = nvvm.prmt #nvvm.permute_mode<b4e> %pos, %lo, %hi : i32
llvm.return %result : i32
}

// CHECK-LABEL: @test_prmt_rc8
llvm.func @test_prmt_rc8(%sel: i32, %val: i32) -> i32 {
// CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}})
%result = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %val : i32
llvm.return %result : i32
}

// CHECK-LABEL: @test_prmt_ecl
llvm.func @test_prmt_ecl(%sel: i32, %val: i32) -> i32 {
// CHECK: call i32 @llvm.nvvm.prmt.ecl(i32 %{{.*}}, i32 %{{.*}})
%result = nvvm.prmt #nvvm.permute_mode<ecl> %sel, %val : i32
llvm.return %result : i32
}

// CHECK-LABEL: @test_prmt_ecr
llvm.func @test_prmt_ecr(%sel: i32, %val: i32) -> i32 {
// CHECK: call i32 @llvm.nvvm.prmt.ecr(i32 %{{.*}}, i32 %{{.*}})
%result = nvvm.prmt #nvvm.permute_mode<ecr> %sel, %val : i32
llvm.return %result : i32
}

// CHECK-LABEL: @test_prmt_rc16
llvm.func @test_prmt_rc16(%sel: i32, %val: i32) -> i32 {
// CHECK: call i32 @llvm.nvvm.prmt.rc16(i32 %{{.*}}, i32 %{{.*}})
%result = nvvm.prmt #nvvm.permute_mode<rc16> %sel, %val : i32
llvm.return %result : i32
}

// CHECK-LABEL: @test_prmt_mixed
llvm.func @test_prmt_mixed(%sel: i32, %lo: i32, %hi: i32) -> i32 {
// CHECK: call i32 @llvm.nvvm.prmt(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%r1 = nvvm.prmt #nvvm.permute_mode<default> %sel, %lo, %hi : i32

// CHECK: call i32 @llvm.nvvm.prmt.rc8(i32 %{{.*}}, i32 %{{.*}})
%r2 = nvvm.prmt #nvvm.permute_mode<rc8> %sel, %r1 : i32

// CHECK: call i32 @llvm.nvvm.prmt.f4e(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%r3 = nvvm.prmt #nvvm.permute_mode<f4e> %lo, %r2, %sel : i32

llvm.return %r3 : i32
}