Skip to content

Commit

Permalink
[AMDGPU] Add bf16 storage support
Browse files Browse the repository at this point in the history
- [Clang] Declare AMDGPU target as supporting BF16 for storage-only purposes on amdgcn
  - Add Sema & CodeGen tests cases.
  - Also add cases that D138651 would have covered as this patch replaces it.
- [AMDGPU] Add BF16 storage-only support
  - Support legalization/dealing with bf16 operations in DAGIsel.
  - bf16 as a type remains illegal and is represented as i16 for storage purposes.

Reviewed By: arsenm

Differential Revision: https://reviews.llvm.org/D139398
  • Loading branch information
Pierre-vh committed Dec 13, 2022
1 parent 1f9fe34 commit 678d894
Show file tree
Hide file tree
Showing 11 changed files with 3,514 additions and 9 deletions.
6 changes: 6 additions & 0 deletions clang/lib/Basic/Targets/AMDGPU.cpp
Expand Up @@ -365,6 +365,12 @@ AMDGPUTargetInfo::AMDGPUTargetInfo(const llvm::Triple &Triple,
!isAMDGCN(Triple));
UseAddrSpaceMapMangling = true;

if (isAMDGCN(Triple)) {
// __bf16 is always available as a load/store only type on AMDGCN.
BFloat16Width = BFloat16Align = 16;
BFloat16Format = &llvm::APFloat::BFloat();
}

HasLegalHalfType = true;
HasFloat16 = true;
WavefrontSize = GPUFeatures & llvm::AMDGPU::FEATURE_WAVE32 ? 32 : 64;
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Basic/Targets/AMDGPU.h
Expand Up @@ -115,6 +115,9 @@ class LLVM_LIBRARY_VISIBILITY AMDGPUTargetInfo final : public TargetInfo {
return getTriple().getArch() == llvm::Triple::amdgcn ? 64 : 32;
}

bool hasBFloat16Type() const override { return isAMDGCN(getTriple()); }
const char *getBFloat16Mangling() const override { return "u6__bf16"; };

const char *getClobbers() const override { return ""; }

ArrayRef<const char *> getGCCRegNames() const override;
Expand Down
129 changes: 129 additions & 0 deletions clang/test/CodeGenCUDA/amdgpu-bf16.cu
@@ -0,0 +1,129 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py
// REQUIRES: amdgpu-registered-target
// REQUIRES: x86-registered-target

// RUN: %clang_cc1 "-aux-triple" "x86_64-unknown-linux-gnu" "-triple" "amdgcn-amd-amdhsa" \
// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -emit-llvm -o - %s | FileCheck %s

#include "Inputs/cuda.h"

// CHECK-LABEL: @_Z8test_argPu6__bf16u6__bf16(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[OUT_ADDR:%.*]] = alloca ptr, align 8, addrspace(5)
// CHECK-NEXT: [[IN_ADDR:%.*]] = alloca bfloat, align 2, addrspace(5)
// CHECK-NEXT: [[BF16:%.*]] = alloca bfloat, align 2, addrspace(5)
// CHECK-NEXT: [[OUT_ADDR_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[OUT_ADDR]] to ptr
// CHECK-NEXT: [[IN_ADDR_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[IN_ADDR]] to ptr
// CHECK-NEXT: [[BF16_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[BF16]] to ptr
// CHECK-NEXT: store ptr [[OUT:%.*]], ptr [[OUT_ADDR_ASCAST]], align 8
// CHECK-NEXT: store bfloat [[IN:%.*]], ptr [[IN_ADDR_ASCAST]], align 2
// CHECK-NEXT: [[TMP0:%.*]] = load bfloat, ptr [[IN_ADDR_ASCAST]], align 2
// CHECK-NEXT: store bfloat [[TMP0]], ptr [[BF16_ASCAST]], align 2
// CHECK-NEXT: [[TMP1:%.*]] = load bfloat, ptr [[BF16_ASCAST]], align 2
// CHECK-NEXT: [[TMP2:%.*]] = load ptr, ptr [[OUT_ADDR_ASCAST]], align 8
// CHECK-NEXT: store bfloat [[TMP1]], ptr [[TMP2]], align 2
// CHECK-NEXT: ret void
//
__device__ void test_arg(__bf16 *out, __bf16 in) {
__bf16 bf16 = in;
*out = bf16;
}

// CHECK-LABEL: @_Z9test_loadPu6__bf16S_(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[OUT_ADDR:%.*]] = alloca ptr, align 8, addrspace(5)
// CHECK-NEXT: [[IN_ADDR:%.*]] = alloca ptr, align 8, addrspace(5)
// CHECK-NEXT: [[BF16:%.*]] = alloca bfloat, align 2, addrspace(5)
// CHECK-NEXT: [[OUT_ADDR_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[OUT_ADDR]] to ptr
// CHECK-NEXT: [[IN_ADDR_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[IN_ADDR]] to ptr
// CHECK-NEXT: [[BF16_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[BF16]] to ptr
// CHECK-NEXT: store ptr [[OUT:%.*]], ptr [[OUT_ADDR_ASCAST]], align 8
// CHECK-NEXT: store ptr [[IN:%.*]], ptr [[IN_ADDR_ASCAST]], align 8
// CHECK-NEXT: [[TMP0:%.*]] = load ptr, ptr [[IN_ADDR_ASCAST]], align 8
// CHECK-NEXT: [[TMP1:%.*]] = load bfloat, ptr [[TMP0]], align 2
// CHECK-NEXT: store bfloat [[TMP1]], ptr [[BF16_ASCAST]], align 2
// CHECK-NEXT: [[TMP2:%.*]] = load bfloat, ptr [[BF16_ASCAST]], align 2
// CHECK-NEXT: [[TMP3:%.*]] = load ptr, ptr [[OUT_ADDR_ASCAST]], align 8
// CHECK-NEXT: store bfloat [[TMP2]], ptr [[TMP3]], align 2
// CHECK-NEXT: ret void
//
__device__ void test_load(__bf16 *out, __bf16 *in) {
__bf16 bf16 = *in;
*out = bf16;
}

// CHECK-LABEL: @_Z8test_retu6__bf16(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[RETVAL:%.*]] = alloca bfloat, align 2, addrspace(5)
// CHECK-NEXT: [[IN_ADDR:%.*]] = alloca bfloat, align 2, addrspace(5)
// CHECK-NEXT: [[RETVAL_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[RETVAL]] to ptr
// CHECK-NEXT: [[IN_ADDR_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[IN_ADDR]] to ptr
// CHECK-NEXT: store bfloat [[IN:%.*]], ptr [[IN_ADDR_ASCAST]], align 2
// CHECK-NEXT: [[TMP0:%.*]] = load bfloat, ptr [[IN_ADDR_ASCAST]], align 2
// CHECK-NEXT: ret bfloat [[TMP0]]
//
__device__ __bf16 test_ret( __bf16 in) {
return in;
}

// CHECK-LABEL: @_Z9test_callu6__bf16(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[RETVAL:%.*]] = alloca bfloat, align 2, addrspace(5)
// CHECK-NEXT: [[IN_ADDR:%.*]] = alloca bfloat, align 2, addrspace(5)
// CHECK-NEXT: [[RETVAL_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[RETVAL]] to ptr
// CHECK-NEXT: [[IN_ADDR_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[IN_ADDR]] to ptr
// CHECK-NEXT: store bfloat [[IN:%.*]], ptr [[IN_ADDR_ASCAST]], align 2
// CHECK-NEXT: [[TMP0:%.*]] = load bfloat, ptr [[IN_ADDR_ASCAST]], align 2
// CHECK-NEXT: [[CALL:%.*]] = call contract noundef bfloat @_Z8test_retu6__bf16(bfloat noundef [[TMP0]]) #[[ATTR1:[0-9]+]]
// CHECK-NEXT: ret bfloat [[CALL]]
//
__device__ __bf16 test_call( __bf16 in) {
return test_ret(in);
}


// CHECK-LABEL: @_Z15test_vec_assignv(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[VEC2_A:%.*]] = alloca <2 x bfloat>, align 4, addrspace(5)
// CHECK-NEXT: [[VEC2_B:%.*]] = alloca <2 x bfloat>, align 4, addrspace(5)
// CHECK-NEXT: [[VEC4_A:%.*]] = alloca <4 x bfloat>, align 8, addrspace(5)
// CHECK-NEXT: [[VEC4_B:%.*]] = alloca <4 x bfloat>, align 8, addrspace(5)
// CHECK-NEXT: [[VEC8_A:%.*]] = alloca <8 x bfloat>, align 16, addrspace(5)
// CHECK-NEXT: [[VEC8_B:%.*]] = alloca <8 x bfloat>, align 16, addrspace(5)
// CHECK-NEXT: [[VEC16_A:%.*]] = alloca <16 x bfloat>, align 32, addrspace(5)
// CHECK-NEXT: [[VEC16_B:%.*]] = alloca <16 x bfloat>, align 32, addrspace(5)
// CHECK-NEXT: [[VEC2_A_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[VEC2_A]] to ptr
// CHECK-NEXT: [[VEC2_B_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[VEC2_B]] to ptr
// CHECK-NEXT: [[VEC4_A_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[VEC4_A]] to ptr
// CHECK-NEXT: [[VEC4_B_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[VEC4_B]] to ptr
// CHECK-NEXT: [[VEC8_A_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[VEC8_A]] to ptr
// CHECK-NEXT: [[VEC8_B_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[VEC8_B]] to ptr
// CHECK-NEXT: [[VEC16_A_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[VEC16_A]] to ptr
// CHECK-NEXT: [[VEC16_B_ASCAST:%.*]] = addrspacecast ptr addrspace(5) [[VEC16_B]] to ptr
// CHECK-NEXT: [[TMP0:%.*]] = load <2 x bfloat>, ptr [[VEC2_B_ASCAST]], align 4
// CHECK-NEXT: store <2 x bfloat> [[TMP0]], ptr [[VEC2_A_ASCAST]], align 4
// CHECK-NEXT: [[TMP1:%.*]] = load <4 x bfloat>, ptr [[VEC4_B_ASCAST]], align 8
// CHECK-NEXT: store <4 x bfloat> [[TMP1]], ptr [[VEC4_A_ASCAST]], align 8
// CHECK-NEXT: [[TMP2:%.*]] = load <8 x bfloat>, ptr [[VEC8_B_ASCAST]], align 16
// CHECK-NEXT: store <8 x bfloat> [[TMP2]], ptr [[VEC8_A_ASCAST]], align 16
// CHECK-NEXT: [[TMP3:%.*]] = load <16 x bfloat>, ptr [[VEC16_B_ASCAST]], align 32
// CHECK-NEXT: store <16 x bfloat> [[TMP3]], ptr [[VEC16_A_ASCAST]], align 32
// CHECK-NEXT: ret void
//
__device__ void test_vec_assign() {
typedef __attribute__((ext_vector_type(2))) __bf16 bf16_x2;
bf16_x2 vec2_a, vec2_b;
vec2_a = vec2_b;

typedef __attribute__((ext_vector_type(4))) __bf16 bf16_x4;
bf16_x4 vec4_a, vec4_b;
vec4_a = vec4_b;

typedef __attribute__((ext_vector_type(8))) __bf16 bf16_x8;
bf16_x8 vec8_a, vec8_b;
vec8_a = vec8_b;

typedef __attribute__((ext_vector_type(16))) __bf16 bf16_x16;
bf16_x16 vec16_a, vec16_b;
vec16_a = vec16_b;
}
99 changes: 99 additions & 0 deletions clang/test/SemaCUDA/amdgpu-bf16.cu
@@ -0,0 +1,99 @@
// REQUIRES: amdgpu-registered-target
// REQUIRES: x86-registered-target

// RUN: %clang_cc1 "-triple" "x86_64-unknown-linux-gnu" "-aux-triple" "amdgcn-amd-amdhsa"\
// RUN: "-target-cpu" "x86-64" -fsyntax-only -verify=amdgcn %s
// RUN: %clang_cc1 "-aux-triple" "x86_64-unknown-linux-gnu" "-triple" "amdgcn-amd-amdhsa"\
// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -fsyntax-only -verify=amdgcn %s

// RUN: %clang_cc1 "-aux-triple" "x86_64-unknown-linux-gnu" "-triple" "r600-unknown-unknown"\
// RUN: -fcuda-is-device "-aux-target-cpu" "x86-64" -fsyntax-only -verify=amdgcn,r600 %s

// AMDGCN has storage-only support for bf16. R600 does not support it should error out when
// it's the main target.

#include "Inputs/cuda.h"

// There should be no errors on using the type itself, or when loading/storing values for amdgcn.
// r600 should error on all uses of the type.

// r600-error@+1 {{__bf16 is not supported on this target}}
typedef __attribute__((ext_vector_type(2))) __bf16 bf16_x2;
// r600-error@+1 {{__bf16 is not supported on this target}}
typedef __attribute__((ext_vector_type(4))) __bf16 bf16_x4;
// r600-error@+1 {{__bf16 is not supported on this target}}
typedef __attribute__((ext_vector_type(8))) __bf16 bf16_x8;
// r600-error@+1 {{__bf16 is not supported on this target}}
typedef __attribute__((ext_vector_type(16))) __bf16 bf16_x16;

// r600-error@+1 2 {{__bf16 is not supported on this target}}
__device__ void test(bool b, __bf16 *out, __bf16 in) {
__bf16 bf16 = in; // r600-error {{__bf16 is not supported on this target}}

bf16 + bf16; // amdgcn-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
bf16 - bf16; // amdgcn-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
bf16 * bf16; // amdgcn-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}
bf16 / bf16; // amdgcn-error {{invalid operands to binary expression ('__bf16' and '__bf16')}}

__fp16 fp16;

bf16 + fp16; // amdgcn-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
fp16 + bf16; // amdgcn-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
bf16 - fp16; // amdgcn-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
fp16 - bf16; // amdgcn-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
bf16 * fp16; // amdgcn-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
fp16 * bf16; // amdgcn-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
bf16 / fp16; // amdgcn-error {{invalid operands to binary expression ('__bf16' and '__fp16')}}
fp16 / bf16; // amdgcn-error {{invalid operands to binary expression ('__fp16' and '__bf16')}}
bf16 = fp16; // amdgcn-error {{assigning to '__bf16' from incompatible type '__fp16'}}
fp16 = bf16; // amdgcn-error {{assigning to '__fp16' from incompatible type '__bf16'}}
bf16 + (b ? fp16 : bf16); // amdgcn-error {{incompatible operand types ('__fp16' and '__bf16')}}
*out = bf16;

// amdgcn-error@+1 {{static_cast from '__bf16' to 'unsigned short' is not allowed}}
unsigned short u16bf16 = static_cast<unsigned short>(bf16);
// amdgcn-error@+2 {{C-style cast from 'unsigned short' to '__bf16' is not allowed}}
// r600-error@+1 {{__bf16 is not supported on this target}}
bf16 = (__bf16)u16bf16;

// amdgcn-error@+1 {{static_cast from '__bf16' to 'float' is not allowed}}
float f32bf16 = static_cast<float>(bf16);
// amdgcn-error@+2 {{C-style cast from 'float' to '__bf16' is not allowed}}
// r600-error@+1 {{__bf16 is not supported on this target}}
bf16 = (__bf16)f32bf16;

// amdgcn-error@+1 {{static_cast from '__bf16' to 'double' is not allowed}}
double f64bf16 = static_cast<double>(bf16);
// amdgcn-error@+2 {{C-style cast from 'double' to '__bf16' is not allowed}}
// r600-error@+1 {{__bf16 is not supported on this target}}
bf16 = (__bf16)f64bf16;

// r600-error@+1 {{__bf16 is not supported on this target}}
typedef __attribute__((ext_vector_type(2))) __bf16 bf16_x2;
bf16_x2 vec2_a, vec2_b;
vec2_a = vec2_b;

// r600-error@+1 {{__bf16 is not supported on this target}}
typedef __attribute__((ext_vector_type(4))) __bf16 bf16_x4;
bf16_x4 vec4_a, vec4_b;
vec4_a = vec4_b;

// r600-error@+1 {{__bf16 is not supported on this target}}
typedef __attribute__((ext_vector_type(8))) __bf16 bf16_x8;
bf16_x8 vec8_a, vec8_b;
vec8_a = vec8_b;

// r600-error@+1 {{__bf16 is not supported on this target}}
typedef __attribute__((ext_vector_type(16))) __bf16 bf16_x16;
bf16_x16 vec16_a, vec16_b;
vec16_a = vec16_b;
}

// r600-error@+1 2 {{__bf16 is not supported on this target}}
__bf16 hostfn(__bf16 a) {
return a;
}

// r600-error@+2 {{__bf16 is not supported on this target}}
// r600-error@+1 {{vector size not an integral multiple of component size}}
typedef __bf16 foo __attribute__((__vector_size__(16), __aligned__(16)));
32 changes: 30 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Expand Up @@ -2908,8 +2908,16 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
break;
case ISD::BF16_TO_FP: {
// Always expand bf16 to f32 casts, they lower to ext + shift.
SDValue Op = DAG.getNode(ISD::BITCAST, dl, MVT::i16, Node->getOperand(0));
Op = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Op);
//
// Note that the operand of this code can be bf16 or an integer type in case
// bf16 is not supported on the target and was softened.
SDValue Op = Node->getOperand(0);
if (Op.getValueType() == MVT::bf16) {
Op = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32,
DAG.getNode(ISD::BITCAST, dl, MVT::i16, Op));
} else {
Op = DAG.getAnyExtOrTrunc(Op, dl, MVT::i32);
}
Op = DAG.getNode(
ISD::SHL, dl, MVT::i32, Op,
DAG.getConstant(16, dl,
Expand All @@ -2921,6 +2929,26 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
Results.push_back(Op);
break;
}
case ISD::FP_TO_BF16: {
SDValue Op = Node->getOperand(0);
if (Op.getValueType() != MVT::f32)
Op = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, Op,
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
Op = DAG.getNode(
ISD::SRL, dl, MVT::i32, DAG.getNode(ISD::BITCAST, dl, MVT::i32, Op),
DAG.getConstant(16, dl,
TLI.getShiftAmountTy(MVT::i32, DAG.getDataLayout())));
// The result of this node can be bf16 or an integer type in case bf16 is
// not supported on the target and was softened to i16 for storage.
if (Node->getValueType(0) == MVT::bf16) {
Op = DAG.getNode(ISD::BITCAST, dl, MVT::bf16,
DAG.getNode(ISD::TRUNCATE, dl, MVT::i16, Op));
} else {
Op = DAG.getAnyExtOrTrunc(Op, dl, Node->getValueType(0));
}
Results.push_back(Op);
break;
}
case ISD::SIGN_EXTEND_INREG: {
EVT ExtraVT = cast<VTSDNode>(Node->getOperand(1))->getVT();
EVT VT = Node->getValueType(0);
Expand Down
8 changes: 6 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Expand Up @@ -148,7 +148,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::FP_TO_UINT_SAT:
Res = PromoteIntRes_FP_TO_XINT_SAT(N); break;

case ISD::FP_TO_FP16: Res = PromoteIntRes_FP_TO_FP16(N); break;
case ISD::FP_TO_BF16:
case ISD::FP_TO_FP16:
Res = PromoteIntRes_FP_TO_FP16_BF16(N);
break;

case ISD::FLT_ROUNDS_: Res = PromoteIntRes_FLT_ROUNDS(N); break;

Expand Down Expand Up @@ -720,7 +723,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_XINT_SAT(SDNode *N) {
N->getOperand(1));
}

SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_FP16(SDNode *N) {
SDValue DAGTypeLegalizer::PromoteIntRes_FP_TO_FP16_BF16(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
SDLoc dl(N);

Expand Down Expand Up @@ -1667,6 +1670,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
OpNo); break;
case ISD::VP_TRUNCATE:
case ISD::TRUNCATE: Res = PromoteIntOp_TRUNCATE(N); break;
case ISD::BF16_TO_FP:
case ISD::FP16_TO_FP:
case ISD::VP_UINT_TO_FP:
case ISD::UINT_TO_FP: Res = PromoteIntOp_UINT_TO_FP(N); break;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Expand Up @@ -324,7 +324,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_EXTRACT_VECTOR_ELT(SDNode *N);
SDValue PromoteIntRes_FP_TO_XINT(SDNode *N);
SDValue PromoteIntRes_FP_TO_XINT_SAT(SDNode *N);
SDValue PromoteIntRes_FP_TO_FP16(SDNode *N);
SDValue PromoteIntRes_FP_TO_FP16_BF16(SDNode *N);
SDValue PromoteIntRes_FREEZE(SDNode *N);
SDValue PromoteIntRes_INT_EXTEND(SDNode *N);
SDValue PromoteIntRes_LOAD(LoadSDNode *N);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Expand Up @@ -163,6 +163,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
Expand);

setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v3f32, MVT::v3f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
Expand All @@ -178,6 +179,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
setLoadExtAction(ISD::EXTLOAD, MVT::v16f64, MVT::v16f32, Expand);

setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v3f64, MVT::v3f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
Expand Down Expand Up @@ -272,6 +274,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
setTruncStoreAction(MVT::v2i64, MVT::v2i16, Expand);
setTruncStoreAction(MVT::v2i64, MVT::v2i32, Expand);

setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
setTruncStoreAction(MVT::v2f32, MVT::v2f16, Expand);
setTruncStoreAction(MVT::v3f32, MVT::v3f16, Expand);
Expand All @@ -280,6 +283,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
setTruncStoreAction(MVT::v16f32, MVT::v16f16, Expand);
setTruncStoreAction(MVT::v32f32, MVT::v32f16, Expand);

setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
setTruncStoreAction(MVT::f64, MVT::f32, Expand);

Expand Down

0 comments on commit 678d894

Please sign in to comment.