Skip to content

Commit

Permalink
JIT ARM64-SVE: Add TrueMask and LoadVector (#98218)
Browse files Browse the repository at this point in the history
* JIT ARM64-SVE: Add TrueMask

Change-Id: I285f8aba668409ca94e11be2489a6d9b50a4ec6e

* LoadVector

Change-Id: I3ad4fd9a8d823cb43a9546ba6356006a0907ac57

* Add SveLoadUnOpMaskedTest.template

* Add CreateTrueMaskByte etc

* Fix up tests

* Remove commented code

* Explain SveMaskPattern

* ARM64-SVE: Implement IF_SVE_BV_2A

* Create vector to/from mask nodes in intrinsic generation

* Add HW_Flag_LowMaskedOperation

* Revert "ARM64-SVE: Implement IF_SVE_BV_2A"

* Use NI_Sve_CreateTrueMaskAll

* Mark API as experimental

* Revert "Mark API as experimental"

This reverts commit 6beb760.

---------

Co-authored-by: Kunal Pathak <Kunal.Pathak@microsoft.com>
  • Loading branch information
a74nh and kunalspathak authored Mar 12, 2024
1 parent cefd7b2 commit 17eb59c
Show file tree
Hide file tree
Showing 23 changed files with 989 additions and 29 deletions.
5 changes: 5 additions & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4554,6 +4554,11 @@ class Compiler
NamedIntrinsic intrinsic, GenTree* immOp, bool mustExpand, int immLowerBound, int immUpperBound);
GenTree* addRangeCheckForHWIntrinsic(GenTree* immOp, int immLowerBound, int immUpperBound);

#if defined(TARGET_ARM64)
GenTree* convertHWIntrinsicToMask(var_types type, GenTree* node, CorInfoType simdBaseJitType, unsigned simdSize);
GenTree* convertHWIntrinsicFromMask(GenTreeHWIntrinsic* node, var_types type);
#endif

#endif // FEATURE_HW_INTRINSICS
GenTree* impArrayAccessIntrinsic(CORINFO_CLASS_HANDLE clsHnd,
CORINFO_SIG_INFO* sig,
Expand Down
34 changes: 34 additions & 0 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7303,6 +7303,34 @@ emitter::code_t emitter::emitInsCodeSve(instruction ins, insFormat fmt)
}
}

// For the given 'elemsize' returns the 'arrangement' when used in a SVE vector register arrangement.
// Asserts and returns INS_OPTS_NONE if an invalid 'elemsize' is passed
//
/*static*/ insOpts emitter::optGetSveInsOpt(emitAttr elemsize)
{
switch (elemsize)
{
case EA_1BYTE:
return INS_OPTS_SCALABLE_B;

case EA_2BYTE:
return INS_OPTS_SCALABLE_H;

case EA_4BYTE:
return INS_OPTS_SCALABLE_S;

case EA_8BYTE:
return INS_OPTS_SCALABLE_D;

case EA_16BYTE:
return INS_OPTS_SCALABLE_Q;

default:
assert(!"Invalid emitAttr for sve vector register");
return INS_OPTS_NONE;
}
}

// For the given 'arrangement' returns the 'elemsize' specified by the SVE vector register arrangement
// asserts and returns EA_UNKNOWN if an invalid 'arrangement' value is passed
//
Expand Down Expand Up @@ -13020,6 +13048,12 @@ void emitter::emitIns_R_R_R(instruction ins,
fmt = IF_SVE_HP_3A;
break;

case INS_sve_ld1b:
case INS_sve_ld1h:
case INS_sve_ld1w:
case INS_sve_ld1d:
return emitIns_R_R_R_I(ins, size, reg1, reg2, reg3, 0, opt);

default:
unreached();
break;
Expand Down
3 changes: 3 additions & 0 deletions src/coreclr/jit/emitarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,9 @@ static emitAttr optGetDatasize(insOpts arrangement);
// For the given 'arrangement' returns the 'elemsize' specified by the vector register arrangement
static emitAttr optGetElemsize(insOpts arrangement);

// For the given 'elemsize' returns the 'arrangement' when used in a SVE vector register arrangement.
static insOpts optGetSveInsOpt(emitAttr elemsize);

// For the given 'arrangement' returns the 'elemsize' specified by the SVE vector register arrangement
static emitAttr optGetSveElemsize(insOpts arrangement);

Expand Down
5 changes: 4 additions & 1 deletion src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26067,9 +26067,12 @@ bool GenTreeHWIntrinsic::OperIsMemoryLoad(GenTree** pAddr) const
case NI_AdvSimd_Arm64_LoadAndInsertScalarVector128x2:
case NI_AdvSimd_Arm64_LoadAndInsertScalarVector128x3:
case NI_AdvSimd_Arm64_LoadAndInsertScalarVector128x4:

addr = Op(3);
break;

case NI_Sve_LoadVector:
addr = Op(2);
break;
#endif // TARGET_ARM64

default:
Expand Down
49 changes: 39 additions & 10 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,15 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
compFloatingPointUsed = true;
}

var_types nodeRetType = retType;
#if defined(TARGET_ARM64)
if (HWIntrinsicInfo::ReturnsPerElementMask(intrinsic))
{
// Ensure the result is generated to a mask.
nodeRetType = TYP_MASK;
}
#endif // defined(TARGET_ARM64)

// table-driven importer of simple intrinsics
if (impIsTableDrivenHWIntrinsic(intrinsic, category))
{
Expand Down Expand Up @@ -1392,7 +1401,7 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
case 0:
{
assert(!isScalar);
retNode = gtNewSimdHWIntrinsicNode(retType, intrinsic, simdBaseJitType, simdSize);
retNode = gtNewSimdHWIntrinsicNode(nodeRetType, intrinsic, simdBaseJitType, simdSize);
break;
}

Expand All @@ -1410,8 +1419,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}
}

retNode = isScalar ? gtNewScalarHWIntrinsicNode(retType, op1, intrinsic)
: gtNewSimdHWIntrinsicNode(retType, op1, intrinsic, simdBaseJitType, simdSize);
retNode = isScalar ? gtNewScalarHWIntrinsicNode(nodeRetType, op1, intrinsic)
: gtNewSimdHWIntrinsicNode(nodeRetType, op1, intrinsic, simdBaseJitType, simdSize);

#if defined(TARGET_XARCH)
switch (intrinsic)
Expand Down Expand Up @@ -1462,8 +1471,9 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

retNode = isScalar ? gtNewScalarHWIntrinsicNode(retType, op1, op2, intrinsic)
: gtNewSimdHWIntrinsicNode(retType, op1, op2, intrinsic, simdBaseJitType, simdSize);
retNode = isScalar
? gtNewScalarHWIntrinsicNode(nodeRetType, op1, op2, intrinsic)
: gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, intrinsic, simdBaseJitType, simdSize);

#ifdef TARGET_XARCH
if ((intrinsic == NI_SSE42_Crc32) || (intrinsic == NI_SSE42_X64_Crc32))
Expand Down Expand Up @@ -1543,9 +1553,9 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
op3 = addRangeCheckIfNeeded(intrinsic, op3, mustExpand, immLowerBound, immUpperBound);
}

retNode = isScalar
? gtNewScalarHWIntrinsicNode(retType, op1, op2, op3, intrinsic)
: gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);
retNode = isScalar ? gtNewScalarHWIntrinsicNode(nodeRetType, op1, op2, op3, intrinsic)
: gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, intrinsic, simdBaseJitType,
simdSize);

#ifdef TARGET_XARCH
if ((intrinsic == NI_AVX2_GatherVector128) || (intrinsic == NI_AVX2_GatherVector256))
Expand All @@ -1566,7 +1576,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

assert(!isScalar);
retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);
retNode =
gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);
break;
}

Expand All @@ -1576,8 +1587,26 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}
else
{
retNode = impSpecialIntrinsic(intrinsic, clsHnd, method, sig, simdBaseJitType, retType, simdSize);
retNode = impSpecialIntrinsic(intrinsic, clsHnd, method, sig, simdBaseJitType, nodeRetType, simdSize);
}

#if defined(TARGET_ARM64)
if (HWIntrinsicInfo::IsMaskedOperation(intrinsic))
{
// Op1 input is a vector. HWInstrinsic requires a mask, so convert to a mask.
assert(numArgs > 0);
GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);
op1 = convertHWIntrinsicToMask(retType, op1, simdBaseJitType, simdSize);
retNode->AsHWIntrinsic()->Op(1) = op1;
}

if (retType != nodeRetType)
{
// HWInstrinsic returns a mask, but all returns must be vectors, so convert mask to vector.
assert(HWIntrinsicInfo::ReturnsPerElementMask(intrinsic));
retNode = convertHWIntrinsicFromMask(retNode->AsHWIntrinsic(), retType);
}
#endif // defined(TARGET_ARM64)

if ((retNode != nullptr) && retNode->OperIs(GT_HWINTRINSIC))
{
Expand Down
41 changes: 37 additions & 4 deletions src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ enum HWIntrinsicCategory : uint8_t
HW_Category_ShiftLeftByImmediate,
HW_Category_ShiftRightByImmediate,
HW_Category_SIMDByIndexedElement,
HW_Category_EnumPattern,

// Helper intrinsics
// - do not directly correspond to a instruction, such as Vector64.AllBitsSet
Expand Down Expand Up @@ -175,6 +176,21 @@ enum HWIntrinsicFlag : unsigned int

// The intrinsic needs consecutive registers
HW_Flag_NeedsConsecutiveRegisters = 0x4000,

// The intrinsic uses scalable registers
HW_Flag_Scalable = 0x8000,

// Returns Per-Element Mask
// the intrinsic returns a vector containing elements that are either "all bits set" or "all bits clear"
// this output can be used as a per-element mask
HW_Flag_ReturnsPerElementMask = 0x10000,

// The intrinsic uses a mask in arg1 to select elements present in the result
HW_Flag_MaskedOperation = 0x20000,

// The intrinsic uses a mask in arg1 to select elements present in the result, and must use a low register.
HW_Flag_LowMaskedOperation = 0x40000,

#else
#error Unsupported platform
#endif
Expand Down Expand Up @@ -654,10 +670,8 @@ struct HWIntrinsicInfo
static bool ReturnsPerElementMask(NamedIntrinsic id)
{
HWIntrinsicFlag flags = lookupFlags(id);
#if defined(TARGET_XARCH)
#if defined(TARGET_XARCH) || defined(TARGET_ARM64)
return (flags & HW_Flag_ReturnsPerElementMask) != 0;
#elif defined(TARGET_ARM64)
unreached();
#else
#error Unsupported platform
#endif
Expand Down Expand Up @@ -848,6 +862,25 @@ struct HWIntrinsicInfo
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_HasImmediateOperand) != 0;
}

static bool IsScalable(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_Scalable) != 0;
}

static bool IsMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return ((flags & HW_Flag_MaskedOperation) != 0) || IsLowMaskedOperation(id);
}

static bool IsLowMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_LowMaskedOperation) != 0;
}

#endif // TARGET_ARM64

static bool HasSpecialSideEffect(NamedIntrinsic id)
Expand Down Expand Up @@ -907,7 +940,7 @@ struct HWIntrinsic final
InitializeBaseType(node);
}

bool IsTableDriven() const
bool codeGenIsTableDriven() const
{
// TODO-Arm64-Cleanup - make more categories to the table-driven framework
bool isTableDrivenCategory = category != HW_Category_Helper;
Expand Down
54 changes: 54 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,20 @@ void HWIntrinsicInfo::lookupImmBounds(
immUpperBound = Compiler::getSIMDVectorLength(simdSize, baseType) - 1;
break;

case NI_Sve_CreateTrueMaskByte:
case NI_Sve_CreateTrueMaskDouble:
case NI_Sve_CreateTrueMaskInt16:
case NI_Sve_CreateTrueMaskInt32:
case NI_Sve_CreateTrueMaskInt64:
case NI_Sve_CreateTrueMaskSByte:
case NI_Sve_CreateTrueMaskSingle:
case NI_Sve_CreateTrueMaskUInt16:
case NI_Sve_CreateTrueMaskUInt32:
case NI_Sve_CreateTrueMaskUInt64:
immLowerBound = (int)SVE_PATTERN_POW2;
immUpperBound = (int)SVE_PATTERN_ALL;
break;

default:
unreached();
}
Expand Down Expand Up @@ -2179,6 +2193,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);
break;
}

default:
{
return nullptr;
Expand All @@ -2188,4 +2203,43 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
return retNode;
}

//------------------------------------------------------------------------
// convertHWIntrinsicFromMask: Convert a HW instrinsic vector node to a mask
//
// Arguments:
// node -- The node to convert
// simdBaseJitType -- the base jit type of the converted node
// simdSize -- the simd size of the converted node
//
// Return Value:
// The node converted to the a mask type
//
GenTree* Compiler::convertHWIntrinsicToMask(var_types type,
GenTree* node,
CorInfoType simdBaseJitType,
unsigned simdSize)
{
// ConvertVectorToMask uses cmpne which requires an embedded mask.
GenTree* embeddedMask = gtNewSimdHWIntrinsicNode(TYP_MASK, NI_Sve_CreateTrueMaskAll, simdBaseJitType, simdSize);
return gtNewSimdHWIntrinsicNode(TYP_MASK, embeddedMask, node, NI_Sve_ConvertVectorToMask, simdBaseJitType,
simdSize);
}

//------------------------------------------------------------------------
// convertHWIntrinsicFromMask: Convert a HW instrinsic mask node to a vector
//
// Arguments:
// node -- The node to convert
// type -- The type of the node to convert to
//
// Return Value:
// The node converted to the given type
//
GenTree* Compiler::convertHWIntrinsicFromMask(GenTreeHWIntrinsic* node, var_types type)
{
assert(node->TypeGet() == TYP_MASK);
return gtNewSimdHWIntrinsicNode(type, node, NI_Sve_ConvertMaskToVector, node->GetSimdBaseJitType(),
node->GetSimdSize());
}

#endif // FEATURE_HW_INTRINSICS
Loading

0 comments on commit 17eb59c

Please sign in to comment.