Skip to content

Commit

Permalink
Updating Unix to save/restore Avx512 state
Browse files Browse the repository at this point in the history
  • Loading branch information
tannergooding committed Mar 22, 2023
1 parent 12e9711 commit aa6e6b2
Show file tree
Hide file tree
Showing 10 changed files with 514 additions and 113 deletions.
10 changes: 5 additions & 5 deletions src/coreclr/nativeaot/Runtime/windows/PalRedhawkMinWin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,17 @@ REDHAWK_PALEXPORT CONTEXT* PalAllocateCompleteOSContext(_Out_ uint8_t** contextB
}
#endif //TARGET_X86

// Determine if the processor supports AVX so we could
// Determine if the processor supports AVX or AVX512 so we could
// retrieve extended registers
DWORD64 FeatureMask = GetEnabledXStateFeatures();
if ((FeatureMask & XSTATE_MASK_AVX) != 0)
if ((FeatureMask & (XSTATE_MASK_AVX | XSTATE_MASK_AVX512)) != 0)
{
context = context | CONTEXT_XSTATE;
}

// Retrieve contextSize by passing NULL for Buffer
DWORD contextSize = 0;
ULONG64 xStateCompactionMask = XSTATE_MASK_LEGACY | XSTATE_MASK_AVX;
ULONG64 xStateCompactionMask = XSTATE_MASK_LEGACY | XSTATE_MASK_AVX | XSTATE_MASK_MPX | XSTATE_MASK_AVX512;
// The initialize call should fail but return contextSize
BOOL success = pfnInitializeContext2 ?
pfnInitializeContext2(NULL, context, NULL, &contextSize, xStateCompactionMask) :
Expand Down Expand Up @@ -426,9 +426,9 @@ REDHAWK_PALEXPORT _Success_(return) bool REDHAWK_PALAPI PalGetCompleteThreadCont
#if defined(TARGET_X86) || defined(TARGET_AMD64)
// Make sure that AVX feature mask is set, if supported. This should not normally fail.
// The system silently ignores any feature specified in the FeatureMask which is not enabled on the processor.
if (!SetXStateFeaturesMask(pCtx, XSTATE_MASK_AVX))
if (!SetXStateFeaturesMask(pCtx, XSTATE_MASK_AVX | XSTATE_MASK_AVX512))
{
_ASSERTE(!"Could not apply XSTATE_MASK_AVX");
_ASSERTE(!"Could not apply XSTATE_MASK_AVX | XSTATE_MASK_AVX512");
return FALSE;
}
#endif //defined(TARGET_X86) || defined(TARGET_AMD64)
Expand Down
113 changes: 113 additions & 0 deletions src/coreclr/pal/inc/pal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,12 @@ QueueUserAPC(

#ifdef HOST_X86

// MSVC directly defines intrinsics for __cpuid and __cpuidex matching the below signatures
// We define matching signatures for use on Unix platforms.

extern "C" void __cpuid(int cpuInfo[4], int function_id);
extern "C" void __cpuidex(int cpuInfo[4], int function_id, int subFunction_id);

//
// ***********************************************************************************
//
Expand Down Expand Up @@ -1461,6 +1467,13 @@ typedef struct _KNONVOLATILE_CONTEXT_POINTERS {
//

#elif defined(HOST_AMD64)

// MSVC directly defines intrinsics for __cpuid and __cpuidex matching the below signatures
// We define matching signatures for use on Unix platforms.

extern "C" void __cpuid(int cpuInfo[4], int function_id);
extern "C" void __cpuidex(int cpuInfo[4], int function_id, int subFunction_id);

// copied from winnt.h

#define CONTEXT_AMD64 0x100000
Expand All @@ -1482,11 +1495,33 @@ typedef struct _KNONVOLATILE_CONTEXT_POINTERS {
#define CONTEXT_EXCEPTION_REQUEST 0x40000000
#define CONTEXT_EXCEPTION_REPORTING 0x80000000

#define XSTATE_GSSE (2)
#define XSTATE_AVX (XSTATE_GSSE)
#define XSTATE_AVX512_KMASK (5)
#define XSTATE_AVX512_ZMM_H (6)
#define XSTATE_AVX512_ZMM (7)

#define XSTATE_MASK_GSSE (1ui64 << (XSTATE_GSSE))
#define XSTATE_MASK_AVX (XSTATE_MASK_GSSE)
#define XSTATE_MASK_AVX512 ((1ui64 << (XSTATE_AVX512_KMASK)) | \
(1ui64 << (XSTATE_AVX512_ZMM_H)) | \
(1ui64 << (XSTATE_AVX512_ZMM)))

typedef struct DECLSPEC_ALIGN(16) _M128A {
ULONGLONG Low;
LONGLONG High;
} M128A, *PM128A;

typedef struct DECLSPEC_ALIGN(32) _M256A {
M128A Low;
M128A High;
} M256A, *PM256A;

typedef struct DECLSPEC_ALIGN(64) _M512A {
M256A Low;
M256A High;
} M512A, *PM512A;

typedef struct _XMM_SAVE_AREA32 {
WORD ControlWord;
WORD StatusWord;
Expand Down Expand Up @@ -1623,6 +1658,84 @@ typedef struct DECLSPEC_ALIGN(16) _CONTEXT {
DWORD64 LastBranchFromRip;
DWORD64 LastExceptionToRip;
DWORD64 LastExceptionFromRip;

// XSTATE
DWORD64 XStateFeaturesMask;
DWORD64 XStateReserved0;

// XSTATE_AVX
struct {
M128A Ymm0H;
M128A Ymm1H;
M128A Ymm2H;
M128A Ymm3H;
M128A Ymm4H;
M128A Ymm5H;
M128A Ymm6H;
M128A Ymm7H;
M128A Ymm8H;
M128A Ymm9H;
M128A Ymm10H;
M128A Ymm11H;
M128A Ymm12H;
M128A Ymm13H;
M128A Ymm14H;
M128A Ymm15H;
};

// XSTATE_AVX512_KMASK
struct {
DWORD64 KMask0;
DWORD64 KMask1;
DWORD64 KMask2;
DWORD64 KMask3;
DWORD64 KMask4;
DWORD64 KMask5;
DWORD64 KMask6;
DWORD64 KMask7;
};

// XSTATE_AVX512_ZMM_H
struct {
M256A Zmm0H;
M256A Zmm1H;
M256A Zmm2H;
M256A Zmm3H;
M256A Zmm4H;
M256A Zmm5H;
M256A Zmm6H;
M256A Zmm7H;
M256A Zmm8H;
M256A Zmm9H;
M256A Zmm10H;
M256A Zmm11H;
M256A Zmm12H;
M256A Zmm13H;
M256A Zmm14H;
M256A Zmm15H;
};

DWORD64 XStateReserved1[4];

// XSTATE_AVX512_ZMM
struct {
M512A Zmm16;
M512A Zmm17;
M512A Zmm18;
M512A Zmm19;
M512A Zmm20;
M512A Zmm21;
M512A Zmm22;
M512A Zmm23;
M512A Zmm24;
M512A Zmm25;
M512A Zmm26;
M512A Zmm27;
M512A Zmm28;
M512A Zmm29;
M512A Zmm30;
M512A Zmm31;
};
} CONTEXT, *PCONTEXT, *LPCONTEXT;

//
Expand Down
22 changes: 17 additions & 5 deletions src/coreclr/pal/src/arch/amd64/asmconstants.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

#ifdef HOST_64BIT

#define XFEATURE_MASK_OPMASK (1 << 5)
#define XFEATURE_MASK_ZMM_Hi256 (1 << 6)
#define XFEATURE_MASK_Hi16_ZMM (1 << 7)
#define XFEATURE_MASK_AVX512 (XFEATURE_MASK_OPMASK | XFEATURE_MASK_ZMM_Hi256 | XFEATURE_MASK_Hi16_ZMM)

// The arch bit is normally set in the flag constants below. Since this is already arch-specific code and the arch bit is not
// relevant, the arch bit is excluded from the flag constants below for simpler tests.
#define CONTEXT_AMD64 0x100000
Expand All @@ -17,7 +22,7 @@

#define CONTEXT_XSTATE 64

#define CONTEXT_ContextFlags 6*8
#define CONTEXT_ContextFlags (6*8)
#define CONTEXT_SegCs CONTEXT_ContextFlags+8
#define CONTEXT_SegDs CONTEXT_SegCs+2
#define CONTEXT_SegEs CONTEXT_SegDs+2
Expand Down Expand Up @@ -49,8 +54,8 @@
#define CONTEXT_R15 CONTEXT_R14+8
#define CONTEXT_Rip CONTEXT_R15+8
#define CONTEXT_FltSave CONTEXT_Rip+8
#define FLOATING_SAVE_AREA_SIZE 4*8+24*16+96
#define CONTEXT_Xmm0 CONTEXT_FltSave+10*16
#define FLOATING_SAVE_AREA_SIZE (4*8)+(24*16)+96
#define CONTEXT_Xmm0 CONTEXT_FltSave+(10*16)
#define CONTEXT_Xmm1 CONTEXT_Xmm0+16
#define CONTEXT_Xmm2 CONTEXT_Xmm1+16
#define CONTEXT_Xmm3 CONTEXT_Xmm2+16
Expand All @@ -67,13 +72,20 @@
#define CONTEXT_Xmm14 CONTEXT_Xmm13+16
#define CONTEXT_Xmm15 CONTEXT_Xmm14+16
#define CONTEXT_VectorRegister CONTEXT_FltSave+FLOATING_SAVE_AREA_SIZE
#define CONTEXT_VectorControl CONTEXT_VectorRegister+16*26
#define CONTEXT_VectorControl CONTEXT_VectorRegister+(16*26)
#define CONTEXT_DebugControl CONTEXT_VectorControl+8
#define CONTEXT_LastBranchToRip CONTEXT_DebugControl+8
#define CONTEXT_LastBranchFromRip CONTEXT_LastBranchToRip+8
#define CONTEXT_LastExceptionToRip CONTEXT_LastBranchFromRip+8
#define CONTEXT_LastExceptionFromRip CONTEXT_LastExceptionToRip+8
#define CONTEXT_Size CONTEXT_LastExceptionFromRip+8
#define CONTEXT_XStateFeaturesMask CONTEXT_LastExceptionFromRip+8
#define CONTEXT_XStateReserved0 CONTEXT_XStateFeaturesMask+8
#define CONTEXT_Ymm0H CONTEXT_XStateReserved0+8
#define CONTEXT_KMask0 CONTEXT_Ymm0H+(16*16)
#define CONTEXT_Zmm0H CONTEXT_Ymm0H+(8*8)
#define CONTEXT_XStateReserved1 CONTEXT_Zmm0H+(32*16)
#define CONTEXT_Zmm16 CONTEXT_XStateReserved1+(8*4)
#define CONTEXT_Size CONTEXT_Zmm16+(64*16)

#else // HOST_64BIT

Expand Down
84 changes: 67 additions & 17 deletions src/coreclr/pal/src/arch/amd64/context2.S
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,73 @@ LOCAL_LABEL(Done_Restore_CONTEXT_FLOATING_POINT):
test BYTE PTR [rdi + CONTEXT_ContextFlags], CONTEXT_XSTATE
je LOCAL_LABEL(Done_Restore_CONTEXT_XSTATE)

// Restore the extended state (for now, this is just the upper halves of YMM registers)
vinsertf128 ymm0, ymm0, xmmword ptr [rdi + (CONTEXT_VectorRegister + 0 * 16)], 1
vinsertf128 ymm1, ymm1, xmmword ptr [rdi + (CONTEXT_VectorRegister + 1 * 16)], 1
vinsertf128 ymm2, ymm2, xmmword ptr [rdi + (CONTEXT_VectorRegister + 2 * 16)], 1
vinsertf128 ymm3, ymm3, xmmword ptr [rdi + (CONTEXT_VectorRegister + 3 * 16)], 1
vinsertf128 ymm4, ymm4, xmmword ptr [rdi + (CONTEXT_VectorRegister + 4 * 16)], 1
vinsertf128 ymm5, ymm5, xmmword ptr [rdi + (CONTEXT_VectorRegister + 5 * 16)], 1
vinsertf128 ymm6, ymm6, xmmword ptr [rdi + (CONTEXT_VectorRegister + 6 * 16)], 1
vinsertf128 ymm7, ymm7, xmmword ptr [rdi + (CONTEXT_VectorRegister + 7 * 16)], 1
vinsertf128 ymm8, ymm8, xmmword ptr [rdi + (CONTEXT_VectorRegister + 8 * 16)], 1
vinsertf128 ymm9, ymm9, xmmword ptr [rdi + (CONTEXT_VectorRegister + 9 * 16)], 1
vinsertf128 ymm10, ymm10, xmmword ptr [rdi + (CONTEXT_VectorRegister + 10 * 16)], 1
vinsertf128 ymm11, ymm11, xmmword ptr [rdi + (CONTEXT_VectorRegister + 11 * 16)], 1
vinsertf128 ymm12, ymm12, xmmword ptr [rdi + (CONTEXT_VectorRegister + 12 * 16)], 1
vinsertf128 ymm13, ymm13, xmmword ptr [rdi + (CONTEXT_VectorRegister + 13 * 16)], 1
vinsertf128 ymm14, ymm14, xmmword ptr [rdi + (CONTEXT_VectorRegister + 14 * 16)], 1
vinsertf128 ymm15, ymm15, xmmword ptr [rdi + (CONTEXT_VectorRegister + 15 * 16)], 1
// Restore the YMM state
vinsertf128 ymm0, ymm0, xmmword ptr [rdi + (CONTEXT_Ymm0H + 0 * 16)], 1
vinsertf128 ymm1, ymm1, xmmword ptr [rdi + (CONTEXT_Ymm0H + 1 * 16)], 1
vinsertf128 ymm2, ymm2, xmmword ptr [rdi + (CONTEXT_Ymm0H + 2 * 16)], 1
vinsertf128 ymm3, ymm3, xmmword ptr [rdi + (CONTEXT_Ymm0H + 3 * 16)], 1
vinsertf128 ymm4, ymm4, xmmword ptr [rdi + (CONTEXT_Ymm0H + 4 * 16)], 1
vinsertf128 ymm5, ymm5, xmmword ptr [rdi + (CONTEXT_Ymm0H + 5 * 16)], 1
vinsertf128 ymm6, ymm6, xmmword ptr [rdi + (CONTEXT_Ymm0H + 6 * 16)], 1
vinsertf128 ymm7, ymm7, xmmword ptr [rdi + (CONTEXT_Ymm0H + 7 * 16)], 1
vinsertf128 ymm8, ymm8, xmmword ptr [rdi + (CONTEXT_Ymm0H + 8 * 16)], 1
vinsertf128 ymm9, ymm9, xmmword ptr [rdi + (CONTEXT_Ymm0H + 9 * 16)], 1
vinsertf128 ymm10, ymm10, xmmword ptr [rdi + (CONTEXT_Ymm0H + 10 * 16)], 1
vinsertf128 ymm11, ymm11, xmmword ptr [rdi + (CONTEXT_Ymm0H + 11 * 16)], 1
vinsertf128 ymm12, ymm12, xmmword ptr [rdi + (CONTEXT_Ymm0H + 12 * 16)], 1
vinsertf128 ymm13, ymm13, xmmword ptr [rdi + (CONTEXT_Ymm0H + 13 * 16)], 1
vinsertf128 ymm14, ymm14, xmmword ptr [rdi + (CONTEXT_Ymm0H + 14 * 16)], 1
vinsertf128 ymm15, ymm15, xmmword ptr [rdi + (CONTEXT_Ymm0H + 15 * 16)], 1

test BYTE PTR [rdi + CONTEXT_XStateFeaturesMask], XFEATURE_MASK_AVX512
je LOCAL_LABEL(Done_Restore_CONTEXT_XSTATE)

// Restore the Opmask state
kmovq k0, qword ptr [rdi + (CONTEXT_KMask0 + 0 * 8)]
kmovq k1, qword ptr [rdi + (CONTEXT_KMask0 + 1 * 8)]
kmovq k2, qword ptr [rdi + (CONTEXT_KMask0 + 2 * 8)]
kmovq k3, qword ptr [rdi + (CONTEXT_KMask0 + 3 * 8)]
kmovq k4, qword ptr [rdi + (CONTEXT_KMask0 + 4 * 8)]
kmovq k5, qword ptr [rdi + (CONTEXT_KMask0 + 5 * 8)]
kmovq k6, qword ptr [rdi + (CONTEXT_KMask0 + 6 * 8)]
kmovq k7, qword ptr [rdi + (CONTEXT_KMask0 + 7 * 8)]

// Restore the ZMM_Hi256 state
vinsertf64x4 zmm0, zmm0, ymmword ptr [rdi + (CONTEXT_Zmm0H + 0 * 32)], 1
vinsertf64x4 zmm1, zmm1, ymmword ptr [rdi + (CONTEXT_Zmm0H + 1 * 32)], 1
vinsertf64x4 zmm2, zmm2, ymmword ptr [rdi + (CONTEXT_Zmm0H + 2 * 32)], 1
vinsertf64x4 zmm3, zmm3, ymmword ptr [rdi + (CONTEXT_Zmm0H + 3 * 32)], 1
vinsertf64x4 zmm4, zmm4, ymmword ptr [rdi + (CONTEXT_Zmm0H + 4 * 32)], 1
vinsertf64x4 zmm5, zmm5, ymmword ptr [rdi + (CONTEXT_Zmm0H + 5 * 32)], 1
vinsertf64x4 zmm6, zmm6, ymmword ptr [rdi + (CONTEXT_Zmm0H + 6 * 32)], 1
vinsertf64x4 zmm7, zmm7, ymmword ptr [rdi + (CONTEXT_Zmm0H + 7 * 32)], 1
vinsertf64x4 zmm8, zmm8, ymmword ptr [rdi + (CONTEXT_Zmm0H + 8 * 32)], 1
vinsertf64x4 zmm9, zmm9, ymmword ptr [rdi + (CONTEXT_Zmm0H + 9 * 32)], 1
vinsertf64x4 zmm10, zmm10, ymmword ptr [rdi + (CONTEXT_Zmm0H + 10 * 32)], 1
vinsertf64x4 zmm11, zmm11, ymmword ptr [rdi + (CONTEXT_Zmm0H + 11 * 32)], 1
vinsertf64x4 zmm12, zmm12, ymmword ptr [rdi + (CONTEXT_Zmm0H + 12 * 32)], 1
vinsertf64x4 zmm13, zmm13, ymmword ptr [rdi + (CONTEXT_Zmm0H + 13 * 32)], 1
vinsertf64x4 zmm14, zmm14, ymmword ptr [rdi + (CONTEXT_Zmm0H + 14 * 32)], 1
vinsertf64x4 zmm15, zmm15, ymmword ptr [rdi + (CONTEXT_Zmm0H + 15 * 32)], 1

// Restore the Hi16_ZMM state
vmovups zmm16, zmmword ptr [rdi + (CONTEXT_Zmm16 + 0 * 64)]
vmovups zmm17, zmmword ptr [rdi + (CONTEXT_Zmm16 + 1 * 64)]
vmovups zmm18, zmmword ptr [rdi + (CONTEXT_Zmm16 + 2 * 64)]
vmovups zmm19, zmmword ptr [rdi + (CONTEXT_Zmm16 + 3 * 64)]
vmovups zmm20, zmmword ptr [rdi + (CONTEXT_Zmm16 + 4 * 64)]
vmovups zmm21, zmmword ptr [rdi + (CONTEXT_Zmm16 + 5 * 64)]
vmovups zmm22, zmmword ptr [rdi + (CONTEXT_Zmm16 + 6 * 64)]
vmovups zmm23, zmmword ptr [rdi + (CONTEXT_Zmm16 + 7 * 64)]
vmovups zmm24, zmmword ptr [rdi + (CONTEXT_Zmm16 + 8 * 64)]
vmovups zmm25, zmmword ptr [rdi + (CONTEXT_Zmm16 + 9 * 64)]
vmovups zmm26, zmmword ptr [rdi + (CONTEXT_Zmm16 + 10 * 64)]
vmovups zmm27, zmmword ptr [rdi + (CONTEXT_Zmm16 + 11 * 64)]
vmovups zmm28, zmmword ptr [rdi + (CONTEXT_Zmm16 + 12 * 64)]
vmovups zmm29, zmmword ptr [rdi + (CONTEXT_Zmm16 + 13 * 64)]
vmovups zmm30, zmmword ptr [rdi + (CONTEXT_Zmm16 + 14 * 64)]
vmovups zmm31, zmmword ptr [rdi + (CONTEXT_Zmm16 + 15 * 64)]

LOCAL_LABEL(Done_Restore_CONTEXT_XSTATE):

test BYTE PTR [rdi + CONTEXT_ContextFlags], CONTEXT_CONTROL
Expand Down
Loading

0 comments on commit aa6e6b2

Please sign in to comment.