diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 20290c958a70e..2854dd4d96746 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -24615,7 +24615,8 @@ void AArch64TargetLowering::ReplaceBITCASTResults( return; } - if (SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1) + if (SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 && + !VT.isVector()) return replaceBoolVectorBitcast(N, Results, DAG); if (VT != MVT::i16 || (SrcVT != MVT::f16 && SrcVT != MVT::bf16)) diff --git a/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll b/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll index 1b22e2f900ddb..557aa010b3a7d 100644 --- a/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll +++ b/llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll @@ -489,3 +489,31 @@ define i6 @no_combine_illegal_num_elements(<6 x i32> %vec) { %bitmask = bitcast <6 x i1> %cmp_result to i6 ret i6 %bitmask } + +; Only apply the combine when casting a vector to a scalar. +define <2 x i8> @vector_to_vector_cast(<16 x i1> %arg) nounwind { +; CHECK-LABEL: vector_to_vector_cast: +; CHECK: ; %bb.0: +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: shl.16b v0, v0, #7 +; CHECK-NEXT: Lloh36: +; CHECK-NEXT: adrp x8, lCPI20_0@PAGE +; CHECK-NEXT: Lloh37: +; CHECK-NEXT: ldr q1, [x8, lCPI20_0@PAGEOFF] +; CHECK-NEXT: add x8, sp, #14 +; CHECK-NEXT: cmlt.16b v0, v0, #0 +; CHECK-NEXT: and.16b v0, v0, v1 +; CHECK-NEXT: ext.16b v1, v0, v0, #8 +; CHECK-NEXT: zip1.16b v0, v0, v1 +; CHECK-NEXT: addv.8h h0, v0 +; CHECK-NEXT: str h0, [sp, #14] +; CHECK-NEXT: ld1.b { v0 }[0], [x8] +; CHECK-NEXT: orr x8, x8, #0x1 +; CHECK-NEXT: ld1.b { v0 }[4], [x8] +; CHECK-NEXT: ; kill: def $d0 killed $d0 killed $q0 +; CHECK-NEXT: add sp, sp, #16 +; CHECK-NEXT: ret +; CHECK-NEXT: .loh AdrpLdr Lloh36, Lloh37 + %bc = bitcast <16 x i1> %arg to <2 x i8> + ret <2 x i8> %bc +}