diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 166b0d84a2ce2..040752b92bad3 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7376,6 +7376,22 @@ static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) { return ZExtBool; } +void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI, + SDNode *Node) const { + // Live-in physreg copies that are glued to SMSTART are applied as + // implicit-def's in the InstrEmitter. Here we remove them, allowing the + // register allocator to pass call args in callee saved regs, without extra + // copies to avoid these fake clobbers of actually-preserved GPRs. + if (MI.getOpcode() == AArch64::MSRpstatesvcrImm1 || + MI.getOpcode() == AArch64::MSRpstatePseudo) + for (unsigned I = MI.getNumOperands() - 1; I > 0; --I) + if (MachineOperand &MO = MI.getOperand(I); + MO.isReg() && MO.isImplicit() && MO.isDef() && + (AArch64::GPR32RegClass.contains(MO.getReg()) || + AArch64::GPR64RegClass.contains(MO.getReg()))) + MI.removeOperand(I); +} + SDValue AArch64TargetLowering::changeStreamingMode( SelectionDAG &DAG, SDLoc DL, bool Enable, SDValue Chain, SDValue InGlue, SDValue PStateSM, bool Entry) const { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 169b0dbab65cd..357c9fe3f7025 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -975,6 +975,9 @@ class AArch64TargetLowering : public TargetLowering { const SDLoc &DL, SelectionDAG &DAG, SmallVectorImpl &InVals) const override; + void AdjustInstrPostInstrSelection(MachineInstr &MI, + SDNode *Node) const override; + SDValue LowerCall(CallLoweringInfo & /*CLI*/, SmallVectorImpl &InVals) const override; diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index bb9464a8d2e1c..8a76690a9add4 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -203,7 +203,9 @@ def : Pat<(i64 (int_aarch64_sme_get_tpidr2)), def MSRpstatePseudo : Pseudo<(outs), (ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>, - Sched<[WriteSys]>; + Sched<[WriteSys]> { + let hasPostISelHook = 1; +} def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)), (MSRpstatePseudo svcr_op:$pstate, 0b1, GPR64:$rtpstate, timm0_1:$expected_pstate)>; diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td index 6c9b1f11a4dec..777d64d888edb 100644 --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -221,6 +221,7 @@ def MSRpstatesvcrImm1 let Inst{11-9} = pstatefield; let Inst{8} = imm; let Inst{7-5} = 0b011; // op2 + let hasPostISelHook = 1; } def : InstAlias<"smstart", (MSRpstatesvcrImm1 0b011, 0b1)>; diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-compatible-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-compatible-interface.ll index 1ad6b189d6fa5..5d0c9127d3ebb 100644 --- a/llvm/test/CodeGen/AArch64/sme-streaming-compatible-interface.ll +++ b/llvm/test/CodeGen/AArch64/sme-streaming-compatible-interface.ll @@ -436,3 +436,56 @@ define void @disable_tailcallopt() "aarch64_pstate_sm_compatible" nounwind { tail call void @normal_callee(); ret void; } + +define void @call_to_non_streaming_pass_args(ptr nocapture noundef readnone %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2) "aarch64_pstate_sm_compatible" { +; CHECK-LABEL: call_to_non_streaming_pass_args: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sub sp, sp, #112 +; CHECK-NEXT: stp d15, d14, [sp, #32] // 16-byte Folded Spill +; CHECK-NEXT: stp d13, d12, [sp, #48] // 16-byte Folded Spill +; CHECK-NEXT: stp d11, d10, [sp, #64] // 16-byte Folded Spill +; CHECK-NEXT: stp d9, d8, [sp, #80] // 16-byte Folded Spill +; CHECK-NEXT: stp x30, x19, [sp, #96] // 16-byte Folded Spill +; CHECK-NEXT: .cfi_def_cfa_offset 112 +; CHECK-NEXT: .cfi_offset w19, -8 +; CHECK-NEXT: .cfi_offset w30, -16 +; CHECK-NEXT: .cfi_offset b8, -24 +; CHECK-NEXT: .cfi_offset b9, -32 +; CHECK-NEXT: .cfi_offset b10, -40 +; CHECK-NEXT: .cfi_offset b11, -48 +; CHECK-NEXT: .cfi_offset b12, -56 +; CHECK-NEXT: .cfi_offset b13, -64 +; CHECK-NEXT: .cfi_offset b14, -72 +; CHECK-NEXT: .cfi_offset b15, -80 +; CHECK-NEXT: stp d2, d3, [sp, #16] // 16-byte Folded Spill +; CHECK-NEXT: mov x8, x1 +; CHECK-NEXT: mov x9, x0 +; CHECK-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill +; CHECK-NEXT: bl __arm_sme_state +; CHECK-NEXT: and x19, x0, #0x1 +; CHECK-NEXT: tbz w19, #0, .LBB10_2 +; CHECK-NEXT: // %bb.1: // %entry +; CHECK-NEXT: smstop sm +; CHECK-NEXT: .LBB10_2: // %entry +; CHECK-NEXT: ldp s0, s1, [sp, #8] // 8-byte Folded Reload +; CHECK-NEXT: mov x0, x9 +; CHECK-NEXT: ldp d2, d3, [sp, #16] // 16-byte Folded Reload +; CHECK-NEXT: mov x1, x8 +; CHECK-NEXT: bl bar +; CHECK-NEXT: tbz w19, #0, .LBB10_4 +; CHECK-NEXT: // %bb.3: // %entry +; CHECK-NEXT: smstart sm +; CHECK-NEXT: .LBB10_4: // %entry +; CHECK-NEXT: ldp x30, x19, [sp, #96] // 16-byte Folded Reload +; CHECK-NEXT: ldp d9, d8, [sp, #80] // 16-byte Folded Reload +; CHECK-NEXT: ldp d11, d10, [sp, #64] // 16-byte Folded Reload +; CHECK-NEXT: ldp d13, d12, [sp, #48] // 16-byte Folded Reload +; CHECK-NEXT: ldp d15, d14, [sp, #32] // 16-byte Folded Reload +; CHECK-NEXT: add sp, sp, #112 +; CHECK-NEXT: ret +entry: + call void @bar(ptr noundef nonnull %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2) + ret void +} + +declare void @bar(ptr noundef, i64 noundef, i64 noundef, i32 noundef, i32 noundef, float noundef, float noundef, double noundef, double noundef) diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll index 102ed896ce7b3..dd7d6470ad7b0 100644 --- a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll +++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll @@ -368,15 +368,11 @@ define i8 @call_to_non_streaming_pass_sve_objects(ptr nocapture noundef readnone ; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill ; CHECK-NEXT: stp x29, x30, [sp, #64] // 16-byte Folded Spill ; CHECK-NEXT: addvl sp, sp, #-3 -; CHECK-NEXT: rdsvl x8, #1 -; CHECK-NEXT: addvl x9, sp, #2 -; CHECK-NEXT: addvl x10, sp, #1 -; CHECK-NEXT: mov x11, sp +; CHECK-NEXT: rdsvl x3, #1 +; CHECK-NEXT: addvl x0, sp, #2 +; CHECK-NEXT: addvl x1, sp, #1 +; CHECK-NEXT: mov x2, sp ; CHECK-NEXT: smstop sm -; CHECK-NEXT: mov x0, x9 -; CHECK-NEXT: mov x1, x10 -; CHECK-NEXT: mov x2, x11 -; CHECK-NEXT: mov x3, x8 ; CHECK-NEXT: bl foo ; CHECK-NEXT: smstart sm ; CHECK-NEXT: ptrue p0.b @@ -400,8 +396,37 @@ entry: ret i8 %vecext } +define void @call_to_non_streaming_pass_args(ptr nocapture noundef readnone %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2) #0 { +; CHECK-LABEL: call_to_non_streaming_pass_args: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sub sp, sp, #112 +; CHECK-NEXT: stp d15, d14, [sp, #32] // 16-byte Folded Spill +; CHECK-NEXT: stp d13, d12, [sp, #48] // 16-byte Folded Spill +; CHECK-NEXT: stp d11, d10, [sp, #64] // 16-byte Folded Spill +; CHECK-NEXT: stp d9, d8, [sp, #80] // 16-byte Folded Spill +; CHECK-NEXT: str x30, [sp, #96] // 8-byte Folded Spill +; CHECK-NEXT: stp d2, d3, [sp, #16] // 16-byte Folded Spill +; CHECK-NEXT: stp s0, s1, [sp, #8] // 8-byte Folded Spill +; CHECK-NEXT: smstop sm +; CHECK-NEXT: ldp s0, s1, [sp, #8] // 8-byte Folded Reload +; CHECK-NEXT: ldp d2, d3, [sp, #16] // 16-byte Folded Reload +; CHECK-NEXT: bl bar +; CHECK-NEXT: smstart sm +; CHECK-NEXT: ldp d9, d8, [sp, #80] // 16-byte Folded Reload +; CHECK-NEXT: ldr x30, [sp, #96] // 8-byte Folded Reload +; CHECK-NEXT: ldp d11, d10, [sp, #64] // 16-byte Folded Reload +; CHECK-NEXT: ldp d13, d12, [sp, #48] // 16-byte Folded Reload +; CHECK-NEXT: ldp d15, d14, [sp, #32] // 16-byte Folded Reload +; CHECK-NEXT: add sp, sp, #112 +; CHECK-NEXT: ret +entry: + call void @bar(ptr noundef nonnull %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2) + ret void +} + declare i64 @llvm.aarch64.sme.cntsb() declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef) +declare void @bar(ptr noundef, i64 noundef, i64 noundef, i32 noundef, i32 noundef, float noundef, float noundef, double noundef, double noundef) attributes #0 = { nounwind vscale_range(1,16) "aarch64_pstate_sm_enabled" }