diff --git a/CocoaMQTTTests/CocoaMQTTDeliverTests.swift b/CocoaMQTTTests/CocoaMQTTDeliverTests.swift index 5c6256e9..9108e6d9 100644 --- a/CocoaMQTTTests/CocoaMQTTDeliverTests.swift +++ b/CocoaMQTTTests/CocoaMQTTDeliverTests.swift @@ -131,13 +131,69 @@ class CocoaMQTTDeliverTests: XCTestCase { } } + func testStorage() { + + let clientID = "deliver-unit-testing" + let caller = Caller() + let deliver = CocoaMQTTDeliver() + + let frames = [FramePublish(topic: "t/0", payload: [0x00], qos: .qos0), + FramePublish(topic: "t/1", payload: [0x01], qos: .qos1, msgid: 1), + FramePublish(topic: "t/2", payload: [0x02], qos: .qos2, msgid: 2)] + + guard let storage = CocoaMQTTStorage(by: clientID) else { + XCTAssert(false, "Initial storage failed") + return + } + + deliver.delegate = caller + deliver.recoverSessionBy(storage) + + for f in frames { + _ = deliver.add(f) + } + + var saved = storage.readAll() + XCTAssertEqual(saved.count, 2) + + + deliver.ack(by: FramePubAck(msgid: 1)) + ms_sleep(100) + saved = storage.readAll() + XCTAssertEqual(saved.count, 1) + + deliver.ack(by: FramePubRec(msgid: 2)) + ms_sleep(100) + saved = storage.readAll() + XCTAssertEqual(saved.count, 1) + assertEqual(saved[0], FramePubRel(msgid: 2)) + + + deliver.ack(by: FramePubComp(msgid: 2)) + ms_sleep(100) + saved = storage.readAll() + XCTAssertEqual(saved.count, 0) + + caller.reset() + _ = storage.write(frames[1]) + deliver.recoverSessionBy(storage) + ms_sleep(100) + XCTAssertEqual(caller.frames.count, 1) + assertEqual(caller.frames[0], frames[1]) + + + deliver.ack(by: FramePubAck(msgid: 1)) + ms_sleep(100) + XCTAssertEqual(storage.readAll().count, 0) + } + func testTODO() { // TODO: How to test large of messages combined qos0/qos1/qos2 } // Helper for assert equality for Frame - private func assertEqual(_ f1: Frame, _ f2: Frame) { + private func assertEqual(_ f1: Frame, _ f2: Frame, _ lines: Int = #line) { if let pub1 = f1 as? FramePublish, let pub2 = f2 as? FramePublish { XCTAssertEqual(pub1.topic, pub2.topic) @@ -149,7 +205,7 @@ class CocoaMQTTDeliverTests: XCTestCase { let rel2 = f2 as? FramePubRel{ XCTAssertEqual(rel1.msgid, rel2.msgid) } else { - XCTAssert(false) + XCTAssert(false, "Assert equal failed line: \(lines)") } } @@ -160,16 +216,29 @@ class CocoaMQTTDeliverTests: XCTestCase { private class Caller: CocoaMQTTDeliverProtocol { + private let delegate_queue_key = DispatchSpecificKey() + private let delegate_queue_val = "_custom_delegate_queue_" + var delegateQueue: DispatchQueue var frames = [Frame]() init() { delegateQueue = DispatchQueue(label: "caller.deliver.test") + delegateQueue.setSpecific(key: delegate_queue_key, value: delegate_queue_val) + } + + func reset() { + frames = [] } func deliver(_ deliver: CocoaMQTTDeliver, wantToSend frame: Frame) { + assert_in_del_queue() frames.append(frame) } + + private func assert_in_del_queue() { + XCTAssertEqual(delegate_queue_val, DispatchQueue.getSpecific(key: delegate_queue_key)) + } } diff --git a/CocoaMQTTTests/CocoaMQTTStorageTests.swift b/CocoaMQTTTests/CocoaMQTTStorageTests.swift index 9ced0193..772042a0 100644 --- a/CocoaMQTTTests/CocoaMQTTStorageTests.swift +++ b/CocoaMQTTTests/CocoaMQTTStorageTests.swift @@ -42,6 +42,14 @@ class CocoaMQTTStorageTests: XCTestCase { for i in 0 ..< should.count { assertEqual(should[i], saved?[i]) } + + let taken = storage?.takeAll() + XCTAssertEqual(should.count, taken?.count) + for i in 0 ..< should.count { + assertEqual(should[i], taken?[i]) + } + + XCTAssertEqual(storage?.readAll().count, 0) } private func assertEqual(_ f1: Frame?, _ f2: Frame?) { diff --git a/Source/CocoaMQTT.swift b/Source/CocoaMQTT.swift index 1c7aedea..d6837833 100644 --- a/Source/CocoaMQTT.swift +++ b/Source/CocoaMQTT.swift @@ -623,32 +623,15 @@ extension CocoaMQTT: CocoaMQTTReaderDelegate { func didRecevied(_ reader: CocoaMQTTReader, connack: FrameConnAck) { printDebug("RECV: \(connack)") - switch connack.returnCode { - case .accept: - connState = .connected - default: - connState = .disconnected - internal_disconnect() - return - } - - // TODO: how to handle the cleanSession = false & auto-reconnect - if cleanSession { - deliver.cleanAll() - } - - delegate?.mqtt(self, didConnectAck: connack.returnCode) - didConnectAck(self, connack.returnCode) - - // reset auto-reconnect state if connack.returnCode == .accept { + + // Disable auto-reconnect + reconectTimeInterval = 0 autoReconnTimer = nil is_internal_disconnected = false - } - - // keep alive - if connack.returnCode == .accept { + + // Start keepalive timer let interval = Double(keepAlive <= 0 ? 60: keepAlive) @@ -662,7 +645,28 @@ extension CocoaMQTT: CocoaMQTTReaderDelegate { wself.ping() } } + + // recover session if enable + + if cleanSession { + deliver.cleanAll() + } else { + if let storage = CocoaMQTTStorage(by: clientID) { + deliver.recoverSessionBy(storage) + } else { + printWarning("Localstorage initial failed for key: \(clientID)") + } + } + + connState = .connected + + } else { + connState = .disconnected + internal_disconnect() } + + delegate?.mqtt(self, didConnectAck: connack.returnCode) + didConnectAck(self, connack.returnCode) } func didRecevied(_ reader: CocoaMQTTReader, publish: FramePublish) { diff --git a/Source/CocoaMQTTDeliver.swift b/Source/CocoaMQTTDeliver.swift index 822176e7..8b85b7b1 100644 --- a/Source/CocoaMQTTDeliver.swift +++ b/Source/CocoaMQTTDeliver.swift @@ -73,6 +73,31 @@ class CocoaMQTTDeliver: NSObject { var isInflightFull: Bool { get { return inflight.count >= inflightWindowSize }} var isInflightEmpty: Bool { get { return inflight.count == 0 }} + var storage: CocoaMQTTStorage? + + func recoverSessionBy(_ storage: CocoaMQTTStorage) { + + let frames = storage.takeAll() + guard frames.count >= 0 else { + return + } + + // Sync to push the frame to mqueue for avoiding overcommit + deliverQueue.sync { + for f in frames { + mqueue.append(f) + } + self.storage = storage + printInfo("Deliver recvoer \(frames.count) msgs") + printDebug("Recover message \(frames)") + } + + deliverQueue.async { [weak self] in + guard let wself = self else { return } + wself.tryTransport() + } + } + /// Add a FramePublish to the message queue to wait for sending /// /// return false means the frame is rejected because of the buffer is full @@ -85,6 +110,7 @@ class CocoaMQTTDeliver: NSObject { // Sync to push the frame to mqueue for avoiding overcommit deliverQueue.sync { mqueue.append(frame) + _ = storage?.write(frame) } deliverQueue.async { [weak self] in @@ -110,6 +136,12 @@ class CocoaMQTTDeliver: NSObject { if acked.count == 0 { printWarning("Acknowledge by \(frame), but not found in inflight window") } else { + // TODO: ACK DONT DELETE PUBREL + for f in acked { + if frame is FramePubAck || frame is FramePubComp { + wself.storage?.remove(f) + } + } printDebug("Acknowledge frame id \(msgid) success, acked: \(acked)") wself.tryTransport() } @@ -208,6 +240,7 @@ extension CocoaMQTTDeliver { nframe.frame = pubrel nframe.timestamp = Date(timeIntervalSinceNow: 0).timeIntervalSince1970 + _ = storage?.write(pubrel) sendfun(pubrel) ackedFrames.append(publish) @@ -236,6 +269,11 @@ extension CocoaMQTTDeliver { printError("The deliver delegate is nil!!! the frame will be drop: \(frame)") return } + + if frame.qos == .qos0 { + if let p = frame as? FramePublish { storage?.remove(p) } + } + delegate.delegateQueue.async { delegate.deliver(self, wantToSend: frame) } diff --git a/Source/CocoaMQTTStorage.swift b/Source/CocoaMQTTStorage.swift index cdca936b..156d803c 100644 --- a/Source/CocoaMQTTStorage.swift +++ b/Source/CocoaMQTTStorage.swift @@ -68,26 +68,24 @@ final class CocoaMQTTStorage: CocoaMQTTStorageProtocol { userDefault.removeObject(forKey: key(frame.msgid)) } + func remove(_ frame: Frame) { + if let pub = frame as? FramePublish { + userDefault.removeObject(forKey: key(pub.msgid)) + } else if let rel = frame as? FramePubRel { + userDefault.removeObject(forKey: key(rel.msgid)) + } + } + func synchronize() -> Bool { return userDefault.synchronize() } func readAll() -> [Frame] { - var frames = [Frame]() - let allObjs = userDefault.dictionaryRepresentation().sorted { (k1, k2) in - return k1.key < k2.key - } - for (_, v) in allObjs { - guard let bytes = v as? [UInt8] else { continue } - guard let parsed = parse(bytes) else { continue } - - if let f = FramePublish(fixedHeader: parsed.0, bytes: parsed.1) { - frames.append(f) - } else if let f = FramePubRel(fixedHeader: parsed.0, bytes: parsed.1) { - frames.append(f) - } - } - return frames + return __read(needDelete: false) + } + + func takeAll() -> [Frame] { + return __read(needDelete: true) } private func key(_ msgid: UInt16) -> String { @@ -108,4 +106,27 @@ final class CocoaMQTTStorage: CocoaMQTTStorageProtocol { return nil } + + private func __read(needDelete: Bool) -> [Frame] { + var frames = [Frame]() + let allObjs = userDefault.dictionaryRepresentation().sorted { (k1, k2) in + return k1.key < k2.key + } + for (k, v) in allObjs { + guard let bytes = v as? [UInt8] else { continue } + guard let parsed = parse(bytes) else { continue } + + if needDelete { + userDefault.removeObject(forKey: k) + } + + if let f = FramePublish(fixedHeader: parsed.0, bytes: parsed.1) { + frames.append(f) + } else if let f = FramePubRel(fixedHeader: parsed.0, bytes: parsed.1) { + frames.append(f) + } + } + return frames + } + }