Expand Up
@@ -88,6 +88,9 @@ AVRTargetLowering::AVRTargetLowering(const AVRTargetMachine &TM,
setOperationAction (ISD::SRA, MVT::i16, Custom);
setOperationAction (ISD::SHL, MVT::i16, Custom);
setOperationAction (ISD::SRL, MVT::i16, Custom);
setOperationAction (ISD::SRA, MVT::i32, Custom);
setOperationAction (ISD::SHL, MVT::i32, Custom);
setOperationAction (ISD::SRL, MVT::i32, Custom);
setOperationAction (ISD::SHL_PARTS, MVT::i16, Expand);
setOperationAction (ISD::SRA_PARTS, MVT::i16, Expand);
setOperationAction (ISD::SRL_PARTS, MVT::i16, Expand);
Expand Down
Expand Up
@@ -247,10 +250,13 @@ const char *AVRTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE (CALL);
NODE (WRAPPER);
NODE (LSL);
NODE (LSLW);
NODE (LSR);
NODE (LSRW);
NODE (ROL);
NODE (ROR);
NODE (ASR);
NODE (ASRW);
NODE (LSLLOOP);
NODE (LSRLOOP);
NODE (ROLLOOP);
Expand Down
Expand Up
@@ -279,6 +285,57 @@ SDValue AVRTargetLowering::LowerShifts(SDValue Op, SelectionDAG &DAG) const {
assert (isPowerOf2_32 (VT.getSizeInBits ()) &&
" Expected power-of-2 shift amount" );
if (VT.getSizeInBits () == 32 ) {
if (!isa<ConstantSDNode>(N->getOperand (1 ))) {
// 32-bit shifts are converted to a loop in IR.
// This should be unreachable.
report_fatal_error (" Expected a constant shift amount!" );
}
SDVTList ResTys = DAG.getVTList (MVT::i16, MVT::i16);
SDValue SrcLo =
DAG.getNode (ISD::EXTRACT_ELEMENT, dl, MVT::i16, Op.getOperand (0 ),
DAG.getConstant (0 , dl, MVT::i16));
SDValue SrcHi =
DAG.getNode (ISD::EXTRACT_ELEMENT, dl, MVT::i16, Op.getOperand (0 ),
DAG.getConstant (1 , dl, MVT::i16));
uint64_t ShiftAmount =
cast<ConstantSDNode>(N->getOperand (1 ))->getZExtValue ();
if (ShiftAmount == 16 ) {
// Special case these two operations because they appear to be used by the
// generic codegen parts to lower 32-bit numbers.
// TODO: perhaps we can lower shift amounts bigger than 16 to a 16-bit
// shift of a part of the 32-bit value?
switch (Op.getOpcode ()) {
case ISD::SHL: {
SDValue Zero = DAG.getConstant (0 , dl, MVT::i16);
return DAG.getNode (ISD::BUILD_PAIR, dl, MVT::i32, Zero, SrcLo);
}
case ISD::SRL: {
SDValue Zero = DAG.getConstant (0 , dl, MVT::i16);
return DAG.getNode (ISD::BUILD_PAIR, dl, MVT::i32, SrcHi, Zero);
}
}
}
SDValue Cnt = DAG.getTargetConstant (ShiftAmount, dl, MVT::i8);
unsigned Opc;
switch (Op.getOpcode ()) {
default :
llvm_unreachable (" Invalid 32-bit shift opcode!" );
case ISD::SHL:
Opc = AVRISD::LSLW;
break ;
case ISD::SRL:
Opc = AVRISD::LSRW;
break ;
case ISD::SRA:
Opc = AVRISD::ASRW;
break ;
}
SDValue Result = DAG.getNode (Opc, dl, ResTys, SrcLo, SrcHi, Cnt);
return DAG.getNode (ISD::BUILD_PAIR, dl, MVT::i32, Result.getValue (0 ),
Result.getValue (1 ));
}
// Expand non-constant shifts to loops.
if (!isa<ConstantSDNode>(N->getOperand (1 ))) {
switch (Op.getOpcode ()) {
Expand Down
Expand Up
@@ -1789,6 +1846,359 @@ MachineBasicBlock *AVRTargetLowering::insertShift(MachineInstr &MI,
return RemBB;
}
// Do a multibyte AVR shift. Insert shift instructions and put the output
// registers in the Regs array.
// Because AVR does not have a normal shift instruction (only a single bit shift
// instruction), we have to emulate this behavior with other instructions.
// It first tries large steps (moving registers around) and then smaller steps
// like single bit shifts.
// Large shifts actually reduce the number of shifted registers, so the below
// algorithms have to work independently of the number of registers that are
// shifted.
// For more information and background, see this blogpost:
// https://aykevl.nl/2021/02/avr-bitshift
static void insertMultibyteShift (MachineInstr &MI, MachineBasicBlock *BB,
MutableArrayRef<std::pair<Register, int >> Regs,
ISD::NodeType Opc, int64_t ShiftAmt) {
const TargetInstrInfo &TII = *BB->getParent ()->getSubtarget ().getInstrInfo ();
const AVRSubtarget &STI = BB->getParent ()->getSubtarget <AVRSubtarget>();
MachineRegisterInfo &MRI = BB->getParent ()->getRegInfo ();
const DebugLoc &dl = MI.getDebugLoc ();
const bool ShiftLeft = Opc == ISD::SHL;
const bool ArithmeticShift = Opc == ISD::SRA;
// Zero a register, for use in later operations.
Register ZeroReg = MRI.createVirtualRegister (&AVR::GPR8RegClass);
BuildMI (*BB, MI, dl, TII.get (AVR::COPY), ZeroReg)
.addReg (STI.getZeroRegister ());
// Do a shift modulo 6 or 7. This is a bit more complicated than most shifts
// and is hard to compose with the rest, so these are special cased.
// The basic idea is to shift one or two bits in the opposite direction and
// then move registers around to get the correct end result.
if (ShiftLeft && (ShiftAmt % 8 ) >= 6 ) {
// Left shift modulo 6 or 7.
// Create a slice of the registers we're going to modify, to ease working
// with them.
size_t ShiftRegsOffset = ShiftAmt / 8 ;
size_t ShiftRegsSize = Regs.size () - ShiftRegsOffset;
MutableArrayRef<std::pair<Register, int >> ShiftRegs =
Regs.slice (ShiftRegsOffset, ShiftRegsSize);
// Shift one to the right, keeping the least significant bit as the carry
// bit.
insertMultibyteShift (MI, BB, ShiftRegs, ISD::SRL, 1 );
// Rotate the least significant bit from the carry bit into a new register
// (that starts out zero).
Register LowByte = MRI.createVirtualRegister (&AVR::GPR8RegClass);
BuildMI (*BB, MI, dl, TII.get (AVR::RORRd), LowByte).addReg (ZeroReg);
// Shift one more to the right if this is a modulo-6 shift.
if (ShiftAmt % 8 == 6 ) {
insertMultibyteShift (MI, BB, ShiftRegs, ISD::SRL, 1 );
Register NewLowByte = MRI.createVirtualRegister (&AVR::GPR8RegClass);
BuildMI (*BB, MI, dl, TII.get (AVR::RORRd), NewLowByte).addReg (LowByte);
LowByte = NewLowByte;
}
// Move all registers to the left, zeroing the bottom registers as needed.
for (size_t I = 0 ; I < Regs.size (); I++) {
int ShiftRegsIdx = I + 1 ;
if (ShiftRegsIdx < (int )ShiftRegs.size ()) {
Regs[I] = ShiftRegs[ShiftRegsIdx];
} else if (ShiftRegsIdx == (int )ShiftRegs.size ()) {
Regs[I] = std::pair (LowByte, 0 );
} else {
Regs[I] = std::pair (ZeroReg, 0 );
}
}
return ;
}
// Right shift modulo 6 or 7.
if (!ShiftLeft && (ShiftAmt % 8 ) >= 6 ) {
// Create a view on the registers we're going to modify, to ease working
// with them.
size_t ShiftRegsSize = Regs.size () - (ShiftAmt / 8 );
MutableArrayRef<std::pair<Register, int >> ShiftRegs =
Regs.slice (0 , ShiftRegsSize);
// Shift one to the left.
insertMultibyteShift (MI, BB, ShiftRegs, ISD::SHL, 1 );
// Sign or zero extend the most significant register into a new register.
// The HighByte is the byte that still has one (or two) bits from the
// original value. The ExtByte is purely a zero/sign extend byte (all bits
// are either 0 or 1).
Register HighByte = MRI.createVirtualRegister (&AVR::GPR8RegClass);
Register ExtByte = 0 ;
if (ArithmeticShift) {
// Sign-extend bit that was shifted out last.
BuildMI (*BB, MI, dl, TII.get (AVR::SBCRdRr), HighByte)
.addReg (HighByte, RegState::Undef)
.addReg (HighByte, RegState::Undef);
ExtByte = HighByte;
// The highest bit of the original value is the same as the zero-extend
// byte, so HighByte and ExtByte are the same.
} else {
// Use the zero register for zero extending.
ExtByte = ZeroReg;
// Rotate most significant bit into a new register (that starts out zero).
BuildMI (*BB, MI, dl, TII.get (AVR::ADCRdRr), HighByte)
.addReg (ExtByte)
.addReg (ExtByte);
}
// Shift one more to the left for modulo 6 shifts.
if (ShiftAmt % 8 == 6 ) {
insertMultibyteShift (MI, BB, ShiftRegs, ISD::SHL, 1 );
// Shift the topmost bit into the HighByte.
Register NewExt = MRI.createVirtualRegister (&AVR::GPR8RegClass);
BuildMI (*BB, MI, dl, TII.get (AVR::ADCRdRr), NewExt)
.addReg (HighByte)
.addReg (HighByte);
HighByte = NewExt;
}
// Move all to the right, while sign or zero extending.
for (int I = Regs.size () - 1 ; I >= 0 ; I--) {
int ShiftRegsIdx = I - (Regs.size () - ShiftRegs.size ()) - 1 ;
if (ShiftRegsIdx >= 0 ) {
Regs[I] = ShiftRegs[ShiftRegsIdx];
} else if (ShiftRegsIdx == -1 ) {
Regs[I] = std::pair (HighByte, 0 );
} else {
Regs[I] = std::pair (ExtByte, 0 );
}
}
return ;
}
// For shift amounts of at least one register, simply rename the registers and
// zero the bottom registers.
while (ShiftLeft && ShiftAmt >= 8 ) {
// Move all registers one to the left.
for (size_t I = 0 ; I < Regs.size () - 1 ; I++) {
Regs[I] = Regs[I + 1 ];
}
// Zero the least significant register.
Regs[Regs.size () - 1 ] = std::pair (ZeroReg, 0 );
// Continue shifts with the leftover registers.
Regs = Regs.drop_back (1 );
ShiftAmt -= 8 ;
}
// And again, the same for right shifts.
Register ShrExtendReg = 0 ;
if (!ShiftLeft && ShiftAmt >= 8 ) {
if (ArithmeticShift) {
// Sign extend the most significant register into ShrExtendReg.
ShrExtendReg = MRI.createVirtualRegister (&AVR::GPR8RegClass);
Register Tmp = MRI.createVirtualRegister (&AVR::GPR8RegClass);
BuildMI (*BB, MI, dl, TII.get (AVR::ADDRdRr), Tmp)
.addReg (Regs[0 ].first , 0 , Regs[0 ].second )
.addReg (Regs[0 ].first , 0 , Regs[0 ].second );
BuildMI (*BB, MI, dl, TII.get (AVR::SBCRdRr), ShrExtendReg)
.addReg (Tmp)
.addReg (Tmp);
} else {
ShrExtendReg = ZeroReg;
}
for (; ShiftAmt >= 8 ; ShiftAmt -= 8 ) {
// Move all registers one to the right.
for (size_t I = Regs.size () - 1 ; I != 0 ; I--) {
Regs[I] = Regs[I - 1 ];
}
// Zero or sign extend the most significant register.
Regs[0 ] = std::pair (ShrExtendReg, 0 );
// Continue shifts with the leftover registers.
Regs = Regs.drop_front (1 );
}
}
// The bigger shifts are already handled above.
assert ((ShiftAmt < 8 ) && " Unexpect shift amount" );
// Shift by four bits, using a complicated swap/eor/andi/eor sequence.
// It only works for logical shifts because the bits shifted in are all
// zeroes.
// To shift a single byte right, it produces code like this:
// swap r0
// andi r0, 0x0f
// For a two-byte (16-bit) shift, it adds the following instructions to shift
// the upper byte into the lower byte:
// swap r1
// eor r0, r1
// andi r1, 0x0f
// eor r0, r1
// For bigger shifts, it repeats the above sequence. For example, for a 3-byte
// (24-bit) shift it adds:
// swap r2
// eor r1, r2
// andi r2, 0x0f
// eor r1, r2
if (!ArithmeticShift && ShiftAmt >= 4 ) {
Register Prev = 0 ;
for (size_t I = 0 ; I < Regs.size (); I++) {
size_t Idx = ShiftLeft ? I : Regs.size () - I - 1 ;
Register SwapReg = MRI.createVirtualRegister (&AVR::LD8RegClass);
BuildMI (*BB, MI, dl, TII.get (AVR::SWAPRd), SwapReg)
.addReg (Regs[Idx].first , 0 , Regs[Idx].second );
if (I != 0 ) {
Register R = MRI.createVirtualRegister (&AVR::GPR8RegClass);
BuildMI (*BB, MI, dl, TII.get (AVR::EORRdRr), R)
.addReg (Prev)
.addReg (SwapReg);
Prev = R;
}
Register AndReg = MRI.createVirtualRegister (&AVR::LD8RegClass);
BuildMI (*BB, MI, dl, TII.get (AVR::ANDIRdK), AndReg)
.addReg (SwapReg)
.addImm (ShiftLeft ? 0xf0 : 0x0f );
if (I != 0 ) {
Register R = MRI.createVirtualRegister (&AVR::GPR8RegClass);
BuildMI (*BB, MI, dl, TII.get (AVR::EORRdRr), R)
.addReg (Prev)
.addReg (AndReg);
size_t PrevIdx = ShiftLeft ? Idx - 1 : Idx + 1 ;
Regs[PrevIdx] = std::pair (R, 0 );
}
Prev = AndReg;
Regs[Idx] = std::pair (AndReg, 0 );
}
ShiftAmt -= 4 ;
}
// Shift by one. This is the fallback that always works, and the shift
// operation that is used for 1, 2, and 3 bit shifts.
while (ShiftLeft && ShiftAmt) {
// Shift one to the left.
for (ssize_t I = Regs.size () - 1 ; I >= 0 ; I--) {
Register Out = MRI.createVirtualRegister (&AVR::GPR8RegClass);
Register In = Regs[I].first ;
Register InSubreg = Regs[I].second ;
if (I == (ssize_t )Regs.size () - 1 ) { // first iteration
BuildMI (*BB, MI, dl, TII.get (AVR::ADDRdRr), Out)
.addReg (In, 0 , InSubreg)
.addReg (In, 0 , InSubreg);
} else {
BuildMI (*BB, MI, dl, TII.get (AVR::ADCRdRr), Out)
.addReg (In, 0 , InSubreg)
.addReg (In, 0 , InSubreg);
}
Regs[I] = std::pair (Out, 0 );
}
ShiftAmt--;
}
while (!ShiftLeft && ShiftAmt) {
// Shift one to the right.
for (size_t I = 0 ; I < Regs.size (); I++) {
Register Out = MRI.createVirtualRegister (&AVR::GPR8RegClass);
Register In = Regs[I].first ;
Register InSubreg = Regs[I].second ;
if (I == 0 ) {
unsigned Opc = ArithmeticShift ? AVR::ASRRd : AVR::LSRRd;
BuildMI (*BB, MI, dl, TII.get (Opc), Out).addReg (In, 0 , InSubreg);
} else {
BuildMI (*BB, MI, dl, TII.get (AVR::RORRd), Out).addReg (In, 0 , InSubreg);
}
Regs[I] = std::pair (Out, 0 );
}
ShiftAmt--;
}
if (ShiftAmt != 0 ) {
llvm_unreachable (" don't know how to shift!" ); // sanity check
}
}
// Do a wide (32-bit) shift.
MachineBasicBlock *
AVRTargetLowering::insertWideShift (MachineInstr &MI,
MachineBasicBlock *BB) const {
const TargetInstrInfo &TII = *Subtarget.getInstrInfo ();
const DebugLoc &dl = MI.getDebugLoc ();
// How much to shift to the right (meaning: a negative number indicates a left
// shift).
int64_t ShiftAmt = MI.getOperand (4 ).getImm ();
ISD::NodeType Opc;
switch (MI.getOpcode ()) {
case AVR::Lsl32:
Opc = ISD::SHL;
break ;
case AVR::Lsr32:
Opc = ISD::SRL;
break ;
case AVR::Asr32:
Opc = ISD::SRA;
break ;
}
// Read the input registers, with the most significant register at index 0.
std::array<std::pair<Register, int >, 4 > Registers = {
std::pair (MI.getOperand (3 ).getReg (), AVR::sub_hi),
std::pair (MI.getOperand (3 ).getReg (), AVR::sub_lo),
std::pair (MI.getOperand (2 ).getReg (), AVR::sub_hi),
std::pair (MI.getOperand (2 ).getReg (), AVR::sub_lo),
};
// Do the shift. The registers are modified in-place.
insertMultibyteShift (MI, BB, Registers, Opc, ShiftAmt);
// Combine the 8-bit registers into 16-bit register pairs.
// This done either from LSB to MSB or from MSB to LSB, depending on the
// shift. It's an optimization so that the register allocator will use the
// fewest movs possible (which order we use isn't a correctness issue, just an
// optimization issue).
// - lsl prefers starting from the most significant byte (2nd case).
// - lshr prefers starting from the least significant byte (1st case).
// - for ashr it depends on the number of shifted bytes.
// Some shift operations still don't get the most optimal mov sequences even
// with this distinction. TODO: figure out why and try to fix it (but we're
// already equal to or faster than avr-gcc in all cases except ashr 8).
if (Opc != ISD::SHL &&
(Opc != ISD::SRA || (ShiftAmt < 16 || ShiftAmt >= 22 ))) {
// Use the resulting registers starting with the least significant byte.
BuildMI (*BB, MI, dl, TII.get (AVR::REG_SEQUENCE), MI.getOperand (0 ).getReg ())
.addReg (Registers[3 ].first , 0 , Registers[3 ].second )
.addImm (AVR::sub_lo)
.addReg (Registers[2 ].first , 0 , Registers[2 ].second )
.addImm (AVR::sub_hi);
BuildMI (*BB, MI, dl, TII.get (AVR::REG_SEQUENCE), MI.getOperand (1 ).getReg ())
.addReg (Registers[1 ].first , 0 , Registers[1 ].second )
.addImm (AVR::sub_lo)
.addReg (Registers[0 ].first , 0 , Registers[0 ].second )
.addImm (AVR::sub_hi);
} else {
// Use the resulting registers starting with the most significant byte.
BuildMI (*BB, MI, dl, TII.get (AVR::REG_SEQUENCE), MI.getOperand (1 ).getReg ())
.addReg (Registers[0 ].first , 0 , Registers[0 ].second )
.addImm (AVR::sub_hi)
.addReg (Registers[1 ].first , 0 , Registers[1 ].second )
.addImm (AVR::sub_lo);
BuildMI (*BB, MI, dl, TII.get (AVR::REG_SEQUENCE), MI.getOperand (0 ).getReg ())
.addReg (Registers[2 ].first , 0 , Registers[2 ].second )
.addImm (AVR::sub_hi)
.addReg (Registers[3 ].first , 0 , Registers[3 ].second )
.addImm (AVR::sub_lo);
}
// Remove the pseudo instruction.
MI.eraseFromParent ();
return BB;
}
static bool isCopyMulResult (MachineBasicBlock::iterator const &I) {
if (I->getOpcode () == AVR::COPY) {
Register SrcReg = I->getOperand (1 ).getReg ();
Expand Down
Expand Up
@@ -1901,6 +2311,10 @@ AVRTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
case AVR::Asr8:
case AVR::Asr16:
return insertShift (MI, MBB);
case AVR::Lsl32:
case AVR::Lsr32:
case AVR::Asr32:
return insertWideShift (MI, MBB);
case AVR::MULRdRr:
case AVR::MULSRdRr:
return insertMul (MI, MBB);
Expand Down