Skip to content

Commit

Permalink
GlobalISel: Rewrite getLCMType
Browse files Browse the repository at this point in the history
Try to make the behavior more consistent with getGCDType, and bias
towards returning something closer to the source type whenever there's
an ambiguity.
  • Loading branch information
arsenm committed Jul 21, 2020
1 parent 12d5bec commit 1ef3ed0
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 39 deletions.
10 changes: 5 additions & 5 deletions llvm/include/llvm/CodeGen/GlobalISel/Utils.h
Expand Up @@ -190,12 +190,12 @@ inline bool isKnownNeverSNaN(Register Val, const MachineRegisterInfo &MRI) {

Align inferAlignFromPtrInfo(MachineFunction &MF, const MachinePointerInfo &MPO);

/// Return the least common multiple type of \p Ty0 and \p Ty1, by changing
/// the number of vector elements or scalar bitwidth. The intent is a
/// G_MERGE_VALUES can be constructed from \p Ty0 elements, and unmerged into
/// \p Ty1.
/// Return the least common multiple type of \p OrigTy and \p TargetTy, by changing the
/// number of vector elements or scalar bitwidth. The intent is a
/// G_MERGE_VALUES, G_BUILD_VECTOR, or G_CONCAT_VECTORS can be constructed from
/// \p OrigTy elements, and unmerged into \p TargetTy
LLVM_READNONE
LLT getLCMType(LLT Ty0, LLT Ty1);
LLT getLCMType(LLT OrigTy, LLT TargetTy);

/// Return a type where the total size is the greatest common divisor of \p
/// OrigTy and \p TargetTy. This will try to either change the number of vector
Expand Down
64 changes: 42 additions & 22 deletions llvm/lib/CodeGen/GlobalISel/Utils.cpp
Expand Up @@ -510,35 +510,55 @@ void llvm::getSelectionDAGFallbackAnalysisUsage(AnalysisUsage &AU) {
AU.addPreserved<StackProtector>();
}

LLT llvm::getLCMType(LLT Ty0, LLT Ty1) {
if (!Ty0.isVector() && !Ty1.isVector()) {
unsigned Mul = Ty0.getSizeInBits() * Ty1.getSizeInBits();
int GCDSize = greatestCommonDivisor(Ty0.getSizeInBits(),
Ty1.getSizeInBits());
return LLT::scalar(Mul / GCDSize);
}
static unsigned getLCMSize(unsigned OrigSize, unsigned TargetSize) {
unsigned Mul = OrigSize * TargetSize;
unsigned GCDSize = greatestCommonDivisor(OrigSize, TargetSize);
return Mul / GCDSize;
}

if (Ty0.isVector() && !Ty1.isVector()) {
assert(Ty0.getElementType() == Ty1 && "not yet handled");
return Ty0;
}
LLT llvm::getLCMType(LLT OrigTy, LLT TargetTy) {
const unsigned OrigSize = OrigTy.getSizeInBits();
const unsigned TargetSize = TargetTy.getSizeInBits();

if (Ty1.isVector() && !Ty0.isVector()) {
assert(Ty1.getElementType() == Ty0 && "not yet handled");
return Ty1;
}
if (OrigSize == TargetSize)
return OrigTy;

if (OrigTy.isVector()) {
const LLT OrigElt = OrigTy.getElementType();

if (TargetTy.isVector()) {
const LLT TargetElt = TargetTy.getElementType();

if (Ty0.isVector() && Ty1.isVector()) {
assert(Ty0.getElementType() == Ty1.getElementType() && "not yet handled");
if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
int GCDElts = greatestCommonDivisor(OrigTy.getNumElements(),
TargetTy.getNumElements());
// Prefer the original element type.
int Mul = OrigTy.getNumElements() * TargetTy.getNumElements();
return LLT::vector(Mul / GCDElts, OrigTy.getElementType());
}
} else {
if (OrigElt.getSizeInBits() == TargetSize)
return OrigTy;
}

int GCDElts = greatestCommonDivisor(Ty0.getNumElements(),
Ty1.getNumElements());
unsigned LCMSize = getLCMSize(OrigSize, TargetSize);
return LLT::vector(LCMSize / OrigElt.getSizeInBits(), OrigElt);
}

int Mul = Ty0.getNumElements() * Ty1.getNumElements();
return LLT::vector(Mul / GCDElts, Ty0.getElementType());
if (TargetTy.isVector()) {
unsigned LCMSize = getLCMSize(OrigSize, TargetSize);
return LLT::vector(LCMSize / OrigSize, OrigTy);
}

llvm_unreachable("not yet handled");
unsigned LCMSize = getLCMSize(OrigSize, TargetSize);

// Preserve pointer types.
if (LCMSize == OrigSize)
return OrigTy;
if (LCMSize == TargetSize)
return TargetTy;

return LLT::scalar(LCMSize);
}

LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
Expand Down
70 changes: 58 additions & 12 deletions llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp
Expand Up @@ -22,6 +22,7 @@ static const LLT P1 = LLT::pointer(1, 32);

static const LLT V2S8 = LLT::vector(2, 8);
static const LLT V4S8 = LLT::vector(4, 8);
static const LLT V8S8 = LLT::vector(8, 8);

static const LLT V2S16 = LLT::vector(2, 16);
static const LLT V3S16 = LLT::vector(3, 16);
Expand All @@ -33,6 +34,7 @@ static const LLT V4S32 = LLT::vector(4, 32);
static const LLT V6S32 = LLT::vector(6, 32);

static const LLT V2S64 = LLT::vector(2, 64);
static const LLT V3S64 = LLT::vector(3, 64);
static const LLT V4S64 = LLT::vector(4, 64);

static const LLT V2P0 = LLT::vector(2, P0);
Expand Down Expand Up @@ -157,18 +159,18 @@ TEST(GISelUtilsTest, getLCMType) {
EXPECT_EQ(S32, getLCMType(S16, S32));

EXPECT_EQ(S64, getLCMType(S64, P0));
EXPECT_EQ(S64, getLCMType(P0, S64));
EXPECT_EQ(P0, getLCMType(P0, S64));

EXPECT_EQ(S64, getLCMType(S32, P0));
EXPECT_EQ(S64, getLCMType(P0, S32));
EXPECT_EQ(P0, getLCMType(S32, P0));
EXPECT_EQ(P0, getLCMType(P0, S32));

EXPECT_EQ(S32, getLCMType(S32, P1));
EXPECT_EQ(S32, getLCMType(P1, S32));
EXPECT_EQ(S64, getLCMType(P0, P0));
EXPECT_EQ(S32, getLCMType(P1, P1));
EXPECT_EQ(P1, getLCMType(P1, S32));
EXPECT_EQ(P0, getLCMType(P0, P0));
EXPECT_EQ(P1, getLCMType(P1, P1));

EXPECT_EQ(S64, getLCMType(P0, P1));
EXPECT_EQ(S64, getLCMType(P1, P0));
EXPECT_EQ(P0, getLCMType(P0, P1));
EXPECT_EQ(P0, getLCMType(P1, P0));

EXPECT_EQ(V2S32, getLCMType(V2S32, V2S32));
EXPECT_EQ(V2S32, getLCMType(V2S32, S32));
Expand All @@ -188,11 +190,55 @@ TEST(GISelUtilsTest, getLCMType) {
EXPECT_EQ(LLT::vector(12, P0), getLCMType(V4P0, V3P0));
EXPECT_EQ(LLT::vector(12, P0), getLCMType(V3P0, V4P0));

// FIXME
// EXPECT_EQ(V2S32, getLCMType(V2S32, S64));
EXPECT_EQ(LLT::vector(12, S64), getLCMType(V4S64, V3P0));
EXPECT_EQ(LLT::vector(12, P0), getLCMType(V3P0, V4S64));

// FIXME
//EXPECT_EQ(S64, getLCMType(S64, V2S32));
EXPECT_EQ(LLT::vector(12, P0), getLCMType(V4P0, V3S64));
EXPECT_EQ(LLT::vector(12, S64), getLCMType(V3S64, V4P0));

EXPECT_EQ(V2P0, getLCMType(V2P0, S32));
EXPECT_EQ(V4S32, getLCMType(S32, V2P0));
EXPECT_EQ(V2P0, getLCMType(V2P0, S64));
EXPECT_EQ(V2S64, getLCMType(S64, V2P0));


EXPECT_EQ(V2P0, getLCMType(V2P0, V2P1));
EXPECT_EQ(V4P1, getLCMType(V2P1, V2P0));

EXPECT_EQ(V2P0, getLCMType(V2P0, V4P1));
EXPECT_EQ(V4P1, getLCMType(V4P1, V2P0));


EXPECT_EQ(V2S32, getLCMType(V2S32, S64));
EXPECT_EQ(S64, getLCMType(S64, V2S32));

EXPECT_EQ(V4S16, getLCMType(V4S16, V2S32));
EXPECT_EQ(V2S32, getLCMType(V2S32, V4S16));

EXPECT_EQ(V2S32, getLCMType(V2S32, V4S8));
EXPECT_EQ(V8S8, getLCMType(V4S8, V2S32));

EXPECT_EQ(V2S16, getLCMType(V2S16, V4S8));
EXPECT_EQ(V4S8, getLCMType(V4S8, V2S16));

EXPECT_EQ(LLT::vector(6, S16), getLCMType(V3S16, V4S8));
EXPECT_EQ(LLT::vector(12, S8), getLCMType(V4S8, V3S16));
EXPECT_EQ(V4S16, getLCMType(V4S16, V4S8));
EXPECT_EQ(V8S8, getLCMType(V4S8, V4S16));

EXPECT_EQ(LLT::vector(6, 4), getLCMType(LLT::vector(3, 4), S8));
EXPECT_EQ(LLT::vector(3, 8), getLCMType(S8, LLT::vector(3, 4)));

EXPECT_EQ(LLT::vector(6, 4),
getLCMType(LLT::vector(3, 4), LLT::pointer(4, 8)));
EXPECT_EQ(LLT::vector(3, LLT::pointer(4, 8)),
getLCMType(LLT::pointer(4, 8), LLT::vector(3, 4)));

EXPECT_EQ(V2S64, getLCMType(V2S64, P0));
EXPECT_EQ(V2P0, getLCMType(P0, V2S64));

EXPECT_EQ(V2S64, getLCMType(V2S64, P1));
EXPECT_EQ(V4P1, getLCMType(P1, V2S64));
}

}

0 comments on commit 1ef3ed0

Please sign in to comment.