Skip to content

Commit

Permalink
[mips][msa] Accept more values for constant splats
Browse files Browse the repository at this point in the history
This patches teaches the MIPS backend to accept more values for constant
splats. Previously, only 10 bit signed immediates or values that could be
loaded using an ldi.[bhwd] instruction would be acceptted. This patch relaxes
that constraint so that any constant value that be splatted is accepted.

As a result, the constant pool is used less for vector operations, and the
suite of bit manipulation instructions b(clr|set|neg)i can now be used with
the full range of their immediate operand.

Reviewers: slthakur

Differential Revision: https://reviews.llvm.org/D30640

llvm-svn: 297457
  • Loading branch information
Simon Dardis committed Mar 10, 2017
1 parent 94fb0bb commit 7090d14
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 29 deletions.
235 changes: 229 additions & 6 deletions llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp
Expand Up @@ -934,6 +934,9 @@ bool MipsSEDAGToDAGISel::trySelect(SDNode *Node) {
// same set/ of registers. Similarly, ldi.h isn't capable of producing {
// 0x00000000, 0x00000001, 0x00000000, 0x00000001 } but 'ldi.d wd, 1' can.

const MipsABIInfo &ABI =
static_cast<const MipsTargetMachine &>(TM).getABI();

BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(Node);
APInt SplatValue, SplatUndef;
unsigned SplatBitSize;
Expand Down Expand Up @@ -971,13 +974,233 @@ bool MipsSEDAGToDAGISel::trySelect(SDNode *Node) {
break;
}

if (!SplatValue.isSignedIntN(10))
return false;

SDValue Imm = CurDAG->getTargetConstant(SplatValue, DL,
ViaVecTy.getVectorElementType());
SDNode *Res;

SDNode *Res = CurDAG->getMachineNode(LdiOp, DL, ViaVecTy, Imm);
// If we have a signed 10 bit integer, we can splat it directly.
//
// If we have something bigger we can synthesize the value into a GPR and
// splat from there.
if (SplatValue.isSignedIntN(10)) {
SDValue Imm = CurDAG->getTargetConstant(SplatValue, DL,
ViaVecTy.getVectorElementType());

Res = CurDAG->getMachineNode(LdiOp, DL, ViaVecTy, Imm);
} else if (SplatValue.isSignedIntN(16) &&
((ABI.IsO32() && SplatBitSize < 64) ||
(ABI.IsN32() || ABI.IsN64()))) {
// Only handle signed 16 bit values when the element size is GPR width.
// MIPS64 can handle all the cases but MIPS32 would need to handle
// negative cases specifically here. Instead, handle those cases as
// 64bit values.

bool Is32BitSplat = ABI.IsO32() || SplatBitSize < 64;
const unsigned ADDiuOp = Is32BitSplat ? Mips::ADDiu : Mips::DADDiu;
const MVT SplatMVT = Is32BitSplat ? MVT::i32 : MVT::i64;
SDValue ZeroVal = CurDAG->getRegister(
Is32BitSplat ? Mips::ZERO : Mips::ZERO_64, SplatMVT);

const unsigned FILLOp =
SplatBitSize == 16
? Mips::FILL_H
: (SplatBitSize == 32 ? Mips::FILL_W
: (SplatBitSize == 64 ? Mips::FILL_D : 0));

assert(FILLOp != 0 && "Unknown FILL Op for splat synthesis!");
assert((!ABI.IsO32() || (ABI.IsO32() && FILLOp != Mips::FILL_D)) &&
"Attempting to use fill.d on MIPS32!");

const unsigned Lo = SplatValue.getLoBits(16).getZExtValue();
SDValue LoVal = CurDAG->getTargetConstant(Lo, DL, SplatMVT);

Res = CurDAG->getMachineNode(ADDiuOp, DL, SplatMVT, ZeroVal, LoVal);
Res = CurDAG->getMachineNode(FILLOp, DL, ViaVecTy, SDValue(Res, 0));

} else if (SplatValue.isSignedIntN(32) && SplatBitSize == 32) {
// Only handle the cases where the splat size agrees with the size
// of the SplatValue here.
const unsigned Lo = SplatValue.getLoBits(16).getZExtValue();
const unsigned Hi = SplatValue.lshr(16).getLoBits(16).getZExtValue();
SDValue ZeroVal = CurDAG->getRegister(Mips::ZERO, MVT::i32);

SDValue LoVal = CurDAG->getTargetConstant(Lo, DL, MVT::i32);
SDValue HiVal = CurDAG->getTargetConstant(Hi, DL, MVT::i32);

if (Hi)
Res = CurDAG->getMachineNode(Mips::LUi, DL, MVT::i32, HiVal);

if (Lo)
Res = CurDAG->getMachineNode(Mips::ORi, DL, MVT::i32,
Hi ? SDValue(Res, 0) : ZeroVal, LoVal);

assert((Hi || Lo) && "Zero case reached 32 bit case splat synthesis!");
Res = CurDAG->getMachineNode(Mips::FILL_W, DL, MVT::v4i32, SDValue(Res, 0));

} else if (SplatValue.isSignedIntN(32) && SplatBitSize == 64 &&
(ABI.IsN32() || ABI.IsN64())) {
// N32 and N64 can perform some tricks that O32 can't for signed 32 bit
// integers due to having 64bit registers. lui will cause the necessary
// zero/sign extension.
const unsigned Lo = SplatValue.getLoBits(16).getZExtValue();
const unsigned Hi = SplatValue.lshr(16).getLoBits(16).getZExtValue();
SDValue ZeroVal = CurDAG->getRegister(Mips::ZERO, MVT::i32);

SDValue LoVal = CurDAG->getTargetConstant(Lo, DL, MVT::i32);
SDValue HiVal = CurDAG->getTargetConstant(Hi, DL, MVT::i32);

if (Hi)
Res = CurDAG->getMachineNode(Mips::LUi, DL, MVT::i32, HiVal);

if (Lo)
Res = CurDAG->getMachineNode(Mips::ORi, DL, MVT::i32,
Hi ? SDValue(Res, 0) : ZeroVal, LoVal);

Res = CurDAG->getMachineNode(
Mips::SUBREG_TO_REG, DL, MVT::i64,
CurDAG->getTargetConstant(((Hi >> 15) & 0x1), DL, MVT::i64),
SDValue(Res, 0),
CurDAG->getTargetConstant(Mips::sub_32, DL, MVT::i64));

Res =
CurDAG->getMachineNode(Mips::FILL_D, DL, MVT::v2i64, SDValue(Res, 0));

} else if (SplatValue.isSignedIntN(64)) {
// If we have a 64 bit Splat value, we perform a similar sequence to the
// above:
//
// MIPS32: MIPS64:
// lui $res, %highest(val) lui $res, %highest(val)
// ori $res, $res, %higher(val) ori $res, $res, %higher(val)
// lui $res2, %hi(val) lui $res2, %hi(val)
// ori $res2, %res2, %lo(val) ori $res2, %res2, %lo(val)
// $res3 = fill $res2 dinsu $res, $res2, 0, 32
// $res4 = insert.w $res3[1], $res fill.d $res
// splat.d $res4, 0
//
// The ability to use dinsu is guaranteed as MSA requires MIPSR5. This saves
// having to materialize the value by shifts and ors.
//
// FIXME: Implement the preferred sequence for MIPS64R6:
//
// MIPS64R6:
// ori $res, $zero, %lo(val)
// daui $res, $res, %hi(val)
// dahi $res, $res, %higher(val)
// dati $res, $res, %highest(cal)
// fill.d $res
//

const unsigned Lo = SplatValue.getLoBits(16).getZExtValue();
const unsigned Hi = SplatValue.lshr(16).getLoBits(16).getZExtValue();
const unsigned Higher = SplatValue.lshr(32).getLoBits(16).getZExtValue();
const unsigned Highest = SplatValue.lshr(48).getLoBits(16).getZExtValue();

SDValue LoVal = CurDAG->getTargetConstant(Lo, DL, MVT::i32);
SDValue HiVal = CurDAG->getTargetConstant(Hi, DL, MVT::i32);
SDValue HigherVal = CurDAG->getTargetConstant(Higher, DL, MVT::i32);
SDValue HighestVal = CurDAG->getTargetConstant(Highest, DL, MVT::i32);
SDValue ZeroVal = CurDAG->getRegister(Mips::ZERO, MVT::i32);

// Independent of whether we're targeting MIPS64 or not, the basic
// operations are the same. Also, directly use the $zero register if
// the 16 bit chunk is zero.
//
// For optimization purposes we always synthesize the splat value as
// an i32 value, then if we're targetting MIPS64, use SUBREG_TO_REG
// just before combining the values with dinsu to produce an i64. This
// enables SelectionDAG to aggressively share components of splat values
// where possible.
//
// FIXME: This is the general constant synthesis problem. This code
// should be factored out into a class shared between all the
// classes that need it. Specifically, for a splat size of 64
// bits that's a negative number we can do better than LUi/ORi
// for the upper 32bits.

if (Hi)
Res = CurDAG->getMachineNode(Mips::LUi, DL, MVT::i32, HiVal);

if (Lo)
Res = CurDAG->getMachineNode(Mips::ORi, DL, MVT::i32,
Hi ? SDValue(Res, 0) : ZeroVal, LoVal);

SDNode *HiRes;
if (Highest)
HiRes = CurDAG->getMachineNode(Mips::LUi, DL, MVT::i32, HighestVal);

if (Higher)
HiRes = CurDAG->getMachineNode(Mips::ORi, DL, MVT::i32,
Highest ? SDValue(HiRes, 0) : ZeroVal,
HigherVal);


if (ABI.IsO32()) {
Res = CurDAG->getMachineNode(Mips::FILL_W, DL, MVT::v4i32,
(Hi || Lo) ? SDValue(Res, 0) : ZeroVal);

Res = CurDAG->getMachineNode(
Mips::INSERT_W, DL, MVT::v4i32, SDValue(Res, 0),
(Highest || Higher) ? SDValue(HiRes, 0) : ZeroVal,
CurDAG->getTargetConstant(1, DL, MVT::i32));

const TargetLowering *TLI = getTargetLowering();
const TargetRegisterClass *RC =
TLI->getRegClassFor(ViaVecTy.getSimpleVT());

Res = CurDAG->getMachineNode(
Mips::COPY_TO_REGCLASS, DL, ViaVecTy, SDValue(Res, 0),
CurDAG->getTargetConstant(RC->getID(), DL, MVT::i32));

Res = CurDAG->getMachineNode(
Mips::SPLATI_D, DL, MVT::v2i64, SDValue(Res, 0),
CurDAG->getTargetConstant(0, DL, MVT::i32));
} else if (ABI.IsN64() || ABI.IsN32()) {

SDValue Zero64Val = CurDAG->getRegister(Mips::ZERO_64, MVT::i64);
const bool HiResNonZero = Highest || Higher;
const bool ResNonZero = Hi || Lo;

if (HiResNonZero)
HiRes = CurDAG->getMachineNode(
Mips::SUBREG_TO_REG, DL, MVT::i64,
CurDAG->getTargetConstant(((Highest >> 15) & 0x1), DL, MVT::i64),
SDValue(HiRes, 0),
CurDAG->getTargetConstant(Mips::sub_32, DL, MVT::i64));

if (ResNonZero)
Res = CurDAG->getMachineNode(
Mips::SUBREG_TO_REG, DL, MVT::i64,
CurDAG->getTargetConstant(((Hi >> 15) & 0x1), DL, MVT::i64),
SDValue(Res, 0),
CurDAG->getTargetConstant(Mips::sub_32, DL, MVT::i64));

// We have 3 cases:
// The HiRes is nonzero but Res is $zero => dsll32 HiRes, 0
// The Res is nonzero but HiRes is $zero => dinsu Res, $zero, 32, 32
// Both are non zero => dinsu Res, HiRes, 32, 32
//
// The obvious "missing" case is when both are zero, but that case is
// handled by the ldi case.
if (ResNonZero) {
SDValue Ops[4] = {HiResNonZero ? SDValue(HiRes, 0) : Zero64Val,
CurDAG->getTargetConstant(64, DL, MVT::i32),
CurDAG->getTargetConstant(32, DL, MVT::i32),
SDValue(Res, 0)};

Res = CurDAG->getMachineNode(Mips::DINSU, DL, MVT::i64, Ops);
} else if (HiResNonZero) {
Res = CurDAG->getMachineNode(
Mips::DSLL32, DL, MVT::i64, SDValue(HiRes, 0),
CurDAG->getTargetConstant(0, DL, MVT::i32));
} else
llvm_unreachable(
"Zero splat value handled by non-zero 64bit splat synthesis!");

Res = CurDAG->getMachineNode(Mips::FILL_D, DL, MVT::v2i64, SDValue(Res, 0));
} else
llvm_unreachable("Unknown ABI in MipsISelDAGToDAG!");

} else
return false;

if (ResVecTy != ViaVecTy) {
// If LdiOp is writing to a different register class to ResVecTy, then
Expand Down
9 changes: 4 additions & 5 deletions llvm/lib/Target/Mips/MipsSEISelLowering.cpp
Expand Up @@ -2529,11 +2529,10 @@ SDValue MipsSETargetLowering::lowerBUILD_VECTOR(SDValue Op,
SplatBitSize != 64)
return SDValue();

// If the value fits into a simm10 then we can use ldi.[bhwd]
// However, if it isn't an integer type we will have to bitcast from an
// integer type first. Also, if there are any undefs, we must lower them
// to defined values first.
if (ResTy.isInteger() && !HasAnyUndefs && SplatValue.isSignedIntN(10))
// If the value isn't an integer type we will have to bitcast
// from an integer type first. Also, if there are any undefs, we must
// lower them to defined values first.
if (ResTy.isInteger() && !HasAnyUndefs)
return Op;

EVT ViaVecTy;
Expand Down
59 changes: 45 additions & 14 deletions llvm/test/CodeGen/Mips/msa/basic_operations.ll
@@ -1,9 +1,9 @@
; RUN: llc -march=mips -mattr=+msa,+fp64 -relocation-model=pic \
; RUN: -verify-machineinstrs < %s | \
; RUN: FileCheck -check-prefixes=ALL,O32,MIPS32,ALL-BE %s
; RUN: FileCheck -check-prefixes=ALL,O32,MIPS32,ALL-BE,O32-BE %s
; RUN: llc -march=mipsel -mattr=+msa,+fp64 -relocation-model=pic \
; RUN: -verify-machineinstrs < %s | \
; RUN: FileCheck -check-prefixes=ALL,O32,MIPS32,ALL-LE %s
; RUN: FileCheck -check-prefixes=ALL,O32,MIPS32,ALL-LE,O32-LE %s
; RUN: llc -march=mips64 -target-abi n32 -mattr=+msa,+fp64 \
; RUN: -relocation-model=pic -verify-machineinstrs < %s | \
; RUN: FileCheck -check-prefixes=ALL,N32,MIPS64,ALL-BE %s
Expand Down Expand Up @@ -58,10 +58,19 @@ define void @const_v16i8() nounwind {
; ALL-DAG: fill.w [[R1:\$w[0-9]+]], [[R2]]

store volatile <16 x i8> <i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 8, i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 8>, <16 x i8>*@v16i8
; O32: addiu [[G_PTR:\$[0-9]+]], {{.*}}, %lo($
; N32: addiu [[G_PTR:\$[0-9]+]], {{.*}}, %got_ofst(.L
; N64: daddiu [[G_PTR:\$[0-9]+]], {{.*}}, %got_ofst(.L
; ALL: ld.b [[R1:\$w[0-9]+]], 0([[G_PTR]])
; ALL-BE-DAG: lui [[R3:\$[0-9]+]], 1286
; ALL-LE-DAG: lui [[R3:\$[0-9]+]], 2055
; ALL-BE-DAG: ori [[R4:\$[0-9]+]], [[R3]], 1800
; ALL-LE-DAG: ori [[R4:\$[0-9]+]], [[R3]], 1541
; O32-BE: fill.w [[R1:\$w[0-9]+]], [[R4]]

; O32: insert.w [[R1]][1], [[R2]]
; O32: splati.d $w{{.*}}, [[R1]][0]

; MIPS64-BE: dinsu [[R4]], [[R2]], 32, 32
; MIPS64-LE: dinsu [[R2]], [[R4]], 32, 32
; MIPS64-BE: fill.d $w{{.*}}, [[R4]]
; MIPS64-LE: fill.d $w{{.*}}, [[R2]]

ret void
}
Expand Down Expand Up @@ -92,10 +101,19 @@ define void @const_v8i16() nounwind {
; ALL-DAG: fill.w [[R1:\$w[0-9]+]], [[R2]]

store volatile <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 1, i16 2, i16 3, i16 4>, <8 x i16>*@v8i16
; O32: addiu [[G_PTR:\$[0-9]+]], {{.*}}, %lo($
; N32: addiu [[G_PTR:\$[0-9]+]], {{.*}}, %got_ofst(.L
; N64: daddiu [[G_PTR:\$[0-9]+]], {{.*}}, %got_ofst(.L
; ALL: ld.h [[R1:\$w[0-9]+]], 0([[G_PTR]])
; ALL-BE-DAG: lui [[R3:\$[0-9]+]], 3
; ALL-LE-DAG: lui [[R3:\$[0-9]+]], 4
; ALL-BE-DAG: ori [[R4:\$[0-9]+]], [[R3]], 4
; ALL-LE-DAG: ori [[R4:\$[0-9]+]], [[R3]], 3

; O32-BE: fill.w [[R1:\$w[0-9]+]], [[R4]]
; O32: insert.w [[R1]][1], [[R2]]
; O32: splati.d $w{{.*}}, [[R1]][0]

; MIPS64-BE: dinsu [[R4]], [[R2]], 32, 32
; MIPS64-LE: dinsu [[R2]], [[R4]], 32, 32
; MIPS64-BE: fill.d $w{{.*}}, [[R4]]
; MIPS64-LE: fill.d $w{{.*}}, [[R2]]

ret void
}
Expand All @@ -122,10 +140,23 @@ define void @const_v4i32() nounwind {
; ALL: ldi.h [[R1:\$w[0-9]+]], 1

store volatile <4 x i32> <i32 1, i32 2, i32 1, i32 2>, <4 x i32>*@v4i32
; O32: addiu [[G_PTR:\$[0-9]+]], {{.*}}, %lo($
; N32: addiu [[G_PTR:\$[0-9]+]], {{.*}}, %got_ofst(.L
; N64: daddiu [[G_PTR:\$[0-9]+]], {{.*}}, %got_ofst(.L
; ALL: ld.w [[R1:\$w[0-9]+]], 0([[G_PTR]])
; -BE-DAG: ori [[R2:\$[0-9]+]], $zero, 1
; O32-BE-DAG: ori [[R3:\$[0-9]+]], $zero, 1
; O32-BE-DAG: ori [[R4:\$[0-9]+]], $zero, 2
; O32-LE-DAG: ori [[R3:\$[0-9]+]], $zero, 2
; O32-LE-DAG: ori [[R4:\$[0-9]+]], $zero, 1
; O32: fill.w [[W0:\$w[0-9]+]], [[R4]]
; O32: insert.w [[W0]][1], [[R3]]
; O32: splati.d [[W1:\$w[0-9]+]], [[W0]]

; MIPS64-DAG: ori [[R5:\$[0-9]+]], $zero, 2
; MIPS64-DAG: ori [[R6:\$[0-9]+]], $zero, 1

; MIPS64-BE: dinsu [[R5]], [[R6]], 32, 32
; MIPS64-LE: dinsu [[R6]], [[R5]], 32, 32
; MIPS64-BE: fill.d $w{{.*}}, [[R4]]
; MIPS64-LE: fill.d $w{{.*}}, [[R2]]


store volatile <4 x i32> <i32 3, i32 4, i32 5, i32 6>, <4 x i32>*@v4i32
; O32: addiu [[G_PTR:\$[0-9]+]], {{.*}}, %lo($
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/CodeGen/Mips/msa/immediates.ll
Expand Up @@ -920,7 +920,7 @@ entry:
define void @bclri_d(<2 x i64> * %ptr) {
entry:
; CHECK-LABEL: bclri_d:
; CHECK: and.v
; CHECK: bclri.d
%a = load <2 x i64>, <2 x i64> * %ptr, align 16
%r = call <2 x i64> @llvm.mips.bclri.d(<2 x i64> %a, i32 16)
store <2 x i64> %r, <2 x i64> * %ptr, align 16
Expand All @@ -930,7 +930,7 @@ entry:
define void @binsli_d(<2 x i64> * %ptr, <2 x i64> * %ptr2) {
entry:
; CHECK-LABEL: binsli_d:
; CHECK: bsel.v
; CHECK: binsli.d
%a = load <2 x i64>, <2 x i64> * %ptr, align 16
%b = load <2 x i64>, <2 x i64> * %ptr2, align 16
%r = call <2 x i64> @llvm.mips.binsli.d(<2 x i64> %a, <2 x i64> %b, i32 4)
Expand All @@ -952,7 +952,7 @@ entry:
define void @bnegi_d(<2 x i64> * %ptr) {
entry:
; CHECK-LABEL: bnegi_d:
; CHECK: xor.v
; CHECK: bnegi.d
%a = load <2 x i64>, <2 x i64> * %ptr, align 16
%r = call <2 x i64> @llvm.mips.bnegi.d(<2 x i64> %a, i32 9)
store <2 x i64> %r, <2 x i64> * %ptr, align 16
Expand All @@ -962,7 +962,7 @@ entry:
define void @bseti_d(<2 x i64> * %ptr) {
entry:
; CHECK-LABEL: bseti_d:
; CHECK: or.v
; CHECK: bseti.d
%a = load <2 x i64>, <2 x i64> * %ptr, align 16
%r = call <2 x i64> @llvm.mips.bseti.d(<2 x i64> %a, i32 25)
store <2 x i64> %r, <2 x i64> * %ptr, align 16
Expand Down

0 comments on commit 7090d14

Please sign in to comment.