-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV][TTI] Add checks for invalid cast operations #88854
Conversation
In issue llvm#88802, the LV cost model would query the cost of the TRUNC for source type 2xi1 and destination type 2xi32. This patch adds an early exit check to prevent invalid operations.
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-backend-risc-v Author: Shih-Po Hung (arcbbb) ChangesIn issue #88802, the LV cost model would query the cost of the TRUNC for source type 2xi1 and destination type 2xi32. This patch adds an early exit check to prevent invalid operations. Full diff: https://github.com/llvm/llvm-project/pull/88854.diff 1 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 38304ff90252f0..c4f1c275f63b65 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -956,6 +956,9 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
return getRISCVInstructionCost(Op, DstLT.second, CostKind);
}
case ISD::TRUNCATE:
+ // Early return for invalid operation
+ if (Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits())
+ break;
if (Dst->getScalarSizeInBits() == 1) {
// We do not use several vncvt to truncate to mask vector. So we could
// not use PowDiff to calculate it.
@@ -968,6 +971,13 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
[[fallthrough]];
case ISD::FP_EXTEND:
case ISD::FP_ROUND: {
+ // Early return for invalid operation
+ if ((ISD == ISD::FP_ROUND) &&
+ Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits())
+ break;
+ if ((ISD == ISD::FP_EXTEND) &&
+ Src->getScalarSizeInBits() >= Dst->getScalarSizeInBits())
+ break;
// Counts of narrow/widen instructions.
unsigned SrcEltSize = Src->getScalarSizeInBits();
unsigned DstEltSize = Dst->getScalarSizeInBits();
|
This fixes a crash in #88802. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
If I'm reading the code and discussion right, the LV is querying a truncate with result wider than source? If so, this does not seem like the right fix. If LV is querying an invalid argument combination, that's a LV bug and should be fixed as such. |
@@ -956,6 +956,9 @@ InstructionCost RISCVTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, | |||
return getRISCVInstructionCost(Op, DstLT.second, CostKind); | |||
} | |||
case ISD::TRUNCATE: | |||
// Early return for invalid operation | |||
if (Dst->getScalarSizeInBits() >= Src->getScalarSizeInBits()) | |||
break; | |||
if (Dst->getScalarSizeInBits() == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't repeatedly call getScalarSizeInBits on the same object. Use a variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thanks!
Test case? |
Added in the updates. Thanks! |
Yes, the actual root cause is LV cost model. I think for TTI cost it is better to return invalid cost or delegate it to the base TTI rather than crash on it, |
TTI should assert sanity of the API, and we should fix the caller. It's very likely there's a higher level semantics bug in the LV causing this, and we need to find and fix that one. |
OK, I create #89161 for this. |
In issue #88802, the LV cost model would query the cost of the TRUNC for source type 2xi1 and destination type 2xi32. This patch adds an early exit check to prevent invalid operations.