Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify SSE implementation of row_lazy match finder #2929

Merged
merged 1 commit into from Dec 15, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
80 changes: 39 additions & 41 deletions lib/compress/zstd_lazy.c
Expand Up @@ -417,7 +417,7 @@ void ZSTD_dedicatedDictSearch_lazy_loadDictionary(ZSTD_matchState_t* ms, const B
U32 const hashLog = ms->cParams.hashLog - ZSTD_LAZY_DDSS_BUCKET_LOG;
U32* const tmpHashTable = hashTable;
U32* const tmpChainTable = hashTable + ((size_t)1 << hashLog);
U32 const tmpChainSize = ((1 << ZSTD_LAZY_DDSS_BUCKET_LOG) - 1) << hashLog;
U32 const tmpChainSize = (U32)((1 << ZSTD_LAZY_DDSS_BUCKET_LOG) - 1) << hashLog;
U32 const tmpMinChain = tmpChainSize < target ? target - tmpChainSize : idx;
U32 hashIdx;

Expand Down Expand Up @@ -982,47 +982,45 @@ void ZSTD_row_update(ZSTD_matchState_t* const ms, const BYTE* ip) {
ZSTD_row_update_internal(ms, ip, mls, rowLog, rowMask, 0 /* dont use cache */);
}

#if defined(ZSTD_ARCH_X86_SSE2)
FORCE_INLINE_TEMPLATE ZSTD_VecMask
ZSTD_row_getSSEMask(int nbChunks, const BYTE* const src, const BYTE tag, const U32 head)
{
const __m128i comparisonMask = _mm_set1_epi8((char)tag);
int matches[4] = {0};
int i;
assert(nbChunks == 1 || nbChunks == 2 || nbChunks == 4);
for (i=0; i<nbChunks; i++) {
const __m128i chunk = _mm_loadu_si128((const __m128i*)(const void*)(src + 16*i));
const __m128i equalMask = _mm_cmpeq_epi8(chunk, comparisonMask);
matches[i] = _mm_movemask_epi8(equalMask);
}
if (nbChunks == 1) return ZSTD_rotateRight_U16((U16)matches[0], head);
if (nbChunks == 2) return ZSTD_rotateRight_U32((U32)matches[1] << 16 | (U32)matches[0], head);
assert(nbChunks == 4);
return ZSTD_rotateRight_U64((U64)matches[3] << 48 | (U64)matches[2] << 32 | (U64)matches[1] << 16 | (U64)matches[0], head);
}
#endif

/* Returns a ZSTD_VecMask (U32) that has the nth bit set to 1 if the newly-computed "tag" matches
* the hash at the nth position in a row of the tagTable.
* Each row is a circular buffer beginning at the value of "head". So we must rotate the "matches" bitfield
* to match up with the actual layout of the entries within the hashTable */
FORCE_INLINE_TEMPLATE
ZSTD_VecMask ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, const U32 rowEntries) {
FORCE_INLINE_TEMPLATE ZSTD_VecMask
ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 head, const U32 rowEntries)
{
const BYTE* const src = tagRow + ZSTD_ROW_HASH_TAG_OFFSET;
assert((rowEntries == 16) || (rowEntries == 32) || rowEntries == 64);
assert(rowEntries <= ZSTD_ROW_HASH_MAX_ENTRIES);

#if defined(ZSTD_ARCH_X86_SSE2)
if (rowEntries == 16) {
const __m128i chunk = _mm_loadu_si128((const __m128i*)(const void*)src);
const __m128i equalMask = _mm_cmpeq_epi8(chunk, _mm_set1_epi8(tag));
const U16 matches = (U16)_mm_movemask_epi8(equalMask);
return ZSTD_rotateRight_U16(matches, head);
} else if (rowEntries == 32) {
const __m128i chunk0 = _mm_loadu_si128((const __m128i*)(const void*)&src[0]);
const __m128i chunk1 = _mm_loadu_si128((const __m128i*)(const void*)&src[16]);
const __m128i equalMask0 = _mm_cmpeq_epi8(chunk0, _mm_set1_epi8(tag));
const __m128i equalMask1 = _mm_cmpeq_epi8(chunk1, _mm_set1_epi8(tag));
const U32 lo = (U32)_mm_movemask_epi8(equalMask0);
const U32 hi = (U32)_mm_movemask_epi8(equalMask1);
return ZSTD_rotateRight_U32((hi << 16) | lo, head);
} else { /* rowEntries == 64 */
const __m128i chunk0 = _mm_loadu_si128((const __m128i*)(const void*)&src[0]);
const __m128i chunk1 = _mm_loadu_si128((const __m128i*)(const void*)&src[16]);
const __m128i chunk2 = _mm_loadu_si128((const __m128i*)(const void*)&src[32]);
const __m128i chunk3 = _mm_loadu_si128((const __m128i*)(const void*)&src[48]);
const __m128i comparisonMask = _mm_set1_epi8(tag);
const __m128i equalMask0 = _mm_cmpeq_epi8(chunk0, comparisonMask);
const __m128i equalMask1 = _mm_cmpeq_epi8(chunk1, comparisonMask);
const __m128i equalMask2 = _mm_cmpeq_epi8(chunk2, comparisonMask);
const __m128i equalMask3 = _mm_cmpeq_epi8(chunk3, comparisonMask);
const U64 mask0 = (U64)_mm_movemask_epi8(equalMask0);
const U64 mask1 = (U64)_mm_movemask_epi8(equalMask1);
const U64 mask2 = (U64)_mm_movemask_epi8(equalMask2);
const U64 mask3 = (U64)_mm_movemask_epi8(equalMask3);
return ZSTD_rotateRight_U64((mask3 << 48) | (mask2 << 32) | (mask1 << 16) | mask0, head);
}
#else
# if defined(ZSTD_ARCH_ARM_NEON)

return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, head);

#else /* SW or NEON-LE */

# if defined(ZSTD_ARCH_ARM_NEON)
/* This NEON path only works for little endian - otherwise use SWAR below */
if (MEM_isLittleEndian()) {
if (rowEntries == 16) {
const uint8x16_t chunk = vld1q_u8(src);
Expand Down Expand Up @@ -1066,9 +1064,9 @@ ZSTD_VecMask ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, con
return ZSTD_rotateRight_U64(matches, head);
}
}
# endif
{ /* SWAR */
const size_t chunkSize = sizeof(size_t);
# endif /* ZSTD_ARCH_ARM_NEON */
/* SWAR */
{ const size_t chunkSize = sizeof(size_t);
const size_t shiftAmount = ((chunkSize * 8) - chunkSize);
const size_t xFF = ~((size_t)0);
const size_t x01 = xFF / 0xFF;
Expand Down Expand Up @@ -1661,7 +1659,7 @@ ZSTD_compressBlock_lazy_generic(
{ start--; matchLength++; }
}
if (isDxS) {
U32 const matchIndex = (U32)((start-base) - (offset - ZSTD_REP_MOVE));
U32 const matchIndex = (U32)((size_t)(start-base) - (offset - ZSTD_REP_MOVE));
const BYTE* match = (matchIndex < prefixLowestIndex) ? dictBase + matchIndex - dictIndexDelta : base + matchIndex;
const BYTE* const mStart = (matchIndex < prefixLowestIndex) ? dictLowest : prefixLowest;
while ((start>anchor) && (match>mStart) && (start[-1] == match[-1])) { start--; match--; matchLength++; } /* catch up */
Expand All @@ -1670,7 +1668,7 @@ ZSTD_compressBlock_lazy_generic(
}
/* store sequence */
_storeSequence:
{ size_t const litLength = start - anchor;
{ size_t const litLength = (size_t)(start - anchor);
ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offset, matchLength-MINMATCH);
anchor = ip = start + matchLength;
}
Expand Down Expand Up @@ -2003,7 +2001,7 @@ size_t ZSTD_compressBlock_lazy_extDict_generic(

/* catch up */
if (offset) {
U32 const matchIndex = (U32)((start-base) - (offset - ZSTD_REP_MOVE));
U32 const matchIndex = (U32)((size_t)(start-base) - (offset - ZSTD_REP_MOVE));
const BYTE* match = (matchIndex < dictLimit) ? dictBase + matchIndex : base + matchIndex;
const BYTE* const mStart = (matchIndex < dictLimit) ? dictStart : prefixStart;
while ((start>anchor) && (match>mStart) && (start[-1] == match[-1])) { start--; match--; matchLength++; } /* catch up */
Expand All @@ -2012,7 +2010,7 @@ size_t ZSTD_compressBlock_lazy_extDict_generic(

/* store sequence */
_storeSequence:
{ size_t const litLength = start - anchor;
{ size_t const litLength = (size_t)(start - anchor);
ZSTD_storeSeq(seqStore, litLength, anchor, iend, (U32)offset, matchLength-MINMATCH);
anchor = ip = start + matchLength;
}
Expand Down