Skip to content

Commit

Permalink
[NVPTX] Make tensor load/store intrinsics overloaded.
Browse files Browse the repository at this point in the history
This way we can support address-space specific variants without explicitly
encoding the space in the name of the intrinsic. Less intrinsics to deal with ->
less boilerplate.

Added a bit of tablegen magic to match/replace an intrinsics with a pointer
argument in particular address space with the space-specific instruction
variant.

Updated tests to use non-default address spaces.

Differential Revision: https://reviews.llvm.org/D43268

llvm-svn: 328006
  • Loading branch information
Artem-B committed Mar 20, 2018
1 parent 3a99893 commit 914d4ba
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 157 deletions.
8 changes: 3 additions & 5 deletions clang/lib/CodeGen/CGBuiltin.cpp
Expand Up @@ -10527,8 +10527,7 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
llvm_unreachable("Unexpected builtin ID.");
}
Value *Result =
Builder.CreateCall(CGM.getIntrinsic(IID),
{Builder.CreatePointerCast(Src, VoidPtrTy), Ldm});
Builder.CreateCall(CGM.getIntrinsic(IID, Src->getType()), {Src, Ldm});

// Save returned values.
for (unsigned i = 0; i < NumResults; ++i) {
Expand Down Expand Up @@ -10567,10 +10566,9 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
default:
llvm_unreachable("Unexpected builtin ID.");
}
Function *Intrinsic = CGM.getIntrinsic(IID);
Function *Intrinsic = CGM.getIntrinsic(IID, Dst->getType());
llvm::Type *ParamType = Intrinsic->getFunctionType()->getParamType(1);
SmallVector<Value *, 10> Values;
Values.push_back(Builder.CreatePointerCast(Dst, VoidPtrTy));
SmallVector<Value *, 10> Values = {Dst};
for (unsigned i = 0; i < NumResults; ++i) {
Value *V = Builder.CreateAlignedLoad(
Builder.CreateGEP(Src.getPointer(), llvm::ConstantInt::get(IntTy, i)),
Expand Down
65 changes: 20 additions & 45 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Expand Up @@ -3884,90 +3884,65 @@ def int_nvvm_match_all_sync_i64p :
//

// WMMA.LOAD
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
string Type, LLVMType regty, int WithStride>
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Type,
LLVMType regty, int WithStride>
: Intrinsic<!if(!eq(Abc#Type,"cf16"),
[regty, regty, regty, regty],
[regty, regty, regty, regty,
regty, regty, regty, regty]),
!if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
[], // Properties must be set during instantiation.
!if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
[IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
"llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
#Space
#!if(WithStride,".stride","")
#"."#Type>;

multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
}

multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
string Type, LLVMType regty> {
defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>;
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout, string Type,
LLVMType regty> {
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 1>;
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 0>;
}

multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
}

// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
// passed to Intrinsic<> form inside of a multiclass. Setting them globally
// outside of the multiclass works.
let IntrProperties = [IntrReadMem, IntrArgMemOnly,
ReadOnly<0>, NoCapture<0>] in {
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
}
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;

// WMMA.STORE.D
class NVVM_WMMA_STD_LSTS<string Layout, string Space,
string Type, LLVMType regty, int WithStride,
class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStride,
// This is only used to create a typed empty array we
// need to pass to !if below.
list<LLVMType>Empty=[]>
: Intrinsic<[],
!listconcat(
[llvm_ptr_ty],
[llvm_anyptr_ty],
!if(!eq(Type,"f16"),
[regty, regty, regty, regty],
[regty, regty, regty, regty,
regty, regty, regty, regty]),
!if(WithStride, [llvm_i32_ty], Empty)),
[], // Properties must be set during instantiation.
[IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
"llvm.nvvm.wmma.store.d.sync."#Layout
#".m16n16k16"#Space
#".m16n16k16"
#!if(WithStride,".stride","")
#"."#Type>;

multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
}

multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>;
def _stride: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 1>;
def NAME: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 0>;
}

multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
}

let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
WriteOnly<0>, NoCapture<0>] in {
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
}
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;

// WMMA.MMA
class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
Expand Down
58 changes: 5 additions & 53 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Expand Up @@ -3327,26 +3327,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_load_a_f16_row:
case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_load_a_f16_col_global:
case Intrinsic::nvvm_wmma_load_a_f16_row_global:
case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
case Intrinsic::nvvm_wmma_load_b_f16_col:
case Intrinsic::nvvm_wmma_load_b_f16_row:
case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_load_b_f16_col_global:
case Intrinsic::nvvm_wmma_load_b_f16_row_global:
case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
case Intrinsic::nvvm_wmma_load_b_f16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f16;
Info.ptrVal = I.getArgOperand(0);
Expand All @@ -3359,15 +3343,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_load_c_f16_col:
case Intrinsic::nvvm_wmma_load_c_f16_row:
case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f16_col_global:
case Intrinsic::nvvm_wmma_load_c_f16_row_global:
case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
case Intrinsic::nvvm_wmma_load_c_f16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
Expand All @@ -3380,15 +3356,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_load_c_f32_col:
case Intrinsic::nvvm_wmma_load_c_f32_row:
case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
case Intrinsic::nvvm_wmma_load_c_f32_col_global:
case Intrinsic::nvvm_wmma_load_c_f32_row_global:
case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
case Intrinsic::nvvm_wmma_load_c_f32_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
Expand All @@ -3401,15 +3369,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_store_d_f16_col:
case Intrinsic::nvvm_wmma_store_d_f16_row:
case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f16_col_global:
case Intrinsic::nvvm_wmma_store_d_f16_row_global:
case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
case Intrinsic::nvvm_wmma_store_d_f16_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
Expand All @@ -3422,15 +3382,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_store_d_f32_col:
case Intrinsic::nvvm_wmma_store_d_f32_row:
case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
case Intrinsic::nvvm_wmma_store_d_f32_col_global:
case Intrinsic::nvvm_wmma_store_d_f32_row_global:
case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
case Intrinsic::nvvm_wmma_store_d_f32_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
Expand Down

0 comments on commit 914d4ba

Please sign in to comment.