Skip to content

Commit

Permalink
Merge pull request #6650 from BradleyWood/vsplats
Browse files Browse the repository at this point in the history
x86: implement vsplats for all vector lengths
  • Loading branch information
0xdaryl committed Aug 11, 2022
2 parents e26f652 + f37cfdb commit 497c5e3
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 18 deletions.
17 changes: 12 additions & 5 deletions compiler/x/codegen/OMRCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,13 +1037,20 @@ bool OMR::X86::CodeGenerator::getSupportsOpCodeForAutoSIMD(TR::CPU *cpu, TR::ILO
case TR::VectorLength128:
return true;
default:
return false;
return false;
}
case TR::vsplats:
if (et == TR::Int32 || et == TR::Int64 || et == TR::Float || et == TR::Double)
return ot.getVectorLength() == TR::VectorLength128;
else
return false;
switch (ot.getVectorLength())
{
case TR::VectorLength128:
return true;
case TR::VectorLength256:
return cpu->supportsFeature(OMR_FEATURE_X86_AVX2);
case TR::VectorLength512:
return cpu->supportsFeature(OMR_FEATURE_X86_AVX512F);
default:
return false;
}

/*
* GRA does not work with vector registers on 32 bit due to a bug where xmm registers are not being assigned.
Expand Down
55 changes: 44 additions & 11 deletions compiler/x/codegen/SIMDTreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,19 @@ TR::Register* OMR::X86::TreeEvaluator::SIMDstoreEvaluator(TR::Node* node, TR::Co

TR::Register* OMR::X86::TreeEvaluator::SIMDsplatsEvaluator(TR::Node* node, TR::CodeGenerator* cg)
{
TR_ASSERT_FATAL_WITH_NODE(node, node->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getDataType().toString());

TR::Node* childNode = node->getChild(0);
TR::Register* childReg = cg->evaluate(childNode);

TR::DataType et = node->getDataType().getVectorElementType();
TR::VectorLength vl = node->getDataType().getVectorLength();
TR::Register* resultReg = cg->allocateRegister(TR_VRF);
switch (node->getDataType().getVectorElementType())
bool broadcast64 = et.isInt64() || et.isDouble();

switch (et)
{
case TR::Int8:
case TR::Int16:
case TR::Int32:
generateRegRegInstruction(TR::InstOpCode::MOVDRegReg4, node, resultReg, childReg, cg);
generateRegRegImmInstruction(TR::InstOpCode::PSHUFDRegRegImm1, node, resultReg, resultReg, 0x00, cg); // 00 00 00 00 shuffle xxxA to AAAA
break;
case TR::Int64:
if (cg->comp()->target().is32Bit())
Expand All @@ -176,18 +177,50 @@ TR::Register* OMR::X86::TreeEvaluator::SIMDsplatsEvaluator(TR::Node* node, TR::C
{
generateRegRegInstruction(TR::InstOpCode::MOVQRegReg8, node, resultReg, childReg, cg);
}
generateRegRegImmInstruction(TR::InstOpCode::PSHUFDRegRegImm1, node, resultReg, resultReg, 0x44, cg); // 01 00 01 00 shuffle xxBA to BABA
break;
case TR::Float:
generateRegRegImmInstruction(TR::InstOpCode::PSHUFDRegRegImm1, node, resultReg, childReg, 0x00, cg); // 00 00 00 00 shuffle xxxA to AAAA
break;
case TR::Double:
generateRegRegImmInstruction(TR::InstOpCode::PSHUFDRegRegImm1, node, resultReg, childReg, 0x44, cg); // 01 00 01 00 shuffle xxBA to BABA
generateRegRegInstruction(TR::InstOpCode::MOVSDRegReg, node, resultReg, childReg, cg);
break;
default:
if (cg->comp()->getOption(TR_TraceCG))
traceMsg(cg->comp(), "Unsupported data type, Node = %p\n", node);
TR_ASSERT(false, "Unsupported data type");
TR_ASSERT_FATAL(false, "Unsupported data type");
break;
}

// Expand byte & word to 32-bits
switch (et)
{
case TR::Int8:
generateRegRegInstruction(TR::InstOpCode::PUNPCKLBWRegReg, node, resultReg, resultReg, cg);
case TR::Int16:
generateRegRegImmInstruction(TR::InstOpCode::PSHUFLWRegRegImm1, node, resultReg, resultReg, 0x0, cg);
default:
break;
}

switch (vl)
{
case TR::VectorLength128:
generateRegRegImmInstruction(TR::InstOpCode::PSHUFDRegRegImm1, node, resultReg, resultReg, broadcast64 ? 0x44 : 0, cg);
break;
case TR::VectorLength256:
{
TR_ASSERT_FATAL(cg->comp()->target().cpu.supportsFeature(OMR_FEATURE_X86_AVX2), "256-bit vsplats requires AVX2");
TR::InstOpCode opcode = broadcast64 ? TR::InstOpCode::VBROADCASTSDYmmYmm : TR::InstOpCode::VBROADCASTSSRegReg;
generateRegRegInstruction(opcode.getMnemonic(), node, resultReg, resultReg, cg, opcode.getSIMDEncoding(&cg->comp()->target().cpu, TR::VectorLength256));
break;
}
case TR::VectorLength512:
{
TR_ASSERT_FATAL(cg->comp()->target().cpu.supportsFeature(OMR_FEATURE_X86_AVX512F), "512-bit vsplats requires AVX-512");
TR::InstOpCode opcode = broadcast64 ? TR::InstOpCode::VBROADCASTSDZmmXmm : TR::InstOpCode::VBROADCASTSSRegReg;
generateRegRegInstruction(opcode.getMnemonic(), node, resultReg, resultReg, cg, OMR::X86::EVEX_L512);
break;
}
default:
TR_ASSERT_FATAL(0, "Unsupported vector length");
break;
}

Expand Down
63 changes: 61 additions & 2 deletions fvtest/compilertriltest/VectorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ TEST_P(ParameterizedVectorTest, VLoadStore) {
TR::ILOpCode loadOp = TR::ILOpCode::createVectorOpCode(TR::vloadi, vt);
TR::ILOpCode storeOp = TR::ILOpCode::createVectorOpCode(TR::vstorei, vt);
TR::CPU cpu = TR::CPU::detect(privateOmrPortLibrary);
bool platformSupport = TR::CodeGenerator::getSupportsOpCodeForAutoSIMD(&cpu, loadOp) && TR::CodeGenerator::getSupportsOpCodeForAutoSIMD(&cpu, loadOp);
bool platformSupport = TR::CodeGenerator::getSupportsOpCodeForAutoSIMD(&cpu, loadOp) && TR::CodeGenerator::getSupportsOpCodeForAutoSIMD(&cpu, storeOp);
SKIP_IF(!platformSupport, MissingImplementation) << "Opcode is not supported by the target platform";

char inputTrees[1024];
Expand Down Expand Up @@ -467,7 +467,66 @@ TEST_P(ParameterizedVectorTest, VLoadStore) {
EXPECT_EQ(0, memcmp(output + TR::DataType::getSize(vt), zero, maxVectorLength - TR::DataType::getSize(vt)));
}

INSTANTIATE_TEST_CASE_P(VLoadStoreVectorTest, ParameterizedVectorTest, ::testing::ValuesIn(*TRTest::MakeVector<std::tuple<TR::VectorLength, TR::DataTypes>>(
TEST_P(ParameterizedVectorTest, VSplats) {
TR::VectorLength vl = std::get<0>(GetParam());
TR::DataTypes et = std::get<1>(GetParam());

SKIP_IF(vl > TR::NumVectorLengths, MissingImplementation) << "Vector length is not supported by the target platform";
SKIP_ON_S390(KnownBug) << "This test is currently disabled on Z platforms because not all Z platforms have vector support (issue #1843)";
SKIP_ON_S390X(KnownBug) << "This test is currently disabled on Z platforms because not all Z platforms have vector support (issue #1843)";

TR::DataType vt = TR::DataType::createVectorType(et, vl);

TR::ILOpCode loadOp = TR::ILOpCode::createVectorOpCode(TR::vloadi, vt);
TR::ILOpCode storeOp = TR::ILOpCode::createVectorOpCode(TR::vstorei, vt);
TR::ILOpCode splatsOp = TR::ILOpCode::createVectorOpCode(TR::vsplats, vt);
TR::ILOpCode elementLoadOp = OMR::ILOpCode::indirectLoadOpCode(et);
TR::CPU cpu = TR::CPU::detect(privateOmrPortLibrary);

bool platformSupport = TR::CodeGenerator::getSupportsOpCodeForAutoSIMD(&cpu, loadOp) &&
TR::CodeGenerator::getSupportsOpCodeForAutoSIMD(&cpu, storeOp) &&
TR::CodeGenerator::getSupportsOpCodeForAutoSIMD(&cpu, splatsOp);

SKIP_IF(!platformSupport, MissingImplementation) << "Opcode " << splatsOp.getName() << vt.toString() << " is not supported by the target platform";

char inputTrees[1024];
char *formatStr = "(method return= NoType args=[Address,Address] "
" (block "
" (vstorei%s offset=0 "
" (aload parm=0) "
" (vsplats%s "
" (%s (aload parm=1)))) "
" (return))) ";

sprintf(inputTrees, formatStr, vt.toString(), vt.toString(), elementLoadOp.getName());

auto trees = parseString(inputTrees);
ASSERT_NOTNULL(trees);

Tril::DefaultCompiler compiler(trees);
ASSERT_EQ(0, compiler.compile()) << "Compilation failed unexpectedly\n" << "Input trees: " << inputTrees;

auto entry_point = compiler.getEntryPoint<void (*)(void *,void *)>();

const uint8_t maxVectorLength = 64;
char output[maxVectorLength] = {0};
char expected[maxVectorLength] = {0};
char input[maxVectorLength] = {0};

int etSize = typeSize(et);
int vlSize = vectorSize(vl);
generateByType(input, et, false);

for (int i = 0; i < vlSize; i += etSize) {
memcpy(expected + i, input, etSize);
}

entry_point(output, input);

EXPECT_EQ(0, memcmp(expected, output, vlSize));
}

INSTANTIATE_TEST_CASE_P(VectorTypeParameters, ParameterizedVectorTest, ::testing::ValuesIn(*TRTest::MakeVector<std::tuple<TR::VectorLength, TR::DataTypes>>(
std::make_tuple(TR::VectorLength128, TR::Int8),
std::make_tuple(TR::VectorLength128, TR::Int16),
std::make_tuple(TR::VectorLength128, TR::Int32),
Expand Down

0 comments on commit 497c5e3

Please sign in to comment.