Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CryptoStream Memory-based ReadAsync/WriteAsync overrides #47207

Merged
merged 14 commits into from Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -85,11 +85,13 @@ public partial class CryptoStream : System.IO.Stream, System.IDisposable
public System.Threading.Tasks.ValueTask FlushFinalBlockAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override int Read(byte[] buffer, int offset, int count) { throw null; }
public override System.Threading.Tasks.Task<int> ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; }
public override System.Threading.Tasks.ValueTask<int> ReadAsync(System.Memory<byte> buffer, System.Threading.CancellationToken cancellationToken = default) { throw null; }
public override int ReadByte() { throw null; }
public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; }
public override void SetLength(long value) { }
public override void Write(byte[] buffer, int offset, int count) { }
public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; }
public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory<byte> buffer, System.Threading.CancellationToken cancellationToken = default) { throw null; }
public override void WriteByte(byte value) { }
}
public enum CryptoStreamMode
Expand Down
Expand Up @@ -5,6 +5,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -202,7 +203,7 @@ public override void SetLength(long value)
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
CheckReadArguments(buffer, offset, count);
return ReadAsyncInternal(buffer, offset, count, cancellationToken);
return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible breaking change: if somebody subclasses CryptoStream and overrides ReadAsync(Memory<byte>, ...) as a wrapper around ReadAsync(byte[], ...), the code will now stack overflow. I don't know if anybody is likely to have done this in practice. But it's a potential hazard of having one virtual method begin to dispatch to a different already-existing virtual method.

@bartonjs is there a pattern for handling this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're concerned about that, both overloads of ReadAsync can delegate to a non-virtual ReadAsyncCore.

I've never seen a type derived from CryptoStream, though. I'm sure someone somewhere has done it, but I don't think we need to be too concerned (famous last words). If we were actually concerned, we'd potentially want to take it a step further and have the ReadAsync memory overload check whether the type was derived or not, delegating to the base implementation if it was, in case someone had overridden ReadAsync(byte[], ...) to do something special, in which case it would be a breaking change for ReadAsync(Memory, ...) to not use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we want to do reflection to see if the derived type has actually overridden ReadAsync(byte[]...) so we can possibly skip the memory copy?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we want to do reflection

The answer to that question is pretty much always "no" 😄

}

public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) =>
Expand All @@ -211,7 +212,8 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
public override int EndRead(IAsyncResult asyncResult) =>
TaskToApm.End<int>(asyncResult);

private async Task<int> ReadAsyncInternal(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
/// <inheritdoc/>
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
// To avoid a race with a stream's position pointer & generating race
// conditions with internal buffer indexes in our own streams that
Expand All @@ -222,7 +224,7 @@ private async Task<int> ReadAsyncInternal(byte[] buffer, int offset, int count,
await AsyncActiveSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await ReadAsyncCore(buffer, offset, count, cancellationToken, useAsync: true).ConfigureAwait(false);
return await ReadAsyncCore(buffer, cancellationToken, useAsync: true).ConfigureAwait(false);
}
finally
{
Expand Down Expand Up @@ -268,7 +270,10 @@ public override void WriteByte(byte value)
public override int Read(byte[] buffer, int offset, int count)
{
CheckReadArguments(buffer, offset, count);
return ReadAsyncCore(buffer, offset, count, default(CancellationToken), useAsync: false).GetAwaiter().GetResult();
var completedValueTask = ReadAsyncCore(buffer.AsMemory(offset, count), default(CancellationToken), useAsync: false);
NewellClark marked this conversation as resolved.
Show resolved Hide resolved
Debug.Assert(completedValueTask.IsCompleted);

return completedValueTask.GetAwaiter().GetResult();
}

private void CheckReadArguments(byte[] buffer, int offset, int count)
Expand All @@ -278,22 +283,22 @@ private void CheckReadArguments(byte[] buffer, int offset, int count)
throw new NotSupportedException(SR.NotSupported_UnreadableStream);
}

private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken, bool useAsync)
private async ValueTask<int> ReadAsyncCore(Memory<byte> buffer, CancellationToken cancellationToken, bool useAsync)
{
// read <= count bytes from the input stream, transforming as we go.
// Basic idea: first we deliver any bytes we already have in the
// _OutputBuffer, because we know they're good. Then, if asked to deliver
// more bytes, we read & transform a block at a time until either there are
// no bytes ready or we've delivered enough.
int bytesToDeliver = count;
int currentOutputIndex = offset;
int bytesToDeliver = buffer.Length;
int currentOutputIndex = 0;
Debug.Assert(_outputBuffer != null);
if (_outputBufferIndex != 0)
{
// we have some already-transformed bytes in the output buffer
if (_outputBufferIndex <= count)
if (_outputBufferIndex <= buffer.Length)
{
Buffer.BlockCopy(_outputBuffer, 0, buffer, offset, _outputBufferIndex);
_outputBuffer.AsSpan(0, _outputBufferIndex).CopyTo(buffer.Span);
bytesToDeliver -= _outputBufferIndex;
currentOutputIndex += _outputBufferIndex;
int toClear = _outputBuffer.Length - _outputBufferIndex;
Expand All @@ -302,14 +307,14 @@ private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, Canc
}
else
{
Buffer.BlockCopy(_outputBuffer, 0, buffer, offset, count);
Buffer.BlockCopy(_outputBuffer, count, _outputBuffer, 0, _outputBufferIndex - count);
_outputBufferIndex -= count;
_outputBuffer.AsSpan(0, buffer.Length).CopyTo(buffer.Span);
Buffer.BlockCopy(_outputBuffer, buffer.Length, _outputBuffer, 0, _outputBufferIndex - buffer.Length);
_outputBufferIndex -= buffer.Length;

int toClear = _outputBuffer.Length - _outputBufferIndex;
CryptographicOperations.ZeroMemory(new Span<byte>(_outputBuffer, _outputBufferIndex, toClear));

return (count);
return buffer.Length;
}
}
// _finalBlockTransformed == true implies we're at the end of the input stream
Expand All @@ -319,7 +324,7 @@ private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, Canc
// eventually, we'll just always return 0 here because there's no more to read
if (_finalBlockTransformed)
{
return (count - bytesToDeliver);
return buffer.Length - bytesToDeliver;
}
// ok, now loop until we've delivered enough or there's nothing available
int amountRead = 0;
Expand Down Expand Up @@ -373,7 +378,7 @@ private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, Canc
// Use ArrayPool.Shared instead of CryptoPool because the array is passed out.
tempOutputBuffer = ArrayPool<byte>.Shared.Rent(numWholeReadBlocks * _outputBlockSize);
numOutputBytes = _transform.TransformBlock(tempInputBuffer, 0, numWholeReadBlocksInBytes, tempOutputBuffer, 0);
Buffer.BlockCopy(tempOutputBuffer, 0, buffer, currentOutputIndex, numOutputBytes);
tempOutputBuffer.AsSpan(0, numOutputBytes).CopyTo(buffer.Span.Slice(currentOutputIndex));

// Clear what was written while we know how much that was
CryptographicOperations.ZeroMemory(new Span<byte>(tempOutputBuffer, 0, numOutputBytes));
Expand Down Expand Up @@ -429,22 +434,22 @@ private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, Canc

if (bytesToDeliver >= numOutputBytes)
{
Buffer.BlockCopy(_outputBuffer, 0, buffer, currentOutputIndex, numOutputBytes);
_outputBuffer.AsSpan(0, numOutputBytes).CopyTo(buffer.Span.Slice(currentOutputIndex));
CryptographicOperations.ZeroMemory(new Span<byte>(_outputBuffer, 0, numOutputBytes));
currentOutputIndex += numOutputBytes;
bytesToDeliver -= numOutputBytes;
}
else
{
Buffer.BlockCopy(_outputBuffer, 0, buffer, currentOutputIndex, bytesToDeliver);
_outputBuffer.AsSpan(0, bytesToDeliver).CopyTo(buffer.Span.Slice(currentOutputIndex));
_outputBufferIndex = numOutputBytes - bytesToDeliver;
Buffer.BlockCopy(_outputBuffer, bytesToDeliver, _outputBuffer, 0, _outputBufferIndex);
int toClear = _outputBuffer.Length - _outputBufferIndex;
CryptographicOperations.ZeroMemory(new Span<byte>(_outputBuffer, _outputBufferIndex, toClear));
return count;
return buffer.Length;
}
}
return count;
return buffer.Length;

ProcessFinalBlock:
// if so, then call TransformFinalBlock to get whatever is left
Expand All @@ -458,27 +463,27 @@ private async Task<int> ReadAsyncCore(byte[] buffer, int offset, int count, Canc
// now, return either everything we just got or just what's asked for, whichever is smaller
if (bytesToDeliver < _outputBufferIndex)
{
Buffer.BlockCopy(_outputBuffer, 0, buffer, currentOutputIndex, bytesToDeliver);
_outputBuffer.AsSpan(0, bytesToDeliver).CopyTo(buffer.Span.Slice(currentOutputIndex));
_outputBufferIndex -= bytesToDeliver;
Buffer.BlockCopy(_outputBuffer, bytesToDeliver, _outputBuffer, 0, _outputBufferIndex);
int toClear = _outputBuffer.Length - _outputBufferIndex;
CryptographicOperations.ZeroMemory(new Span<byte>(_outputBuffer, _outputBufferIndex, toClear));
return (count);
return buffer.Length;
}
else
{
Buffer.BlockCopy(_outputBuffer, 0, buffer, currentOutputIndex, _outputBufferIndex);
_outputBuffer.AsSpan(0, _outputBufferIndex).CopyTo(buffer.Span.Slice(currentOutputIndex));
bytesToDeliver -= _outputBufferIndex;
_outputBufferIndex = 0;
CryptographicOperations.ZeroMemory(_outputBuffer);
return (count - bytesToDeliver);
return buffer.Length - bytesToDeliver;
}
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
CheckWriteArguments(buffer, offset, count);
return WriteAsyncInternal(buffer, offset, count, cancellationToken);
return WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
}

public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) =>
Expand All @@ -487,7 +492,8 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati
public override void EndWrite(IAsyncResult asyncResult) =>
TaskToApm.End(asyncResult);

private async Task WriteAsyncInternal(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
/// <inheritdoc/>
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
// To avoid a race with a stream's position pointer & generating race
// conditions with internal buffer indexes in our own streams that
Expand All @@ -498,7 +504,7 @@ private async Task WriteAsyncInternal(byte[] buffer, int offset, int count, Canc
await AsyncActiveSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await WriteAsyncCore(buffer, offset, count, cancellationToken, useAsync: true).ConfigureAwait(false);
await WriteAsyncCore(buffer, cancellationToken, useAsync: true).ConfigureAwait(false);
}
finally
{
Expand All @@ -509,7 +515,7 @@ private async Task WriteAsyncInternal(byte[] buffer, int offset, int count, Canc
public override void Write(byte[] buffer, int offset, int count)
{
CheckWriteArguments(buffer, offset, count);
WriteAsyncCore(buffer, offset, count, default(CancellationToken), useAsync: false).AsTask().GetAwaiter().GetResult();
WriteAsyncCore(buffer.AsMemory(offset, count), default, useAsync: false).AsTask().GetAwaiter().GetResult();
}

private void CheckWriteArguments(byte[] buffer, int offset, int count)
Expand All @@ -519,22 +525,22 @@ private void CheckWriteArguments(byte[] buffer, int offset, int count)
throw new NotSupportedException(SR.NotSupported_UnwritableStream);
}

private async ValueTask WriteAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken, bool useAsync)
private async ValueTask WriteAsyncCore(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken, bool useAsync)
{
// write <= count bytes to the output stream, transforming as we go.
// Basic idea: using bytes in the _InputBuffer first, make whole blocks,
// transform them, and write them out. Cache any remaining bytes in the _InputBuffer.
int bytesToWrite = count;
int currentInputIndex = offset;
int bytesToWrite = buffer.Length;
int currentInputIndex = 0;
// if we have some bytes in the _InputBuffer, we have to deal with those first,
// so let's try to make an entire block out of it
if (_inputBufferIndex > 0)
{
Debug.Assert(_inputBuffer != null);
if (count >= _inputBlockSize - _inputBufferIndex)
if (buffer.Length >= _inputBlockSize - _inputBufferIndex)
{
// we have enough to transform at least a block, so fill the input block
Buffer.BlockCopy(buffer, offset, _inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex);
buffer.Slice(0, _inputBlockSize - _inputBufferIndex).CopyTo(_inputBuffer.AsMemory(_inputBufferIndex));
currentInputIndex += (_inputBlockSize - _inputBufferIndex);
bytesToWrite -= (_inputBlockSize - _inputBufferIndex);
_inputBufferIndex = _inputBlockSize;
Expand All @@ -544,8 +550,8 @@ private async ValueTask WriteAsyncCore(byte[] buffer, int offset, int count, Can
{
// not enough to transform a block, so just copy the bytes into the _InputBuffer
// and return
Buffer.BlockCopy(buffer, offset, _inputBuffer, _inputBufferIndex, count);
_inputBufferIndex += count;
buffer.Slice(0, buffer.Length).CopyTo(_inputBuffer.AsMemory(_inputBufferIndex));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
buffer.Slice(0, buffer.Length).CopyTo(_inputBuffer.AsMemory(_inputBufferIndex));
buffer.CopyTo(_inputBuffer.AsMemory(_inputBufferIndex));

_inputBufferIndex += buffer.Length;
return;
}
}
Expand Down Expand Up @@ -585,8 +591,7 @@ private async ValueTask WriteAsyncCore(byte[] buffer, int offset, int count, Can

try
{
numOutputBytes =
_transform.TransformBlock(buffer, currentInputIndex, numWholeBlocksInBytes, tempOutputBuffer, 0);
numOutputBytes = TransformBlock(_transform, buffer.Slice(currentInputIndex, numWholeBlocksInBytes), tempOutputBuffer, 0);

if (useAsync)
{
Expand Down Expand Up @@ -614,7 +619,7 @@ private async ValueTask WriteAsyncCore(byte[] buffer, int offset, int count, Can
{
Debug.Assert(_outputBuffer != null);
// do it the slow way
numOutputBytes = _transform.TransformBlock(buffer, currentInputIndex, _inputBlockSize, _outputBuffer, 0);
numOutputBytes = TransformBlock(_transform, buffer.Slice(currentInputIndex, _inputBlockSize), _outputBuffer, 0);

if (useAsync)
await _stream.WriteAsync(new ReadOnlyMemory<byte>(_outputBuffer, 0, numOutputBytes), cancellationToken).ConfigureAwait(false);
Expand All @@ -630,12 +635,43 @@ private async ValueTask WriteAsyncCore(byte[] buffer, int offset, int count, Can
Debug.Assert(_inputBuffer != null);
// In this case, we don't have an entire block's worth left, so store it up in the
// input buffer, which by now must be empty.
Buffer.BlockCopy(buffer, currentInputIndex, _inputBuffer, 0, bytesToWrite);
buffer.Slice(currentInputIndex, bytesToWrite).CopyTo(_inputBuffer);
_inputBufferIndex += bytesToWrite;
return;
}
}
return;

static int TransformBlock(ICryptoTransform transform, ReadOnlyMemory<byte> inputBuffer, byte[] outputBuffer, int outputOffset)
{
if (MemoryMarshal.TryGetArray(inputBuffer, out ArraySegment<byte> segment))
{
// Skip the copy if readonlymemory is actually an array.
Debug.Assert(segment.Array is not null);
return transform.TransformBlock(segment.Array, segment.Offset, inputBuffer.Length, outputBuffer, outputOffset);
}
else
{
var rentedBuffer = ArrayPool<byte>.Shared.Rent(inputBuffer.Length);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var shouldn't be used here, and please add the comment explaining why this isn't using CryptoPool:

Suggested change
var rentedBuffer = ArrayPool<byte>.Shared.Rent(inputBuffer.Length);
// Use ArrayPool.Shared instead of CryptoPool because the array is passed out.
byte[] rentedBuffer = ArrayPool<byte>.Shared.Rent(inputBuffer.Length);

try
{
inputBuffer.CopyTo(rentedBuffer);
int result = transform.TransformBlock(rentedBuffer, 0, inputBuffer.Length, outputBuffer, outputOffset);
CryptographicOperations.ZeroMemory(rentedBuffer.AsSpan(0, inputBuffer.Length));
ArrayPool<byte>.Shared.Return(rentedBuffer);
rentedBuffer = null;

return result;
}
catch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The catch block here is unnecessary. Instead, consider restructuring the outer try block as the following pseudocode:

byte[] rented = ArrayPool.Rent();
try
{
    input.CopyTo(rented);
    Transform(from: rented, to: output);
}
finally
{
    ZeroMem(rented);
}
ArrayPool.Return(rented);

By putting a finally block around the code where the rented buffer contains potentially sensitive data, we can ensure that the whole thing is zeroed out whether the operation completes successfully or fails.

We could also consider pinning the temporary buffer as part of this operation to provide further protection against copies of the data being made.

Note to other reviewers: Should we skip the array pool entirely and instead use pre-pinned arrays? If an adversary can force the crypto transform to fail, they can force the application to abandon a bunch of Gen2 arrays, which could lead to perf degradation.

{
CryptographicOperations.ZeroMemory(rentedBuffer.AsSpan(0, inputBuffer.Length));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the array is being pinned for the operation that's presumably so that it can be populated and cleared without GC compaction applying, so I'd expect the clear to be inside the pin.

rentedBuffer = null;

throw;
}
}
}
}

public void Clear()
Expand Down