diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index f75b3d3caa62f2..65142a03f0a624 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -897,76 +897,73 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, const Instruction *I) { - if (isa(Dst) && isa(Src)) { - // FIXME: Need to compute legalizing cost for illegal types. - if (!isTypeLegal(Src) || !isTypeLegal(Dst)) - return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); - - // Skip if element size of Dst or Src is bigger than ELEN. - if (Src->getScalarSizeInBits() > ST->getELen() || - Dst->getScalarSizeInBits() > ST->getELen()) - return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); - - int ISD = TLI->InstructionOpcodeToISD(Opcode); - assert(ISD && "Invalid opcode"); - - // FIXME: Need to consider vsetvli and lmul. - int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) - - (int)Log2_32(Src->getScalarSizeInBits()); - switch (ISD) { - case ISD::SIGN_EXTEND: - case ISD::ZERO_EXTEND: - if (Src->getScalarSizeInBits() == 1) { - // We do not use vsext/vzext to extend from mask vector. - // Instead we use the following instructions to extend from mask vector: - // vmv.v.i v8, 0 - // vmerge.vim v8, v8, -1, v0 - return 2; - } - return 1; - case ISD::TRUNCATE: - if (Dst->getScalarSizeInBits() == 1) { - // We do not use several vncvt to truncate to mask vector. So we could - // not use PowDiff to calculate it. - // Instead we use the following instructions to truncate to mask vector: - // vand.vi v8, v8, 1 - // vmsne.vi v0, v8, 0 - return 2; - } - [[fallthrough]]; - case ISD::FP_EXTEND: - case ISD::FP_ROUND: - // Counts of narrow/widen instructions. - return std::abs(PowDiff); - case ISD::FP_TO_SINT: - case ISD::FP_TO_UINT: - case ISD::SINT_TO_FP: - case ISD::UINT_TO_FP: - if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) { - // The cost of convert from or to mask vector is different from other - // cases. We could not use PowDiff to calculate it. - // For mask vector to fp, we should use the following instructions: - // vmv.v.i v8, 0 - // vmerge.vim v8, v8, -1, v0 - // vfcvt.f.x.v v8, v8 - - // And for fp vector to mask, we use: - // vfncvt.rtz.x.f.w v9, v8 - // vand.vi v8, v9, 1 - // vmsne.vi v0, v8, 0 - return 3; - } - if (std::abs(PowDiff) <= 1) - return 1; - // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8), - // so it only need two conversion. - if (Src->isIntOrIntVectorTy()) - return 2; - // Counts of narrow/widen instructions. - return std::abs(PowDiff); + bool IsVectorType = isa(Dst) && isa(Src); + bool IsTypeLegal = isTypeLegal(Src) && isTypeLegal(Dst) && + (Src->getScalarSizeInBits() <= ST->getELen()) && + (Dst->getScalarSizeInBits() <= ST->getELen()); + + // FIXME: Need to compute legalizing cost for illegal types. + if (!IsVectorType || !IsTypeLegal) + return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); + + int ISD = TLI->InstructionOpcodeToISD(Opcode); + assert(ISD && "Invalid opcode"); + + // FIXME: Need to consider vsetvli and lmul. + int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) - + (int)Log2_32(Src->getScalarSizeInBits()); + switch (ISD) { + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + if (Src->getScalarSizeInBits() == 1) { + // We do not use vsext/vzext to extend from mask vector. + // Instead we use the following instructions to extend from mask vector: + // vmv.v.i v8, 0 + // vmerge.vim v8, v8, -1, v0 + return 2; } + return 1; + case ISD::TRUNCATE: + if (Dst->getScalarSizeInBits() == 1) { + // We do not use several vncvt to truncate to mask vector. So we could + // not use PowDiff to calculate it. + // Instead we use the following instructions to truncate to mask vector: + // vand.vi v8, v8, 1 + // vmsne.vi v0, v8, 0 + return 2; + } + [[fallthrough]]; + case ISD::FP_EXTEND: + case ISD::FP_ROUND: + // Counts of narrow/widen instructions. + return std::abs(PowDiff); + case ISD::FP_TO_SINT: + case ISD::FP_TO_UINT: + case ISD::SINT_TO_FP: + case ISD::UINT_TO_FP: + if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) { + // The cost of convert from or to mask vector is different from other + // cases. We could not use PowDiff to calculate it. + // For mask vector to fp, we should use the following instructions: + // vmv.v.i v8, 0 + // vmerge.vim v8, v8, -1, v0 + // vfcvt.f.x.v v8, v8 + + // And for fp vector to mask, we use: + // vfncvt.rtz.x.f.w v9, v8 + // vand.vi v8, v9, 1 + // vmsne.vi v0, v8, 0 + return 3; + } + if (std::abs(PowDiff) <= 1) + return 1; + // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8), + // so it only need two conversion. + if (Src->isIntOrIntVectorTy()) + return 2; + // Counts of narrow/widen instructions. + return std::abs(PowDiff); } - return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); } unsigned RISCVTTIImpl::getEstimatedVLFor(VectorType *Ty) {