From 3d4410d3e338df8550b9168e2315154193d8f3e0 Mon Sep 17 00:00:00 2001 From: Oleksandr Poliakov <31327136+sanych-sun@users.noreply.github.com> Date: Wed, 12 Nov 2025 13:36:07 -0800 Subject: [PATCH] CSHARP-5777: Avoid ThreadPool-dependent IO methods in sync API (#1805) --- .../Core/Misc/StreamExtensionMethods.cs | 179 ++++++++++-------- .../Core/Connections/BinaryConnectionTests.cs | 15 +- .../Core/Misc/StreamExtensionMethodsTests.cs | 82 ++++++-- .../ServerDiscoveryAndMonitoringProseTests.cs | 2 +- 4 files changed, 172 insertions(+), 106 deletions(-) diff --git a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs index 93e2da1a1ba..b552f8bf372 100644 --- a/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs +++ b/src/MongoDB.Driver/Core/Misc/StreamExtensionMethods.cs @@ -36,44 +36,15 @@ public static void EfficientCopyTo(this Stream input, Stream output) } } - public static int Read(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) - { - try - { - using var manualResetEvent = new ManualResetEventSlim(); - var readOperation = stream.BeginRead( - buffer, - offset, - count, - state => ((ManualResetEventSlim)state.AsyncState).Set(), - manualResetEvent); - - if (readOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken)) - { - return stream.EndRead(readOperation); - } - } - catch (OperationCanceledException) - { - // Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed. - } - catch (ObjectDisposedException) - { - throw new IOException(); - } - - try - { - stream.Dispose(); - } - catch - { - // Ignore any exceptions - } - - cancellationToken.ThrowIfCancellationRequested(); - throw new TimeoutException(); - } + public static int Read(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) => + ExecuteOperationWithTimeout( + stream, + (str, state) => str.Read(state.Buffer, state.Offset, state.Count), + buffer, + offset, + count, + timeout, + cancellationToken); public static async Task ReadAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) { @@ -217,46 +188,19 @@ public static async Task ReadBytesAsync(this Stream stream, byte[] destination, } } - public static void Write(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) - { - try - { - using var manualResetEvent = new ManualResetEventSlim(); - var writeOperation = stream.BeginWrite( - buffer, - offset, - count, - state => ((ManualResetEventSlim)state.AsyncState).Set(), - manualResetEvent); - - if (writeOperation.IsCompleted || manualResetEvent.Wait(timeout, cancellationToken)) + public static void Write(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) => + ExecuteOperationWithTimeout( + stream, + (str, state) => { - stream.EndWrite(writeOperation); - return; - } - } - catch (OperationCanceledException) - { - // Have to suppress OperationCanceledException here, it will be thrown after the stream will be disposed. - } - catch (ObjectDisposedException) - { - // It's possible to get ObjectDisposedException when the connection pool was closed with interruptInUseConnections set to true. - throw new IOException(); - } - - try - { - stream.Dispose(); - } - catch - { - // Ignore any exceptions - } - - cancellationToken.ThrowIfCancellationRequested(); - throw new TimeoutException(); - } + str.Write(state.Buffer, state.Offset, state.Count); + return true; + }, + buffer, + offset, + count, + timeout, + cancellationToken); public static async Task WriteAsync(this Stream stream, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) { @@ -325,5 +269,86 @@ public static async Task WriteBytesAsync(this Stream stream, OperationContext op count -= bytesToWrite; } } + + private static TResult ExecuteOperationWithTimeout(Stream stream, Func operation, byte[] buffer, int offset, int count, TimeSpan timeout, CancellationToken cancellationToken) + { + StreamDisposeCallbackState callbackState = null; + Timer timer = null; + CancellationTokenRegistration cancellationSubscription = default; + if (timeout != Timeout.InfiniteTimeSpan) + { + callbackState = new StreamDisposeCallbackState(stream); + timer = new Timer(DisposeStreamCallback, callbackState, timeout, Timeout.InfiniteTimeSpan); + } + + if (cancellationToken.CanBeCanceled) + { + callbackState ??= new StreamDisposeCallbackState(stream); + cancellationSubscription = cancellationToken.Register(DisposeStreamCallback, callbackState); + } + + try + { + var result = operation(stream, (buffer, offset, count)); + if (callbackState?.TryChangeStateFromInProgress(OperationState.Done) == false) + { + // if cannot change the state - then the stream was/will be disposed, throw here + throw new IOException(); + } + + return result; + } + catch (IOException) + { + if (callbackState?.OperationState == OperationState.Interrupted) + { + cancellationToken.ThrowIfCancellationRequested(); + throw new TimeoutException(); + } + + throw; + } + finally + { + timer?.Dispose(); + cancellationSubscription.Dispose(); + } + + static void DisposeStreamCallback(object state) + { + var disposeCallbackState = (StreamDisposeCallbackState)state; + if (!disposeCallbackState.TryChangeStateFromInProgress(OperationState.Interrupted)) + { + // If the state can't be changed - then I/O had already succeeded + return; + } + + try + { + disposeCallbackState.Stream.Dispose(); + } + catch (Exception) + { + // callbacks should not fail, suppress any exceptions here + } + } + } + + private record StreamDisposeCallbackState(Stream Stream) + { + private int _operationState = 0; + + public OperationState OperationState => (OperationState)_operationState; + + public bool TryChangeStateFromInProgress(OperationState newState) => + Interlocked.CompareExchange(ref _operationState, (int)newState, (int)OperationState.InProgress) == (int)OperationState.InProgress; + } + + private enum OperationState + { + InProgress = 0, + Done, + Interrupted, + } } } diff --git a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs index f9e7cbdb952..374a2a9d226 100644 --- a/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Connections/BinaryConnectionTests.cs @@ -713,19 +713,8 @@ public async Task SendMessage_should_put_the_message_on_the_stream_and_raise_the private void SetupStreamRead(Mock streamMock, TaskCompletionSource tcs) { - streamMock.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] _, int __, int ___, AsyncCallback callback, object state) => - { - var innerTcs = new TaskCompletionSource(state); - tcs.Task.ContinueWith(t => - { - innerTcs.TrySetException(t.Exception.InnerException); - callback(innerTcs.Task); - }); - return innerTcs.Task; - }); - streamMock.Setup(s => s.EndRead(It.IsAny())) - .Returns(x => ((Task)x).GetAwaiter().GetResult()); + streamMock.Setup(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((byte[] _, int __, int ___) => tcs.Task.GetAwaiter().GetResult()); streamMock.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(tcs.Task); streamMock.Setup(s => s.Close()).Callback(() => tcs.TrySetException(new ObjectDisposedException("stream"))); diff --git a/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs b/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs index 8da7e5f7de8..cd51bdd785d 100644 --- a/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs +++ b/tests/MongoDB.Driver.Tests/Core/Misc/StreamExtensionMethodsTests.cs @@ -1,4 +1,4 @@ -/* Copyright 2013-present MongoDB Inc. +/* Copyright 2010-present MongoDB Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -90,20 +90,18 @@ public async Task ReadBytes_with_byte_array_should_have_expected_effect_for_part var bytes = new byte[] { 1, 2, 3 }; var n = 0; var position = 0; - Task ReadPartial (byte[] buffer, int offset, int count) + int ReadPartial (byte[] buffer, int offset, int count) { var length = partition[n++]; Buffer.BlockCopy(bytes, position, buffer, offset, length); position += length; - return Task.FromResult(length); + return length; } mockStream.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count)); - mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count)); - mockStream.Setup(s => s.EndRead(It.IsAny())) - .Returns(x => ((Task)x).GetAwaiter().GetResult()); + .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count))); + mockStream.Setup(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count)); var destination = new byte[3]; if (async) @@ -203,6 +201,49 @@ await Record.ExceptionAsync(() => stream.ReadBytesAsync(OperationContext.NoTimeo .ParamName.Should().Be("stream"); } + [Theory] + [ParameterAttributeData] + public async Task ReadBytes_with_byte_array_throws_on_timeout([Values(true, false)]bool async) + { + var streamMock = new Mock(); + SetupStreamRead(streamMock); + var stream = streamMock.Object; + + var destination = new byte[2]; + var timeout = TimeSpan.FromMilliseconds(10); + + var exception = async ? + await Record.ExceptionAsync(() => stream.ReadAsync(destination, 0, 2, timeout, CancellationToken.None)) : + Record.Exception(() => stream.Read(destination, 0, 2, timeout, CancellationToken.None)); + + exception.Should().BeOfType(); + } + + [Theory] + [ParameterAttributeData] + public async Task ReadBytes_with_byte_array_throws_on_cancellation([Values(true, false)]bool async) + { + var streamMock = new Mock(); + SetupStreamRead(streamMock); + var stream = streamMock.Object; + + var destination = new byte[2]; + using var cancellationTokenSource = new CancellationTokenSource(10); + + var exception = async ? + await Record.ExceptionAsync(() => stream.ReadAsync(destination, 0, 2, Timeout.InfiniteTimeSpan, cancellationTokenSource.Token)) : + Record.Exception(() => stream.Read(destination, 0, 2, Timeout.InfiniteTimeSpan, cancellationTokenSource.Token)); + + if (async) + { + exception.Should().BeOfType(); + } + else + { + exception.Should().BeOfType(); + } + } + [Theory] [InlineData(true, 0, new byte[] { 0, 0 })] [InlineData(true, 1, new byte[] { 1, 0 })] @@ -267,20 +308,18 @@ public async Task ReadBytes_with_byte_buffer_should_have_expected_effect_for_par var destination = new ByteArrayBuffer(new byte[3], 3); var n = 0; var position = 0; - Task ReadPartial (byte[] buffer, int offset, int count) + int ReadPartial (byte[] buffer, int offset, int count) { var length = partition[n++]; Buffer.BlockCopy(bytes, position, buffer, offset, length); position += length; - return Task.FromResult(length); + return length; } mockStream.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadPartial(buffer, offset, count)); - mockStream.Setup(s => s.BeginRead(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((byte[] buffer, int offset, int count, AsyncCallback callback, object state) => ReadPartial(buffer, offset, count)); - mockStream.Setup(s => s.EndRead(It.IsAny())) - .Returns(x => ((Task)x).GetAwaiter().GetResult()); + .Returns((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(ReadPartial(buffer, offset, count))); + mockStream.Setup(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((byte[] buffer, int offset, int count) => ReadPartial(buffer, offset, count)); if (async) { @@ -533,5 +572,18 @@ private Mock CreateMockByteBuffer(int length) mockBuffer.SetupGet(b => b.Length).Returns(length); return mockBuffer; } + + private void SetupStreamRead(Mock streamMock, TaskCompletionSource readTaskCompletionSource = null) + { + readTaskCompletionSource ??= new TaskCompletionSource(); + streamMock.Setup(s => s.Close()).Callback(() => + { + readTaskCompletionSource.SetException(new IOException()); + }); + streamMock.Setup(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())).Returns(() => + readTaskCompletionSource.Task.GetAwaiter().GetResult()); + streamMock.Setup(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).Returns(() => + readTaskCompletionSource.Task); + } } } diff --git a/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs b/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs index 4d61e9c7b60..7bdaba366f9 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/server-discovery-and-monitoring/ServerDiscoveryAndMonitoringProseTests.cs @@ -96,7 +96,7 @@ public void Heartbeat_should_be_emitted_before_connection_open() var mockStream = new Mock(); mockStream - .Setup(s => s.BeginWrite(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(s => s.Write(It.IsAny(), It.IsAny(), It.IsAny())) .Callback(() => EnqueueEvent(HelloReceivedEvent)) .Throws(new Exception("Stream is closed."));