-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV][TTI] Refactor getCastInstrCost to exit early #86619
Conversation
To reduce the indentation by using early returns, this patch hoist the return for illegal type and non vector type earlier. It should mostly be an NFC.
@llvm/pr-subscribers-backend-risc-v Author: Shih-Po Hung (arcbbb) ChangesTo reduce the indentation by using early returns, this patch hoist the return for illegal type and non vector type earlier. It should mostly be an NFC. Full diff: https://github.com/llvm/llvm-project/pull/86619.diff 1 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index f75b3d3caa62f2..65142a03f0a624 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -897,76 +897,73 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
TTI::CastContextHint CCH,
TTI::TargetCostKind CostKind,
const Instruction *I) {
- if (isa<VectorType>(Dst) && isa<VectorType>(Src)) {
- // FIXME: Need to compute legalizing cost for illegal types.
- if (!isTypeLegal(Src) || !isTypeLegal(Dst))
- return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
-
- // Skip if element size of Dst or Src is bigger than ELEN.
- if (Src->getScalarSizeInBits() > ST->getELen() ||
- Dst->getScalarSizeInBits() > ST->getELen())
- return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
-
- int ISD = TLI->InstructionOpcodeToISD(Opcode);
- assert(ISD && "Invalid opcode");
-
- // FIXME: Need to consider vsetvli and lmul.
- int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
- (int)Log2_32(Src->getScalarSizeInBits());
- switch (ISD) {
- case ISD::SIGN_EXTEND:
- case ISD::ZERO_EXTEND:
- if (Src->getScalarSizeInBits() == 1) {
- // We do not use vsext/vzext to extend from mask vector.
- // Instead we use the following instructions to extend from mask vector:
- // vmv.v.i v8, 0
- // vmerge.vim v8, v8, -1, v0
- return 2;
- }
- return 1;
- case ISD::TRUNCATE:
- if (Dst->getScalarSizeInBits() == 1) {
- // We do not use several vncvt to truncate to mask vector. So we could
- // not use PowDiff to calculate it.
- // Instead we use the following instructions to truncate to mask vector:
- // vand.vi v8, v8, 1
- // vmsne.vi v0, v8, 0
- return 2;
- }
- [[fallthrough]];
- case ISD::FP_EXTEND:
- case ISD::FP_ROUND:
- // Counts of narrow/widen instructions.
- return std::abs(PowDiff);
- case ISD::FP_TO_SINT:
- case ISD::FP_TO_UINT:
- case ISD::SINT_TO_FP:
- case ISD::UINT_TO_FP:
- if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
- // The cost of convert from or to mask vector is different from other
- // cases. We could not use PowDiff to calculate it.
- // For mask vector to fp, we should use the following instructions:
- // vmv.v.i v8, 0
- // vmerge.vim v8, v8, -1, v0
- // vfcvt.f.x.v v8, v8
-
- // And for fp vector to mask, we use:
- // vfncvt.rtz.x.f.w v9, v8
- // vand.vi v8, v9, 1
- // vmsne.vi v0, v8, 0
- return 3;
- }
- if (std::abs(PowDiff) <= 1)
- return 1;
- // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
- // so it only need two conversion.
- if (Src->isIntOrIntVectorTy())
- return 2;
- // Counts of narrow/widen instructions.
- return std::abs(PowDiff);
+ bool IsVectorType = isa<VectorType>(Dst) && isa<VectorType>(Src);
+ bool IsTypeLegal = isTypeLegal(Src) && isTypeLegal(Dst) &&
+ (Src->getScalarSizeInBits() <= ST->getELen()) &&
+ (Dst->getScalarSizeInBits() <= ST->getELen());
+
+ // FIXME: Need to compute legalizing cost for illegal types.
+ if (!IsVectorType || !IsTypeLegal)
+ return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
+
+ int ISD = TLI->InstructionOpcodeToISD(Opcode);
+ assert(ISD && "Invalid opcode");
+
+ // FIXME: Need to consider vsetvli and lmul.
+ int PowDiff = (int)Log2_32(Dst->getScalarSizeInBits()) -
+ (int)Log2_32(Src->getScalarSizeInBits());
+ switch (ISD) {
+ case ISD::SIGN_EXTEND:
+ case ISD::ZERO_EXTEND:
+ if (Src->getScalarSizeInBits() == 1) {
+ // We do not use vsext/vzext to extend from mask vector.
+ // Instead we use the following instructions to extend from mask vector:
+ // vmv.v.i v8, 0
+ // vmerge.vim v8, v8, -1, v0
+ return 2;
}
+ return 1;
+ case ISD::TRUNCATE:
+ if (Dst->getScalarSizeInBits() == 1) {
+ // We do not use several vncvt to truncate to mask vector. So we could
+ // not use PowDiff to calculate it.
+ // Instead we use the following instructions to truncate to mask vector:
+ // vand.vi v8, v8, 1
+ // vmsne.vi v0, v8, 0
+ return 2;
+ }
+ [[fallthrough]];
+ case ISD::FP_EXTEND:
+ case ISD::FP_ROUND:
+ // Counts of narrow/widen instructions.
+ return std::abs(PowDiff);
+ case ISD::FP_TO_SINT:
+ case ISD::FP_TO_UINT:
+ case ISD::SINT_TO_FP:
+ case ISD::UINT_TO_FP:
+ if (Src->getScalarSizeInBits() == 1 || Dst->getScalarSizeInBits() == 1) {
+ // The cost of convert from or to mask vector is different from other
+ // cases. We could not use PowDiff to calculate it.
+ // For mask vector to fp, we should use the following instructions:
+ // vmv.v.i v8, 0
+ // vmerge.vim v8, v8, -1, v0
+ // vfcvt.f.x.v v8, v8
+
+ // And for fp vector to mask, we use:
+ // vfncvt.rtz.x.f.w v9, v8
+ // vand.vi v8, v9, 1
+ // vmsne.vi v0, v8, 0
+ return 3;
+ }
+ if (std::abs(PowDiff) <= 1)
+ return 1;
+ // Backend could lower (v[sz]ext i8 to double) to vfcvt(v[sz]ext.f8 i8),
+ // so it only need two conversion.
+ if (Src->isIntOrIntVectorTy())
+ return 2;
+ // Counts of narrow/widen instructions.
+ return std::abs(PowDiff);
}
- return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
}
unsigned RISCVTTIImpl::getEstimatedVLFor(VectorType *Ty) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
} | ||
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're missing a return if none of the cases in the switch match
llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp:967:1: warning: non-void function does not return a value in all control paths [-Wreturn-type]
967 | }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the trouble. I fix it with commit 5dc0c75
Hi, This patch is causing an assertion error when building builtins-riscv64-unknown-linux-gnu:
This issue still persist after your fix forward patch landed. Could you revert your changes and fix them and reland them? Entire stdout from the build: https://logs.chromium.org/logs/fuchsia/buildbucket/cr-buildbucket/8752425016802502289/+/u/clang/build/stdout Thanks. |
Shbould be fixed after 2fbc40d |
Thanks @topperc and @zeroomega ! |
To reduce the indentation by using early returns, this patch hoist the return for illegal type and non vector type earlier.
It should mostly be an NFC.