Skip to content

Commit d7ac526

Browse files
authored
Allow ref byte to point just past the end of spans (#73)
Motivation ---------- I was under the impression that `ref byte` should never point past the end of a `Span<byte>` or `ReadOnlySpan<byte>` because then GC couldn't recognize the pointer and adjust it during GC moves. However, this there is a specific exception that allows `ref byte` to point precisely one byte past the end and still be recognized to allow pointer arithmetic scenarios like the ones in this algorithm. Modifications ------------- Where there is more complex logic (compared to the reference C++ implementation) to allow buffer end pointers to point to the last byte in the buffer simplify this logic and point one byte past the end of the buffer. Also ensure that comparison operations that ensure a certain length don't involve an intermediate pointer that moves off the beginning of the buffer. Results ------- We're now back to closer to the reference C++ implementation but still have memory safety.
1 parent e3c9834 commit d7ac526

File tree

4 files changed

+45
-56
lines changed

4 files changed

+45
-56
lines changed

Diff for: Snappier.Benchmarks/FindMatchLength.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public void GlobalSetup()
6464

6565
ref byte s1 = ref _array[0];
6666
ref byte s2 = ref Unsafe.Add(ref s1, 12);
67-
ref byte s2Limit = ref Unsafe.Add(ref s1, _array.Length - 1);
67+
ref byte s2Limit = ref Unsafe.Add(ref s1, _array.Length);
6868

6969
return SnappyCompressor.FindMatchLength(ref s1, ref s2, ref s2Limit, ref data);
7070
}

Diff for: Snappier.Tests/Internal/SnappyCompressorTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public void FindMatchLength(int expectedResult, string s1String, string s2String
9292
ref byte s2 = ref Unsafe.Add(ref s1, s1String.Length);
9393

9494
var result =
95-
SnappyCompressor.FindMatchLength(ref s1, ref s2, ref Unsafe.Add(ref s2, length - 1), ref data);
95+
SnappyCompressor.FindMatchLength(ref s1, ref s2, ref Unsafe.Add(ref s2, length), ref data);
9696

9797
Assert.Equal(result.matchLength < 8, result.matchLengthLessThan8);
9898
Assert.Equal(expectedResult, result.matchLength);

Diff for: Snappier/Internal/SnappyCompressor.cs

+17-23
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,15 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
125125
uint mask = (uint)(2 * (tableSpan.Length - 1));
126126

127127
ref byte inputStart = ref Unsafe.AsRef(in input[0]);
128-
// Last byte of the input, not one byte past the end, to avoid issues on GC moves
129-
ref byte inputEnd = ref Unsafe.Add(ref inputStart, input.Length - 1);
128+
ref byte inputEnd = ref Unsafe.Add(ref inputStart, input.Length);
130129
ref byte ip = ref inputStart;
131130

132131
ref byte op = ref output[0];
133132
ref ushort table = ref tableSpan[0];
134133

135134
if (input.Length >= Constants.InputMarginBytes)
136135
{
137-
ref byte ipLimit = ref Unsafe.Subtract(ref inputEnd, Constants.InputMarginBytes - 1);
136+
ref byte ipLimit = ref Unsafe.Subtract(ref inputEnd, Constants.InputMarginBytes);
138137

139138
for (uint preload = Helpers.UnsafeReadUInt32(ref Unsafe.Add(ref ip, 1));;)
140139
{
@@ -288,7 +287,7 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
288287
// Step 2: A 4-byte match has been found. We'll later see if more
289288
// than 4 bytes match. But, prior to the match, input
290289
// bytes [next_emit, ip) are unmatched. Emit them as "literal bytes."
291-
Debug.Assert(!Unsafe.IsAddressGreaterThan(ref Unsafe.Add(ref nextEmit, 16), ref Unsafe.Add(ref inputEnd, 1)));
290+
Debug.Assert(!Unsafe.IsAddressGreaterThan(ref Unsafe.Add(ref nextEmit, 16), ref inputEnd));
292291
op = ref EmitLiteralFast(ref op, ref nextEmit, (uint) Unsafe.ByteOffset(ref nextEmit, ref ip));
293292

294293
// Step 3: Call EmitCopy, and then see if another EmitCopy could
@@ -350,9 +349,9 @@ private static int CompressFragment(ReadOnlySpan<byte> input, Span<byte> output,
350349

351350
emit_remainder:
352351
// Emit the remaining bytes as a literal
353-
if (!Unsafe.IsAddressGreaterThan(ref ip, ref inputEnd))
352+
if (Unsafe.IsAddressLessThan(ref ip, ref inputEnd))
354353
{
355-
op = ref EmitLiteralSlow(ref op, ref ip, (uint) Unsafe.ByteOffset(ref ip, ref inputEnd) + 1);
354+
op = ref EmitLiteralSlow(ref op, ref ip, (uint) Unsafe.ByteOffset(ref ip, ref inputEnd));
356355
}
357356

358357
return (int) Unsafe.ByteOffset(ref output[0], ref op);
@@ -490,28 +489,23 @@ private static ref byte EmitCopyLenGreaterThanOrEqualTo12(ref byte op, long offs
490489
/// Find the largest n such that
491490
///
492491
/// s1[0,n-1] == s2[0,n-1]
493-
/// and n &lt;= (s2_limit - s2 + 1).
492+
/// and n &lt;= (s2_limit - s2).
494493
///
495494
/// Return (n, n &lt; 8).
496495
/// Reads up to and including *s2_limit but not beyond.
497-
/// Does not read *(s1 + (s2_limit - s2 + 1)) or beyond.
498-
/// Requires that s2_limit+1 &gt;= s2.
496+
/// Does not read *(s1 + (s2_limit - s2)) or beyond.
497+
/// Requires that s2_limit &gt;= s2.
499498
///
500499
/// In addition populate *data with the next 5 bytes from the end of the match.
501500
/// This is only done if 8 bytes are available (s2_limit - s2 &gt;= 8). The point is
502501
/// that on some arch's this can be done faster in this routine than subsequent
503502
/// loading from s2 + n.
504503
/// </summary>
505-
/// <remarks>
506-
/// The reference implementation has s2Limit as one byte past the end of the input,
507-
/// but this implementation has it at the end of the input. This ensures that it always
508-
/// points within the array in case GC moves the array.
509-
/// </remarks>
510504
[MethodImpl(MethodImplOptions.AggressiveInlining)]
511505
internal static (int matchLength, bool matchLengthLessThan8) FindMatchLength(
512506
ref byte s1, ref byte s2, ref byte s2Limit, ref ulong data)
513507
{
514-
Debug.Assert(!Unsafe.IsAddressLessThan(ref Unsafe.Add(ref s2Limit, 1), ref s2));
508+
Debug.Assert(!Unsafe.IsAddressLessThan(ref s2Limit, ref s2));
515509

516510
if (BitConverter.IsLittleEndian && IntPtr.Size == 8)
517511
{
@@ -521,14 +515,14 @@ internal static (int matchLength, bool matchLengthLessThan8) FindMatchLength(
521515

522516
int matched = 0;
523517

524-
while (!Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 3))
518+
while (Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)4
525519
&& Helpers.UnsafeReadUInt32(ref s2) == Helpers.UnsafeReadUInt32(ref Unsafe.Add(ref s1, matched)))
526520
{
527521
s2 = ref Unsafe.Add(ref s2, 4);
528522
matched += 4;
529523
}
530524

531-
if (BitConverter.IsLittleEndian && !Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 3)))
525+
if (BitConverter.IsLittleEndian && Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)4)
532526
{
533527
uint x = Helpers.UnsafeReadUInt32(ref s2) ^ Helpers.UnsafeReadUInt32(ref Unsafe.Add(ref s1, matched));
534528
int matchingBits = Helpers.FindLsbSetNonZero(x);
@@ -537,14 +531,14 @@ internal static (int matchLength, bool matchLengthLessThan8) FindMatchLength(
537531
}
538532
else
539533
{
540-
while (!Unsafe.IsAddressGreaterThan(ref s2, ref s2Limit) && Unsafe.Add(ref s1, matched) == s2)
534+
while (Unsafe.IsAddressLessThan(ref s2, ref s2Limit) && Unsafe.Add(ref s1, matched) == s2)
541535
{
542536
s2 = ref Unsafe.Add(ref s2, 1);
543537
++matched;
544538
}
545539
}
546540

547-
if (!Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 7)))
541+
if (Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)8)
548542
{
549543
data = Helpers.UnsafeReadUInt64(ref s2);
550544
}
@@ -562,7 +556,7 @@ private static (int matchLength, bool matchLengthLessThan8) FindMatchLengthX64(
562556
// immediately. As an optimization though, it is useful. It creates some not
563557
// uncommon code paths that determine, without extra effort, whether the match
564558
// length is less than 8.
565-
if (!Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 15)))
559+
if (Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)16)
566560
{
567561
ulong a1 = Helpers.UnsafeReadUInt64(ref s1);
568562
ulong a2 = Helpers.UnsafeReadUInt64(ref s2);
@@ -590,7 +584,7 @@ private static (int matchLength, bool matchLengthLessThan8) FindMatchLengthX64(
590584
// time until we find a 64-bit block that doesn't match; then we find
591585
// the first non-matching bit and use that to calculate the total
592586
// length of the match.
593-
while (!Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 15)))
587+
while (Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)16)
594588
{
595589
ulong a1 = Helpers.UnsafeReadUInt64(ref Unsafe.Add(ref s1, matched));
596590
ulong a2 = Helpers.UnsafeReadUInt64(ref s2);
@@ -615,7 +609,7 @@ private static (int matchLength, bool matchLengthLessThan8) FindMatchLengthX64(
615609
}
616610
}
617611

618-
while (!Unsafe.IsAddressGreaterThan(ref s2, ref s2Limit))
612+
while (Unsafe.IsAddressLessThan(ref s2, ref s2Limit))
619613
{
620614
if (Unsafe.Add(ref s1, matched) == s2)
621615
{
@@ -624,7 +618,7 @@ private static (int matchLength, bool matchLengthLessThan8) FindMatchLengthX64(
624618
}
625619
else
626620
{
627-
if (!Unsafe.IsAddressGreaterThan(ref s2, ref Unsafe.Subtract(ref s2Limit, 7)))
621+
if (Unsafe.ByteOffset(ref s2, ref s2Limit) >= (nint)8)
628622
{
629623
data = Helpers.UnsafeReadUInt64(ref s2);
630624
}

Diff for: Snappier/Internal/SnappyDecompressor.cs

+26-31
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,11 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
186186
unchecked
187187
{
188188
ref byte input = ref Unsafe.AsRef(in inputSpan[0]);
189-
190-
// The reference Snappy implementation uses inputEnd as a pointer one byte past the end of the buffer.
191-
// However, this is not safe when using ref locals. The ref must point to somewhere within the array
192-
// so that GC can adjust the ref if the memory is moved.
193-
ref byte inputEnd = ref Unsafe.Add(ref input, inputSpan.Length - 1);
189+
ref byte inputEnd = ref Unsafe.Add(ref input, inputSpan.Length);
194190

195191
// Track the point in the input before which input is guaranteed to have at least Constants.MaxTagLength bytes left
196-
ref byte inputLimitMinMaxTagLength = ref Unsafe.Subtract(ref inputEnd, Math.Min(inputSpan.Length, Constants.MaximumTagLength - 1) - 1);
192+
ref byte inputLimitMinMaxTagLength = ref Unsafe.Subtract(ref inputEnd, Math.Min(inputSpan.Length, Constants.MaximumTagLength - 1));
197193

198-
// We always allocate buffer with at least one extra byte on the end, so bufferEnd doesn't have the same
199-
// restrictions as inputEnd.
200194
ref byte buffer = ref _lookbackBuffer.Span[0];
201195
ref byte bufferEnd = ref Unsafe.Add(ref buffer, _lookbackBuffer.Length);
202196
ref byte op = ref Unsafe.Add(ref buffer, _lookbackPosition);
@@ -239,9 +233,9 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
239233
{
240234
// Data has been moved to the scratch buffer
241235
input = ref scratch;
242-
inputEnd = ref Unsafe.Add(ref input, newScratchLength - 1);
236+
inputEnd = ref Unsafe.Add(ref input, newScratchLength);
243237
inputLimitMinMaxTagLength = ref Unsafe.Subtract(ref inputEnd,
244-
Math.Min(newScratchLength, Constants.MaximumTagLength - 1) - 1);
238+
Math.Min(newScratchLength, Constants.MaximumTagLength - 1));
245239
}
246240
}
247241

@@ -256,7 +250,7 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
256250
{
257251
nint literalLength = unchecked((c >> 2) + 1);
258252

259-
if (TryFastAppend(ref op, ref bufferEnd, in input, Unsafe.ByteOffset(ref input, ref inputEnd) + 1, literalLength))
253+
if (TryFastAppend(ref op, ref bufferEnd, in input, Unsafe.ByteOffset(ref input, ref inputEnd), literalLength))
260254
{
261255
Debug.Assert(literalLength < 61);
262256
op = ref Unsafe.Add(ref op, literalLength);
@@ -280,7 +274,7 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
280274
input = ref Unsafe.Add(ref input, literalLengthLength);
281275
}
282276

283-
nint inputRemaining = Unsafe.ByteOffset(ref input, ref inputEnd) + 1;
277+
nint inputRemaining = Unsafe.ByteOffset(ref input, ref inputEnd);
284278
if (inputRemaining < literalLength)
285279
{
286280
Append(ref op, ref bufferEnd, in input, inputRemaining);
@@ -306,9 +300,9 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
306300
{
307301
// Data has been moved to the scratch buffer
308302
input = ref scratch;
309-
inputEnd = ref Unsafe.Add(ref input, newScratchLength - 1);
303+
inputEnd = ref Unsafe.Add(ref input, newScratchLength);
310304
inputLimitMinMaxTagLength = ref Unsafe.Subtract(ref inputEnd,
311-
Math.Min(newScratchLength, Constants.MaximumTagLength - 1) - 1);
305+
Math.Min(newScratchLength, Constants.MaximumTagLength - 1));
312306

313307
}
314308
}
@@ -367,9 +361,9 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
367361
{
368362
// Data has been moved to the scratch buffer
369363
input = ref scratch;
370-
inputEnd = ref Unsafe.Add(ref input, newScratchLength - 1);
364+
inputEnd = ref Unsafe.Add(ref input, newScratchLength);
371365
inputLimitMinMaxTagLength = ref Unsafe.Subtract(ref inputEnd,
372-
Math.Min(newScratchLength, Constants.MaximumTagLength - 1) - 1);
366+
Math.Min(newScratchLength, Constants.MaximumTagLength - 1));
373367
}
374368
}
375369

@@ -415,7 +409,7 @@ internal void DecompressAllTags(ReadOnlySpan<byte> inputSpan)
415409
(int) literalLengthLength) + 1;
416410
}
417411

418-
nint inputRemaining = Unsafe.ByteOffset(ref input, ref inputEnd) + 1;
412+
nint inputRemaining = Unsafe.ByteOffset(ref input, ref inputEnd);
419413
if (inputRemaining < literalLength)
420414
{
421415
Append(ref op, ref bufferEnd, in input, inputRemaining);
@@ -468,7 +462,7 @@ private uint RefillTagFromScratch(ref byte input, ref byte inputEnd, ref byte sc
468462
{
469463
Debug.Assert(_scratchLength > 0);
470464

471-
if (Unsafe.IsAddressGreaterThan(ref input, ref inputEnd))
465+
if (!Unsafe.IsAddressLessThan(ref input, ref inputEnd))
472466
{
473467
return 0;
474468
}
@@ -477,7 +471,7 @@ private uint RefillTagFromScratch(ref byte input, ref byte inputEnd, ref byte sc
477471
uint entry = Constants.CharTable[scratch];
478472
uint needed = (entry >> 11) + 1; // +1 byte for 'c'
479473

480-
uint toCopy = Math.Min((uint)Unsafe.ByteOffset(ref input, ref inputEnd) + 1, needed - _scratchLength);
474+
uint toCopy = Math.Min((uint)Unsafe.ByteOffset(ref input, ref inputEnd), needed - _scratchLength);
481475
Unsafe.CopyBlockUnaligned(ref Unsafe.Add(ref scratch, _scratchLength), ref input, toCopy);
482476

483477
_scratchLength += toCopy;
@@ -502,7 +496,7 @@ private uint RefillTagFromScratch(ref byte input, ref byte inputEnd, ref byte sc
502496
// always have some extra bytes on the end so we don't risk buffer overruns.
503497
private uint RefillTag(ref byte input, ref byte inputEnd, ref byte scratch)
504498
{
505-
if (Unsafe.IsAddressGreaterThan(ref input, ref inputEnd))
499+
if (!Unsafe.IsAddressLessThan(ref input, ref inputEnd))
506500
{
507501
return uint.MaxValue;
508502
}
@@ -511,7 +505,7 @@ private uint RefillTag(ref byte input, ref byte inputEnd, ref byte scratch)
511505
uint entry = Constants.CharTable[input];
512506
uint needed = (entry >> 11) + 1; // +1 byte for 'c'
513507

514-
uint inputLength = (uint)Unsafe.ByteOffset(ref input, ref inputEnd) + 1;
508+
uint inputLength = (uint)Unsafe.ByteOffset(ref input, ref inputEnd);
515509
if (inputLength < needed)
516510
{
517511
// Data is insufficient, copy to scratch
@@ -555,11 +549,8 @@ private int? ExpectedLength
555549
ArrayPool<byte>.Shared.Return(_lookbackBufferArray);
556550
}
557551

558-
// Always pad the lookback buffer with an extra byte that we don't use. This allows a "ref byte" reference past
559-
// the end of the perceived buffer that still points within the array. This is a requirement so that GC can recognize
560-
// the "ref byte" points within the array and adjust it if the array is moved.
561-
_lookbackBufferArray = ArrayPool<byte>.Shared.Rent(value.GetValueOrDefault() + 1);
562-
_lookbackBuffer = _lookbackBufferArray.AsMemory(0, _lookbackBufferArray.Length - 1);
552+
_lookbackBufferArray = ArrayPool<byte>.Shared.Rent(value.GetValueOrDefault());
553+
_lookbackBuffer = _lookbackBufferArray.AsMemory(0, _lookbackBufferArray.Length);
563554
}
564555
}
565556
}
@@ -595,7 +586,7 @@ private void Append(ReadOnlySpan<byte> input)
595586
}
596587

597588
[MethodImpl(MethodImplOptions.AggressiveInlining)]
598-
private void Append(ref byte op, ref byte bufferEnd, in byte input, nint length)
589+
private static void Append(ref byte op, ref byte bufferEnd, in byte input, nint length)
599590
{
600591
if (length > Unsafe.ByteOffset(ref op, ref bufferEnd))
601592
{
@@ -606,7 +597,7 @@ private void Append(ref byte op, ref byte bufferEnd, in byte input, nint length)
606597
}
607598

608599
[MethodImpl(MethodImplOptions.AggressiveInlining)]
609-
private bool TryFastAppend(ref byte op, ref byte bufferEnd, in byte input, nint available, nint length)
600+
private static bool TryFastAppend(ref byte op, ref byte bufferEnd, in byte input, nint available, nint length)
610601
{
611602
if (length <= 16 && available >= 16 + Constants.MaximumTagLength &&
612603
Unsafe.ByteOffset(ref op, ref bufferEnd) >= (nint) 16)
@@ -619,10 +610,13 @@ private bool TryFastAppend(ref byte op, ref byte bufferEnd, in byte input, nint
619610
}
620611

621612
[MethodImpl(MethodImplOptions.AggressiveInlining)]
622-
private void AppendFromSelf(ref byte op, ref byte buffer, ref byte bufferEnd, uint copyOffset, nint length)
613+
private static void AppendFromSelf(ref byte op, ref byte buffer, ref byte bufferEnd, uint copyOffset, nint length)
623614
{
624-
ref byte source = ref Unsafe.Subtract(ref op, copyOffset);
625-
if (!Unsafe.IsAddressLessThan(ref source, ref op) || Unsafe.IsAddressLessThan(ref source, ref buffer))
615+
// ToInt64() ensures that this logic works correctly on x86 (with a slight perf hit on x86, though). This is because
616+
// nint is only 32-bit on x86, so casting uint copyOffset to an nint for the comparison can result in a negative number with some
617+
// forms of illegal data. This would then bypass the exception and cause unsafe memory access. Performing the comparison
618+
// as a long ensures we have enough bits to not lose data. On 64-bit platforms this is effectively a no-op.
619+
if (copyOffset == 0 || Unsafe.ByteOffset(ref buffer, ref op).ToInt64() < copyOffset)
626620
{
627621
ThrowHelper.ThrowInvalidDataException("Invalid copy offset");
628622
}
@@ -632,6 +626,7 @@ private void AppendFromSelf(ref byte op, ref byte buffer, ref byte bufferEnd, ui
632626
ThrowHelper.ThrowInvalidDataException("Data too long");
633627
}
634628

629+
ref byte source = ref Unsafe.Subtract(ref op, copyOffset);
635630
CopyHelpers.IncrementalCopy(ref source, ref op,
636631
ref Unsafe.Add(ref op, length), ref bufferEnd);
637632
}

0 commit comments

Comments
 (0)