Skip to content

Commit

Permalink
[PowerPC] Legalize v256i1 and v512i1 and implement load and store of …
Browse files Browse the repository at this point in the history
…these types

This patch legalizes the v256i1 and v512i1 types that will be used for MMA.

It implements loads and stores of these types.
v256i1 is a pair of VSX registers, so for this type, we load/store the two
underlying registers. v512i1 is used for MMA accumulators. So in addition to
loading and storing the 4 associated VSX registers, we generate instructions to
prime (copy the VSX registers to the accumulator) after loading and unprime
(copy the accumulator back to the VSX registers) before storing.

This patch also adds the UACC register class that is necessary to implement the
loads and stores. This class represents accumulator in their unprimed form and
allow the distinction between primed and unprimed accumulators to avoid invalid
copies of the VSX registers associated with primed accumulators.

Differential Revision: https://reviews.llvm.org/D84968
  • Loading branch information
Baptiste Saleil committed Sep 28, 2020
1 parent 33125cf commit 0156914
Show file tree
Hide file tree
Showing 9 changed files with 532 additions and 5 deletions.
11 changes: 8 additions & 3 deletions clang/lib/Basic/Targets/PPC.h
Expand Up @@ -404,19 +404,20 @@ class LLVM_LIBRARY_VISIBILITY PPC64TargetInfo : public PPCTargetInfo {
LongWidth = LongAlign = PointerWidth = PointerAlign = 64;
IntMaxType = SignedLong;
Int64Type = SignedLong;
std::string DataLayout = "";

if (Triple.isOSAIX()) {
// TODO: Set appropriate ABI for AIX platform.
resetDataLayout("E-m:a-i64:64-n32:64");
DataLayout = "E-m:a-i64:64-n32:64";
SuitableAlign = 64;
LongDoubleWidth = 64;
LongDoubleAlign = DoubleAlign = 32;
LongDoubleFormat = &llvm::APFloat::IEEEdouble();
} else if ((Triple.getArch() == llvm::Triple::ppc64le)) {
resetDataLayout("e-m:e-i64:64-n32:64");
DataLayout = "e-m:e-i64:64-n32:64";
ABI = "elfv2";
} else {
resetDataLayout("E-m:e-i64:64-n32:64");
DataLayout = "E-m:e-i64:64-n32:64";
ABI = "elfv1";
}

Expand All @@ -425,6 +426,10 @@ class LLVM_LIBRARY_VISIBILITY PPC64TargetInfo : public PPCTargetInfo {
LongDoubleFormat = &llvm::APFloat::IEEEdouble();
}

if (Triple.isOSAIX() || Triple.isOSLinux())
DataLayout += "-v256:256:256-v512:512:512";
resetDataLayout(DataLayout);

// PPC64 supports atomics up to 8 bytes.
MaxAtomicPromoteWidth = MaxAtomicInlineWidth = 64;
}
Expand Down
20 changes: 18 additions & 2 deletions clang/test/CodeGen/target-data.c
Expand Up @@ -136,11 +136,27 @@

// RUN: %clang_cc1 -triple powerpc64-linux -o - -emit-llvm %s | \
// RUN: FileCheck %s -check-prefix=PPC64-LINUX
// PPC64-LINUX: target datalayout = "E-m:e-i64:64-n32:64"
// PPC64-LINUX: target datalayout = "E-m:e-i64:64-n32:64-v256:256:256-v512:512:512"

// RUN: %clang_cc1 -triple powerpc64-linux -o - -emit-llvm -target-cpu future %s | \
// RUN: FileCheck %s -check-prefix=PPC64-FUTURE
// PPC64-FUTURE: target datalayout = "E-m:e-i64:64-n32:64-v256:256:256-v512:512:512"

// RUN: %clang_cc1 -triple powerpc64-linux -o - -emit-llvm -target-cpu pwr10 %s | \
// RUN: FileCheck %s -check-prefix=PPC64-P10
// PPC64-P10: target datalayout = "E-m:e-i64:64-n32:64-v256:256:256-v512:512:512"

// RUN: %clang_cc1 -triple powerpc64le-linux -o - -emit-llvm %s | \
// RUN: FileCheck %s -check-prefix=PPC64LE-LINUX
// PPC64LE-LINUX: target datalayout = "e-m:e-i64:64-n32:64"
// PPC64LE-LINUX: target datalayout = "e-m:e-i64:64-n32:64-v256:256:256-v512:512:512"

// RUN: %clang_cc1 -triple powerpc64le-linux -o - -emit-llvm -target-cpu future %s | \
// RUN: FileCheck %s -check-prefix=PPC64LE-FUTURE
// PPC64LE-FUTURE: target datalayout = "e-m:e-i64:64-n32:64-v256:256:256-v512:512:512"

// RUN: %clang_cc1 -triple powerpc64le-linux -o - -emit-llvm -target-cpu pwr10 %s | \
// RUN: FileCheck %s -check-prefix=PPC64LE-P10
// PPC64LE-P10: target datalayout = "e-m:e-i64:64-n32:64-v256:256:256-v512:512:512"

// RUN: %clang_cc1 -triple nvptx-unknown -o - -emit-llvm %s | \
// RUN: FileCheck %s -check-prefix=NVPTX
Expand Down
109 changes: 109 additions & 0 deletions llvm/lib/Target/PowerPC/PPCISelLowering.cpp
Expand Up @@ -1181,6 +1181,18 @@ PPCTargetLowering::PPCTargetLowering(const PPCTargetMachine &TM,
}
}

if (Subtarget.pairedVectorMemops()) {
addRegisterClass(MVT::v256i1, &PPC::VSRpRCRegClass);
setOperationAction(ISD::LOAD, MVT::v256i1, Custom);
setOperationAction(ISD::STORE, MVT::v256i1, Custom);
}
if (Subtarget.hasMMA()) {
addRegisterClass(MVT::v512i1, &PPC::UACCRCRegClass);
setOperationAction(ISD::LOAD, MVT::v512i1, Custom);
setOperationAction(ISD::STORE, MVT::v512i1, Custom);
setOperationAction(ISD::BUILD_VECTOR, MVT::v512i1, Custom);
}

if (Subtarget.has64BitSupport())
setOperationAction(ISD::PREFETCH, MVT::Other, Legal);

Expand Down Expand Up @@ -1523,6 +1535,10 @@ const char *PPCTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "PPCISD::TLS_DYNAMIC_MAT_PCREL_ADDR";
case PPCISD::TLS_LOCAL_EXEC_MAT_ADDR:
return "PPCISD::TLS_LOCAL_EXEC_MAT_ADDR";
case PPCISD::ACC_BUILD: return "PPCISD::ACC_BUILD";
case PPCISD::PAIR_BUILD: return "PPCISD::PAIR_BUILD";
case PPCISD::EXTRACT_VSX_REG: return "PPCISD::EXTRACT_VSX_REG";
case PPCISD::XXMFACC: return "PPCISD::XXMFACC";
case PPCISD::LD_SPLAT: return "PPCISD::LD_SPLAT";
case PPCISD::FNMSUB: return "PPCISD::FNMSUB";
case PPCISD::STRICT_FADDRTZ:
Expand Down Expand Up @@ -7824,6 +7840,8 @@ SDValue PPCTargetLowering::lowerEH_SJLJ_LONGJMP(SDValue Op,
}

SDValue PPCTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
if (Op.getValueType().isVector())
return LowerVectorLoad(Op, DAG);

assert(Op.getValueType() == MVT::i1 &&
"Custom lowering only for i1 loads");
Expand All @@ -7847,6 +7865,9 @@ SDValue PPCTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
}

SDValue PPCTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
if (Op.getOperand(1).getValueType().isVector())
return LowerVectorStore(Op, DAG);

assert(Op.getOperand(1).getValueType() == MVT::i1 &&
"Custom lowering only for i1 stores");

Expand Down Expand Up @@ -10581,6 +10602,94 @@ SDValue PPCTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
return Op;
}

SDValue PPCTargetLowering::LowerVectorLoad(SDValue Op,
SelectionDAG &DAG) const {
SDLoc dl(Op);
LoadSDNode *LN = cast<LoadSDNode>(Op.getNode());
SDValue LoadChain = LN->getChain();
SDValue BasePtr = LN->getBasePtr();
EVT VT = Op.getValueType();

if (VT != MVT::v256i1 && VT != MVT::v512i1)
return Op;

// Type v256i1 is used for pairs and v512i1 is used for accumulators.
// Here we create 2 or 4 v16i8 loads to load the pair or accumulator value in
// 2 or 4 vsx registers.
assert((VT != MVT::v512i1 || Subtarget.hasMMA()) &&
"Type unsupported without MMA");
assert((VT != MVT::v256i1 || Subtarget.pairedVectorMemops()) &&
"Type unsupported without paired vector support");
Align Alignment = LN->getAlign();
SmallVector<SDValue, 4> Loads;
SmallVector<SDValue, 4> LoadChains;
unsigned NumVecs = VT.getSizeInBits() / 128;
for (unsigned Idx = 0; Idx < NumVecs; ++Idx) {
SDValue Load =
DAG.getLoad(MVT::v16i8, dl, LoadChain, BasePtr,
LN->getPointerInfo().getWithOffset(Idx * 16),
commonAlignment(Alignment, Idx * 16),
LN->getMemOperand()->getFlags(), LN->getAAInfo());
BasePtr = DAG.getNode(ISD::ADD, dl, BasePtr.getValueType(), BasePtr,
DAG.getConstant(16, dl, BasePtr.getValueType()));
Loads.push_back(Load);
LoadChains.push_back(Load.getValue(1));
}
if (Subtarget.isLittleEndian()) {
std::reverse(Loads.begin(), Loads.end());
std::reverse(LoadChains.begin(), LoadChains.end());
}
SDValue TF = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, LoadChains);
SDValue Value =
DAG.getNode(VT == MVT::v512i1 ? PPCISD::ACC_BUILD : PPCISD::PAIR_BUILD,
dl, VT, Loads);
SDValue RetOps[] = {Value, TF};
return DAG.getMergeValues(RetOps, dl);
}

SDValue PPCTargetLowering::LowerVectorStore(SDValue Op,
SelectionDAG &DAG) const {
SDLoc dl(Op);
StoreSDNode *SN = cast<StoreSDNode>(Op.getNode());
SDValue StoreChain = SN->getChain();
SDValue BasePtr = SN->getBasePtr();
SDValue Value = SN->getValue();
EVT StoreVT = Value.getValueType();

if (StoreVT != MVT::v256i1 && StoreVT != MVT::v512i1)
return Op;

// Type v256i1 is used for pairs and v512i1 is used for accumulators.
// Here we create 2 or 4 v16i8 stores to store the pair or accumulator
// underlying registers individually.
assert((StoreVT != MVT::v512i1 || Subtarget.hasMMA()) &&
"Type unsupported without MMA");
assert((StoreVT != MVT::v256i1 || Subtarget.pairedVectorMemops()) &&
"Type unsupported without paired vector support");
Align Alignment = SN->getAlign();
SmallVector<SDValue, 4> Stores;
unsigned NumVecs = 2;
if (StoreVT == MVT::v512i1) {
Value = DAG.getNode(PPCISD::XXMFACC, dl, MVT::v512i1, Value);
NumVecs = 4;
}
for (unsigned Idx = 0; Idx < NumVecs; ++Idx) {
unsigned VecNum = Subtarget.isLittleEndian() ? NumVecs - 1 - Idx : Idx;
SDValue Elt = DAG.getNode(PPCISD::EXTRACT_VSX_REG, dl, MVT::v16i8, Value,
DAG.getConstant(VecNum, dl, MVT::i64));
SDValue Store =
DAG.getStore(StoreChain, dl, Elt, BasePtr,
SN->getPointerInfo().getWithOffset(Idx * 16),
commonAlignment(Alignment, Idx * 16),
SN->getMemOperand()->getFlags(), SN->getAAInfo());
BasePtr = DAG.getNode(ISD::ADD, dl, BasePtr.getValueType(), BasePtr,
DAG.getConstant(16, dl, BasePtr.getValueType()));
Stores.push_back(Store);
}
SDValue TF = DAG.getTokenFactor(dl, Stores);
return TF;
}

SDValue PPCTargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
SDLoc dl(Op);
if (Op.getValueType() == MVT::v4i32) {
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/PowerPC/PPCISelLowering.h
Expand Up @@ -450,6 +450,21 @@ namespace llvm {
/// available. This is used with ADD_TLS to produce an add like PADDI.
TLS_LOCAL_EXEC_MAT_ADDR,

/// ACC_BUILD = Build an accumulator register from 4 VSX registers.
ACC_BUILD,

/// PAIR_BUILD = Build a vector pair register from 2 VSX registers.
PAIR_BUILD,

/// EXTRACT_VSX_REG = Extract one of the underlying vsx registers of
/// an accumulator or pair register. This node is needed because
/// EXTRACT_SUBVECTOR expects the input and output vectors to have the same
/// element type.
EXTRACT_VSX_REG,

/// XXMFACC = This corresponds to the xxmfacc instruction.
XXMFACC,

// Constrained conversion from floating point to int
STRICT_FCTIDZ = ISD::FIRST_TARGET_STRICTFP_OPCODE,
STRICT_FCTIWZ,
Expand Down
25 changes: 25 additions & 0 deletions llvm/lib/Target/PowerPC/PPCInstrInfo.cpp
Expand Up @@ -2465,6 +2465,31 @@ bool PPCInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
auto DL = MI.getDebugLoc();

switch (MI.getOpcode()) {
case PPC::BUILD_UACC: {
MCRegister ACC = MI.getOperand(0).getReg();
MCRegister UACC = MI.getOperand(1).getReg();
if (ACC - PPC::ACC0 != UACC - PPC::UACC0) {
MCRegister SrcVSR = PPC::VSL0 + (UACC - PPC::UACC0) * 4;
MCRegister DstVSR = PPC::VSL0 + (ACC - PPC::ACC0) * 4;
// FIXME: This can easily be improved to look up to the top of the MBB
// to see if the inputs are XXLOR's. If they are and SrcReg is killed,
// we can just re-target any such XXLOR's to DstVSR + offset.
for (int VecNo = 0; VecNo < 4; VecNo++)
BuildMI(MBB, MI, DL, get(PPC::XXLOR), DstVSR + VecNo)
.addReg(SrcVSR + VecNo)
.addReg(SrcVSR + VecNo);
}
// BUILD_UACC is expanded to 4 copies of the underlying vsx regisers.
// So after building the 4 copies, we can replace the BUILD_UACC instruction
// with a NOP.
LLVM_FALLTHROUGH;
}
case PPC::KILL_PAIR: {
MI.setDesc(get(PPC::UNENCODED_NOP));
MI.RemoveOperand(1);
MI.RemoveOperand(0);
return true;
}
case TargetOpcode::LOAD_STACK_GUARD: {
assert(Subtarget.isTargetLinux() &&
"Only Linux target is expected to contain LOAD_STACK_GUARD");
Expand Down
87 changes: 87 additions & 0 deletions llvm/lib/Target/PowerPC/PPCInstrPrefix.td
Expand Up @@ -5,12 +5,35 @@
def SDT_PPCSplat32 : SDTypeProfile<1, 3, [ SDTCisVT<0, v2i64>,
SDTCisVec<1>, SDTCisInt<2>, SDTCisInt<3>
]>;
def SDT_PPCAccBuild : SDTypeProfile<1, 4, [
SDTCisVT<0, v512i1>, SDTCisVT<1, v4i32>, SDTCisVT<2, v4i32>,
SDTCisVT<3, v4i32>, SDTCisVT<4, v4i32>
]>;
def SDT_PPCPairBuild : SDTypeProfile<1, 2, [
SDTCisVT<0, v256i1>, SDTCisVT<1, v4i32>, SDTCisVT<2, v4i32>
]>;
def SDT_PPCAccExtractVsx : SDTypeProfile<1, 2, [
SDTCisVT<0, v4i32>, SDTCisVT<1, v512i1>, SDTCisInt<2>
]>;
def SDT_PPCPairExtractVsx : SDTypeProfile<1, 2, [
SDTCisVT<0, v4i32>, SDTCisVT<1, v256i1>, SDTCisInt<2>
]>;
def SDT_PPCxxmfacc : SDTypeProfile<1, 1, [
SDTCisVT<0, v512i1>, SDTCisVT<1, v512i1>
]>;

//===----------------------------------------------------------------------===//
// ISA 3.1 specific PPCISD nodes.
//

def PPCxxsplti32dx : SDNode<"PPCISD::XXSPLTI32DX", SDT_PPCSplat32, []>;
def PPCAccBuild : SDNode<"PPCISD::ACC_BUILD", SDT_PPCAccBuild, []>;
def PPCPairBuild : SDNode<"PPCISD::PAIR_BUILD", SDT_PPCPairBuild, []>;
def PPCAccExtractVsx : SDNode<"PPCISD::EXTRACT_VSX_REG", SDT_PPCAccExtractVsx,
[]>;
def PPCPairExtractVsx : SDNode<"PPCISD::EXTRACT_VSX_REG", SDT_PPCPairExtractVsx,
[]>;
def PPCxxmfacc : SDNode<"PPCISD::XXMFACC", SDT_PPCxxmfacc, []>;

//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -525,6 +548,16 @@ def vsrprc : RegisterOperand<VSRpRC> {
let ParserMatchClass = PPCRegVSRpRCAsmOperand;
}

def PPCRegVSRpEvenRCAsmOperand : AsmOperandClass {
let Name = "RegVSRpEvenRC"; let PredicateMethod = "isVSRpEvenRegNumber";
}

def vsrpevenrc : RegisterOperand<VSRpRC> {
let ParserMatchClass = PPCRegVSRpEvenRCAsmOperand;
let EncoderMethod = "getVSRpEvenEncoding";
let DecoderMethod = "decodeVSRpEvenOperands";
}

class DQForm_XTp5_RA17_MEM<bits<6> opcode, bits<4> xo, dag OOL, dag IOL,
string asmstr, InstrItinClass itin, list<dag> pattern>
: I<opcode, OOL, IOL, asmstr, itin> {
Expand Down Expand Up @@ -594,6 +627,10 @@ def acc : RegisterOperand<ACCRC> {
let ParserMatchClass = PPCRegACCRCAsmOperand;
}

def uacc : RegisterOperand<UACCRC> {
let ParserMatchClass = PPCRegACCRCAsmOperand;
}

// [PO AS XO2 XO]
class XForm_AT3<bits<6> opcode, bits<5> xo2, bits<10> xo, dag OOL, dag IOL,
string asmstr, InstrItinClass itin, list<dag> pattern>
Expand Down Expand Up @@ -774,6 +811,11 @@ let Predicates = [MMA] in {
XForm_AT3<31, 1, 177, (outs acc:$AT), (ins acc:$ATi), "xxmtacc $AT",
IIC_VecGeneral, []>, RegConstraint<"$ATi = $AT">,
NoEncode<"$ATi">;
def KILL_PAIR : PPCPostRAExpPseudo<(outs vsrprc:$XTp), (ins vsrprc:$XSp),
"#KILL_PAIR", []>,
RegConstraint<"$XTp = $XSp">;
def BUILD_UACC : PPCPostRAExpPseudo<(outs acc:$AT), (ins uacc:$AS),
"#BUILD_UACC $AT, $AS", []>;
// We define XXSETACCZ as rematerializable to undo CSE of that intrinsic in
// the backend. We avoid CSE here because it generates a copy of the acc
// register and this copy is more expensive than calling the intrinsic again.
Expand All @@ -784,6 +826,51 @@ let Predicates = [MMA] in {
}
}

def Concats {
dag VecsToVecPair0 =
(v256i1 (INSERT_SUBREG
(INSERT_SUBREG (IMPLICIT_DEF), $vs0, sub_vsx1),
$vs1, sub_vsx0));
dag VecsToVecPair1 =
(v256i1 (INSERT_SUBREG
(INSERT_SUBREG (IMPLICIT_DEF), $vs2, sub_vsx1),
$vs3, sub_vsx0));
dag VecsToVecQuad =
(BUILD_UACC (INSERT_SUBREG
(INSERT_SUBREG (v512i1 (IMPLICIT_DEF)),
(KILL_PAIR VecsToVecPair0), sub_pair0),
(KILL_PAIR VecsToVecPair1), sub_pair1));
}

def Extracts {
dag Pair0 = (v256i1 (EXTRACT_SUBREG $v, sub_pair0));
dag Pair1 = (v256i1 (EXTRACT_SUBREG $v, sub_pair1));
dag Vec0 = (v4i32 (EXTRACT_SUBREG Pair0, sub_vsx0));
dag Vec1 = (v4i32 (EXTRACT_SUBREG Pair0, sub_vsx1));
dag Vec2 = (v4i32 (EXTRACT_SUBREG Pair1, sub_vsx0));
dag Vec3 = (v4i32 (EXTRACT_SUBREG Pair1, sub_vsx1));
}

let Predicates = [MMA] in {
def : Pat<(v512i1 (PPCAccBuild v4i32:$vs1, v4i32:$vs0, v4i32:$vs3, v4i32:$vs2)),
(XXMTACC Concats.VecsToVecQuad)>;
def : Pat<(v256i1 (PPCPairBuild v4i32:$vs1, v4i32:$vs0)),
Concats.VecsToVecPair0>;
def : Pat<(v512i1 (PPCxxmfacc v512i1:$AS)), (XXMFACC acc:$AS)>;
def : Pat<(v4i32 (PPCAccExtractVsx acc:$v, (i64 0))),
Extracts.Vec0>;
def : Pat<(v4i32 (PPCAccExtractVsx acc:$v, (i64 1))),
Extracts.Vec1>;
def : Pat<(v4i32 (PPCAccExtractVsx acc:$v, (i64 2))),
Extracts.Vec2>;
def : Pat<(v4i32 (PPCAccExtractVsx acc:$v, (i64 3))),
Extracts.Vec3>;
def : Pat<(v4i32 (PPCPairExtractVsx vsrpevenrc:$v, (i64 0))),
(v4i32 (EXTRACT_SUBREG $v, sub_vsx0))>;
def : Pat<(v4i32 (PPCPairExtractVsx vsrpevenrc:$v, (i64 1))),
(v4i32 (EXTRACT_SUBREG $v, sub_vsx1))>;
}

let mayLoad = 1, mayStore = 0, Predicates = [PairedVectorMemops] in {
def LXVP : DQForm_XTp5_RA17_MEM<6, 0, (outs vsrprc:$XTp),
(ins memrix16:$DQ_RA), "lxvp $XTp, $DQ_RA",
Expand Down

0 comments on commit 0156914

Please sign in to comment.