From c28197f21278d2b679ff821ec89026462750e1f9 Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Mon, 23 Jun 2025 14:07:12 +0800 Subject: [PATCH 1/8] init --- .../SignalRClient/ConnectionProtocol.swift | 9 ++ Sources/SignalRClient/HttpConnection.swift | 42 ++++++++-- Sources/SignalRClient/HubConnection.swift | 83 ++++++++++++++----- .../SignalRClient/HubConnectionBuilder.swift | 2 +- Sources/SignalRClient/MessageBuffer.swift | 23 ++++- .../HubConnection+OnResultTests.swift | 2 +- .../HubConnection+OnTests.swift | 2 +- 7 files changed, 133 insertions(+), 30 deletions(-) diff --git a/Sources/SignalRClient/ConnectionProtocol.swift b/Sources/SignalRClient/ConnectionProtocol.swift index 291770c..6c1ce13 100644 --- a/Sources/SignalRClient/ConnectionProtocol.swift +++ b/Sources/SignalRClient/ConnectionProtocol.swift @@ -1,6 +1,13 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +enum ConnectionFeature: String, CaseIterable { + case Reconnect = "reconnect" + case Resend = "resend" + case Disconnected = "disconnected" + // Add more feature keys as needed +} + protocol ConnectionProtocol: AnyObject, Sendable { func onReceive(_ handler: @escaping Transport.OnReceiveHandler) async func onClose(_ handler: @escaping Transport.OnCloseHander) async @@ -8,4 +15,6 @@ protocol ConnectionProtocol: AnyObject, Sendable { func send(_ data: StringOrData) async throws func stop(error: Error?) async var inherentKeepAlive: Bool { get async } + var features: [ConnectionFeature: Any] { get async } + func setFeature(feature: ConnectionFeature, value: Any) async } \ No newline at end of file diff --git a/Sources/SignalRClient/HttpConnection.swift b/Sources/SignalRClient/HttpConnection.swift index 8bb2fc1..f6d6f8d 100644 --- a/Sources/SignalRClient/HttpConnection.swift +++ b/Sources/SignalRClient/HttpConnection.swift @@ -84,7 +84,7 @@ actor HttpConnection: ConnectionProtocol { private var accessTokenFactory: (@Sendable () async throws -> String?)? private var inherentKeepAlivePrivate: Bool = false - public var features: [String: Any] = [:] + public var features: [ConnectionFeature: Any] = [:] public var baseUrl: String public var connectionId: String? public var inherentKeepAlive: Bool { @@ -359,9 +359,31 @@ actor HttpConnection: ConnectionProtocol { private func startTransport(url: String, transferFormat: TransferFormat) async throws { await transport!.onReceive(self.onReceive) - await transport!.onClose { [weak self] error in - guard let self = self else { return } - await self.handleConnectionClose(error: error) + if (self.features[ConnectionFeature.Reconnect] as? Bool) == true { + await transport!.onClose { [weak self] error in + var callStop = false; + guard let self = self else { return } + if await (self.features[ConnectionFeature.Reconnect] as? Bool) == true { + do { + if let disconnectedHandler = await self.features[ConnectionFeature.Disconnected] as? (Error?) -> Void { + disconnectedHandler(error); + } + try await self.transport?.connect(url: url, transferFormat: transferFormat); + if let resendHandler = await self.features[ConnectionFeature.Resend] as? () -> Void { + resendHandler(); + } + } catch { + callStop = true; + } + } + else { + await self.handleConnectionClose(error: error); + return; + } + if callStop { + await self.handleConnectionClose(error: error); + } + } } do { @@ -435,6 +457,12 @@ actor HttpConnection: ConnectionProtocol { if let useStatefulReconnect = options.useStatefulReconnect, useStatefulReconnect { queryItems.append(URLQueryItem(name: "useStatefulReconnect", value: "true")) } + else { + if (queryItems.first(where: { $0.name == "useStatefulReconnect" })?.value ?? "false") == "true" { + options.useStatefulReconnect = true; + } + } + negotiateUrlComponents.queryItems = queryItems return negotiateUrlComponents.url!.absoluteString } @@ -476,7 +504,7 @@ actor HttpConnection: ConnectionProtocol { let transferFormats = endpoint.transferFormats.compactMap { TransferFormat($0) } if transferFormats.contains(requestedTransferFormat) { do { - features["reconnect"] = (transportType == .webSockets && useStatefulReconnect) ? true : nil + self.features[ConnectionFeature.Reconnect] = (transportType == .webSockets && useStatefulReconnect) ? true : false; let constructedTransport = try await constructTransport(transport: transportType) return constructedTransport } catch { @@ -511,4 +539,8 @@ actor HttpConnection: ConnectionProtocol { } return urlRequest } + + public func setFeature(feature: ConnectionFeature, value: Any) async { + features[feature] = value + } } diff --git a/Sources/SignalRClient/HubConnection.swift b/Sources/SignalRClient/HubConnection.swift index 7ba13d2..976a6c0 100644 --- a/Sources/SignalRClient/HubConnection.swift +++ b/Sources/SignalRClient/HubConnection.swift @@ -6,7 +6,7 @@ import Foundation public actor HubConnection { private static let defaultTimeout: TimeInterval = 30 private static let defaultPingInterval: TimeInterval = 15 - private static let defaultStatefulReconnectBufferSize: Int = 100_000_000 // bytes of messages + private static let defaultStatefulReconnectBufferSize: Int = 100_000_000 // bytes of messages private var invocationBinder: DefaultInvocationBinder private var invocationHandler: InvocationHandler @@ -16,7 +16,7 @@ public actor HubConnection { private let logger: Logger private let hubProtocol: HubProtocol private let connection: ConnectionProtocol - private let retryPolicy: RetryPolicy + private let reconnectPolicy: RetryPolicy private let keepAliveScheduler: TimeScheduler private let serverTimeoutScheduler: TimeScheduler private let statefulReconnectBufferSize: Int @@ -24,8 +24,8 @@ public actor HubConnection { private var connectionStarted: Bool = false private var receivedHandshakeResponse: Bool = false private var invocationId: Int = 0 + private var messageBuffer: MessageBuffer? = nil private var connectionStatus: HubConnectionState = .Stopped - private var stopping: Bool = false private var stopDuringStartError: Error? private nonisolated(unsafe) var handshakeResolver: ((HandshakeResponseMessage) -> Void)? private nonisolated(unsafe) var handshakeRejector: ((Error) -> Void)? @@ -40,16 +40,17 @@ public actor HubConnection { internal init(connection: ConnectionProtocol, logger: Logger, hubProtocol: HubProtocol, - retryPolicy: RetryPolicy, + reconnectPolicy: RetryPolicy, serverTimeout: TimeInterval?, keepAliveInterval: TimeInterval?, statefulReconnectBufferSize: Int?) { self.serverTimeout = serverTimeout ?? HubConnection.defaultTimeout self.keepAliveInterval = keepAliveInterval ?? HubConnection.defaultPingInterval - self.statefulReconnectBufferSize = statefulReconnectBufferSize ?? HubConnection.defaultStatefulReconnectBufferSize + self.statefulReconnectBufferSize = + statefulReconnectBufferSize ?? HubConnection.defaultStatefulReconnectBufferSize self.logger = logger - self.retryPolicy = retryPolicy + self.reconnectPolicy = reconnectPolicy self.connection = connection self.hubProtocol = hubProtocol @@ -58,10 +59,12 @@ public actor HubConnection { self.invocationHandler = InvocationHandler() self.keepAliveScheduler = TimeScheduler(initialInterval: self.keepAliveInterval) self.serverTimeoutScheduler = TimeScheduler(initialInterval: self.serverTimeout) + self.reconnectedHandlers = [] + self.reconnectingHandlers = [] } public func start() async throws { - if (connectionStatus != .Stopped) { + if connectionStatus != .Stopped { throw SignalRError.invalidOperation("Start client while not in a stopped state.") } @@ -77,7 +80,6 @@ public actor HubConnection { startSuccessfully = true } catch { connectionStatus = .Stopped - stopping = false await keepAliveScheduler.stop() await serverTimeoutScheduler.stop() logger.log(level: .debug, message: "HubConnection start failed \(error)") @@ -96,13 +98,14 @@ public actor HubConnection { } // 2. Another stop is running, just wait for it - if (stopping) { + if connectionStatus == .Stopping { logger.log(level: .debug, message: "Connection is already stopping") await stopTask?.value return } - stopping = true + connectionStatus = .Stopping + await self.connection.setFeature(feature: ConnectionFeature.Reconnect, value: false) // In this step, there's no other start running stopTask = Task { @@ -291,6 +294,13 @@ public actor HubConnection { private func stopInternal() async { if (connectionStatus == .Stopped) { + logger.log(level: .debug,message:"Call to HubConnection.stop ignored because it is already in the disconnected state.") + return + } + + if connectionStatus == .Stopping { + logger.log(level: .debug,message:"Call to HubConnection.stop ignored because it is already in the stopping state.") + await stopTask?.value return } @@ -325,7 +335,7 @@ public actor HubConnection { handshakeRejector!(SignalRError.connectionAborted) } - if (stopping) { + if connectionStatus == .Connecting { await completeClose(error: error) return } @@ -335,7 +345,7 @@ public actor HubConnection { // 2. Connected: In this case, we should reconnect // 3. Reconnecting: In this case, we're in the control of previous reconnect(), let that function handle the reconnection - if (connectionStatus == .Connected) { + if connectionStatus == .Connected { do { try await reconnect(error: error) } catch { @@ -351,13 +361,15 @@ public actor HubConnection { var lastError: Error? = error // reconnect - while let interval = retryPolicy.nextRetryInterval(retryContext: RetryContext( - retryCount: retryCount, - elapsed: elapsed, - retryReason: lastError - )) { + while let interval = reconnectPolicy.nextRetryInterval( + retryContext: RetryContext( + retryCount: retryCount, + elapsed: elapsed, + retryReason: lastError + )) + { try Task.checkCancellation() - if (stopping) { + if connectionStatus == .Stopping { break } @@ -380,7 +392,7 @@ public actor HubConnection { logger.log(level: .warning, message: "Connection reconnect failed: \(error)") } - if (stopping) { + if connectionStatus == .Stopping { break } @@ -424,6 +436,17 @@ public actor HubConnection { do { let hubMessage = try hubProtocol.parseMessages(input: data!, binder: invocationBinder) for message in hubMessage { + do { + if let messageBuffer = self.messageBuffer { + let shouldProcess = try await messageBuffer.shouldProcessMessage(message) + if !shouldProcess { + // Don't process the message, we are either waiting for a SequenceMessage or received a duplicate message + continue + } + } + } catch { + logger.log(level: .error, message: "Error parsing messages: \(error)") + } await dispatchMessage(message) } } catch { @@ -499,7 +522,6 @@ public actor HubConnection { private func completeClose(error: Error?) async { connectionStatus = .Stopped - stopping = false await keepAliveScheduler.stop() await serverTimeoutScheduler.stop() @@ -513,7 +535,7 @@ public actor HubConnection { private func startInternal() async throws { try Task.checkCancellation() - guard stopping == false else { + guard connectionStatus != .Stopping else { throw SignalRError.invalidOperation("Stopping is called") } @@ -532,6 +554,7 @@ public actor HubConnection { logger.log(level: .error, message: "Unsupported handshake version: \(version)") throw SignalRError.unsupportedHandshakeVersion } + // TODO: enable version 2 when stateful reconnect is done receivedHandshakeResponse = false let handshakeRequest = HandshakeRequestMessage(protocol: hubProtocol.name, version: version) @@ -567,6 +590,23 @@ public actor HubConnection { } } + let useStatefulReconnect = await (self.connection.features[ConnectionFeature.Reconnect] as? Bool) == true + if useStatefulReconnect { + self.messageBuffer = MessageBuffer( + hubProtocol: self.hubProtocol, connection: self.connection, + bufferSize: self.statefulReconnectBufferSize) + await self.connection.setFeature( + feature: ConnectionFeature.Disconnected, + value: { [weak self] () async -> Void in + _ = try? await self?.messageBuffer?.disconnected() + }) + await self.connection.setFeature( + feature: ConnectionFeature.Resend, + value: { [weak self] () async -> Any? in + return try? await self?.messageBuffer?.resend() + }) + } + let inherentKeepAlive = await connection.inherentKeepAlive if (!inherentKeepAlive) { await keepAliveScheduler.start { @@ -808,6 +848,7 @@ public actor HubConnection { public enum HubConnectionState { // The connection is stopped. Start can only be called if the connection is in this state. case Stopped + case Stopping case Connecting case Connected case Reconnecting diff --git a/Sources/SignalRClient/HubConnectionBuilder.swift b/Sources/SignalRClient/HubConnectionBuilder.swift index 8314662..d0ef592 100644 --- a/Sources/SignalRClient/HubConnectionBuilder.swift +++ b/Sources/SignalRClient/HubConnectionBuilder.swift @@ -103,7 +103,7 @@ public class HubConnectionBuilder { return HubConnection(connection: connection, logger: logger, hubProtocol: hubProtocol, - retryPolicy: retryPolicy, + reconnectPolicy: retryPolicy, serverTimeout: serverTimeout, keepAliveInterval: keepAliveInterval, statefulReconnectBufferSize: statefulReconnectBufferSize) diff --git a/Sources/SignalRClient/MessageBuffer.swift b/Sources/SignalRClient/MessageBuffer.swift index 3b5dc1c..da6fdce 100644 --- a/Sources/SignalRClient/MessageBuffer.swift +++ b/Sources/SignalRClient/MessageBuffer.swift @@ -4,6 +4,9 @@ import Foundation actor MessageBuffer { + private var hubProtocol: HubProtocol; + private var connection: ConnectionProtocol; + private var maxBufferSize: Int private var messages: [BufferedItem] = [] private var bufferedByteCount: Int = 0 @@ -13,10 +16,28 @@ actor MessageBuffer { private var dequeueContinuations: [CheckedContinuation] = [] private var closed: Bool = false - init(bufferSize: Int) { + init(hubProtocol: HubProtocol, connection: ConnectionProtocol, bufferSize: Int) { + self.hubProtocol = hubProtocol + self.connection = connection self.maxBufferSize = bufferSize } + public func send(message: HubMessage) async throws -> Void { + throw SignalRError.invalidOperation("Send is not implemented") + } + + public func resend() async throws -> Void { + throw SignalRError.invalidOperation("Resend is not implemented") + } + + public func disconnected() async throws -> Void { + throw SignalRError.invalidOperation("Disconnected is not implemented") + } + + public func shouldProcessMessage(_ message: HubMessage) throws -> Bool { + throw SignalRError.invalidOperation("ShouldProcessMessage is not implemented") + } + public func enqueue(content: StringOrData) async throws -> Void { if closed { throw SignalRError.invalidOperation("Message buffer has closed") diff --git a/Tests/SignalRClientTests/HubConnection+OnResultTests.swift b/Tests/SignalRClientTests/HubConnection+OnResultTests.swift index 1c11b0a..bbbda3c 100644 --- a/Tests/SignalRClientTests/HubConnection+OnResultTests.swift +++ b/Tests/SignalRClientTests/HubConnection+OnResultTests.swift @@ -29,7 +29,7 @@ final class HubConnectionOnResultTests: XCTestCase { connection: mockConnection, logger: Logger(logLevel: .debug, logHandler: logHandler), hubProtocol: hubProtocol, - retryPolicy: DefaultRetryPolicy(retryDelays: []), // No retry + reconnectPolicy: DefaultRetryPolicy(retryDelays: []), // No retry serverTimeout: nil, keepAliveInterval: nil, statefulReconnectBufferSize: nil diff --git a/Tests/SignalRClientTests/HubConnection+OnTests.swift b/Tests/SignalRClientTests/HubConnection+OnTests.swift index d1a853a..f8d6f8b 100644 --- a/Tests/SignalRClientTests/HubConnection+OnTests.swift +++ b/Tests/SignalRClientTests/HubConnection+OnTests.swift @@ -25,7 +25,7 @@ final class HubConnectionOnTests: XCTestCase { connection: mockConnection, logger: Logger(logLevel: .debug, logHandler: logHandler), hubProtocol: hubProtocol, - retryPolicy: DefaultRetryPolicy(retryDelays: []), // No retry + reconnectPolicy: DefaultRetryPolicy(retryDelays: []), // No retry serverTimeout: nil, keepAliveInterval: nil, statefulReconnectBufferSize: nil From 06eaab940009c4ee3e2ede328237aa889e850032 Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Tue, 8 Jul 2025 16:29:16 +0800 Subject: [PATCH 2/8] udpate message buffer and ut --- .../SignalRClient/ConnectionProtocol.swift | 1 - Sources/SignalRClient/HttpConnection.swift | 12 +- Sources/SignalRClient/HubConnection.swift | 65 ++++----- .../SignalRClient/HubConnectionBuilder.swift | 2 +- Sources/SignalRClient/MessageBuffer.swift | 126 ++++++++++++++++-- .../HubConnection+OnResultTests.swift | 2 +- .../HubConnection+OnTests.swift | 2 +- .../HubConnectionTests.swift | 5 + .../MessageBufferTests.swift | 28 ++-- 9 files changed, 178 insertions(+), 65 deletions(-) diff --git a/Sources/SignalRClient/ConnectionProtocol.swift b/Sources/SignalRClient/ConnectionProtocol.swift index 6c1ce13..23ed33c 100644 --- a/Sources/SignalRClient/ConnectionProtocol.swift +++ b/Sources/SignalRClient/ConnectionProtocol.swift @@ -5,7 +5,6 @@ enum ConnectionFeature: String, CaseIterable { case Reconnect = "reconnect" case Resend = "resend" case Disconnected = "disconnected" - // Add more feature keys as needed } protocol ConnectionProtocol: AnyObject, Sendable { diff --git a/Sources/SignalRClient/HttpConnection.swift b/Sources/SignalRClient/HttpConnection.swift index f6d6f8d..e98ba2b 100644 --- a/Sources/SignalRClient/HttpConnection.swift +++ b/Sources/SignalRClient/HttpConnection.swift @@ -385,13 +385,11 @@ actor HttpConnection: ConnectionProtocol { } } } - - do { - try await transport!.connect(url: url, transferFormat: transferFormat) - } catch { - await transport!.onReceive(nil) - await transport!.onClose(nil) - throw error + else { + await transport!.onClose { [weak self] error in + guard let self = self else { return } + await self.handleConnectionClose(error: error) + } } } diff --git a/Sources/SignalRClient/HubConnection.swift b/Sources/SignalRClient/HubConnection.swift index 976a6c0..c6da5be 100644 --- a/Sources/SignalRClient/HubConnection.swift +++ b/Sources/SignalRClient/HubConnection.swift @@ -16,7 +16,7 @@ public actor HubConnection { private let logger: Logger private let hubProtocol: HubProtocol private let connection: ConnectionProtocol - private let reconnectPolicy: RetryPolicy + private let retryPolicy: RetryPolicy private let keepAliveScheduler: TimeScheduler private let serverTimeoutScheduler: TimeScheduler private let statefulReconnectBufferSize: Int @@ -40,17 +40,16 @@ public actor HubConnection { internal init(connection: ConnectionProtocol, logger: Logger, hubProtocol: HubProtocol, - reconnectPolicy: RetryPolicy, + retryPolicy: RetryPolicy, serverTimeout: TimeInterval?, keepAliveInterval: TimeInterval?, statefulReconnectBufferSize: Int?) { self.serverTimeout = serverTimeout ?? HubConnection.defaultTimeout self.keepAliveInterval = keepAliveInterval ?? HubConnection.defaultPingInterval - self.statefulReconnectBufferSize = - statefulReconnectBufferSize ?? HubConnection.defaultStatefulReconnectBufferSize + self.statefulReconnectBufferSize = statefulReconnectBufferSize ?? HubConnection.defaultStatefulReconnectBufferSize self.logger = logger - self.reconnectPolicy = reconnectPolicy + self.retryPolicy = retryPolicy self.connection = connection self.hubProtocol = hubProtocol @@ -64,7 +63,7 @@ public actor HubConnection { } public func start() async throws { - if connectionStatus != .Stopped { + if (connectionStatus != .Stopped) { throw SignalRError.invalidOperation("Start client while not in a stopped state.") } @@ -345,7 +344,7 @@ public actor HubConnection { // 2. Connected: In this case, we should reconnect // 3. Reconnecting: In this case, we're in the control of previous reconnect(), let that function handle the reconnection - if connectionStatus == .Connected { + if (connectionStatus == .Connected) { do { try await reconnect(error: error) } catch { @@ -361,13 +360,11 @@ public actor HubConnection { var lastError: Error? = error // reconnect - while let interval = reconnectPolicy.nextRetryInterval( - retryContext: RetryContext( - retryCount: retryCount, - elapsed: elapsed, - retryReason: lastError - )) - { + while let interval = retryPolicy.nextRetryInterval(retryContext: RetryContext( + retryCount: retryCount, + elapsed: elapsed, + retryReason: lastError + )) { try Task.checkCancellation() if connectionStatus == .Stopping { break @@ -436,16 +433,11 @@ public actor HubConnection { do { let hubMessage = try hubProtocol.parseMessages(input: data!, binder: invocationBinder) for message in hubMessage { - do { - if let messageBuffer = self.messageBuffer { - let shouldProcess = try await messageBuffer.shouldProcessMessage(message) - if !shouldProcess { - // Don't process the message, we are either waiting for a SequenceMessage or received a duplicate message - continue - } + if let messageBuffer = self.messageBuffer { + if !(try await messageBuffer.shouldProcessMessage(message)) { + // Don't process the message, we are either waiting for a SequenceMessage or received a duplicate message + continue } - } catch { - logger.log(level: .error, message: "Error parsing messages: \(error)") } await dispatchMessage(message) } @@ -482,11 +474,18 @@ public actor HubConnection { case _ as CloseMessage: // Close break - case _ as AckMessage: - // TODO: In stateful reconnect + case let message as AckMessage: + let result = await self.messageBuffer?.ack(sequenceId: message.sequenceId); + if (result == false) { + logger.log(level: .warning, message: "Ack message received for sequenceId: \(message.sequenceId), but failed.") + } break - case _ as SequenceMessage: - // TODO: In stateful reconnect + case let message as SequenceMessage: + if let messageBuffer = self.messageBuffer { + await messageBuffer.resetSequenceMessage(message: message) + } else { + logger.log(level: .warning, message: "Sequence message received but no message buffer is available.") + } break default: logger.log(level: .warning, message: "Unknown message type: \(message)") @@ -554,7 +553,7 @@ public actor HubConnection { logger.log(level: .error, message: "Unsupported handshake version: \(version)") throw SignalRError.unsupportedHandshakeVersion } - // TODO: enable version 2 when stateful reconnect is done + // TODO: enable version 2 when stateful reconnect is ready receivedHandshakeResponse = false let handshakeRequest = HandshakeRequestMessage(protocol: hubProtocol.name, version: version) @@ -593,17 +592,19 @@ public actor HubConnection { let useStatefulReconnect = await (self.connection.features[ConnectionFeature.Reconnect] as? Bool) == true if useStatefulReconnect { self.messageBuffer = MessageBuffer( - hubProtocol: self.hubProtocol, connection: self.connection, - bufferSize: self.statefulReconnectBufferSize) + bufferSize: self.statefulReconnectBufferSize, + hubProtocol: self.hubProtocol, + connection: self.connection + ) await self.connection.setFeature( feature: ConnectionFeature.Disconnected, value: { [weak self] () async -> Void in - _ = try? await self?.messageBuffer?.disconnected() + _ = await self?.messageBuffer?.disconnected() }) await self.connection.setFeature( feature: ConnectionFeature.Resend, value: { [weak self] () async -> Any? in - return try? await self?.messageBuffer?.resend() + return try? await self?.messageBuffer?.resend() }) } diff --git a/Sources/SignalRClient/HubConnectionBuilder.swift b/Sources/SignalRClient/HubConnectionBuilder.swift index d0ef592..8314662 100644 --- a/Sources/SignalRClient/HubConnectionBuilder.swift +++ b/Sources/SignalRClient/HubConnectionBuilder.swift @@ -103,7 +103,7 @@ public class HubConnectionBuilder { return HubConnection(connection: connection, logger: logger, hubProtocol: hubProtocol, - reconnectPolicy: retryPolicy, + retryPolicy: retryPolicy, serverTimeout: serverTimeout, keepAliveInterval: keepAliveInterval, statefulReconnectBufferSize: statefulReconnectBufferSize) diff --git a/Sources/SignalRClient/MessageBuffer.swift b/Sources/SignalRClient/MessageBuffer.swift index da6fdce..871cbdb 100644 --- a/Sources/SignalRClient/MessageBuffer.swift +++ b/Sources/SignalRClient/MessageBuffer.swift @@ -4,38 +4,131 @@ import Foundation actor MessageBuffer { - private var hubProtocol: HubProtocol; - private var connection: ConnectionProtocol; - private var maxBufferSize: Int private var messages: [BufferedItem] = [] private var bufferedByteCount: Int = 0 private var totalMessageCount: Int = 0 private var lastSendSequenceId: Int = 0 private var nextSendIdx = 0 + private var nextReceivingIdx: Int64 = 1 + private var lastReceivedSequenceId: Int64 = 0 private var dequeueContinuations: [CheckedContinuation] = [] private var closed: Bool = false + private var reconnectInprogress: Bool = false; + private var waitForSequenceMessage: Bool = false; + + private var ackTimerHandle: DispatchWorkItem? + + private var hubProtocol: HubProtocol; + private var connection: ConnectionProtocol; - init(hubProtocol: HubProtocol, connection: ConnectionProtocol, bufferSize: Int) { + init(bufferSize: Int, hubProtocol: HubProtocol, connection: ConnectionProtocol) { + self.maxBufferSize = bufferSize self.hubProtocol = hubProtocol self.connection = connection - self.maxBufferSize = bufferSize } public func send(message: HubMessage) async throws -> Void { - throw SignalRError.invalidOperation("Send is not implemented") + let serializedMessage = try self.hubProtocol.writeMessage(message: message); + + try await self.enqueue(content: serializedMessage); + + if (!self.reconnectInprogress) { + do { + try await self.connection.send(serializedMessage); + } + catch { + self.disconnected(); + } + } } public func resend() async throws -> Void { - throw SignalRError.invalidOperation("Resend is not implemented") + let sequenceId = Int64(self.messages.count > 0 ? self.messages[0].id : self.totalMessageCount + 1); + let serializedMessage = try self.hubProtocol.writeMessage(message: SequenceMessage(sequenceId: sequenceId)); + try await self.connection.send(serializedMessage); + + let messages = self.messages; + for element in messages { + try await self.connection.send(element.content); + } + + self.reconnectInprogress = false; + } + + public func disconnected() -> Void { + self.reconnectInprogress = true; + self.waitForSequenceMessage = true; + } + + private func ackTimer() { + guard ackTimerHandle == nil else { + return + } + + let workItem = DispatchWorkItem { [weak self] in + guard let self = self else { return } + + Task { + await self.performScheduledAck() + } + } + + ackTimerHandle = workItem + DispatchQueue.main.asyncAfter(deadline: .now() + 1.0, execute: workItem) } - public func disconnected() async throws -> Void { - throw SignalRError.invalidOperation("Disconnected is not implemented") + private func performScheduledAck() async { + defer { + // 在方法结束时清理定时器 + ackTimerHandle = nil + } + + do { + if !reconnectInprogress { + let ackMessage = AckMessage( + sequenceId: lastReceivedSequenceId + ) + + let serializedMessage = try hubProtocol.writeMessage(message: ackMessage) + try await connection.send(serializedMessage) + } + } catch { + // 忽略错误,连接关闭时不需要发送ACK + } } public func shouldProcessMessage(_ message: HubMessage) throws -> Bool { - throw SignalRError.invalidOperation("ShouldProcessMessage is not implemented") + if (self.waitForSequenceMessage) { + if (message.type != .sequence) { + return false; + } else { + self.waitForSequenceMessage = false; + return true; + } + } + + if !self.isInvocationMessage(message: message) { + return true + } + + let currentId = self.nextReceivingIdx; + self.nextReceivingIdx += 1; + if currentId <= self.lastReceivedSequenceId{ + if currentId == self.lastReceivedSequenceId{ + // Should only hit this if we just reconnected and the server is sending + // Messages it has buffered, which would mean it hasn't seen an Ack for these messages + self.ackTimer(); + } + // Ignore, this is a duplicate message + return false; + } + self.lastReceivedSequenceId = currentId; + + // Only start the timer for sending an Ack message when we have a message to ack. This also conveniently solves + // timer throttling by not having a recursive timer, and by starting the timer via a network call (recv) + self.ackTimer(); + return true; } public func enqueue(content: StringOrData) async throws -> Void { @@ -71,7 +164,7 @@ actor MessageBuffer { } } - public func ack(sequenceId: Int) throws -> Bool { + public func ack(sequenceId: Int64) -> Bool { // It might be wrong ack or the ack of previous connection if (sequenceId <= 0 || sequenceId > lastSendSequenceId) { return false @@ -137,6 +230,17 @@ actor MessageBuffer { } } + public func resetSequenceMessage(message: SequenceMessage) async { + if message.sequenceId > self.nextReceivingIdx { + // do not await stop + Task { + await self.connection.stop(error: SignalRError.invalidOperation("Received sequence message with sequenceId \(message.sequenceId) greater than nextReceivingIdx \(self.nextReceivingIdx)")) + } + return + } + self.nextReceivingIdx = message.sequenceId; + } + private func isInvocationMessage(message: HubMessage) -> Bool { switch (message.type) { case .invocation, .streamItem, .completion, .streamInvocation, .cancelInvocation: diff --git a/Tests/SignalRClientTests/HubConnection+OnResultTests.swift b/Tests/SignalRClientTests/HubConnection+OnResultTests.swift index bbbda3c..1c11b0a 100644 --- a/Tests/SignalRClientTests/HubConnection+OnResultTests.swift +++ b/Tests/SignalRClientTests/HubConnection+OnResultTests.swift @@ -29,7 +29,7 @@ final class HubConnectionOnResultTests: XCTestCase { connection: mockConnection, logger: Logger(logLevel: .debug, logHandler: logHandler), hubProtocol: hubProtocol, - reconnectPolicy: DefaultRetryPolicy(retryDelays: []), // No retry + retryPolicy: DefaultRetryPolicy(retryDelays: []), // No retry serverTimeout: nil, keepAliveInterval: nil, statefulReconnectBufferSize: nil diff --git a/Tests/SignalRClientTests/HubConnection+OnTests.swift b/Tests/SignalRClientTests/HubConnection+OnTests.swift index f8d6f8b..d1a853a 100644 --- a/Tests/SignalRClientTests/HubConnection+OnTests.swift +++ b/Tests/SignalRClientTests/HubConnection+OnTests.swift @@ -25,7 +25,7 @@ final class HubConnectionOnTests: XCTestCase { connection: mockConnection, logger: Logger(logLevel: .debug, logHandler: logHandler), hubProtocol: hubProtocol, - reconnectPolicy: DefaultRetryPolicy(retryDelays: []), // No retry + retryPolicy: DefaultRetryPolicy(retryDelays: []), // No retry serverTimeout: nil, keepAliveInterval: nil, statefulReconnectBufferSize: nil diff --git a/Tests/SignalRClientTests/HubConnectionTests.swift b/Tests/SignalRClientTests/HubConnectionTests.swift index da8d7c3..9c74d98 100644 --- a/Tests/SignalRClientTests/HubConnectionTests.swift +++ b/Tests/SignalRClientTests/HubConnectionTests.swift @@ -13,6 +13,7 @@ class MockConnection: ConnectionProtocol, @unchecked Sendable { var onSend: ((StringOrData) -> Void)? var onStart: (() -> Void)? var onStop: ((Error?) -> Void)? + var features: [ConnectionFeature : Any] = [:] private(set) var startCalled = false private(set) var sendCalled = false @@ -42,6 +43,10 @@ class MockConnection: ConnectionProtocol, @unchecked Sendable { func onClose(_ handler: @escaping @Sendable ((any Error)?) async -> Void) async { onClose = handler } + + func setFeature(feature: SignalRClient.ConnectionFeature, value: Any) async { + features[feature] = value + } } final class HubConnectionTests: XCTestCase { diff --git a/Tests/SignalRClientTests/MessageBufferTests.swift b/Tests/SignalRClientTests/MessageBufferTests.swift index 287cd4e..893dfbc 100644 --- a/Tests/SignalRClientTests/MessageBufferTests.swift +++ b/Tests/SignalRClientTests/MessageBufferTests.swift @@ -3,8 +3,14 @@ import XCTest @testable import SignalRClient class MessageBufferTest: XCTestCase { + func getTestMessageBuffer(bufferSize: Int) -> MessageBuffer { + return MessageBuffer(bufferSize: bufferSize, + hubProtocol: JsonHubProtocol(), + connection: MockConnection()) + } + func testSendWithinBufferSize() async throws { - let buffer = MessageBuffer(bufferSize: 100) + let buffer = getTestMessageBuffer(bufferSize: 100) let expectation = XCTestExpectation(description: "Should enqueue") Task { try await buffer.enqueue(content: .string("data")) @@ -14,7 +20,7 @@ class MessageBufferTest: XCTestCase { } func testSendTriggersBackpressure() async throws { - let buffer = MessageBuffer(bufferSize: 5) + let buffer = getTestMessageBuffer(bufferSize: 5) let expectation1 = XCTestExpectation(description: "Should not enqueue") expectation1.isInverted = true let expectation2 = XCTestExpectation(description: "Should enqueue") @@ -33,7 +39,7 @@ class MessageBufferTest: XCTestCase { } func testBackPressureAndRelease() async throws { - let buffer = MessageBuffer(bufferSize: 10) + let buffer = getTestMessageBuffer(bufferSize: 10) try await buffer.enqueue(content: .string("1234567890")) async let eq1 = buffer.enqueue(content: .string("1")) async let eq2 = buffer.enqueue(content: .string("2")) @@ -51,7 +57,7 @@ class MessageBufferTest: XCTestCase { } func testBackPressureAndRelease2() async throws { - let buffer = MessageBuffer(bufferSize: 10) + let buffer = getTestMessageBuffer(bufferSize: 10) let expect1 = XCTestExpectation(description: "Should not release 1") expect1.isInverted = true let expect2 = XCTestExpectation(description: "Should not release 2") @@ -93,7 +99,7 @@ class MessageBufferTest: XCTestCase { } func testAckInvalidSequenceIdIgnored() async throws { - let buffer = MessageBuffer(bufferSize: 100) + let buffer = getTestMessageBuffer(bufferSize: 100) let rst = try await buffer.ack(sequenceId: 1) // without any send XCTAssertEqual(false, rst) @@ -104,7 +110,7 @@ class MessageBufferTest: XCTestCase { } func testWaitToDequeueReturnsImmediatelyIfAvailable() async throws { - let buffer = MessageBuffer(bufferSize: 100) + let buffer = getTestMessageBuffer(bufferSize: 100) _ = try await buffer.enqueue(content: .string("msg")) let result = try await buffer.WaitToDequeue() XCTAssertTrue(result) @@ -113,7 +119,7 @@ class MessageBufferTest: XCTestCase { } func testWaitToDequeueFirst() async throws { - let buffer = MessageBuffer(bufferSize: 100) + let buffer = getTestMessageBuffer(bufferSize: 100) async let dqueue: Bool = try await buffer.WaitToDequeue() try await Task.sleep(for: .milliseconds(10)) @@ -127,7 +133,7 @@ class MessageBufferTest: XCTestCase { } func testMultipleDequeueWait() async throws { - let buffer = MessageBuffer(bufferSize: 100) + let buffer = getTestMessageBuffer(bufferSize: 100) async let dqueue1: Bool = try await buffer.WaitToDequeue() async let dqueue2: Bool = try await buffer.WaitToDequeue() try await Task.sleep(for: .milliseconds(10)) @@ -143,13 +149,13 @@ class MessageBufferTest: XCTestCase { } func testTryDequeueReturnsNilIfEmpty() async throws { - let buffer = MessageBuffer(bufferSize: 100) + let buffer = getTestMessageBuffer(bufferSize: 100) let result = try await buffer.TryDequeue() XCTAssertNil(result) } func testResetDequeueResetsCorrectly() async throws { - let buffer = MessageBuffer(bufferSize: 100) + let buffer = getTestMessageBuffer(bufferSize: 100) try await buffer.enqueue(content: .string("test1")) try await buffer.enqueue(content: .string("test2")) let t1 = try await buffer.TryDequeue() @@ -172,7 +178,7 @@ class MessageBufferTest: XCTestCase { } func testContinuousBackPressure() async throws { - let buffer = MessageBuffer(bufferSize: 5) + let buffer = getTestMessageBuffer(bufferSize: 5) var tasks: [Task] = [] for i in 0..<100 { let task = Task { From 37bcf8537041bd6ee0db5cfa00afe6ca0e9e8171 Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Tue, 8 Jul 2025 16:34:10 +0800 Subject: [PATCH 3/8] update --- Sources/SignalRClient/HttpConnection.swift | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Sources/SignalRClient/HttpConnection.swift b/Sources/SignalRClient/HttpConnection.swift index e98ba2b..a8b4b34 100644 --- a/Sources/SignalRClient/HttpConnection.swift +++ b/Sources/SignalRClient/HttpConnection.swift @@ -391,6 +391,14 @@ actor HttpConnection: ConnectionProtocol { await self.handleConnectionClose(error: error) } } + + do { + try await transport!.connect(url: url, transferFormat: transferFormat) + } catch { + await transport!.onReceive(nil) + await transport!.onClose(nil) + throw error + } } private func handleConnectionClose(error: Error?) async { From 557efc5f96e38b280b38d04dc1e8f9b3343df454 Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Wed, 9 Jul 2025 17:07:47 +0800 Subject: [PATCH 4/8] update ut --- .../SignalRClient/HubConnectionBuilder.swift | 6 +++ .../HubConnectionTests.swift | 44 ++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/Sources/SignalRClient/HubConnectionBuilder.swift b/Sources/SignalRClient/HubConnectionBuilder.swift index 8314662..84e70f5 100644 --- a/Sources/SignalRClient/HubConnectionBuilder.swift +++ b/Sources/SignalRClient/HubConnectionBuilder.swift @@ -65,6 +65,12 @@ public class HubConnectionBuilder { return self } + public func withStatefulReconnect(bufferSize: Int) -> HubConnectionBuilder { + self.httpConnectionOptions.useStatefulReconnect = true + self.statefulReconnectBufferSize = bufferSize + return self + } + public func withAutomaticReconnect() -> HubConnectionBuilder { self.retryPolicy = DefaultRetryPolicy(retryDelays: [0, 2, 10, 30]) return self diff --git a/Tests/SignalRClientTests/HubConnectionTests.swift b/Tests/SignalRClientTests/HubConnectionTests.swift index 9c74d98..591aa60 100644 --- a/Tests/SignalRClientTests/HubConnectionTests.swift +++ b/Tests/SignalRClientTests/HubConnectionTests.swift @@ -18,7 +18,7 @@ class MockConnection: ConnectionProtocol, @unchecked Sendable { private(set) var startCalled = false private(set) var sendCalled = false private(set) var stopCalled = false - private(set) var sentData: StringOrData? + private(set) var sentData: [StringOrData?] = [] func start(transferFormat: TransferFormat) async throws { startCalled = true @@ -27,7 +27,7 @@ class MockConnection: ConnectionProtocol, @unchecked Sendable { func send(_ data: StringOrData) async throws { sendCalled = true - sentData = data + sentData.append(data) onSend?(data) } @@ -535,6 +535,46 @@ final class HubConnectionTests: XCTestCase { await fulfillment(of: [pingExpectations[0], pingExpectations[1], pingExpectations[2]], timeout: 1.0) } + func testStatefulReconnect() async throws { + let bufferSize = 10 + let expectation = XCTestExpectation(description: "send() should be called") + + mockConnection.onSend = { data in + expectation.fulfill() + Task { await self.hubConnection.processIncomingData(.string(self.successHandshakeResponse)) } + } + await mockConnection.setFeature(feature: ConnectionFeature.Reconnect, value: true); + + hubConnection = HubConnection( + connection: mockConnection, + logger: Logger(logLevel: .debug, logHandler: logHandler), + hubProtocol: hubProtocol, + retryPolicy: DefaultRetryPolicy(retryDelays: [0, 1, 2]), + serverTimeout: nil, + keepAliveInterval: nil, + statefulReconnectBufferSize: bufferSize + ) + + let startTask = Task { try await hubConnection.start() } + defer { startTask.cancel() } + await fulfillment(of: [expectation], timeout: 1.0) + await whenTaskWithTimeout(startTask, timeout: 1.0) + + XCTAssertNotNil(mockConnection.features[ConnectionFeature.Disconnected]); + XCTAssertNotNil(mockConnection.features[ConnectionFeature.Resend]); + + if let disconnectedClosure = mockConnection.features[ConnectionFeature.Disconnected] as? () async -> Void { + await disconnectedClosure() + print("called disconnected closure") + } + + if let resendClosure = mockConnection.features[ConnectionFeature.Resend] as? () async -> Any? { + let _ = await resendClosure() + } + + print(mockConnection.sentData.count, mockConnection.sentData) + } + func serverTimeoutTest() async throws { hubConnection = HubConnection( connection: mockConnection, From 8b6b0746f77c5a6cbfd787be11e1f7d6a4be2912 Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Mon, 21 Jul 2025 17:25:17 +0800 Subject: [PATCH 5/8] update HubConnection --- Sources/SignalRClient/HubConnection.swift | 51 +++++++++++------------ 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/Sources/SignalRClient/HubConnection.swift b/Sources/SignalRClient/HubConnection.swift index c6da5be..72bc6a2 100644 --- a/Sources/SignalRClient/HubConnection.swift +++ b/Sources/SignalRClient/HubConnection.swift @@ -118,9 +118,8 @@ public actor HubConnection { let (nonstreamArguments, streamArguments) = splitStreamArguments(arguments: arguments) let streamIds = await invocationHandler.createClientStreamIds(count: streamArguments.count) let invocationMessage = InvocationMessage(target: method, arguments: AnyEncodableArray(nonstreamArguments), streamIds: streamIds, headers: nil, invocationId: nil) - let data = try hubProtocol.writeMessage(message: invocationMessage) logger.log(level: .debug, message: "Sending message to target: \(method)") - try await sendMessageInternal(data) + try await sendWithProtocol(invocationMessage) launchStreams(streamIds: streamIds, clientStreams: streamArguments) } @@ -145,8 +144,7 @@ public actor HubConnection { do { for try await item in stream { let streamItem = StreamItemMessage(invocationId: streamIds[i], item: AnyEncodable(item), headers: nil) - let data = try hubProtocol.writeMessage(message: streamItem) - try await sendMessageInternal(data) + try await sendWithProtocol(streamItem) } } catch { err = "\(error)" @@ -154,8 +152,7 @@ public actor HubConnection { } do { let completionMessage = CompletionMessage(invocationId: streamIds[i], error: err, result: AnyEncodable(nil), headers: nil) - let data = try hubProtocol.writeMessage(message: completionMessage) - try await sendMessageInternal(data) + try await sendWithProtocol(completionMessage) } catch { logger.log(level: .error, message: "Fail to send client stream complete message :\(error)") } @@ -168,9 +165,8 @@ public actor HubConnection { let streamIds = await invocationHandler.createClientStreamIds(count: streamArguments.count) let (invocationId, tcs) = await invocationHandler.create() let invocationMessage = InvocationMessage(target: method, arguments: AnyEncodableArray(nonstreamArguments), streamIds: streamIds, headers: nil, invocationId: invocationId) - let data = try hubProtocol.writeMessage(message: invocationMessage) logger.log(level: .debug, message: "Invoke message to target: \(method), invocationId: \(invocationId)") - try await sendMessageInternal(data) + try await sendWithProtocol(invocationMessage) launchStreams(streamIds: streamIds, clientStreams: streamArguments) _ = try await tcs.task() } @@ -182,9 +178,8 @@ public actor HubConnection { invocationBinder.registerReturnValueType(invocationId: invocationId, types: TReturn.self) let invocationMessage = InvocationMessage(target: method, arguments: AnyEncodableArray(nonstreamArguments), streamIds: streamIds, headers: nil, invocationId: invocationId) do { - let data = try hubProtocol.writeMessage(message: invocationMessage) logger.log(level: .debug, message: "Invoke message to target: \(method), invocationId: \(invocationId)") - try await sendMessageInternal(data) + try await sendWithProtocol(invocationMessage) launchStreams(streamIds: streamIds, clientStreams: streamArguments) } catch { await invocationHandler.cancel(invocationId: invocationId, error: error) @@ -206,9 +201,8 @@ public actor HubConnection { invocationBinder.registerReturnValueType(invocationId: invocationId, types: Element.self) let StreamInvocationMessage = StreamInvocationMessage(invocationId: invocationId, target: method, arguments: AnyEncodableArray(nonstreamArguments), streamIds: streamIds, headers: nil) do { - let data = try hubProtocol.writeMessage(message: StreamInvocationMessage) logger.log(level: .debug, message: "Stream message to target: \(method), invocationId: \(invocationId)") - try await sendMessageInternal(data) + try await sendWithProtocol(StreamInvocationMessage) launchStreams(streamIds: streamIds, clientStreams: streamArguments) } catch { await invocationHandler.cancel(invocationId: invocationId, error: error) @@ -236,9 +230,8 @@ public actor HubConnection { streamResult.onCancel = { do { let cancelInvocation = CancelInvocationMessage(invocationId: invocationId, headers: nil) - let data = try self.hubProtocol.writeMessage(message: cancelInvocation) await self.invocationHandler.cancel(invocationId: invocationId, error: SignalRError.streamCancelled) - try await self.sendMessageInternal(data) + try await self.sendWithProtocol(cancelInvocation) } catch {} } @@ -498,8 +491,7 @@ public actor HubConnection { if let invocationId = message.invocationId { logger.log(level: .warning, message: "No result given for method: \(message.target), and invocationId: \(invocationId)") let completionMessage = CompletionMessage(invocationId: invocationId, error: "No handler registered for method: \(message.target)", result: AnyEncodable(nil), headers: nil) - let data = try hubProtocol.writeMessage(message: completionMessage) - try await sendMessageInternal(data) + try await sendWithProtocol(completionMessage) } return } @@ -512,8 +504,7 @@ public actor HubConnection { result = nil } let completionMessage = CompletionMessage(invocationId: message.invocationId!, error: nil, result: AnyEncodable(result), headers: nil) - let data = try hubProtocol.writeMessage(message: completionMessage) - try await sendMessageInternal(data) + try await sendWithProtocol(completionMessage) } else { _ = try await handler(message.arguments.value ?? []) } @@ -547,13 +538,12 @@ public actor HubConnection { try await connection.start(transferFormat: hubProtocol.transferFormat) // After connection open, perform handshake - let version = hubProtocol.version - // As we only support 1 now - guard version == 1 else { - logger.log(level: .error, message: "Unsupported handshake version: \(version)") - throw SignalRError.unsupportedHandshakeVersion + var version = hubProtocol.version + if !(await connection.features[ConnectionFeature.Reconnect] as? Bool ?? false) { + // Stateful Reconnect starts with HubProtocol version 2, newer clients connecting to older servers will fail to connect due to + // the handshake only supporting version 1, so we will try to send version 1 during the handshake to keep old servers working. + version = 1; } - // TODO: enable version 2 when stateful reconnect is ready receivedHandshakeResponse = false let handshakeRequest = HandshakeRequestMessage(protocol: hubProtocol.name, version: version) @@ -647,6 +637,16 @@ public actor HubConnection { try await connection.send(content) } + private func sendWithProtocol(_ message: HubMessage) async throws { + if self.messageBuffer != nil { + try await self.messageBuffer?.send(message: message) + } + else { + let data = try hubProtocol.writeMessage(message: message) + try await sendMessageInternal(data) + } + } + private func processHandshakeResponse(_ content: StringOrData) throws -> StringOrData? { var remainingData: StringOrData? var handshakeResponse: HandshakeResponseMessage @@ -684,8 +684,7 @@ public actor HubConnection { private func sendPing() async throws { let pingMessage = PingMessage() - let data = try hubProtocol.writeMessage(message: pingMessage) - try await sendMessageInternal(data) + try await sendWithProtocol(pingMessage) } private class SubscriptionEntity { From 6eb4789233512a934ac655ba3ae2f072b958d131 Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Tue, 22 Jul 2025 17:05:46 +0800 Subject: [PATCH 6/8] add more ut --- Sources/SignalRClient/HubConnection.swift | 3 +- .../HubConnectionTests.swift | 145 +++++++++++++++--- Tests/SignalRClientTests/Utils.swift | 34 ++++ 3 files changed, 160 insertions(+), 22 deletions(-) create mode 100644 Tests/SignalRClientTests/Utils.swift diff --git a/Sources/SignalRClient/HubConnection.swift b/Sources/SignalRClient/HubConnection.swift index 72bc6a2..c7fbd0d 100644 --- a/Sources/SignalRClient/HubConnection.swift +++ b/Sources/SignalRClient/HubConnection.swift @@ -598,8 +598,7 @@ public actor HubConnection { }) } - let inherentKeepAlive = await connection.inherentKeepAlive - if (!inherentKeepAlive) { + if (!(await connection.inherentKeepAlive)) { await keepAliveScheduler.start { do { let state = self.state() diff --git a/Tests/SignalRClientTests/HubConnectionTests.swift b/Tests/SignalRClientTests/HubConnectionTests.swift index 591aa60..0311efe 100644 --- a/Tests/SignalRClientTests/HubConnectionTests.swift +++ b/Tests/SignalRClientTests/HubConnectionTests.swift @@ -61,6 +61,7 @@ final class HubConnectionTests: XCTestCase { var logHandler: LogHandler! var hubProtocol: HubProtocol! var hubConnection: HubConnection! + var hubConnectionForStatefulReconnect: HubConnection! override func setUp() async throws { mockConnection = MockConnection() @@ -75,6 +76,15 @@ final class HubConnectionTests: XCTestCase { keepAliveInterval: nil, statefulReconnectBufferSize: nil ) + hubConnectionForStatefulReconnect = HubConnection( + connection: mockConnection, + logger: Logger(logLevel: .debug, logHandler: logHandler), + hubProtocol: hubProtocol, + retryPolicy: DefaultRetryPolicy(retryDelays: [0, 1, 2]), + serverTimeout: nil, + keepAliveInterval: 0.5, + statefulReconnectBufferSize: 10000 + ) } func testStart_CallsStartOnConnection() async throws { @@ -535,44 +545,139 @@ final class HubConnectionTests: XCTestCase { await fulfillment(of: [pingExpectations[0], pingExpectations[1], pingExpectations[2]], timeout: 1.0) } - func testStatefulReconnect() async throws { - let bufferSize = 10 - let expectation = XCTestExpectation(description: "send() should be called") + func testStatefulReconnect_sendsSequenceMessageOnReconnect() async throws { + let pingExpectations = [XCTestExpectation(description: "ping should be called")] + let disconnectExpectation = XCTestExpectation(description: "disconnect should be called") + let resendExpectation = XCTestExpectation(description: "reconnect should be called") + var sentPingCount = 0 mockConnection.onSend = { data in - expectation.fulfill() - Task { await self.hubConnection.processIncomingData(.string(self.successHandshakeResponse)) } + do { + let messages = try self.hubProtocol.parseMessages(input: data, binder: TestInvocationBinder(binderTypes: [])) + for message in messages { + if message is PingMessage { + if sentPingCount < pingExpectations.count { + pingExpectations[sentPingCount].fulfill() + } + sentPingCount += 1 + return + } + } + Task { await self.hubConnectionForStatefulReconnect.processIncomingData(.string(self.successHandshakeResponse)) } // only success the first time + } catch { + XCTFail("Unexpected error: \(error)") + } + } await mockConnection.setFeature(feature: ConnectionFeature.Reconnect, value: true); - hubConnection = HubConnection( - connection: mockConnection, - logger: Logger(logLevel: .debug, logHandler: logHandler), - hubProtocol: hubProtocol, - retryPolicy: DefaultRetryPolicy(retryDelays: [0, 1, 2]), - serverTimeout: nil, - keepAliveInterval: nil, - statefulReconnectBufferSize: bufferSize - ) - - let startTask = Task { try await hubConnection.start() } + let startTask = Task { try await hubConnectionForStatefulReconnect.start() } defer { startTask.cancel() } - await fulfillment(of: [expectation], timeout: 1.0) - await whenTaskWithTimeout(startTask, timeout: 1.0) + await fulfillment(of: [pingExpectations[0]], timeout: 2) + + await whenTaskWithTimeout(startTask, timeout: 0.1) XCTAssertNotNil(mockConnection.features[ConnectionFeature.Disconnected]); XCTAssertNotNil(mockConnection.features[ConnectionFeature.Resend]); if let disconnectedClosure = mockConnection.features[ConnectionFeature.Disconnected] as? () async -> Void { await disconnectedClosure() - print("called disconnected closure") + disconnectExpectation.fulfill() } if let resendClosure = mockConnection.features[ConnectionFeature.Resend] as? () async -> Any? { let _ = await resendClosure() + resendExpectation.fulfill() + } + await fulfillment(of: [disconnectExpectation, resendExpectation], timeout: 0.1) + + // expected 3 sent messages: [{"protocol":"json","version":2}, {"type":6}, {"type":9,"sequenceId":1}] + XCTAssertEqual(mockConnection.sentData.count, 3); + let sentHubMessages = try getParsedData(data: mockConnection.sentData, binder: TestInvocationBinder(binderTypes: [])) + XCTAssertEqual(sentHubMessages.count, 2); + XCTAssertTrue(sentHubMessages[0] is PingMessage) + XCTAssertTrue(sentHubMessages[1] is SequenceMessage) + XCTAssertEqual((sentHubMessages[1] as! SequenceMessage).sequenceId, 1) + } + + func testStatefulReconnect_resendsMessagesOnReconnect() async throws { + let pingExpectations = [XCTestExpectation(description: "ping should be called")] + let disconnectExpectation = XCTestExpectation(description: "disconnect should be called") + let resendExpectation = XCTestExpectation(description: "reconnect should be called") + var sentPingCount = 0 + + mockConnection.onSend = { data in + do { + let messages = try self.hubProtocol.parseMessages(input: data, binder: TestInvocationBinder(binderTypes: [Int.self])) + for message in messages { + if message is PingMessage { + if sentPingCount < pingExpectations.count { + pingExpectations[sentPingCount].fulfill() + } + sentPingCount += 1 + return + } + } + Task { await self.hubConnectionForStatefulReconnect.processIncomingData(.string(self.successHandshakeResponse)) } // only success the first time + } catch { + XCTFail("Unexpected error: \(error)") + } + } + await mockConnection.setFeature(feature: ConnectionFeature.Reconnect, value: true); + + let startTask = Task { try await hubConnectionForStatefulReconnect.start() } + defer { startTask.cancel() } + await whenTaskWithTimeout(startTask, timeout: 1) + await fulfillment(of: [pingExpectations[0]], timeout: 1) + + XCTAssertNotNil(mockConnection.features[ConnectionFeature.Disconnected]); + XCTAssertNotNil(mockConnection.features[ConnectionFeature.Resend]); - print(mockConnection.sentData.count, mockConnection.sentData) + await whenTaskWithTimeout(Task { try await hubConnectionForStatefulReconnect.send(method: "test", arguments: 13) }, timeout: 0.1); + await whenTaskWithTimeout(Task { try await hubConnectionForStatefulReconnect.send(method: "test", arguments: 12) }, timeout: 0.1); + await whenTaskWithTimeout(Task { try await hubConnectionForStatefulReconnect.send(method: "test", arguments: 11) }, timeout: 0.1); + + if let disconnectedClosure = mockConnection.features[ConnectionFeature.Disconnected] as? () async -> Void { + await disconnectedClosure() + disconnectExpectation.fulfill() + } + + if let resendClosure = mockConnection.features[ConnectionFeature.Resend] as? () async -> Any? { + let _ = await resendClosure() + resendExpectation.fulfill() + } + await fulfillment(of: [disconnectExpectation], timeout: 1) + await fulfillment(of: [resendExpectation], timeout: 1) + + /* Expeceted mockConnection.SentData = [ + 0 {"protocol":"json","version":2} + 1 {"type":6} + 2 {"target":"test","arguments":[13],"type":1} + 3 {"target":"test","arguments":[12],"type":1} + 4 {"target":"test","arguments":[11],"type":1} + ... may contains additional ping messages, ignore them + 5 {"type":9,"sequenceId":1} + 6 {"target":"test","arguments":[13],"type":1} + 7 {"target":"test","arguments":[12],"type":1} + 8 {"target":"test","arguments":[11],"type":1} + ]*/ + // use hubProtocol to parse the messages + let parsedSentData = try getParsedData(data: mockConnection.sentData, binder: TestInvocationBinder(binderTypes: [Int.self])) + let sentHubMessage = removeAllPingMessagesButFirst(messages: parsedSentData) + XCTAssertEqual(sentHubMessage.count, 8) // the first message for handshake is not a HubMessage + XCTAssertTrue(sentHubMessage[0] is PingMessage) + XCTAssertTrue(sentHubMessage[4] is SequenceMessage) + XCTAssertTrue((sentHubMessage[4] as! SequenceMessage).sequenceId == 1) + XCTAssertTrue(sentHubMessage[5] is InvocationMessage) + XCTAssertTrue((sentHubMessage[5] as! InvocationMessage).target == "test") + XCTAssertTrue((sentHubMessage[5] as! InvocationMessage).arguments.value?[0] as? Int == 13) + XCTAssertTrue(sentHubMessage[6] is InvocationMessage) + XCTAssertTrue((sentHubMessage[6] as! InvocationMessage).target == "test") + XCTAssertTrue((sentHubMessage[6] as! InvocationMessage).arguments.value?[0] as? Int == 12) + XCTAssertTrue(sentHubMessage[7] is InvocationMessage) + XCTAssertTrue((sentHubMessage[7] as! InvocationMessage).target == "test") + XCTAssertTrue((sentHubMessage[7] as! InvocationMessage).arguments.value?[0] as? Int == 11) } func serverTimeoutTest() async throws { diff --git a/Tests/SignalRClientTests/Utils.swift b/Tests/SignalRClientTests/Utils.swift new file mode 100644 index 0000000..eb48a22 --- /dev/null +++ b/Tests/SignalRClientTests/Utils.swift @@ -0,0 +1,34 @@ +@testable import SignalRClient + +// Remove all ping messages from the message array but keep the first one. +func removeAllPingMessagesButFirst(messages: [HubMessage]) -> [HubMessage] { + var result = [HubMessage]() + var containsPing = false + for message in messages { + if let pingMessage = message as? PingMessage { + if containsPing { + continue // Skip all but the first ping message + } + result.append(pingMessage) + containsPing = true + } else { + result.append(message) + } + } + return result +} + +func getParsedData(data: [StringOrData?], binder: InvocationBinder) throws -> [HubMessage] { + var parsedData = [HubMessage]() + for item in data { + if item == nil { + continue + } + let messages = try JsonHubProtocol().parseMessages(input: item!, binder: binder); + // append the messages array as a single element + if !messages.isEmpty { + parsedData.append(contentsOf: messages) + } + } + return parsedData +} \ No newline at end of file From da293bd0e30ffae7f92e2bdbbb6f45e0b588cf79 Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Mon, 4 Aug 2025 17:12:37 +0800 Subject: [PATCH 7/8] more --- .../HubConnectionTests.swift | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/Tests/SignalRClientTests/HubConnectionTests.swift b/Tests/SignalRClientTests/HubConnectionTests.swift index 0311efe..4ff86c9 100644 --- a/Tests/SignalRClientTests/HubConnectionTests.swift +++ b/Tests/SignalRClientTests/HubConnectionTests.swift @@ -680,6 +680,95 @@ final class HubConnectionTests: XCTestCase { XCTAssertTrue((sentHubMessage[7] as! InvocationMessage).arguments.value?[0] as? Int == 11) } + func testStatefulReconnect_resendsMessagesWhileDisconnectedOnReconnect() async throws { + let pingExpectations = [XCTestExpectation(description: "ping should be called")] + let disconnectExpectation = XCTestExpectation(description: "disconnect should be called") + let resendExpectation = XCTestExpectation(description: "reconnect should be called") + var sentPingCount = 0 + + mockConnection.onSend = { data in + do { + let messages = try self.hubProtocol.parseMessages(input: data, binder: TestInvocationBinder(binderTypes: [Int.self])) + for message in messages { + if message is PingMessage { + if sentPingCount < pingExpectations.count { + pingExpectations[sentPingCount].fulfill() + } + sentPingCount += 1 + return + } + } + Task { await self.hubConnectionForStatefulReconnect.processIncomingData(.string(self.successHandshakeResponse)) } + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + await mockConnection.setFeature(feature: ConnectionFeature.Reconnect, value: true) + + let startTask = Task { try await hubConnectionForStatefulReconnect.start() } + defer { startTask.cancel() } + await whenTaskWithTimeout(startTask, timeout: 1) + await fulfillment(of: [pingExpectations[0]], timeout: 1) + + XCTAssertNotNil(mockConnection.features[ConnectionFeature.Disconnected]) + XCTAssertNotNil(mockConnection.features[ConnectionFeature.Resend]) + + // Send first message before disconnect + await whenTaskWithTimeout(Task { try await hubConnectionForStatefulReconnect.send(method: "test", arguments: 13) }, timeout: 0.1) + + // Pretend TestConnection disconnected + if let disconnectedClosure = mockConnection.features[ConnectionFeature.Disconnected] as? () async -> Void { + await disconnectedClosure() + disconnectExpectation.fulfill() + } + + // Send while disconnected, should wait until resend completes + let sendTask = Task { try await hubConnectionForStatefulReconnect.send(method: "test", arguments: 22) } + var sendDone = false + let monitorTask = Task { + try await sendTask.value + sendDone = true + } + + // Give a small delay to ensure send is waiting + try await Task.sleep(nanoseconds: 10_000_000) // 10ms + XCTAssertFalse(sendDone) + + if let resendClosure = mockConnection.features[ConnectionFeature.Resend] as? () async -> Any? { + let _ = await resendClosure() + resendExpectation.fulfill() + } + + await whenTaskWithTimeout(monitorTask, timeout: 1) + XCTAssertTrue(sendDone) + + await fulfillment(of: [disconnectExpectation, resendExpectation], timeout: 1) + + /* Expected mockConnection.sentData = [ + 0 {"protocol":"json","version":2} + 1 {"type":6} // ping + 2 {"target":"test","arguments":[13],"type":1} // first send + 3 {"type":9,"sequenceId":1} // sequence message + 4 {"target":"test","arguments":[13],"type":1} // resend first message + 5 {"target":"test","arguments":[22],"type":1} // send message that waited + ]*/ + + let parsedSentData = try getParsedData(data: mockConnection.sentData, binder: TestInvocationBinder(binderTypes: [Int.self])) + let sentHubMessage = removeAllPingMessagesButFirst(messages: parsedSentData) + + XCTAssertEqual(sentHubMessage.count, 5) + XCTAssertTrue(sentHubMessage[0] is PingMessage) + XCTAssertTrue(sentHubMessage[2] is SequenceMessage) + XCTAssertEqual((sentHubMessage[2] as! SequenceMessage).sequenceId, 1) + XCTAssertTrue(sentHubMessage[3] is InvocationMessage) + XCTAssertEqual((sentHubMessage[3] as! InvocationMessage).target, "test") + XCTAssertEqual((sentHubMessage[3] as! InvocationMessage).arguments.value?[0] as? Int, 13) + XCTAssertTrue(sentHubMessage[4] is InvocationMessage) + XCTAssertEqual((sentHubMessage[4] as! InvocationMessage).target, "test") + XCTAssertEqual((sentHubMessage[4] as! InvocationMessage).arguments.value?[0] as? Int, 22) + } + func serverTimeoutTest() async throws { hubConnection = HubConnection( connection: mockConnection, From e83baaff44aa477a1719908060cd79b4be39e2d5 Mon Sep 17 00:00:00 2001 From: xingsy97 <87063252+xingsy97@users.noreply.github.com> Date: Mon, 8 Sep 2025 18:51:18 +0800 Subject: [PATCH 8/8] resolve comment --- Sources/SignalRClient/MessageBuffer.swift | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Sources/SignalRClient/MessageBuffer.swift b/Sources/SignalRClient/MessageBuffer.swift index 871cbdb..9c75514 100644 --- a/Sources/SignalRClient/MessageBuffer.swift +++ b/Sources/SignalRClient/MessageBuffer.swift @@ -80,7 +80,6 @@ actor MessageBuffer { private func performScheduledAck() async { defer { - // 在方法结束时清理定时器 ackTimerHandle = nil } @@ -94,7 +93,7 @@ actor MessageBuffer { try await connection.send(serializedMessage) } } catch { - // 忽略错误,连接关闭时不需要发送ACK + // Ignore exception, no need to send ACK when reconnecting } }