Skip to content

Commit

Permalink
Optimise pairwise multiplication
Browse files Browse the repository at this point in the history
This speedup was first inspired by a comment by @AndyGrant on my recent PR "If mullo_epi16 would preserve the signedness, then this could be used to remove 50% of the max operations during the halfkp-pairwise mat-mul relu deal."

That got me thinking, because although mullo_epi16 did not preserve the signedness, mulhi_epi16 did, and so we could shift left and then use mulhi_epi16, instead of shifting right after the mullo.

However, due to some issues with shifting into the sign bit, the FT weights and biases had to be multiplied by 2 for the optimisation to work. In the future we can increase the quantisation number in the FT to 255 (instead of 2x127=254) and also train a net from scratch to make use of the extra bit (currently all the weights and biases are even)

Speedup on "Arch=x86-64-bmi2 COMP=clang", courtesy of @Torom
Result of 50 runs
base (...es/stockfish) =     962946  +/- 1202
test (...ise-max-less) =     979696  +/- 1084
diff                   =     +16750  +/- 1794

speedup        = +0.0174
P(speedup > 0) =  1.0000

CPU: 4 x Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz
Hyperthreading: on

Also a speedup on "COMP=gcc", courtesy of Torom once again
Result of 50 runs
base (...tockfish_gcc) =     966033  +/- 1574
test (...max-less_gcc) =     983319  +/- 1513
diff                   =     +17286  +/- 2515

speedup        = +0.0179
P(speedup > 0) =  1.0000

CPU: 4 x Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz
Hyperthreading: on

closes #5282

Passed STC:
LLR: 2.96 (-2.94,2.94) <0.00,2.00>
Total: 67712 W: 17715 L: 17358 D: 32639
Ptnml(0-2): 225, 7472, 18140, 7759, 260
https://tests.stockfishchess.org/tests/view/664c1d75830eb9f886616906

No functional change
  • Loading branch information
cj5716 committed May 23, 2024
1 parent c14b697 commit 8d29658
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 41 deletions.
4 changes: 2 additions & 2 deletions src/evaluate.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ namespace Eval {
// for the build process (profile-build and fishtest) to work. Do not change the
// name of the macro or the location where this macro is defined, as it is used
// in the Makefile/Fishtest.
#define EvalFileDefaultNameBig "nn-c721dfca8cd3.nnue"
#define EvalFileDefaultNameSmall "nn-baff1ede1f90.nnue"
#define EvalFileDefaultNameBig "nn-ea6ac98d29d4.nnue"
#define EvalFileDefaultNameSmall "nn-321bbccb5bcf.nnue"

namespace NNUE {
struct Networks;
Expand Down
64 changes: 38 additions & 26 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ using psqt_vec_t = __m256i;
#define vec_store(a, b) _mm512_store_si512(a, b)
#define vec_add_16(a, b) _mm512_add_epi16(a, b)
#define vec_sub_16(a, b) _mm512_sub_epi16(a, b)
#define vec_mul_16(a, b) _mm512_mullo_epi16(a, b)
#define vec_mulhi_16(a, b) _mm512_mulhi_epi16(a, b)
#define vec_zero() _mm512_setzero_epi32()
#define vec_set_16(a) _mm512_set1_epi16(a)
#define vec_max_16(a, b) _mm512_max_epi16(a, b)
#define vec_min_16(a, b) _mm512_min_epi16(a, b)
#define vec_slli_16(a, b) _mm512_slli_epi16(a, b)
// Inverse permuted at load time
#define vec_msb_pack_16(a, b) \
_mm512_packs_epi16(_mm512_srli_epi16(a, 7), _mm512_srli_epi16(b, 7))
#define vec_packus_16(a, b) _mm512_packus_epi16(a, b)
#define vec_load_psqt(a) _mm256_load_si256(a)
#define vec_store_psqt(a, b) _mm256_store_si256(a, b)
#define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b)
Expand All @@ -78,14 +78,14 @@ using psqt_vec_t = __m256i;
#define vec_store(a, b) _mm256_store_si256(a, b)
#define vec_add_16(a, b) _mm256_add_epi16(a, b)
#define vec_sub_16(a, b) _mm256_sub_epi16(a, b)
#define vec_mul_16(a, b) _mm256_mullo_epi16(a, b)
#define vec_mulhi_16(a, b) _mm256_mulhi_epi16(a, b)
#define vec_zero() _mm256_setzero_si256()
#define vec_set_16(a) _mm256_set1_epi16(a)
#define vec_max_16(a, b) _mm256_max_epi16(a, b)
#define vec_min_16(a, b) _mm256_min_epi16(a, b)
#define vec_slli_16(a, b) _mm256_slli_epi16(a, b)
// Inverse permuted at load time
#define vec_msb_pack_16(a, b) \
_mm256_packs_epi16(_mm256_srli_epi16(a, 7), _mm256_srli_epi16(b, 7))
#define vec_packus_16(a, b) _mm256_packus_epi16(a, b)
#define vec_load_psqt(a) _mm256_load_si256(a)
#define vec_store_psqt(a, b) _mm256_store_si256(a, b)
#define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b)
Expand All @@ -101,12 +101,13 @@ using psqt_vec_t = __m128i;
#define vec_store(a, b) *(a) = (b)
#define vec_add_16(a, b) _mm_add_epi16(a, b)
#define vec_sub_16(a, b) _mm_sub_epi16(a, b)
#define vec_mul_16(a, b) _mm_mullo_epi16(a, b)
#define vec_mulhi_16(a, b) _mm_mulhi_epi16(a, b)
#define vec_zero() _mm_setzero_si128()
#define vec_set_16(a) _mm_set1_epi16(a)
#define vec_max_16(a, b) _mm_max_epi16(a, b)
#define vec_min_16(a, b) _mm_min_epi16(a, b)
#define vec_msb_pack_16(a, b) _mm_packs_epi16(_mm_srli_epi16(a, 7), _mm_srli_epi16(b, 7))
#define vec_slli_16(a, b) _mm_slli_epi16(a, b)
#define vec_packus_16(a, b) _mm_packus_epi16(a, b)
#define vec_load_psqt(a) (*(a))
#define vec_store_psqt(a, b) *(a) = (b)
#define vec_add_psqt_32(a, b) _mm_add_epi32(a, b)
Expand All @@ -122,18 +123,14 @@ using psqt_vec_t = int32x4_t;
#define vec_store(a, b) *(a) = (b)
#define vec_add_16(a, b) vaddq_s16(a, b)
#define vec_sub_16(a, b) vsubq_s16(a, b)
#define vec_mul_16(a, b) vmulq_s16(a, b)
#define vec_mulhi_16(a, b) vqdmulhq_s16(a, b)
#define vec_zero() \
vec_t { 0 }
#define vec_set_16(a) vdupq_n_s16(a)
#define vec_max_16(a, b) vmaxq_s16(a, b)
#define vec_min_16(a, b) vminq_s16(a, b)
inline vec_t vec_msb_pack_16(vec_t a, vec_t b) {
const int8x8_t shifta = vshrn_n_s16(a, 7);
const int8x8_t shiftb = vshrn_n_s16(b, 7);
const int8x16_t compacted = vcombine_s8(shifta, shiftb);
return *reinterpret_cast<const vec_t*>(&compacted);
}
#define vec_slli_16(a, b) vshlq_s16(a, vec_set_16(b))
#define vec_packus_16(a, b) reinterpret_cast<vec_t>(vcombine_u8(vqmovun_s16(a), vqmovun_s16(b)))
#define vec_load_psqt(a) (*(a))
#define vec_store_psqt(a, b) *(a) = (b)
#define vec_add_psqt_32(a, b) vaddq_s32(a, b)
Expand Down Expand Up @@ -332,7 +329,7 @@ class FeatureTransformer {
constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;

const vec_t Zero = vec_zero();
const vec_t One = vec_set_16(127);
const vec_t One = vec_set_16(254);

const vec_t* in0 = reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][0]));
const vec_t* in1 =
Expand All @@ -341,15 +338,30 @@ class FeatureTransformer {

for (IndexType j = 0; j < NumOutputChunks; ++j)
{
const vec_t sum0a = vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero);
const vec_t sum0b = vec_max_16(vec_min_16(in0[j * 2 + 1], One), Zero);
const vec_t sum1a = vec_max_16(vec_min_16(in1[j * 2 + 0], One), Zero);
const vec_t sum1b = vec_max_16(vec_min_16(in1[j * 2 + 1], One), Zero);
// What we want to do is multiply inputs in a pairwise manner (after clipping), and then shift right by 9.
// Instead, we shift left by 7, and use mulhi, stripping the bottom 16 bits, effectively shifting right by 16,
// resulting in a net shift of 9 bits. We use mulhi because it maintains the sign of the multiplication (unlike mullo),
// allowing us to make use of packus to clip 2 of the inputs, resulting in a save of 2 "vec_max_16" calls.
// A special case is when we use NEON, where we shift left by 6 instead, because the instruction "vqdmulhq_s16"
// also doubles the return value after the multiplication, adding an extra shift to the left by 1, so we
// compensate by shifting less before the multiplication.

#if defined(USE_SSE2)
constexpr int shift = 7;
#else
constexpr int shift = 6;
#endif
const vec_t sum0a =
vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero), shift);
const vec_t sum0b =
vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 1], One), Zero), shift);
const vec_t sum1a = vec_min_16(in1[j * 2 + 0], One);
const vec_t sum1b = vec_min_16(in1[j * 2 + 1], One);

const vec_t pa = vec_mul_16(sum0a, sum1a);
const vec_t pb = vec_mul_16(sum0b, sum1b);
const vec_t pa = vec_mulhi_16(sum0a, sum1a);
const vec_t pb = vec_mulhi_16(sum0b, sum1b);

out[j] = vec_msb_pack_16(pa, pb);
out[j] = vec_packus_16(pa, pb);
}

#else
Expand All @@ -359,9 +371,9 @@ class FeatureTransformer {
BiasType sum0 = accumulation[static_cast<int>(perspectives[p])][j + 0];
BiasType sum1 =
accumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];
sum0 = std::clamp<BiasType>(sum0, 0, 127);
sum1 = std::clamp<BiasType>(sum1, 0, 127);
output[offset + j] = static_cast<OutputType>(unsigned(sum0 * sum1) / 128);
sum0 = std::clamp<BiasType>(sum0, 0, 254);
sum1 = std::clamp<BiasType>(sum1, 0, 254);
output[offset + j] = static_cast<OutputType>(unsigned(sum0 * sum1) / 512);
}

#endif
Expand Down
9 changes: 3 additions & 6 deletions src/nnue/nnue_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,11 @@ trace(Position& pos, const Eval::NNUE::Networks& networks, Eval::NNUE::Accumulat
ss << "| " << bucket << " ";
ss << " | ";
format_cp_aligned_dot(t.psqt[bucket], ss, pos);
ss << " "
<< " | ";
ss << " " << " | ";
format_cp_aligned_dot(t.positional[bucket], ss, pos);
ss << " "
<< " | ";
ss << " " << " | ";
format_cp_aligned_dot(t.psqt[bucket] + t.positional[bucket], ss, pos);
ss << " "
<< " |";
ss << " " << " |";
if (bucket == t.correctBucket)
ss << " <-- this bucket is used";
ss << '\n';
Expand Down
6 changes: 2 additions & 4 deletions src/tune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ void make_option(OptionsMap* options, const string& n, int v, const SetRange& r)

// Print formatted parameters, ready to be copy-pasted in Fishtest
std::cout << n << "," << v << "," << r(v).first << "," << r(v).second << ","
<< (r(v).second - r(v).first) / 20.0 << ","
<< "0.0020" << std::endl;
<< (r(v).second - r(v).first) / 20.0 << "," << "0.0020" << std::endl;
}
}

Expand Down Expand Up @@ -118,7 +117,6 @@ void Tune::Entry<Tune::PostUpdate>::read_option() {

namespace Stockfish {

void Tune::read_results() { /* ...insert your values here... */
}
void Tune::read_results() { /* ...insert your values here... */ }

} // namespace Stockfish
6 changes: 3 additions & 3 deletions src/uci.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ void UCIEngine::bench(std::istream& args) {

dbg_print();

std::cerr << "\n==========================="
<< "\nTotal time (ms) : " << elapsed << "\nNodes searched : " << nodes
<< "\nNodes/second : " << 1000 * nodes / elapsed << std::endl;
std::cerr << "\n===========================" << "\nTotal time (ms) : " << elapsed
<< "\nNodes searched : " << nodes << "\nNodes/second : " << 1000 * nodes / elapsed
<< std::endl;

// reset callback, to not capture a dangling reference to nodesSearched
engine.set_on_update_full([&](const auto& i) { on_update_full(i, options["UCI_ShowWDL"]); });
Expand Down

0 comments on commit 8d29658

Please sign in to comment.