Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Don't look through EXTRACT_ELEMENT in lowerScalarInsert if the element types are different. #78668

Merged
merged 1 commit into from
Jan 19, 2024

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Jan 19, 2024

If the element type of we're extracting doesn't match the type we're inserting into, we can't directly insert or extract the subvector.

…e element types are different.

If the element type of we're extracting doesn't match the type
we're inserting into, we can't directly insert or extract the subvector.
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 19, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

If the element type of we're extracting doesn't match the type we're inserting into, we can't directly insert or extract the subvector.


Full diff: https://github.com/llvm/llvm-project/pull/78668.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+16-12)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll (+23)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 019ea93aac9dc6..b41e2f40dc72f0 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -4040,19 +4040,23 @@ static SDValue lowerScalarInsert(SDValue Scalar, SDValue VL, MVT VT,
   if (Scalar.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
       isNullConstant(Scalar.getOperand(1))) {
     SDValue ExtractedVal = Scalar.getOperand(0);
-    MVT ExtractedVT = ExtractedVal.getSimpleValueType();
-    MVT ExtractedContainerVT = ExtractedVT;
-    if (ExtractedContainerVT.isFixedLengthVector()) {
-      ExtractedContainerVT = getContainerForFixedLengthVector(
-          DAG, ExtractedContainerVT, Subtarget);
-      ExtractedVal = convertToScalableVector(ExtractedContainerVT, ExtractedVal,
-                                             DAG, Subtarget);
-    }
-    if (ExtractedContainerVT.bitsLE(VT))
-      return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru, ExtractedVal,
+    // The element types must be the same.
+    if (ExtractedVal.getValueType().getVectorElementType() ==
+        VT.getVectorElementType()) {
+      MVT ExtractedVT = ExtractedVal.getSimpleValueType();
+      MVT ExtractedContainerVT = ExtractedVT;
+      if (ExtractedContainerVT.isFixedLengthVector()) {
+        ExtractedContainerVT = getContainerForFixedLengthVector(
+            DAG, ExtractedContainerVT, Subtarget);
+        ExtractedVal = convertToScalableVector(ExtractedContainerVT,
+                                               ExtractedVal, DAG, Subtarget);
+      }
+      if (ExtractedContainerVT.bitsLE(VT))
+        return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru,
+                           ExtractedVal, DAG.getConstant(0, DL, XLenVT));
+      return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ExtractedVal,
                          DAG.getConstant(0, DL, XLenVT));
-    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ExtractedVal,
-                       DAG.getConstant(0, DL, XLenVT));
+    }
   }
 
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll b/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll
index 8fd8c2548ff34d..351c0bab9dca89 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll
@@ -343,3 +343,26 @@ declare i64 @llvm.smax.i64(i64, i64)
 declare i64 @llvm.smin.i64(i64, i64)
 declare float @llvm.maxnum.f32(float ,float)
 declare float @llvm.minnum.f32(float ,float)
+
+define void @crash(<2 x i32> %0) {
+; CHECK-LABEL: crash:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmv.s.x v9, a0
+; CHECK-NEXT:    vredsum.vs v8, v8, v9
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    sb a0, 0(zero)
+; CHECK-NEXT:    ret
+entry:
+  %1 = extractelement <2 x i32> %0, i64 0
+  %2 = tail call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> zeroinitializer)
+  %3 = zext i16 %2 to i32
+  %op.rdx = add i32 %1, %3
+  %conv18.us = trunc i32 %op.rdx to i8
+  store i8 %conv18.us, ptr null, align 1
+  ret void
+}
+declare i16 @llvm.vector.reduce.add.v4i16(<4 x i16>)

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@topperc topperc merged commit 0ad83bc into llvm:main Jan 19, 2024
4 of 5 checks passed
@topperc topperc deleted the pr/lowerScalarInsert-crash branch January 19, 2024 06:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants