diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index a6dd83cb199fe..f28ce858a169a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1830,6 +1830,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { break; } case Intrinsic::matrix_multiply: { + // Optimize negation in matrix multiplication. + // -A * -B -> A * B Value *A, *B; if (match(II->getArgOperand(0), m_FNeg(m_Value(A))) && @@ -1838,6 +1840,50 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { replaceOperand(*II, 1, B); return II; } + + Value *Op0 = II->getOperand(0); + Value *Op1 = II->getOperand(1); + Value *OpNotNeg, *NegatedOp; + unsigned NegatedOpArg, OtherOpArg; + if (match(Op0, m_FNeg(m_Value(OpNotNeg)))) { + NegatedOp = Op0; + NegatedOpArg = 0; + OtherOpArg = 1; + } else if (match(Op1, m_FNeg(m_Value(OpNotNeg)))) { + NegatedOp = Op1; + NegatedOpArg = 1; + OtherOpArg = 0; + } else + // Multiplication doesn't have a negated operand. + break; + + // Only optimize if the negated operand has only one use. + if (!NegatedOp->hasOneUse()) + break; + + Value *OtherOp = II->getOperand(OtherOpArg); + VectorType *RetTy = cast(II->getType()); + VectorType *NegatedOpTy = cast(NegatedOp->getType()); + VectorType *OtherOpTy = cast(OtherOp->getType()); + ElementCount NegatedCount = NegatedOpTy->getElementCount(); + ElementCount OtherCount = OtherOpTy->getElementCount(); + ElementCount RetCount = RetTy->getElementCount(); + // (-A) * B -> A * (-B), if it is cheaper to negate B and vice versa. + if (ElementCount::isKnownGT(NegatedCount, OtherCount) && + ElementCount::isKnownLT(OtherCount, RetCount)) { + Value *InverseOtherOp = Builder.CreateFNeg(OtherOp); + replaceOperand(*II, NegatedOpArg, OpNotNeg); + replaceOperand(*II, OtherOpArg, InverseOtherOp); + return II; + } + // (-A) * B -> -(A * B), if it is cheaper to negate the result + if (ElementCount::isKnownGT(NegatedCount, RetCount)) { + SmallVector NewArgs(II->args()); + NewArgs[NegatedOpArg] = OpNotNeg; + Instruction *NewMul = + Builder.CreateIntrinsic(II->getType(), IID, NewArgs, II); + return replaceInstUsesWith(*II, Builder.CreateFNegFMF(NewMul, II)); + } break; } case Intrinsic::fmuladd: { diff --git a/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll b/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll index b0eecd9d5255e..bd1050efc160f 100644 --- a/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll +++ b/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll @@ -4,9 +4,9 @@ ; The result has the fewest vector elements between the result and the two operands so the negation can be moved there define <2 x double> @test_negation_move_to_result(<6 x double> %a, <3 x double> %b) { ; CHECK-LABEL: @test_negation_move_to_result( -; CHECK-NEXT: [[A_NEG:%.*]] = fneg <6 x double> [[A:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1) -; CHECK-NEXT: ret <2 x double> [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1) +; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x double> [[TMP1]] +; CHECK-NEXT: ret <2 x double> [[TMP2]] ; %a.neg = fneg <6 x double> %a %res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1) @@ -17,20 +17,53 @@ define <2 x double> @test_negation_move_to_result(<6 x double> %a, <3 x double> ; Fast flag should be preserved define <2 x double> @test_negation_move_to_result_with_fastflags(<6 x double> %a, <3 x double> %b) { ; CHECK-LABEL: @test_negation_move_to_result_with_fastflags( -; CHECK-NEXT: [[A_NEG:%.*]] = fneg <6 x double> [[A:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1) -; CHECK-NEXT: ret <2 x double> [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1) +; CHECK-NEXT: [[TMP2:%.*]] = fneg fast <2 x double> [[TMP1]] +; CHECK-NEXT: ret <2 x double> [[TMP2]] ; %a.neg = fneg <6 x double> %a %res = tail call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1) ret <2 x double> %res } +define <2 x double> @test_negation_move_to_result_with_nnan_flag(<6 x double> %a, <3 x double> %b) { +; CHECK-LABEL: @test_negation_move_to_result_with_nnan_flag( +; CHECK-NEXT: [[TMP1:%.*]] = call nnan <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1) +; CHECK-NEXT: [[TMP2:%.*]] = fneg nnan <2 x double> [[TMP1]] +; CHECK-NEXT: ret <2 x double> [[TMP2]] +; + %a.neg = fneg <6 x double> %a + %res = tail call nnan <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1) + ret <2 x double> %res +} + +define <2 x double> @test_negation_move_to_result_with_nsz_flag(<6 x double> %a, <3 x double> %b) { +; CHECK-LABEL: @test_negation_move_to_result_with_nsz_flag( +; CHECK-NEXT: [[TMP1:%.*]] = call nsz <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1) +; CHECK-NEXT: [[TMP2:%.*]] = fneg nsz <2 x double> [[TMP1]] +; CHECK-NEXT: ret <2 x double> [[TMP2]] +; + %a.neg = fneg <6 x double> %a + %res = tail call nsz <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1) + ret <2 x double> %res +} + +define <2 x double> @test_negation_move_to_result_with_fastflag_on_negation(<6 x double> %a, <3 x double> %b) { +; CHECK-LABEL: @test_negation_move_to_result_with_fastflag_on_negation( +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1) +; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x double> [[TMP1]] +; CHECK-NEXT: ret <2 x double> [[TMP2]] +; + %a.neg = fneg fast<6 x double> %a + %res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1) + ret <2 x double> %res +} + ; %b has the fewest vector elements between the result and the two operands so the negation can be moved there define <9 x double> @test_move_negation_to_second_operand(<27 x double> %a, <3 x double> %b) { ; CHECK-LABEL: @test_move_negation_to_second_operand( -; CHECK-NEXT: [[A_NEG:%.*]] = fneg <27 x double> [[A:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]] +; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1) ; CHECK-NEXT: ret <9 x double> [[RES]] ; %a.neg = fneg <27 x double> %a @@ -42,8 +75,8 @@ define <9 x double> @test_move_negation_to_second_operand(<27 x double> %a, <3 x ; Fast flag should be preserved define <9 x double> @test_move_negation_to_second_operand_with_fast_flags(<27 x double> %a, <3 x double> %b) { ; CHECK-LABEL: @test_move_negation_to_second_operand_with_fast_flags( -; CHECK-NEXT: [[A_NEG:%.*]] = fneg <27 x double> [[A:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call fast <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]] +; CHECK-NEXT: [[RES:%.*]] = tail call fast <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1) ; CHECK-NEXT: ret <9 x double> [[RES]] ; %a.neg = fneg <27 x double> %a @@ -54,9 +87,9 @@ define <9 x double> @test_move_negation_to_second_operand_with_fast_flags(<27 x ; The result has the fewest vector elements between the result and the two operands so the negation can be moved there define <2 x double> @test_negation_move_to_result_from_second_operand(<3 x double> %a, <6 x double> %b){ ; CHECK-LABEL: @test_negation_move_to_result_from_second_operand( -; CHECK-NEXT: [[B_NEG:%.*]] = fneg <6 x double> [[B:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> [[A:%.*]], <6 x double> [[B_NEG]], i32 1, i32 3, i32 2) -; CHECK-NEXT: ret <2 x double> [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> [[A:%.*]], <6 x double> [[B:%.*]], i32 1, i32 3, i32 2) +; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x double> [[TMP1]] +; CHECK-NEXT: ret <2 x double> [[TMP2]] ; %b.neg = fneg <6 x double> %b %res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> %a, <6 x double> %b.neg, i32 1, i32 3, i32 2) @@ -66,8 +99,8 @@ define <2 x double> @test_negation_move_to_result_from_second_operand(<3 x doubl ; %a has the fewest vector elements between the result and the two operands so the negation can be moved there define <9 x double> @test_move_negation_to_first_operand(<3 x double> %a, <27 x double> %b) { ; CHECK-LABEL: @test_move_negation_to_first_operand( -; CHECK-NEXT: [[B_NEG:%.*]] = fneg <27 x double> [[B:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v3f64.v27f64(<3 x double> [[A:%.*]], <27 x double> [[B_NEG]], i32 1, i32 3, i32 9) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[A:%.*]] +; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v3f64.v27f64(<3 x double> [[TMP1]], <27 x double> [[B:%.*]], i32 1, i32 3, i32 9) ; CHECK-NEXT: ret <9 x double> [[RES]] ; %b.neg = fneg <27 x double> %b @@ -234,8 +267,8 @@ define <12 x double> @fneg_with_multiple_uses_2(<15 x double> %a, <20 x double> ; negation should be moved to the second operand given it has the smallest operand count define <72 x double> @chain_of_matrix_mutliplies(<27 x double> %a, <3 x double> %b, <8 x double> %c) { ; CHECK-LABEL: @chain_of_matrix_mutliplies( -; CHECK-NEXT: [[A_NEG:%.*]] = fneg <27 x double> [[A:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]] +; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1) ; CHECK-NEXT: [[RES_2:%.*]] = tail call <72 x double> @llvm.matrix.multiply.v72f64.v9f64.v8f64(<9 x double> [[RES]], <8 x double> [[C:%.*]], i32 9, i32 1, i32 8) ; CHECK-NEXT: ret <72 x double> [[RES_2]] ; @@ -249,11 +282,11 @@ define <72 x double> @chain_of_matrix_mutliplies(<27 x double> %a, <3 x double> ; second negation should be moved to the result of the second multipication define <6 x double> @chain_of_matrix_mutliplies_with_two_negations(<3 x double> %a, <5 x double> %b, <10 x double> %c) { ; CHECK-LABEL: @chain_of_matrix_mutliplies_with_two_negations( -; CHECK-NEXT: [[B_NEG:%.*]] = fneg <5 x double> [[B:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> [[A:%.*]], <5 x double> [[B_NEG]], i32 3, i32 1, i32 5) -; CHECK-NEXT: [[RES_NEG:%.*]] = fneg <15 x double> [[RES]] -; CHECK-NEXT: [[RES_2:%.*]] = tail call <6 x double> @llvm.matrix.multiply.v6f64.v15f64.v10f64(<15 x double> [[RES_NEG]], <10 x double> [[C:%.*]], i32 3, i32 5, i32 2) -; CHECK-NEXT: ret <6 x double> [[RES_2]] +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[A:%.*]] +; CHECK-NEXT: [[RES:%.*]] = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> [[TMP1]], <5 x double> [[B:%.*]], i32 3, i32 1, i32 5) +; CHECK-NEXT: [[TMP2:%.*]] = call <6 x double> @llvm.matrix.multiply.v6f64.v15f64.v10f64(<15 x double> [[RES]], <10 x double> [[C:%.*]], i32 3, i32 5, i32 2) +; CHECK-NEXT: [[TMP3:%.*]] = fneg <6 x double> [[TMP2]] +; CHECK-NEXT: ret <6 x double> [[TMP3]] ; %b.neg = fneg <5 x double> %b %res = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> %a, <5 x double> %b.neg, i32 3, i32 1, i32 5) @@ -265,10 +298,10 @@ define <6 x double> @chain_of_matrix_mutliplies_with_two_negations(<3 x double> ; negation should be propagated to the result of the second matrix multiplication define <6 x double> @chain_of_matrix_mutliplies_propagation(<15 x double> %a, <20 x double> %b, <8 x double> %c){ ; CHECK-LABEL: @chain_of_matrix_mutliplies_propagation( -; CHECK-NEXT: [[A_NEG:%.*]] = fneg <15 x double> [[A:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> [[A_NEG]], <20 x double> [[B:%.*]], i32 3, i32 5, i32 4) -; CHECK-NEXT: [[RES_2:%.*]] = tail call <6 x double> @llvm.matrix.multiply.v6f64.v12f64.v8f64(<12 x double> [[RES]], <8 x double> [[C:%.*]], i32 3, i32 4, i32 2) -; CHECK-NEXT: ret <6 x double> [[RES_2]] +; CHECK-NEXT: [[TMP1:%.*]] = call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> [[A:%.*]], <20 x double> [[B:%.*]], i32 3, i32 5, i32 4) +; CHECK-NEXT: [[TMP2:%.*]] = call <6 x double> @llvm.matrix.multiply.v6f64.v12f64.v8f64(<12 x double> [[TMP1]], <8 x double> [[C:%.*]], i32 3, i32 4, i32 2) +; CHECK-NEXT: [[TMP3:%.*]] = fneg <6 x double> [[TMP2]] +; CHECK-NEXT: ret <6 x double> [[TMP3]] ; %a.neg = fneg <15 x double> %a %res = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> %a.neg, <20 x double> %b, i32 3, i32 5, i32 4)