diff --git a/src/lib/lwan-websocket.c b/src/lib/lwan-websocket.c index 9641aa997..626df961d 100644 --- a/src/lib/lwan-websocket.c +++ b/src/lib/lwan-websocket.c @@ -157,14 +157,9 @@ static void unmask(char *msg, size_t msg_len, char mask[static 4]) #if defined(__AVX2__) const __m256i mask256 = _mm256_castps_si256(_mm256_broadcast_ss((const float *)mask)); - if (msg_len >= 32) { - do { - __m256i v = _mm256_lddqu_si256((const __m256i *)msg); - _mm256_storeu_si256((__m256i *)msg, _mm256_xor_si256(v, mask256)); - - msg += 32; - msg_len -= 32; - } while (msg_len >= 32); + for (; msg_len >= 32; msg_len -= 32, msg += 32) { + __m256i v = _mm256_lddqu_si256((const __m256i *)msg); + _mm256_storeu_si256((__m256i *)msg, _mm256_xor_si256(v, mask256)); } #endif @@ -176,40 +171,33 @@ static void unmask(char *msg, size_t msg_len, char mask[static 4]) #else const __m128i mask128 = _mm_loadu_si128((const __m128i *)mask); #endif - if (msg_len >= 16) { - do { + for (; msg_len >= 16; msg_len -= 16, msg += 16) { #if defined(__SSE3__) - __m128i v = _mm_lddqu_si128((const __m128i *)msg); + __m128i v = _mm_lddqu_si128((const __m128i *)msg); #else - __m128i v = _mm_loadu_si128((const __m128i *)msg); + __m128i v = _mm_loadu_si128((const __m128i *)msg); #endif - _mm_storeu_si128((__m128i *)msg, _mm_xor_si128(v, mask128)); - - msg += 16; - msg_len -= 16; - } while (msg_len >= 16); + _mm_storeu_si128((__m128i *)msg, _mm_xor_si128(v, mask128)); } #endif - if (sizeof(void *) == 8) { - if (msg_len >= 8) { + if (sizeof(void *) == 8 && msg_len >= 8) { #if defined(__SSE_4_1__) - /* We're far away enough from the AVX2 path that it's - * probably better to use mask128 instead of mask256 - * here. */ - const __int64 mask64 = _mm_extract_epi64(mask128, 0); + /* We're far away enough from the AVX2 path that it's + * probably better to use mask128 instead of mask256 + * here. */ + const __int64 mask64 = _mm_extract_epi64(mask128, 0); #else - const uint32_t mask32 = string_as_uint32(mask); - const uint64_t mask64 = (uint64_t)mask32 << 32 | (uint64_t)mask32; + const uint32_t mask32 = string_as_uint32(mask); + const uint64_t mask64 = (uint64_t)mask32 << 32 | (uint64_t)mask32; #endif - do { - uint64_t v = string_as_uint64(msg); - v ^= (uint64_t)mask64; - msg = mempcpy(msg, &v, sizeof(v)); - msg_len -= 8; - } while (msg_len >= 8); - } + do { + uint64_t v = string_as_uint64(msg); + v ^= (uint64_t)mask64; + msg = mempcpy(msg, &v, sizeof(v)); + msg_len -= 8; + } while (msg_len >= 8); } if (msg_len >= 4) {