diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs index 3f7a967a60a47..6e911a797121f 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs @@ -146,8 +146,15 @@ public override ValueTask ReadAsync(CancellationToken cancellationToken) } } - // Otherwise, queue the reader. - var reader = new AsyncOperation(parent._runContinuationsAsynchronously, cancellationToken); + // Otherwise, queue a reader. Note that in addition to checking whether synchronous continuations were requested, + // we also check whether the supplied cancellation token can be canceled. The writer calls UnregisterCancellation + // while holding the lock, and if a callback needs to be unregistered and is currently running, it needs to wait + // for that callback to complete so that the subsequent code knows it won't be contending with another thread + // trying to complete the operation. However, if we allowed a synchronous continuation from this operation, that + // cancellation callback could end up running arbitrary code, including code that called back into the reader or + // writer and tried to take the same lock held by the thread running UnregisterCancellation... deadlock. As such, + // we only allow synchronous continuations here if both a) the caller requested it and the token isn't cancelable. + var reader = new AsyncOperation(parent._runContinuationsAsynchronously | cancellationToken.CanBeCanceled, cancellationToken); parent._blockedReaders.EnqueueTail(reader); return reader.ValueTaskOfT; } @@ -193,8 +200,15 @@ public override ValueTask WaitToReadAsync(CancellationToken cancellationTo } } - // Otherwise, queue a reader. - var waiter = new AsyncOperation(parent._runContinuationsAsynchronously, cancellationToken); + // Otherwise, queue a reader. Note that in addition to checking whether synchronous continuations were requested, + // we also check whether the supplied cancellation token can be canceled. The writer calls UnregisterCancellation + // while holding the lock, and if a callback needs to be unregistered and is currently running, it needs to wait + // for that callback to complete so that the subsequent code knows it won't be contending with another thread + // trying to complete the operation. However, if we allowed a synchronous continuation from this operation, that + // cancellation callback could end up running arbitrary code, including code that called back into the reader or + // writer and tried to take the same lock held by the thread running UnregisterCancellation... deadlock. As such, + // we only allow synchronous continuations here if both a) the caller requested it and the token isn't cancelable. + var waiter = new AsyncOperation(parent._runContinuationsAsynchronously | cancellationToken.CanBeCanceled, cancellationToken); ChannelUtilities.QueueWaiter(ref _parent._waitingReadersTail, waiter); return waiter.ValueTaskOfT; } diff --git a/src/libraries/System.Threading.Channels/tests/BoundedChannelTests.cs b/src/libraries/System.Threading.Channels/tests/BoundedChannelTests.cs index ea27c1482901b..19c28a248fc8e 100644 --- a/src/libraries/System.Threading.Channels/tests/BoundedChannelTests.cs +++ b/src/libraries/System.Threading.Channels/tests/BoundedChannelTests.cs @@ -390,16 +390,18 @@ public async Task WaitToWriteAsync_AfterFullThenRead_ReturnsTrue() } [Theory] - [InlineData(false)] - [InlineData(true)] - public void AllowSynchronousContinuations_WaitToReadAsync_ContinuationsInvokedAccordingToSetting(bool allowSynchronousContinuations) + [MemberData(nameof(ThreeBools))] + public void AllowSynchronousContinuations_Reading_ContinuationsInvokedAccordingToSetting(bool allowSynchronousContinuations, bool cancelable, bool waitToReadAsync) { var c = Channel.CreateBounded(new BoundedChannelOptions(1) { AllowSynchronousContinuations = allowSynchronousContinuations }); + CancellationToken ct = cancelable ? new CancellationTokenSource().Token : CancellationToken.None; + int expectedId = Environment.CurrentManagedThreadId; - Task r = c.Reader.WaitToReadAsync().AsTask().ContinueWith(_ => + Task t = waitToReadAsync ? (Task)c.Reader.WaitToReadAsync(ct).AsTask() : c.Reader.ReadAsync(ct).AsTask(); + Task r = t.ContinueWith(_ => { - Assert.Equal(allowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); + Assert.Equal(allowSynchronousContinuations && !cancelable, expectedId == Environment.CurrentManagedThreadId); }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); Assert.True(c.Writer.WriteAsync(42).IsCompletedSuccessfully); diff --git a/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs b/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs index 42721f12c4bea..ed88b1bed9fcc 100644 --- a/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs +++ b/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs @@ -24,6 +24,12 @@ public abstract partial class ChannelTestBase : TestBase protected virtual bool RequiresSingleWriter => false; protected virtual bool BuffersItems => true; + public static IEnumerable ThreeBools => + from b1 in new[] { false, true } + from b2 in new[] { false, true } + from b3 in new[] { false, true } + select new object[] { b1, b2, b3 }; + [Fact] public void ValidateDebuggerAttributes() {