230 changes: 134 additions & 96 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ bool NVPTXProxyRegErasure::runOnMachineFunction(MachineFunction &MF) {
case NVPTX::ProxyRegI64:
case NVPTX::ProxyRegF16:
case NVPTX::ProxyRegF16x2:
case NVPTX::ProxyRegBF16:
case NVPTX::ProxyRegBF16x2:
case NVPTX::ProxyRegF32:
case NVPTX::ProxyRegF64:
replaceMachineInstructionUsage(MF, MI);
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>;
def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>;
def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4), VRFrame32, VRFrameLocal32)>;
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
def Float16Regs : NVPTXRegClass<[f16], 16, (add (sequence "H%u", 0, 4))>;
def Float16x2Regs : NVPTXRegClass<[v2f16], 32, (add (sequence "HH%u", 0, 4))>;
def Float16Regs : NVPTXRegClass<[f16,bf16], 16, (add (sequence "H%u", 0, 4))>;
def Float16x2Regs : NVPTXRegClass<[v2f16,v2bf16], 32, (add (sequence "HH%u", 0, 4))>;
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;
def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
Expand Down
35 changes: 35 additions & 0 deletions llvm/test/CodeGen/NVPTX/bf16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
; RUN: llc < %s -march=nvptx | FileCheck %s
; RUN: %if ptxas %{ llc < %s -march=nvptx | %ptxas-verify %}

; LDST: .b8 bfloat_array[8] = {1, 2, 3, 4, 5, 6, 7, 8};
@"bfloat_array" = addrspace(1) constant [4 x bfloat]
[bfloat 0xR0201, bfloat 0xR0403, bfloat 0xR0605, bfloat 0xR0807]

define void @test_load_store(bfloat addrspace(1)* %in, bfloat addrspace(1)* %out) {
; CHECK-LABEL: @test_load_store
; CHECK: ld.global.b16 [[TMP:%h[0-9]+]], [{{%r[0-9]+}}]
; CHECK: st.global.b16 [{{%r[0-9]+}}], [[TMP]]
%val = load bfloat, bfloat addrspace(1)* %in
store bfloat %val, bfloat addrspace(1) * %out
ret void
}

define void @test_bitcast_from_bfloat(bfloat addrspace(1)* %in, i16 addrspace(1)* %out) {
; CHECK-LABEL: @test_bitcast_from_bfloat
; CHECK: ld.global.b16 [[TMP:%h[0-9]+]], [{{%r[0-9]+}}]
; CHECK: st.global.b16 [{{%r[0-9]+}}], [[TMP]]
%val = load bfloat, bfloat addrspace(1) * %in
%val_int = bitcast bfloat %val to i16
store i16 %val_int, i16 addrspace(1)* %out
ret void
}

define void @test_bitcast_to_bfloat(bfloat addrspace(1)* %out, i16 addrspace(1)* %in) {
; CHECK-LABEL: @test_bitcast_to_bfloat
; CHECK: ld.global.u16 [[TMP:%rs[0-9]+]], [{{%r[0-9]+}}]
; CHECK: st.global.u16 [{{%r[0-9]+}}], [[TMP]]
%val = load i16, i16 addrspace(1)* %in
%val_fp = bitcast i16 %val to bfloat
store bfloat %val_fp, bfloat addrspace(1)* %out
ret void
}