diff --git a/Sources/GRPCCore/Internal/Concurrency Primitives/Lock.swift b/Sources/GRPCCore/Internal/Concurrency Primitives/Lock.swift index 0cadb250e..5ee259adc 100644 --- a/Sources/GRPCCore/Internal/Concurrency Primitives/Lock.swift +++ b/Sources/GRPCCore/Internal/Concurrency Primitives/Lock.swift @@ -253,6 +253,40 @@ public struct _LockedValueBox { public func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { return try self.storage.withLockedValue(mutate) } + + /// An unsafe view over the locked value box. + /// + /// Prefer ``withLockedValue(_:)`` where possible. + public var unsafe: Unsafe { + Unsafe(storage: self.storage) + } + + public struct Unsafe { + @usableFromInline + let storage: LockStorage + + /// Manually acquire the lock. + @inlinable + public func lock() { + self.storage.lock() + } + + /// Manually release the lock. + @inlinable + public func unlock() { + self.storage.unlock() + } + + /// Mutate the value, assuming the lock has been acquired manually. + @inlinable + public func withValueAssumingLockIsAcquired( + _ mutate: (inout Value) throws -> T + ) rethrows -> T { + return try self.storage.withUnsafeMutablePointerToHeader { value in + try mutate(&value.pointee) + } + } + } } extension _LockedValueBox: Sendable where Value: Sendable {} diff --git a/Sources/GRPCCore/Streaming/Internal/BroadcastAsyncSequence.swift b/Sources/GRPCCore/Streaming/Internal/BroadcastAsyncSequence.swift index 738cfb86d..a56380066 100644 --- a/Sources/GRPCCore/Streaming/Internal/BroadcastAsyncSequence.swift +++ b/Sources/GRPCCore/Streaming/Internal/BroadcastAsyncSequence.swift @@ -261,20 +261,26 @@ final class _BroadcastSequenceStorage: Sendable { func nextElement( forSubscriber id: _BroadcastSequenceStateMachine.Subscriptions.ID ) async throws -> Element? { - let onNext = self._state.withLockedValue { $0.nextElement(forSubscriber: id) } + return try await withTaskCancellationHandler { + self._state.unsafe.lock() + let onNext = self._state.unsafe.withValueAssumingLockIsAcquired { + $0.nextElement(forSubscriber: id) + } - switch onNext { - case .return(let returnAndProduceMore): - returnAndProduceMore.producers.resume() - return try returnAndProduceMore.nextResult.get() + switch onNext { + case .return(let returnAndProduceMore): + self._state.unsafe.unlock() + returnAndProduceMore.producers.resume() + return try returnAndProduceMore.nextResult.get() - case .suspend: - return try await withTaskCancellationHandler { + case .suspend: return try await withCheckedThrowingContinuation { continuation in - let onSetContinuation = self._state.withLockedValue { state in + let onSetContinuation = self._state.unsafe.withValueAssumingLockIsAcquired { state in state.setContinuation(continuation, forSubscription: id) } + self._state.unsafe.unlock() + switch onSetContinuation { case .resume(let continuation, let result): continuation.resume(with: result) @@ -282,17 +288,17 @@ final class _BroadcastSequenceStorage: Sendable { () } } - } onCancel: { - let onCancel = self._state.withLockedValue { state in - state.cancelSubscription(withID: id) - } + } + } onCancel: { + let onCancel = self._state.withLockedValue { state in + state.cancelSubscription(withID: id) + } - switch onCancel { - case .resume(let continuation, let result): - continuation.resume(with: result) - case .none: - () - } + switch onCancel { + case .resume(let continuation, let result): + continuation.resume(with: result) + case .none: + () } } } @@ -572,9 +578,18 @@ struct _BroadcastSequenceStateMachine: Sendable { self.producerToken += 1 onYield = .suspend(token) } else { - // No consumers are slow. Remove the oldest value. + // No consumers are slow, some subscribers exist, a subset of which are waiting + // for a value. Drop the oldest value and resume the fastest consumers. self.elements.removeFirst() - onYield = .none + let continuations = self.subscriptions.takeContinuations().map { + ConsumerContinuations(continuations: $0, result: .success(element)) + } + + if let continuations = continuations { + onYield = .resume(continuations) + } else { + onYield = .none + } } case self.subscriptions.count: