Skip to content

Commit

Permalink
Support for cleansession=false (#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
HJianBo committed Dec 29, 2019
1 parent 0b311a1 commit a652cb7
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 39 deletions.
73 changes: 71 additions & 2 deletions CocoaMQTTTests/CocoaMQTTDeliverTests.swift
Expand Up @@ -131,13 +131,69 @@ class CocoaMQTTDeliverTests: XCTestCase {
}
}

func testStorage() {

let clientID = "deliver-unit-testing"
let caller = Caller()
let deliver = CocoaMQTTDeliver()

let frames = [FramePublish(topic: "t/0", payload: [0x00], qos: .qos0),
FramePublish(topic: "t/1", payload: [0x01], qos: .qos1, msgid: 1),
FramePublish(topic: "t/2", payload: [0x02], qos: .qos2, msgid: 2)]

guard let storage = CocoaMQTTStorage(by: clientID) else {
XCTAssert(false, "Initial storage failed")
return
}

deliver.delegate = caller
deliver.recoverSessionBy(storage)

for f in frames {
_ = deliver.add(f)
}

var saved = storage.readAll()
XCTAssertEqual(saved.count, 2)


deliver.ack(by: FramePubAck(msgid: 1))
ms_sleep(100)
saved = storage.readAll()
XCTAssertEqual(saved.count, 1)

deliver.ack(by: FramePubRec(msgid: 2))
ms_sleep(100)
saved = storage.readAll()
XCTAssertEqual(saved.count, 1)
assertEqual(saved[0], FramePubRel(msgid: 2))


deliver.ack(by: FramePubComp(msgid: 2))
ms_sleep(100)
saved = storage.readAll()
XCTAssertEqual(saved.count, 0)

caller.reset()
_ = storage.write(frames[1])
deliver.recoverSessionBy(storage)
ms_sleep(100)
XCTAssertEqual(caller.frames.count, 1)
assertEqual(caller.frames[0], frames[1])


deliver.ack(by: FramePubAck(msgid: 1))
ms_sleep(100)
XCTAssertEqual(storage.readAll().count, 0)
}

func testTODO() {
// TODO: How to test large of messages combined qos0/qos1/qos2
}


// Helper for assert equality for Frame
private func assertEqual(_ f1: Frame, _ f2: Frame) {
private func assertEqual(_ f1: Frame, _ f2: Frame, _ lines: Int = #line) {
if let pub1 = f1 as? FramePublish,
let pub2 = f2 as? FramePublish {
XCTAssertEqual(pub1.topic, pub2.topic)
Expand All @@ -149,7 +205,7 @@ class CocoaMQTTDeliverTests: XCTestCase {
let rel2 = f2 as? FramePubRel{
XCTAssertEqual(rel1.msgid, rel2.msgid)
} else {
XCTAssert(false)
XCTAssert(false, "Assert equal failed line: \(lines)")
}
}

Expand All @@ -160,16 +216,29 @@ class CocoaMQTTDeliverTests: XCTestCase {

private class Caller: CocoaMQTTDeliverProtocol {

private let delegate_queue_key = DispatchSpecificKey<String>()
private let delegate_queue_val = "_custom_delegate_queue_"

var delegateQueue: DispatchQueue

var frames = [Frame]()

init() {
delegateQueue = DispatchQueue(label: "caller.deliver.test")
delegateQueue.setSpecific(key: delegate_queue_key, value: delegate_queue_val)
}

func reset() {
frames = []
}

func deliver(_ deliver: CocoaMQTTDeliver, wantToSend frame: Frame) {
assert_in_del_queue()

frames.append(frame)
}

private func assert_in_del_queue() {
XCTAssertEqual(delegate_queue_val, DispatchQueue.getSpecific(key: delegate_queue_key))
}
}
8 changes: 8 additions & 0 deletions CocoaMQTTTests/CocoaMQTTStorageTests.swift
Expand Up @@ -42,6 +42,14 @@ class CocoaMQTTStorageTests: XCTestCase {
for i in 0 ..< should.count {
assertEqual(should[i], saved?[i])
}

let taken = storage?.takeAll()
XCTAssertEqual(should.count, taken?.count)
for i in 0 ..< should.count {
assertEqual(should[i], taken?[i])
}

XCTAssertEqual(storage?.readAll().count, 0)
}

private func assertEqual(_ f1: Frame?, _ f2: Frame?) {
Expand Down
48 changes: 26 additions & 22 deletions Source/CocoaMQTT.swift
Expand Up @@ -623,32 +623,15 @@ extension CocoaMQTT: CocoaMQTTReaderDelegate {
func didRecevied(_ reader: CocoaMQTTReader, connack: FrameConnAck) {
printDebug("RECV: \(connack)")

switch connack.returnCode {
case .accept:
connState = .connected
default:
connState = .disconnected
internal_disconnect()
return
}

// TODO: how to handle the cleanSession = false & auto-reconnect
if cleanSession {
deliver.cleanAll()
}

delegate?.mqtt(self, didConnectAck: connack.returnCode)
didConnectAck(self, connack.returnCode)

// reset auto-reconnect state
if connack.returnCode == .accept {

// Disable auto-reconnect

reconectTimeInterval = 0
autoReconnTimer = nil
is_internal_disconnected = false
}

// keep alive
if connack.returnCode == .accept {

// Start keepalive timer

let interval = Double(keepAlive <= 0 ? 60: keepAlive)

Expand All @@ -662,7 +645,28 @@ extension CocoaMQTT: CocoaMQTTReaderDelegate {
wself.ping()
}
}

// recover session if enable

if cleanSession {
deliver.cleanAll()
} else {
if let storage = CocoaMQTTStorage(by: clientID) {
deliver.recoverSessionBy(storage)
} else {
printWarning("Localstorage initial failed for key: \(clientID)")
}
}

connState = .connected

} else {
connState = .disconnected
internal_disconnect()
}

delegate?.mqtt(self, didConnectAck: connack.returnCode)
didConnectAck(self, connack.returnCode)
}

func didRecevied(_ reader: CocoaMQTTReader, publish: FramePublish) {
Expand Down
38 changes: 38 additions & 0 deletions Source/CocoaMQTTDeliver.swift
Expand Up @@ -73,6 +73,31 @@ class CocoaMQTTDeliver: NSObject {
var isInflightFull: Bool { get { return inflight.count >= inflightWindowSize }}
var isInflightEmpty: Bool { get { return inflight.count == 0 }}

var storage: CocoaMQTTStorage?

func recoverSessionBy(_ storage: CocoaMQTTStorage) {

let frames = storage.takeAll()
guard frames.count >= 0 else {
return
}

// Sync to push the frame to mqueue for avoiding overcommit
deliverQueue.sync {
for f in frames {
mqueue.append(f)
}
self.storage = storage
printInfo("Deliver recvoer \(frames.count) msgs")
printDebug("Recover message \(frames)")
}

deliverQueue.async { [weak self] in
guard let wself = self else { return }
wself.tryTransport()
}
}

/// Add a FramePublish to the message queue to wait for sending
///
/// return false means the frame is rejected because of the buffer is full
Expand All @@ -85,6 +110,7 @@ class CocoaMQTTDeliver: NSObject {
// Sync to push the frame to mqueue for avoiding overcommit
deliverQueue.sync {
mqueue.append(frame)
_ = storage?.write(frame)
}

deliverQueue.async { [weak self] in
Expand All @@ -110,6 +136,12 @@ class CocoaMQTTDeliver: NSObject {
if acked.count == 0 {
printWarning("Acknowledge by \(frame), but not found in inflight window")
} else {
// TODO: ACK DONT DELETE PUBREL
for f in acked {
if frame is FramePubAck || frame is FramePubComp {
wself.storage?.remove(f)
}
}
printDebug("Acknowledge frame id \(msgid) success, acked: \(acked)")
wself.tryTransport()
}
Expand Down Expand Up @@ -208,6 +240,7 @@ extension CocoaMQTTDeliver {
nframe.frame = pubrel
nframe.timestamp = Date(timeIntervalSinceNow: 0).timeIntervalSince1970

_ = storage?.write(pubrel)
sendfun(pubrel)

ackedFrames.append(publish)
Expand Down Expand Up @@ -236,6 +269,11 @@ extension CocoaMQTTDeliver {
printError("The deliver delegate is nil!!! the frame will be drop: \(frame)")
return
}

if frame.qos == .qos0 {
if let p = frame as? FramePublish { storage?.remove(p) }
}

delegate.delegateQueue.async {
delegate.deliver(self, wantToSend: frame)
}
Expand Down
51 changes: 36 additions & 15 deletions Source/CocoaMQTTStorage.swift
Expand Up @@ -68,26 +68,24 @@ final class CocoaMQTTStorage: CocoaMQTTStorageProtocol {
userDefault.removeObject(forKey: key(frame.msgid))
}

func remove(_ frame: Frame) {
if let pub = frame as? FramePublish {
userDefault.removeObject(forKey: key(pub.msgid))
} else if let rel = frame as? FramePubRel {
userDefault.removeObject(forKey: key(rel.msgid))
}
}

func synchronize() -> Bool {
return userDefault.synchronize()
}

func readAll() -> [Frame] {
var frames = [Frame]()
let allObjs = userDefault.dictionaryRepresentation().sorted { (k1, k2) in
return k1.key < k2.key
}
for (_, v) in allObjs {
guard let bytes = v as? [UInt8] else { continue }
guard let parsed = parse(bytes) else { continue }

if let f = FramePublish(fixedHeader: parsed.0, bytes: parsed.1) {
frames.append(f)
} else if let f = FramePubRel(fixedHeader: parsed.0, bytes: parsed.1) {
frames.append(f)
}
}
return frames
return __read(needDelete: false)
}

func takeAll() -> [Frame] {
return __read(needDelete: true)
}

private func key(_ msgid: UInt16) -> String {
Expand All @@ -108,4 +106,27 @@ final class CocoaMQTTStorage: CocoaMQTTStorageProtocol {

return nil
}

private func __read(needDelete: Bool) -> [Frame] {
var frames = [Frame]()
let allObjs = userDefault.dictionaryRepresentation().sorted { (k1, k2) in
return k1.key < k2.key
}
for (k, v) in allObjs {
guard let bytes = v as? [UInt8] else { continue }
guard let parsed = parse(bytes) else { continue }

if needDelete {
userDefault.removeObject(forKey: k)
}

if let f = FramePublish(fixedHeader: parsed.0, bytes: parsed.1) {
frames.append(f)
} else if let f = FramePubRel(fixedHeader: parsed.0, bytes: parsed.1) {
frames.append(f)
}
}
return frames
}

}

0 comments on commit a652cb7

Please sign in to comment.