diff --git a/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs b/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs index 5f049e6944538..b2a3134ae7501 100644 --- a/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs +++ b/src/libraries/System.Private.CoreLib/src/System/IO/UnmanagedMemoryStream.cs @@ -23,14 +23,9 @@ namespace System.IO * of the UnmanagedMemoryStream. * 3) You clean up the memory when appropriate. The UnmanagedMemoryStream * currently will do NOTHING to free this memory. - * 4) All calls to Write and WriteByte may not be threadsafe currently. - * - * It may become necessary to add in some sort of - * DeallocationMode enum, specifying whether we unmap a section of memory, - * call free, run a user-provided delegate to free the memory, etc. - * We'll suggest user write a subclass of UnmanagedMemoryStream that uses - * a SafeHandle subclass to hold onto the memory. - * + * 4) This type is not thread safe. However, the implementation should prevent buffer + * overruns or returning uninitialized memory when Reads and Writes are called + * concurrently in thread unsafe manner. */ /// @@ -40,10 +35,10 @@ public class UnmanagedMemoryStream : Stream { private SafeBuffer? _buffer; private unsafe byte* _mem; - private long _length; - private long _capacity; - private long _position; - private long _offset; + private nuint _capacity; + private nuint _offset; + private nuint _length; // nuint to guarantee atomic access on 32-bit platforms + private long _position; // long to allow seeking to any location beyond the length of the stream. private FileAccess _access; private bool _isOpen; private CachedCompletedInt32Task _lastReadTask; // The last successful task returned from ReadAsync @@ -123,10 +118,10 @@ protected void Initialize(SafeBuffer buffer, long offset, long length, FileAcces } } - _offset = offset; + _offset = (nuint)offset; _buffer = buffer; - _length = length; - _capacity = length; + _length = (nuint)length; + _capacity = (nuint)length; _access = access; _isOpen = true; } @@ -171,8 +166,8 @@ protected unsafe void Initialize(byte* pointer, long length, long capacity, File _mem = pointer; _offset = 0; - _length = length; - _capacity = capacity; + _length = (nuint)length; + _capacity = (nuint)capacity; _access = access; _isOpen = true; } @@ -259,7 +254,7 @@ public override long Length get { EnsureNotClosed(); - return Interlocked.Read(ref _length); + return (long)_length; } } @@ -271,7 +266,7 @@ public long Capacity get { EnsureNotClosed(); - return _capacity; + return (long)_capacity; } } @@ -283,14 +278,14 @@ public override long Position get { if (!CanSeek) ThrowHelper.ThrowObjectDisposedException_StreamClosed(null); - return Interlocked.Read(ref _position); + return _position; } set { ArgumentOutOfRangeException.ThrowIfNegative(value); if (!CanSeek) ThrowHelper.ThrowObjectDisposedException_StreamClosed(null); - Interlocked.Exchange(ref _position, value); + _position = value; } } @@ -308,11 +303,10 @@ public override long Position EnsureNotClosed(); // Use a temp to avoid a race - long pos = Interlocked.Read(ref _position); - if (pos > _capacity) + long pos = _position; + if (pos > (long)_capacity) throw new IndexOutOfRangeException(SR.IndexOutOfRange_UMSPosition); - byte* ptr = _mem + pos; - return ptr; + return _mem + pos; } set { @@ -327,7 +321,7 @@ public override long Position if (newPosition < 0) throw new ArgumentOutOfRangeException(nameof(value), SR.ArgumentOutOfRange_UnmanagedMemStreamLength); - Interlocked.Exchange(ref _position, newPosition); + _position = newPosition; } } @@ -367,8 +361,13 @@ internal int ReadCore(Span buffer) // Use a local variable to avoid a race where another thread // changes our position after we decide we can read some bytes. - long pos = Interlocked.Read(ref _position); - long len = Interlocked.Read(ref _length); + long pos = _position; + + // Use a volatile read to prevent reading of the uninitialized memory. This volatile read + // and matching volatile write that set _length avoids reordering of NativeMemory.Clear + // operations with reading of the buffer below. + long len = (long)Volatile.Read(ref _length); + long n = Math.Min(len - pos, buffer.Length); if (n <= 0) { @@ -407,7 +406,7 @@ internal int ReadCore(Span buffer) } } - Interlocked.Exchange(ref _position, pos + n); + _position = pos + n; return nInt; } @@ -484,11 +483,16 @@ public override int ReadByte() EnsureNotClosed(); EnsureReadable(); - long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition - long len = Interlocked.Read(ref _length); + long pos = _position; // Use a local to avoid a race condition + + // Use a volatile read to prevent reading of the uninitialized memory. This volatile read + // and matching volatile write that set _length avoids reordering of NativeMemory.Clear + // operations with reading of the buffer below. + long len = (long)Volatile.Read(ref _length); + if (pos >= len) return -1; - Interlocked.Exchange(ref _position, pos + 1); + _position = pos + 1; int result; if (_buffer != null) { @@ -529,35 +533,33 @@ public override long Seek(long offset, SeekOrigin loc) { EnsureNotClosed(); + long newPosition; switch (loc) { case SeekOrigin.Begin: - if (offset < 0) + newPosition = offset; + if (newPosition < 0) throw new IOException(SR.IO_SeekBeforeBegin); - Interlocked.Exchange(ref _position, offset); break; case SeekOrigin.Current: - long pos = Interlocked.Read(ref _position); - if (offset + pos < 0) + newPosition = _position + offset; + if (newPosition < 0) throw new IOException(SR.IO_SeekBeforeBegin); - Interlocked.Exchange(ref _position, offset + pos); break; case SeekOrigin.End: - long len = Interlocked.Read(ref _length); - if (len + offset < 0) + newPosition = (long)_length + offset; + if (newPosition < 0) throw new IOException(SR.IO_SeekBeforeBegin); - Interlocked.Exchange(ref _position, len + offset); break; default: throw new ArgumentException(SR.Argument_InvalidSeekOrigin); } - long finalPos = Interlocked.Read(ref _position); - Debug.Assert(finalPos >= 0, "_position >= 0"); - return finalPos; + _position = newPosition; + return newPosition; } /// @@ -573,11 +575,10 @@ public override void SetLength(long value) EnsureNotClosed(); EnsureWriteable(); - if (value > _capacity) + if (value > (long)_capacity) throw new IOException(SR.IO_FixedCapacity); - long pos = Interlocked.Read(ref _position); - long len = Interlocked.Read(ref _length); + long len = (long)_length; if (value > len) { unsafe @@ -585,10 +586,11 @@ public override void SetLength(long value) NativeMemory.Clear(_mem + len, (nuint)(value - len)); } } - Interlocked.Exchange(ref _length, value); - if (pos > value) + Volatile.Write(ref _length, (nuint)value); // volatile to prevent reading of uninitialized memory + + if (_position > value) { - Interlocked.Exchange(ref _position, value); + _position = value; } } @@ -625,8 +627,8 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) EnsureNotClosed(); EnsureWriteable(); - long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition - long len = Interlocked.Read(ref _length); + long pos = _position; // Use a local to avoid a race condition + long len = (long)_length; long n = pos + buffer.Length; // Check for overflow if (n < 0) @@ -634,7 +636,7 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) throw new IOException(SR.IO_StreamTooLong); } - if (n > _capacity) + if (n > (long)_capacity) { throw new NotSupportedException(SR.IO_FixedCapacity); } @@ -648,16 +650,16 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) NativeMemory.Clear(_mem + len, (nuint)(pos - len)); } - // set length after zeroing memory to avoid race condition of accessing unzeroed memory + // set length after zeroing memory to avoid race condition of accessing uninitialized memory if (n > len) { - Interlocked.Exchange(ref _length, n); + Volatile.Write(ref _length, (nuint)n); // volatile to prevent reading of uninitialized memory } } if (_buffer != null) { - long bytesLeft = _capacity - pos; + long bytesLeft = (long)_capacity - pos; if (bytesLeft < buffer.Length) { throw new ArgumentException(SR.Arg_BufferTooSmall); @@ -682,8 +684,7 @@ internal unsafe void WriteCore(ReadOnlySpan buffer) Buffer.Memmove(ref *(_mem + pos), ref MemoryMarshal.GetReference(buffer), (nuint)buffer.Length); } - Interlocked.Exchange(ref _position, n); - return; + _position = n; } /// @@ -754,8 +755,8 @@ public override void WriteByte(byte value) EnsureNotClosed(); EnsureWriteable(); - long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition - long len = Interlocked.Read(ref _length); + long pos = _position; // Use a local to avoid a race condition + long len = (long)_length; long n = pos + 1; if (pos >= len) { @@ -763,7 +764,7 @@ public override void WriteByte(byte value) if (n < 0) throw new IOException(SR.IO_StreamTooLong); - if (n > _capacity) + if (n > (long)_capacity) throw new NotSupportedException(SR.IO_FixedCapacity); // Check to see whether we are now expanding the stream and must @@ -779,8 +780,7 @@ public override void WriteByte(byte value) } } - // set length after zeroing memory to avoid race condition of accessing unzeroed memory - Interlocked.Exchange(ref _length, n); + Volatile.Write(ref _length, (nuint)n); // volatile to prevent reading of uninitialized memory } } @@ -810,7 +810,7 @@ public override void WriteByte(byte value) _mem[pos] = value; } } - Interlocked.Exchange(ref _position, n); + _position = n; } } }