Skip to content

Commit

Permalink
Coalesce 16-bit FP types to use integer register classes.
Browse files Browse the repository at this point in the history
i16/f16/bf16 will use the same .b16 registers and
i32/v2f16 and v2bf16 will share .b32 registers.

The changes are mostly mechanical, intended to remove unnecessary register
classes which tend to produce redundant register moves.

Differential Revision: https://reviews.llvm.org/D151601

v2f16 regtype conversion to i32
  • Loading branch information
Artem-B committed Jun 5, 2023
1 parent deecf89 commit dc90f42
Show file tree
Hide file tree
Showing 22 changed files with 975 additions and 1,140 deletions.
25 changes: 14 additions & 11 deletions clang/test/CodeGenCUDA/bf16.cu
Expand Up @@ -2,7 +2,7 @@
// REQUIRES: x86-registered-target

// RUN: %clang_cc1 "-aux-triple" "x86_64-unknown-linux-gnu" "-triple" "nvptx64-nvidia-cuda" \
// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -S -o - %s | FileCheck %s
// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -O1 -S -o - %s | FileCheck %s

#include "Inputs/cuda.h"

Expand All @@ -11,36 +11,39 @@
// CHECK: .param .b16 _Z8test_argPDF16bDF16b_param_1
//
__device__ void test_arg(__bf16 *out, __bf16 in) {
// CHECK: ld.param.b16 %{{h.*}}, [_Z8test_argPDF16bDF16b_param_1];
// CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0];
// CHECK-DAG: ld.param.b16 %[[R:rs[0-9]+]], [_Z8test_argPDF16bDF16b_param_1];
__bf16 bf16 = in;
*out = bf16;
// CHECK: st.b16
// CHECK: st.b16 [%[[A]]], %[[R]]
// CHECK: ret;
}


// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z8test_retDF16b(
// CHECK: .param .b16 _Z8test_retDF16b_param_0
__device__ __bf16 test_ret( __bf16 in) {
// CHECK: ld.param.b16 %h{{.*}}, [_Z8test_retDF16b_param_0];
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z8test_retDF16b_param_0];
return in;
// CHECK: st.param.b16 [func_retval0+0], %h
// CHECK: st.param.b16 [func_retval0+0], %[[R]]
// CHECK: ret;
}

__device__ __bf16 external_func( __bf16 in);

// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z9test_callDF16b(
// CHECK: .param .b16 _Z9test_callDF16b_param_0
__device__ __bf16 test_call( __bf16 in) {
// CHECK: ld.param.b16 %h{{.*}}, [_Z9test_callDF16b_param_0];
// CHECK: st.param.b16 [param0+0], %h2;
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
// CHECK: st.param.b16 [param0+0], %[[R]];
// CHECK: .param .b32 retval0;
// CHECK: call.uni (retval0),
// CHECK-NEXT: _Z8test_retDF16b,
// CHECK-NEXT: _Z13external_funcDF16b,
// CHECK-NEXT: (
// CHECK-NEXT: param0
// CHECK-NEXT );
// CHECK: ld.param.b16 %h{{.*}}, [retval0+0];
return test_ret(in);
// CHECK: st.param.b16 [func_retval0+0], %h
// CHECK: ld.param.b16 %[[RET:rs[0-9]+]], [retval0+0];
return external_func(in);
// CHECK: st.param.b16 [func_retval0+0], %[[RET]]
// CHECK: ret;
}
4 changes: 0 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Expand Up @@ -309,10 +309,6 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
Ret = (5 << 28);
} else if (RC == &NVPTX::Float64RegsRegClass) {
Ret = (6 << 28);
} else if (RC == &NVPTX::Float16RegsRegClass) {
Ret = (7 << 28);
} else if (RC == &NVPTX::Float16x2RegsRegClass) {
Ret = (8 << 28);
} else {
report_fatal_error("Bad register class");
}
Expand Down
338 changes: 117 additions & 221 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Expand Up @@ -410,10 +410,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass);
addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass);
addRegisterClass(MVT::bf16, &NVPTX::Float16RegsRegClass);
addRegisterClass(MVT::v2bf16, &NVPTX::Float16x2RegsRegClass);
addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);

// Conversion to/from FP16/FP16x2 is always legal.
setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal);
Expand Down
5 changes: 0 additions & 5 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
Expand Up @@ -51,11 +51,6 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
} else if (DestRC == &NVPTX::Int64RegsRegClass) {
Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64rr
: NVPTX::BITCONVERT_64_F2I);
} else if (DestRC == &NVPTX::Float16RegsRegClass) {
Op = (SrcRC == &NVPTX::Float16RegsRegClass ? NVPTX::FMOV16rr
: NVPTX::BITCONVERT_16_I2F);
} else if (DestRC == &NVPTX::Float16x2RegsRegClass) {
Op = NVPTX::IMOV32rr;
} else if (DestRC == &NVPTX::Float32RegsRegClass) {
Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32rr
: NVPTX::BITCONVERT_32_I2F);
Expand Down

0 comments on commit dc90f42

Please sign in to comment.