Skip to content

Commit

Permalink
[mlir][gpu] Extend shuffle op modes and add nvvm lowering
Browse files Browse the repository at this point in the history
Add up, down and idx modes to gpu shuffle ops, also change the mode from
string to enum

Differential Revision: https://reviews.llvm.org/D114188
  • Loading branch information
ThomasRaoux committed Nov 19, 2021
1 parent ff7f2cf commit 47555d7
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 67 deletions.
10 changes: 9 additions & 1 deletion mlir/include/mlir/Dialect/GPU/GPUOps.td
Expand Up @@ -647,13 +647,21 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
}

def GPU_ShuffleOpXor : StrEnumAttrCase<"XOR", -1, "xor">;
def GPU_ShuffleOpDown : StrEnumAttrCase<"DOWN", -1, "down">;
def GPU_ShuffleOpUp : StrEnumAttrCase<"UP", -1, "up">;
def GPU_ShuffleOpIdx : StrEnumAttrCase<"IDX", -1, "idx">;

def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr",
"Indexing modes supported by gpu.shuffle.",
[
GPU_ShuffleOpXor,
GPU_ShuffleOpXor, GPU_ShuffleOpUp, GPU_ShuffleOpDown, GPU_ShuffleOpIdx,
]>{
let cppNamespace = "::mlir::gpu";
let storageType = "mlir::StringAttr";
let returnType = "::mlir::gpu::ShuffleModeAttr";
let convertFromStorage =
"*symbolizeEnum<::mlir::gpu::ShuffleModeAttr>($_self.getValue())";
let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))";
}


Expand Down
36 changes: 27 additions & 9 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Expand Up @@ -97,22 +97,36 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
let assemblyFormat = "attr-dict";
}

def NVVM_ShflBflyOp :
NVVM_Op<"shfl.sync.bfly">,
def ShflKindBfly : StrEnumAttrCase<"bfly">;
def ShflKindUp : StrEnumAttrCase<"up">;
def ShflKindDown : StrEnumAttrCase<"down">;
def ShflKindIdx : StrEnumAttrCase<"idx">;

/// Enum attribute of the different shuffle kinds.
def ShflKind : StrEnumAttr<"ShflKind", "NVVM shuffle kind",
[ShflKindBfly, ShflKindUp, ShflKindDown, ShflKindIdx]> {
let cppNamespace = "::mlir::NVVM";
let storageType = "mlir::StringAttr";
let returnType = "NVVM::ShflKind";
let convertFromStorage = "*symbolizeEnum<NVVM::ShflKind>($_self.getValue())";
let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))";
}

def NVVM_ShflOp :
NVVM_Op<"shfl.sync">,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$dst,
Arguments<(ins I32:$dst,
LLVM_Type:$val,
LLVM_Type:$offset,
LLVM_Type:$mask_and_clamp,
I32:$offset,
I32:$mask_and_clamp,
ShflKind:$kind,
OptionalAttr<UnitAttr>:$return_value_and_is_valid)> {
string llvmBuilder = [{
auto intId = getShflBflyIntrinsicId(
$_resultType, static_cast<bool>($return_value_and_is_valid));
auto intId = getShflIntrinsicId(
$_resultType, $kind, static_cast<bool>($return_value_and_is_valid));
$res = createIntrinsicCall(builder,
intId, {$dst, $val, $offset, $mask_and_clamp});
}];
let parser = [{ return parseNVVMShflSyncBflyOp(parser, result); }];
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
let verifier = [{
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
Expand All @@ -125,6 +139,10 @@ def NVVM_ShflBflyOp :
"i1 as the second element");
return success();
}];
let assemblyFormat = [{
$kind $dst `,` $val `,` $offset `,` $mask_and_clamp attr-dict
`:` type($val) `->` type($res)
}];
}

def NVVM_VoteBallotOp :
Expand Down
19 changes: 17 additions & 2 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Expand Up @@ -39,6 +39,21 @@ using namespace mlir;

namespace {

/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
static NVVM::ShflKind convertShflKind(gpu::ShuffleModeAttr mode) {
switch (mode) {
case gpu::ShuffleModeAttr::XOR:
return NVVM::ShflKind::bfly;
case gpu::ShuffleModeAttr::UP:
return NVVM::ShflKind::up;
case gpu::ShuffleModeAttr::DOWN:
return NVVM::ShflKind::down;
case gpu::ShuffleModeAttr::IDX:
return NVVM::ShflKind::idx;
}
llvm_unreachable("unknown shuffle mode");
}

struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;

Expand Down Expand Up @@ -81,9 +96,9 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.width(), one);

auto returnValueAndIsValidAttr = rewriter.getUnitAttr();
Value shfl = rewriter.create<NVVM::ShflBflyOp>(
Value shfl = rewriter.create<NVVM::ShflOp>(
loc, resultTy, activeMask, adaptor.value(), adaptor.offset(),
maskAndClamp, returnValueAndIsValidAttr);
maskAndClamp, convertShflKind(op.mode()), returnValueAndIsValidAttr);
Value shflValue = rewriter.create<LLVM::ExtractValueOp>(
loc, valueTy, shfl, rewriter.getIndexArrayAttr(0));
Value isActiveSrcLane = rewriter.create<LLVM::ExtractValueOp>(
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Expand Up @@ -302,7 +302,7 @@ static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
}

static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) {
p << ' ' << op.getOperands() << ' ' << op.mode() << " : "
p << ' ' << op.getOperands() << ' ' << stringifyEnum(op.mode()) << " : "
<< op.value().getType();
}

Expand Down
27 changes: 0 additions & 27 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Expand Up @@ -43,33 +43,6 @@ static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
p << " : " << op->getResultTypes();
}

// <operation> ::=
// `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
// ({return_value_and_is_valid})? : result_type
static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 8> ops;
Type resultType;
if (parser.parseOperandList(ops) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(resultType) ||
parser.addTypeToList(resultType, result.types))
return failure();

for (auto &attr : result.attributes) {
if (attr.getName() != "return_value_and_is_valid")
continue;
auto structType = resultType.dyn_cast<LLVM::LLVMStructType>();
if (structType && !structType.getBody().empty())
resultType = structType.getBody()[0];
break;
}

auto int32Ty = IntegerType::get(parser.getContext(), 32);
return parser.resolveOperands(ops, {int32Ty, resultType, int32Ty, int32Ty},
parser.getNameLoc(), result.operands);
}

// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
OperationState &result) {
Expand Down
42 changes: 36 additions & 6 deletions mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Expand Up @@ -23,15 +23,45 @@ using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::detail::createIntrinsicCall;

static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
bool withPredicate) {
static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
NVVM::ShflKind kind,
bool withPredicate) {

if (withPredicate) {
resultType = cast<llvm::StructType>(resultType)->getElementType(0);
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
switch (kind) {
case NVVM::ShflKind::bfly:
return resultType->isFloatTy()
? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
case NVVM::ShflKind::up:
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
: llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
case NVVM::ShflKind::down:
return resultType->isFloatTy()
? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
: llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
case NVVM::ShflKind::idx:
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
: llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
}
} else {
switch (kind) {
case NVVM::ShflKind::bfly:
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
case NVVM::ShflKind::up:
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
: llvm::Intrinsic::nvvm_shfl_sync_up_i32;
case NVVM::ShflKind::down:
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
: llvm::Intrinsic::nvvm_shfl_sync_down_i32;
case NVVM::ShflKind::idx:
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
: llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
}
}
return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
: llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
llvm_unreachable("unknown shuffle kind");
}

namespace {
Expand Down
18 changes: 12 additions & 6 deletions mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
Expand Up @@ -78,7 +78,7 @@ gpu.module @test_module {
gpu.func @gpu_all_reduce_op() {
%arg0 = arith.constant 1.0 : f32
// TODO: Check full IR expansion once lowering has settled.
// CHECK: nvvm.shfl.sync.bfly
// CHECK: nvvm.shfl.sync "bfly" {{.*}}
// CHECK: nvvm.barrier0
// CHECK: llvm.fadd
%result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32)
Expand All @@ -94,7 +94,7 @@ gpu.module @test_module {
gpu.func @gpu_all_reduce_region() {
%arg0 = arith.constant 1 : i32
// TODO: Check full IR expansion once lowering has settled.
// CHECK: nvvm.shfl.sync.bfly
// CHECK: nvvm.shfl.sync "bfly" {{.*}}
// CHECK: nvvm.barrier0
%result = "gpu.all_reduce"(%arg0) ({
^bb(%lhs : i32, %rhs : i32):
Expand All @@ -109,7 +109,7 @@ gpu.module @test_module {

gpu.module @test_module {
// CHECK-LABEL: func @gpu_shuffle()
builtin.func @gpu_shuffle() -> (f32) {
builtin.func @gpu_shuffle() -> (f32, f32, f32, f32) {
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
%arg0 = arith.constant 1.0 : f32
// CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
Expand All @@ -120,12 +120,18 @@ gpu.module @test_module {
// CHECK: %[[#SHL:]] = llvm.shl %[[#ONE]], %[[#WIDTH]] : i32
// CHECK: %[[#MASK:]] = llvm.sub %[[#SHL]], %[[#ONE]] : i32
// CHECK: %[[#CLAMP:]] = llvm.sub %[[#WIDTH]], %[[#ONE]] : i32
// CHECK: %[[#SHFL:]] = nvvm.shfl.sync.bfly %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] : !llvm.struct<(f32, i1)>
// CHECK: %[[#SHFL:]] = nvvm.shfl.sync "bfly" %[[#MASK]], %[[#VALUE]], %[[#OFFSET]], %[[#CLAMP]] {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
// CHECK: llvm.extractvalue %[[#SHFL]][0 : index] : !llvm.struct<(f32, i1)>
// CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm.struct<(f32, i1)>
%shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (f32, i1)

std.return %shfl : f32
// CHECK: nvvm.shfl.sync "up" {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
%shflu, %predu = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "up" } : (f32, i32, i32) -> (f32, i1)
// CHECK: nvvm.shfl.sync "down" {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
%shfld, %predd = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "down" } : (f32, i32, i32) -> (f32, i1)
// CHECK: nvvm.shfl.sync "idx" {{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
%shfli, %predi = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "idx" } : (f32, i32, i32) -> (f32, i1)

std.return %shfl, %shflu, %shfld, %shfli : f32, f32,f32, f32
}
}

Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/GPU/ops.mlir
Expand Up @@ -55,6 +55,12 @@ module attributes {gpu.container_module} {
%offset = arith.constant 3 : i32
// CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} xor : f32
%shfl, %pred = gpu.shuffle %arg0, %offset, %width xor : f32
// CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} up : f32
%shfl1, %pred1 = gpu.shuffle %arg0, %offset, %width up : f32
// CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} down : f32
%shfl2, %pred2 = gpu.shuffle %arg0, %offset, %width down : f32
// CHECK: gpu.shuffle %{{.*}}, %{{.*}}, %{{.*}} idx : f32
%shfl3, %pred3 = gpu.shuffle %arg0, %offset, %width idx : f32

"gpu.barrier"() : () -> ()

Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Expand Up @@ -495,21 +495,21 @@ func @null_non_llvm_type() {

func @nvvm_invalid_shfl_pred_1(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32
%0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> i32
}

// -----

func @nvvm_invalid_shfl_pred_2(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32)>
%0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32)>
}

// -----

func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32) {
// expected-error@+1 {{expected return type to be a two-element struct with i1 as the second element}}
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32, i32)>
%0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i32)>
}

// -----
Expand Down
22 changes: 14 additions & 8 deletions mlir/test/Dialect/LLVMIR/nvvm.mlir
Expand Up @@ -37,20 +37,26 @@ func @llvm.nvvm.barrier0() {
func @nvvm_shfl(
%arg0 : i32, %arg1 : i32, %arg2 : i32,
%arg3 : i32, %arg4 : f32) -> i32 {
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 : i32
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32
%1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 : f32
// CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32 -> i32
%0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 : i32 -> i32
// CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
%1 = nvvm.shfl.sync "bfly" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
// CHECK: nvvm.shfl.sync "up" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
%2 = nvvm.shfl.sync "up" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
// CHECK: nvvm.shfl.sync "down" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
%3 = nvvm.shfl.sync "down" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
// CHECK: nvvm.shfl.sync "idx" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
%4 = nvvm.shfl.sync "idx" %arg0, %arg4, %arg1, %arg2 : f32 -> f32
llvm.return %0 : i32
}

func @nvvm_shfl_pred(
%arg0 : i32, %arg1 : i32, %arg2 : i32,
%arg3 : i32, %arg4 : f32) -> !llvm.struct<(i32, i1)> {
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.struct<(i32, i1)>
%0 = nvvm.shfl.sync.bfly %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(i32, i1)>
// CHECK: nvvm.shfl.sync.bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !llvm.struct<(f32, i1)>
%1 = nvvm.shfl.sync.bfly %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : !llvm.struct<(f32, i1)>
// CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
%0 = nvvm.shfl.sync "bfly" %arg0, %arg3, %arg1, %arg2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
// CHECK: nvvm.shfl.sync "bfly" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
%1 = nvvm.shfl.sync "bfly" %arg0, %arg4, %arg1, %arg2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
llvm.return %0 : !llvm.struct<(i32, i1)>
}

Expand Down
32 changes: 28 additions & 4 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Expand Up @@ -42,19 +42,43 @@ llvm.func @nvvm_shfl(
%0 : i32, %1 : i32, %2 : i32,
%3 : i32, %4 : f32) -> i32 {
// CHECK: call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 : i32
%6 = nvvm.shfl.sync "bfly" %0, %3, %1, %2 : i32 -> i32
// CHECK: call float @llvm.nvvm.shfl.sync.bfly.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 : f32
%7 = nvvm.shfl.sync "bfly" %0, %4, %1, %2 : f32 -> f32
// CHECK: call i32 @llvm.nvvm.shfl.sync.up.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%8 = nvvm.shfl.sync "up" %0, %3, %1, %2 : i32 -> i32
// CHECK: call float @llvm.nvvm.shfl.sync.up.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%9 = nvvm.shfl.sync "up" %0, %4, %1, %2 : f32 -> f32
// CHECK: call i32 @llvm.nvvm.shfl.sync.down.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%10 = nvvm.shfl.sync "down" %0, %3, %1, %2 : i32 -> i32
// CHECK: call float @llvm.nvvm.shfl.sync.down.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%11 = nvvm.shfl.sync "down" %0, %4, %1, %2 : f32 -> f32
// CHECK: call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%12 = nvvm.shfl.sync "idx" %0, %3, %1, %2 : i32 -> i32
// CHECK: call float @llvm.nvvm.shfl.sync.idx.f32(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%13 = nvvm.shfl.sync "idx" %0, %4, %1, %2 : f32 -> f32
llvm.return %6 : i32
}

llvm.func @nvvm_shfl_pred(
%0 : i32, %1 : i32, %2 : i32,
%3 : i32, %4 : f32) -> !llvm.struct<(i32, i1)> {
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.bfly.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%6 = nvvm.shfl.sync.bfly %0, %3, %1, %2 {return_value_and_is_valid} : !llvm.struct<(i32, i1)>
%6 = nvvm.shfl.sync "bfly" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.bfly.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%7 = nvvm.shfl.sync.bfly %0, %4, %1, %2 {return_value_and_is_valid} : !llvm.struct<(f32, i1)>
%7 = nvvm.shfl.sync "bfly" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.up.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%8 = nvvm.shfl.sync "up" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.up.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%9 = nvvm.shfl.sync "up" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.down.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%10 = nvvm.shfl.sync "down" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.down.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%11 = nvvm.shfl.sync "down" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
// CHECK: call { i32, i1 } @llvm.nvvm.shfl.sync.idx.i32p(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%12 = nvvm.shfl.sync "idx" %0, %3, %1, %2 {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
// CHECK: call { float, i1 } @llvm.nvvm.shfl.sync.idx.f32p(i32 %{{.*}}, float %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%13 = nvvm.shfl.sync "idx" %0, %4, %1, %2 {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
llvm.return %6 : !llvm.struct<(i32, i1)>
}

Expand Down

0 comments on commit 47555d7

Please sign in to comment.