Skip to content

Commit

Permalink
[AVX-512] Fix accidental uses of AH/BH/CH/DH after copies to/from mas…
Browse files Browse the repository at this point in the history
…k registers

We've had several bugs(PR32256, PR32241) recently that resulted from usages of AH/BH/CH/DH either before or after a copy to/from a mask register.

This ultimately occurs because we create COPY_TO_REGCLASS with VK1 and GR8. Then in CopyToFromAsymmetricReg in X86InstrInfo we find a 32-bit super register for the GR8 to emit the KMOV with. But as these tests are demonstrating, its possible for the GR8 register to be a high register and we end up doing an accidental extra or insert from bits 15:8.

I think the best way forward is to stop making copies directly between mask registers and GR8/GR16. Instead I think we should restrict to only copies between mask registers and GR32/GR64 and use EXTRACT_SUBREG/INSERT_SUBREG to handle the conversion from GR32 to GR16/8 or vice versa.

Unfortunately, this complicates fastisel a bit more now to create the subreg extracts where we used to create GR8 copies. We can probably make a helper function to bring down the repitition.

This does result in KMOVD being used for copies when BWI is available because we don't know the original mask register size. This caused a lot of deltas on tests because we have to split the checks for KMOVD vs KMOVW based on BWI.

Differential Revision: https://reviews.llvm.org/D30968

llvm-svn: 298928
  • Loading branch information
topperc committed Mar 28, 2017
1 parent 7333a9e commit 058f2f6
Show file tree
Hide file tree
Showing 57 changed files with 2,874 additions and 2,374 deletions.
58 changes: 45 additions & 13 deletions llvm/lib/Target/X86/X86FastISel.cpp
Expand Up @@ -367,6 +367,10 @@ bool X86FastISel::X86FastEmitLoad(EVT VT, X86AddressMode &AM,
switch (VT.getSimpleVT().SimpleTy) {
default: return false;
case MVT::i1:
// TODO: Support this properly.
if (Subtarget->hasAVX512())
return false;
LLVM_FALLTHROUGH;
case MVT::i8:
Opc = X86::MOV8rm;
RC = &X86::GR8RegClass;
Expand Down Expand Up @@ -540,11 +544,12 @@ bool X86FastISel::X86FastEmitStore(EVT VT, unsigned ValReg, bool ValIsKill,
// In case ValReg is a K register, COPY to a GPR
if (MRI.getRegClass(ValReg) == &X86::VK1RegClass) {
unsigned KValReg = ValReg;
ValReg = createResultReg(Subtarget->is64Bit() ? &X86::GR8RegClass
: &X86::GR8_ABCD_LRegClass);
ValReg = createResultReg(&X86::GR32RegClass);
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
TII.get(TargetOpcode::COPY), ValReg)
.addReg(KValReg);
ValReg = fastEmitInst_extractsubreg(MVT::i8, ValReg, /*Kill=*/true,
X86::sub_8bit);
}
// Mask out all but lowest bit.
unsigned AndResult = createResultReg(&X86::GR8RegClass);
Expand Down Expand Up @@ -1280,11 +1285,12 @@ bool X86FastISel::X86SelectRet(const Instruction *I) {
// In case SrcReg is a K register, COPY to a GPR
if (MRI.getRegClass(SrcReg) == &X86::VK1RegClass) {
unsigned KSrcReg = SrcReg;
SrcReg = createResultReg(Subtarget->is64Bit() ? &X86::GR8RegClass
: &X86::GR8_ABCD_LRegClass);
SrcReg = createResultReg(&X86::GR32RegClass);
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
TII.get(TargetOpcode::COPY), SrcReg)
.addReg(KSrcReg);
SrcReg = fastEmitInst_extractsubreg(MVT::i8, SrcReg, /*Kill=*/true,
X86::sub_8bit);
}
SrcReg = fastEmitZExtFromI1(MVT::i8, SrcReg, /*TODO: Kill=*/false);
SrcVT = MVT::i8;
Expand Down Expand Up @@ -1580,11 +1586,12 @@ bool X86FastISel::X86SelectZExt(const Instruction *I) {
// In case ResultReg is a K register, COPY to a GPR
if (MRI.getRegClass(ResultReg) == &X86::VK1RegClass) {
unsigned KResultReg = ResultReg;
ResultReg = createResultReg(Subtarget->is64Bit() ? &X86::GR8RegClass
: &X86::GR8_ABCD_LRegClass);
ResultReg = createResultReg(&X86::GR32RegClass);
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
TII.get(TargetOpcode::COPY), ResultReg)
.addReg(KResultReg);
ResultReg = fastEmitInst_extractsubreg(MVT::i8, ResultReg, /*Kill=*/true,
X86::sub_8bit);
}

// Set the high bits to zero.
Expand Down Expand Up @@ -1768,11 +1775,12 @@ bool X86FastISel::X86SelectBranch(const Instruction *I) {
// In case OpReg is a K register, COPY to a GPR
if (MRI.getRegClass(OpReg) == &X86::VK1RegClass) {
unsigned KOpReg = OpReg;
OpReg = createResultReg(Subtarget->is64Bit() ? &X86::GR8RegClass
: &X86::GR8_ABCD_LRegClass);
OpReg = createResultReg(&X86::GR32RegClass);
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
TII.get(TargetOpcode::COPY), OpReg)
.addReg(KOpReg);
OpReg = fastEmitInst_extractsubreg(MVT::i8, OpReg, /*Kill=*/true,
X86::sub_8bit);
}
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc, TII.get(X86::TEST8ri))
.addReg(OpReg)
Expand Down Expand Up @@ -2113,11 +2121,12 @@ bool X86FastISel::X86FastEmitCMoveSelect(MVT RetVT, const Instruction *I) {
// In case OpReg is a K register, COPY to a GPR
if (MRI.getRegClass(CondReg) == &X86::VK1RegClass) {
unsigned KCondReg = CondReg;
CondReg = createResultReg(Subtarget->is64Bit() ?
&X86::GR8RegClass : &X86::GR8_ABCD_LRegClass);
CondReg = createResultReg(&X86::GR32RegClass);
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
TII.get(TargetOpcode::COPY), CondReg)
.addReg(KCondReg, getKillRegState(CondIsKill));
CondReg = fastEmitInst_extractsubreg(MVT::i8, CondReg, /*Kill=*/true,
X86::sub_8bit);
}
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc, TII.get(X86::TEST8ri))
.addReg(CondReg, getKillRegState(CondIsKill))
Expand Down Expand Up @@ -2327,11 +2336,12 @@ bool X86FastISel::X86FastEmitPseudoSelect(MVT RetVT, const Instruction *I) {
// In case OpReg is a K register, COPY to a GPR
if (MRI.getRegClass(CondReg) == &X86::VK1RegClass) {
unsigned KCondReg = CondReg;
CondReg = createResultReg(Subtarget->is64Bit() ?
&X86::GR8RegClass : &X86::GR8_ABCD_LRegClass);
CondReg = createResultReg(&X86::GR32RegClass);
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
TII.get(TargetOpcode::COPY), CondReg)
.addReg(KCondReg, getKillRegState(CondIsKill));
CondReg = fastEmitInst_extractsubreg(MVT::i8, CondReg, /*Kill=*/true,
X86::sub_8bit);
}
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc, TII.get(X86::TEST8ri))
.addReg(CondReg, getKillRegState(CondIsKill))
Expand Down Expand Up @@ -3307,6 +3317,16 @@ bool X86FastISel::fastLowerCall(CallLoweringInfo &CLI) {

// Handle zero-extension from i1 to i8, which is common.
if (ArgVT == MVT::i1) {
// In case SrcReg is a K register, COPY to a GPR
if (MRI.getRegClass(ArgReg) == &X86::VK1RegClass) {
unsigned KArgReg = ArgReg;
ArgReg = createResultReg(&X86::GR32RegClass);
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
TII.get(TargetOpcode::COPY), ArgReg)
.addReg(KArgReg);
ArgReg = fastEmitInst_extractsubreg(MVT::i8, ArgReg, /*Kill=*/true,
X86::sub_8bit);
}
// Set the high bits to zero.
ArgReg = fastEmitZExtFromI1(MVT::i8, ArgReg, /*TODO: Kill=*/false);
ArgVT = MVT::i8;
Expand Down Expand Up @@ -3642,6 +3662,13 @@ unsigned X86FastISel::X86MaterializeInt(const ConstantInt *CI, MVT VT) {
switch (VT.SimpleTy) {
default: llvm_unreachable("Unexpected value type");
case MVT::i1:
if (Subtarget->hasAVX512()) {
// Need to copy to a VK1 register.
unsigned ResultReg = createResultReg(&X86::VK1RegClass);
BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
TII.get(TargetOpcode::COPY), ResultReg).addReg(SrcReg);
return ResultReg;
}
case MVT::i8:
return fastEmitInst_extractsubreg(MVT::i8, SrcReg, /*Kill=*/true,
X86::sub_8bit);
Expand All @@ -3663,7 +3690,12 @@ unsigned X86FastISel::X86MaterializeInt(const ConstantInt *CI, MVT VT) {
unsigned Opc = 0;
switch (VT.SimpleTy) {
default: llvm_unreachable("Unexpected value type");
case MVT::i1: VT = MVT::i8; LLVM_FALLTHROUGH;
case MVT::i1:
// TODO: Support this properly.
if (Subtarget->hasAVX512())
return 0;
VT = MVT::i8;
LLVM_FALLTHROUGH;
case MVT::i8: Opc = X86::MOV8ri; break;
case MVT::i16: Opc = X86::MOV16ri; break;
case MVT::i32: Opc = X86::MOV32ri; break;
Expand Down
79 changes: 61 additions & 18 deletions llvm/lib/Target/X86/X86InstrAVX512.td
Expand Up @@ -2183,28 +2183,26 @@ let Predicates = [HasBWI] in {

// GR from/to mask register
def : Pat<(v16i1 (bitconvert (i16 GR16:$src))),
(COPY_TO_REGCLASS GR16:$src, VK16)>;
(COPY_TO_REGCLASS (i32 (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit)), VK16)>;
def : Pat<(i16 (bitconvert (v16i1 VK16:$src))),
(COPY_TO_REGCLASS VK16:$src, GR16)>;
(EXTRACT_SUBREG (i32 (COPY_TO_REGCLASS VK16:$src, GR32)), sub_16bit)>;

def : Pat<(v8i1 (bitconvert (i8 GR8:$src))),
(COPY_TO_REGCLASS GR8:$src, VK8)>;
(COPY_TO_REGCLASS (i32 (INSERT_SUBREG (IMPLICIT_DEF), GR8:$src, sub_8bit)), VK8)>;
def : Pat<(i8 (bitconvert (v8i1 VK8:$src))),
(COPY_TO_REGCLASS VK8:$src, GR8)>;
(EXTRACT_SUBREG (i32 (COPY_TO_REGCLASS VK8:$src, GR32)), sub_8bit)>;

def : Pat<(i32 (zext (i16 (bitconvert (v16i1 VK16:$src))))),
(KMOVWrk VK16:$src)>;
def : Pat<(i32 (anyext (i16 (bitconvert (v16i1 VK16:$src))))),
(i32 (INSERT_SUBREG (IMPLICIT_DEF),
(i16 (COPY_TO_REGCLASS VK16:$src, GR16)), sub_16bit))>;
(COPY_TO_REGCLASS VK16:$src, GR32)>;

def : Pat<(i32 (zext (i8 (bitconvert (v8i1 VK8:$src))))),
(MOVZX32rr8 (COPY_TO_REGCLASS VK8:$src, GR8))>, Requires<[NoDQI]>;
(MOVZX32rr8 (EXTRACT_SUBREG (i32 (COPY_TO_REGCLASS VK8:$src, GR32)), sub_8bit))>, Requires<[NoDQI]>;
def : Pat<(i32 (zext (i8 (bitconvert (v8i1 VK8:$src))))),
(KMOVBrk VK8:$src)>, Requires<[HasDQI]>;
def : Pat<(i32 (anyext (i8 (bitconvert (v8i1 VK8:$src))))),
(i32 (INSERT_SUBREG (IMPLICIT_DEF),
(i8 (COPY_TO_REGCLASS VK8:$src, GR8)), sub_8bit))>;
(COPY_TO_REGCLASS VK8:$src, GR32)>;

def : Pat<(v32i1 (bitconvert (i32 GR32:$src))),
(COPY_TO_REGCLASS GR32:$src, VK32)>;
Expand Down Expand Up @@ -3288,6 +3286,23 @@ def : Pat<(masked_store addr:$dst, Mask,

}

multiclass avx512_store_scalar_lowering_subreg<string InstrStr,
AVX512VLVectorVTInfo _,
dag Mask, RegisterClass MaskRC,
SubRegIndex subreg> {

def : Pat<(masked_store addr:$dst, Mask,
(_.info512.VT (insert_subvector undef,
(_.info256.VT (insert_subvector undef,
(_.info128.VT _.info128.RC:$src),
(iPTR 0))),
(iPTR 0)))),
(!cast<Instruction>(InstrStr#mrk) addr:$dst,
(i1 (COPY_TO_REGCLASS (i32 (INSERT_SUBREG (IMPLICIT_DEF), MaskRC:$mask, subreg)), VK1WM)),
(COPY_TO_REGCLASS _.info128.RC:$src, _.info128.FRC))>;

}

multiclass avx512_load_scalar_lowering<string InstrStr, AVX512VLVectorVTInfo _,
dag Mask, RegisterClass MaskRC> {

Expand All @@ -3314,22 +3329,50 @@ def : Pat<(_.info128.VT (extract_subvector

}

multiclass avx512_load_scalar_lowering_subreg<string InstrStr,
AVX512VLVectorVTInfo _,
dag Mask, RegisterClass MaskRC,
SubRegIndex subreg> {

def : Pat<(_.info128.VT (extract_subvector
(_.info512.VT (masked_load addr:$srcAddr, Mask,
(_.info512.VT (bitconvert
(v16i32 immAllZerosV))))),
(iPTR 0))),
(!cast<Instruction>(InstrStr#rmkz)
(i1 (COPY_TO_REGCLASS (i32 (INSERT_SUBREG (IMPLICIT_DEF), MaskRC:$mask, subreg)), VK1WM)),
addr:$srcAddr)>;

def : Pat<(_.info128.VT (extract_subvector
(_.info512.VT (masked_load addr:$srcAddr, Mask,
(_.info512.VT (insert_subvector undef,
(_.info256.VT (insert_subvector undef,
(_.info128.VT (X86vzmovl _.info128.RC:$src)),
(iPTR 0))),
(iPTR 0))))),
(iPTR 0))),
(!cast<Instruction>(InstrStr#rmk) _.info128.RC:$src,
(i1 (COPY_TO_REGCLASS (i32 (INSERT_SUBREG (IMPLICIT_DEF), MaskRC:$mask, subreg)), VK1WM)),
addr:$srcAddr)>;

}

defm : avx512_move_scalar_lowering<"VMOVSSZ", X86Movss, fp32imm0, v4f32x_info>;
defm : avx512_move_scalar_lowering<"VMOVSDZ", X86Movsd, fp64imm0, v2f64x_info>;

defm : avx512_store_scalar_lowering<"VMOVSSZ", avx512vl_f32_info,
(v16i1 (bitconvert (i16 (trunc (and GR32:$mask, (i32 1)))))), GR32>;
defm : avx512_store_scalar_lowering<"VMOVSSZ", avx512vl_f32_info,
(v16i1 (bitconvert (i16 (and GR16:$mask, (i16 1))))), GR16>;
defm : avx512_store_scalar_lowering<"VMOVSDZ", avx512vl_f64_info,
(v8i1 (bitconvert (i8 (and GR8:$mask, (i8 1))))), GR8>;
defm : avx512_store_scalar_lowering_subreg<"VMOVSSZ", avx512vl_f32_info,
(v16i1 (bitconvert (i16 (and GR16:$mask, (i16 1))))), GR16, sub_16bit>;
defm : avx512_store_scalar_lowering_subreg<"VMOVSDZ", avx512vl_f64_info,
(v8i1 (bitconvert (i8 (and GR8:$mask, (i8 1))))), GR8, sub_8bit>;

defm : avx512_load_scalar_lowering<"VMOVSSZ", avx512vl_f32_info,
(v16i1 (bitconvert (i16 (trunc (and GR32:$mask, (i32 1)))))), GR32>;
defm : avx512_load_scalar_lowering<"VMOVSSZ", avx512vl_f32_info,
(v16i1 (bitconvert (i16 (and GR16:$mask, (i16 1))))), GR16>;
defm : avx512_load_scalar_lowering<"VMOVSDZ", avx512vl_f64_info,
(v8i1 (bitconvert (i8 (and GR8:$mask, (i8 1))))), GR8>;
defm : avx512_load_scalar_lowering_subreg<"VMOVSSZ", avx512vl_f32_info,
(v16i1 (bitconvert (i16 (and GR16:$mask, (i16 1))))), GR16, sub_16bit>;
defm : avx512_load_scalar_lowering_subreg<"VMOVSDZ", avx512vl_f64_info,
(v8i1 (bitconvert (i8 (and GR8:$mask, (i8 1))))), GR8, sub_8bit>;

def : Pat<(f32 (X86selects VK1WM:$mask, (f32 FR32X:$src1), (f32 FR32X:$src2))),
(COPY_TO_REGCLASS (VMOVSSZrrk (COPY_TO_REGCLASS FR32X:$src2, VR128X),
Expand All @@ -3340,7 +3383,7 @@ def : Pat<(f64 (X86selects VK1WM:$mask, (f64 FR64X:$src1), (f64 FR64X:$src2))),
VK1WM:$mask, (v2f64 (IMPLICIT_DEF)), FR64X:$src1), FR64X)>;

def : Pat<(int_x86_avx512_mask_store_ss addr:$dst, VR128X:$src, GR8:$mask),
(VMOVSSZmrk addr:$dst, (i1 (COPY_TO_REGCLASS GR8:$mask, VK1WM)),
(VMOVSSZmrk addr:$dst, (i1 (COPY_TO_REGCLASS (i32 (INSERT_SUBREG (IMPLICIT_DEF), GR8:$mask, sub_8bit)), VK1WM)),
(COPY_TO_REGCLASS VR128X:$src, FR32X))>;

let hasSideEffects = 0 in
Expand Down
22 changes: 0 additions & 22 deletions llvm/lib/Target/X86/X86InstrInfo.cpp
Expand Up @@ -6309,8 +6309,6 @@ static unsigned CopyToFromAsymmetricReg(unsigned &DestReg, unsigned &SrcReg,

// SrcReg(MaskReg) -> DestReg(GR64)
// SrcReg(MaskReg) -> DestReg(GR32)
// SrcReg(MaskReg) -> DestReg(GR16)
// SrcReg(MaskReg) -> DestReg(GR8)

// All KMASK RegClasses hold the same k registers, can be tested against anyone.
if (X86::VK16RegClass.contains(SrcReg)) {
Expand All @@ -6320,21 +6318,10 @@ static unsigned CopyToFromAsymmetricReg(unsigned &DestReg, unsigned &SrcReg,
}
if (X86::GR32RegClass.contains(DestReg))
return Subtarget.hasBWI() ? X86::KMOVDrk : X86::KMOVWrk;
if (X86::GR16RegClass.contains(DestReg)) {
DestReg = getX86SubSuperRegister(DestReg, 32);
return X86::KMOVWrk;
}
if (X86::GR8RegClass.contains(DestReg)) {
assert(!isHReg(DestReg) && "Cannot move between mask and h-reg");
DestReg = getX86SubSuperRegister(DestReg, 32);
return Subtarget.hasDQI() ? X86::KMOVBrk : X86::KMOVWrk;
}
}

// SrcReg(GR64) -> DestReg(MaskReg)
// SrcReg(GR32) -> DestReg(MaskReg)
// SrcReg(GR16) -> DestReg(MaskReg)
// SrcReg(GR8) -> DestReg(MaskReg)

// All KMASK RegClasses hold the same k registers, can be tested against anyone.
if (X86::VK16RegClass.contains(DestReg)) {
Expand All @@ -6344,15 +6331,6 @@ static unsigned CopyToFromAsymmetricReg(unsigned &DestReg, unsigned &SrcReg,
}
if (X86::GR32RegClass.contains(SrcReg))
return Subtarget.hasBWI() ? X86::KMOVDkr : X86::KMOVWkr;
if (X86::GR16RegClass.contains(SrcReg)) {
SrcReg = getX86SubSuperRegister(SrcReg, 32);
return X86::KMOVWkr;
}
if (X86::GR8RegClass.contains(SrcReg)) {
assert(!isHReg(SrcReg) && "Cannot move between mask and h-reg");
SrcReg = getX86SubSuperRegister(SrcReg, 32);
return Subtarget.hasDQI() ? X86::KMOVBkr : X86::KMOVWkr;
}
}


Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/X86/avx512-calling-conv.ll
Expand Up @@ -298,7 +298,7 @@ define <8 x i1> @test7a(<8 x i32>%a, <8 x i32>%b) {
; SKX-NEXT: vpsllw $15, %xmm0, %xmm0
; SKX-NEXT: vpmovw2m %xmm0, %k0
; SKX-NEXT: movb $85, %al
; SKX-NEXT: kmovb %eax, %k1
; SKX-NEXT: kmovd %eax, %k1
; SKX-NEXT: kandb %k1, %k0, %k0
; SKX-NEXT: vpmovm2w %k0, %xmm0
; SKX-NEXT: popq %rax
Expand Down
1 change: 1 addition & 0 deletions llvm/test/CodeGen/X86/avx512-cmp-kor-sequence.ll
Expand Up @@ -19,6 +19,7 @@ define zeroext i16 @cmp_kor_seq_16(<16 x float> %a, <16 x float> %b, <16 x float
; CHECK-NEXT: korw %k3, %k2, %k1
; CHECK-NEXT: korw %k1, %k0, %k0
; CHECK-NEXT: kmovw %k0, %eax
; CHECK-NEXT: # kill: %AX<def> %AX<kill> %EAX<kill>
; CHECK-NEXT: retq
entry:
%0 = tail call i16 @llvm.x86.avx512.mask.cmp.ps.512(<16 x float> %a, <16 x float> %x, i32 13, i16 -1, i32 4)
Expand Down

0 comments on commit 058f2f6

Please sign in to comment.