diff --git a/.github/workflows/ReleaseNotes.md b/.github/workflows/ReleaseNotes.md index f6ec0cc0f..554b146f5 100644 --- a/.github/workflows/ReleaseNotes.md +++ b/.github/workflows/ReleaseNotes.md @@ -1,5 +1,6 @@ * [Core] MQTT Packets being sent over web socket transport are now setting the web socket frame boundaries correctly (#1499). * [Core] Add support for attaching and detaching events from different threads. +* [Core] Fixed a deadlock in _AsyncLock_ implementation (#1520). * [Client] Keep alive mechanism now uses the configured timeout value from the options (thanks to @Stannieman, #1495). * [Client] The _PingAsync_ will fallback to the timeout specified in the client options when the cancellation token cannot be cancelled. * [Server] A DISCONNECT packet is no longer sent to MQTT clients < 5.0.0 (thanks to @logicaloud, #1506). diff --git a/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs b/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs new file mode 100644 index 000000000..5b50c3e36 --- /dev/null +++ b/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Jobs; +using MQTTnet.Internal; + +namespace MQTTnet.Benchmarks +{ + [SimpleJob(RuntimeMoniker.Net60)] + [MemoryDiagnoser] + public class AsyncLockBenchmark : BaseBenchmark + { + [Benchmark] + public async Task Synchronize_100_Tasks() + { + const int tasksCount = 100; + + var tasks = new Task[tasksCount]; + var asyncLock = new AsyncLock(); + var globalI = 0; + + for (var i = 0; i < tasksCount; i++) + { + tasks[i] = Task.Run( + async () => + { + using (await asyncLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) + { + var localI = globalI; + await Task.Delay(5).ConfigureAwait(false); // Increase the chance for wrong data. + localI++; + globalI = localI; + } + }); + } + + await Task.WhenAll(tasks).ConfigureAwait(false); + + if (globalI != tasksCount) + { + throw new Exception($"Code is broken ({globalI})!"); + } + } + + [Benchmark] + public async Task Wait_100_000_Times() + { + var asyncLock = new AsyncLock(); + + using (var cancellationToken = new CancellationTokenSource()) + { + for (var i = 0; i < 100000; i++) + { + using (await asyncLock.WaitAsync(cancellationToken.Token).ConfigureAwait(false)) + { + } + } + } + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj b/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj index eae2a1577..63f0fbcac 100644 --- a/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj +++ b/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj @@ -12,7 +12,7 @@ - + diff --git a/Source/MQTTnet.Benchmarks/Program.cs b/Source/MQTTnet.Benchmarks/Program.cs index 13921d887..1fcc670da 100644 --- a/Source/MQTTnet.Benchmarks/Program.cs +++ b/Source/MQTTnet.Benchmarks/Program.cs @@ -29,6 +29,7 @@ public static void Main(string[] args) Console.WriteLine("c = SubscribeBenchmark"); Console.WriteLine("d = UnsubscribeBenchmark"); Console.WriteLine("e = MessageDeliveryBenchmark"); + Console.WriteLine("f = AsyncLockBenchmark"); var pressedKey = Console.ReadKey(true); switch (pressedKey.KeyChar) @@ -75,6 +76,9 @@ public static void Main(string[] args) case 'e': BenchmarkRunner.Run(); break; + case 'f': + BenchmarkRunner.Run(); + break; } Console.ReadLine(); diff --git a/Source/MQTTnet.TestApp/AsyncLockTest.cs b/Source/MQTTnet.TestApp/AsyncLockTest.cs new file mode 100644 index 000000000..4f943b197 --- /dev/null +++ b/Source/MQTTnet.TestApp/AsyncLockTest.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading; +using System.Threading.Tasks; +using MQTTnet.Internal; + +namespace MQTTnet.TestApp +{ + public sealed class AsyncLockTest + { + public async Task Run() + { + var asyncLock = new AsyncLock(); + + using (var cancellationToken = new CancellationTokenSource()) + { + for (var i = 0; i < 100000; i++) + { + using (await asyncLock.WaitAsync(cancellationToken.Token).ConfigureAwait(false)) + { + } + } + } + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet.TestApp/Program.cs b/Source/MQTTnet.TestApp/Program.cs index f1a325814..8ac2c9b50 100644 --- a/Source/MQTTnet.TestApp/Program.cs +++ b/Source/MQTTnet.TestApp/Program.cs @@ -28,6 +28,7 @@ public static void Main() Console.WriteLine("c = Start QoS 0 benchmark"); Console.WriteLine("d = Start server with logging"); Console.WriteLine("e = Start Message Throughput Test"); + Console.WriteLine("f = Start AsyncLock Test"); var pressedKey = Console.ReadKey(true); if (pressedKey.KeyChar == '1') @@ -88,6 +89,10 @@ public static void Main() { Task.Run(new MessageThroughputTest().Run); } + else if (pressedKey.KeyChar == 'f') + { + Task.Run(new AsyncLockTest().Run); + } Thread.Sleep(Timeout.Infinite); } diff --git a/Source/MQTTnet.Tests/Internal/AsyncLock_Tests.cs b/Source/MQTTnet.Tests/Internal/AsyncLock_Tests.cs index 69cf79a35..fae812272 100644 --- a/Source/MQTTnet.Tests/Internal/AsyncLock_Tests.cs +++ b/Source/MQTTnet.Tests/Internal/AsyncLock_Tests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -13,6 +14,80 @@ namespace MQTTnet.Tests.Internal [TestClass] public class AsyncLock_Tests { + [TestMethod] + public void Lock_10_Parallel_Tasks() + { + const int ThreadsCount = 10; + + var threads = new Task[ThreadsCount]; + var @lock = new AsyncLock(); + var globalI = 0; + for (var i = 0; i < ThreadsCount; i++) + { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + threads[i] = Task.Run( + async () => + { + using (await @lock.WaitAsync(CancellationToken.None)) + { + var localI = globalI; + await Task.Delay(10); // Increase the chance for wrong data. + localI++; + globalI = localI; + } + }); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + } + + Task.WaitAll(threads); + Assert.AreEqual(ThreadsCount, globalI); + } + + [TestMethod] + public void Lock_10_Parallel_Tasks_With_Dispose_Doesnt_Lockup() + { + const int ThreadsCount = 10; + + var threads = new Task[ThreadsCount]; + var @lock = new AsyncLock(); + var globalI = 0; + for (var i = 0; i < ThreadsCount; i++) + { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + threads[i] = Task.Run( + async () => + { + using (await @lock.WaitAsync(CancellationToken.None)) + { + var localI = globalI; + await Task.Delay(10); // Increase the chance for wrong data. + localI++; + globalI = localI; + } + }) + .ContinueWith( + x => + { + if (globalI == 5) + { + @lock.Dispose(); + @lock = new AsyncLock(); + } + + if (x.Exception != null) + { + Debug.WriteLine(x.Exception.GetBaseException().GetType().Name); + } + }); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + } + + Task.WaitAll(threads); + + // Expect only 6 because the others are failing due to disposal (if (globalI == 5)). + Assert.AreEqual(6, globalI); + } + [TestMethod] public async Task Lock_Serial_Calls() { @@ -45,58 +120,32 @@ public async Task Test_Cancellation() } } - //[TestMethod] - //public async Task Test_Cancellation_With_Later_Access() - //{ - // var @lock = new AsyncLock(); - - // var releaser = await @lock.WaitAsync().ConfigureAwait(false); - - // try - // { - // using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(3))) - // { - // await @lock.WaitAsync(cts.Token).ConfigureAwait(false); - // } - // } - // catch (OperationCanceledException) - // { - // } - - // releaser.Dispose(); - - // using (await @lock.WaitAsync().ConfigureAwait(false)) - // { - // // When the method finished, the thread got access. - // } - //} - [TestMethod] - public void Lock_10_Parallel_Tasks() + public async Task Test_Cancellation_With_Later_Access() { - const int ThreadsCount = 10; + var asyncLock = new AsyncLock(); - var threads = new Task[ThreadsCount]; - var @lock = new AsyncLock(); - var globalI = 0; - for (var i = 0; i < ThreadsCount; i++) + var releaser = await asyncLock.WaitAsync(CancellationToken.None).ConfigureAwait(false); + + try { -#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - threads[i] = Task.Run(async () => + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(3))) { - using (var releaser = await @lock.WaitAsync(CancellationToken.None)) - { - var localI = globalI; - await Task.Delay(10); // Increase the chance for wrong data. - localI++; - globalI = localI; - } - }); -#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + await asyncLock.WaitAsync(timeout.Token).ConfigureAwait(false); + } + + Assert.Fail("Exception should be thrown!"); + } + catch (OperationCanceledException) + { } - Task.WaitAll(threads); - Assert.AreEqual(ThreadsCount, globalI); + releaser.Dispose(); + + using (await asyncLock.WaitAsync(CancellationToken.None).ConfigureAwait(false)) + { + // When the method finished, the thread got access. + } } } -} +} \ No newline at end of file diff --git a/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs index 11f3ecbec..d8803349c 100644 --- a/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs +++ b/Source/MQTTnet/Implementations/PlatformAbstractionLayer.cs @@ -11,11 +11,11 @@ namespace MQTTnet.Implementations public static class PlatformAbstractionLayer { #if NET452 - public static Task CompletedTask => Task.FromResult(0); + public static Task CompletedTask { get; } = Task.FromResult(0); public static byte[] EmptyByteArray { get; } = new byte[0]; #else - public static Task CompletedTask => Task.CompletedTask; + public static Task CompletedTask { get; } = Task.CompletedTask; public static byte[] EmptyByteArray { get; } = Array.Empty(); #endif diff --git a/Source/MQTTnet/Internal/AsyncLock.cs b/Source/MQTTnet/Internal/AsyncLock.cs index 748ca0d8c..17384ef37 100644 --- a/Source/MQTTnet/Internal/AsyncLock.cs +++ b/Source/MQTTnet/Internal/AsyncLock.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; @@ -10,73 +12,173 @@ namespace MQTTnet.Internal { public sealed class AsyncLock : IDisposable { - readonly Task _releaser; - readonly object _syncRoot = new object(); + /* + * This async supporting lock does not support reentrancy! + */ + + readonly List _queuedTasks = new List(64); + readonly Task _releaserTaskWithDirectApproval; - SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); + readonly Releaser _releaserWithDirectApproval; + + readonly object _syncRoot = new object(); + bool _isDisposed; + public AsyncLock() { - _releaser = Task.FromResult((IDisposable)new Releaser(this)); + _releaserWithDirectApproval = new Releaser(this, null, CancellationToken.None); + _releaserTaskWithDirectApproval = Task.FromResult((IDisposable)_releaserWithDirectApproval); } public void Dispose() { lock (_syncRoot) { - _semaphore?.Dispose(); - _semaphore = null; + foreach (var waitingTask in _queuedTasks) + { + waitingTask.Fail(new ObjectDisposedException(nameof(AsyncLock))); + } + + _queuedTasks.Clear(); + + _isDisposed = true; } } public Task WaitAsync(CancellationToken cancellationToken) { - Task task; + var hasDirectApproval = false; + Releaser releaser; - // This lock is required to avoid ObjectDisposedExceptions. - // These are fired when this lock gets disposed (and thus the semaphore) - // and a worker thread tries to call this method at the same time. - // Another way would be catching all ObjectDisposedExceptions but this situation happens - // quite often when clients are disconnecting. lock (_syncRoot) { - task = _semaphore?.WaitAsync(cancellationToken); - } + if (_isDisposed) + { + throw new ObjectDisposedException(nameof(AsyncLock)); + } + + if (_queuedTasks.Count == 0) + { + // There is no other waiting task apart from the current one. + // So we can approve the current task directly. + releaser = _releaserWithDirectApproval; + hasDirectApproval = true; + Debug.WriteLine("AsyncLock: Task -1 directly approved."); + } + else + { + releaser = new Releaser(this, new TaskCompletionSource(), cancellationToken); + } - if (task == null) - { - throw new ObjectDisposedException("The AsyncLock is disposed."); + _queuedTasks.Add(releaser); } - if (task.Status == TaskStatus.RanToCompletion) + if (!hasDirectApproval) { - return _releaser; + return releaser.Task; } - // Wait for the _WaitAsync_ method and return the releaser afterwards. - return task.ContinueWith((_, state) => (IDisposable)state, _releaser.Result, cancellationToken, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + return _releaserTaskWithDirectApproval; } - void Release() + void Release(Releaser releaser) { lock (_syncRoot) { - _semaphore?.Release(); + if (_isDisposed) + { + // There is no much left to do! + return; + } + + var activeTask = _queuedTasks[0]; + if (!ReferenceEquals(activeTask, releaser)) + { + throw new InvalidOperationException("The active task must be the current releaser."); + } + + _queuedTasks.RemoveAt(0); + + while (_queuedTasks.Count > 0) + { + var nextTask = _queuedTasks[0]; + if (!nextTask.IsPending) + { + // Dequeue all canceled or failed tasks. + _queuedTasks.RemoveAt(0); + continue; + } + + nextTask.Approve(); + return; + } + + Debug.WriteLine("AsyncLock: No Task pending."); } } sealed class Releaser : IDisposable { - readonly AsyncLock _lock; + readonly AsyncLock _asyncLock; + readonly CancellationToken _cancellationToken; + readonly int _id; + readonly TaskCompletionSource _promise; + + // ReSharper disable once FieldCanBeMadeReadOnly.Local + CancellationTokenRegistration _cancellationTokenRegistration; - internal Releaser(AsyncLock @lock) + internal Releaser(AsyncLock asyncLock, TaskCompletionSource promise, CancellationToken cancellationToken) { - _lock = @lock; + _asyncLock = asyncLock ?? throw new ArgumentNullException(nameof(asyncLock)); + _promise = promise; + _cancellationToken = cancellationToken; + + if (cancellationToken.CanBeCanceled) + { + _cancellationTokenRegistration = cancellationToken.Register(Cancel); + } + + _id = promise?.Task.Id ?? -1; + + Debug.WriteLine($"AsyncLock: Task {_id} queued."); + } + + public bool IsPending => _promise != null && !_promise.Task.IsCanceled && !_promise.Task.IsFaulted && !_promise.Task.IsCompleted; + + public Task Task => _promise?.Task; + + public void Approve() + { + _promise?.TrySetResult(this); + + Debug.WriteLine($"AsyncLock: Task {_id} approved."); } public void Dispose() { - _lock.Release(); + if (_cancellationToken.CanBeCanceled) + { + _cancellationTokenRegistration.Dispose(); + } + + Debug.WriteLine($"AsyncLock: Task {_id} completed."); + + _asyncLock.Release(this); + } + + public void Fail(Exception exception) + { + _promise?.TrySetException(exception); + + Debug.WriteLine($"AsyncLock: Task {_id} failed ({exception.GetType().Name})."); + } + + void Cancel() + { + _promise?.TrySetCanceled(); + + Debug.WriteLine($"AsyncLock: Task {_id} canceled."); } } }