Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute optimal register count for feature transformer accumulation dynamically. #3543

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,47 @@

namespace Stockfish::Eval::NNUE {

// We use __m* types as template arguments which causes GCC to emit warnings
// about losing some attribute information. This is irrelevant to us as we
// only take their size.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"

namespace Detail {
template <int RegisterWidth, int LaneWidth, int NumLanes, int MaxRegs>
static inline constexpr int compute_best_reg_count() {
static_assert(RegisterWidth > 0);
static_assert(LaneWidth > 0);
static_assert(RegisterWidth >= LaneWidth);
static_assert(MaxRegs > 0);
static_assert(RegisterWidth % LaneWidth == 0);
static_assert((NumLanes * LaneWidth) % RegisterWidth == 0);

const int idealRegisterCount = (NumLanes * LaneWidth) / RegisterWidth;
if (idealRegisterCount <= MaxRegs)
return idealRegisterCount;

// Look for the largest divisor of idealRegisterCount that is smaller than MaxRegs
int divisor = MaxRegs;
for (; divisor > 1; --divisor)
if (idealRegisterCount % divisor == 0)
break;

return divisor;
}
}

template <typename RegisterT, typename LaneT, int NumLanes, int MaxRegs>
static inline constexpr int BestRegCount =
Detail::compute_best_reg_count<sizeof(RegisterT), sizeof(LaneT), NumLanes, MaxRegs>();

// If vector instructions are enabled, we update and refresh the
// accumulator tile by tile such that each tile fits in the CPU's
// vector registers.
#define VECTOR

static_assert(PSQTBuckets == 8, "Assumed by the current choice of constants.");
static_assert(PSQTBuckets % 8 == 0,
"Per feature PSQT values cannot be processed at granularity lower than 8 at a time.");

#ifdef USE_AVX512
typedef __m512i vec_t;
Expand All @@ -49,8 +84,8 @@ namespace Stockfish::Eval::NNUE {
#define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b)
#define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b)
#define vec_zero_psqt() _mm256_setzero_si256()
static constexpr IndexType NumRegs = 8; // only 8 are needed
static constexpr IndexType NumPsqtRegs = 1;
static constexpr IndexType NumRegs = BestRegCount<vec_t, std::int16_t, TransformedFeatureDimensions, 32>;
static constexpr IndexType NumPsqtRegs = BestRegCount<psqt_vec_t, std::int32_t, PSQTBuckets, 32>;

#elif USE_AVX2
typedef __m256i vec_t;
Expand All @@ -64,8 +99,8 @@ namespace Stockfish::Eval::NNUE {
#define vec_add_psqt_32(a,b) _mm256_add_epi32(a,b)
#define vec_sub_psqt_32(a,b) _mm256_sub_epi32(a,b)
#define vec_zero_psqt() _mm256_setzero_si256()
static constexpr IndexType NumRegs = 16;
static constexpr IndexType NumPsqtRegs = 1;
static constexpr IndexType NumRegs = BestRegCount<vec_t, std::int16_t, TransformedFeatureDimensions, 16>;
static constexpr IndexType NumPsqtRegs = BestRegCount<psqt_vec_t, std::int32_t, PSQTBuckets, 16>;

#elif USE_SSE2
typedef __m128i vec_t;
Expand All @@ -79,8 +114,8 @@ namespace Stockfish::Eval::NNUE {
#define vec_add_psqt_32(a,b) _mm_add_epi32(a,b)
#define vec_sub_psqt_32(a,b) _mm_sub_epi32(a,b)
#define vec_zero_psqt() _mm_setzero_si128()
static constexpr IndexType NumRegs = Is64Bit ? 16 : 8;
static constexpr IndexType NumPsqtRegs = 2;
static constexpr IndexType NumRegs = BestRegCount<vec_t, std::int16_t, TransformedFeatureDimensions, Is64Bit ? 16 : 8>;
static constexpr IndexType NumPsqtRegs = BestRegCount<psqt_vec_t, std::int32_t, PSQTBuckets, Is64Bit ? 16 : 8>;

#elif USE_MMX
typedef __m64 vec_t;
Expand All @@ -94,8 +129,8 @@ namespace Stockfish::Eval::NNUE {
#define vec_add_psqt_32(a,b) _mm_add_pi32(a,b)
#define vec_sub_psqt_32(a,b) _mm_sub_pi32(a,b)
#define vec_zero_psqt() _mm_setzero_si64()
static constexpr IndexType NumRegs = 8;
static constexpr IndexType NumPsqtRegs = 4;
static constexpr IndexType NumRegs = BestRegCount<vec_t, std::int16_t, TransformedFeatureDimensions, 8>;
static constexpr IndexType NumPsqtRegs = BestRegCount<psqt_vec_t, std::int32_t, PSQTBuckets, 8>;

#elif USE_NEON
typedef int16x8_t vec_t;
Expand All @@ -109,14 +144,16 @@ namespace Stockfish::Eval::NNUE {
#define vec_add_psqt_32(a,b) vaddq_s32(a,b)
#define vec_sub_psqt_32(a,b) vsubq_s32(a,b)
#define vec_zero_psqt() psqt_vec_t{0}
static constexpr IndexType NumRegs = 16;
static constexpr IndexType NumPsqtRegs = 2;
static constexpr IndexType NumRegs = BestRegCount<vec_t, std::int16_t, TransformedFeatureDimensions, 16>;
static constexpr IndexType NumPsqtRegs = BestRegCount<psqt_vec_t, std::int32_t, PSQTBuckets, 16>;

#else
#undef VECTOR

#endif

#pragma GCC diagnostic pop

// Input feature converter
class FeatureTransformer {

Expand Down