diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 39799ee5d278d..210d0c34c8137 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -16731,7 +16731,8 @@ performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, // extract(dup x) -> x if (N0.getOpcode() == AArch64ISD::DUP) - return DAG.getZExtOrTrunc(N0.getOperand(0), SDLoc(N), VT); + return VT.isInteger() ? DAG.getZExtOrTrunc(N0.getOperand(0), SDLoc(N), VT) + : N0.getOperand(0); // Rewrite for pairwise fadd pattern // (f32 (extract_vector_elt diff --git a/llvm/test/CodeGen/AArch64/aarch64-dup-dot-crash.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-dot-crash.ll new file mode 100644 index 0000000000000..ebd7a8fce9e8a --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-dup-dot-crash.ll @@ -0,0 +1,39 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2 +; RUN: llc -mtriple=arm64-unknown-unknown < %s -o -| FileCheck %s + +; This test covers a case where extract_vector_elt is selected when DUP is +; generated. Where it tries to generate a ZextOrTrunc node with floating point +; type resulting in a crash. +; See https://reviews.llvm.org/D128144#4280024 for context +define void @dot_product(double %a) { +; CHECK-LABEL: dot_product: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: fmov d1, #1.00000000 +; CHECK-NEXT: fadd d0, d0, d1 +; CHECK-NEXT: fadd d0, d0, d1 +; CHECK-NEXT: movi d1, #0000000000000000 +; CHECK-NEXT: fadd d0, d0, d1 +; CHECK-NEXT: fsqrt d0, d0 +; CHECK-NEXT: fcmp d0, #0.0 +; CHECK-NEXT: ret +entry: + %fadd = call double @llvm.vector.reduce.fadd.v3f64(double %a, <3 x double> ) + %sqrt = call double @llvm.sqrt.f64(double %fadd) + %insert = insertelement <3 x double> zeroinitializer, double %sqrt, i64 0 + %shuffle = shufflevector <3 x double> %insert, <3 x double> zeroinitializer, <3 x i32> zeroinitializer + %mul = fmul <3 x double> %shuffle, + %shuffle.1 = extractelement <3 x double> %mul, i64 0 + %shuffle.2 = extractelement <3 x double> %mul, i64 1 + %cmp = fcmp ogt double %shuffle.2, 0.000000e+00 + br i1 %cmp, label %exit, label %bb.1 + +bb.1: + %mul.2 = fmul double %shuffle.1, 0.000000e+00 + br label %exit + +exit: + ret void +} + +declare double @llvm.sqrt.f64(double) +declare double @llvm.vector.reduce.fadd.v3f64(double, <3 x double>)