Skip to content

Commit

Permalink
Manual acknowledgments
Browse files Browse the repository at this point in the history
  • Loading branch information
fracasula committed Apr 26, 2021
1 parent 7bbf6d3 commit 6be477c
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 35 deletions.
3 changes: 1 addition & 2 deletions packets/publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@ func (p *Publish) Buffers() net.Buffers {
var b bytes.Buffer
writeString(p.Topic, &b)
if p.QoS > 0 {
writeUint16(p.PacketID, &b)
_ = writeUint16(p.PacketID, &b)
}
idvp := p.Properties.Pack(PUBLISH)
encodeVBIdirect(len(idvp), &b)
return net.Buffers{b.Bytes(), idvp, p.Payload}

}

// WriteTo is the implementation of the interface required function for a packet
Expand Down
74 changes: 74 additions & 0 deletions paho/acks_tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package paho

import (
"errors"
"sync"

"github.com/eclipse/paho.golang/packets"
)

var (
ErrPacketNotFound = errors.New("packet not found")
)

type acksTracker struct {
mx sync.Mutex
order []packet
}

func (t *acksTracker) add(pb *packets.Publish) {
t.mx.Lock()
defer t.mx.Unlock()

// @TODO don't just add, check for duplicates and skip if already acked
t.order = append(t.order, packet{pb: pb})
}

func (t *acksTracker) markAsAcked(pb *packets.Publish) error {
t.mx.Lock()
defer t.mx.Unlock()

for k, v := range t.order {
if pb.PacketID == v.pb.PacketID {
t.order[k].acknowledged = true
return nil
}
}

return ErrPacketNotFound
}

func (t *acksTracker) flush(do func([]*packets.Publish)) {
t.mx.Lock()
defer t.mx.Unlock()

var (
buf []*packets.Publish
)
for _, v := range t.order {
if v.acknowledged {
buf = append(buf, v.pb)
} else {
break
}
}

if len(buf) == 0 {
return
}

do(buf)
t.order = t.order[len(buf):]
}

// reset should be used upon disconnections
func (t *acksTracker) reset() {
t.mx.Lock()
defer t.mx.Unlock()
t.order = nil
}

type packet struct {
pb *packets.Publish
acknowledged bool
}
79 changes: 79 additions & 0 deletions paho/acks_tracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package paho

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/eclipse/paho.golang/packets"
)

func TestAcksTracker(t *testing.T) {
var (
at acksTracker
p1 = &packets.Publish{PacketID: 1}
p2 = &packets.Publish{PacketID: 2}
p3 = &packets.Publish{PacketID: 3}
p4 = &packets.Publish{PacketID: 4} // to test not found
)

t.Run("flush-without-acking", func(t *testing.T) {
at.add(p1)
at.add(p2)
at.add(p3)
require.Equal(t, ErrPacketNotFound, at.markAsAcked(p4))
at.flush(func(_ []*packets.Publish) {
t.Fatal("flush should not call 'do' since no packets have been acknowledged so far")
})
})

t.Run("ack-in-the-middle", func(t *testing.T) {
require.NoError(t, at.markAsAcked(p3))
at.flush(func(_ []*packets.Publish) {
t.Fatal("flush should not call 'do' since p1 and p2 have not been acknowledged yet")
})
})

t.Run("idempotent-acking", func(t *testing.T) {
require.NoError(t, at.markAsAcked(p3))
require.NoError(t, at.markAsAcked(p3))
require.NoError(t, at.markAsAcked(p3))
})

t.Run("ack-first", func(t *testing.T) {
var flushCalled bool
require.NoError(t, at.markAsAcked(p1))
at.flush(func(pbs []*packets.Publish) {
require.Equal(t, []*packets.Publish{p1}, pbs, "Only p1 expected even though p3 was acked, p2 is still missing")
flushCalled = true
})
require.True(t, flushCalled)
})

t.Run("ack-after-flush", func(t *testing.T) {
var flushCalled bool
require.NoError(t, at.markAsAcked(p2))
at.add(p4) // this should just be appended and not flushed (yet)
at.flush(func(pbs []*packets.Publish) {
require.Equal(t, []*packets.Publish{p2, p3}, pbs, "Only p2 and p3 expected, p1 was flushed in the previous call")
flushCalled = true
})
require.True(t, flushCalled)
})

t.Run("ack-last", func(t *testing.T) {
var flushCalled bool
require.NoError(t, at.markAsAcked(p4))
at.flush(func(pbs []*packets.Publish) {
require.Equal(t, []*packets.Publish{p4}, pbs, "Only p4 expected, the rest was flushed in previous calls")
flushCalled = true
})
require.True(t, flushCalled)
})

t.Run("flush-after-acking-everything", func(t *testing.T) {
at.flush(func(_ []*packets.Publish) {
t.Fatal("no call to 'do' expected, we flushed all packets")
})
})
}
123 changes: 90 additions & 33 deletions paho/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,31 @@ type (
// are required to be set, defaults are provided for Persistence, MIDs,
// PingHandler, PacketTimeout and Router.
ClientConfig struct {
ClientID string
Conn net.Conn
MIDs MIDService
AuthHandler Auther
PingHandler Pinger
Router Router
Persistence Persistence
PacketTimeout time.Duration
OnServerDisconnect func(*Disconnect)
ClientID string
Conn net.Conn
MIDs MIDService
AuthHandler Auther
PingHandler Pinger
Router Router
Persistence Persistence
PacketTimeout time.Duration
// Only called when receiving packets.DISCONNECT from server
OnClientError func(error)
OnServerDisconnect func(*Disconnect)
// Client error call, For example: net.Error
PublishHook func(*Publish)
OnClientError func(error)
// PublishHook allows a user provided function to be called before
// a Publish packet is sent allowing it to inspect or modify the
// Publish, an example of the utility of this is provided in the
// Topic Alias Handler extension which will automatically assign
// and use topic alias values rather than topic strings.
PublishHook func(*Publish)
// EnableManualAcknowledgment is used to control the acknowledgment of packets manually.
// BEWARE that the MQTT specs require clients to send acknowledgments in the order in which the corresponding
// PUBLISH packets were received.
// Consider the following scenario: the client receives packets 1,2,3,4
// If you acknowledge 3 first, no ack is actually sent to the server but it's buffered until also 1 and 2
// are acknowledged.
EnableManualAcknowledgment bool
}
// Client is the struct representing an MQTT client
Client struct {
Expand All @@ -56,6 +63,7 @@ type (
raCtx *CPContext
stop chan struct{}
publishPackets chan *packets.Publish
acksTracker acksTracker
workers sync.WaitGroup
serverProps CommsProperties
clientProps CommsProperties
Expand Down Expand Up @@ -203,6 +211,27 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) {
c.routePublishPackets()
}()

if c.EnableManualAcknowledgment {
c.workers.Add(1)
go func() {
defer c.workers.Done()
defer c.debug.Println("returning for ack tracker routine")
t := time.NewTicker(10 * time.Millisecond) // @TODO ticker should be configurable
for {
select {
case <-c.stop:
return
case <-t.C:
c.acksTracker.flush(func(pbs []*packets.Publish) {
for _, pb := range pbs {
c.ack(pb) // @TODO handle potential error from c.ack(pb)?
}
})
}
}
}()
}

c.debug.Println("starting Incoming")
c.workers.Add(1)
go func() {
Expand Down Expand Up @@ -301,29 +330,55 @@ func (c *Client) routePublishPackets() {
if !open {
return
}
c.Router.Route(pb)
switch pb.QoS {
case 1:
pa := packets.Puback{
Properties: &packets.Properties{},
PacketID: pb.PacketID,
}
c.debug.Println("sending PUBACK")
_, err := pa.WriteTo(c.Conn)
if err != nil {
c.errors.Printf("failed to send PUBACK for %d: %s", pb.PacketID, err)
}
case 2:
pr := packets.Pubrec{
Properties: &packets.Properties{},
PacketID: pb.PacketID,
}
c.debug.Printf("sending PUBREC")
_, err := pr.WriteTo(c.Conn)
if err != nil {
c.errors.Printf("failed to send PUBREC for %d: %s", pb.PacketID, err)
}

if !c.ClientConfig.EnableManualAcknowledgment {
c.Router.Route(pb)
c.ack(pb)
continue
}

if pb.QoS == 0 {
continue
}

c.acksTracker.add(pb)
c.Router.Route(pb)
}
}
}

func (c *Client) Ack(pb *packets.Publish) error {
if !c.EnableManualAcknowledgment {
// @TODO use variable? e.g. ErrManualAcknowledgmentDisabled
return fmt.Errorf("cannot ack with manual acknowledgment is disabled")
}
if pb.QoS == 0 {
return nil
}
return c.acksTracker.markAsAcked(pb)
}

func (c *Client) ack(pb *packets.Publish) {
switch pb.QoS {
case 1:
pa := packets.Puback{
Properties: &packets.Properties{},
PacketID: pb.PacketID,
}
c.debug.Println("sending PUBACK")
_, err := pa.WriteTo(c.Conn)
if err != nil {
c.errors.Printf("failed to send PUBACK for %d: %s", pb.PacketID, err)
}
case 2:
pr := packets.Pubrec{
Properties: &packets.Properties{},
PacketID: pb.PacketID,
}
c.debug.Printf("sending PUBREC")
_, err := pr.WriteTo(c.Conn)
if err != nil {
c.errors.Printf("failed to send PUBREC for %d: %s", pb.PacketID, err)
}
}
}
Expand Down Expand Up @@ -464,6 +519,8 @@ func (c *Client) close() {
c.debug.Println("ping stopped")
_ = c.Conn.Close()
c.debug.Println("conn closed")
// @TODO check if c.close() is always called on disconnections (client & server)
c.acksTracker.reset() // upon reconnection the unacked messages will be redelivered
}

// Error is called to signify that an error situation has occurred, this
Expand Down

0 comments on commit 6be477c

Please sign in to comment.