Skip to content

Commit

Permalink
Merge pull request #6610 from Akira1Saitoh/aarch64VectorReduceMin
Browse files Browse the repository at this point in the history
AArch64: Implement evaluator for vector reduction min/max
  • Loading branch information
knn-k committed Jul 13, 2022
2 parents c728c41 + 6dad45a commit 29756eb
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 2 deletions.
2 changes: 2 additions & 0 deletions compiler/aarch64/codegen/OMRCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,8 @@ bool OMR::ARM64::CodeGenerator::getSupportsOpCodeForAutoSIMD(TR::CPU *cpu, TR::I
case TR::vmax:
case TR::vreductionAdd:
case TR::vreductionMul:
case TR::vreductionMax:
case TR::vreductionMin:
return true;
case TR::vand:
case TR::vor:
Expand Down
161 changes: 159 additions & 2 deletions compiler/aarch64/codegen/OMRTreeEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1741,16 +1741,173 @@ OMR::ARM64::TreeEvaluator::vreductionFirstNonZeroEvaluator(TR::Node *node, TR::C
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
}

/**
* @brief A helper function for generating instuction sequence for reduction min/max operations for vectors of integer elements.
*
* @param[in] node: node
* @param[in] et: element type
* @param[in] isMax: true if operation is max
* @param[in] cg: CodeGenerator
* @return general purpose register containing the result
*/
static TR::Register*
vreductionMinMaxIntHelper(TR::Node *node, TR::DataType et, bool isMax, TR::CodeGenerator *cg)
{
TR::Node *sourceChild = node->getFirstChild();
TR::Register *sourceReg = cg->evaluate(sourceChild);

TR_ASSERT_FATAL_WITH_NODE(node, sourceReg->getKind() == TR_VRF, "unexpected Register kind");

TR::InstOpCode::Mnemonic op;
TR::InstOpCode::Mnemonic movOp;
switch (et)
{
case TR::Int8:
op = isMax ? TR::InstOpCode::vsmaxv16b : TR::InstOpCode::vsminv16b;
movOp = TR::InstOpCode::smovwb;
break;
case TR::Int16:
op = isMax ? TR::InstOpCode::vsmaxv8h : TR::InstOpCode::vsminv8h;
movOp = TR::InstOpCode::smovwh;
break;
case TR::Int32:
op = isMax ? TR::InstOpCode::vsmaxv4s : TR::InstOpCode::vsminv4s;
movOp = TR::InstOpCode::umovws;
break;
default:
TR_ASSERT_FATAL_WITH_NODE(node, false, "Unexpected element type");
break;
}

TR::Register *tmpReg = cg->allocateRegister(TR_VRF);
TR::Register *resReg = cg->allocateRegister(TR_GPR);

generateTrg1Src1Instruction(cg, op, node, tmpReg, sourceReg);
generateMovVectorElementToGPRInstruction(cg, movOp, node, resReg, tmpReg, 0);

cg->stopUsingRegister(tmpReg);
node->setRegister(resReg);
cg->decReferenceCount(sourceChild);

return resReg;
}

/**
* @brief A helper function for generating instuction sequence for reduction min/max operations for vectors of 64-bit integer elements.
*
* @param[in] node: node
* @param[in] isMax: true if operation is max
* @param[in] cg: CodeGenerator
* @return register containing the result
*/
static TR::Register *
vreductionMinMaxInt64Helper(TR::Node *node, bool isMax, TR::CodeGenerator *cg)
{
TR::Node *sourceChild = node->getFirstChild();
TR::Register *sourceReg = cg->evaluate(sourceChild);

TR_ASSERT_FATAL_WITH_NODE(node, sourceReg->getKind() == TR_VRF, "unexpected Register kind");

TR::Register *tmpReg = cg->allocateRegister(TR_GPR);
TR::Register *resReg = cg->allocateRegister(TR_GPR);

generateMovVectorElementToGPRInstruction(cg, TR::InstOpCode::umovxd, node, tmpReg, sourceReg, 0);
generateMovVectorElementToGPRInstruction(cg, TR::InstOpCode::umovxd, node, resReg, sourceReg, 1);

generateCompareInstruction(cg, node, tmpReg, resReg, true);
generateCondTrg1Src2Instruction(cg, TR::InstOpCode::cselx, node, resReg, tmpReg, resReg, isMax ? TR::CC_GT : TR::CC_LT);

cg->stopUsingRegister(tmpReg);
node->setRegister(resReg);
cg->decReferenceCount(sourceChild);

return resReg;
}

/**
* @brief A helper function for generating instuction sequence for reduction min/max operations for vectors of float elements.
*
* @param[in] node: node
* @param[in] et: element type
* @param[in] isMax: true if operation is max
* @param[in] cg: CodeGenerator
* @return floating point register containing the result
*/
static TR::Register*
vreductionMinMaxFloatHelper(TR::Node *node, TR::DataType et, bool isMax, TR::CodeGenerator *cg)
{
TR::Node *sourceChild = node->getFirstChild();
TR::Register *sourceReg = cg->evaluate(sourceChild);

TR_ASSERT_FATAL_WITH_NODE(node, sourceReg->getKind() == TR_VRF, "unexpected Register kind");
TR::Register *resReg = cg->allocateRegister(TR_FPR);
if (et == TR::Float)
{
generateTrg1Src1Instruction(cg, isMax ? TR::InstOpCode::vfmaxv4s : TR::InstOpCode::vfminv4s, node, resReg, sourceReg);
}
else if (et == TR::Double)
{
generateTrg1Src1Instruction(cg, isMax ? TR::InstOpCode::fmaxp2d : TR::InstOpCode::fminp2d, node, resReg, sourceReg);
}
else
{
TR_ASSERT_FATAL_WITH_NODE(node, false, "Unexpected element type");
}

node->setRegister(resReg);
cg->decReferenceCount(sourceChild);

return resReg;
}

TR::Register*
OMR::ARM64::TreeEvaluator::vreductionMaxEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getFirstChild()->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getFirstChild()->getDataType().toString());

TR::DataType et = node->getFirstChild()->getDataType().getVectorElementType();
switch(et)
{
case TR::Int8:
case TR::Int16:
case TR::Int32:
return vreductionMinMaxIntHelper(node, et, true, cg);
case TR::Int64:
/* SMAXV does not accept 64bit elements */
return vreductionMinMaxInt64Helper(node, true, cg);
case TR::Float:
case TR::Double:
return vreductionMinMaxFloatHelper(node, et, true, cg);
default:
TR_ASSERT_FATAL_WITH_NODE(node, false, "unrecognized vector type %s", node->getFirstChild()->getDataType().toString());
return NULL;
}
}

TR::Register*
OMR::ARM64::TreeEvaluator::vreductionMinEvaluator(TR::Node *node, TR::CodeGenerator *cg)
{
return TR::TreeEvaluator::unImpOpEvaluator(node, cg);
TR_ASSERT_FATAL_WITH_NODE(node, node->getFirstChild()->getDataType().getVectorLength() == TR::VectorLength128,
"Only 128-bit vectors are supported %s", node->getFirstChild()->getDataType().toString());

TR::DataType et = node->getFirstChild()->getDataType().getVectorElementType();
switch(et)
{
case TR::Int8:
case TR::Int16:
case TR::Int32:
return vreductionMinMaxIntHelper(node, et, false, cg);
case TR::Int64:
/* SMINV does not accept 64bit elements */
return vreductionMinMaxInt64Helper(node, false, cg);
case TR::Float:
case TR::Double:
return vreductionMinMaxFloatHelper(node, et, false, cg);
default:
TR_ASSERT_FATAL_WITH_NODE(node, false, "unrecognized vector type %s", node->getFirstChild()->getDataType().toString());
return NULL;
}
}

/**
Expand Down

0 comments on commit 29756eb

Please sign in to comment.