Skip to content

Commit

Permalink
[RISCV] Fix vm operand constraint to fit GCC's behavior
Browse files Browse the repository at this point in the history
- `vm` constraint is used for masking operand, which always v0.

- Update testcase, only masking operand should use `vm`, vector mask operations
  should just use `vr` for any vector register.

 - Revise the description of `vm` constraint.

- This patch also fix issue on RISCVRegisterInfo.td and RISCVISelLowering.cpp.

  RISCVRegisterInfo.td:
  - The first VT in the list must be the largest total size since the
    SelectionDAGBuilder uses the first register in the list as the canonical
    type for the register.

  RISCVISelLowering.cpp:
  - Fix RISCVTargetLowering::splitValueIntoRegisterParts and
    RISCVTargetLowering::joinRegisterPartsIntoValue for handling vectors
    with different total size, that will happened on fractional LMUL since
    fractional LMUL is always occupy one vector register.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D112599
  • Loading branch information
kito-cheng committed Dec 9, 2021
1 parent 352e36e commit 39c8617
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 27 deletions.
2 changes: 1 addition & 1 deletion llvm/docs/LangRef.rst
Expand Up @@ -4823,7 +4823,7 @@ RISC-V:
- ``r``: A 32- or 64-bit general-purpose register (depending on the platform
``XLEN``).
- ``vr``: A vector register. (requires V extension).
- ``vm``: A vector mask register. (requires V extension).
- ``vm``: A vector register for masking operand. (requires V extension).

Sparc:

Expand Down
46 changes: 31 additions & 15 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -9572,8 +9572,8 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
return std::make_pair(0U, RC);
}
} else if (Constraint == "vm") {
if (TRI->isTypeLegalForClass(RISCV::VMRegClass, VT.SimpleTy))
return std::make_pair(0U, &RISCV::VMRegClass);
if (TRI->isTypeLegalForClass(RISCV::VMV0RegClass, VT.SimpleTy))
return std::make_pair(0U, &RISCV::VMV0RegClass);
}
}

Expand Down Expand Up @@ -10112,17 +10112,29 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
unsigned ValueVTBitSize = ValueVT.getSizeInBits().getKnownMinSize();
unsigned PartVTBitSize = PartVT.getSizeInBits().getKnownMinSize();
if (PartVTBitSize % ValueVTBitSize == 0) {
assert(PartVTBitSize >= ValueVTBitSize);
// If the element types are different, bitcast to the same element type of
// PartVT first.
// Give an example here, we want copy a <vscale x 1 x i8> value to
// <vscale x 4 x i16>.
// We need to convert <vscale x 1 x i8> to <vscale x 8 x i8> by insert
// subvector, then we can bitcast to <vscale x 4 x i16>.
if (ValueEltVT != PartEltVT) {
unsigned Count = ValueVTBitSize / PartEltVT.getSizeInBits();
assert(Count != 0 && "The number of element should not be zero.");
EVT SameEltTypeVT =
EVT::getVectorVT(Context, PartEltVT, Count, /*IsScalable=*/true);
Val = DAG.getNode(ISD::BITCAST, DL, SameEltTypeVT, Val);
if (PartVTBitSize > ValueVTBitSize) {
unsigned Count = PartVTBitSize / ValueEltVT.getFixedSizeInBits();
assert(Count != 0 && "The number of element should not be zero.");
EVT SameEltTypeVT =
EVT::getVectorVT(Context, ValueEltVT, Count, /*IsScalable=*/true);
Val = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, SameEltTypeVT,
DAG.getUNDEF(SameEltTypeVT), Val,
DAG.getVectorIdxConstant(0, DL));
}
Val = DAG.getNode(ISD::BITCAST, DL, PartVT, Val);
} else {
Val =
DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PartVT, DAG.getUNDEF(PartVT),
Val, DAG.getVectorIdxConstant(0, DL));
}
Val = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PartVT, DAG.getUNDEF(PartVT),
Val, DAG.getConstant(0, DL, Subtarget.getXLenVT()));
Parts[0] = Val;
return true;
}
Expand Down Expand Up @@ -10152,19 +10164,23 @@ SDValue RISCVTargetLowering::joinRegisterPartsIntoValue(
unsigned ValueVTBitSize = ValueVT.getSizeInBits().getKnownMinSize();
unsigned PartVTBitSize = PartVT.getSizeInBits().getKnownMinSize();
if (PartVTBitSize % ValueVTBitSize == 0) {
assert(PartVTBitSize >= ValueVTBitSize);
EVT SameEltTypeVT = ValueVT;
// If the element types are different, convert it to the same element type
// of PartVT.
// Give an example here, we want copy a <vscale x 1 x i8> value from
// <vscale x 4 x i16>.
// We need to convert <vscale x 4 x i16> to <vscale x 8 x i8> first,
// then we can extract <vscale x 1 x i8>.
if (ValueEltVT != PartEltVT) {
unsigned Count = ValueVTBitSize / PartEltVT.getSizeInBits();
unsigned Count = PartVTBitSize / ValueEltVT.getFixedSizeInBits();
assert(Count != 0 && "The number of element should not be zero.");
SameEltTypeVT =
EVT::getVectorVT(Context, PartEltVT, Count, /*IsScalable=*/true);
EVT::getVectorVT(Context, ValueEltVT, Count, /*IsScalable=*/true);
Val = DAG.getNode(ISD::BITCAST, DL, SameEltTypeVT, Val);
}
Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SameEltTypeVT, Val,
DAG.getConstant(0, DL, Subtarget.getXLenVT()));
if (ValueEltVT != PartEltVT)
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ValueVT, Val,
DAG.getVectorIdxConstant(0, DL));
return Val;
}
}
Expand Down
7 changes: 3 additions & 4 deletions llvm/lib/Target/RISCV/RISCVRegisterInfo.td
Expand Up @@ -550,16 +550,15 @@ def VRM8NoV0 : VReg<[vint8m8_t, vint16m8_t, vint32m8_t, vint64m8_t,
vfloat16m8_t, vfloat32m8_t, vfloat64m8_t],
(add V8M8, V16M8, V24M8), 8>;

defvar VMaskVTs = [vbool64_t, vbool32_t, vbool16_t, vbool8_t,
vbool4_t, vbool2_t, vbool1_t];
defvar VMaskVTs = [vbool1_t, vbool2_t, vbool4_t, vbool8_t, vbool16_t,
vbool32_t, vbool64_t];

def VMV0 : RegisterClass<"RISCV", VMaskVTs, 64, (add V0)> {
let Size = 64;
}

// The register class is added for inline assembly for vector mask types.
def VM : VReg<[vbool1_t, vbool2_t, vbool4_t, vbool8_t, vbool16_t,
vbool32_t, vbool64_t],
def VM : VReg<VMaskVTs,
(add (sequence "V%u", 8, 31),
(sequence "V%u", 0, 7)), 1>;

Expand Down
26 changes: 19 additions & 7 deletions llvm/test/CodeGen/RISCV/rvv/inline-asm.ll
Expand Up @@ -10,7 +10,7 @@ define <vscale x 1 x i1> @test_1xi1(<vscale x 1 x i1> %in, <vscale x 1 x i1> %in
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 1 x i1> asm "vmand.mm $0, $1, $2", "=^vm,^vm,^vm"(<vscale x 1 x i1> %in, <vscale x 1 x i1> %in2)
%0 = tail call <vscale x 1 x i1> asm "vmand.mm $0, $1, $2", "=^vr,^vr,^vr"(<vscale x 1 x i1> %in, <vscale x 1 x i1> %in2)
ret <vscale x 1 x i1> %0
}

Expand All @@ -22,7 +22,7 @@ define <vscale x 2 x i1> @test_2xi1(<vscale x 2 x i1> %in, <vscale x 2 x i1> %in
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 2 x i1> asm "vmand.mm $0, $1, $2", "=^vm,^vm,^vm"(<vscale x 2 x i1> %in, <vscale x 2 x i1> %in2)
%0 = tail call <vscale x 2 x i1> asm "vmand.mm $0, $1, $2", "=^vr,^vr,^vr"(<vscale x 2 x i1> %in, <vscale x 2 x i1> %in2)
ret <vscale x 2 x i1> %0
}

Expand All @@ -34,7 +34,7 @@ define <vscale x 4 x i1> @test_4xi1(<vscale x 4 x i1> %in, <vscale x 4 x i1> %in
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 4 x i1> asm "vmand.mm $0, $1, $2", "=^vm,^vm,^vm"(<vscale x 4 x i1> %in, <vscale x 4 x i1> %in2)
%0 = tail call <vscale x 4 x i1> asm "vmand.mm $0, $1, $2", "=^vr,^vr,^vr"(<vscale x 4 x i1> %in, <vscale x 4 x i1> %in2)
ret <vscale x 4 x i1> %0
}

Expand All @@ -46,7 +46,7 @@ define <vscale x 8 x i1> @test_8xi1(<vscale x 8 x i1> %in, <vscale x 8 x i1> %in
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 8 x i1> asm "vmand.mm $0, $1, $2", "=^vm,^vm,^vm"(<vscale x 8 x i1> %in, <vscale x 8 x i1> %in2)
%0 = tail call <vscale x 8 x i1> asm "vmand.mm $0, $1, $2", "=^vr,^vr,^vr"(<vscale x 8 x i1> %in, <vscale x 8 x i1> %in2)
ret <vscale x 8 x i1> %0
}

Expand All @@ -58,7 +58,7 @@ define <vscale x 16 x i1> @test_16xi1(<vscale x 16 x i1> %in, <vscale x 16 x i1>
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 16 x i1> asm "vmand.mm $0, $1, $2", "=^vm,^vm,^vm"(<vscale x 16 x i1> %in, <vscale x 16 x i1> %in2)
%0 = tail call <vscale x 16 x i1> asm "vmand.mm $0, $1, $2", "=^vr,^vr,^vr"(<vscale x 16 x i1> %in, <vscale x 16 x i1> %in2)
ret <vscale x 16 x i1> %0
}

Expand All @@ -70,7 +70,7 @@ define <vscale x 32 x i1> @test_32xi1(<vscale x 32 x i1> %in, <vscale x 32 x i1>
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 32 x i1> asm "vmand.mm $0, $1, $2", "=^vm,^vm,^vm"(<vscale x 32 x i1> %in, <vscale x 32 x i1> %in2)
%0 = tail call <vscale x 32 x i1> asm "vmand.mm $0, $1, $2", "=^vr,^vr,^vr"(<vscale x 32 x i1> %in, <vscale x 32 x i1> %in2)
ret <vscale x 32 x i1> %0
}

Expand All @@ -82,7 +82,7 @@ define <vscale x 64 x i1> @test_64xi1(<vscale x 64 x i1> %in, <vscale x 64 x i1>
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 64 x i1> asm "vmand.mm $0, $1, $2", "=^vm,^vm,^vm"(<vscale x 64 x i1> %in, <vscale x 64 x i1> %in2)
%0 = tail call <vscale x 64 x i1> asm "vmand.mm $0, $1, $2", "=^vr,^vr,^vr"(<vscale x 64 x i1> %in, <vscale x 64 x i1> %in2)
ret <vscale x 64 x i1> %0
}

Expand Down Expand Up @@ -350,6 +350,18 @@ entry:
ret <vscale x 64 x i8> %0
}

define <vscale x 64 x i8> @test_64xi8_with_mask(<vscale x 64 x i8> %in, <vscale x 64 x i8> %in2, <vscale x 64 x i1> %mask) nounwind {
; CHECK-LABEL: test_64xi8_with_mask:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: #APP
; CHECK-NEXT: vadd.vv v8, v8, v16, v0.t
; CHECK-NEXT: #NO_APP
; CHECK-NEXT: ret
entry:
%0 = tail call <vscale x 64 x i8> asm "vadd.vv $0, $1, $2, $3.t", "=^vr,^vr,^vr,^vm"(<vscale x 64 x i8> %in, <vscale x 64 x i8> %in2, <vscale x 64 x i1> %mask)
ret <vscale x 64 x i8> %0
}

define <vscale x 4 x i8> @test_specify_reg_mf2(<vscale x 4 x i8> %in, <vscale x 4 x i8> %in2) nounwind {
; CHECK-LABEL: test_specify_reg_mf2:
; CHECK: # %bb.0: # %entry
Expand Down

0 comments on commit 39c8617

Please sign in to comment.