Skip to content

Commit

Permalink
[LegalizeVectorOps][RISCV] Add scalable-vector SELECT expansion
Browse files Browse the repository at this point in the history
This patch extends VectorLegalizer::ExpandSELECT to permit expansion
also for scalable vector types. The only real change is conditionally
checking for BUILD_VECTOR or SPLAT_VECTOR legality depending on the
vector type.

We can use this to fix "cannot select" errors for scalable vector
selects on the RISCV target. Note that in future patches RISCV will
possibly custom-lower vector SELECTs to VSELECTs for branchless codegen.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D102063
  • Loading branch information
frasercrmck committed May 10, 2021
1 parent 9ba661f commit 6db0ced
Show file tree
Hide file tree
Showing 6 changed files with 6,885 additions and 5 deletions.
18 changes: 13 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,11 +924,16 @@ SDValue VectorLegalizer::ExpandSELECT(SDNode *Node) {
// AND,OR,XOR, we will have to scalarize the op.
// Notice that the operation may be 'promoted' which means that it is
// 'bitcasted' to another type which is handled.
// Also, we need to be able to construct a splat vector using BUILD_VECTOR.
// Also, we need to be able to construct a splat vector using either
// BUILD_VECTOR or SPLAT_VECTOR.
// FIXME: Should we also permit fixed-length SPLAT_VECTOR as a fallback to
// BUILD_VECTOR?
if (TLI.getOperationAction(ISD::AND, VT) == TargetLowering::Expand ||
TLI.getOperationAction(ISD::XOR, VT) == TargetLowering::Expand ||
TLI.getOperationAction(ISD::OR, VT) == TargetLowering::Expand ||
TLI.getOperationAction(ISD::BUILD_VECTOR, VT) == TargetLowering::Expand)
TLI.getOperationAction(ISD::OR, VT) == TargetLowering::Expand ||
TLI.getOperationAction(VT.isFixedLengthVector() ? ISD::BUILD_VECTOR
: ISD::SPLAT_VECTOR,
VT) == TargetLowering::Expand)
return DAG.UnrollVectorOp(Node);

// Generate a mask operand.
Expand All @@ -942,8 +947,11 @@ SDValue VectorLegalizer::ExpandSELECT(SDNode *Node) {
BitTy),
DAG.getConstant(0, DL, BitTy));

// Broadcast the mask so that the entire vector is all-one or all zero.
Mask = DAG.getSplatBuildVector(MaskTy, DL, Mask);
// Broadcast the mask so that the entire vector is all one or all zero.
if (VT.isFixedLengthVector())
Mask = DAG.getSplatBuildVector(MaskTy, DL, Mask);
else
Mask = DAG.getSplatVector(MaskTy, DL, Mask);

// Bitcast the operands to be the same type as the mask.
// This is needed when we select between FP types because
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,

setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);

setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);

setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
Expand Down Expand Up @@ -517,6 +520,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);

setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);

setOperationAction(ISD::STEP_VECTOR, VT, Custom);
setOperationAction(ISD::VECTOR_REVERSE, VT, Custom);

Expand Down Expand Up @@ -571,6 +577,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);

setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);

setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
Expand Down Expand Up @@ -695,6 +704,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FP_TO_UINT, VT, Custom);

setOperationAction(ISD::VSELECT, VT, Custom);
setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);

setOperationAction(ISD::ANY_EXTEND, VT, Custom);
setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
Expand Down Expand Up @@ -762,6 +773,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setCondCodeAction(CC, VT, Expand);

setOperationAction(ISD::VSELECT, VT, Custom);
setOperationAction(ISD::SELECT, VT, Expand);
setOperationAction(ISD::SELECT_CC, VT, Expand);

setOperationAction(ISD::BITCAST, VT, Custom);

Expand Down
Loading

0 comments on commit 6db0ced

Please sign in to comment.