Skip to content
Permalink
Browse files Browse the repository at this point in the history
Allow ref byte to point just past the end of spans (#73)
Motivation
----------
I was under the impression that `ref byte` should never point past the
end of a `Span<byte>` or `ReadOnlySpan<byte>` because then GC couldn't
recognize the pointer and adjust it during GC moves. However, this there
is a specific exception that allows `ref byte` to point precisely one
byte past the end and still be recognized to allow pointer arithmetic
scenarios like the ones in this algorithm.

Modifications
-------------
Where there is more complex logic (compared to the reference C++
implementation) to allow buffer end pointers to point to the last byte
in the buffer simplify this logic and point one byte past the end of the
buffer.

Also ensure that comparison operations that ensure a certain length
don't involve an intermediate pointer that moves off the beginning of
the buffer.

Results
-------
We're now back to closer to the reference C++ implementation but still
have memory safety.
  • Loading branch information
brantburnett committed Mar 26, 2023
1 parent e3c9834 commit d7ac526
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 56 deletions.
2 changes: 1 addition & 1 deletion Snappier.Benchmarks/FindMatchLength.cs
Expand Up @@ -64,7 +64,7 @@ public void GlobalSetup()

ref byte s1 = ref _array[0];
ref byte s2 = ref Unsafe.Add(ref s1, 12);
ref byte s2Limit = ref Unsafe.Add(ref s1, _array.Length - 1);
ref byte s2Limit = ref Unsafe.Add(ref s1, _array.Length);

return SnappyCompressor.FindMatchLength(ref s1, ref s2, ref s2Limit, ref data);
}
Expand Down
2 changes: 1 addition & 1 deletion Snappier.Tests/Internal/SnappyCompressorTests.cs
Expand Up @@ -92,7 +92,7 @@ public void FindMatchLength(int expectedResult, string s1String, string s2String
ref byte s2 = ref Unsafe.Add(ref s1, s1String.Length);

var result =
SnappyCompressor.FindMatchLength(ref s1, ref s2, ref Unsafe.Add(ref s2, length - 1), ref data);
SnappyCompressor.FindMatchLength(ref s1, ref s2, ref Unsafe.Add(ref s2, length), ref data);

Assert.Equal(result.matchLength < 8, result.matchLengthLessThan8);
Assert.Equal(expectedResult, result.matchLength);
Expand Down
40 changes: 17 additions & 23 deletions Snappier/Internal/SnappyCompressor.cs
Expand Up @@ -125,16 +125,15 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
uint mask = (uint)(2 * (tableSpan.Length - 1));

ref byte inputStart = ref Unsafe.AsRef(in input[0]);
// Last byte of the input, not one byte past the end, to avoid issues on GC moves
ref byte inputEnd = ref Unsafe.Add(ref inputStart, input.Length - 1);
ref byte inputEnd = ref Unsafe.Add(ref inputStart, input.Length);
ref byte ip = ref inputStart;

ref byte op = ref output[0];
ref ushort table = ref tableSpan[0];

if (input.Length >= Constants.InputMarginBytes)
{
ref byte ipLimit = ref Unsafe.Subtract(ref inputEnd, Constants.InputMarginBytes - 1);
ref byte ipLimit = ref Unsafe.Subtract(ref inputEnd, Constants.InputMarginBytes);

for (uint preload = Helpers.UnsafeReadUInt32(ref Unsafe.Add(ref ip, 1));;)
{
Expand Down Expand Up @@ -288,7 +287,7 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
// Step 2: A 4-byte match has been found. We'll later see if more
// than 4 bytes match. But, prior to the match, input
// bytes [next_emit, ip) are unmatched. Emit them as "literal bytes."
Debug.Assert(!Unsafe.IsAddressGreaterThan(ref Unsafe.Add(ref nextEmit, 16), ref Unsafe.Add(ref inputEnd, 1)));
Debug.Assert(!Unsafe.IsAddressGreaterThan(ref Unsafe.Add(ref nextEmit, 16), ref inputEnd));
op = ref EmitLiteralFast(ref op, ref nextEmit, (uint) Unsafe.ByteOffset(ref nextEmit, ref ip));

// Step 3: Call EmitCopy, and then see if another EmitCopy could
Expand Down Expand Up @@ -350,9 +349,9 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,

emit_remainder:
// Emit the remaining bytes as a literal
if (!Unsafe.IsAddressGreaterThan(ref ip, ref inputEnd))
if (Unsafe.IsAddressLessThan(ref ip, ref inputEnd))
{
op = ref EmitLiteralSlow(ref op, ref ip, (uint) Unsafe.ByteOffset(ref ip, ref inputEnd) + 1);
op = ref EmitLiteralSlow(ref op, ref ip, (uint) Unsafe.ByteOffset(ref ip, ref inputEnd));
}

return (int) Unsafe.ByteOffset(ref output[0], ref op);
Expand Down Expand Up @@ -490,28 +489,23 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
/// Find the largest n such that
///
/// s1[0,n-1] == s2[0,n-1]
/// and n &lt;= (s2_limit - s2 + 1).
/// and n &lt;= (s2_limit - s2).
///
/// Return (n, n &lt; 8).
/// Reads up to and including *s2_limit but not beyond.
/// Does not read *(s1 + (s2_limit - s2 + 1)) or beyond.
/// Requires that s2_limit+1 &gt;= s2.
/// Does not read *(s1 + (s2_limit - s2)) or beyond.
/// Requires that s2_limit &gt;= s2.
///
/// In addition populate *data with the next 5 bytes from the end of the match.
/// This is only done if 8 bytes are available (s2_limit - s2 &gt;= 8). The point is
/// that on some arch's this can be done faster in this routine than subsequent
/// loading from s2 + n.
/// </summary>
/// <remarks>
/// The reference implementation has s2Limit as one byte past the end of the input,
/// but this implementation has it at the end of the input. This ensures that it always
/// points within the array in case GC moves the array.
/// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static (int matchLength, bool matchLengthLessThan8) FindMatchLength(
ref byte s1, ref byte s2, ref byte s2Limit, ref ulong data)
{
Debug.Assert(!Unsafe.IsAddressLessThan(ref Unsafe.Add(ref s2Limit, 1), ref s2));
Debug.Assert(!Unsafe.IsAddressLessThan(ref s2Limit, ref s2));

if (BitConverter.IsLittleEndian && IntPtr.Size == 8)
{
Expand All @@ -521,14 +515,14 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,

int matched = 0;

while (!Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 3))
while (Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)4
&& Helpers.UnsafeReadUInt32(ref s2) == Helpers.UnsafeReadUInt32(ref Unsafe.Add(ref s1, matched)))
{
s2 = ref Unsafe.Add(ref s2, 4);
matched += 4;
}

if (BitConverter.IsLittleEndian && !Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 3)))
if (BitConverter.IsLittleEndian && Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)4)
{
uint x = Helpers.UnsafeReadUInt32(ref s2) ^ Helpers.UnsafeReadUInt32(ref Unsafe.Add(ref s1, matched));
int matchingBits = Helpers.FindLsbSetNonZero(x);
Expand All @@ -537,14 +531,14 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
}
else
{
while (!Unsafe.IsAddressGreaterThan(ref s2, ref s2Limit) && Unsafe.Add(ref s1, matched) == s2)
while (Unsafe.IsAddressLessThan(ref s2, ref s2Limit) && Unsafe.Add(ref s1, matched) == s2)
{
s2 = ref Unsafe.Add(ref s2, 1);
++matched;
}
}

if (!Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 7)))
if (Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)8)
{
data = Helpers.UnsafeReadUInt64(ref s2);
}
Expand All @@ -562,7 +556,7 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
// immediately. As an optimization though, it is useful. It creates some not
// uncommon code paths that determine, without extra effort, whether the match
// length is less than 8.
if (!Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 15)))
if (Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)16)
{
ulong a1 = Helpers.UnsafeReadUInt64(ref s1);
ulong a2 = Helpers.UnsafeReadUInt64(ref s2);
Expand Down Expand Up @@ -590,7 +584,7 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
// time until we find a 64-bit block that doesn't match; then we find
// the first non-matching bit and use that to calculate the total
// length of the match.
while (!Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 15)))
while (Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)16)
{
ulong a1 = Helpers.UnsafeReadUInt64(ref Unsafe.Add(ref s1, matched));
ulong a2 = Helpers.UnsafeReadUInt64(ref s2);
Expand All @@ -615,7 +609,7 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
}
}

while (!Unsafe.IsAddressGreaterThan(ref s2, ref s2Limit))
while (Unsafe.IsAddressLessThan(ref s2, ref s2Limit))
{
if (Unsafe.Add(ref s1, matched) == s2)
{
Expand All @@ -624,7 +618,7 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
}
else
{
if (!Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 7)))
if (Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)8)
{
data = Helpers.UnsafeReadUInt64(ref s2);
}
Expand Down
57 changes: 26 additions & 31 deletions Snappier/Internal/SnappyDecompressor.cs
Expand Up @@ -186,17 +186,11 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
unchecked
{
ref byte input = ref Unsafe.AsRef(in inputSpan[0]);

// The reference Snappy implementation uses inputEnd as a pointer one byte past the end of the buffer.
// However, this is not safe when using ref locals. The ref must point to somewhere within the array
// so that GC can adjust the ref if the memory is moved.
ref byte inputEnd = ref Unsafe.Add(ref input, inputSpan.Length - 1);
ref byte inputEnd = ref Unsafe.Add(ref input, inputSpan.Length);

// Track the point in the input before which input is guaranteed to have at least Constants.MaxTagLength bytes left
ref byte inputLimitMinMaxTagLength = ref Unsafe.Subtract(ref inputEnd, Math.Min(inputSpan.Length, Constants.MaximumTagLength - 1) - 1);
ref byte inputLimitMinMaxTagLength = ref Unsafe.Subtract(ref inputEnd, Math.Min(inputSpan.Length, Constants.MaximumTagLength - 1));

// We always allocate buffer with at least one extra byte on the end, so bufferEnd doesn't have the same
// restrictions as inputEnd.
ref byte buffer = ref _lookbackBuffer.Span[0];
ref byte bufferEnd = ref Unsafe.Add(ref buffer, _lookbackBuffer.Length);
ref byte op = ref Unsafe.Add(ref buffer, _lookbackPosition);
Expand Down Expand Up @@ -239,9 +233,9 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
{
// Data has been moved to the scratch buffer
input = ref scratch;
inputEnd = ref Unsafe.Add(ref input, newScratchLength - 1);
inputEnd = ref Unsafe.Add(ref input, newScratchLength);
inputLimitMinMaxTagLength = ref Unsafe.Subtract(ref inputEnd,
Math.Min(newScratchLength, Constants.MaximumTagLength - 1) - 1);
Math.Min(newScratchLength, Constants.MaximumTagLength - 1));
}
}

Expand All @@ -256,7 +250,7 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
{
nint literalLength = unchecked((c >> 2) + 1);

if (TryFastAppend(ref op, ref bufferEnd, in input, Unsafe.ByteOffset(ref input, ref inputEnd) + 1, literalLength))
if (TryFastAppend(ref op, ref bufferEnd, in input, Unsafe.ByteOffset(ref input, ref inputEnd), literalLength))
{
Debug.Assert(literalLength < 61);
op = ref Unsafe.Add(ref op, literalLength);
Expand All @@ -280,7 +274,7 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
input = ref Unsafe.Add(ref input, literalLengthLength);
}

nint inputRemaining = Unsafe.ByteOffset(ref input, ref inputEnd) + 1;
nint inputRemaining = Unsafe.ByteOffset(ref input, ref inputEnd);
if (inputRemaining < literalLength)
{
Append(ref op, ref bufferEnd, in input, inputRemaining);
Expand All @@ -306,9 +300,9 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
{
// Data has been moved to the scratch buffer
input = ref scratch;
inputEnd = ref Unsafe.Add(ref input, newScratchLength - 1);
inputEnd = ref Unsafe.Add(ref input, newScratchLength);
inputLimitMinMaxTagLength = ref Unsafe.Subtract(ref inputEnd,
Math.Min(newScratchLength, Constants.MaximumTagLength - 1) - 1);
Math.Min(newScratchLength, Constants.MaximumTagLength - 1));

}
}
Expand Down Expand Up @@ -367,9 +361,9 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
{
// Data has been moved to the scratch buffer
input = ref scratch;
inputEnd = ref Unsafe.Add(ref input, newScratchLength - 1);
inputEnd = ref Unsafe.Add(ref input, newScratchLength);
inputLimitMinMaxTagLength = ref Unsafe.Subtract(ref inputEnd,
Math.Min(newScratchLength, Constants.MaximumTagLength - 1) - 1);
Math.Min(newScratchLength, Constants.MaximumTagLength - 1));
}
}

Expand Down Expand Up @@ -415,7 +409,7 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
(int) literalLengthLength) + 1;
}

nint inputRemaining = Unsafe.ByteOffset(ref input, ref inputEnd) + 1;
nint inputRemaining = Unsafe.ByteOffset(ref input, ref inputEnd);
if (inputRemaining < literalLength)
{
Append(ref op, ref bufferEnd, in input, inputRemaining);
Expand Down Expand Up @@ -468,7 +462,7 @@ private uint RefillTagFromScratch(ref byte input, ref byte inputEnd, ref byte sc
{
Debug.Assert(_scratchLength > 0);

if (Unsafe.IsAddressGreaterThan(ref input, ref inputEnd))
if (!Unsafe.IsAddressLessThan(ref input, ref inputEnd))
{
return 0;
}
Expand All @@ -477,7 +471,7 @@ private uint RefillTagFromScratch(ref byte input, ref byte inputEnd, ref byte sc
uint entry = Constants.CharTable[scratch];
uint needed = (entry >> 11) + 1; // +1 byte for 'c'

uint toCopy = Math.Min((uint)Unsafe.ByteOffset(ref input, ref inputEnd) + 1, needed - _scratchLength);
uint toCopy = Math.Min((uint)Unsafe.ByteOffset(ref input, ref inputEnd), needed - _scratchLength);
Unsafe.CopyBlockUnaligned(ref Unsafe.Add(ref scratch, _scratchLength), ref input, toCopy);

_scratchLength += toCopy;
Expand All @@ -502,7 +496,7 @@ private uint RefillTagFromScratch(ref byte input, ref byte inputEnd, ref byte sc
// always have some extra bytes on the end so we don't risk buffer overruns.
private uint RefillTag(ref byte input, ref byte inputEnd, ref byte scratch)
{
if (Unsafe.IsAddressGreaterThan(ref input, ref inputEnd))
if (!Unsafe.IsAddressLessThan(ref input, ref inputEnd))
{
return uint.MaxValue;
}
Expand All @@ -511,7 +505,7 @@ private uint RefillTag(ref byte input, ref byte inputEnd, ref byte scratch)
uint entry = Constants.CharTable[input];
uint needed = (entry >> 11) + 1; // +1 byte for 'c'

uint inputLength = (uint)Unsafe.ByteOffset(ref input, ref inputEnd) + 1;
uint inputLength = (uint)Unsafe.ByteOffset(ref input, ref inputEnd);
if (inputLength < needed)
{
// Data is insufficient, copy to scratch
Expand Down Expand Up @@ -555,11 +549,8 @@ private uint RefillTag(ref byte input, ref byte inputEnd, ref byte scratch)
ArrayPool<byte>.Shared.Return(_lookbackBufferArray);
}

// Always pad the lookback buffer with an extra byte that we don't use. This allows a "ref byte" reference past
// the end of the perceived buffer that still points within the array. This is a requirement so that GC can recognize
// the "ref byte" points within the array and adjust it if the array is moved.
_lookbackBufferArray = ArrayPool<byte>.Shared.Rent(value.GetValueOrDefault() + 1);
_lookbackBuffer = _lookbackBufferArray.AsMemory(0, _lookbackBufferArray.Length - 1);
_lookbackBufferArray = ArrayPool<byte>.Shared.Rent(value.GetValueOrDefault());
_lookbackBuffer = _lookbackBufferArray.AsMemory(0, _lookbackBufferArray.Length);
}
}
}
Expand Down Expand Up @@ -595,7 +586,7 @@ private void Append(ReadOnlySpan<byte> input)
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void Append(ref byte op, ref byte bufferEnd, in byte input, nint length)
private static void Append(ref byte op, ref byte bufferEnd, in byte input, nint length)
{
if (length > Unsafe.ByteOffset(ref op, ref bufferEnd))
{
Expand All @@ -606,7 +597,7 @@ private void Append(ref byte op, ref byte bufferEnd, in byte input, nint length)
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private bool TryFastAppend(ref byte op, ref byte bufferEnd, in byte input, nint available, nint length)
private static bool TryFastAppend(ref byte op, ref byte bufferEnd, in byte input, nint available, nint length)
{
if (length <= 16 && available >= 16 + Constants.MaximumTagLength &&
Unsafe.ByteOffset(ref op, ref bufferEnd) >= (nint) 16)
Expand All @@ -619,10 +610,13 @@ private bool TryFastAppend(ref byte op, ref byte bufferEnd, in byte input, nint
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void AppendFromSelf(ref byte op, ref byte buffer, ref byte bufferEnd, uint copyOffset, nint length)
private static void AppendFromSelf(ref byte op, ref byte buffer, ref byte bufferEnd, uint copyOffset, nint length)
{
ref byte source = ref Unsafe.Subtract(ref op, copyOffset);
if (!Unsafe.IsAddressLessThan(ref source, ref op) || Unsafe.IsAddressLessThan(ref source, ref buffer))
// ToInt64() ensures that this logic works correctly on x86 (with a slight perf hit on x86, though). This is because
// nint is only 32-bit on x86, so casting uint copyOffset to an nint for the comparison can result in a negative number with some
// forms of illegal data. This would then bypass the exception and cause unsafe memory access. Performing the comparison
// as a long ensures we have enough bits to not lose data. On 64-bit platforms this is effectively a no-op.
if (copyOffset == 0 || Unsafe.ByteOffset(ref buffer, ref op).ToInt64() < copyOffset)
{
ThrowHelper.ThrowInvalidDataException("Invalid copy offset");
}
Expand All @@ -632,6 +626,7 @@ private void AppendFromSelf(ref byte op, ref byte buffer, ref byte bufferEnd, ui
ThrowHelper.ThrowInvalidDataException("Data too long");
}

ref byte source = ref Unsafe.Subtract(ref op, copyOffset);
CopyHelpers.IncrementalCopy(ref source, ref op,
ref Unsafe.Add(ref op, length), ref bufferEnd);
}
Expand Down

0 comments on commit d7ac526

Please sign in to comment.