Skip to content

Commit

Permalink
AMDGPU: Make v4bf16 a legal type (#76217)
Browse files Browse the repository at this point in the history
Gets a few code quality improvements. A few cases are worse
from losing load narrowing.
Depends #76213 #76214 #76215
  • Loading branch information
arsenm committed Jan 5, 2024
1 parent c1eef48 commit 4768563
Show file tree
Hide file tree
Showing 8 changed files with 5,678 additions and 6,430 deletions.
19 changes: 10 additions & 9 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,16 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
Custom);
setOperationAction(
ISD::EXTRACT_SUBVECTOR,
{MVT::v2f16, MVT::v2i16, MVT::v4f16, MVT::v4i16, MVT::v2f32,
MVT::v2i32, MVT::v3f32, MVT::v3i32, MVT::v4f32, MVT::v4i32,
MVT::v5f32, MVT::v5i32, MVT::v6f32, MVT::v6i32, MVT::v7f32,
MVT::v7i32, MVT::v8f32, MVT::v8i32, MVT::v9f32, MVT::v9i32,
MVT::v10i32, MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32,
MVT::v12f32, MVT::v16f16, MVT::v16i16, MVT::v16f32, MVT::v16i32,
MVT::v32f32, MVT::v32i32, MVT::v2f64, MVT::v2i64, MVT::v3f64,
MVT::v3i64, MVT::v4f64, MVT::v4i64, MVT::v8f64, MVT::v8i64,
MVT::v16f64, MVT::v16i64, MVT::v32i16, MVT::v32f16},
{MVT::v2f16, MVT::v2i16, MVT::v2bf16, MVT::v4f16, MVT::v4i16,
MVT::v4bf16, MVT::v2f32, MVT::v2i32, MVT::v3f32, MVT::v3i32,
MVT::v4f32, MVT::v4i32, MVT::v5f32, MVT::v5i32, MVT::v6f32,
MVT::v6i32, MVT::v7f32, MVT::v7i32, MVT::v8f32, MVT::v8i32,
MVT::v9f32, MVT::v9i32, MVT::v10i32, MVT::v10f32, MVT::v11i32,
MVT::v11f32, MVT::v12i32, MVT::v12f32, MVT::v16f16, MVT::v16i16,
MVT::v16f32, MVT::v16i32, MVT::v32f32, MVT::v32i32, MVT::v2f64,
MVT::v2i64, MVT::v3f64, MVT::v3i64, MVT::v4f64, MVT::v4i64,
MVT::v8f64, MVT::v8i64, MVT::v16f64, MVT::v16i64, MVT::v32i16,
MVT::v32f16},
Custom);

setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
Expand Down
25 changes: 16 additions & 9 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
Expand Down Expand Up @@ -312,10 +313,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
{MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
MVT::v4f16, MVT::v3i64, MVT::v3f64, MVT::v6i32, MVT::v6f32,
MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64, MVT::v8i16,
MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64, MVT::v16f64,
MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
MVT::v4f16, MVT::v4bf16, MVT::v3i64, MVT::v3f64, MVT::v6i32,
MVT::v6f32, MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64,
MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64,
MVT::v16f64, MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
switch (Op) {
case ISD::LOAD:
Expand Down Expand Up @@ -421,13 +422,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
{MVT::v8i32, MVT::v8f32, MVT::v16i32, MVT::v16f32},
Expand);

setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16}, Custom);
setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16, MVT::v4bf16},
Custom);

// Avoid stack access for these.
// TODO: Generalize to more vector types.
setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
{MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
MVT::v8i8, MVT::v4i16, MVT::v4f16},
MVT::v8i8, MVT::v4i16, MVT::v4f16, MVT::v4bf16},
Custom);

// Deal with vec3 vector operations when widened to vec4.
Expand Down Expand Up @@ -667,11 +669,15 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
AddPromotedToType(ISD::LOAD, MVT::v4i16, MVT::v2i32);
setOperationAction(ISD::LOAD, MVT::v4f16, Promote);
AddPromotedToType(ISD::LOAD, MVT::v4f16, MVT::v2i32);
setOperationAction(ISD::LOAD, MVT::v4bf16, Promote);
AddPromotedToType(ISD::LOAD, MVT::v4bf16, MVT::v2i32);

setOperationAction(ISD::STORE, MVT::v4i16, Promote);
AddPromotedToType(ISD::STORE, MVT::v4i16, MVT::v2i32);
setOperationAction(ISD::STORE, MVT::v4f16, Promote);
AddPromotedToType(ISD::STORE, MVT::v4f16, MVT::v2i32);
setOperationAction(ISD::STORE, MVT::v4bf16, Promote);
AddPromotedToType(ISD::STORE, MVT::v4bf16, MVT::v2i32);

setOperationAction(ISD::LOAD, MVT::v8i16, Promote);
AddPromotedToType(ISD::LOAD, MVT::v8i16, MVT::v4i32);
Expand Down Expand Up @@ -781,7 +787,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
Custom);

setOperationAction(ISD::FEXP, MVT::v2f16, Custom);
setOperationAction(ISD::SELECT, {MVT::v4i16, MVT::v4f16}, Custom);
setOperationAction(ISD::SELECT, {MVT::v4i16, MVT::v4f16, MVT::v4bf16},
Custom);

if (Subtarget->hasPackedFP32Ops()) {
setOperationAction({ISD::FADD, ISD::FMUL, ISD::FMA, ISD::FNEG},
Expand Down Expand Up @@ -6805,7 +6812,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
SDLoc SL(Op);
EVT VT = Op.getValueType();

if (VT == MVT::v4i16 || VT == MVT::v4f16 ||
if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4bf16 ||
VT == MVT::v8i16 || VT == MVT::v8f16) {
EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
VT.getVectorNumElements() / 2);
Expand Down Expand Up @@ -6871,7 +6878,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
}

assert(VT == MVT::v2f16 || VT == MVT::v2i16);
assert(VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16);
assert(!Subtarget->hasVOP3PInsts() && "this should be legal");

SDValue Lo = Op.getOperand(0);
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/AMDGPU/SIInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,19 @@ def : BitConvert <f64, v2i32, VReg_64>;
def : BitConvert <v2i32, f64, VReg_64>;
def : BitConvert <v4i16, v4f16, VReg_64>;
def : BitConvert <v4f16, v4i16, VReg_64>;
def : BitConvert <v4bf16, v2i32, VReg_64>;
def : BitConvert <v2i32, v4bf16, VReg_64>;
def : BitConvert <v4bf16, i64, VReg_64>;
def : BitConvert <i64, v4bf16, VReg_64>;
def : BitConvert <v4bf16, v4i16, VReg_64>;
def : BitConvert <v4i16, v4bf16, VReg_64>;
def : BitConvert <v4bf16, v4f16, VReg_64>;
def : BitConvert <v4f16, v4bf16, VReg_64>;
def : BitConvert <v4bf16, v2f32, VReg_64>;
def : BitConvert <v2f32, v4bf16, VReg_64>;
def : BitConvert <v4bf16, f64, VReg_64>;
def : BitConvert <f64, v4bf16, VReg_64>;


// FIXME: Make SGPR
def : BitConvert <v2i32, v4f16, VReg_64>;
Expand Down
Loading

0 comments on commit 4768563

Please sign in to comment.