Skip to content

Commit

Permalink
[RISCV] Add cost model for scalable scatter and gather
Browse files Browse the repository at this point in the history
The costing we use for fixed length vector gather and scatter is to simply count up the memory ops, and multiply by a fixed memory op cost. For scalable vectors, we don't actually know how many lanes are active. Instead, we have to end up making a worst case assumption on how many lanes could be active. In the generic +V case, this results in very high costs, but we can do better when we know an upper bound on the VLEN.

There's some obvious ways to improve this - e.g. using information about VL and mask bits from the instruction to reduce the upper bound - but this seems like a reasonable starting point.

The resulting costs do bias us pretty strongly away from generating scatter/gather for generic +V.  Without this, we'd be returning an invalid cost and thus definitely not vectorizing, so no major change in practical behavior expected.

Differential Revision: https://reviews.llvm.org/D127541
  • Loading branch information
preames committed Jun 16, 2022
1 parent bbb73ad commit d764aa7
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 93 deletions.
24 changes: 15 additions & 9 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,21 @@ InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
Alignment, CostKind, I);

// FIXME: Only supporting fixed vectors for now.
if (!isa<FixedVectorType>(DataTy))
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
Alignment, CostKind, I);

auto *VTy = cast<FixedVectorType>(DataTy);
unsigned NumLoads = VTy->getNumElements();
InstructionCost MemOpCost =
getMemoryOpCost(Opcode, VTy->getElementType(), Alignment, 0, CostKind, I);
// Cost is proportional to the number of memory operations implied. For
// scalable vectors, we use an upper bound on that number since we don't
// know exactly what VL will be.
auto &VTy = *cast<VectorType>(DataTy);
InstructionCost MemOpCost = getMemoryOpCost(Opcode, VTy.getElementType(),
Alignment, 0, CostKind, I);
if (isa<ScalableVectorType>(VTy)) {
const unsigned EltSize = VTy.getScalarSizeInBits();
const unsigned MinSize = VTy.getPrimitiveSizeInBits().getKnownMinValue();
const unsigned VectorBitsMax = ST->getRealMaxVLen();
const unsigned MaxVLMAX =
RISCVTargetLowering::computeVLMAX(VectorBitsMax, EltSize, MinSize);
return MaxVLMAX * MemOpCost;
}
unsigned NumLoads = cast<FixedVectorType>(VTy).getNumElements();
return NumLoads * MemOpCost;
}

Expand Down
Loading

0 comments on commit d764aa7

Please sign in to comment.