Skip to content

Commit

Permalink
Simplify Square Clipped ReLU code.
Browse files Browse the repository at this point in the history
Squared numbers are never negative, so barring any wraparound there
is no need to clamp to 0. From reading the code, there's no obvious
way to get wraparound, so the entire operation can be simplified
away. Updated original truncated code comments to be sensible.

Verified by running ./stockfish bench 128 1 24 and by the following test:

STC: https://tests.stockfishchess.org/tests/view/64da4db95b17f7c21c0eabe7
LLR: 2.94 (-2.94,2.94) <-1.75,0.25>
Total: 60224 W: 15425 L: 15236 D: 29563
Ptnml(0-2): 195, 6576, 16382, 6763, 196

closes #4751

No functional change
  • Loading branch information
gcp authored and snicolet committed Aug 22, 2023
1 parent 4c5919f commit c6f6236
Showing 1 changed file with 5 additions and 19 deletions.
24 changes: 5 additions & 19 deletions src/nnue/layers/sqr_clipped_relu.h
Expand Up @@ -65,12 +65,6 @@ namespace Stockfish::Eval::NNUE::Layers {
#if defined(USE_SSE2)
constexpr IndexType NumChunks = InputDimensions / 16;

#ifdef USE_SSE41
const __m128i Zero = _mm_setzero_si128();
#else
const __m128i k0x80s = _mm_set1_epi8(-128);
#endif

static_assert(WeightScaleBits == 6);
const auto in = reinterpret_cast<const __m128i*>(input);
const auto out = reinterpret_cast<__m128i*>(output);
Expand All @@ -82,21 +76,13 @@ namespace Stockfish::Eval::NNUE::Layers {
_mm_load_si128(&in[i * 4 + 2]),
_mm_load_si128(&in[i * 4 + 3]));

// Not sure if
// We shift by WeightScaleBits * 2 = 12 and divide by 128
// which is an additional shift-right of 7, meaning 19 in total.
// MulHi strips the lower 16 bits so we need to shift out 3 more to match.
words0 = _mm_srli_epi16(_mm_mulhi_epi16(words0, words0), 3);
words1 = _mm_srli_epi16(_mm_mulhi_epi16(words1, words1), 3);

const __m128i packedbytes = _mm_packs_epi16(words0, words1);

_mm_store_si128(&out[i],

#ifdef USE_SSE41
_mm_max_epi8(packedbytes, Zero)
#else
_mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s)
#endif

);
_mm_store_si128(&out[i], _mm_packs_epi16(words0, words1));
}
constexpr IndexType Start = NumChunks * 16;

Expand All @@ -108,7 +94,7 @@ namespace Stockfish::Eval::NNUE::Layers {
output[i] = static_cast<OutputType>(
// really should be /127 but we need to make it fast
// needs to be accounted for in the trainer
std::max(0ll, std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128)));
std::min(127ll, (((long long)input[i] * input[i]) >> (2 * WeightScaleBits)) / 128));
}
}
};
Expand Down

0 comments on commit c6f6236

Please sign in to comment.