Skip to content

Commit

Permalink
JIT ARM64-SVE: Allow LCL_VARs to store as mask
Browse files Browse the repository at this point in the history
  • Loading branch information
a74nh committed Mar 12, 2024
1 parent 4018d58 commit 6628904
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 50 deletions.
4 changes: 4 additions & 0 deletions src/coreclr/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,16 @@ function(create_standalone_jit)
if ((TARGETDETAILS_ARCH STREQUAL "x64") OR (TARGETDETAILS_ARCH STREQUAL "arm64") OR ((TARGETDETAILS_ARCH STREQUAL "x86") AND NOT (TARGETDETAILS_OS STREQUAL "unix")))
target_compile_definitions(${TARGETDETAILS_TARGET} PRIVATE FEATURE_SIMD)
target_compile_definitions(${TARGETDETAILS_TARGET} PRIVATE FEATURE_HW_INTRINSICS)
target_compile_definitions(${TARGETDETAILS_TARGET} PRIVATE FEATURE_MASKED_SIMD)
target_compile_definitions(${TARGETDETAILS_TARGET} PRIVATE FEATURE_MASKED_HW_INTRINSICS)
endif ()
endfunction()

if (CLR_CMAKE_TARGET_ARCH_AMD64 OR CLR_CMAKE_TARGET_ARCH_ARM64 OR (CLR_CMAKE_TARGET_ARCH_I386 AND NOT CLR_CMAKE_HOST_UNIX))
add_compile_definitions($<$<NOT:$<BOOL:$<TARGET_PROPERTY:IGNORE_DEFAULT_TARGET_ARCH>>>:FEATURE_SIMD>)
add_compile_definitions($<$<NOT:$<BOOL:$<TARGET_PROPERTY:IGNORE_DEFAULT_TARGET_ARCH>>>:FEATURE_HW_INTRINSICS>)
add_compile_definitions($<$<NOT:$<BOOL:$<TARGET_PROPERTY:IGNORE_DEFAULT_TARGET_ARCH>>>:FEATURE_MASKED_SIMD>)
add_compile_definitions($<$<NOT:$<BOOL:$<TARGET_PROPERTY:IGNORE_DEFAULT_TARGET_ARCH>>>:FEATURE_MASKED_HW_INTRINSICS>)
endif ()

# JIT_BUILD disables certain PAL_TRY debugging features
Expand Down
21 changes: 19 additions & 2 deletions src/coreclr/jit/codegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2771,7 +2771,16 @@ void CodeGen::genCodeForLclVar(GenTreeLclVar* tree)
emitAttr attr = emitActualTypeSize(targetType);

emitter* emit = GetEmitter();
emit->emitIns_R_S(ins, attr, tree->GetRegNum(), varNum, 0);

if (ins == INS_sve_ldr && !varTypeUsesMaskReg(targetType))
{
emit->emitIns_R_S(ins, attr, tree->GetRegNum(), varNum, 0, INS_SCALABLE_OPTS_UNPREDICATED);
}
else
{
emit->emitIns_R_S(ins, attr, tree->GetRegNum(), varNum, 0);
}

genProduceReg(tree);
}
}
Expand Down Expand Up @@ -2956,7 +2965,15 @@ void CodeGen::genCodeForStoreLclVar(GenTreeLclVar* lclNode)
instruction ins = ins_StoreFromSrc(dataReg, targetType);
emitAttr attr = emitActualTypeSize(targetType);

emit->emitIns_S_R(ins, attr, dataReg, varNum, /* offset */ 0);
// TODO-SVE: Removable once REG_V0 and REG_P0 are distinct
if (ins == INS_sve_str && !varTypeUsesMaskReg(targetType))
{
emit->emitIns_S_R(ins, attr, dataReg, varNum, /* offset */ 0, INS_SCALABLE_OPTS_UNPREDICATED);
}
else
{
emit->emitIns_S_R(ins, attr, dataReg, varNum, /* offset */ 0);
}
}
else // store into register (i.e move into register)
{
Expand Down
129 changes: 103 additions & 26 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17311,13 +17311,19 @@ void emitter::emitIns_S(instruction ins, emitAttr attr, int varx, int offs)
*
* Add an instruction referencing a register and a stack-based local variable.
*/
void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int varx, int offs)
void emitter::emitIns_R_S(instruction ins,
emitAttr attr,
regNumber reg1,
int varx,
int offs,
insScalableOpts sopt /* = INS_SCALABLE_OPTS_NONE */)
{
emitAttr size = EA_SIZE(attr);
insFormat fmt = IF_NONE;
int disp = 0;
unsigned scale = 0;
bool isLdrStr = false;
emitAttr size = EA_SIZE(attr);
insFormat fmt = IF_NONE;
int disp = 0;
unsigned scale = 0;
bool isLdrStr = false;
bool isScalable = false;

assert(offs >= 0);

Expand Down Expand Up @@ -17353,16 +17359,42 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va
scale = 0;
break;

case INS_sve_ldr:
assert(isVectorRegister(reg1) || isPredicateRegister(reg1));
isScalable = true;

// TODO-SVE: This should probably be set earlier in the caller
size = EA_SCALABLE;
attr = size;

// TODO-SVE: Use register number instead of enum
if (sopt == INS_SCALABLE_OPTS_UNPREDICATED)
{
fmt = IF_SVE_IE_2A;
// TODO-SVE: Don't assume 128bit vectors
scale = NaturalScale_helper(EA_16BYTE);
}
else
{
assert(insScalableOptsNone(sopt));
fmt = IF_SVE_ID_2A;
// TODO-SVE: Don't assume 128bit vectors
// Predicate size is vector length / 8
scale = NaturalScale_helper(EA_2BYTE);
}
break;

default:
NYI("emitIns_R_S"); // FP locals?
return;

} // end switch (ins)

/* Figure out the variable's frame position */
ssize_t imm;
int base;
bool FPbased;
ssize_t imm;
int base;
bool FPbased;
insFormat scalarfmt = fmt;

base = emitComp->lvaFrameAddress(varx, &FPbased);
disp = base + offs;
Expand All @@ -17387,13 +17419,13 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va

if (imm <= 0x0fff)
{
fmt = IF_DI_2A; // add reg1,reg2,#disp
scalarfmt = IF_DI_2A; // add reg1,reg2,#disp
}
else
{
regNumber rsvdReg = codeGen->rsGetRsvdReg();
codeGen->instGen_Set_Reg_To_Imm(EA_PTRSIZE, rsvdReg, imm);
fmt = IF_DR_3A; // add reg1,reg2,rsvdReg
scalarfmt = IF_DR_3A; // add reg1,reg2,rsvdReg
}
}
else
Expand All @@ -17402,13 +17434,13 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va
imm = disp;
if (imm == 0)
{
fmt = IF_LS_2A;
scalarfmt = IF_LS_2A;
}
else if ((imm < 0) || ((imm & mask) != 0))
{
if ((imm >= -256) && (imm <= 255))
{
fmt = IF_LS_2C;
scalarfmt = IF_LS_2C;
}
else
{
Expand All @@ -17417,11 +17449,13 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va
}
else if (imm > 0)
{
// TODO: We should be able to scale values <0 for all variants.

if (((imm & mask) == 0) && ((imm >> scale) < 0x1000))
{
imm >>= scale; // The immediate is scaled by the size of the ld/st

fmt = IF_LS_2B;
scalarfmt = IF_LS_2B;
}
else
{
Expand All @@ -17433,10 +17467,15 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va
{
regNumber rsvdReg = codeGen->rsGetRsvdReg();
codeGen->instGen_Set_Reg_To_Imm(EA_PTRSIZE, rsvdReg, imm);
fmt = IF_LS_3A;
scalarfmt = IF_LS_3A;
}
}

// Set the format based on the immediate encoding
if (!isScalable)
{
fmt = scalarfmt;
}
assert(fmt != IF_NONE);

// Try to optimize a load/store with an alternative instruction.
Expand Down Expand Up @@ -17564,7 +17603,12 @@ void emitter::emitIns_R_R_S_S(
*
* Add an instruction referencing a stack-based local variable and a register
*/
void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int varx, int offs)
void emitter::emitIns_S_R(instruction ins,
emitAttr attr,
regNumber reg1,
int varx,
int offs,
insScalableOpts sopt /* = INS_SCALABLE_OPTS_NONE */)
{
assert(offs >= 0);
emitAttr size = EA_SIZE(attr);
Expand All @@ -17573,6 +17617,7 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
unsigned scale = 0;
bool isVectorStore = false;
bool isStr = false;
bool isScalable = false;

// TODO-ARM64-CQ: use unscaled loads?
/* Figure out the encoding format of the instruction */
Expand Down Expand Up @@ -17604,6 +17649,31 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
isStr = true;
break;

case INS_sve_str:
assert(isVectorRegister(reg1) || isPredicateRegister(reg1));
isScalable = true;

// TODO-SVE: This should probably be set earlier in the caller
size = EA_SCALABLE;
attr = size;

// TODO-SVE: Use register number instead of enum
if (sopt == INS_SCALABLE_OPTS_UNPREDICATED)
{
fmt = IF_SVE_JH_2A;
// TODO-SVE: Don't assume 128bit vectors
scale = NaturalScale_helper(EA_16BYTE);
}
else
{
assert(insScalableOptsNone(sopt));
fmt = IF_SVE_JG_2A;
// TODO-SVE: Don't assume 128bit vectors
// Predicate size is vector length / 8
scale = NaturalScale_helper(EA_2BYTE);
}
break;

default:
NYI("emitIns_S_R"); // FP locals?
return;
Expand All @@ -17617,7 +17687,7 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
base = emitComp->lvaFrameAddress(varx, &FPbased);
disp = base + offs;
assert(scale >= 0);
if (isVectorStore)
if (isVectorStore || isScalable)
{
assert(scale <= 4);
}
Expand All @@ -17630,18 +17700,19 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
regNumber reg2 = FPbased ? REG_FPBASE : REG_SPBASE;
reg2 = encodingSPtoZR(reg2);

bool useRegForImm = false;
ssize_t imm = disp;
ssize_t mask = (1 << scale) - 1; // the mask of low bits that must be zero to encode the immediate
bool useRegForImm = false;
ssize_t imm = disp;
ssize_t mask = (1 << scale) - 1; // the mask of low bits that must be zero to encode the immediate
insFormat scalarfmt = fmt;
if (imm == 0)
{
fmt = IF_LS_2A;
scalarfmt = IF_LS_2A;
}
else if ((imm < 0) || ((imm & mask) != 0))
{
if ((imm >= -256) && (imm <= 255))
if (isValidSimm9(imm))
{
fmt = IF_LS_2C;
scalarfmt = IF_LS_2C;
}
else
{
Expand All @@ -17650,11 +17721,12 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
}
else if (imm > 0)
{
// TODO: We should be able to scale values <0 for all variants.

if (((imm & mask) == 0) && ((imm >> scale) < 0x1000))
{
imm >>= scale; // The immediate is scaled by the size of the ld/st

fmt = IF_LS_2B;
scalarfmt = IF_LS_2B;
}
else
{
Expand All @@ -17668,9 +17740,14 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
// It is instead implicit when idSetIsLclVar() is set, with this encoding format.
regNumber rsvdReg = codeGen->rsGetRsvdReg();
codeGen->instGen_Set_Reg_To_Imm(EA_PTRSIZE, rsvdReg, imm);
fmt = IF_LS_3A;
scalarfmt = IF_LS_3A;
}

// Set the format based on the immediate encoding
if (!isScalable)
{
fmt = scalarfmt;
}
assert(fmt != IF_NONE);

// Try to optimize a store with an alternative instruction.
Expand Down
6 changes: 4 additions & 2 deletions src/coreclr/jit/emitarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -1782,7 +1782,8 @@ void emitIns_C(instruction ins, emitAttr attr, CORINFO_FIELD_HANDLE fdlHnd, int

void emitIns_S(instruction ins, emitAttr attr, int varx, int offs);

void emitIns_S_R(instruction ins, emitAttr attr, regNumber ireg, int varx, int offs);
void emitIns_S_R(
instruction ins, emitAttr attr, regNumber ireg, int varx, int offs, insScalableOpts sopt = INS_SCALABLE_OPTS_NONE);

void emitIns_S_S_R_R(
instruction ins, emitAttr attr, emitAttr attr2, regNumber ireg, regNumber ireg2, int varx, int offs);
Expand All @@ -1800,7 +1801,8 @@ void emitIns_R_R_R_I_LdStPair(instruction ins,
int offs2 = -1 DEBUG_ARG(unsigned var1RefsOffs = BAD_IL_OFFSET)
DEBUG_ARG(unsigned var2RefsOffs = BAD_IL_OFFSET));

void emitIns_R_S(instruction ins, emitAttr attr, regNumber ireg, int varx, int offs);
void emitIns_R_S(
instruction ins, emitAttr attr, regNumber ireg, int varx, int offs, insScalableOpts sopt = INS_SCALABLE_OPTS_NONE);

void emitIns_R_R_S_S(
instruction ins, emitAttr attr, emitAttr attr2, regNumber ireg, regNumber ireg2, int varx, int offs);
Expand Down
15 changes: 11 additions & 4 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,11 @@ GenTree* Compiler::getArgForHWIntrinsic(var_types argType,
{
arg = impSIMDPopStack();
}
#if defined(TARGET_ARM64) && defined(FEATURE_MASKED_SIMD)
assert(varTypeIsSIMD(arg) || varTypeIsMask(arg));
#else
assert(varTypeIsSIMD(arg));
#endif // TARGET_ARM64 && FEATURE_MASKED_SIMD
}
else
{
Expand Down Expand Up @@ -1591,13 +1595,16 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
}

#if defined(TARGET_ARM64)

if (HWIntrinsicInfo::IsMaskedOperation(intrinsic))
{
// Op1 input is a vector. HWInstrinsic requires a mask, so convert to a mask.
assert(numArgs > 0);
GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);
op1 = convertHWIntrinsicToMask(retType, op1, simdBaseJitType, simdSize);
retNode->AsHWIntrinsic()->Op(1) = op1;
GenTree* op1 = retNode->AsHWIntrinsic()->Op(1);
if (op1->TypeGet() != TYP_MASK)
{
// Op1 input is a vector. HWInstrinsic requires a mask.
retNode->AsHWIntrinsic()->Op(1) = convertHWIntrinsicToMask(retType, op1, simdBaseJitType, simdSize);
}
}

if (retType != nodeRetType)
Expand Down
11 changes: 11 additions & 0 deletions src/coreclr/jit/importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6419,6 +6419,17 @@ void Compiler::impImportBlockCode(BasicBlock* block)
impSpillSideEffects(false, CHECK_SPILL_ALL DEBUGARG("Spill before store to pinned local"));
}

#if defined(TARGET_ARM64) && defined(FEATURE_MASKED_SIMD)
// Masks must be converted to vectors before being stored to memory.
// But, for local stores we can optimise away the conversion
if (op1->OperIsHWIntrinsic() && op1->AsHWIntrinsic()->GetHWIntrinsicId() == NI_Sve_ConvertMaskToVector)
{
op1 = op1->AsHWIntrinsic()->Op(1);
lvaTable[lclNum].lvType = TYP_MASK;
lclTyp = lvaGetActualType(lclNum);
}
#endif // TARGET_ARM64 && FEATURE_MASKED_SIMD

op1 = gtNewStoreLclVarNode(lclNum, op1);

// TODO-ASG: delete this zero-diff quirk. Requires some forward substitution work.
Expand Down

0 comments on commit 6628904

Please sign in to comment.