12 changes: 5 additions & 7 deletions llvm/lib/IR/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Align GEPOperator::getMaxPreservedAlignment(const DataLayout &DL) const {
/// If the index isn't known, we take 1 because it is the index that will
/// give the worse alignment of the offset.
const uint64_t ElemCount = OpC ? OpC->getZExtValue() : 1;
Offset = DL.getTypeAllocSize(GTI.getIndexedType()) * ElemCount;
Offset = GTI.getSequentialElementStride(DL) * ElemCount;
}
Result = Align(MinAlign(Offset, Result.value()));
}
Expand Down Expand Up @@ -157,7 +157,7 @@ bool GEPOperator::accumulateConstantOffset(
continue;
}
if (!AccumulateOffset(ConstOffset->getValue(),
DL.getTypeAllocSize(GTI.getIndexedType())))
GTI.getSequentialElementStride(DL)))
return false;
continue;
}
Expand All @@ -170,8 +170,7 @@ bool GEPOperator::accumulateConstantOffset(
if (!ExternalAnalysis(*V, AnalysisIndex))
return false;
UsedExternalAnalysis = true;
if (!AccumulateOffset(AnalysisIndex,
DL.getTypeAllocSize(GTI.getIndexedType())))
if (!AccumulateOffset(AnalysisIndex, GTI.getSequentialElementStride(DL)))
return false;
}
return true;
Expand Down Expand Up @@ -218,14 +217,13 @@ bool GEPOperator::collectOffset(
continue;
}
CollectConstantOffset(ConstOffset->getValue(),
DL.getTypeAllocSize(GTI.getIndexedType()));
GTI.getSequentialElementStride(DL));
continue;
}

if (STy || ScalableType)
return false;
APInt IndexedSize =
APInt(BitWidth, DL.getTypeAllocSize(GTI.getIndexedType()));
APInt IndexedSize = APInt(BitWidth, GTI.getSequentialElementStride(DL));
// Insert an initial offset of 0 for V iff none exists already, then
// increment the offset by IndexedSize.
if (!IndexedSize.isZero()) {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/IR/Value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ getOffsetFromIndex(const GEPOperator *GEP, unsigned Idx, const DataLayout &DL) {

// Otherwise, we have a sequential type like an array or fixed-length
// vector. Multiply the index by the ElementSize.
TypeSize Size = DL.getTypeAllocSize(GTI.getIndexedType());
TypeSize Size = GTI.getSequentialElementStride(DL);
if (Size.isScalable())
return std::nullopt;
Offset += Size.getFixedValue() * OpC->getSExtValue();
Expand Down
17 changes: 13 additions & 4 deletions llvm/lib/Object/WasmObjectFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,7 @@ Error WasmObjectFile::parseExportSection(ReadContext &Ctx) {
break;
case wasm::WASM_EXTERNAL_TABLE:
Info.Kind = wasm::WASM_SYMBOL_TYPE_TABLE;
Info.ElementIndex = Ex.Index;
break;
default:
return make_error<GenericBinaryError>("unexpected export kind",
Expand Down Expand Up @@ -1667,10 +1668,18 @@ Expected<StringRef> WasmObjectFile::getSymbolName(DataRefImpl Symb) const {
Expected<uint64_t> WasmObjectFile::getSymbolAddress(DataRefImpl Symb) const {
auto &Sym = getWasmSymbol(Symb);
if (Sym.Info.Kind == wasm::WASM_SYMBOL_TYPE_FUNCTION &&
isDefinedFunctionIndex(Sym.Info.ElementIndex))
return getDefinedFunction(Sym.Info.ElementIndex).CodeSectionOffset;
else
return getSymbolValue(Symb);
isDefinedFunctionIndex(Sym.Info.ElementIndex)) {
// For object files, use the section offset. The linker relies on this.
// For linked files, use the file offset. This behavior matches the way
// browsers print stack traces and is useful for binary size analysis.
// (see https://webassembly.github.io/spec/web-api/index.html#conventions)
uint32_t Adjustment = isRelocatableObject() || isSharedObject()
? 0
: Sections[CodeSection].Offset;
return getDefinedFunction(Sym.Info.ElementIndex).CodeSectionOffset +
Adjustment;
}
return getSymbolValue(Symb);
}

uint64_t WasmObjectFile::getWasmSymbolValue(const WasmSymbol &Sym) const {
Expand Down
10 changes: 4 additions & 6 deletions llvm/lib/Target/AArch64/AArch64FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ bool AArch64FastISel::computeAddress(const Value *Obj, Address &Addr, Type *Ty)
unsigned Idx = cast<ConstantInt>(Op)->getZExtValue();
TmpOffset += SL->getElementOffset(Idx);
} else {
uint64_t S = DL.getTypeAllocSize(GTI.getIndexedType());
uint64_t S = GTI.getSequentialElementStride(DL);
while (true) {
if (const ConstantInt *CI = dyn_cast<ConstantInt>(Op)) {
// Constant-offset addressing.
Expand Down Expand Up @@ -4978,15 +4978,13 @@ bool AArch64FastISel::selectGetElementPtr(const Instruction *I) {
if (Field)
TotalOffs += DL.getStructLayout(StTy)->getElementOffset(Field);
} else {
Type *Ty = GTI.getIndexedType();

// If this is a constant subscript, handle it quickly.
if (const auto *CI = dyn_cast<ConstantInt>(Idx)) {
if (CI->isZero())
continue;
// N = N + Offset
TotalOffs +=
DL.getTypeAllocSize(Ty) * cast<ConstantInt>(CI)->getSExtValue();
TotalOffs += GTI.getSequentialElementStride(DL) *
cast<ConstantInt>(CI)->getSExtValue();
continue;
}
if (TotalOffs) {
Expand All @@ -4997,7 +4995,7 @@ bool AArch64FastISel::selectGetElementPtr(const Instruction *I) {
}

// N = N + Idx * ElementSize;
uint64_t ElementSize = DL.getTypeAllocSize(Ty);
uint64_t ElementSize = GTI.getSequentialElementStride(DL);
unsigned IdxN = getRegForGEPIndex(Idx);
if (!IdxN)
return false;
Expand Down
34 changes: 28 additions & 6 deletions llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
// Regardless of FP16 support, widen 16-bit elements to 32-bits.
.minScalar(0, s32)
.libcallFor({s32, s64});
getActionDefinitionsBuilder(G_FPOWI)
.scalarize(0)
.minScalar(0, s32)
.libcallFor({{s32, s32}, {s64, s32}});

getActionDefinitionsBuilder(G_INSERT)
.legalIf(all(typeInSet(0, {s32, s64, p0}),
Expand Down Expand Up @@ -761,17 +765,35 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.lowerIf(
all(typeInSet(0, {s8, s16, s32, s64, s128}), typeIs(2, p0)));

LegalityPredicate UseOutlineAtomics = [&ST](const LegalityQuery &Query) {
return ST.outlineAtomics() && !ST.hasLSE();
};

getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG)
.legalIf(all(typeInSet(0, {s32, s64}), typeIs(1, p0)))
.customIf([](const LegalityQuery &Query) {
return Query.Types[0].getSizeInBits() == 128;
.legalIf(all(typeInSet(0, {s32, s64}), typeIs(1, p0),
predNot(UseOutlineAtomics)))
.customIf(all(typeIs(0, s128), predNot(UseOutlineAtomics)))
.customIf([UseOutlineAtomics](const LegalityQuery &Query) {
return Query.Types[0].getSizeInBits() == 128 &&
!UseOutlineAtomics(Query);
})
.libcallIf(all(typeInSet(0, {s8, s16, s32, s64, s128}), typeIs(1, p0),
UseOutlineAtomics))
.clampScalar(0, s32, s64);

getActionDefinitionsBuilder({G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD,
G_ATOMICRMW_SUB, G_ATOMICRMW_AND, G_ATOMICRMW_OR,
G_ATOMICRMW_XOR})
.legalIf(all(typeInSet(0, {s32, s64}), typeIs(1, p0),
predNot(UseOutlineAtomics)))
.libcallIf(all(typeInSet(0, {s8, s16, s32, s64}), typeIs(1, p0),
UseOutlineAtomics))
.clampScalar(0, s32, s64);

// Do not outline these atomics operations, as per comment in
// AArch64ISelLowering.cpp's shouldExpandAtomicRMWInIR().
getActionDefinitionsBuilder(
{G_ATOMICRMW_XCHG, G_ATOMICRMW_ADD, G_ATOMICRMW_SUB, G_ATOMICRMW_AND,
G_ATOMICRMW_OR, G_ATOMICRMW_XOR, G_ATOMICRMW_MIN, G_ATOMICRMW_MAX,
G_ATOMICRMW_UMIN, G_ATOMICRMW_UMAX})
{G_ATOMICRMW_MIN, G_ATOMICRMW_MAX, G_ATOMICRMW_UMIN, G_ATOMICRMW_UMAX})
.legalIf(all(typeInSet(0, {s32, s64}), typeIs(1, p0)))
.clampScalar(0, s32, s64);

Expand Down
10 changes: 9 additions & 1 deletion llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,15 @@ void MetadataStreamerMsgPackV5::emitHiddenKernelArgs(
Offset += 8; // Skipped.
}

Offset += 72; // Reserved.
// Emit argument for hidden dynamic lds size
if (MFI.isDynamicLDSUsed()) {
emitKernelArg(DL, Int32Ty, Align(4), "hidden_dynamic_lds_size", Offset,
Args);
} else {
Offset += 4; // skipped
}

Offset += 68; // Reserved.

// hidden_private_base and hidden_shared_base are only when the subtarget has
// ApertureRegs.
Expand Down
20 changes: 5 additions & 15 deletions llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,26 +317,16 @@ void AMDGPUDAGToDAGISel::PreprocessISelDAG() {
}
}

bool AMDGPUDAGToDAGISel::isInlineImmediate(const SDNode *N,
bool Negated) const {
bool AMDGPUDAGToDAGISel::isInlineImmediate(const SDNode *N) const {
if (N->isUndef())
return true;

const SIInstrInfo *TII = Subtarget->getInstrInfo();
if (Negated) {
if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(N))
return TII->isInlineConstant(-C->getAPIntValue());
if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(N))
return TII->isInlineConstant(C->getAPIntValue());

if (const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N))
return TII->isInlineConstant(-C->getValueAPF().bitcastToAPInt());

} else {
if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(N))
return TII->isInlineConstant(C->getAPIntValue());

if (const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N))
return TII->isInlineConstant(C->getValueAPF().bitcastToAPInt());
}
if (const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N))
return TII->isInlineConstant(C->getValueAPF().bitcastToAPInt());

return false;
}
Expand Down
14 changes: 3 additions & 11 deletions llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,20 @@ static inline bool getConstantValue(SDValue N, uint32_t &Out) {
}

// TODO: Handle undef as zero
static inline SDNode *packConstantV2I16(const SDNode *N, SelectionDAG &DAG,
bool Negate = false) {
static inline SDNode *packConstantV2I16(const SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::BUILD_VECTOR && N->getNumOperands() == 2);
uint32_t LHSVal, RHSVal;
if (getConstantValue(N->getOperand(0), LHSVal) &&
getConstantValue(N->getOperand(1), RHSVal)) {
SDLoc SL(N);
uint32_t K = Negate ? (-LHSVal & 0xffff) | (-RHSVal << 16)
: (LHSVal & 0xffff) | (RHSVal << 16);
uint32_t K = (LHSVal & 0xffff) | (RHSVal << 16);
return DAG.getMachineNode(AMDGPU::S_MOV_B32, SL, N->getValueType(0),
DAG.getTargetConstant(K, SL, MVT::i32));
}

return nullptr;
}

static inline SDNode *packNegConstantV2I16(const SDNode *N, SelectionDAG &DAG) {
return packConstantV2I16(N, DAG, true);
}
} // namespace

/// AMDGPU specific code to select AMDGPU machine instructions for
Expand Down Expand Up @@ -110,10 +105,7 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {

private:
std::pair<SDValue, SDValue> foldFrameIndex(SDValue N) const;
bool isInlineImmediate(const SDNode *N, bool Negated = false) const;
bool isNegInlineImmediate(const SDNode *N) const {
return isInlineImmediate(N, true);
}
bool isInlineImmediate(const SDNode *N) const;

bool isInlineImmediate16(int64_t Imm) const {
return AMDGPU::isInlinableLiteral16(Imm, Subtarget->hasInv2PiInlineImm());
Expand Down
39 changes: 30 additions & 9 deletions llvm/lib/Target/AMDGPU/AMDGPUMachineFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,26 @@

using namespace llvm;

static const GlobalVariable *
getKernelDynLDSGlobalFromFunction(const Function &F) {
const Module *M = F.getParent();
SmallString<64> KernelDynLDSName("llvm.amdgcn.");
KernelDynLDSName += F.getName();
KernelDynLDSName += ".dynlds";
return M->getNamedGlobal(KernelDynLDSName);
}

static bool hasLDSKernelArgument(const Function &F) {
for (const Argument &Arg : F.args()) {
Type *ArgTy = Arg.getType();
if (auto PtrTy = dyn_cast<PointerType>(ArgTy)) {
if (PtrTy->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS)
return true;
}
}
return false;
}

AMDGPUMachineFunction::AMDGPUMachineFunction(const Function &F,
const AMDGPUSubtarget &ST)
: IsEntryFunction(AMDGPU::isEntryFunctionCC(F.getCallingConv())),
Expand Down Expand Up @@ -65,6 +85,10 @@ AMDGPUMachineFunction::AMDGPUMachineFunction(const Function &F,
Attribute NSZAttr = F.getFnAttribute("no-signed-zeros-fp-math");
NoSignedZerosFPMath =
NSZAttr.isStringAttribute() && NSZAttr.getValueAsString() == "true";

const GlobalVariable *DynLdsGlobal = getKernelDynLDSGlobalFromFunction(F);
if (DynLdsGlobal || hasLDSKernelArgument(F))
UsesDynamicLDS = true;
}

unsigned AMDGPUMachineFunction::allocateLDSGlobal(const DataLayout &DL,
Expand Down Expand Up @@ -139,15 +163,6 @@ unsigned AMDGPUMachineFunction::allocateLDSGlobal(const DataLayout &DL,
return Offset;
}

static const GlobalVariable *
getKernelDynLDSGlobalFromFunction(const Function &F) {
const Module *M = F.getParent();
std::string KernelDynLDSName = "llvm.amdgcn.";
KernelDynLDSName += F.getName();
KernelDynLDSName += ".dynlds";
return M->getNamedGlobal(KernelDynLDSName);
}

std::optional<uint32_t>
AMDGPUMachineFunction::getLDSKernelIdMetadata(const Function &F) {
// TODO: Would be more consistent with the abs symbols to use a range
Expand Down Expand Up @@ -210,3 +225,9 @@ void AMDGPUMachineFunction::setDynLDSAlign(const Function &F,
}
}
}

void AMDGPUMachineFunction::setUsesDynamicLDS(bool DynLDS) {
UsesDynamicLDS = DynLDS;
}

bool AMDGPUMachineFunction::isDynamicLDSUsed() const { return UsesDynamicLDS; }
7 changes: 7 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUMachineFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class AMDGPUMachineFunction : public MachineFunctionInfo {
/// stages.
Align DynLDSAlign;

// Flag to check dynamic LDS usage by kernel.
bool UsesDynamicLDS = false;

// Kernels + shaders. i.e. functions called by the hardware and not called
// by other functions.
bool IsEntryFunction = false;
Expand Down Expand Up @@ -119,6 +122,10 @@ class AMDGPUMachineFunction : public MachineFunctionInfo {
Align getDynLDSAlign() const { return DynLDSAlign; }

void setDynLDSAlign(const Function &F, const GlobalVariable &GV);

void setUsesDynamicLDS(bool DynLDS);

bool isDynamicLDSUsed() const;
};

}
Expand Down
21 changes: 13 additions & 8 deletions llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,9 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
case AMDGPU::OPERAND_REG_IMM_V2FP32:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT32:
case AMDGPU::OPERAND_REG_IMM_V2INT32:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_KIMM32:
case AMDGPU::OPERAND_INLINE_SPLIT_BARRIER_INT32:
return &APFloat::IEEEsingle();
Expand All @@ -1879,13 +1882,10 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_KIMM16:
return &APFloat::IEEEhalf();
Expand Down Expand Up @@ -2033,9 +2033,14 @@ bool AMDGPUOperand::isLiteralImm(MVT type) const {
// We allow fp literals with f16x2 operands assuming that the specified
// literal goes into the lower half and the upper half is zero. We also
// require that the literal may be losslessly converted to f16.
MVT ExpectedType = (type == MVT::v2f16)? MVT::f16 :
(type == MVT::v2i16)? MVT::i16 :
(type == MVT::v2f32)? MVT::f32 : type;
//
// For i16x2 operands, we assume that the specified literal is encoded as a
// single-precision float. This is pretty odd, but it matches SP3 and what
// happens in hardware.
MVT ExpectedType = (type == MVT::v2f16) ? MVT::f16
: (type == MVT::v2i16) ? MVT::f32
: (type == MVT::v2f32) ? MVT::f32
: type;

APFloat FPLiteral(APFloat::IEEEdouble(), APInt(64, Imm.Val));
return canLosslesslyConvertToFPType(FPLiteral, ExpectedType);
Expand Down Expand Up @@ -3401,12 +3406,12 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
if (OperandType == AMDGPU::OPERAND_REG_INLINE_C_V2INT16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_V2INT16 ||
OperandType == AMDGPU::OPERAND_REG_IMM_V2INT16)
return AMDGPU::isInlinableIntLiteralV216(Val);
return AMDGPU::isInlinableLiteralV2I16(Val);

if (OperandType == AMDGPU::OPERAND_REG_INLINE_C_V2FP16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_V2FP16 ||
OperandType == AMDGPU::OPERAND_REG_IMM_V2FP16)
return AMDGPU::isInlinableLiteralV216(Val, hasInv2PiInlineImm());
return AMDGPU::isInlinableLiteralV2F16(Val);

return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm());
}
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ static DecodeStatus decodeSplitBarrier(MCInst &Inst, unsigned Val,
DECODE_SrcOp(decodeOperand_##RegClass##_Imm##ImmWidth, 9, OpWidth, Imm, \
false, ImmWidth)

#define DECODE_OPERAND_SRC_REG_OR_IMM_9_TYPED(Name, OpWidth, ImmWidth) \
DECODE_SrcOp(decodeOperand_##Name, 9, OpWidth, Imm, false, ImmWidth)

// Decoder for Src(9-bit encoding) AGPR or immediate. Set Imm{9} to 1 (set acc)
// and decode using 'enum10' from decodeSrcOp.
#define DECODE_OPERAND_SRC_REG_OR_IMM_A9(RegClass, OpWidth, ImmWidth) \
Expand Down Expand Up @@ -262,6 +265,9 @@ DECODE_OPERAND_SRC_REG_OR_IMM_9(VReg_256, OPW256, 64)
DECODE_OPERAND_SRC_REG_OR_IMM_9(VReg_512, OPW512, 32)
DECODE_OPERAND_SRC_REG_OR_IMM_9(VReg_1024, OPW1024, 32)

DECODE_OPERAND_SRC_REG_OR_IMM_9_TYPED(VS_32_ImmV2I16, OPW32, 32)
DECODE_OPERAND_SRC_REG_OR_IMM_9_TYPED(VS_32_ImmV2F16, OPW32, 16)

DECODE_OPERAND_SRC_REG_OR_IMM_A9(AReg_64, OPW64, 64)
DECODE_OPERAND_SRC_REG_OR_IMM_A9(AReg_128, OPW128, 32)
DECODE_OPERAND_SRC_REG_OR_IMM_A9(AReg_256, OPW256, 64)
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/GCNSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,
bool hasDstSelForwardingHazard() const { return GFX940Insts; }

// Cannot use op_sel with v_dot instructions.
bool hasDOTOpSelHazard() const { return GFX940Insts; }
bool hasDOTOpSelHazard() const { return GFX940Insts || GFX11Insts; }

// Does not have HW interlocs for VALU writing and then reading SGPRs.
bool hasVDecCoExecHazard() const {
Expand Down
125 changes: 78 additions & 47 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,56 +460,84 @@ void AMDGPUInstPrinter::printImmediateInt16(uint32_t Imm,
}
}

void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int16_t SImm = static_cast<int16_t>(Imm);
if (isInlinableIntLiteral(SImm)) {
O << SImm;
return;
}

// This must accept a 32-bit immediate value to correctly handle packed 16-bit
// operations.
static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O) {
if (Imm == 0x3C00)
O<< "1.0";
O << "1.0";
else if (Imm == 0xBC00)
O<< "-1.0";
O << "-1.0";
else if (Imm == 0x3800)
O<< "0.5";
O << "0.5";
else if (Imm == 0xB800)
O<< "-0.5";
O << "-0.5";
else if (Imm == 0x4000)
O<< "2.0";
O << "2.0";
else if (Imm == 0xC000)
O<< "-2.0";
O << "-2.0";
else if (Imm == 0x4400)
O<< "4.0";
O << "4.0";
else if (Imm == 0xC400)
O<< "-4.0";
else if (Imm == 0x3118 &&
STI.hasFeature(AMDGPU::FeatureInv2PiInlineImm)) {
O << "-4.0";
else if (Imm == 0x3118 && STI.hasFeature(AMDGPU::FeatureInv2PiInlineImm))
O << "0.15915494";
} else {
uint64_t Imm16 = static_cast<uint16_t>(Imm);
O << formatHex(Imm16);
}
}
else
return false;

void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
uint16_t Lo16 = static_cast<uint16_t>(Imm);
printImmediate16(Lo16, STI, O);
return true;
}

void AMDGPUInstPrinter::printImmediate32(uint32_t Imm,
void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int16_t SImm = static_cast<int16_t>(Imm);
if (isInlinableIntLiteral(SImm)) {
O << SImm;
return;
}

uint16_t HImm = static_cast<uint16_t>(Imm);
if (printImmediateFloat16(HImm, STI, O))
return;

uint64_t Imm16 = static_cast<uint16_t>(Imm);
O << formatHex(Imm16);
}

void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int32_t SImm = static_cast<int32_t>(Imm);
if (SImm >= -16 && SImm <= 64) {
if (isInlinableIntLiteral(SImm)) {
O << SImm;
return;
}

switch (OpType) {
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
if (printImmediateFloat32(Imm, STI, O))
return;
break;
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
if (isUInt<16>(Imm) &&
printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
return;
break;
default:
llvm_unreachable("bad operand type");
}

O << formatHex(static_cast<uint64_t>(Imm));
}

bool AMDGPUInstPrinter::printImmediateFloat32(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
if (Imm == llvm::bit_cast<uint32_t>(0.0f))
O << "0.0";
else if (Imm == llvm::bit_cast<uint32_t>(1.0f))
Expand All @@ -532,7 +560,24 @@ void AMDGPUInstPrinter::printImmediate32(uint32_t Imm,
STI.hasFeature(AMDGPU::FeatureInv2PiInlineImm))
O << "0.15915494";
else
O << formatHex(static_cast<uint64_t>(Imm));
return false;

return true;
}

void AMDGPUInstPrinter::printImmediate32(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int32_t SImm = static_cast<int32_t>(Imm);
if (isInlinableIntLiteral(SImm)) {
O << SImm;
return;
}

if (printImmediateFloat32(Imm, STI, O))
return;

O << formatHex(static_cast<uint64_t>(Imm));
}

void AMDGPUInstPrinter::printImmediate64(uint64_t Imm,
Expand Down Expand Up @@ -755,25 +800,11 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
break;
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
if (!isUInt<16>(Op.getImm()) &&
STI.hasFeature(AMDGPU::FeatureVOP3Literal)) {
printImmediate32(Op.getImm(), STI, O);
break;
}

// Deal with 16-bit FP inline immediates not working.
if (OpTy == AMDGPU::OPERAND_REG_IMM_V2FP16) {
printImmediate16(static_cast<uint16_t>(Op.getImm()), STI, O);
break;
}
[[fallthrough]];
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
printImmediateInt16(static_cast<uint16_t>(Op.getImm()), STI, O);
break;
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
printImmediateV216(Op.getImm(), STI, O);
printImmediateV216(Op.getImm(), OpTy, STI, O);
break;
case MCOI::OPERAND_UNKNOWN:
case MCOI::OPERAND_PCREL:
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@ class AMDGPUInstPrinter : public MCInstPrinter {
raw_ostream &O);
void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateV216(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateV216(uint32_t Imm, uint8_t OpType,
const MCSubtargetInfo &STI, raw_ostream &O);
bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediate32(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediate64(uint64_t Imm, const MCSubtargetInfo &STI,
Expand Down
19 changes: 6 additions & 13 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,22 +284,15 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
// which does not have f16 support?
return getLit16Encoding(static_cast<uint16_t>(Imm), STI);
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_IMM_V2FP16: {
if (!isUInt<16>(Imm) && STI.hasFeature(AMDGPU::FeatureVOP3Literal))
return getLit32Encoding(static_cast<uint32_t>(Imm), STI);
if (OpInfo.OperandType == AMDGPU::OPERAND_REG_IMM_V2FP16)
return getLit16Encoding(static_cast<uint16_t>(Imm), STI);
[[fallthrough]];
}
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);
return AMDGPU::getInlineEncodingV2I16(static_cast<uint32_t>(Imm))
.value_or(255);
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
uint16_t Lo16 = static_cast<uint16_t>(Imm);
uint32_t Encoding = getLit16Encoding(Lo16, STI);
return Encoding;
}
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
return AMDGPU::getInlineEncodingV2F16(static_cast<uint32_t>(Imm))
.value_or(255);
case AMDGPU::OPERAND_KIMM32:
case AMDGPU::OPERAND_KIMM16:
return MO.getImm();
Expand Down
148 changes: 119 additions & 29 deletions llvm/lib/Target/AMDGPU/SIFoldOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,7 @@ bool SIFoldOperands::canUseImmWithOpSel(FoldCandidate &Fold) const {
assert(Old.isReg() && Fold.isImm());

if (!(TSFlags & SIInstrFlags::IsPacked) || (TSFlags & SIInstrFlags::IsMAI) ||
(ST->hasDOTOpSelHazard() && (TSFlags & SIInstrFlags::IsDOT)) ||
isUInt<16>(Fold.ImmToFold) ||
!AMDGPU::isFoldableLiteralV216(Fold.ImmToFold, ST->hasInv2PiInlineImm()))
(ST->hasDOTOpSelHazard() && (TSFlags & SIInstrFlags::IsDOT)))
return false;

unsigned Opcode = MI->getOpcode();
Expand All @@ -234,51 +232,143 @@ bool SIFoldOperands::tryFoldImmWithOpSel(FoldCandidate &Fold) const {
MachineOperand &Old = MI->getOperand(Fold.UseOpNo);
unsigned Opcode = MI->getOpcode();
int OpNo = MI->getOperandNo(&Old);
uint8_t OpType = TII->get(Opcode).operands()[OpNo].OperandType;

// If the literal can be inlined as-is, apply it and short-circuit the
// tests below. The main motivation for this is to avoid unintuitive
// uses of opsel.
if (AMDGPU::isInlinableLiteralV216(Fold.ImmToFold, OpType)) {
Old.ChangeToImmediate(Fold.ImmToFold);
return true;
}

// Set op_sel/op_sel_hi on this operand or bail out if op_sel is
// already set.
// Refer to op_sel/op_sel_hi and check if we can change the immediate and
// op_sel in a way that allows an inline constant.
int ModIdx = -1;
if (OpNo == AMDGPU::getNamedOperandIdx(Opcode, AMDGPU::OpName::src0))
unsigned SrcIdx = ~0;
if (OpNo == AMDGPU::getNamedOperandIdx(Opcode, AMDGPU::OpName::src0)) {
ModIdx = AMDGPU::OpName::src0_modifiers;
else if (OpNo == AMDGPU::getNamedOperandIdx(Opcode, AMDGPU::OpName::src1))
SrcIdx = 0;
} else if (OpNo == AMDGPU::getNamedOperandIdx(Opcode, AMDGPU::OpName::src1)) {
ModIdx = AMDGPU::OpName::src1_modifiers;
else if (OpNo == AMDGPU::getNamedOperandIdx(Opcode, AMDGPU::OpName::src2))
SrcIdx = 1;
} else if (OpNo == AMDGPU::getNamedOperandIdx(Opcode, AMDGPU::OpName::src2)) {
ModIdx = AMDGPU::OpName::src2_modifiers;
SrcIdx = 2;
}
assert(ModIdx != -1);
ModIdx = AMDGPU::getNamedOperandIdx(Opcode, ModIdx);
MachineOperand &Mod = MI->getOperand(ModIdx);
unsigned Val = Mod.getImm();
if ((Val & SISrcMods::OP_SEL_0) || !(Val & SISrcMods::OP_SEL_1))
unsigned ModVal = Mod.getImm();

uint16_t ImmLo = static_cast<uint16_t>(
Fold.ImmToFold >> (ModVal & SISrcMods::OP_SEL_0 ? 16 : 0));
uint16_t ImmHi = static_cast<uint16_t>(
Fold.ImmToFold >> (ModVal & SISrcMods::OP_SEL_1 ? 16 : 0));
uint32_t Imm = (static_cast<uint32_t>(ImmHi) << 16) | ImmLo;
unsigned NewModVal = ModVal & ~(SISrcMods::OP_SEL_0 | SISrcMods::OP_SEL_1);

// Helper function that attempts to inline the given value with a newly
// chosen opsel pattern.
auto tryFoldToInline = [&](uint32_t Imm) -> bool {
if (AMDGPU::isInlinableLiteralV216(Imm, OpType)) {
Mod.setImm(NewModVal | SISrcMods::OP_SEL_1);
Old.ChangeToImmediate(Imm);
return true;
}

// Try to shuffle the halves around and leverage opsel to get an inline
// constant.
uint16_t Lo = static_cast<uint16_t>(Imm);
uint16_t Hi = static_cast<uint16_t>(Imm >> 16);
if (Lo == Hi) {
if (AMDGPU::isInlinableLiteralV216(Lo, OpType)) {
Mod.setImm(NewModVal);
Old.ChangeToImmediate(Lo);
return true;
}

if (static_cast<int16_t>(Lo) < 0) {
int32_t SExt = static_cast<int16_t>(Lo);
if (AMDGPU::isInlinableLiteralV216(SExt, OpType)) {
Mod.setImm(NewModVal);
Old.ChangeToImmediate(SExt);
return true;
}
}

// This check is only useful for integer instructions
if (OpType == AMDGPU::OPERAND_REG_IMM_V2INT16 ||
OpType == AMDGPU::OPERAND_REG_INLINE_AC_V2INT16) {
if (AMDGPU::isInlinableLiteralV216(Lo << 16, OpType)) {
Mod.setImm(NewModVal | SISrcMods::OP_SEL_0 | SISrcMods::OP_SEL_1);
Old.ChangeToImmediate(static_cast<uint32_t>(Lo) << 16);
return true;
}
}
} else {
uint32_t Swapped = (static_cast<uint32_t>(Lo) << 16) | Hi;
if (AMDGPU::isInlinableLiteralV216(Swapped, OpType)) {
Mod.setImm(NewModVal | SISrcMods::OP_SEL_0);
Old.ChangeToImmediate(Swapped);
return true;
}
}

return false;
};

// Only apply the following transformation if that operand requires
// a packed immediate.
// If upper part is all zero we do not need op_sel_hi.
if (!(Fold.ImmToFold & 0xffff)) {
MachineOperand New =
MachineOperand::CreateImm((Fold.ImmToFold >> 16) & 0xffff);
if (!TII->isOperandLegal(*MI, OpNo, &New))
return false;
Mod.setImm(Mod.getImm() | SISrcMods::OP_SEL_0);
Mod.setImm(Mod.getImm() & ~SISrcMods::OP_SEL_1);
Old.ChangeToImmediate((Fold.ImmToFold >> 16) & 0xffff);
if (tryFoldToInline(Imm))
return true;

// Replace integer addition by subtraction and vice versa if it allows
// folding the immediate to an inline constant.
//
// We should only ever get here for SrcIdx == 1 due to canonicalization
// earlier in the pipeline, but we double-check here to be safe / fully
// general.
bool IsUAdd = Opcode == AMDGPU::V_PK_ADD_U16;
bool IsUSub = Opcode == AMDGPU::V_PK_SUB_U16;
if (SrcIdx == 1 && (IsUAdd || IsUSub)) {
unsigned ClampIdx =
AMDGPU::getNamedOperandIdx(Opcode, AMDGPU::OpName::clamp);
bool Clamp = MI->getOperand(ClampIdx).getImm() != 0;

if (!Clamp) {
uint16_t NegLo = -static_cast<uint16_t>(Imm);
uint16_t NegHi = -static_cast<uint16_t>(Imm >> 16);
uint32_t NegImm = (static_cast<uint32_t>(NegHi) << 16) | NegLo;

if (tryFoldToInline(NegImm)) {
unsigned NegOpcode =
IsUAdd ? AMDGPU::V_PK_SUB_U16 : AMDGPU::V_PK_ADD_U16;
MI->setDesc(TII->get(NegOpcode));
return true;
}
}
}
MachineOperand New = MachineOperand::CreateImm(Fold.ImmToFold & 0xffff);
if (!TII->isOperandLegal(*MI, OpNo, &New))
return false;
Mod.setImm(Mod.getImm() & ~SISrcMods::OP_SEL_1);
Old.ChangeToImmediate(Fold.ImmToFold & 0xffff);
return true;

return false;
}

bool SIFoldOperands::updateOperand(FoldCandidate &Fold) const {
MachineInstr *MI = Fold.UseMI;
MachineOperand &Old = MI->getOperand(Fold.UseOpNo);
assert(Old.isReg());

if (Fold.isImm() && canUseImmWithOpSel(Fold))
return tryFoldImmWithOpSel(Fold);
if (Fold.isImm() && canUseImmWithOpSel(Fold)) {
if (tryFoldImmWithOpSel(Fold))
return true;

// We can't represent the candidate as an inline constant. Try as a literal
// with the original opsel, checking constant bus limitations.
MachineOperand New = MachineOperand::CreateImm(Fold.ImmToFold);
int OpNo = MI->getOperandNo(&Old);
if (!TII->isOperandLegal(*MI, OpNo, &New))
return false;
Old.ChangeToImmediate(Fold.ImmToFold);
return true;
}

if ((Fold.isImm() || Fold.isFI() || Fold.isGlobal()) && Fold.needsShrink()) {
MachineBasicBlock *MBB = MI->getParent();
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6890,6 +6890,7 @@ SDValue SITargetLowering::LowerGlobalAddress(AMDGPUMachineFunction *MFI,
// Adjust alignment for that dynamic shared memory array.
Function &F = DAG.getMachineFunction().getFunction();
MFI->setDynLDSAlign(F, *cast<GlobalVariable>(GV));
MFI->setUsesDynamicLDS(true);
return SDValue(
DAG.getMachineNode(AMDGPU::GET_GROUPSTATICSIZE, DL, PtrVT), 0);
}
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4153,15 +4153,15 @@ bool SIInstrInfo::isInlineConstant(const MachineOperand &MO,
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
return (isInt<16>(Imm) || isUInt<16>(Imm)) &&
AMDGPU::isInlinableIntLiteral((int16_t)Imm);
return AMDGPU::isInlinableLiteralV2I16(Imm);
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
return AMDGPU::isInlinableLiteralV2F16(Imm);
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
case AMDGPU::OPERAND_REG_INLINE_AC_FP16: {
if (isInt<16>(Imm) || isUInt<16>(Imm)) {
// A few special case instructions have 16-bit operands on subtargets
// where 16-bit instructions are not legal.
Expand Down
17 changes: 0 additions & 17 deletions llvm/lib/Target/AMDGPU/SIInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -860,23 +860,6 @@ def ShiftAmt32Imm : ImmLeaf <i32, [{
return Imm < 32;
}]>;

def getNegV2I16Imm : SDNodeXForm<build_vector, [{
return SDValue(packNegConstantV2I16(N, *CurDAG), 0);
}]>;

def NegSubInlineConstV216 : PatLeaf<(build_vector), [{
assert(N->getNumOperands() == 2);
assert(N->getOperand(0).getValueType().getSizeInBits() == 16);
SDValue Src0 = N->getOperand(0);
SDValue Src1 = N->getOperand(1);
if (Src0 == Src1)
return isNegInlineImmediate(Src0.getNode());

return (isNullConstantOrUndef(Src0) && isNegInlineImmediate(Src1.getNode())) ||
(isNullConstantOrUndef(Src1) && isNegInlineImmediate(Src0.getNode()));
}], getNegV2I16Imm>;


def fp16_zeros_high_16bits : PatLeaf<(f16 VGPR_32:$src), [{
return fp16SrcZerosHighBits(N->getOpcode());
}]>;
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AMDGPU/SIRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1152,11 +1152,11 @@ class RegOrF32 <string RegisterClass, string OperandTypePrefix>

class RegOrV2B16 <string RegisterClass, string OperandTypePrefix>
: RegOrImmOperand <RegisterClass, OperandTypePrefix # "_V2INT16",
!subst("_v2b16", "V2B16", NAME), "_Imm16">;
!subst("_v2b16", "V2B16", NAME), "_ImmV2I16">;

class RegOrV2F16 <string RegisterClass, string OperandTypePrefix>
: RegOrImmOperand <RegisterClass, OperandTypePrefix # "_V2FP16",
!subst("_v2f16", "V2F16", NAME), "_Imm16">;
!subst("_v2f16", "V2F16", NAME), "_ImmV2F16">;

class RegOrF64 <string RegisterClass, string OperandTypePrefix>
: RegOrImmOperand <RegisterClass, OperandTypePrefix # "_FP64",
Expand Down
106 changes: 74 additions & 32 deletions llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2506,53 +2506,95 @@ bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi) {
Val == 0x3118; // 1/2pi
}

bool isInlinableLiteralV216(int32_t Literal, bool HasInv2Pi) {
assert(HasInv2Pi);

if (isInt<16>(Literal) || isUInt<16>(Literal)) {
int16_t Trunc = static_cast<int16_t>(Literal);
return AMDGPU::isInlinableLiteral16(Trunc, HasInv2Pi);
std::optional<unsigned> getInlineEncodingV216(bool IsFloat, uint32_t Literal) {
// Unfortunately, the Instruction Set Architecture Reference Guide is
// misleading about how the inline operands work for (packed) 16-bit
// instructions. In a nutshell, the actual HW behavior is:
//
// - integer encodings (-16 .. 64) are always produced as sign-extended
// 32-bit values
// - float encodings are produced as:
// - for F16 instructions: corresponding half-precision float values in
// the LSBs, 0 in the MSBs
// - for UI16 instructions: corresponding single-precision float value
int32_t Signed = static_cast<int32_t>(Literal);
if (Signed >= 0 && Signed <= 64)
return 128 + Signed;

if (Signed >= -16 && Signed <= -1)
return 192 + std::abs(Signed);

if (IsFloat) {
// clang-format off
switch (Literal) {
case 0x3800: return 240; // 0.5
case 0xB800: return 241; // -0.5
case 0x3C00: return 242; // 1.0
case 0xBC00: return 243; // -1.0
case 0x4000: return 244; // 2.0
case 0xC000: return 245; // -2.0
case 0x4400: return 246; // 4.0
case 0xC400: return 247; // -4.0
case 0x3118: return 248; // 1.0 / (2.0 * pi)
default: break;
}
// clang-format on
} else {
// clang-format off
switch (Literal) {
case 0x3F000000: return 240; // 0.5
case 0xBF000000: return 241; // -0.5
case 0x3F800000: return 242; // 1.0
case 0xBF800000: return 243; // -1.0
case 0x40000000: return 244; // 2.0
case 0xC0000000: return 245; // -2.0
case 0x40800000: return 246; // 4.0
case 0xC0800000: return 247; // -4.0
case 0x3E22F983: return 248; // 1.0 / (2.0 * pi)
default: break;
}
// clang-format on
}
if (!(Literal & 0xffff))
return AMDGPU::isInlinableLiteral16(Literal >> 16, HasInv2Pi);

int16_t Lo16 = static_cast<int16_t>(Literal);
int16_t Hi16 = static_cast<int16_t>(Literal >> 16);
return Lo16 == Hi16 && isInlinableLiteral16(Lo16, HasInv2Pi);
return {};
}

bool isInlinableIntLiteralV216(int32_t Literal) {
int16_t Lo16 = static_cast<int16_t>(Literal);
if (isInt<16>(Literal) || isUInt<16>(Literal))
return isInlinableIntLiteral(Lo16);
// Encoding of the literal as an inline constant for a V_PK_*_IU16 instruction
// or nullopt.
std::optional<unsigned> getInlineEncodingV2I16(uint32_t Literal) {
return getInlineEncodingV216(false, Literal);
}

int16_t Hi16 = static_cast<int16_t>(Literal >> 16);
if (!(Literal & 0xffff))
return isInlinableIntLiteral(Hi16);
return Lo16 == Hi16 && isInlinableIntLiteral(Lo16);
// Encoding of the literal as an inline constant for a V_PK_*_F16 instruction
// or nullopt.
std::optional<unsigned> getInlineEncodingV2F16(uint32_t Literal) {
return getInlineEncodingV216(true, Literal);
}

bool isInlinableLiteralV216(int32_t Literal, bool HasInv2Pi, uint8_t OpType) {
// Whether the given literal can be inlined for a V_PK_* instruction.
bool isInlinableLiteralV216(uint32_t Literal, uint8_t OpType) {
switch (OpType) {
case AMDGPU::OPERAND_REG_IMM_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
return getInlineEncodingV216(false, Literal).has_value();
case AMDGPU::OPERAND_REG_IMM_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
return isInlinableLiteralV216(Literal, HasInv2Pi);
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
return getInlineEncodingV216(true, Literal).has_value();
default:
return isInlinableIntLiteralV216(Literal);
llvm_unreachable("bad packed operand type");
}
}

bool isFoldableLiteralV216(int32_t Literal, bool HasInv2Pi) {
assert(HasInv2Pi);

int16_t Lo16 = static_cast<int16_t>(Literal);
if (isInt<16>(Literal) || isUInt<16>(Literal))
return true;
// Whether the given literal can be inlined for a V_PK_*_IU16 instruction.
bool isInlinableLiteralV2I16(uint32_t Literal) {
return getInlineEncodingV2I16(Literal).has_value();
}

int16_t Hi16 = static_cast<int16_t>(Literal >> 16);
if (!(Literal & 0xffff))
return true;
return Lo16 == Hi16;
// Whether the given literal can be inlined for a V_PK_*_F16 instruction.
bool isInlinableLiteralV2F16(uint32_t Literal) {
return getInlineEncodingV2F16(Literal).has_value();
}

bool isValid32BitLiteral(uint64_t Val, bool IsFP64) {
Expand Down
11 changes: 7 additions & 4 deletions llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1291,16 +1291,19 @@ LLVM_READNONE
bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi);

LLVM_READNONE
bool isInlinableLiteralV216(int32_t Literal, bool HasInv2Pi);
std::optional<unsigned> getInlineEncodingV2I16(uint32_t Literal);

LLVM_READNONE
bool isInlinableIntLiteralV216(int32_t Literal);
std::optional<unsigned> getInlineEncodingV2F16(uint32_t Literal);

LLVM_READNONE
bool isInlinableLiteralV216(int32_t Literal, bool HasInv2Pi, uint8_t OpType);
bool isInlinableLiteralV216(uint32_t Literal, uint8_t OpType);

LLVM_READNONE
bool isFoldableLiteralV216(int32_t Literal, bool HasInv2Pi);
bool isInlinableLiteralV2I16(uint32_t Literal);

LLVM_READNONE
bool isInlinableLiteralV2F16(uint32_t Literal);

LLVM_READNONE
bool isValid32BitLiteral(uint64_t Val, bool IsFP64);
Expand Down
9 changes: 0 additions & 9 deletions llvm/lib/Target/AMDGPU/VOP3PInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,6 @@ defm V_PK_LSHRREV_B16 : VOP3PInst<"v_pk_lshrrev_b16", VOP3P_Profile<VOP_V2I16_V2

let SubtargetPredicate = HasVOP3PInsts in {

// Undo sub x, c -> add x, -c canonicalization since c is more likely
// an inline immediate than -c.
// The constant will be emitted as a mov, and folded later.
// TODO: We could directly encode the immediate now
def : GCNPat<
(add (v2i16 (VOP3PMods v2i16:$src0, i32:$src0_modifiers)), NegSubInlineConstV216:$src1),
(V_PK_SUB_U16 $src0_modifiers, $src0, SRCMODS.OP_SEL_1, NegSubInlineConstV216:$src1)
>;

// Integer operations with clamp bit set.
class VOP3PSatPat<SDPatternOperator pat, Instruction inst> : GCNPat<
(pat (v2i16 (VOP3PMods v2i16:$src0, i32:$src0_modifiers)),
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/ARM/ARMFastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ bool ARMFastISel::ARMComputeAddress(const Value *Obj, Address &Addr) {
unsigned Idx = cast<ConstantInt>(Op)->getZExtValue();
TmpOffset += SL->getElementOffset(Idx);
} else {
uint64_t S = DL.getTypeAllocSize(GTI.getIndexedType());
uint64_t S = GTI.getSequentialElementStride(DL);
while (true) {
if (const ConstantInt *CI = dyn_cast<ConstantInt>(Op)) {
// Constant-offset addressing.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/Mips/MipsFastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ bool MipsFastISel::computeAddress(const Value *Obj, Address &Addr) {
unsigned Idx = cast<ConstantInt>(Op)->getZExtValue();
TmpOffset += SL->getElementOffset(Idx);
} else {
uint64_t S = DL.getTypeAllocSize(GTI.getIndexedType());
uint64_t S = GTI.getSequentialElementStride(DL);
while (true) {
if (const ConstantInt *CI = dyn_cast<ConstantInt>(Op)) {
// Constant-offset addressing.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/PowerPC/PPCFastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ bool PPCFastISel::PPCComputeAddress(const Value *Obj, Address &Addr) {
unsigned Idx = cast<ConstantInt>(Op)->getZExtValue();
TmpOffset += SL->getElementOffset(Idx);
} else {
uint64_t S = DL.getTypeAllocSize(GTI.getIndexedType());
uint64_t S = GTI.getSequentialElementStride(DL);
for (;;) {
if (const ConstantInt *CI = dyn_cast<ConstantInt>(Op)) {
// Constant-offset addressing.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,

VecOperand = i;

TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType());
TypeSize TS = GTI.getSequentialElementStride(*DL);
if (TS.isScalable())
return std::make_pair(nullptr, nullptr);

Expand Down
78 changes: 45 additions & 33 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2711,11 +2711,19 @@ InstructionCost RISCVTargetLowering::getVRGatherVICost(MVT VT) const {
return getLMULCost(VT);
}

/// Return the cost of a vslidedown.vi/vx or vslideup.vi/vx instruction
/// Return the cost of a vslidedown.vx or vslideup.vx instruction
/// for the type VT. (This does not cover the vslide1up or vslide1down
/// variants.) Slides may be linear in the number of vregs implied by LMUL,
/// or may track the vrgather.vv cost. It is implementation-dependent.
InstructionCost RISCVTargetLowering::getVSlideCost(MVT VT) const {
InstructionCost RISCVTargetLowering::getVSlideVXCost(MVT VT) const {
return getLMULCost(VT);
}

/// Return the cost of a vslidedown.vi or vslideup.vi instruction
/// for the type VT. (This does not cover the vslide1up or vslide1down
/// variants.) Slides may be linear in the number of vregs implied by LMUL,
/// or may track the vrgather.vv cost. It is implementation-dependent.
InstructionCost RISCVTargetLowering::getVSlideVICost(MVT VT) const {
return getLMULCost(VT);
}

Expand Down Expand Up @@ -2811,8 +2819,8 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
SDValue SplatZero = DAG.getNode(
RISCVISD::VMV_V_X_VL, DL, DstContainerVT, DAG.getUNDEF(DstContainerVT),
DAG.getConstant(0, DL, Subtarget.getXLenVT()), VL);
Res = DAG.getNode(RISCVISD::VSELECT_VL, DL, DstContainerVT, IsNan, SplatZero,
Res, VL);
Res = DAG.getNode(RISCVISD::VMERGE_VL, DL, DstContainerVT, IsNan, SplatZero,
Res, DAG.getUNDEF(DstContainerVT), VL);

if (DstVT.isFixedLengthVector())
Res = convertFromScalableVector(DstVT, Res, DAG, Subtarget);
Expand Down Expand Up @@ -5401,17 +5409,17 @@ static SDValue lowerFMAXIMUM_FMINIMUM(SDValue Op, SelectionDAG &DAG,
SDValue XIsNonNan = DAG.getNode(RISCVISD::SETCC_VL, DL, Mask.getValueType(),
{X, X, DAG.getCondCode(ISD::SETOEQ),
DAG.getUNDEF(ContainerVT), Mask, VL});
NewY =
DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, XIsNonNan, Y, X, VL);
NewY = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, XIsNonNan, Y, X,
DAG.getUNDEF(ContainerVT), VL);
}

SDValue NewX = X;
if (!YIsNeverNan) {
SDValue YIsNonNan = DAG.getNode(RISCVISD::SETCC_VL, DL, Mask.getValueType(),
{Y, Y, DAG.getCondCode(ISD::SETOEQ),
DAG.getUNDEF(ContainerVT), Mask, VL});
NewX =
DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, YIsNonNan, X, Y, VL);
NewX = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, YIsNonNan, X, Y,
DAG.getUNDEF(ContainerVT), VL);
}

unsigned Opc =
Expand Down Expand Up @@ -5528,7 +5536,6 @@ static unsigned getRISCVVLOp(SDValue Op) {
return RISCVISD::VMXOR_VL;
return RISCVISD::XOR_VL;
case ISD::VP_SELECT:
return RISCVISD::VSELECT_VL;
case ISD::VP_MERGE:
return RISCVISD::VMERGE_VL;
case ISD::VP_ASHR:
Expand Down Expand Up @@ -5563,7 +5570,7 @@ static bool hasMergeOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
125 &&
124 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
Expand All @@ -5589,7 +5596,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
125 &&
124 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
Expand Down Expand Up @@ -7456,8 +7463,9 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
DAG.getUNDEF(ContainerVT), SplatZero, VL);
SplatTrueVal = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
DAG.getUNDEF(ContainerVT), SplatTrueVal, VL);
SDValue Select = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, CC,
SplatTrueVal, SplatZero, VL);
SDValue Select =
DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, CC, SplatTrueVal,
SplatZero, DAG.getUNDEF(ContainerVT), VL);

return convertFromScalableVector(VecVT, Select, DAG, Subtarget);
}
Expand Down Expand Up @@ -8240,8 +8248,8 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
return Vec;
// TAMU
if (Policy == RISCVII::TAIL_AGNOSTIC)
return DAG.getNode(RISCVISD::VSELECT_VL, DL, VT, Mask, Vec, MaskedOff,
AVL);
return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Mask, Vec, MaskedOff,
DAG.getUNDEF(VT), AVL);
// TUMA or TUMU: Currently we always emit tumu policy regardless of tuma.
// It's fine because vmerge does not care mask policy.
return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Mask, Vec, MaskedOff,
Expand Down Expand Up @@ -8489,8 +8497,8 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT,
{VID, SplattedIdx, DAG.getCondCode(ISD::SETEQ),
DAG.getUNDEF(MaskVT), Mask, VL});
return DAG.getNode(RISCVISD::VSELECT_VL, DL, VT, SelectCond, SplattedVal,
Vec, VL);
return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, SelectCond, SplattedVal,
Vec, DAG.getUNDEF(VT), VL);
}
// EGS * EEW >= 128 bits
case Intrinsic::riscv_vaesdf_vv:
Expand Down Expand Up @@ -10243,8 +10251,8 @@ SDValue RISCVTargetLowering::lowerFixedLengthVectorSelectToRVV(
SDLoc DL(Op);
SDValue VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;

SDValue Select =
DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, CC, Op1, Op2, VL);
SDValue Select = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, CC, Op1,
Op2, DAG.getUNDEF(ContainerVT), VL);

return convertFromScalableVector(VT, Select, DAG, Subtarget);
}
Expand Down Expand Up @@ -10327,9 +10335,14 @@ SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG) const {
Ops.push_back(DAG.getUNDEF(ContainerVT));
} else if (ISD::getVPExplicitVectorLengthIdx(Op.getOpcode()) ==
OpIdx.index()) {
// For VP_MERGE, copy the false operand instead of an undef value.
assert(Op.getOpcode() == ISD::VP_MERGE);
Ops.push_back(Ops.back());
if (Op.getOpcode() == ISD::VP_MERGE) {
// For VP_MERGE, copy the false operand instead of an undef value.
Ops.push_back(Ops.back());
} else {
assert(Op.getOpcode() == ISD::VP_SELECT);
// For VP_SELECT, add an undef value.
Ops.push_back(DAG.getUNDEF(ContainerVT));
}
}
}
// Pass through operands which aren't fixed-length vectors.
Expand Down Expand Up @@ -10379,8 +10392,8 @@ SDValue RISCVTargetLowering::lowerVPExtMaskOp(SDValue Op,
SDValue Splat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
DAG.getUNDEF(ContainerVT), SplatValue, VL);

SDValue Result = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, Src,
Splat, ZeroSplat, VL);
SDValue Result = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, Src, Splat,
ZeroSplat, DAG.getUNDEF(ContainerVT), VL);
if (!VT.isFixedLengthVector())
return Result;
return convertFromScalableVector(VT, Result, DAG, Subtarget);
Expand Down Expand Up @@ -10508,8 +10521,8 @@ SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op,
RISCVISDExtOpc == RISCVISD::VZEXT_VL ? 1 : -1, DL, XLenVT);
SDValue OneSplat = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT,
DAG.getUNDEF(IntVT), One, VL);
Src = DAG.getNode(RISCVISD::VSELECT_VL, DL, IntVT, Src, OneSplat,
ZeroSplat, VL);
Src = DAG.getNode(RISCVISD::VMERGE_VL, DL, IntVT, Src, OneSplat,
ZeroSplat, DAG.getUNDEF(IntVT), VL);
} else if (DstEltSize > (2 * SrcEltSize)) {
// Widen before converting.
MVT IntVT = MVT::getVectorVT(MVT::getIntegerVT(DstEltSize / 2),
Expand Down Expand Up @@ -10633,17 +10646,17 @@ RISCVTargetLowering::lowerVPSpliceExperimental(SDValue Op,
SDValue SplatZeroOp1 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
DAG.getUNDEF(ContainerVT),
DAG.getConstant(0, DL, XLenVT), EVL1);
Op1 = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, Op1, SplatOneOp1,
SplatZeroOp1, EVL1);
Op1 = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, Op1, SplatOneOp1,
SplatZeroOp1, DAG.getUNDEF(ContainerVT), EVL1);

SDValue SplatOneOp2 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
DAG.getUNDEF(ContainerVT),
DAG.getConstant(1, DL, XLenVT), EVL2);
SDValue SplatZeroOp2 = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, ContainerVT,
DAG.getUNDEF(ContainerVT),
DAG.getConstant(0, DL, XLenVT), EVL2);
Op2 = DAG.getNode(RISCVISD::VSELECT_VL, DL, ContainerVT, Op2, SplatOneOp2,
SplatZeroOp2, EVL2);
Op2 = DAG.getNode(RISCVISD::VMERGE_VL, DL, ContainerVT, Op2, SplatOneOp2,
SplatZeroOp2, DAG.getUNDEF(ContainerVT), EVL2);
}

int64_t ImmValue = cast<ConstantSDNode>(Offset)->getSExtValue();
Expand Down Expand Up @@ -10713,8 +10726,8 @@ RISCVTargetLowering::lowerVPReverseExperimental(SDValue Op,
SDValue SplatZero = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IndicesVT,
DAG.getUNDEF(IndicesVT),
DAG.getConstant(0, DL, XLenVT), EVL);
Op1 = DAG.getNode(RISCVISD::VSELECT_VL, DL, IndicesVT, Op1, SplatOne,
SplatZero, EVL);
Op1 = DAG.getNode(RISCVISD::VMERGE_VL, DL, IndicesVT, Op1, SplatOne,
SplatZero, DAG.getUNDEF(IndicesVT), EVL);
}

unsigned EltSize = GatherVT.getScalarSizeInBits();
Expand Down Expand Up @@ -18683,7 +18696,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VWMACCSU_VL)
NODE_NAME_CASE(VNSRL_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VSELECT_VL)
NODE_NAME_CASE(VMERGE_VL)
NODE_NAME_CASE(VMAND_VL)
NODE_NAME_CASE(VMOR_VL)
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,8 @@ enum NodeType : unsigned {
// operand is VL.
SETCC_VL,

// Vector select with an additional VL operand. This operation is unmasked.
VSELECT_VL,
// General vmerge node with mask, true, false, passthru, and vl operands.
// Tail agnostic vselect can be implemented by setting passthru to undef.
VMERGE_VL,

// Mask binary operators.
Expand Down Expand Up @@ -526,7 +525,8 @@ class RISCVTargetLowering : public TargetLowering {

InstructionCost getVRGatherVVCost(MVT VT) const;
InstructionCost getVRGatherVICost(MVT VT) const;
InstructionCost getVSlideCost(MVT VT) const;
InstructionCost getVSlideVXCost(MVT VT) const;
InstructionCost getVSlideVICost(MVT VT) const;

// Provide custom lowering hooks for some operations.
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
Expand Down
103 changes: 22 additions & 81 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,6 @@ def riscv_vrgatherei16_vv_vl : SDNode<"RISCVISD::VRGATHEREI16_VV_VL",
SDTCisSameNumEltsAs<0, 4>,
SDTCisVT<5, XLenVT>]>>;

def SDT_RISCVSelect_VL : SDTypeProfile<1, 4, [
SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>,
SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisVT<4, XLenVT>
]>;

def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL", SDT_RISCVSelect_VL>;

def SDT_RISCVVMERGE_VL : SDTypeProfile<1, 5, [
SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>,
SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameAs<0, 4>,
Expand Down Expand Up @@ -1722,21 +1715,21 @@ multiclass VPatMultiplyAccVL_VV_VX<PatFrag op, string instruction_name> {
(!cast<Instruction>(instruction_name#"_VX_"# suffix #"_MASK")
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
def : Pat<(riscv_vselect_vl (vti.Mask V0),
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (op vti.RegClass:$rd,
(riscv_mul_vl_oneuse vti.RegClass:$rs1, vti.RegClass:$rs2,
srcvalue, (vti.Mask true_mask), VLOpFrag),
srcvalue, (vti.Mask true_mask), VLOpFrag)),
vti.RegClass:$rd, VLOpFrag),
vti.RegClass:$rd, undef, VLOpFrag),
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
def : Pat<(riscv_vselect_vl (vti.Mask V0),
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (op vti.RegClass:$rd,
(riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1), vti.RegClass:$rs2,
srcvalue, (vti.Mask true_mask), VLOpFrag),
srcvalue, (vti.Mask true_mask), VLOpFrag)),
vti.RegClass:$rd, VLOpFrag),
vti.RegClass:$rd, undef, VLOpFrag),
(!cast<Instruction>(instruction_name#"_VX_"# suffix #"_MASK")
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
Expand Down Expand Up @@ -1861,17 +1854,17 @@ multiclass VPatFPMulAccVL_VV_VF<PatFrag vop, string instruction_name> {
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TU_MU)>;
def : Pat<(riscv_vselect_vl (vti.Mask V0),
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
vti.RegClass:$rd, VLOpFrag),
vti.RegClass:$rd, undef, VLOpFrag),
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
def : Pat<(riscv_vselect_vl (vti.Mask V0),
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
vti.RegClass:$rd, VLOpFrag),
vti.RegClass:$rd, undef, VLOpFrag),
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
Expand Down Expand Up @@ -1905,21 +1898,21 @@ multiclass VPatFPMulAccVL_VV_VF_RM<PatFrag vop, string instruction_name> {
// RISCVInsertReadWriteCSR
FRM_DYN,
GPR:$vl, vti.Log2SEW, TU_MU)>;
def : Pat<(riscv_vselect_vl (vti.Mask V0),
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
vti.RegClass:$rd, VLOpFrag),
vti.RegClass:$rd, undef, VLOpFrag),
(!cast<Instruction>(instruction_name#"_VV_"# suffix #"_MASK")
vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0),
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
FRM_DYN,
GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
def : Pat<(riscv_vselect_vl (vti.Mask V0),
def : Pat<(riscv_vmerge_vl (vti.Mask V0),
(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rs2,
vti.RegClass:$rd, (vti.Mask true_mask), VLOpFrag)),
vti.RegClass:$rd, VLOpFrag),
vti.RegClass:$rd, undef, VLOpFrag),
(!cast<Instruction>(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix # "_MASK")
vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask V0),
Expand Down Expand Up @@ -2255,31 +2248,6 @@ foreach vtiTowti = AllWidenableIntVectors in {
// 11.15. Vector Integer Merge Instructions
foreach vti = AllIntegerVectors in {
let Predicates = GetVTypePredicates<vti>.Predicates in {
def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
vti.RegClass:$rs1,
vti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VVM_"#vti.LMul.MX)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs2, vti.RegClass:$rs1, (vti.Mask V0),
GPR:$vl, vti.Log2SEW)>;

def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
(SplatPat XLenVT:$rs1),
vti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VXM_"#vti.LMul.MX)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs2, GPR:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;

def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask V0),
(SplatPat_simm5 simm5:$rs1),
vti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX)
(vti.Vector (IMPLICIT_DEF)),
vti.RegClass:$rs2, simm5:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;

def : Pat<(vti.Vector (riscv_vmerge_vl (vti.Mask V0),
vti.RegClass:$rs1,
vti.RegClass:$rs2,
Expand Down Expand Up @@ -2534,33 +2502,6 @@ foreach fvti = AllFloatVectors in {
// 13.15. Vector Floating-Point Merge Instruction
defvar ivti = GetIntVTypeInfo<fvti>.Vti;
let Predicates = GetVTypePredicates<ivti>.Predicates in {
def : Pat<(fvti.Vector (riscv_vselect_vl (fvti.Mask V0),
fvti.RegClass:$rs1,
fvti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
(fvti.Vector (IMPLICIT_DEF)),
fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
GPR:$vl, fvti.Log2SEW)>;

def : Pat<(fvti.Vector (riscv_vselect_vl (fvti.Mask V0),
(SplatFPOp (SelectFPImm (XLenVT GPR:$imm))),
fvti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VXM_"#fvti.LMul.MX)
(fvti.Vector (IMPLICIT_DEF)),
fvti.RegClass:$rs2,
GPR:$imm,
(fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;

def : Pat<(fvti.Vector (riscv_vselect_vl (fvti.Mask V0),
(SplatFPOp (fvti.Scalar fpimm0)),
fvti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
(fvti.Vector (IMPLICIT_DEF)),
fvti.RegClass:$rs2, 0, (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;

def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
fvti.RegClass:$rs1,
fvti.RegClass:$rs2,
Expand All @@ -2570,6 +2511,16 @@ foreach fvti = AllFloatVectors in {
fvti.RegClass:$merge, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
GPR:$vl, fvti.Log2SEW)>;

def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
(SplatFPOp (SelectFPImm (XLenVT GPR:$imm))),
fvti.RegClass:$rs2,
fvti.RegClass:$merge,
VLOpFrag)),
(!cast<Instruction>("PseudoVMERGE_VXM_"#fvti.LMul.MX)
fvti.RegClass:$merge, fvti.RegClass:$rs2, GPR:$imm, (fvti.Mask V0),
GPR:$vl, fvti.Log2SEW)>;


def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
(SplatFPOp (fvti.Scalar fpimm0)),
fvti.RegClass:$rs2,
Expand All @@ -2581,16 +2532,6 @@ foreach fvti = AllFloatVectors in {
}

let Predicates = GetVTypePredicates<fvti>.Predicates in {
def : Pat<(fvti.Vector (riscv_vselect_vl (fvti.Mask V0),
(SplatFPOp fvti.ScalarRegClass:$rs1),
fvti.RegClass:$rs2,
VLOpFrag)),
(!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX)
(fvti.Vector (IMPLICIT_DEF)),
fvti.RegClass:$rs2,
(fvti.Scalar fvti.ScalarRegClass:$rs1),
(fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;

def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask V0),
(SplatFPOp fvti.ScalarRegClass:$rs1),
fvti.RegClass:$rs2,
Expand Down
114 changes: 101 additions & 13 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,65 @@ static cl::opt<unsigned> SLPMaxVF(
"exclusively by SLP vectorizer."),
cl::Hidden);

InstructionCost
RISCVTTIImpl::getRISCVInstructionCost(ArrayRef<unsigned> OpCodes, MVT VT,
TTI::TargetCostKind CostKind) {
size_t NumInstr = OpCodes.size();
if (CostKind == TTI::TCK_CodeSize)
return NumInstr;
InstructionCost LMULCost = TLI->getLMULCost(VT);
if ((CostKind != TTI::TCK_RecipThroughput) && (CostKind != TTI::TCK_Latency))
return LMULCost * NumInstr;
InstructionCost Cost = 0;
for (auto Op : OpCodes) {
switch (Op) {
case RISCV::VRGATHER_VI:
Cost += TLI->getVRGatherVICost(VT);
break;
case RISCV::VRGATHER_VV:
Cost += TLI->getVRGatherVVCost(VT);
break;
case RISCV::VSLIDEUP_VI:
case RISCV::VSLIDEDOWN_VI:
Cost += TLI->getVSlideVICost(VT);
break;
case RISCV::VSLIDEUP_VX:
case RISCV::VSLIDEDOWN_VX:
Cost += TLI->getVSlideVXCost(VT);
break;
case RISCV::VREDMAX_VS:
case RISCV::VREDMIN_VS:
case RISCV::VREDMAXU_VS:
case RISCV::VREDMINU_VS:
case RISCV::VREDSUM_VS:
case RISCV::VREDAND_VS:
case RISCV::VREDOR_VS:
case RISCV::VREDXOR_VS:
case RISCV::VFREDMAX_VS:
case RISCV::VFREDMIN_VS:
case RISCV::VFREDUSUM_VS: {
unsigned VL = VT.getVectorMinNumElements();
if (!VT.isFixedLengthVector())
VL *= *getVScaleForTuning();
Cost += Log2_32_Ceil(VL);
break;
}
case RISCV::VFREDOSUM_VS: {
unsigned VL = VT.getVectorMinNumElements();
if (!VT.isFixedLengthVector())
VL *= *getVScaleForTuning();
Cost += VL;
break;
}
case RISCV::VMV_S_X:
// FIXME: VMV_S_X doesn't use LMUL, the cost should be 1
default:
Cost += LMULCost;
}
}
return Cost;
}

InstructionCost RISCVTTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
TTI::TargetCostKind CostKind) {
assert(Ty->isIntegerTy() &&
Expand Down Expand Up @@ -281,7 +340,8 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
// Example sequence:
// vnsrl.wi v10, v8, 0
if (equal(DeinterleaveMask, Mask))
return LT.first * TLI->getLMULCost(LT.second);
return LT.first * getRISCVInstructionCost(RISCV::VNSRL_WI,
LT.second, CostKind);
}
}
}
Expand All @@ -292,7 +352,8 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
LT.second.getVectorNumElements() <= 256)) {
VectorType *IdxTy = getVRGatherIndexType(LT.second, *ST, Tp->getContext());
InstructionCost IndexCost = getConstantPoolLoadCost(IdxTy, CostKind);
return IndexCost + TLI->getVRGatherVVCost(LT.second);
return IndexCost +
getRISCVInstructionCost(RISCV::VRGATHER_VV, LT.second, CostKind);
}
[[fallthrough]];
}
Expand All @@ -310,7 +371,10 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
VectorType *MaskTy = VectorType::get(IntegerType::getInt1Ty(C), EC);
InstructionCost IndexCost = getConstantPoolLoadCost(IdxTy, CostKind);
InstructionCost MaskCost = getConstantPoolLoadCost(MaskTy, CostKind);
return 2 * IndexCost + 2 * TLI->getVRGatherVVCost(LT.second) + MaskCost;
return 2 * IndexCost +
getRISCVInstructionCost({RISCV::VRGATHER_VV, RISCV::VRGATHER_VV},
LT.second, CostKind) +
MaskCost;
}
[[fallthrough]];
}
Expand Down Expand Up @@ -365,19 +429,24 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
// Example sequence:
// vsetivli zero, 4, e8, mf2, tu, ma (ignored)
// vslidedown.vi v8, v9, 2
return LT.first * TLI->getVSlideCost(LT.second);
return LT.first *
getRISCVInstructionCost(RISCV::VSLIDEDOWN_VI, LT.second, CostKind);
case TTI::SK_InsertSubvector:
// Example sequence:
// vsetivli zero, 4, e8, mf2, tu, ma (ignored)
// vslideup.vi v8, v9, 2
return LT.first * TLI->getVSlideCost(LT.second);
return LT.first *
getRISCVInstructionCost(RISCV::VSLIDEUP_VI, LT.second, CostKind);
case TTI::SK_Select: {
// Example sequence:
// li a0, 90
// vsetivli zero, 8, e8, mf2, ta, ma (ignored)
// vmv.s.x v0, a0
// vmerge.vvm v8, v9, v8, v0
return LT.first * 3 * TLI->getLMULCost(LT.second);
return LT.first *
(TLI->getLMULCost(LT.second) + // FIXME: should be 1 for li
getRISCVInstructionCost({RISCV::VMV_S_X, RISCV::VMERGE_VVM},
LT.second, CostKind));
}
case TTI::SK_Broadcast: {
bool HasScalar = (Args.size() > 0) && (Operator::getOpcode(Args[0]) ==
Expand All @@ -389,7 +458,10 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
// vsetivli zero, 2, e8, mf8, ta, ma (ignored)
// vmv.v.x v8, a0
// vmsne.vi v0, v8, 0
return LT.first * TLI->getLMULCost(LT.second) * 3;
return LT.first *
(TLI->getLMULCost(LT.second) + // FIXME: should be 1 for andi
getRISCVInstructionCost({RISCV::VMV_V_X, RISCV::VMSNE_VI},
LT.second, CostKind));
}
// Example sequence:
// vsetivli zero, 2, e8, mf8, ta, mu (ignored)
Expand All @@ -400,24 +472,38 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
// vmv.v.x v8, a0
// vmsne.vi v0, v8, 0

return LT.first * TLI->getLMULCost(LT.second) * 6;
return LT.first *
(TLI->getLMULCost(LT.second) + // FIXME: this should be 1 for andi
TLI->getLMULCost(
LT.second) + // FIXME: vmv.x.s is the same as extractelement
getRISCVInstructionCost({RISCV::VMV_V_I, RISCV::VMERGE_VIM,
RISCV::VMV_V_X, RISCV::VMSNE_VI},
LT.second, CostKind));
}

if (HasScalar) {
// Example sequence:
// vmv.v.x v8, a0
return LT.first * TLI->getLMULCost(LT.second);
return LT.first *
getRISCVInstructionCost(RISCV::VMV_V_X, LT.second, CostKind);
}

// Example sequence:
// vrgather.vi v9, v8, 0
return LT.first * TLI->getVRGatherVICost(LT.second);
return LT.first *
getRISCVInstructionCost(RISCV::VRGATHER_VI, LT.second, CostKind);
}
case TTI::SK_Splice:
case TTI::SK_Splice: {
// vslidedown+vslideup.
// TODO: Multiplying by LT.first implies this legalizes into multiple copies
// of similar code, but I think we expand through memory.
return 2 * LT.first * TLI->getVSlideCost(LT.second);
unsigned Opcodes[2] = {RISCV::VSLIDEDOWN_VX, RISCV::VSLIDEUP_VX};
if (Index >= 0 && Index < 32)
Opcodes[0] = RISCV::VSLIDEDOWN_VI;
else if (Index < 0 && Index > -32)
Opcodes[1] = RISCV::VSLIDEUP_VI;
return LT.first * getRISCVInstructionCost(Opcodes, LT.second, CostKind);
}
case TTI::SK_Reverse: {
// TODO: Cases to improve here:
// * Illegal vector types
Expand All @@ -437,7 +523,9 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
if (LT.second.isFixedLengthVector())
// vrsub.vi has a 5 bit immediate field, otherwise an li suffices
LenCost = isInt<5>(LT.second.getVectorNumElements() - 1) ? 0 : 1;
InstructionCost GatherCost = 2 + TLI->getVRGatherVVCost(LT.second);
// FIXME: replace the constant `2` below with cost of {VID_V,VRSUB_VX}
InstructionCost GatherCost =
2 + getRISCVInstructionCost(RISCV::VRGATHER_VV, LT.second, CostKind);
// Mask operation additionally required extend and truncate
InstructionCost ExtendCost = Tp->getElementType()->isIntegerTy(1) ? 3 : 0;
return LT.first * (LenCost + GatherCost + ExtendCost);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
/// actual target hardware.
unsigned getEstimatedVLFor(VectorType *Ty);

InstructionCost getRISCVInstructionCost(ArrayRef<unsigned> OpCodes, MVT VT,
TTI::TargetCostKind CostKind);

/// Return the cost of accessing a constant pool entry of the specified
/// type.
InstructionCost getConstantPoolLoadCost(Type *Ty,
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/WebAssembly/WebAssemblyFastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ bool WebAssemblyFastISel::computeAddress(const Value *Obj, Address &Addr) {
unsigned Idx = cast<ConstantInt>(Op)->getZExtValue();
TmpOffset += SL->getElementOffset(Idx);
} else {
uint64_t S = DL.getTypeAllocSize(GTI.getIndexedType());
uint64_t S = GTI.getSequentialElementStride(DL);
for (;;) {
if (const auto *CI = dyn_cast<ConstantInt>(Op)) {
// Constant-offset addressing.
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/X86/MCTargetDesc/X86MCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1650,6 +1650,9 @@ void X86MCCodeEmitter::encodeInstruction(const MCInst &MI,
if (HasVEX_4V) // Skip 1st src (which is encoded in VEX_VVVV)
++SrcRegNum;

if (IsND) // Skip new data destination
++CurOp;

emitRegModRMByte(MI.getOperand(SrcRegNum),
getX86RegNum(MI.getOperand(CurOp)), CB);
CurOp = SrcRegNum + 1;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/X86/X86FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ bool X86FastISel::X86SelectAddress(const Value *V, X86AddressMode &AM) {

// A array/variable index is always of the form i*S where S is the
// constant scale size. See if we can push the scale into immediates.
uint64_t S = DL.getTypeAllocSize(GTI.getIndexedType());
uint64_t S = GTI.getSequentialElementStride(DL);
for (;;) {
if (const ConstantInt *CI = dyn_cast<ConstantInt>(Op)) {
// Constant-offset addressing.
Expand Down
13 changes: 6 additions & 7 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49959,18 +49959,17 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
SDValue Ptr = Ld->getBasePtr();
SDValue Chain = Ld->getChain();
for (SDNode *User : Chain->uses()) {
if (User != N &&
auto *UserLd = dyn_cast<MemSDNode>(User);
if (User != N && UserLd &&
(User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
ISD::isNormalLoad(User)) &&
cast<MemSDNode>(User)->getChain() == Chain &&
!User->hasAnyUseOfValue(1) &&
UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
User->getValueSizeInBits(0).getFixedValue() >
RegVT.getFixedSizeInBits()) {
if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
cast<MemSDNode>(User)->getBasePtr() == Ptr &&
cast<MemSDNode>(User)->getMemoryVT().getSizeInBits() ==
MemVT.getSizeInBits()) {
UserLd->getBasePtr() == Ptr &&
UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits()) {
SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
RegVT.getSizeInBits());
Extract = DAG.getBitcast(RegVT, Extract);
Expand All @@ -49989,7 +49988,7 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
// See if we are loading a constant that matches in the lower
// bits of a longer constant (but from a different constant pool ptr).
EVT UserVT = User->getValueType(0);
SDValue UserPtr = cast<MemSDNode>(User)->getBasePtr();
SDValue UserPtr = UserLd->getBasePtr();
const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
if (LdC && UserC && UserPtr != Ptr) {
Expand Down
171 changes: 139 additions & 32 deletions llvm/lib/Target/X86/X86InstrArithmetic.td
Original file line number Diff line number Diff line change
Expand Up @@ -184,52 +184,139 @@ def IMUL64rmi32 : IMulOpMI_R<Xi64, WriteIMul64Imm>;
//===----------------------------------------------------------------------===//
// INC and DEC Instructions
//
class IncOpR_RF<X86TypeInfo t> : UnaryOpR_RF<0xFF, MRM0r, "inc", t, null_frag> {
class IncOpR_RF<X86TypeInfo t, bit ndd = 0> : UnaryOpR_RF<0xFF, MRM0r, "inc", t, null_frag, ndd> {
let Pattern = [(set t.RegClass:$dst, EFLAGS,
(X86add_flag_nocf t.RegClass:$src1, 1))];
}
class DecOpR_RF<X86TypeInfo t> : UnaryOpR_RF<0xFF, MRM1r, "dec", t, null_frag> {
class DecOpR_RF<X86TypeInfo t, bit ndd = 0> : UnaryOpR_RF<0xFF, MRM1r, "dec", t, null_frag, ndd> {
let Pattern = [(set t.RegClass:$dst, EFLAGS,
(X86sub_flag_nocf t.RegClass:$src1, 1))];
}
class IncOpM_M<X86TypeInfo t> : UnaryOpM_MF<0xFF, MRM0m, "inc", t, null_frag> {
class IncOpR_R<X86TypeInfo t, bit ndd = 0> : UnaryOpR_R<0xFF, MRM0r, "inc", t, null_frag, ndd>;
class DecOpR_R<X86TypeInfo t, bit ndd = 0> : UnaryOpR_R<0xFF, MRM1r, "dec", t, null_frag, ndd>;
class IncOpM_MF<X86TypeInfo t> : UnaryOpM_MF<0xFF, MRM0m, "inc", t, null_frag> {
let Pattern = [(store (add (t.LoadNode addr:$src1), 1), addr:$src1),
(implicit EFLAGS)];
}
class DecOpM_M<X86TypeInfo t> : UnaryOpM_MF<0xFF, MRM1m, "dec", t, null_frag> {
class DecOpM_MF<X86TypeInfo t> : UnaryOpM_MF<0xFF, MRM1m, "dec", t, null_frag> {
let Pattern = [(store (add (t.LoadNode addr:$src1), -1), addr:$src1),
(implicit EFLAGS)];
}
class IncOpM_RF<X86TypeInfo t> : UnaryOpM_RF<0xFF, MRM0m, "inc", t, null_frag> {
let Pattern = [(set t.RegClass:$dst, EFLAGS, (add (t.LoadNode addr:$src1), 1))];
}
class DecOpM_RF<X86TypeInfo t> : UnaryOpM_RF<0xFF, MRM1m, "dec", t, null_frag> {
let Pattern = [(set t.RegClass:$dst, EFLAGS, (add (t.LoadNode addr:$src1), -1))];
}
class IncOpM_M<X86TypeInfo t> : UnaryOpM_M<0xFF, MRM0m, "inc", t, null_frag>;
class DecOpM_M<X86TypeInfo t> : UnaryOpM_M<0xFF, MRM1m, "dec", t, null_frag>;
class IncOpM_R<X86TypeInfo t> : UnaryOpM_R<0xFF, MRM0m, "inc", t, null_frag>;
class DecOpM_R<X86TypeInfo t> : UnaryOpM_R<0xFF, MRM1m, "dec", t, null_frag>;

// IncDec_Alt - Instructions like "inc reg" short forms.
// Short forms only valid in 32-bit mode. Selected during MCInst lowering.
class IncDec_Alt<bits<8> o, string m, X86TypeInfo t>
: UnaryOpR_RF<o, AddRegFrm, m, t, null_frag>, Requires<[Not64BitMode]>;

let isConvertibleToThreeAddress = 1 in {
def INC16r_alt : IncDec_Alt<0x40, "inc", Xi16>, OpSize16;
def INC32r_alt : IncDec_Alt<0x40, "inc", Xi32>, OpSize32;
def DEC16r_alt : IncDec_Alt<0x48, "dec", Xi16>, OpSize16;
def DEC32r_alt : IncDec_Alt<0x48, "dec", Xi32>, OpSize32;
def INC8r : IncOpR_RF<Xi8>;
def INC16r : IncOpR_RF<Xi16>, OpSize16;
def INC32r : IncOpR_RF<Xi32>, OpSize32;
def INC64r : IncOpR_RF<Xi64>;
def DEC8r : DecOpR_RF<Xi8>;
def DEC16r : DecOpR_RF<Xi16>, OpSize16;
def DEC32r : DecOpR_RF<Xi32>, OpSize32;
def DEC64r : DecOpR_RF<Xi64>;
def INC16r_alt : IncDec_Alt<0x40, "inc", Xi16>, OpSize16;
def INC32r_alt : IncDec_Alt<0x40, "inc", Xi32>, OpSize32;
def DEC16r_alt : IncDec_Alt<0x48, "dec", Xi16>, OpSize16;
def DEC32r_alt : IncDec_Alt<0x48, "dec", Xi32>, OpSize32;
let Predicates = [NoNDD] in {
def INC8r : IncOpR_RF<Xi8>;
def INC16r : IncOpR_RF<Xi16>, OpSize16;
def INC32r : IncOpR_RF<Xi32>, OpSize32;
def INC64r : IncOpR_RF<Xi64>;
def DEC8r : DecOpR_RF<Xi8>;
def DEC16r : DecOpR_RF<Xi16>, OpSize16;
def DEC32r : DecOpR_RF<Xi32>, OpSize32;
def DEC64r : DecOpR_RF<Xi64>;
}
let Predicates = [HasNDD, In64BitMode] in {
def INC8r_ND : IncOpR_RF<Xi8, 1>;
def INC16r_ND : IncOpR_RF<Xi16, 1>, PD;
def INC32r_ND : IncOpR_RF<Xi32, 1>;
def INC64r_ND : IncOpR_RF<Xi64, 1>;
def DEC8r_ND : DecOpR_RF<Xi8, 1>;
def DEC16r_ND : DecOpR_RF<Xi16, 1>, PD;
def DEC32r_ND : DecOpR_RF<Xi32, 1>;
def DEC64r_ND : DecOpR_RF<Xi64, 1>;
}
let Predicates = [In64BitMode], Pattern = [(null_frag)] in {
def INC8r_NF : IncOpR_R<Xi8>, NF;
def INC16r_NF : IncOpR_R<Xi16>, NF, PD;
def INC32r_NF : IncOpR_R<Xi32>, NF;
def INC64r_NF : IncOpR_R<Xi64>, NF;
def DEC8r_NF : DecOpR_R<Xi8>, NF;
def DEC16r_NF : DecOpR_R<Xi16>, NF, PD;
def DEC32r_NF : DecOpR_R<Xi32>, NF;
def DEC64r_NF : DecOpR_R<Xi64>, NF;
def INC8r_NF_ND : IncOpR_R<Xi8, 1>, NF;
def INC16r_NF_ND : IncOpR_R<Xi16, 1>, NF, PD;
def INC32r_NF_ND : IncOpR_R<Xi32, 1>, NF;
def INC64r_NF_ND : IncOpR_R<Xi64, 1>, NF;
def DEC8r_NF_ND : DecOpR_R<Xi8, 1>, NF;
def DEC16r_NF_ND : DecOpR_R<Xi16, 1>, NF, PD;
def DEC32r_NF_ND : DecOpR_R<Xi32, 1>, NF;
def DEC64r_NF_ND : DecOpR_R<Xi64, 1>, NF;
def INC8r_EVEX : IncOpR_RF<Xi8>, PL;
def INC16r_EVEX : IncOpR_RF<Xi16>, PL, PD;
def INC32r_EVEX : IncOpR_RF<Xi32>, PL;
def INC64r_EVEX : IncOpR_RF<Xi64>, PL;
def DEC8r_EVEX : DecOpR_RF<Xi8>, PL;
def DEC16r_EVEX : DecOpR_RF<Xi16>, PL, PD;
def DEC32r_EVEX : DecOpR_RF<Xi32>, PL;
def DEC64r_EVEX : DecOpR_RF<Xi64>, PL;
}
}
let Predicates = [UseIncDec] in {
def INC8m : IncOpM_M<Xi8>;
def INC16m : IncOpM_M<Xi16>, OpSize16;
def INC32m : IncOpM_M<Xi32>, OpSize32;
def DEC8m : DecOpM_M<Xi8>;
def DEC16m : DecOpM_M<Xi16>, OpSize16;
def DEC32m : DecOpM_M<Xi32>, OpSize32;
def INC8m : IncOpM_MF<Xi8>;
def INC16m : IncOpM_MF<Xi16>, OpSize16;
def INC32m : IncOpM_MF<Xi32>, OpSize32;
def DEC8m : DecOpM_MF<Xi8>;
def DEC16m : DecOpM_MF<Xi16>, OpSize16;
def DEC32m : DecOpM_MF<Xi32>, OpSize32;
}
let Predicates = [UseIncDec, In64BitMode] in {
def INC64m : IncOpM_M<Xi64>;
def DEC64m : DecOpM_M<Xi64>;
def INC64m : IncOpM_MF<Xi64>;
def DEC64m : DecOpM_MF<Xi64>;
}
let Predicates = [HasNDD, In64BitMode, UseIncDec] in {
def INC8m_ND : IncOpM_RF<Xi8>;
def INC16m_ND : IncOpM_RF<Xi16>, PD;
def INC32m_ND : IncOpM_RF<Xi32>;
def DEC8m_ND : DecOpM_RF<Xi8>;
def DEC16m_ND : DecOpM_RF<Xi16>, PD;
def DEC32m_ND : DecOpM_RF<Xi32>;
def INC64m_ND : IncOpM_RF<Xi64>;
def DEC64m_ND : DecOpM_RF<Xi64>;
}
let Predicates = [In64BitMode], Pattern = [(null_frag)] in {
def INC8m_NF : IncOpM_M<Xi8>, NF;
def INC16m_NF : IncOpM_M<Xi16>, NF, PD;
def INC32m_NF : IncOpM_M<Xi32>, NF;
def INC64m_NF : IncOpM_M<Xi64>, NF;
def DEC8m_NF : DecOpM_M<Xi8>, NF;
def DEC16m_NF : DecOpM_M<Xi16>, NF, PD;
def DEC32m_NF : DecOpM_M<Xi32>, NF;
def DEC64m_NF : DecOpM_M<Xi64>, NF;
def INC8m_NF_ND : IncOpM_R<Xi8>, NF;
def INC16m_NF_ND : IncOpM_R<Xi16>, NF, PD;
def INC32m_NF_ND : IncOpM_R<Xi32>, NF;
def INC64m_NF_ND : IncOpM_R<Xi64>, NF;
def DEC8m_NF_ND : DecOpM_R<Xi8>, NF;
def DEC16m_NF_ND : DecOpM_R<Xi16>, NF, PD;
def DEC32m_NF_ND : DecOpM_R<Xi32>, NF;
def DEC64m_NF_ND : DecOpM_R<Xi64>, NF;
def INC8m_EVEX : IncOpM_MF<Xi8>, PL;
def INC16m_EVEX : IncOpM_MF<Xi16>, PL, PD;
def INC32m_EVEX : IncOpM_MF<Xi32>, PL;
def INC64m_EVEX : IncOpM_MF<Xi64>, PL;
def DEC8m_EVEX : DecOpM_MF<Xi8>, PL;
def DEC16m_EVEX : DecOpM_MF<Xi16>, PL, PD;
def DEC32m_EVEX : DecOpM_MF<Xi32>, PL;
def DEC64m_EVEX : DecOpM_MF<Xi64>, PL;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1119,14 +1206,34 @@ defm MULX64 : MulX<Xi64, WriteMULX64>, REX_W;
// We don't have patterns for these as there is no advantage over ADC for
// most code.
let Form = MRMSrcReg in {
def ADCX32rr : BinOpRRF_RF<0xF6, "adcx", Xi32, null_frag>, T8, PD;
def ADCX64rr : BinOpRRF_RF<0xF6, "adcx", Xi64, null_frag>, T8, PD;
def ADOX32rr : BinOpRRF_RF<0xF6, "adox", Xi32, null_frag>, T8, XS;
def ADOX64rr : BinOpRRF_RF<0xF6, "adox", Xi64, null_frag>, T8, XS;
def ADCX32rr : BinOpRRF_RF<0xF6, "adcx", Xi32>, T8, PD;
def ADCX64rr : BinOpRRF_RF<0xF6, "adcx", Xi64>, T8, PD;
def ADOX32rr : BinOpRRF_RF<0xF6, "adox", Xi32>, T8, XS;
def ADOX64rr : BinOpRRF_RF<0xF6, "adox", Xi64>, T8, XS;
let Predicates =[In64BitMode] in {
def ADCX32rr_EVEX : BinOpRRF_RF<0x66, "adcx", Xi32>, EVEX, T_MAP4, PD;
def ADCX64rr_EVEX : BinOpRRF_RF<0x66, "adcx", Xi64>, EVEX, T_MAP4, PD;
def ADOX32rr_EVEX : BinOpRRF_RF<0x66, "adox", Xi32>, EVEX, T_MAP4, XS;
def ADOX64rr_EVEX : BinOpRRF_RF<0x66, "adox", Xi64>, EVEX, T_MAP4, XS;
def ADCX32rr_ND : BinOpRRF_RF<0x66, "adcx", Xi32, null_frag, 1>, PD;
def ADCX64rr_ND : BinOpRRF_RF<0x66, "adcx", Xi64, null_frag, 1>, PD;
def ADOX32rr_ND : BinOpRRF_RF<0x66, "adox", Xi32, null_frag, 1>, XS;
def ADOX64rr_ND : BinOpRRF_RF<0x66, "adox", Xi64, null_frag, 1>, XS;
}
}
let Form = MRMSrcMem in {
def ADCX32rm : BinOpRMF_RF<0xF6, "adcx", Xi32, null_frag>, T8, PD;
def ADCX64rm : BinOpRMF_RF<0xF6, "adcx", Xi64, null_frag>, T8, PD;
def ADOX32rm : BinOpRMF_RF<0xF6, "adox", Xi32, null_frag>, T8, XS;
def ADOX64rm : BinOpRMF_RF<0xF6, "adox", Xi64, null_frag>, T8, XS;
def ADCX32rm : BinOpRMF_RF<0xF6, "adcx", Xi32>, T8, PD;
def ADCX64rm : BinOpRMF_RF<0xF6, "adcx", Xi64>, T8, PD;
def ADOX32rm : BinOpRMF_RF<0xF6, "adox", Xi32>, T8, XS;
def ADOX64rm : BinOpRMF_RF<0xF6, "adox", Xi64>, T8, XS;
let Predicates =[In64BitMode] in {
def ADCX32rm_EVEX : BinOpRMF_RF<0x66, "adcx", Xi32>, EVEX, T_MAP4, PD;
def ADCX64rm_EVEX : BinOpRMF_RF<0x66, "adcx", Xi64>, EVEX, T_MAP4, PD;
def ADOX32rm_EVEX : BinOpRMF_RF<0x66, "adox", Xi32>, EVEX, T_MAP4, XS;
def ADOX64rm_EVEX : BinOpRMF_RF<0x66, "adox", Xi64>, EVEX, T_MAP4, XS;
def ADCX32rm_ND : BinOpRMF_RF<0x66, "adcx", Xi32, null_frag, 1>, PD;
def ADCX64rm_ND : BinOpRMF_RF<0x66, "adcx", Xi64, null_frag, 1>, PD;
def ADOX32rm_ND : BinOpRMF_RF<0x66, "adox", Xi32, null_frag, 1>, XS;
def ADOX64rm_ND : BinOpRMF_RF<0x66, "adox", Xi64, null_frag, 1>, XS;
}
}
12 changes: 6 additions & 6 deletions llvm/lib/Target/X86/X86InstrMisc.td
Original file line number Diff line number Diff line change
Expand Up @@ -1353,22 +1353,22 @@ multiclass bmi_pdep_pext<string mnemonic, RegisterClass RC,
def rr#Suffix : I<0xF5, MRMSrcReg, (outs RC:$dst), (ins RC:$src1, RC:$src2),
!strconcat(mnemonic, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set RC:$dst, (OpNode RC:$src1, RC:$src2))]>,
VEX, VVVV, Sched<[WriteALU]>;
NoCD8, VVVV, Sched<[WriteALU]>;
def rm#Suffix : I<0xF5, MRMSrcMem, (outs RC:$dst), (ins RC:$src1, x86memop:$src2),
!strconcat(mnemonic, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set RC:$dst, (OpNode RC:$src1, (ld_frag addr:$src2)))]>,
VEX, VVVV, Sched<[WriteALU.Folded, WriteALU.ReadAfterFold]>;
NoCD8, VVVV, Sched<[WriteALU.Folded, WriteALU.ReadAfterFold]>;
}

let Predicates = [HasBMI2, NoEGPR] in {
defm PDEP32 : bmi_pdep_pext<"pdep{l}", GR32, i32mem,
X86pdep, loadi32>, T8, XD;
X86pdep, loadi32>, T8, XD, VEX;
defm PDEP64 : bmi_pdep_pext<"pdep{q}", GR64, i64mem,
X86pdep, loadi64>, T8, XD, REX_W;
X86pdep, loadi64>, T8, XD, REX_W, VEX;
defm PEXT32 : bmi_pdep_pext<"pext{l}", GR32, i32mem,
X86pext, loadi32>, T8, XS;
X86pext, loadi32>, T8, XS, VEX;
defm PEXT64 : bmi_pdep_pext<"pext{q}", GR64, i64mem,
X86pext, loadi64>, T8, XS, REX_W;
X86pext, loadi64>, T8, XS, REX_W, VEX;
}

let Predicates = [HasBMI2, HasEGPR] in {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/X86/X86InstrShiftRotate.td
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ let Predicates = [HasBMI2, NoEGPR] in {
defm SHLX64 : bmi_shift<"shlx{q}", GR64, i64mem>, T8, PD, REX_W;
}

let Predicates = [HasBMI2, HasEGPR] in {
let Predicates = [HasBMI2, HasEGPR, In64BitMode] in {
defm RORX32 : bmi_rotate<"rorx{l}", GR32, i32mem, "_EVEX">, EVEX;
defm RORX64 : bmi_rotate<"rorx{q}", GR64, i64mem, "_EVEX">, REX_W, EVEX;
defm SARX32 : bmi_shift<"sarx{l}", GR32, i32mem, "_EVEX">, T8, XS, EVEX;
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/X86/X86InstrUtils.td
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ class BinOpRR_RF_Rev<bits<8> o, string m, X86TypeInfo t, bit ndd = 0>
}
// BinOpRRF_RF - Instructions that read "reg, reg", write "reg" and read/write
// EFLAGS.
class BinOpRRF_RF<bits<8> o, string m, X86TypeInfo t, SDPatternOperator node, bit ndd = 0>
class BinOpRRF_RF<bits<8> o, string m, X86TypeInfo t, SDPatternOperator node = null_frag, bit ndd = 0>
: BinOpRR<o, m, !if(!eq(ndd, 0), binop_args, binop_ndd_args), t, (outs t.RegClass:$dst),
[(set t.RegClass:$dst, EFLAGS,
(node t.RegClass:$src1, t.RegClass:$src2,
Expand Down Expand Up @@ -1041,7 +1041,7 @@ class BinOpRM_RF<bits<8> o, string m, X86TypeInfo t, SDPatternOperator node, bit
(t.LoadNode addr:$src2)))]>, DefEFLAGS, NDD<ndd>;
// BinOpRMF_RF - Instructions that read "reg, [mem]", write "reg" and read/write
// EFLAGS.
class BinOpRMF_RF<bits<8> o, string m, X86TypeInfo t, SDPatternOperator node, bit ndd = 0>
class BinOpRMF_RF<bits<8> o, string m, X86TypeInfo t, SDPatternOperator node = null_frag, bit ndd = 0>
: BinOpRM<o, m, !if(!eq(ndd, 0), binop_args, binop_ndd_args), t, (outs t.RegClass:$dst),
[(set t.RegClass:$dst, EFLAGS,
(node t.RegClass:$src1, (t.LoadNode addr:$src2), EFLAGS))]>,
Expand Down
7 changes: 0 additions & 7 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1666,13 +1666,6 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *Ashr = foldAddToAshr(I))
return Ashr;

// min(A, B) + max(A, B) => A + B.
if (match(&I, m_CombineOr(m_c_Add(m_SMax(m_Value(A), m_Value(B)),
m_c_SMin(m_Deferred(A), m_Deferred(B))),
m_c_Add(m_UMax(m_Value(A), m_Value(B)),
m_c_UMin(m_Deferred(A), m_Deferred(B))))))
return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I);

// (~X) + (~Y) --> -2 - (X + Y)
{
// To ensure we can save instructions we need to ensure that we consume both
Expand Down
46 changes: 5 additions & 41 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1536,11 +1536,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}

if (II->isCommutative()) {
if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
return I;

if (Instruction *I = foldCommutativeIntrinsicOverPhis(*II))
return I;
if (auto Pair = matchSymmetricPair(II->getOperand(0), II->getOperand(1))) {
replaceOperand(*II, 0, Pair->first);
replaceOperand(*II, 1, Pair->second);
return II;
}

if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
return NewCall;
Expand Down Expand Up @@ -4246,39 +4246,3 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
Call.setCalledFunction(FTy, NestF);
return &Call;
}

// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
Instruction *
InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {
assert(II.isCommutative());

Value *A, *B, *C;
if (match(II.getOperand(0), m_Select(m_Value(A), m_Value(B), m_Value(C))) &&
match(II.getOperand(1),
m_Select(m_Specific(A), m_Specific(C), m_Specific(B)))) {
replaceOperand(II, 0, B);
replaceOperand(II, 1, C);
return &II;
}

return nullptr;
}

Instruction *
InstCombinerImpl::foldCommutativeIntrinsicOverPhis(IntrinsicInst &II) {
assert(II.isCommutative() && "Instruction should be commutative");

PHINode *LHS = dyn_cast<PHINode>(II.getOperand(0));
PHINode *RHS = dyn_cast<PHINode>(II.getOperand(1));

if (!LHS || !RHS)
return nullptr;

if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
replaceOperand(II, 0, P->first);
replaceOperand(II, 1, P->second);
return &II;
}

return nullptr;
}
23 changes: 8 additions & 15 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,15 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
bool transformConstExprCastCall(CallBase &Call);
Instruction *transformCallThroughTrampoline(CallBase &Call,
IntrinsicInst &Tramp);
Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);

// Match a pair of Phi Nodes like
// phi [a, BB0], [b, BB1] & phi [b, BB0], [a, BB1]
// Return the matched two operands.
std::optional<std::pair<Value *, Value *>>
matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS);

// Tries to fold (op phi(a, b) phi(b, a)) -> (op a, b)
// while op is a commutative intrinsic call.
Instruction *foldCommutativeIntrinsicOverPhis(IntrinsicInst &II);
// Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
// Otherwise, return std::nullopt
// Currently it matches:
// - LHS = (select c, a, b), RHS = (select c, b, a)
// - LHS = (phi [a, BB0], [b, BB1]), RHS = (phi [b, BB0], [a, BB1])
// - LHS = min(a, b), RHS = max(a, b)
std::optional<std::pair<Value *, Value *>> matchSymmetricPair(Value *LHS,
Value *RHS);

Value *simplifyMaskedLoad(IntrinsicInst &II);
Instruction *simplifyMaskedStore(IntrinsicInst &II);
Expand Down Expand Up @@ -502,11 +500,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// X % (C0 * C1)
Value *SimplifyAddWithRemainder(BinaryOperator &I);

// Tries to fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b)
// while Binop is commutative.
Value *SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, Value *LHS,
Value *RHS);

// Binary Op helper for select operations where the expression can be
// efficiently reorganized.
Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
Expand Down
7 changes: 0 additions & 7 deletions llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,6 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
return Res;

// min(X, Y) * max(X, Y) => X * Y.
if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)),
m_c_SMin(m_Deferred(X), m_Deferred(Y))),
m_c_Mul(m_UMax(m_Value(X), m_Value(Y)),
m_c_UMin(m_Deferred(X), m_Deferred(Y))))))
return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I);

// (mul Op0 Op1):
// if Log2(Op0) folds away ->
// (shl Op1, Log2(Op0))
Expand Down
77 changes: 44 additions & 33 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,14 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) {
getComplexity(I.getOperand(1)))
Changed = !I.swapOperands();

if (I.isCommutative()) {
if (auto Pair = matchSymmetricPair(I.getOperand(0), I.getOperand(1))) {
replaceOperand(I, 0, Pair->first);
replaceOperand(I, 1, Pair->second);
Changed = true;
}
}

BinaryOperator *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0));
BinaryOperator *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1));

Expand Down Expand Up @@ -1096,8 +1104,8 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
}

std::optional<std::pair<Value *, Value *>>
InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
static std::optional<std::pair<Value *, Value *>>
matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
if (LHS->getParent() != RHS->getParent())
return std::nullopt;

Expand All @@ -1123,25 +1131,41 @@ InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
return std::optional(std::pair(L0, R0));
}

Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I,
Value *Op0,
Value *Op1) {
assert(I.isCommutative() && "Instruction should be commutative");

PHINode *LHS = dyn_cast<PHINode>(Op0);
PHINode *RHS = dyn_cast<PHINode>(Op1);

if (!LHS || !RHS)
return nullptr;

if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
Value *BI = Builder.CreateBinOp(I.getOpcode(), P->first, P->second);
if (auto *BO = dyn_cast<BinaryOperator>(BI))
BO->copyIRFlags(&I);
return BI;
std::optional<std::pair<Value *, Value *>>
InstCombinerImpl::matchSymmetricPair(Value *LHS, Value *RHS) {
Instruction *LHSInst = dyn_cast<Instruction>(LHS);
Instruction *RHSInst = dyn_cast<Instruction>(RHS);
if (!LHSInst || !RHSInst || LHSInst->getOpcode() != RHSInst->getOpcode())
return std::nullopt;
switch (LHSInst->getOpcode()) {
case Instruction::PHI:
return matchSymmetricPhiNodesPair(cast<PHINode>(LHS), cast<PHINode>(RHS));
case Instruction::Select: {
Value *Cond = LHSInst->getOperand(0);
Value *TrueVal = LHSInst->getOperand(1);
Value *FalseVal = LHSInst->getOperand(2);
if (Cond == RHSInst->getOperand(0) && TrueVal == RHSInst->getOperand(2) &&
FalseVal == RHSInst->getOperand(1))
return std::pair(TrueVal, FalseVal);
return std::nullopt;
}
case Instruction::Call: {
// Match min(a, b) and max(a, b)
MinMaxIntrinsic *LHSMinMax = dyn_cast<MinMaxIntrinsic>(LHSInst);
MinMaxIntrinsic *RHSMinMax = dyn_cast<MinMaxIntrinsic>(RHSInst);
if (LHSMinMax && RHSMinMax &&
LHSMinMax->getPredicate() ==
ICmpInst::getSwappedPredicate(RHSMinMax->getPredicate()) &&
((LHSMinMax->getLHS() == RHSMinMax->getLHS() &&
LHSMinMax->getRHS() == RHSMinMax->getRHS()) ||
(LHSMinMax->getLHS() == RHSMinMax->getRHS() &&
LHSMinMax->getRHS() == RHSMinMax->getLHS())))
return std::pair(LHSMinMax->getLHS(), LHSMinMax->getRHS());
return std::nullopt;
}
default:
return std::nullopt;
}

return nullptr;
}

Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
Expand Down Expand Up @@ -1187,14 +1211,6 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
};

if (LHSIsSelect && RHSIsSelect && A == D) {
// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
if (I.isCommutative() && B == F && C == E) {
Value *BI = Builder.CreateBinOp(I.getOpcode(), B, E);
if (auto *BO = dyn_cast<BinaryOperator>(BI))
BO->copyIRFlags(&I);
return BI;
}

// (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)
Cond = A;
True = simplifyBinOp(Opcode, B, E, FMF, Q);
Expand Down Expand Up @@ -1577,11 +1593,6 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
BO.getParent() != Phi1->getParent())
return nullptr;

if (BO.isCommutative()) {
if (Value *V = SimplifyPhiCommutativeBinaryOp(BO, Phi0, Phi1))
return replaceInstUsesWith(BO, V);
}

// Fold if there is at least one specific constant value in phi0 or phi1's
// incoming values that comes from the same block and this specific constant
// value can be used to do optimization for specific binary operator.
Expand Down
Loading